File size: 2,833 Bytes
2af0e94
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
#!/bin/bash -l
#SBATCH --job-name=om-single
#SBATCH --account=AIRR-P51-DAWN-GPU
#SBATCH --partition=pvc9
#SBATCH --nodes=1
#SBATCH --gres=gpu:4
#SBATCH --ntasks-per-node=8
#SBATCH --cpus-per-task=12
#SBATCH --time=36:00:00
#SBATCH --output=Logs/train_single_%j.out
#SBATCH --error=Logs/train_single_%j.err
#SBATCH --exclude=pvc-s-135,pvc-s-[194-256]

# 1 node x 4 XPU cards x 2 tiles/card = 8 XPU tiles
# Same settings as multi-node but single node to start training faster.

cd /rds/project/rds-TWhPgQVLKbA/Code/OmniMorph

. /etc/profile.d/modules.sh
module purge
module load rhel9/default-dawn

source ~/miniconda3/etc/profile.d/conda.sh
conda activate ~/rds/rds-airr-p51-TWhPgQVLKbA/Env/pub_env/pytorch-xpu

# --- CCL/MPI setup ---
export I_MPI_PMI_LIBRARY=/usr/local/software/slurm/current-rhel8/lib/libpmi2.so
export I_MPI_HYDRA_BOOTSTRAP=slurm
export CCL_WORKER_AFFINITY=auto
export CCL_ZE_CACHE_OPEN_IPC_HANDLES_THRESHOLD=100000

# --- XPU memory allocator ---
export PYTORCH_ALLOC_CONF=expandable_segments:True,max_split_size_mb:512

# --- Single-node setup ---
export MASTER_ADDR=$(scontrol show hostname "$SLURM_NODELIST" | head -n1)
export MASTER_PORT=12355

SRUN_TIMEOUT=43200  # 12h timeout per srun — epoch takes ~10h with 8 tiles
MAX_ITERATIONS=50   # Safety net: max crash-recovery attempts within the 36h walltime

echo "============================================"
echo "Job ID:        $SLURM_JOB_ID"
echo "Nodes:         $SLURM_NNODES"
echo "Tasks/node:    $SLURM_NTASKS_PER_NODE"
echo "Total tasks:   $SLURM_NTASKS"
echo "CPUs/task:     $SLURM_CPUS_PER_TASK"
echo "Master:        $MASTER_ADDR:$MASTER_PORT"
echo "Node list:     $SLURM_NODELIST"
echo "Walltime:      36:00:00"
echo "Timeout/srun:  ${SRUN_TIMEOUT}s"
echo "Max iters:     $MAX_ITERATIONS"
echo "============================================"

for ITER in $(seq 1 $MAX_ITERATIONS); do
    echo "=== Restart iteration $ITER at $(date) ==="

    timeout $SRUN_TIMEOUT srun --kill-on-bad-exit=1 bash -c '
export LOCAL_RANK=$SLURM_LOCALID
export RANK=$SLURM_PROCID
export WORLD_SIZE=$SLURM_NTASKS
export MASTER_ADDR='"$MASTER_ADDR"'
export MASTER_PORT='"$MASTER_PORT"'
python OM_train_3modes.py -C Config/config_om.yaml --batchsize 2 --max-steps-before-restart 0
'
    EXIT_CODE=$?
    echo "=== srun exit code: $EXIT_CODE at $(date) ==="

    if [ $EXIT_CODE -eq 0 ]; then
        echo "=== Training completed successfully ==="
        break
    elif [ $EXIT_CODE -eq 42 ]; then
        echo "=== Proactive restart (code 42), resuming in 5s ==="
        sleep 5
    elif [ $EXIT_CODE -eq 124 ]; then
        echo "=== Timeout (CCL hang), resuming in 10s ==="
        sleep 10
    else
        echo "=== Crash (exit code $EXIT_CODE), resuming in 10s ==="
        sleep 10
    fi
done

echo "=== Job finished at $(date), last exit code: $EXIT_CODE ==="