""" OM_train_3modes.py — Three-mode training (diffusion + registration + contrastive) using OMorpher. Drop-in replacement for OM_train_3modes.py. Uses the OMorpher object-oriented wrapper instead of procedural DeformDDPM calls, while preserving the same training logic, DDP support, loss functions, and checkpoint format. Usage: # Single-GPU python Scripts/OM_train_3modes.py -C Config/config_om.yaml # Multi-GPU (DDP) CUDA_VISIBLE_DEVICES=0,1 python Scripts/OM_train_3modes.py -C Config/config_om.yaml # Dummy data for testing (no real dataset needed) python Scripts/OM_train_3modes.py -C Config/config_om.yaml --dummy-samples 20 """ import os import sys # Add project root to path so imports work from Scripts/ ROOT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) sys.path.insert(0, ROOT_DIR) import gc import glob import random import numpy as np import torch import torch.distributed as dist import torch.multiprocessing as mp import torch.nn.functional as F from torch.nn.parallel import DistributedDataParallel as DDP from torch.optim import Adam from torch.utils.data import DataLoader from tqdm import tqdm import argparse from OMorpher import OMorpher from Diffusion.networks import DefRec_MutAttnNet from Diffusion.losses import Grad, LNCC, LMSE, MSLNCC from Dataloader.dataLoader import OMDataset_indiv, OMDataset_pair from Dataloader.dataloader_utils import thresh_img import utils # ========================== Constants ========================== EPS = 1e-5 MSK_EPS = 0.01 TEXT_EMBED_PROB = 0.7 AUG_RESAMPLE_PROB = 0.6 LOSS_WEIGHTS_DIFF = [2.0, 1.0, 16] # [ang, dist, reg] LOSS_WEIGHTS_REGIST = [1.0, 0.05, 128] # [imgsim, imgmse, ddf] DIFF_REG_BATCH_RATIO = 2 LOSS_WEIGHT_CONTRASTIVE = 1.0 CONTRASTIVE_STEP_RATIO = 2 # Auto-detect: use DDP only when multiple CUDA GPUs are available use_distributed = torch.cuda.is_available() and torch.cuda.device_count() > 1 # use_distributed = True # use_distributed = False # ========================== Arguments ========================== parser = argparse.ArgumentParser() parser.add_argument( "--config", "-C", help="Path for the config file", type=str, default="Config/config_all.yaml", required=False, ) parser.add_argument("--dummy-samples", type=int, default=0, help="Use dummy random data for testing (0=use real data)") parser.add_argument("--batchsize", type=int, default=0, help="Override batch size from config (0=use config value)") args = parser.parse_args() # ========================== Dummy Datasets ========================== class _DummyIndiv(torch.utils.data.Dataset): def __init__(self, n, sz, embd_dim=1024): self.n, self.sz, self.embd_dim = n, sz, embd_dim def __len__(self): return self.n def __getitem__(self, i): return np.random.rand(1, self.sz, self.sz, self.sz).astype(np.float64), np.random.randn(self.embd_dim).astype(np.float32) class _DummyPair(torch.utils.data.Dataset): def __init__(self, n, sz, embd_dim=1024): self.n, self.sz, self.embd_dim = n, sz, embd_dim def __len__(self): return self.n def __getitem__(self, i): return (np.random.rand(1, self.sz, self.sz, self.sz).astype(np.float64), np.random.rand(1, self.sz, self.sz, self.sz).astype(np.float64), np.random.randn(self.embd_dim).astype(np.float32), np.random.randn(self.embd_dim).astype(np.float32)) # ========================== DDP Setup ========================== def ddp_setup(rank, world_size): """ Args: rank: Unique identifier of each process world_size: Total number of processes """ os.environ["MASTER_ADDR"] = "localhost" os.environ["MASTER_PORT"] = "12355" dist.init_process_group(backend="nccl", rank=rank, world_size=world_size) torch.cuda.set_device(rank) # ========================== Helpers ========================== def reverse_diffuse_train(network, om, img_org, cond_imgs, T, text=None): """Registration reverse diffusion with selective gradient control. Mirrors DeformDDPM.diff_recover() with T=[None, T_regist]. Only the last k=2 timesteps have gradients enabled for efficient training. Args: network: DDP-wrapped (or raw) network module. om: OMorpher instance (provides STN instances and device info). img_org: Source image [B, 1, S, S, S]. cond_imgs: Processed conditioning image [B, 1, S, S, S]. T: [T_init, T_schedule]. T_init=None means no forward diffusion. T_schedule is a list of batched timestep lists from the training loop. text: Optional text embedding [B, 1024]. Returns: (ddf_comp, img_rec): Composed DDF and recovered image. """ B = img_org.shape[0] S = om.img_size # T[0] = None → no forward diffusion, start from original image ddf_comp = torch.zeros( [B, om.ndims] + [S] * om.ndims, dtype=torch.float32, device=om.device, ) img_rec = img_org.clone().detach() time_steps = T[1] k = 2 trainable_iterations = time_steps[-1:-k - 1:-1] net_module = network.module if isinstance(network, DDP) else network for i in time_steps: t = torch.tensor(np.array([i])).to(om.device) if i in trainable_iterations: # Gradients enabled — call through DDP wrapper for gradient sync pre_dvf = network(x=img_rec, y=cond_imgs, t=t, rec_num=2, text=text) else: # No gradients — call underlying module directly with torch.no_grad(): pre_dvf = net_module(x=img_rec, y=cond_imgs, t=t, rec_num=2, text=text) ddf_comp = om.stn_full(ddf_comp, pre_dvf) + pre_dvf img_rec = om.img_stn(img_org.clone().detach(), ddf_comp) return ddf_comp, img_rec def ddp_load_checkpoint(gpu_id, network, optimizer, model_file, use_dist=True, load_strict=False): """Load checkpoint with DDP-aware parameter broadcast.""" if gpu_id == 0: utils.print_memory_usage("Before Loading Model") if torch.cuda.is_available(): gc.collect() torch.cuda.empty_cache() checkpoint = torch.load(model_file, map_location='cpu') state_dict = checkpoint['model_state_dict'] # Strip DDP 'module.' and DeformDDPM 'network.' prefixes cleaned = {} for k, v in state_dict.items(): k = k.replace("module.", "") if k.startswith("network."): k = k[len("network."):] cleaned[k] = v net = network.module if use_dist else network net_keys = set(net.state_dict().keys()) filtered = {k: v for k, v in cleaned.items() if k in net_keys} net.load_state_dict(filtered, strict=load_strict) if load_strict: optimizer.load_state_dict(checkpoint['optimizer_state_dict']) utils.print_memory_usage("After Loading Checkpoint on GPU") if use_dist: # Broadcast model weights from rank 0 to all other GPUs dist.barrier() for param in network.parameters(): dist.broadcast(param.data, src=0) dist.barrier() for param_group in optimizer.param_groups: for param in param_group['params']: if param.grad is not None: dist.broadcast(param.grad, src=0) initial_epoch = int(os.path.basename(model_file).split('.')[0][:6]) + 1 return initial_epoch def save_checkpoint(network, optimizer, epoch, save_path, use_dist=True): """Save checkpoint with 'network.' key prefix for backward compatibility.""" net = network.module if use_dist and isinstance(network, DDP) else network state_dict = {f"network.{k}": v for k, v in net.state_dict().items()} torch.save({ 'model_state_dict': state_dict, 'optimizer_state_dict': optimizer.state_dict(), 'epoch': epoch, }, save_path) # ========================== Main Training ========================== def main_train(rank=0, world_size=1, train_mode_ratio=1, thresh_imgsim=0.01): if use_distributed: ddp_setup(rank, world_size) if torch.distributed.is_initialized(): print(f"World size: {torch.distributed.get_world_size()}") print(f"Communication backend: {torch.distributed.get_backend()}") gpu_id = rank device = f"cuda:{rank}" if use_distributed else None # ---- OMorpher initialisation (config, network, STN, losses, auto-checkpoint) ---- om = OMorpher(config=args.config, device=device) config = om.config if args.batchsize > 0: config['batchsize'] = args.batchsize if gpu_id == 0: print(config) epoch_per_save = config['epoch_per_save'] suffix_pth = f"_{config['data_name']}_{config['net_name']}.pth" model_dir = os.path.join(ROOT_DIR, 'Models', f"{config['data_name']}_{config['net_name']}/") # ---- Additional loss functions for the three training modes ---- # Diffusion losses reused from OMorpher: om._loss_dist (MRSE), om._loss_ang (NCC) loss_reg = Grad( penalty=['l1', 'negdetj', 'range'], ndims=om.ndims, outrange_thresh=0.2, outrange_weight=1e3, ) loss_reg1 = Grad( penalty=['l1', 'negdetj', 'range'], ndims=om.ndims, outrange_thresh=0.6, outrange_weight=1e3, ) # loss_imgsim = LNCC() loss_imgmse = MSLNCC() loss_imgmse = LMSE() # ---- DDP wrapping ---- if use_distributed: om.network.to(rank) om.stn_full.to(rank) om.stn_ctl.to(rank) om.img_stn.to(rank) om.msk_stn.to(rank) network = DDP(om.network, device_ids=[rank]) else: om.network.to(om.device) network = om.network # ---- Optimizer ---- optimizer = Adam(network.parameters(), lr=config["lr"]) # ---- Data loaders ---- if args.dummy_samples > 0: dataset = _DummyIndiv(args.dummy_samples, config['img_size']) datasetp = _DummyPair(args.dummy_samples, config['img_size']) else: dataset = OMDataset_indiv(transform=None) datasetp = OMDataset_pair(transform=None) train_loader = DataLoader( dataset, batch_size=config['batchsize'], shuffle=True, drop_last=True, ) train_loader_p = DataLoader( datasetp, batch_size=max(1, config['batchsize'] // DIFF_REG_BATCH_RATIO), shuffle=True, drop_last=True, ) # ---- Auto-resume from checkpoint ---- os.makedirs(model_dir, exist_ok=True) model_files = sorted(glob.glob(os.path.join(model_dir, "*.pth"))) if model_files: if gpu_id == 0: print(model_files) initial_epoch = ddp_load_checkpoint( gpu_id, network, optimizer, model_files[-1], use_distributed, ) else: initial_epoch = 0 if gpu_id == 0: print('len_train_data: ', len(dataset)) is_defrec = isinstance(om.network, DefRec_MutAttnNet) # ---- Training loop ---- for epoch in range(initial_epoch, config["epoch"]): epoch_loss_tot = 0.0 epoch_loss_gen_d = 0.0 epoch_loss_gen_a = 0.0 epoch_loss_reg_ = 0.0 epoch_loss_regist = 0.0 epoch_loss_imgsim_ = 0.0 epoch_loss_imgmse_ = 0.0 epoch_loss_ddfreg = 0.0 epoch_loss_contrastive = 0.0 network.train() loss_nan_step = 0 total = min(len(train_loader), len(train_loader_p)) for step, (batch, batch_p) in tqdm( enumerate(zip(train_loader, train_loader_p)), total=total, ): # ========================================================== # Mode 1: Diffusion training on single image # ========================================================== [x0, embd] = batch x0 = x0.to(om.device).type(torch.float32) embd_dev = embd.to(om.device).type(torch.float32) if np.random.uniform(0, 1) < TEXT_EMBED_PROB: embd_in = embd_dev else: embd_in = None n = x0.size()[0] x0 = x0.to(om.device) blind_mask = utils.get_random_deformed_mask( x0.shape[2:], apply_possibility=0.6, ).to(om.device) # Data augmentation if om.ndims > 2: if np.random.uniform(0, 1) < AUG_RESAMPLE_PROB: x0 = utils.random_resample(x0, deform_scale=0) else: [x0] = utils.random_permute([x0], select_dims=[-1, -2, -3]) if config['noise_scale'] > 0: if np.random.uniform(0, 1) < AUG_RESAMPLE_PROB: x0 = thresh_img(x0, [0, 1 * config['noise_scale']]) x0 = x0 * (np.random.normal(1, config['noise_scale'] * 1)) + np.random.normal(0, config['noise_scale'] * 1) t = torch.randint(0, om.timesteps, (n,)).to(om.device) proc_type = random.choice( ['adding', 'downsample', 'slice', 'slice1', 'none', 'uncon', 'uncon', 'uncon'], ) cond_img, _, cond_ratio = om._proc_cond_img(x0, proc_type=proc_type) # Forward diffusion + network prediction noisy_img, dvf_gt, _ = om._get_random_ddf(x0, t) if is_defrec: pre_dvf_I = network( x=noisy_img * blind_mask, y=cond_img, t=[t], rec_num=2, text=embd_in, ) else: pre_dvf_I = network( x=noisy_img * blind_mask, y=cond_img, t=t, rec_num=2, text=embd_in, ) # Diffusion losses loss_tot = 0 loss_ddf = loss_reg(pre_dvf_I, img=x0) trm_pred = om.stn_full(pre_dvf_I, dvf_gt) loss_gen_d = om._loss_dist( pred=trm_pred, inv_lab=dvf_gt, ddf_stn=None, mask=blind_mask, ) loss_gen_a = om._loss_ang( pred=trm_pred, inv_lab=dvf_gt, ddf_stn=None, mask=blind_mask, ) loss_tot += LOSS_WEIGHTS_DIFF[0] * loss_gen_a + LOSS_WEIGHTS_DIFF[1] * loss_gen_d loss_tot += LOSS_WEIGHTS_DIFF[2] * loss_ddf loss_tot = torch.sqrt(1. + MSK_EPS - cond_ratio) * loss_tot # NaN / divergence checks if torch.isnan(x0).any(): print(f"*** Encountered NaN in input image x0 at epoch {epoch}, step {step}.") if loss_ddf > 0.001: print(f"*** High diffusion DDF loss at epoch {epoch}, step {step}: {loss_ddf.item()}.") if torch.isnan(loss_tot) or torch.isinf(loss_tot): print(f"*** Encountered NaN or Inf loss at epoch {epoch}, step {step}. Skipping this batch.") loss_nan_step += 1 continue if loss_nan_step > 5: print(f"*** Too many NaN or Inf losses ({loss_nan_step} times) at epoch {epoch}, step {step}. Stopping training.") raise ValueError("Too many NaN losses detected in loss_tot. Code terminated.") optimizer.zero_grad() loss_tot.backward() optimizer.step() epoch_loss_tot += loss_tot.item() / total epoch_loss_gen_d += loss_gen_d.item() / total epoch_loss_gen_a += loss_gen_a.item() / total epoch_loss_reg_ += loss_ddf.item() / total # ========================================================== # Mode 2: Contrastive training (text-image alignment) # ========================================================== loss_contra_val = None if step % CONTRASTIVE_STEP_RATIO == 0: # Access raw network (not DDP-wrapped) for contrastive forward pass raw_network = network.module if isinstance(network, DDP) else network n_contra = x0.size()[0] t_contra = torch.randint(0, config["timesteps"], (n_contra,)).to(om.device) _ = raw_network(x=(x0 * blind_mask).detach(), y=cond_img.detach(), t=t_contra, text=embd_dev.detach()) if hasattr(raw_network, 'img_embd') and raw_network.img_embd is not None: img_embd = raw_network.img_embd # [B, 1024] loss_contra = LOSS_WEIGHT_CONTRASTIVE * (1 - F.cosine_similarity(img_embd, embd_dev, dim=-1).mean()) optimizer.zero_grad() loss_contra.backward() torch.nn.utils.clip_grad_norm_(network.parameters(), max_norm=0.05) optimizer.step() loss_contra_val = loss_contra.item() epoch_loss_contrastive += loss_contra_val / total # ========================================================== # Mode 3: Registration training on paired images # ========================================================== if step % train_mode_ratio == 0: [x1, y1, _, embd_y] = batch_p if np.random.uniform(0, 1) < TEXT_EMBED_PROB: embd_y = embd_y.to(om.device).type(torch.float32) else: embd_y = None x1 = x1.to(om.device).type(torch.float32) y1 = y1.to(om.device).type(torch.float32) n = x1.size()[0] # Augmentation [x1, y1] = utils.random_permute([x1, y1], select_dims=[-1, -2, -3]) if config['noise_scale'] > 0: [x1, y1] = thresh_img([x1, y1], [0, 2 * config['noise_scale']]) random_scale = np.random.normal(1, config['noise_scale'] * 1) random_shift = np.random.normal(0, config['noise_scale'] * 1) x1 = x1 * random_scale + random_shift y1 = y1 * random_scale + random_shift # Timestep schedule for reverse diffusion scale_regist = np.random.uniform(0.0, 0.7) T_regist = sorted( random.sample( range(int(om.timesteps * scale_regist), om.timesteps), 16, ), reverse=True, ) T_regist = [ [t_val for _ in range(max(1, config["batchsize"] // 2))] for t_val in T_regist ] proc_type = random.choice(['downsample', 'slice', 'slice1', 'none', 'none']) y1_proc, msk_tgt, cond_ratio = om._proc_cond_img( y1, proc_type=proc_type, ) # Reverse diffusion for registration (via OMorpher's network & STN) ddf_comp, img_rec = reverse_diffuse_train( network, om, x1, y1_proc, T=[None, T_regist], text=embd_y, ) # Registration losses loss_sim = loss_imgsim(img_rec, y1, label=(y1 > thresh_imgsim)) loss_mse = loss_imgmse(img_rec, y1) loss_ddf1 = loss_reg1(ddf_comp, img=y1) loss_regist = 0 loss_regist += LOSS_WEIGHTS_REGIST[0] * loss_sim loss_regist += LOSS_WEIGHTS_REGIST[1] * loss_mse loss_regist += LOSS_WEIGHTS_REGIST[2] * loss_ddf1 # NaN / divergence checks if torch.isnan(x0).any(): print(f"*** Encountered NaN in input image x0 at epoch {epoch}, step {step}.") if loss_ddf1 > 0.002: print(f"*** High registration DDF loss at epoch {epoch}, step {step}: {loss_ddf1.item()}.") loss_regist = torch.sqrt(cond_ratio + MSK_EPS) * loss_regist optimizer.zero_grad() loss_regist.backward() torch.nn.utils.clip_grad_norm_(network.parameters(), max_norm=0.4) optimizer.step() epoch_loss_regist += loss_regist.item() / total epoch_loss_imgsim_ += loss_sim.item() / total epoch_loss_imgmse_ += loss_mse.item() / total epoch_loss_ddfreg += loss_ddf1.item() / total if step % 10 == 0: print('step:', step, ':', loss_tot.item(), '=', loss_gen_a.item(), '+', loss_gen_d.item(), '+', loss_ddf.item()) if loss_contra_val is not None: print(f' loss_contrastive: {loss_contra_val:.6f}') print(f' loss_regist: {loss_regist} = {loss_sim} (imgsim) + {loss_mse} (imgmse) + {loss_ddf1} (ddf)') if 1: print('==================') print(epoch, ':', epoch_loss_tot, '=', epoch_loss_gen_a, '+', epoch_loss_gen_d, '+', epoch_loss_reg_, ' (ang+dist+regul)') print(f' loss_contrastive: {epoch_loss_contrastive}') print(f' loss_regist: {epoch_loss_regist} = {epoch_loss_imgsim_} (imgsim) + {epoch_loss_imgmse_} (imgmse) + {epoch_loss_ddfreg} (ddf)') print('==================') if 0 == epoch % epoch_per_save: save_path = os.path.join(model_dir, str(epoch).rjust(6, '0') + suffix_pth) os.makedirs(model_dir, exist_ok=True) if not use_distributed: print(f"saved in {save_path}") save_checkpoint(network, optimizer, epoch, save_path, use_dist=False) elif gpu_id == 0: print(f"saved in {save_path}") save_checkpoint(network, optimizer, epoch, save_path, use_dist=True) # Resource cleanup if torch.cuda.is_available(): torch.cuda.empty_cache() gc.collect() if use_distributed and dist.is_initialized(): dist.destroy_process_group() if __name__ == "__main__": if use_distributed: world_size = torch.cuda.device_count() print(f"Distributed GPU number = {world_size}") mp.spawn(main_train, args=(world_size,), nprocs=world_size) else: main_train(0, 1)