Omini3D / Scripts /OM_reg_pair_ext.py
maxmo2009's picture
Sync from local: code + epoch-110 checkpoint, clean README
2af0e94 verified
"""
OM_reg_pair.py — Paired registration using OMorpher with external dataset.
Loads fixed/moving pairs from a Learn2Reg-style JSON dataset file
(e.g. HippocampusMR_dataset.json) and registers each moving image to its
paired fixed image. Saves registered images, masks, DDFs, source originals,
and evaluation metrics (DSC, ASD, HD) per organ label.
Usage:
python Scripts/OM_reg_pair.py -C Config/config_om.yaml \
--dataset-json /path/to/HippocampusMR_dataset.json \
--split val
python Scripts/OM_reg_pair.py -C Config/config_om.yaml \
--dataset-json /path/to/HippocampusMR_dataset.json \
--split test -N 10
"""
import os
import sys
# Add project root to path so imports work from Scripts/
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
import csv
import json
import numpy as np
import torch
import torch.nn.functional as F
import nibabel as nib
import yaml
import SimpleITK as sitk
from scipy.ndimage import distance_transform_edt, binary_erosion
from tqdm import tqdm
import utils
from Dataloader.dataLoader import reverse_axis_order
from OMorpher import OMorpher
# ========== CLI ==========
import argparse
parser = argparse.ArgumentParser()
parser.add_argument(
"--config", "-C",
help="Path for the config file",
type=str,
default="Config/config_om.yaml",
required=False,
)
parser.add_argument(
"--dataset-json",
help="Path to the Learn2Reg-style dataset JSON",
type=str,
default="~/rds/rds-airr-p51-TWhPgQVLKbA/Code/Registration/Dataset/HippocampusMR/HippocampusMR_dataset.json",
)
parser.add_argument(
"--split",
help="Which registration split to use: 'val' or 'test'",
type=str,
choices=["val", "test"],
default="val",
)
parser.add_argument(
"--max-samples", "-N",
help="Max number of pairs to register (0 = all)",
type=int,
default=0,
)
args = parser.parse_args()
# ========== Config ==========
with open(args.config, "r") as file:
hyp_parameters = yaml.safe_load(file)
print(hyp_parameters)
hyp_parameters["batchsize"] = 1
model_img_sz = hyp_parameters["img_size"]
timesteps = hyp_parameters["timesteps"]
condition_type = hyp_parameters["condition_type"]
ndims = hyp_parameters["ndims"]
# ========== Load external dataset JSON ==========
dataset_json_path = os.path.expanduser(args.dataset_json)
dataset_root = os.path.dirname(dataset_json_path)
with open(dataset_json_path, "r") as f:
dataset_meta = json.load(f)
dataset_name = dataset_meta.get("name", "UnknownDataset")
print(f"Dataset: {dataset_name}")
# Select registration split
if args.split == "val":
pairs = dataset_meta.get("registration_val", [])
elif args.split == "test":
pairs = dataset_meta.get("registration_test", [])
else:
raise ValueError(f"Unknown split: {args.split}")
if args.max_samples > 0:
pairs = pairs[: args.max_samples]
print(f"Split: {args.split}, Pairs: {len(pairs)}")
# Build label lookup: image basename -> label relative path
# from the "training" entries in the JSON
_label_lookup = {}
for entry in dataset_meta.get("training", []):
img_base = os.path.basename(entry["image"])
_label_lookup[img_base] = entry.get("label")
# Label class names (from JSON: "0": "background", "1": "head", "2": "tail")
_label_names = dataset_meta.get("labels", {}).get("0", {})
# Organ labels are all non-background classes
organ_label_ids = {int(k): v for k, v in _label_names.items() if int(k) > 0}
print(f"Organ labels for evaluation: {organ_label_ids}")
# ========== OMorpher setup ==========
epoch = f'{hyp_parameters["model_id_str"]}_{hyp_parameters["data_name"]}_{hyp_parameters["net_name"]}'
model_save_path = os.path.join(
f'Models/{hyp_parameters["data_name"]}_{hyp_parameters["net_name"]}/',
str(epoch) + ".pth",
)
print("Loading model from:", model_save_path)
om = OMorpher(
config=hyp_parameters,
checkpoint_path=model_save_path,
device=str(hyp_parameters.get("device", "cpu")),
)
print(om)
# ========== Output directories ==========
reg_img_savepath = hyp_parameters["reg_img_savepath"]
reg_msk_savepath = hyp_parameters["reg_msk_savepath"]
reg_ddf_savepath = hyp_parameters["reg_ddf_savepath"]
reg_img_savepath_fullres = reg_img_savepath.rstrip("/") + "_fullres/"
reg_msk_savepath_fullres = reg_msk_savepath.rstrip("/") + "_fullres/"
reg_ddf_savepath_fullres = reg_ddf_savepath.rstrip("/") + "_fullres/"
eval_dir = os.path.join(reg_img_savepath, "..", "eval")
for p in [
reg_img_savepath, reg_msk_savepath, reg_ddf_savepath,
reg_img_savepath_fullres, reg_msk_savepath_fullres, reg_ddf_savepath_fullres,
eval_dir,
]:
os.makedirs(p, exist_ok=True)
# ========== Helper functions ==========
def resolve_path(rel_path):
"""Resolve a relative path from the dataset JSON to an absolute path."""
if os.path.isabs(rel_path):
return rel_path
return os.path.normpath(os.path.join(dataset_root, rel_path))
def load_volume(nifti_path):
"""Load a NIfTI volume: axis reorder only.
OMorpher._standardize_img handles: normalize → pad-to-cube → resize to model res.
"""
volume = sitk.ReadImage(nifti_path)
volume = sitk.GetArrayFromImage(volume)
volume = reverse_axis_order(volume)
if volume.ndim == 4:
volume = volume[:, :, :, 0]
return volume
def load_label(nifti_path):
"""Load a NIfTI label map: axis reorder only.
OMorpher._standardize_label handles: pad-to-cube → resize to model res (nearest).
"""
label = sitk.ReadImage(nifti_path)
label = sitk.GetArrayFromImage(label)
label = reverse_axis_order(label)
if label.ndim > 3:
label = label[:, :, :, 0]
return label
def get_label_path_for_image(image_rel_path):
"""Find the label path for an image by looking up the training entries."""
img_base = os.path.basename(image_rel_path)
label_rel = _label_lookup.get(img_base)
if label_rel is None:
return None
return resolve_path(label_rel)
def split_label_classes(label_map, class_ids):
"""Split a multi-class label map into per-class binary masks.
Returns a dict {class_id: binary_numpy_array}.
"""
masks = {}
for cid in class_ids:
masks[cid] = (label_map == cid).astype(np.float32)
return masks
def get_volume_name(path):
"""Extract a short name from a NIfTI file path."""
name = os.path.basename(path)
for ext in [".nii.gz", ".nii"]:
if name.endswith(ext):
name = name[: -len(ext)]
break
return name
# ---------- Evaluation metrics ----------
def _surface_distances(pred, gt):
"""Compute directed surface distances between two binary masks."""
pred_bool = pred > 0.5
gt_bool = gt > 0.5
if not np.any(pred_bool) or not np.any(gt_bool):
return None, None
struct = None
pred_surface = pred_bool ^ binary_erosion(pred_bool, structure=struct)
gt_surface = gt_bool ^ binary_erosion(gt_bool, structure=struct)
if not np.any(pred_surface):
pred_surface = pred_bool
if not np.any(gt_surface):
gt_surface = gt_bool
dt_gt = distance_transform_edt(~gt_surface)
dt_pred = distance_transform_edt(~pred_surface)
return dt_gt[pred_surface], dt_pred[gt_surface]
def compute_dsc(pred, gt):
"""Dice Similarity Coefficient."""
pred_bool = pred > 0.5
gt_bool = gt > 0.5
intersection = np.sum(pred_bool & gt_bool)
denom = np.sum(pred_bool) + np.sum(gt_bool)
if denom == 0:
return 1.0
return 2.0 * float(intersection) / float(denom)
def compute_asd(pred, gt):
"""Average (symmetric) Surface Distance."""
d1, d2 = _surface_distances(pred, gt)
if d1 is None:
return float("nan")
return (np.mean(d1) + np.mean(d2)) / 2.0
def compute_hd(pred, gt):
"""Hausdorff Distance (maximum of directed HDs)."""
d1, d2 = _surface_distances(pred, gt)
if d1 is None:
return float("nan")
return float(max(np.max(d1), np.max(d2)))
def compute_negdetj_pct(ddf, ndims=3):
"""Percent of voxels with negative Jacobian determinant.
Args:
ddf: displacement field tensor [1, ndims, ...] or numpy array.
ndims: 2 or 3.
Returns:
Percentage of voxels where det(Jacobian) < 0.
"""
if isinstance(ddf, torch.Tensor):
ddf = ddf.detach().cpu().numpy()
# ddf shape: [1, C, ...] or [C, ...]
if ddf.ndim == ndims + 2:
ddf = ddf[0] # remove batch dim -> [C, ...]
# Compute spatial gradients via finite differences (forward diff, clipped)
if ndims == 3:
# ddf: [3, D, H, W]
# Derivatives along each spatial axis
dux_dx = np.diff(ddf[0], axis=0, append=ddf[0, -1:, :, :])
duy_dx = np.diff(ddf[1], axis=0, append=ddf[1, -1:, :, :])
duz_dx = np.diff(ddf[2], axis=0, append=ddf[2, -1:, :, :])
dux_dy = np.diff(ddf[0], axis=1, append=ddf[0, :, -1:, :])
duy_dy = np.diff(ddf[1], axis=1, append=ddf[1, :, -1:, :])
duz_dy = np.diff(ddf[2], axis=1, append=ddf[2, :, -1:, :])
dux_dz = np.diff(ddf[0], axis=2, append=ddf[0, :, :, -1:])
duy_dz = np.diff(ddf[1], axis=2, append=ddf[1, :, :, -1:])
duz_dz = np.diff(ddf[2], axis=2, append=ddf[2, :, :, -1:])
# Jacobian = I + du/dx
j11 = 1.0 + dux_dx; j12 = dux_dy; j13 = dux_dz
j21 = duy_dx; j22 = 1.0 + duy_dy; j23 = duy_dz
j31 = duz_dx; j32 = duz_dy; j33 = 1.0 + duz_dz
detj = (
j11 * (j22 * j33 - j23 * j32)
- j12 * (j21 * j33 - j23 * j31)
+ j13 * (j21 * j32 - j22 * j31)
)
elif ndims == 2:
dux_dx = np.diff(ddf[0], axis=0, append=ddf[0, -1:, :])
duy_dx = np.diff(ddf[1], axis=0, append=ddf[1, -1:, :])
dux_dy = np.diff(ddf[0], axis=1, append=ddf[0, :, -1:])
duy_dy = np.diff(ddf[1], axis=1, append=ddf[1, :, -1:])
detj = (1.0 + dux_dx) * (1.0 + duy_dy) - dux_dy * duy_dx
else:
raise ValueError(f"Unsupported ndims={ndims}")
n_neg = np.sum(detj < 0)
n_total = detj.size
return 100.0 * float(n_neg) / float(n_total)
# ========== Prepare evaluation structures ==========
# metrics[class_id][metric_name][pair_idx] = value (post-registration)
metrics = {
cid: {"dsc": {}, "asd": {}, "hd": {}}
for cid in organ_label_ids
}
# metrics_pre: same structure but for pre-registration (source vs target, no deformation)
metrics_pre = {
cid: {"dsc": {}, "asd": {}, "hd": {}}
for cid in organ_label_ids
}
# Per-pair DDF quality metric (not per-class)
negdetj_pct = {} # pair_idx -> percentage of negative Jacobian determinant
# Also collect per-pair info for the CSV
pair_info = [] # list of (pair_idx, fixed_name, moving_name)
# ========== Paired registration ==========
with torch.no_grad():
for pair_idx, pair in enumerate(tqdm(pairs, desc="Pairs")):
fixed_rel = pair["fixed"]
moving_rel = pair["moving"]
fixed_path = resolve_path(fixed_rel)
moving_path = resolve_path(moving_rel)
fixed_name = get_volume_name(fixed_rel)
moving_name = get_volume_name(moving_rel)
pair_tag = f"Tgt{pair_idx:04d}_Src{pair_idx:04d}"
pair_info.append((pair_idx, fixed_name, moving_name))
print(f"\n [{pair_idx}] Fixed: {fixed_name}, Moving: {moving_name}")
# --- Load volumes ---
fixed_vol = load_volume(fixed_path)
moving_vol = load_volume(moving_path)
# --- Load labels (if available) ---
fixed_label_path = get_label_path_for_image(fixed_rel)
moving_label_path = get_label_path_for_image(moving_rel)
fixed_label_map = None
moving_label_map = None
if fixed_label_path is not None and os.path.exists(fixed_label_path):
fixed_label_map = load_label(fixed_label_path)
if moving_label_path is not None and os.path.exists(moving_label_path):
moving_label_map = load_label(moving_label_path)
# --- Prepare tensors via OMorpher ---
# Set moving image as init (source to be deformed)
om.set_init_img(moving_vol)
src_img_model = om._init_img.clone()
src_img_fullres = om._init_img_raw.clone()
src_orig_sz = list(src_img_fullres.shape[2:])
# Set fixed image as conditioning (target)
om.set_init_img(fixed_vol)
tgt_img_model = om._init_img.clone()
tgt_img_fullres = om._init_img_raw.clone()
# Standardize labels through OMorpher
src_mask_model, src_mask_fullres = None, None
tgt_mask_model, tgt_mask_fullres = None, None
if moving_label_map is not None:
# Split into per-class binary masks, stack as channels
src_class_masks = split_label_classes(moving_label_map, organ_label_ids.keys())
src_masks_model = []
src_masks_fullres = []
om.set_init_img(moving_vol) # reset so _standardize_label uses correct shape
for cid in sorted(organ_label_ids.keys()):
m_model, m_fullres = om._standardize_label(src_class_masks[cid])
src_masks_model.append(m_model)
src_masks_fullres.append(m_fullres)
src_mask_model = torch.cat(src_masks_model, dim=1)
src_mask_fullres = torch.cat(src_masks_fullres, dim=1)
if fixed_label_map is not None:
tgt_class_masks = split_label_classes(fixed_label_map, organ_label_ids.keys())
tgt_masks_model = []
tgt_masks_fullres = []
om.set_init_img(fixed_vol) # reset so _standardize_label uses correct shape
for cid in sorted(organ_label_ids.keys()):
m_model, m_fullres = om._standardize_label(tgt_class_masks[cid])
tgt_masks_model.append(m_model)
tgt_masks_fullres.append(m_fullres)
tgt_mask_model = torch.cat(tgt_masks_model, dim=1)
tgt_mask_fullres = torch.cat(tgt_masks_fullres, dim=1)
# --- Save target (fixed) original at model resolution ---
nib.save(
utils.converet_to_nibabel(tgt_img_model, ndims=ndims),
os.path.join(reg_img_savepath, f"{pair_tag}_TGT_ORG.nii.gz"),
)
if tgt_mask_model is not None:
nib.save(
utils.converet_to_nibabel(tgt_mask_model, ndims=ndims),
os.path.join(reg_msk_savepath, f"{pair_tag}_TGT_ORG_GT.nii.gz"),
)
# --- Save source (moving) original at model resolution ---
nib.save(
utils.converet_to_nibabel(src_img_model, ndims=ndims),
os.path.join(reg_img_savepath, f"Src{pair_idx:04d}_ORG.nii.gz"),
)
if src_mask_model is not None:
nib.save(
utils.converet_to_nibabel(src_mask_model, ndims=ndims),
os.path.join(reg_msk_savepath, f"Src{pair_idx:04d}_ORG_GT.nii.gz"),
)
# --- Save target original at full resolution ---
nib.save(
utils.converet_to_nibabel(tgt_img_fullres, ndims=ndims),
os.path.join(reg_img_savepath_fullres, f"{pair_tag}_TGT_ORG.nii.gz"),
)
if tgt_mask_fullres is not None:
nib.save(
utils.converet_to_nibabel(tgt_mask_fullres, ndims=ndims),
os.path.join(reg_msk_savepath_fullres, f"{pair_tag}_TGT_ORG_GT.nii.gz"),
)
# --- Save source original at full resolution ---
nib.save(
utils.converet_to_nibabel(src_img_fullres, ndims=ndims),
os.path.join(reg_img_savepath_fullres, f"Src{pair_idx:04d}_ORG.nii.gz"),
)
if src_mask_fullres is not None:
nib.save(
utils.converet_to_nibabel(src_mask_fullres, ndims=ndims),
os.path.join(reg_msk_savepath_fullres, f"Src{pair_idx:04d}_ORG_GT.nii.gz"),
)
# --- Register moving to fixed ---
om.set_init_img(src_img_model)
om.set_cond_img(tgt_img_model.clone().detach())
om.predict(
T=[None, timesteps],
proc_type=condition_type,
)
ddf_comp = om.get_def()
# --- DDF quality: percent negative Jacobian determinant ---
neg_pct = compute_negdetj_pct(ddf_comp, ndims=ndims)
negdetj_pct[pair_idx] = neg_pct
print(f" %|J|<0 = {neg_pct:.4f}%")
# --- Model-resolution registered image ---
img_rec = om.apply_def(
img=src_img_model, ddf=ddf_comp, padding_mode="zeros",
)
nib.save(
utils.converet_to_nibabel(img_rec, ndims=ndims),
os.path.join(reg_img_savepath, f"{pair_tag}.nii.gz"),
)
# --- Model-resolution registered mask ---
msk_rec = None
if src_mask_model is not None:
msk_rec = om.apply_def(
img=src_mask_model, ddf=ddf_comp,
padding_mode="zeros", resample_mode="nearest",
)
nib.save(
utils.converet_to_nibabel(msk_rec, ndims=ndims),
os.path.join(reg_msk_savepath, f"{pair_tag}_GT.nii.gz"),
)
# --- Model-resolution DDF ---
nib.save(
utils.converet_to_nibabel(ddf_comp, ndims=ndims),
os.path.join(reg_ddf_savepath, f"{pair_tag}.nii.gz"),
)
# --- Full-resolution registered image ---
img_rec_fullres = om.apply_def(
img=src_img_fullres, ddf=ddf_comp, padding_mode="border",
)
nib.save(
utils.converet_to_nibabel(img_rec_fullres, ndims=ndims),
os.path.join(reg_img_savepath_fullres, f"{pair_tag}.nii.gz"),
)
# --- Full-resolution registered mask ---
msk_rec_fullres = None
if src_mask_fullres is not None:
msk_rec_fullres = om.apply_def(
img=src_mask_fullres, ddf=ddf_comp,
padding_mode="zeros", resample_mode="nearest",
)
nib.save(
utils.converet_to_nibabel(msk_rec_fullres, ndims=ndims),
os.path.join(reg_msk_savepath_fullres, f"{pair_tag}_GT.nii.gz"),
)
# --- Full-resolution DDF ---
ddf_fullres = F.interpolate(
ddf_comp, size=src_orig_sz, mode="trilinear", align_corners=False,
)
nib.save(
utils.converet_to_nibabel(ddf_fullres, ndims=ndims),
os.path.join(reg_ddf_savepath_fullres, f"{pair_tag}.nii.gz"),
)
# --- Evaluation metrics (full-res organ labels) ---
if (
organ_label_ids
and src_mask_fullres is not None
and tgt_mask_fullres is not None
):
for ch_idx, cid in enumerate(sorted(organ_label_ids.keys())):
lk = organ_label_ids[cid]
tgt_mask_np = tgt_mask_fullres[0, ch_idx].cpu().numpy()
src_mask_np = src_mask_fullres[0, ch_idx].cpu().numpy()
if np.all(tgt_mask_np < 0) or np.all(src_mask_np < 0):
continue
# Pre-registration: source vs target (no deformation)
pre_dsc = compute_dsc(src_mask_np, tgt_mask_np)
pre_asd = compute_asd(src_mask_np, tgt_mask_np)
pre_hd = compute_hd(src_mask_np, tgt_mask_np)
metrics_pre[cid]["dsc"][pair_idx] = pre_dsc
metrics_pre[cid]["asd"][pair_idx] = pre_asd
metrics_pre[cid]["hd"][pair_idx] = pre_hd
# Post-registration: registered mask vs target
if msk_rec_fullres is not None:
reg_mask_np = msk_rec_fullres[0, ch_idx].cpu().numpy()
post_dsc = compute_dsc(reg_mask_np, tgt_mask_np)
post_asd = compute_asd(reg_mask_np, tgt_mask_np)
post_hd = compute_hd(reg_mask_np, tgt_mask_np)
else:
post_dsc = float("nan")
post_asd = float("nan")
post_hd = float("nan")
metrics[cid]["dsc"][pair_idx] = post_dsc
metrics[cid]["asd"][pair_idx] = post_asd
metrics[cid]["hd"][pair_idx] = post_hd
print(
f" [{lk}] PRE DSC={pre_dsc:.4f} ASD={pre_asd:.2f} HD={pre_hd:.2f}"
)
print(
f" [{lk}] POST DSC={post_dsc:.4f} ASD={post_asd:.2f} HD={post_hd:.2f}"
)
print("\nPaired registration complete.")
# ========== Write evaluation CSVs ==========
n_pairs = len(pairs)
def _fmt(val):
if val is None:
return ""
if np.isnan(val):
return "NaN"
return f"{val:.6f}"
# --- Per-pair %|J|<0 CSV ---
negdetj_csv_path = os.path.join(eval_dir, "negdetj_pct.csv")
with open(negdetj_csv_path, "w", newline="") as f:
writer = csv.writer(f)
writer.writerow(["pair_idx", "fixed", "moving", "negdetj_pct"])
for pi, fixed_name, moving_name in pair_info:
writer.writerow([pi, fixed_name, moving_name, _fmt(negdetj_pct.get(pi))])
print(f"Saved {negdetj_csv_path}")
for cid in sorted(organ_label_ids.keys()):
lk = organ_label_ids[cid]
prefix = f"{lk}_" if len(organ_label_ids) > 1 else ""
for metric_name in ["dsc", "asd", "hd"]:
mn_upper = metric_name.upper()
csv_path = os.path.join(eval_dir, f"{prefix}{metric_name}.csv")
with open(csv_path, "w", newline="") as f:
writer = csv.writer(f)
writer.writerow([
"pair_idx", "fixed", "moving",
f"pre_{mn_upper}", f"post_{mn_upper}",
])
for pi, fixed_name, moving_name in pair_info:
pre_val = metrics_pre[cid][metric_name].get(pi)
post_val = metrics[cid][metric_name].get(pi)
writer.writerow([
pi, fixed_name, moving_name,
_fmt(pre_val), _fmt(post_val),
])
print(f"Saved {csv_path}")
# --- Overall summary ---
overall_path = os.path.join(eval_dir, "overall.csv")
with open(overall_path, "w", newline="") as f:
writer = csv.writer(f)
writer.writerow([
"label", "metric",
"pre_mean", "pre_std",
"post_mean", "post_std",
"n_pairs",
])
# %|J|<0 summary (not per-label)
negdetj_vals = [v for v in negdetj_pct.values() if not np.isnan(v)]
writer.writerow([
"ALL",
"%|J|<0",
"", "",
_fmt(np.mean(negdetj_vals) if negdetj_vals else float("nan")),
_fmt(np.std(negdetj_vals) if negdetj_vals else float("nan")),
len(negdetj_vals),
])
for cid in sorted(organ_label_ids.keys()):
lk = organ_label_ids[cid]
for metric_name in ["dsc", "asd", "hd"]:
pre_vals = [
v for v in metrics_pre[cid][metric_name].values()
if not np.isnan(v)
]
post_vals = [
v for v in metrics[cid][metric_name].values()
if not np.isnan(v)
]
pre_mean = np.mean(pre_vals) if pre_vals else float("nan")
pre_std = np.std(pre_vals) if pre_vals else float("nan")
post_mean = np.mean(post_vals) if post_vals else float("nan")
post_std = np.std(post_vals) if post_vals else float("nan")
n = max(len(pre_vals), len(post_vals))
writer.writerow([
lk,
metric_name.upper(),
_fmt(pre_mean), _fmt(pre_std),
_fmt(post_mean), _fmt(post_std),
n,
])
print(f"Saved {overall_path}")