| """Utility functions for model training and evaluation.""" |
|
|
| import os |
| from typing import List |
|
|
| LANGS: List[str] = ["java", "python", "pharo"] |
|
|
|
|
| def load_dataset_splits(base_dir=None, langs=None): |
| """Load dataset splits from CSV files under data/raw. |
| |
| Expects files like data/raw/java_train.csv, data/raw/java_test.csv, etc. |
| Returns a dict mapping split names (e.g. "java_test") to pandas DataFrames. |
| |
| Raises: |
| FileNotFoundError: se la directory base o un file atteso non esiste. |
| ImportError: se pandas non è installato. |
| |
| """ |
| if base_dir is None: |
| base_dir = os.path.join("data", "raw") |
|
|
| if langs is None: |
| langs = LANGS |
|
|
| if not os.path.isdir(base_dir): |
| raise FileNotFoundError( |
| f"CSV datasets not found under {base_dir}; cannot load dataset splits." |
| ) |
|
|
| try: |
| import pandas as pd |
| except Exception as e: |
| raise ImportError("pandas is required to load dataset splits") from e |
|
|
| datasets = {} |
| for lang in langs: |
| for split in ("train", "test"): |
| fname = f"{lang}_{split}.csv" |
| path = os.path.join(base_dir, fname) |
| if not os.path.isfile(path): |
| raise FileNotFoundError(f"Expected dataset file missing: {path}") |
| df = pd.read_csv(path) |
| datasets[f"{lang}_{split}"] = df |
|
|
| return datasets |
|
|
|
|
| def parse_labels_column(df): |
| """Parse the 'labels' column of a DataFrame into lists of integers.""" |
|
|
| def _parse_one(x): |
| if isinstance(x, str): |
| s = x.strip() |
| if s.startswith("[") and s.endswith("]"): |
| s = s[1:-1] |
| return [int(tok) for tok in s.split() if tok] |
| try: |
| import numpy as np |
|
|
| if isinstance(x, np.ndarray): |
| return [int(v) for v in x.tolist()] |
| except ImportError: |
| pass |
| if isinstance(x, (list, tuple)): |
| return [int(v) for v in x] |
| raise ValueError(f"Formato labels non gestito: {type(x)} -> {x!r}") |
|
|
| df["labels"] = df["labels"].apply(_parse_one) |
| return df |
|
|