""" OM_reg_pair.py — Paired registration using OMorpher with external dataset. Loads fixed/moving pairs from a Learn2Reg-style JSON dataset file (e.g. HippocampusMR_dataset.json) and registers each moving image to its paired fixed image. Saves registered images, masks, DDFs, source originals, and evaluation metrics (DSC, ASD, HD) per organ label. Usage: python Scripts/OM_reg_pair.py -C Config/config_om.yaml \ --dataset-json /path/to/HippocampusMR_dataset.json \ --split val python Scripts/OM_reg_pair.py -C Config/config_om.yaml \ --dataset-json /path/to/HippocampusMR_dataset.json \ --split test -N 10 """ import os import sys # Add project root to path so imports work from Scripts/ sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) import csv import json import numpy as np import torch import torch.nn.functional as F import nibabel as nib import yaml import SimpleITK as sitk from scipy.ndimage import distance_transform_edt, binary_erosion from tqdm import tqdm import utils from Dataloader.dataLoader import reverse_axis_order from OMorpher import OMorpher # ========== CLI ========== import argparse parser = argparse.ArgumentParser() parser.add_argument( "--config", "-C", help="Path for the config file", type=str, default="Config/config_om.yaml", required=False, ) parser.add_argument( "--dataset-json", help="Path to the Learn2Reg-style dataset JSON", type=str, default="~/rds/rds-airr-p51-TWhPgQVLKbA/Code/Registration/Dataset/HippocampusMR/HippocampusMR_dataset.json", ) parser.add_argument( "--split", help="Which registration split to use: 'val' or 'test'", type=str, choices=["val", "test"], default="val", ) parser.add_argument( "--max-samples", "-N", help="Max number of pairs to register (0 = all)", type=int, default=0, ) args = parser.parse_args() # ========== Config ========== with open(args.config, "r") as file: hyp_parameters = yaml.safe_load(file) print(hyp_parameters) hyp_parameters["batchsize"] = 1 model_img_sz = hyp_parameters["img_size"] timesteps = hyp_parameters["timesteps"] condition_type = hyp_parameters["condition_type"] ndims = hyp_parameters["ndims"] # ========== Load external dataset JSON ========== dataset_json_path = os.path.expanduser(args.dataset_json) dataset_root = os.path.dirname(dataset_json_path) with open(dataset_json_path, "r") as f: dataset_meta = json.load(f) dataset_name = dataset_meta.get("name", "UnknownDataset") print(f"Dataset: {dataset_name}") # Select registration split if args.split == "val": pairs = dataset_meta.get("registration_val", []) elif args.split == "test": pairs = dataset_meta.get("registration_test", []) else: raise ValueError(f"Unknown split: {args.split}") if args.max_samples > 0: pairs = pairs[: args.max_samples] print(f"Split: {args.split}, Pairs: {len(pairs)}") # Build label lookup: image basename -> label relative path # from the "training" entries in the JSON _label_lookup = {} for entry in dataset_meta.get("training", []): img_base = os.path.basename(entry["image"]) _label_lookup[img_base] = entry.get("label") # Label class names (from JSON: "0": "background", "1": "head", "2": "tail") _label_names = dataset_meta.get("labels", {}).get("0", {}) # Organ labels are all non-background classes organ_label_ids = {int(k): v for k, v in _label_names.items() if int(k) > 0} print(f"Organ labels for evaluation: {organ_label_ids}") # ========== OMorpher setup ========== epoch = f'{hyp_parameters["model_id_str"]}_{hyp_parameters["data_name"]}_{hyp_parameters["net_name"]}' model_save_path = os.path.join( f'Models/{hyp_parameters["data_name"]}_{hyp_parameters["net_name"]}/', str(epoch) + ".pth", ) print("Loading model from:", model_save_path) om = OMorpher( config=hyp_parameters, checkpoint_path=model_save_path, device=str(hyp_parameters.get("device", "cpu")), ) print(om) # ========== Output directories ========== reg_img_savepath = hyp_parameters["reg_img_savepath"] reg_msk_savepath = hyp_parameters["reg_msk_savepath"] reg_ddf_savepath = hyp_parameters["reg_ddf_savepath"] reg_img_savepath_fullres = reg_img_savepath.rstrip("/") + "_fullres/" reg_msk_savepath_fullres = reg_msk_savepath.rstrip("/") + "_fullres/" reg_ddf_savepath_fullres = reg_ddf_savepath.rstrip("/") + "_fullres/" eval_dir = os.path.join(reg_img_savepath, "..", "eval") for p in [ reg_img_savepath, reg_msk_savepath, reg_ddf_savepath, reg_img_savepath_fullres, reg_msk_savepath_fullres, reg_ddf_savepath_fullres, eval_dir, ]: os.makedirs(p, exist_ok=True) # ========== Helper functions ========== def resolve_path(rel_path): """Resolve a relative path from the dataset JSON to an absolute path.""" if os.path.isabs(rel_path): return rel_path return os.path.normpath(os.path.join(dataset_root, rel_path)) def load_volume(nifti_path): """Load a NIfTI volume: axis reorder only. OMorpher._standardize_img handles: normalize → pad-to-cube → resize to model res. """ volume = sitk.ReadImage(nifti_path) volume = sitk.GetArrayFromImage(volume) volume = reverse_axis_order(volume) if volume.ndim == 4: volume = volume[:, :, :, 0] return volume def load_label(nifti_path): """Load a NIfTI label map: axis reorder only. OMorpher._standardize_label handles: pad-to-cube → resize to model res (nearest). """ label = sitk.ReadImage(nifti_path) label = sitk.GetArrayFromImage(label) label = reverse_axis_order(label) if label.ndim > 3: label = label[:, :, :, 0] return label def get_label_path_for_image(image_rel_path): """Find the label path for an image by looking up the training entries.""" img_base = os.path.basename(image_rel_path) label_rel = _label_lookup.get(img_base) if label_rel is None: return None return resolve_path(label_rel) def split_label_classes(label_map, class_ids): """Split a multi-class label map into per-class binary masks. Returns a dict {class_id: binary_numpy_array}. """ masks = {} for cid in class_ids: masks[cid] = (label_map == cid).astype(np.float32) return masks def get_volume_name(path): """Extract a short name from a NIfTI file path.""" name = os.path.basename(path) for ext in [".nii.gz", ".nii"]: if name.endswith(ext): name = name[: -len(ext)] break return name # ---------- Evaluation metrics ---------- def _surface_distances(pred, gt): """Compute directed surface distances between two binary masks.""" pred_bool = pred > 0.5 gt_bool = gt > 0.5 if not np.any(pred_bool) or not np.any(gt_bool): return None, None struct = None pred_surface = pred_bool ^ binary_erosion(pred_bool, structure=struct) gt_surface = gt_bool ^ binary_erosion(gt_bool, structure=struct) if not np.any(pred_surface): pred_surface = pred_bool if not np.any(gt_surface): gt_surface = gt_bool dt_gt = distance_transform_edt(~gt_surface) dt_pred = distance_transform_edt(~pred_surface) return dt_gt[pred_surface], dt_pred[gt_surface] def compute_dsc(pred, gt): """Dice Similarity Coefficient.""" pred_bool = pred > 0.5 gt_bool = gt > 0.5 intersection = np.sum(pred_bool & gt_bool) denom = np.sum(pred_bool) + np.sum(gt_bool) if denom == 0: return 1.0 return 2.0 * float(intersection) / float(denom) def compute_asd(pred, gt): """Average (symmetric) Surface Distance.""" d1, d2 = _surface_distances(pred, gt) if d1 is None: return float("nan") return (np.mean(d1) + np.mean(d2)) / 2.0 def compute_hd(pred, gt): """Hausdorff Distance (maximum of directed HDs).""" d1, d2 = _surface_distances(pred, gt) if d1 is None: return float("nan") return float(max(np.max(d1), np.max(d2))) def compute_negdetj_pct(ddf, ndims=3): """Percent of voxels with negative Jacobian determinant. Args: ddf: displacement field tensor [1, ndims, ...] or numpy array. ndims: 2 or 3. Returns: Percentage of voxels where det(Jacobian) < 0. """ if isinstance(ddf, torch.Tensor): ddf = ddf.detach().cpu().numpy() # ddf shape: [1, C, ...] or [C, ...] if ddf.ndim == ndims + 2: ddf = ddf[0] # remove batch dim -> [C, ...] # Compute spatial gradients via finite differences (forward diff, clipped) if ndims == 3: # ddf: [3, D, H, W] # Derivatives along each spatial axis dux_dx = np.diff(ddf[0], axis=0, append=ddf[0, -1:, :, :]) duy_dx = np.diff(ddf[1], axis=0, append=ddf[1, -1:, :, :]) duz_dx = np.diff(ddf[2], axis=0, append=ddf[2, -1:, :, :]) dux_dy = np.diff(ddf[0], axis=1, append=ddf[0, :, -1:, :]) duy_dy = np.diff(ddf[1], axis=1, append=ddf[1, :, -1:, :]) duz_dy = np.diff(ddf[2], axis=1, append=ddf[2, :, -1:, :]) dux_dz = np.diff(ddf[0], axis=2, append=ddf[0, :, :, -1:]) duy_dz = np.diff(ddf[1], axis=2, append=ddf[1, :, :, -1:]) duz_dz = np.diff(ddf[2], axis=2, append=ddf[2, :, :, -1:]) # Jacobian = I + du/dx j11 = 1.0 + dux_dx; j12 = dux_dy; j13 = dux_dz j21 = duy_dx; j22 = 1.0 + duy_dy; j23 = duy_dz j31 = duz_dx; j32 = duz_dy; j33 = 1.0 + duz_dz detj = ( j11 * (j22 * j33 - j23 * j32) - j12 * (j21 * j33 - j23 * j31) + j13 * (j21 * j32 - j22 * j31) ) elif ndims == 2: dux_dx = np.diff(ddf[0], axis=0, append=ddf[0, -1:, :]) duy_dx = np.diff(ddf[1], axis=0, append=ddf[1, -1:, :]) dux_dy = np.diff(ddf[0], axis=1, append=ddf[0, :, -1:]) duy_dy = np.diff(ddf[1], axis=1, append=ddf[1, :, -1:]) detj = (1.0 + dux_dx) * (1.0 + duy_dy) - dux_dy * duy_dx else: raise ValueError(f"Unsupported ndims={ndims}") n_neg = np.sum(detj < 0) n_total = detj.size return 100.0 * float(n_neg) / float(n_total) # ========== Prepare evaluation structures ========== # metrics[class_id][metric_name][pair_idx] = value (post-registration) metrics = { cid: {"dsc": {}, "asd": {}, "hd": {}} for cid in organ_label_ids } # metrics_pre: same structure but for pre-registration (source vs target, no deformation) metrics_pre = { cid: {"dsc": {}, "asd": {}, "hd": {}} for cid in organ_label_ids } # Per-pair DDF quality metric (not per-class) negdetj_pct = {} # pair_idx -> percentage of negative Jacobian determinant # Also collect per-pair info for the CSV pair_info = [] # list of (pair_idx, fixed_name, moving_name) # ========== Paired registration ========== with torch.no_grad(): for pair_idx, pair in enumerate(tqdm(pairs, desc="Pairs")): fixed_rel = pair["fixed"] moving_rel = pair["moving"] fixed_path = resolve_path(fixed_rel) moving_path = resolve_path(moving_rel) fixed_name = get_volume_name(fixed_rel) moving_name = get_volume_name(moving_rel) pair_tag = f"Tgt{pair_idx:04d}_Src{pair_idx:04d}" pair_info.append((pair_idx, fixed_name, moving_name)) print(f"\n [{pair_idx}] Fixed: {fixed_name}, Moving: {moving_name}") # --- Load volumes --- fixed_vol = load_volume(fixed_path) moving_vol = load_volume(moving_path) # --- Load labels (if available) --- fixed_label_path = get_label_path_for_image(fixed_rel) moving_label_path = get_label_path_for_image(moving_rel) fixed_label_map = None moving_label_map = None if fixed_label_path is not None and os.path.exists(fixed_label_path): fixed_label_map = load_label(fixed_label_path) if moving_label_path is not None and os.path.exists(moving_label_path): moving_label_map = load_label(moving_label_path) # --- Prepare tensors via OMorpher --- # Set moving image as init (source to be deformed) om.set_init_img(moving_vol) src_img_model = om._init_img.clone() src_img_fullres = om._init_img_raw.clone() src_orig_sz = list(src_img_fullres.shape[2:]) # Set fixed image as conditioning (target) om.set_init_img(fixed_vol) tgt_img_model = om._init_img.clone() tgt_img_fullres = om._init_img_raw.clone() # Standardize labels through OMorpher src_mask_model, src_mask_fullres = None, None tgt_mask_model, tgt_mask_fullres = None, None if moving_label_map is not None: # Split into per-class binary masks, stack as channels src_class_masks = split_label_classes(moving_label_map, organ_label_ids.keys()) src_masks_model = [] src_masks_fullres = [] om.set_init_img(moving_vol) # reset so _standardize_label uses correct shape for cid in sorted(organ_label_ids.keys()): m_model, m_fullres = om._standardize_label(src_class_masks[cid]) src_masks_model.append(m_model) src_masks_fullres.append(m_fullres) src_mask_model = torch.cat(src_masks_model, dim=1) src_mask_fullres = torch.cat(src_masks_fullres, dim=1) if fixed_label_map is not None: tgt_class_masks = split_label_classes(fixed_label_map, organ_label_ids.keys()) tgt_masks_model = [] tgt_masks_fullres = [] om.set_init_img(fixed_vol) # reset so _standardize_label uses correct shape for cid in sorted(organ_label_ids.keys()): m_model, m_fullres = om._standardize_label(tgt_class_masks[cid]) tgt_masks_model.append(m_model) tgt_masks_fullres.append(m_fullres) tgt_mask_model = torch.cat(tgt_masks_model, dim=1) tgt_mask_fullres = torch.cat(tgt_masks_fullres, dim=1) # --- Save target (fixed) original at model resolution --- nib.save( utils.converet_to_nibabel(tgt_img_model, ndims=ndims), os.path.join(reg_img_savepath, f"{pair_tag}_TGT_ORG.nii.gz"), ) if tgt_mask_model is not None: nib.save( utils.converet_to_nibabel(tgt_mask_model, ndims=ndims), os.path.join(reg_msk_savepath, f"{pair_tag}_TGT_ORG_GT.nii.gz"), ) # --- Save source (moving) original at model resolution --- nib.save( utils.converet_to_nibabel(src_img_model, ndims=ndims), os.path.join(reg_img_savepath, f"Src{pair_idx:04d}_ORG.nii.gz"), ) if src_mask_model is not None: nib.save( utils.converet_to_nibabel(src_mask_model, ndims=ndims), os.path.join(reg_msk_savepath, f"Src{pair_idx:04d}_ORG_GT.nii.gz"), ) # --- Save target original at full resolution --- nib.save( utils.converet_to_nibabel(tgt_img_fullres, ndims=ndims), os.path.join(reg_img_savepath_fullres, f"{pair_tag}_TGT_ORG.nii.gz"), ) if tgt_mask_fullres is not None: nib.save( utils.converet_to_nibabel(tgt_mask_fullres, ndims=ndims), os.path.join(reg_msk_savepath_fullres, f"{pair_tag}_TGT_ORG_GT.nii.gz"), ) # --- Save source original at full resolution --- nib.save( utils.converet_to_nibabel(src_img_fullres, ndims=ndims), os.path.join(reg_img_savepath_fullres, f"Src{pair_idx:04d}_ORG.nii.gz"), ) if src_mask_fullres is not None: nib.save( utils.converet_to_nibabel(src_mask_fullres, ndims=ndims), os.path.join(reg_msk_savepath_fullres, f"Src{pair_idx:04d}_ORG_GT.nii.gz"), ) # --- Register moving to fixed --- om.set_init_img(src_img_model) om.set_cond_img(tgt_img_model.clone().detach()) om.predict( T=[None, timesteps], proc_type=condition_type, ) ddf_comp = om.get_def() # --- DDF quality: percent negative Jacobian determinant --- neg_pct = compute_negdetj_pct(ddf_comp, ndims=ndims) negdetj_pct[pair_idx] = neg_pct print(f" %|J|<0 = {neg_pct:.4f}%") # --- Model-resolution registered image --- img_rec = om.apply_def( img=src_img_model, ddf=ddf_comp, padding_mode="zeros", ) nib.save( utils.converet_to_nibabel(img_rec, ndims=ndims), os.path.join(reg_img_savepath, f"{pair_tag}.nii.gz"), ) # --- Model-resolution registered mask --- msk_rec = None if src_mask_model is not None: msk_rec = om.apply_def( img=src_mask_model, ddf=ddf_comp, padding_mode="zeros", resample_mode="nearest", ) nib.save( utils.converet_to_nibabel(msk_rec, ndims=ndims), os.path.join(reg_msk_savepath, f"{pair_tag}_GT.nii.gz"), ) # --- Model-resolution DDF --- nib.save( utils.converet_to_nibabel(ddf_comp, ndims=ndims), os.path.join(reg_ddf_savepath, f"{pair_tag}.nii.gz"), ) # --- Full-resolution registered image --- img_rec_fullres = om.apply_def( img=src_img_fullres, ddf=ddf_comp, padding_mode="border", ) nib.save( utils.converet_to_nibabel(img_rec_fullres, ndims=ndims), os.path.join(reg_img_savepath_fullres, f"{pair_tag}.nii.gz"), ) # --- Full-resolution registered mask --- msk_rec_fullres = None if src_mask_fullres is not None: msk_rec_fullres = om.apply_def( img=src_mask_fullres, ddf=ddf_comp, padding_mode="zeros", resample_mode="nearest", ) nib.save( utils.converet_to_nibabel(msk_rec_fullres, ndims=ndims), os.path.join(reg_msk_savepath_fullres, f"{pair_tag}_GT.nii.gz"), ) # --- Full-resolution DDF --- ddf_fullres = F.interpolate( ddf_comp, size=src_orig_sz, mode="trilinear", align_corners=False, ) nib.save( utils.converet_to_nibabel(ddf_fullres, ndims=ndims), os.path.join(reg_ddf_savepath_fullres, f"{pair_tag}.nii.gz"), ) # --- Evaluation metrics (full-res organ labels) --- if ( organ_label_ids and src_mask_fullres is not None and tgt_mask_fullres is not None ): for ch_idx, cid in enumerate(sorted(organ_label_ids.keys())): lk = organ_label_ids[cid] tgt_mask_np = tgt_mask_fullres[0, ch_idx].cpu().numpy() src_mask_np = src_mask_fullres[0, ch_idx].cpu().numpy() if np.all(tgt_mask_np < 0) or np.all(src_mask_np < 0): continue # Pre-registration: source vs target (no deformation) pre_dsc = compute_dsc(src_mask_np, tgt_mask_np) pre_asd = compute_asd(src_mask_np, tgt_mask_np) pre_hd = compute_hd(src_mask_np, tgt_mask_np) metrics_pre[cid]["dsc"][pair_idx] = pre_dsc metrics_pre[cid]["asd"][pair_idx] = pre_asd metrics_pre[cid]["hd"][pair_idx] = pre_hd # Post-registration: registered mask vs target if msk_rec_fullres is not None: reg_mask_np = msk_rec_fullres[0, ch_idx].cpu().numpy() post_dsc = compute_dsc(reg_mask_np, tgt_mask_np) post_asd = compute_asd(reg_mask_np, tgt_mask_np) post_hd = compute_hd(reg_mask_np, tgt_mask_np) else: post_dsc = float("nan") post_asd = float("nan") post_hd = float("nan") metrics[cid]["dsc"][pair_idx] = post_dsc metrics[cid]["asd"][pair_idx] = post_asd metrics[cid]["hd"][pair_idx] = post_hd print( f" [{lk}] PRE DSC={pre_dsc:.4f} ASD={pre_asd:.2f} HD={pre_hd:.2f}" ) print( f" [{lk}] POST DSC={post_dsc:.4f} ASD={post_asd:.2f} HD={post_hd:.2f}" ) print("\nPaired registration complete.") # ========== Write evaluation CSVs ========== n_pairs = len(pairs) def _fmt(val): if val is None: return "" if np.isnan(val): return "NaN" return f"{val:.6f}" # --- Per-pair %|J|<0 CSV --- negdetj_csv_path = os.path.join(eval_dir, "negdetj_pct.csv") with open(negdetj_csv_path, "w", newline="") as f: writer = csv.writer(f) writer.writerow(["pair_idx", "fixed", "moving", "negdetj_pct"]) for pi, fixed_name, moving_name in pair_info: writer.writerow([pi, fixed_name, moving_name, _fmt(negdetj_pct.get(pi))]) print(f"Saved {negdetj_csv_path}") for cid in sorted(organ_label_ids.keys()): lk = organ_label_ids[cid] prefix = f"{lk}_" if len(organ_label_ids) > 1 else "" for metric_name in ["dsc", "asd", "hd"]: mn_upper = metric_name.upper() csv_path = os.path.join(eval_dir, f"{prefix}{metric_name}.csv") with open(csv_path, "w", newline="") as f: writer = csv.writer(f) writer.writerow([ "pair_idx", "fixed", "moving", f"pre_{mn_upper}", f"post_{mn_upper}", ]) for pi, fixed_name, moving_name in pair_info: pre_val = metrics_pre[cid][metric_name].get(pi) post_val = metrics[cid][metric_name].get(pi) writer.writerow([ pi, fixed_name, moving_name, _fmt(pre_val), _fmt(post_val), ]) print(f"Saved {csv_path}") # --- Overall summary --- overall_path = os.path.join(eval_dir, "overall.csv") with open(overall_path, "w", newline="") as f: writer = csv.writer(f) writer.writerow([ "label", "metric", "pre_mean", "pre_std", "post_mean", "post_std", "n_pairs", ]) # %|J|<0 summary (not per-label) negdetj_vals = [v for v in negdetj_pct.values() if not np.isnan(v)] writer.writerow([ "ALL", "%|J|<0", "", "", _fmt(np.mean(negdetj_vals) if negdetj_vals else float("nan")), _fmt(np.std(negdetj_vals) if negdetj_vals else float("nan")), len(negdetj_vals), ]) for cid in sorted(organ_label_ids.keys()): lk = organ_label_ids[cid] for metric_name in ["dsc", "asd", "hd"]: pre_vals = [ v for v in metrics_pre[cid][metric_name].values() if not np.isnan(v) ] post_vals = [ v for v in metrics[cid][metric_name].values() if not np.isnan(v) ] pre_mean = np.mean(pre_vals) if pre_vals else float("nan") pre_std = np.std(pre_vals) if pre_vals else float("nan") post_mean = np.mean(post_vals) if post_vals else float("nan") post_std = np.std(post_vals) if post_vals else float("nan") n = max(len(pre_vals), len(post_vals)) writer.writerow([ lk, metric_name.upper(), _fmt(pre_mean), _fmt(pre_std), _fmt(post_mean), _fmt(post_std), n, ]) print(f"Saved {overall_path}")