"""Evaluate a trained classification model (ViT + LoRA) on the TN5000 test set. Produces: - AUC-ROC, F1-Score, Sensitivity, Specificity, ECE - Confusion matrix - Per-class classification report - Inference latency measurements (Teacher vs Student comparison) Usage:: python scripts/evaluate_classification.py --checkpoint outputs/classification/best python scripts/evaluate_classification.py --checkpoint outputs/classification/best --split test """ from __future__ import annotations import argparse import json import time from pathlib import Path import numpy as np import torch import torch.nn.functional as F from peft import PeftModel from torch.utils.data import DataLoader from tqdm import tqdm from transformers import AutoImageProcessor, AutoModelForImageClassification from thyroid_vfm.config import ClassificationConfig, load_yaml_config from thyroid_vfm.data.transforms import build_ultrasound_transform from thyroid_vfm.data.voc import Tn5000ClassificationDataset from thyroid_vfm.evaluation.metrics import compute_classification_metrics def _build_test_loader(config: ClassificationConfig, processor, split: str = "test"): class_to_id = {name: idx for idx, name in enumerate(config.dataset.class_names)} dataset = Tn5000ClassificationDataset( root_dir=config.dataset.root_dir, split=split, class_to_id=class_to_id, transform=build_ultrasound_transform(config.dataset.image_size, train=False), images_dir=config.dataset.images_dir, annotations_dir=config.dataset.annotations_dir, splits_dir=config.dataset.splits_dir, ) def collate_fn(batch: list[dict]) -> dict[str, torch.Tensor]: images = [item["image"] for item in batch] labels = torch.tensor([item["label"] for item in batch], dtype=torch.long) encoded = processor(images=images, return_tensors="pt") encoded["labels"] = labels return encoded return DataLoader( dataset, batch_size=config.batch_size, shuffle=False, num_workers=config.num_workers, collate_fn=collate_fn, ) @torch.no_grad() def evaluate( checkpoint_dir: Path, config: ClassificationConfig, split: str = "test", ) -> None: device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print(f" Device: {device}") # Load model and processor print(f" Loading model from {checkpoint_dir} ...") base_model = AutoModelForImageClassification.from_pretrained( config.model_name, num_labels=config.num_labels, ignore_mismatched_sizes=True, ) model = PeftModel.from_pretrained(base_model, str(checkpoint_dir)) # Load full model state (includes re-initialized classifier head) full_state = checkpoint_dir / "full_model.pt" if full_state.exists(): state = torch.load(str(full_state), map_location="cpu", weights_only=True) model.load_state_dict(state, strict=False) print(" Loaded full model state (with classifier head).") model.to(device).eval() processor = AutoImageProcessor.from_pretrained(str(checkpoint_dir)) loader = _build_test_loader(config, processor, split=split) print(f" Evaluating on {len(loader.dataset)} samples (split={split!r})...\n") all_labels = [] all_preds = [] all_probs = [] latencies = [] for batch in tqdm(loader, desc="Inference"): batch = {k: v.to(device) for k, v in batch.items()} t0 = time.perf_counter() outputs = model(**batch) if device.type == "cuda": torch.cuda.synchronize() latencies.append(time.perf_counter() - t0) logits = outputs.logits probs = F.softmax(logits, dim=-1) preds = logits.argmax(dim=-1) all_labels.append(batch["labels"].cpu().numpy()) all_preds.append(preds.cpu().numpy()) all_probs.append(probs.cpu().numpy()) labels = np.concatenate(all_labels) preds = np.concatenate(all_preds) probs = np.concatenate(all_probs) metrics = compute_classification_metrics(labels, preds, probs, config.dataset.class_names) print(metrics.summary()) # Latency stats total_samples = len(labels) total_time = sum(latencies) per_sample_ms = (total_time / total_samples) * 1000 print(f"\n Inference Latency:") print(f" Total time : {total_time:.3f} s") print(f" Per sample : {per_sample_ms:.2f} ms") print(f" Throughput : {total_samples / total_time:.1f} samples/s") # Save results results = { "split": split, "num_samples": total_samples, "accuracy": metrics.accuracy, "auc_roc": metrics.auc_roc, "f1": metrics.f1, "sensitivity": metrics.sensitivity, "specificity": metrics.specificity, "ece": metrics.ece, "latency_per_sample_ms": round(per_sample_ms, 3), "throughput_samples_per_s": round(total_samples / total_time, 1), } results_path = checkpoint_dir.parent / f"eval_results_{split}.json" results_path.write_text(json.dumps(results, indent=2), encoding="utf-8") print(f"\n Results saved to {results_path}") def main() -> None: parser = argparse.ArgumentParser(description="Evaluate a trained ViT+LoRA model.") parser.add_argument( "--checkpoint", type=Path, default=Path("outputs/classification/best"), help="Path to the saved PEFT checkpoint.", ) parser.add_argument( "--config", type=Path, default=Path("configs/classification_vit_lora.yaml"), help="Path to the classification config.", ) parser.add_argument( "--split", type=str, default="test", choices=["train", "val", "test"], help="Which split to evaluate on.", ) args = parser.parse_args() config = load_yaml_config(args.config, ClassificationConfig) evaluate(args.checkpoint, config, args.split) if __name__ == "__main__": main()