| """
|
| OM_reg_pair.py — Paired registration using OMorpher with external dataset.
|
|
|
| Loads fixed/moving pairs from a Learn2Reg-style JSON dataset file
|
| (e.g. HippocampusMR_dataset.json) and registers each moving image to its
|
| paired fixed image. Saves registered images, masks, DDFs, source originals,
|
| and evaluation metrics (DSC, ASD, HD) per organ label.
|
|
|
| Usage:
|
| python Scripts/OM_reg_pair.py -C Config/config_om.yaml \
|
| --dataset-json /path/to/HippocampusMR_dataset.json \
|
| --split val
|
|
|
| python Scripts/OM_reg_pair.py -C Config/config_om.yaml \
|
| --dataset-json /path/to/HippocampusMR_dataset.json \
|
| --split test -N 10
|
| """
|
|
|
| import os
|
| import sys
|
|
|
|
|
| sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
|
|
| import csv
|
| import json
|
| import numpy as np
|
| import torch
|
| import torch.nn.functional as F
|
| import nibabel as nib
|
| import yaml
|
| import SimpleITK as sitk
|
| from scipy.ndimage import distance_transform_edt, binary_erosion
|
| from tqdm import tqdm
|
|
|
| import utils
|
| from Dataloader.dataLoader import reverse_axis_order
|
| from OMorpher import OMorpher
|
|
|
|
|
|
|
| import argparse
|
|
|
| parser = argparse.ArgumentParser()
|
| parser.add_argument(
|
| "--config", "-C",
|
| help="Path for the config file",
|
| type=str,
|
| default="Config/config_om.yaml",
|
| required=False,
|
| )
|
| parser.add_argument(
|
| "--dataset-json",
|
| help="Path to the Learn2Reg-style dataset JSON",
|
| type=str,
|
| default="~/rds/rds-airr-p51-TWhPgQVLKbA/Code/Registration/Dataset/HippocampusMR/HippocampusMR_dataset.json",
|
| )
|
| parser.add_argument(
|
| "--split",
|
| help="Which registration split to use: 'val' or 'test'",
|
| type=str,
|
| choices=["val", "test"],
|
| default="val",
|
| )
|
| parser.add_argument(
|
| "--max-samples", "-N",
|
| help="Max number of pairs to register (0 = all)",
|
| type=int,
|
| default=0,
|
| )
|
| args = parser.parse_args()
|
|
|
|
|
|
|
| with open(args.config, "r") as file:
|
| hyp_parameters = yaml.safe_load(file)
|
| print(hyp_parameters)
|
|
|
| hyp_parameters["batchsize"] = 1
|
| model_img_sz = hyp_parameters["img_size"]
|
| timesteps = hyp_parameters["timesteps"]
|
| condition_type = hyp_parameters["condition_type"]
|
| ndims = hyp_parameters["ndims"]
|
|
|
|
|
|
|
| dataset_json_path = os.path.expanduser(args.dataset_json)
|
| dataset_root = os.path.dirname(dataset_json_path)
|
|
|
| with open(dataset_json_path, "r") as f:
|
| dataset_meta = json.load(f)
|
|
|
| dataset_name = dataset_meta.get("name", "UnknownDataset")
|
| print(f"Dataset: {dataset_name}")
|
|
|
|
|
| if args.split == "val":
|
| pairs = dataset_meta.get("registration_val", [])
|
| elif args.split == "test":
|
| pairs = dataset_meta.get("registration_test", [])
|
| else:
|
| raise ValueError(f"Unknown split: {args.split}")
|
|
|
| if args.max_samples > 0:
|
| pairs = pairs[: args.max_samples]
|
|
|
| print(f"Split: {args.split}, Pairs: {len(pairs)}")
|
|
|
|
|
|
|
| _label_lookup = {}
|
| for entry in dataset_meta.get("training", []):
|
| img_base = os.path.basename(entry["image"])
|
| _label_lookup[img_base] = entry.get("label")
|
|
|
|
|
| _label_names = dataset_meta.get("labels", {}).get("0", {})
|
|
|
| organ_label_ids = {int(k): v for k, v in _label_names.items() if int(k) > 0}
|
| print(f"Organ labels for evaluation: {organ_label_ids}")
|
|
|
|
|
|
|
| epoch = f'{hyp_parameters["model_id_str"]}_{hyp_parameters["data_name"]}_{hyp_parameters["net_name"]}'
|
| model_save_path = os.path.join(
|
| f'Models/{hyp_parameters["data_name"]}_{hyp_parameters["net_name"]}/',
|
| str(epoch) + ".pth",
|
| )
|
| print("Loading model from:", model_save_path)
|
|
|
| om = OMorpher(
|
| config=hyp_parameters,
|
| checkpoint_path=model_save_path,
|
| device=str(hyp_parameters.get("device", "cpu")),
|
| )
|
| print(om)
|
|
|
|
|
|
|
| reg_img_savepath = hyp_parameters["reg_img_savepath"]
|
| reg_msk_savepath = hyp_parameters["reg_msk_savepath"]
|
| reg_ddf_savepath = hyp_parameters["reg_ddf_savepath"]
|
|
|
| reg_img_savepath_fullres = reg_img_savepath.rstrip("/") + "_fullres/"
|
| reg_msk_savepath_fullres = reg_msk_savepath.rstrip("/") + "_fullres/"
|
| reg_ddf_savepath_fullres = reg_ddf_savepath.rstrip("/") + "_fullres/"
|
|
|
| eval_dir = os.path.join(reg_img_savepath, "..", "eval")
|
|
|
| for p in [
|
| reg_img_savepath, reg_msk_savepath, reg_ddf_savepath,
|
| reg_img_savepath_fullres, reg_msk_savepath_fullres, reg_ddf_savepath_fullres,
|
| eval_dir,
|
| ]:
|
| os.makedirs(p, exist_ok=True)
|
|
|
|
|
|
|
|
|
|
|
| def resolve_path(rel_path):
|
| """Resolve a relative path from the dataset JSON to an absolute path."""
|
| if os.path.isabs(rel_path):
|
| return rel_path
|
| return os.path.normpath(os.path.join(dataset_root, rel_path))
|
|
|
|
|
| def load_volume(nifti_path):
|
| """Load a NIfTI volume: axis reorder only.
|
|
|
| OMorpher._standardize_img handles: normalize → pad-to-cube → resize to model res.
|
| """
|
| volume = sitk.ReadImage(nifti_path)
|
| volume = sitk.GetArrayFromImage(volume)
|
| volume = reverse_axis_order(volume)
|
| if volume.ndim == 4:
|
| volume = volume[:, :, :, 0]
|
| return volume
|
|
|
|
|
| def load_label(nifti_path):
|
| """Load a NIfTI label map: axis reorder only.
|
|
|
| OMorpher._standardize_label handles: pad-to-cube → resize to model res (nearest).
|
| """
|
| label = sitk.ReadImage(nifti_path)
|
| label = sitk.GetArrayFromImage(label)
|
| label = reverse_axis_order(label)
|
| if label.ndim > 3:
|
| label = label[:, :, :, 0]
|
| return label
|
|
|
|
|
| def get_label_path_for_image(image_rel_path):
|
| """Find the label path for an image by looking up the training entries."""
|
| img_base = os.path.basename(image_rel_path)
|
| label_rel = _label_lookup.get(img_base)
|
| if label_rel is None:
|
| return None
|
| return resolve_path(label_rel)
|
|
|
|
|
| def split_label_classes(label_map, class_ids):
|
| """Split a multi-class label map into per-class binary masks.
|
|
|
| Returns a dict {class_id: binary_numpy_array}.
|
| """
|
| masks = {}
|
| for cid in class_ids:
|
| masks[cid] = (label_map == cid).astype(np.float32)
|
| return masks
|
|
|
|
|
| def get_volume_name(path):
|
| """Extract a short name from a NIfTI file path."""
|
| name = os.path.basename(path)
|
| for ext in [".nii.gz", ".nii"]:
|
| if name.endswith(ext):
|
| name = name[: -len(ext)]
|
| break
|
| return name
|
|
|
|
|
|
|
|
|
|
|
| def _surface_distances(pred, gt):
|
| """Compute directed surface distances between two binary masks."""
|
| pred_bool = pred > 0.5
|
| gt_bool = gt > 0.5
|
|
|
| if not np.any(pred_bool) or not np.any(gt_bool):
|
| return None, None
|
|
|
| struct = None
|
| pred_surface = pred_bool ^ binary_erosion(pred_bool, structure=struct)
|
| gt_surface = gt_bool ^ binary_erosion(gt_bool, structure=struct)
|
|
|
| if not np.any(pred_surface):
|
| pred_surface = pred_bool
|
| if not np.any(gt_surface):
|
| gt_surface = gt_bool
|
|
|
| dt_gt = distance_transform_edt(~gt_surface)
|
| dt_pred = distance_transform_edt(~pred_surface)
|
|
|
| return dt_gt[pred_surface], dt_pred[gt_surface]
|
|
|
|
|
| def compute_dsc(pred, gt):
|
| """Dice Similarity Coefficient."""
|
| pred_bool = pred > 0.5
|
| gt_bool = gt > 0.5
|
| intersection = np.sum(pred_bool & gt_bool)
|
| denom = np.sum(pred_bool) + np.sum(gt_bool)
|
| if denom == 0:
|
| return 1.0
|
| return 2.0 * float(intersection) / float(denom)
|
|
|
|
|
| def compute_asd(pred, gt):
|
| """Average (symmetric) Surface Distance."""
|
| d1, d2 = _surface_distances(pred, gt)
|
| if d1 is None:
|
| return float("nan")
|
| return (np.mean(d1) + np.mean(d2)) / 2.0
|
|
|
|
|
| def compute_hd(pred, gt):
|
| """Hausdorff Distance (maximum of directed HDs)."""
|
| d1, d2 = _surface_distances(pred, gt)
|
| if d1 is None:
|
| return float("nan")
|
| return float(max(np.max(d1), np.max(d2)))
|
|
|
|
|
| def compute_negdetj_pct(ddf, ndims=3):
|
| """Percent of voxels with negative Jacobian determinant.
|
|
|
| Args:
|
| ddf: displacement field tensor [1, ndims, ...] or numpy array.
|
| ndims: 2 or 3.
|
| Returns:
|
| Percentage of voxels where det(Jacobian) < 0.
|
| """
|
| if isinstance(ddf, torch.Tensor):
|
| ddf = ddf.detach().cpu().numpy()
|
|
|
| if ddf.ndim == ndims + 2:
|
| ddf = ddf[0]
|
|
|
|
|
| if ndims == 3:
|
|
|
|
|
| dux_dx = np.diff(ddf[0], axis=0, append=ddf[0, -1:, :, :])
|
| duy_dx = np.diff(ddf[1], axis=0, append=ddf[1, -1:, :, :])
|
| duz_dx = np.diff(ddf[2], axis=0, append=ddf[2, -1:, :, :])
|
|
|
| dux_dy = np.diff(ddf[0], axis=1, append=ddf[0, :, -1:, :])
|
| duy_dy = np.diff(ddf[1], axis=1, append=ddf[1, :, -1:, :])
|
| duz_dy = np.diff(ddf[2], axis=1, append=ddf[2, :, -1:, :])
|
|
|
| dux_dz = np.diff(ddf[0], axis=2, append=ddf[0, :, :, -1:])
|
| duy_dz = np.diff(ddf[1], axis=2, append=ddf[1, :, :, -1:])
|
| duz_dz = np.diff(ddf[2], axis=2, append=ddf[2, :, :, -1:])
|
|
|
|
|
| j11 = 1.0 + dux_dx; j12 = dux_dy; j13 = dux_dz
|
| j21 = duy_dx; j22 = 1.0 + duy_dy; j23 = duy_dz
|
| j31 = duz_dx; j32 = duz_dy; j33 = 1.0 + duz_dz
|
|
|
| detj = (
|
| j11 * (j22 * j33 - j23 * j32)
|
| - j12 * (j21 * j33 - j23 * j31)
|
| + j13 * (j21 * j32 - j22 * j31)
|
| )
|
| elif ndims == 2:
|
| dux_dx = np.diff(ddf[0], axis=0, append=ddf[0, -1:, :])
|
| duy_dx = np.diff(ddf[1], axis=0, append=ddf[1, -1:, :])
|
|
|
| dux_dy = np.diff(ddf[0], axis=1, append=ddf[0, :, -1:])
|
| duy_dy = np.diff(ddf[1], axis=1, append=ddf[1, :, -1:])
|
|
|
| detj = (1.0 + dux_dx) * (1.0 + duy_dy) - dux_dy * duy_dx
|
| else:
|
| raise ValueError(f"Unsupported ndims={ndims}")
|
|
|
| n_neg = np.sum(detj < 0)
|
| n_total = detj.size
|
| return 100.0 * float(n_neg) / float(n_total)
|
|
|
|
|
|
|
|
|
|
|
| metrics = {
|
| cid: {"dsc": {}, "asd": {}, "hd": {}}
|
| for cid in organ_label_ids
|
| }
|
|
|
| metrics_pre = {
|
| cid: {"dsc": {}, "asd": {}, "hd": {}}
|
| for cid in organ_label_ids
|
| }
|
|
|
|
|
| negdetj_pct = {}
|
|
|
|
|
| pair_info = []
|
|
|
|
|
|
|
| with torch.no_grad():
|
| for pair_idx, pair in enumerate(tqdm(pairs, desc="Pairs")):
|
| fixed_rel = pair["fixed"]
|
| moving_rel = pair["moving"]
|
|
|
| fixed_path = resolve_path(fixed_rel)
|
| moving_path = resolve_path(moving_rel)
|
|
|
| fixed_name = get_volume_name(fixed_rel)
|
| moving_name = get_volume_name(moving_rel)
|
| pair_tag = f"Tgt{pair_idx:04d}_Src{pair_idx:04d}"
|
|
|
| pair_info.append((pair_idx, fixed_name, moving_name))
|
| print(f"\n [{pair_idx}] Fixed: {fixed_name}, Moving: {moving_name}")
|
|
|
|
|
| fixed_vol = load_volume(fixed_path)
|
| moving_vol = load_volume(moving_path)
|
|
|
|
|
| fixed_label_path = get_label_path_for_image(fixed_rel)
|
| moving_label_path = get_label_path_for_image(moving_rel)
|
|
|
| fixed_label_map = None
|
| moving_label_map = None
|
| if fixed_label_path is not None and os.path.exists(fixed_label_path):
|
| fixed_label_map = load_label(fixed_label_path)
|
| if moving_label_path is not None and os.path.exists(moving_label_path):
|
| moving_label_map = load_label(moving_label_path)
|
|
|
|
|
|
|
| om.set_init_img(moving_vol)
|
| src_img_model = om._init_img.clone()
|
| src_img_fullres = om._init_img_raw.clone()
|
| src_orig_sz = list(src_img_fullres.shape[2:])
|
|
|
|
|
| om.set_init_img(fixed_vol)
|
| tgt_img_model = om._init_img.clone()
|
| tgt_img_fullres = om._init_img_raw.clone()
|
|
|
|
|
| src_mask_model, src_mask_fullres = None, None
|
| tgt_mask_model, tgt_mask_fullres = None, None
|
|
|
| if moving_label_map is not None:
|
|
|
| src_class_masks = split_label_classes(moving_label_map, organ_label_ids.keys())
|
| src_masks_model = []
|
| src_masks_fullres = []
|
| om.set_init_img(moving_vol)
|
| for cid in sorted(organ_label_ids.keys()):
|
| m_model, m_fullres = om._standardize_label(src_class_masks[cid])
|
| src_masks_model.append(m_model)
|
| src_masks_fullres.append(m_fullres)
|
| src_mask_model = torch.cat(src_masks_model, dim=1)
|
| src_mask_fullres = torch.cat(src_masks_fullres, dim=1)
|
|
|
| if fixed_label_map is not None:
|
| tgt_class_masks = split_label_classes(fixed_label_map, organ_label_ids.keys())
|
| tgt_masks_model = []
|
| tgt_masks_fullres = []
|
| om.set_init_img(fixed_vol)
|
| for cid in sorted(organ_label_ids.keys()):
|
| m_model, m_fullres = om._standardize_label(tgt_class_masks[cid])
|
| tgt_masks_model.append(m_model)
|
| tgt_masks_fullres.append(m_fullres)
|
| tgt_mask_model = torch.cat(tgt_masks_model, dim=1)
|
| tgt_mask_fullres = torch.cat(tgt_masks_fullres, dim=1)
|
|
|
|
|
| nib.save(
|
| utils.converet_to_nibabel(tgt_img_model, ndims=ndims),
|
| os.path.join(reg_img_savepath, f"{pair_tag}_TGT_ORG.nii.gz"),
|
| )
|
| if tgt_mask_model is not None:
|
| nib.save(
|
| utils.converet_to_nibabel(tgt_mask_model, ndims=ndims),
|
| os.path.join(reg_msk_savepath, f"{pair_tag}_TGT_ORG_GT.nii.gz"),
|
| )
|
|
|
|
|
| nib.save(
|
| utils.converet_to_nibabel(src_img_model, ndims=ndims),
|
| os.path.join(reg_img_savepath, f"Src{pair_idx:04d}_ORG.nii.gz"),
|
| )
|
| if src_mask_model is not None:
|
| nib.save(
|
| utils.converet_to_nibabel(src_mask_model, ndims=ndims),
|
| os.path.join(reg_msk_savepath, f"Src{pair_idx:04d}_ORG_GT.nii.gz"),
|
| )
|
|
|
|
|
| nib.save(
|
| utils.converet_to_nibabel(tgt_img_fullres, ndims=ndims),
|
| os.path.join(reg_img_savepath_fullres, f"{pair_tag}_TGT_ORG.nii.gz"),
|
| )
|
| if tgt_mask_fullres is not None:
|
| nib.save(
|
| utils.converet_to_nibabel(tgt_mask_fullres, ndims=ndims),
|
| os.path.join(reg_msk_savepath_fullres, f"{pair_tag}_TGT_ORG_GT.nii.gz"),
|
| )
|
|
|
|
|
| nib.save(
|
| utils.converet_to_nibabel(src_img_fullres, ndims=ndims),
|
| os.path.join(reg_img_savepath_fullres, f"Src{pair_idx:04d}_ORG.nii.gz"),
|
| )
|
| if src_mask_fullres is not None:
|
| nib.save(
|
| utils.converet_to_nibabel(src_mask_fullres, ndims=ndims),
|
| os.path.join(reg_msk_savepath_fullres, f"Src{pair_idx:04d}_ORG_GT.nii.gz"),
|
| )
|
|
|
|
|
| om.set_init_img(src_img_model)
|
| om.set_cond_img(tgt_img_model.clone().detach())
|
|
|
| om.predict(
|
| T=[None, timesteps],
|
| proc_type=condition_type,
|
| )
|
|
|
| ddf_comp = om.get_def()
|
|
|
|
|
| neg_pct = compute_negdetj_pct(ddf_comp, ndims=ndims)
|
| negdetj_pct[pair_idx] = neg_pct
|
| print(f" %|J|<0 = {neg_pct:.4f}%")
|
|
|
|
|
| img_rec = om.apply_def(
|
| img=src_img_model, ddf=ddf_comp, padding_mode="zeros",
|
| )
|
| nib.save(
|
| utils.converet_to_nibabel(img_rec, ndims=ndims),
|
| os.path.join(reg_img_savepath, f"{pair_tag}.nii.gz"),
|
| )
|
|
|
|
|
| msk_rec = None
|
| if src_mask_model is not None:
|
| msk_rec = om.apply_def(
|
| img=src_mask_model, ddf=ddf_comp,
|
| padding_mode="zeros", resample_mode="nearest",
|
| )
|
| nib.save(
|
| utils.converet_to_nibabel(msk_rec, ndims=ndims),
|
| os.path.join(reg_msk_savepath, f"{pair_tag}_GT.nii.gz"),
|
| )
|
|
|
|
|
| nib.save(
|
| utils.converet_to_nibabel(ddf_comp, ndims=ndims),
|
| os.path.join(reg_ddf_savepath, f"{pair_tag}.nii.gz"),
|
| )
|
|
|
|
|
| img_rec_fullres = om.apply_def(
|
| img=src_img_fullres, ddf=ddf_comp, padding_mode="border",
|
| )
|
| nib.save(
|
| utils.converet_to_nibabel(img_rec_fullres, ndims=ndims),
|
| os.path.join(reg_img_savepath_fullres, f"{pair_tag}.nii.gz"),
|
| )
|
|
|
|
|
| msk_rec_fullres = None
|
| if src_mask_fullres is not None:
|
| msk_rec_fullres = om.apply_def(
|
| img=src_mask_fullres, ddf=ddf_comp,
|
| padding_mode="zeros", resample_mode="nearest",
|
| )
|
| nib.save(
|
| utils.converet_to_nibabel(msk_rec_fullres, ndims=ndims),
|
| os.path.join(reg_msk_savepath_fullres, f"{pair_tag}_GT.nii.gz"),
|
| )
|
|
|
|
|
| ddf_fullres = F.interpolate(
|
| ddf_comp, size=src_orig_sz, mode="trilinear", align_corners=False,
|
| )
|
| nib.save(
|
| utils.converet_to_nibabel(ddf_fullres, ndims=ndims),
|
| os.path.join(reg_ddf_savepath_fullres, f"{pair_tag}.nii.gz"),
|
| )
|
|
|
|
|
| if (
|
| organ_label_ids
|
| and src_mask_fullres is not None
|
| and tgt_mask_fullres is not None
|
| ):
|
| for ch_idx, cid in enumerate(sorted(organ_label_ids.keys())):
|
| lk = organ_label_ids[cid]
|
| tgt_mask_np = tgt_mask_fullres[0, ch_idx].cpu().numpy()
|
| src_mask_np = src_mask_fullres[0, ch_idx].cpu().numpy()
|
|
|
| if np.all(tgt_mask_np < 0) or np.all(src_mask_np < 0):
|
| continue
|
|
|
|
|
| pre_dsc = compute_dsc(src_mask_np, tgt_mask_np)
|
| pre_asd = compute_asd(src_mask_np, tgt_mask_np)
|
| pre_hd = compute_hd(src_mask_np, tgt_mask_np)
|
|
|
| metrics_pre[cid]["dsc"][pair_idx] = pre_dsc
|
| metrics_pre[cid]["asd"][pair_idx] = pre_asd
|
| metrics_pre[cid]["hd"][pair_idx] = pre_hd
|
|
|
|
|
| if msk_rec_fullres is not None:
|
| reg_mask_np = msk_rec_fullres[0, ch_idx].cpu().numpy()
|
| post_dsc = compute_dsc(reg_mask_np, tgt_mask_np)
|
| post_asd = compute_asd(reg_mask_np, tgt_mask_np)
|
| post_hd = compute_hd(reg_mask_np, tgt_mask_np)
|
| else:
|
| post_dsc = float("nan")
|
| post_asd = float("nan")
|
| post_hd = float("nan")
|
|
|
| metrics[cid]["dsc"][pair_idx] = post_dsc
|
| metrics[cid]["asd"][pair_idx] = post_asd
|
| metrics[cid]["hd"][pair_idx] = post_hd
|
|
|
| print(
|
| f" [{lk}] PRE DSC={pre_dsc:.4f} ASD={pre_asd:.2f} HD={pre_hd:.2f}"
|
| )
|
| print(
|
| f" [{lk}] POST DSC={post_dsc:.4f} ASD={post_asd:.2f} HD={post_hd:.2f}"
|
| )
|
|
|
| print("\nPaired registration complete.")
|
|
|
|
|
|
|
| n_pairs = len(pairs)
|
|
|
| def _fmt(val):
|
| if val is None:
|
| return ""
|
| if np.isnan(val):
|
| return "NaN"
|
| return f"{val:.6f}"
|
|
|
|
|
|
|
| negdetj_csv_path = os.path.join(eval_dir, "negdetj_pct.csv")
|
| with open(negdetj_csv_path, "w", newline="") as f:
|
| writer = csv.writer(f)
|
| writer.writerow(["pair_idx", "fixed", "moving", "negdetj_pct"])
|
| for pi, fixed_name, moving_name in pair_info:
|
| writer.writerow([pi, fixed_name, moving_name, _fmt(negdetj_pct.get(pi))])
|
| print(f"Saved {negdetj_csv_path}")
|
|
|
| for cid in sorted(organ_label_ids.keys()):
|
| lk = organ_label_ids[cid]
|
| prefix = f"{lk}_" if len(organ_label_ids) > 1 else ""
|
|
|
| for metric_name in ["dsc", "asd", "hd"]:
|
| mn_upper = metric_name.upper()
|
| csv_path = os.path.join(eval_dir, f"{prefix}{metric_name}.csv")
|
| with open(csv_path, "w", newline="") as f:
|
| writer = csv.writer(f)
|
| writer.writerow([
|
| "pair_idx", "fixed", "moving",
|
| f"pre_{mn_upper}", f"post_{mn_upper}",
|
| ])
|
| for pi, fixed_name, moving_name in pair_info:
|
| pre_val = metrics_pre[cid][metric_name].get(pi)
|
| post_val = metrics[cid][metric_name].get(pi)
|
| writer.writerow([
|
| pi, fixed_name, moving_name,
|
| _fmt(pre_val), _fmt(post_val),
|
| ])
|
| print(f"Saved {csv_path}")
|
|
|
|
|
| overall_path = os.path.join(eval_dir, "overall.csv")
|
| with open(overall_path, "w", newline="") as f:
|
| writer = csv.writer(f)
|
| writer.writerow([
|
| "label", "metric",
|
| "pre_mean", "pre_std",
|
| "post_mean", "post_std",
|
| "n_pairs",
|
| ])
|
|
|
| negdetj_vals = [v for v in negdetj_pct.values() if not np.isnan(v)]
|
| writer.writerow([
|
| "ALL",
|
| "%|J|<0",
|
| "", "",
|
| _fmt(np.mean(negdetj_vals) if negdetj_vals else float("nan")),
|
| _fmt(np.std(negdetj_vals) if negdetj_vals else float("nan")),
|
| len(negdetj_vals),
|
| ])
|
| for cid in sorted(organ_label_ids.keys()):
|
| lk = organ_label_ids[cid]
|
| for metric_name in ["dsc", "asd", "hd"]:
|
| pre_vals = [
|
| v for v in metrics_pre[cid][metric_name].values()
|
| if not np.isnan(v)
|
| ]
|
| post_vals = [
|
| v for v in metrics[cid][metric_name].values()
|
| if not np.isnan(v)
|
| ]
|
| pre_mean = np.mean(pre_vals) if pre_vals else float("nan")
|
| pre_std = np.std(pre_vals) if pre_vals else float("nan")
|
| post_mean = np.mean(post_vals) if post_vals else float("nan")
|
| post_std = np.std(post_vals) if post_vals else float("nan")
|
| n = max(len(pre_vals), len(post_vals))
|
| writer.writerow([
|
| lk,
|
| metric_name.upper(),
|
| _fmt(pre_mean), _fmt(pre_std),
|
| _fmt(post_mean), _fmt(post_std),
|
| n,
|
| ])
|
| print(f"Saved {overall_path}")
|
|
|