| """Main API for Code Comment Classification using FastAPI.""" |
| from contextlib import asynccontextmanager |
| from datetime import datetime |
| from functools import lru_cache, wraps |
| from http import HTTPStatus |
| import inspect |
| import logging |
| import os |
| from pathlib import Path |
|
|
| from api.monitoring import instrumentator, prediction_metric |
| from api.schemas import PredictRequest |
| from api.sync_models import sync_best_models_to_disk |
| from fastapi import FastAPI, Request, Response |
| from fastapi.middleware.cors import CORSMiddleware |
| from fastapi.responses import JSONResponse |
|
|
| from codecommentclassification import ModelPredictor |
|
|
| MODELS_DIR = Path(os.getenv("MODELS_DIR", "models/api")) |
|
|
|
|
| logging.basicConfig( |
| level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s" |
| ) |
| logger = logging.getLogger(__name__) |
|
|
|
|
| @lru_cache(maxsize=3) |
| def get_predictor(lang: str, model_type: str) -> ModelPredictor: |
| """Lazily loads the heavy model only when requested.""" |
| logger.info(f"Loading model for {lang} - {model_type}...") |
| return ModelPredictor(lang=lang, model_type=model_type, model_root=str(MODELS_DIR)) |
|
|
|
|
| @asynccontextmanager |
| async def lifespan(app: FastAPI): |
| """Lifespan context manager to sync models at startup.""" |
| try: |
| logger.info(f"Syncing champion models from MLflow to {MODELS_DIR}...") |
| sync_best_models_to_disk( |
| models_root=MODELS_DIR.parent, |
| api_subdir=MODELS_DIR.name, |
| ) |
| except Exception as e: |
| logger.error(f"Failed to sync models at startup: {e}") |
|
|
| if not MODELS_DIR.exists(): |
| logger.warning(f"Models directory not found at: {MODELS_DIR.resolve()}") |
| else: |
| logger.info(f"Using models from: {MODELS_DIR.resolve()}") |
| yield |
| get_predictor.cache_clear() |
|
|
|
|
| app = FastAPI( |
| title="Code Comment Classification API", |
| description="API for classifying code comments using SetFit models.", |
| version="0.1", |
| lifespan=lifespan, |
| ) |
|
|
| instrumentator.instrument(app) |
| instrumentator.expose(app) |
|
|
| frontend_origins = os.getenv("FRONTEND_ORIGINS") |
|
|
| if frontend_origins: |
| origins = [o.strip() for o in frontend_origins.split(",") if o.strip()] |
| else: |
| |
| origins = [ |
| "http://localhost:5173", |
| "http://127.0.0.1:5173", |
| "http://localhost", |
| ] |
|
|
| app.add_middleware( |
| CORSMiddleware, |
| allow_origins=origins, |
| allow_credentials=True, |
| allow_methods=["*"], |
| allow_headers=["*"], |
| ) |
|
|
|
|
| def _build_response(results: dict, request: Request): |
| if isinstance(results, (Response, JSONResponse)): |
| return results |
|
|
| response = { |
| "message": results["message"], |
| "method": request.method, |
| "status-code": results["status-code"], |
| "timestamp": datetime.now().isoformat(), |
| "url": request.url._url, |
| } |
|
|
| if "data" in results: |
| response["data"] = results["data"] |
|
|
| return response |
|
|
|
|
| def construct_response(f): |
| """Construct a JSON response for an endpoint's results (sync and async).""" |
| if inspect.iscoroutinefunction(f): |
|
|
| @wraps(f) |
| async def wrap(request: Request, *args, **kwargs): |
| results = await f(request, *args, **kwargs) |
| return _build_response(results, request) |
| else: |
|
|
| @wraps(f) |
| def wrap(request: Request, *args, **kwargs): |
| results = f(request, *args, **kwargs) |
| return _build_response(results, request) |
|
|
| return wrap |
|
|
|
|
| @app.get("/", tags=["General"]) |
| @construct_response |
| def _index(request: Request): |
| """Root endpoint.""" |
| return { |
| "message": HTTPStatus.OK.phrase, |
| "status-code": HTTPStatus.OK, |
| "data": { |
| "message": "Welcome to the Code Comment Classification API! Please use /docs for API documentation." |
| }, |
| } |
|
|
|
|
| @app.get("/privacy", tags=["General"]) |
| @construct_response |
| async def get_privacy_notice(request: Request): |
| """Return the Privacy Notice for the API.""" |
| return { |
| "message": "Privacy Notice", |
| "status-code": HTTPStatus.OK, |
| "data": { |
| "policy": "This API processes text data for classification purposes only. No data is permanently stored.", |
| "compliance_link": "https://behavizapi.peopleware.ai/api/docs#section/Getting-Started/Privacy-Notice", |
| }, |
| } |
|
|
|
|
| @app.get("/status") |
| def get_status(): |
| """Endpoint to check if the API is running.""" |
| return {"status": "API is running"} |
|
|
|
|
| @app.get("/models", tags=["Prediction"]) |
| @construct_response |
| def _get_models_list(request: Request): |
| """Return the list of available languages based on directories found in models/ .""" |
| |
| if MODELS_DIR.exists(): |
| available_languages = [ |
| {"language": d.name, "model_types": mt.name} |
| for d in MODELS_DIR.iterdir() |
| if d.is_dir() |
| for mt in d.iterdir() |
| if mt.is_dir() |
| ] |
| else: |
| available_languages = [] |
|
|
| return { |
| "message": HTTPStatus.OK.phrase, |
| "status-code": HTTPStatus.OK, |
| "data": available_languages, |
| } |
|
|
|
|
| @app.post("/predict", tags=["Prediction"]) |
| @construct_response |
| def predict( |
| request: Request, |
| payload: PredictRequest, |
| ): |
| """Inference endpoint.""" |
| if payload.model_type is None: |
| return { |
| "message": "Model type must be specified.", |
| "status-code": HTTPStatus.BAD_REQUEST, |
| } |
|
|
| try: |
| predictor = get_predictor(payload.language.value, payload.model_type.value) |
| result = predictor.predict(payload.text) |
| predictions_list = result.tolist() if hasattr(result, "tolist") else result |
|
|
| |
| prediction_metric.labels( |
| language=payload.language.value, |
| model_type=payload.model_type.value, |
| ).inc() |
|
|
| return { |
| "message": HTTPStatus.OK.phrase, |
| "status-code": HTTPStatus.OK, |
| "data": { |
| "language": payload.language, |
| "model_type": payload.model_type, |
| "predictions": predictions_list, |
| }, |
| } |
|
|
| except FileNotFoundError: |
| return { |
| "message": f"Model for language '{payload.language}' not found.", |
| "status-code": HTTPStatus.NOT_FOUND, |
| } |
| except ValueError as e: |
| return { |
| "message": str(e), |
| "status-code": HTTPStatus.BAD_REQUEST, |
| } |
| except Exception as e: |
| return { |
| "message": f"Internal Error: {str(e)}", |
| "status-code": HTTPStatus.INTERNAL_SERVER_ERROR, |
| } |
|
|