| """Prediction helpers for different model types. |
| |
| This module provides `ModelPredictor`, a lightweight wrapper that unifies |
| inference for SetFit, scikit-learn RandomForest pipelines, and HuggingFace |
| transformer sequence classification models. It standardizes inputs/outputs |
| to a NumPy array of shape (n_samples, n_labels). |
| """ |
|
|
| import os |
| from typing import List, Union |
|
|
| import joblib |
| import numpy as np |
| from setfit import SetFitModel |
| import torch |
| from transformers import AutoModelForSequenceClassification, AutoTokenizer |
|
|
| TextInput = Union[str, List[str]] |
|
|
|
|
| class ModelPredictor: |
| """Unified predictor for SetFit, Random Forest and Transformer models. |
| |
| Expected directory layout: |
| |
| models/ |
| βββ java/ |
| β βββ setfit/ # SetFit saved model directory |
| β βββ random_forest.joblib # sklearn pipeline |
| β βββ transformer/ # HF model + tokenizer (config.json, etc.) |
| βββ python/ |
| β βββ setfit/ |
| β βββ random_forest.joblib |
| β βββ transformer/ |
| βββ pharo/ |
| βββ setfit/ |
| βββ random_forest.joblib |
| βββ transformer/ |
| """ |
|
|
| def __init__( |
| self, |
| lang: str, |
| model_type: str, |
| model_root: str = "models", |
| threshold: float = 0.5, |
| max_length: int = 128, |
| ) -> None: |
| """Parameters |
| |
| ---------- |
| lang : str |
| One of {"java", "python", "pharo"}. |
| model_type : str |
| One of {"setfit", "random_forest", "transformer"}. |
| model_root : str |
| Root directory where models are stored. |
| threshold : float |
| Decision threshold for multi-label Transformer predictions. |
| Ignored for SetFit and Random Forest (they already output labels). |
| max_length : int |
| Max sequence length for Transformer tokenization. |
| |
| """ |
| self.lang = lang |
| self.model_type = model_type |
| self.model_root = model_root |
| self.threshold = float(threshold) |
| self.max_length = int(max_length) |
|
|
| |
| self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
| if model_type == "setfit": |
| model_path = os.path.join(self.model_root, self.lang, "setfit") |
| if not os.path.isdir(model_path): |
| raise FileNotFoundError(f"SetFit model not found at: {model_path}") |
| self.model = SetFitModel.from_pretrained(model_path) |
|
|
| elif model_type == "random_forest": |
| model_path = os.path.join(self.model_root, self.lang, "random_forest.joblib") |
| if not os.path.isfile(model_path): |
| raise FileNotFoundError(f"Random Forest model not found at: {model_path}") |
| self.model = joblib.load(model_path) |
|
|
| elif model_type == "transformer": |
| model_path = os.path.join(self.model_root, self.lang, "transformer") |
| if not os.path.isdir(model_path): |
| raise FileNotFoundError(f"Transformer model not found at: {model_path}") |
|
|
| |
| self.tokenizer = AutoTokenizer.from_pretrained(model_path) |
| self.model = AutoModelForSequenceClassification.from_pretrained(model_path).to( |
| self.device |
| ) |
| self.model.eval() |
|
|
| else: |
| raise ValueError(f"Unsupported model_type: {model_type}") |
|
|
| def predict(self, texts: TextInput) -> np.ndarray: |
| """Run prediction on one or many text samples. |
| |
| Parameters |
| ---------- |
| texts : str | list[str] |
| A single text or a list of texts. |
| |
| Returns |
| ------- |
| np.ndarray |
| Array of shape (n_samples, n_labels) with integer (typically binary) values. |
| |
| """ |
| if isinstance(texts, str): |
| texts = [texts] |
|
|
| if self.model_type == "setfit": |
| raw_outputs = self.model(texts) |
| outputs = np.array(list(raw_outputs), dtype=int) |
|
|
| elif self.model_type == "random_forest": |
| raw_outputs = self.model.predict(texts) |
| outputs = np.array(list(raw_outputs), dtype=int) |
|
|
| elif self.model_type == "transformer": |
| enc = self.tokenizer( |
| texts, |
| padding=True, |
| truncation=True, |
| max_length=self.max_length, |
| return_tensors="pt", |
| ) |
| enc = {k: v.to(self.device) for k, v in enc.items()} |
|
|
| with torch.no_grad(): |
| logits = self.model(**enc).logits |
| probs = torch.sigmoid(logits) |
| preds = (probs > self.threshold).long().cpu().numpy() |
|
|
| outputs = preds.astype(int) |
| else: |
| raise ValueError(f"Unsupported model_type: {self.model_type}") |
|
|
| |
| if outputs.ndim == 1: |
| outputs = outputs.reshape(1, -1) |
|
|
| return outputs |
|
|