AI_Project / plot_metrics.py
Untraceable09's picture
Add files using upload-large-folder tool
ad34663 verified
Raw
History Blame Contribute Delete
1.99 kB
import json
import pandas as pd
import matplotlib.pyplot as plt
import os
# Classification Metrics
try:
with open('outputs/classification/training_history.json', 'r') as f:
cls_history = json.load(f)
epochs = [x['epoch'] for x in cls_history]
train_loss = [x['train_loss'] for x in cls_history]
val_loss = [x['val_loss'] for x in cls_history]
val_auc = [x['val_auc_roc'] for x in cls_history]
plt.figure(figsize=(12, 5))
plt.subplot(1, 2, 1)
plt.plot(epochs, train_loss, label='Train Loss')
plt.plot(epochs, val_loss, label='Val Loss')
plt.title('Classification Loss (ViT+LoRA)')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()
plt.subplot(1, 2, 2)
plt.plot(epochs, val_auc, label='Val AUC-ROC', color='green')
plt.title('Classification AUC-ROC')
plt.xlabel('Epoch')
plt.ylabel('AUC-ROC')
plt.legend()
plt.tight_layout()
plt.savefig('classification_metrics.png')
plt.close()
except Exception as e:
print("Error plotting classification:", e)
import csv
# Detection Metrics
try:
epochs = []
mAP50 = []
mAP50_95 = []
with open('runs/detect/outputs/detection/yolov8_tn50002/results.csv', 'r') as f:
reader = csv.DictReader(f)
# Strip spaces from column names
reader.fieldnames = [name.strip() for name in reader.fieldnames]
for row in reader:
epochs.append(int(row['epoch']))
mAP50.append(float(row['metrics/mAP50(B)']))
mAP50_95.append(float(row['metrics/mAP50-95(B)']))
plt.figure(figsize=(8, 5))
plt.plot(epochs, mAP50, label='mAP@50')
plt.plot(epochs, mAP50_95, label='mAP@50-95')
plt.title('Detection mAP (YOLOv8)')
plt.xlabel('Epoch')
plt.ylabel('mAP')
plt.legend()
plt.tight_layout()
plt.savefig('detection_metrics.png')
plt.close()
except Exception as e:
print("Error plotting detection:", e)
print("Plotting complete.")