prediff_code / scripts /train_diffusion /prediff_lightning_module.py
weatherforecast1024's picture
Upload folder using huggingface_hub
7667a87 verified
Raw
History Blame Contribute Delete
51.6 kB
from omegaconf import OmegaConf
import os
from shutil import copyfile
import warnings
from typing import Dict,Sequence,Union
import inspect
import numpy as np
import torch
from torch.optim.lr_scheduler import LambdaLR, CosineAnnealingLR, SequentialLR
import torchmetrics
from lightning.pytorch import Trainer, loggers as pl_loggers
from lightning.pytorch.profilers import PyTorchProfiler
from lightning.pytorch.strategies import DDPStrategy
from lightning.pytorch.callbacks import (
Callback, LearningRateMonitor, DeviceStatsMonitor,
EarlyStopping, ModelCheckpoint
)
from lightning.pytorch.utilities import grad_norm
from einops import rearrange
from models.vae import AutoencoderKL
from models.knowledge_alignment import SEVIRAvgIntensityAlignment,get_alignment_kwargs_avg_x
from models.diffusion import LatentDiffusion
from models.core_model.cuboid_transformer import CuboidTransformerUNet
from datamodule import SEVIRLightningDataModule,vis_sevir_seq
from utils.path import (
default_exps_dir,
default_pretrained_vae_dir,default_pretrained_alignment_dir
)
from utils.optim import disable_train,warmup_lambda
from utils.layout import step_layout_to_in_out_slice
from evaluation import FrechetVideoDistance,SEVIRSkillScore
class PreDiffSEVIRPLModule(LatentDiffusion):
def __init__(self,
total_num_steps: int,
oc_file: str = None,
save_dir: str = None):
self.total_num_steps = total_num_steps
if oc_file is not None:
oc_from_file = OmegaConf.load(open(oc_file, "r"))
else:
oc_from_file = None
oc = self.get_base_config(oc_from_file=oc_from_file)
self.save_hyperparameters(oc)
self.oc = oc
latent_model_cfg = OmegaConf.to_object(oc.model.latent_model)
num_blocks = len(latent_model_cfg["depth"])
if isinstance(latent_model_cfg["self_pattern"], str):
block_attn_patterns = [latent_model_cfg["self_pattern"]] * num_blocks
else:
block_attn_patterns = OmegaConf.to_container(latent_model_cfg["self_pattern"])
latent_model = CuboidTransformerUNet(
input_shape=latent_model_cfg["input_shape"],
target_shape=latent_model_cfg["target_shape"],
base_units=latent_model_cfg["base_units"],
scale_alpha=latent_model_cfg["scale_alpha"],
num_heads=latent_model_cfg["num_heads"],
attn_drop=latent_model_cfg["attn_drop"],
proj_drop=latent_model_cfg["proj_drop"],
ffn_drop=latent_model_cfg["ffn_drop"],
# inter-attn downsample/upsample
downsample=latent_model_cfg["downsample"],
downsample_type=latent_model_cfg["downsample_type"],
upsample_type=latent_model_cfg["upsample_type"],
upsample_kernel_size=latent_model_cfg["upsample_kernel_size"],
# attention
depth=latent_model_cfg["depth"],
block_attn_patterns=block_attn_patterns,
# global vectors
num_global_vectors=latent_model_cfg["num_global_vectors"],
use_global_vector_ffn=latent_model_cfg["use_global_vector_ffn"],
use_global_self_attn=latent_model_cfg["use_global_self_attn"],
separate_global_qkv=latent_model_cfg["separate_global_qkv"],
global_dim_ratio=latent_model_cfg["global_dim_ratio"],
# misc
ffn_activation=latent_model_cfg["ffn_activation"],
gated_ffn=latent_model_cfg["gated_ffn"],
norm_layer=latent_model_cfg["norm_layer"],
padding_type=latent_model_cfg["padding_type"],
checkpoint_level=latent_model_cfg["checkpoint_level"],
pos_embed_type=latent_model_cfg["pos_embed_type"],
use_relative_pos=latent_model_cfg["use_relative_pos"],
self_attn_use_final_proj=latent_model_cfg["self_attn_use_final_proj"],
# initialization
attn_linear_init_mode=latent_model_cfg["attn_linear_init_mode"],
ffn_linear_init_mode=latent_model_cfg["ffn_linear_init_mode"],
ffn2_linear_init_mode=latent_model_cfg["ffn2_linear_init_mode"],
attn_proj_linear_init_mode=latent_model_cfg["attn_proj_linear_init_mode"],
conv_init_mode=latent_model_cfg["conv_init_mode"],
down_linear_init_mode=latent_model_cfg["down_up_linear_init_mode"],
up_linear_init_mode=latent_model_cfg["down_up_linear_init_mode"],
global_proj_linear_init_mode=latent_model_cfg["global_proj_linear_init_mode"],
norm_init_mode=latent_model_cfg["norm_init_mode"],
# timestep embedding for diffusion
time_embed_channels_mult=latent_model_cfg["time_embed_channels_mult"],
time_embed_use_scale_shift_norm=latent_model_cfg["time_embed_use_scale_shift_norm"],
time_embed_dropout=latent_model_cfg["time_embed_dropout"],
unet_res_connect=latent_model_cfg["unet_res_connect"]
)
vae_cfg = OmegaConf.to_object(oc.model.vae)
first_stage_model = AutoencoderKL(
down_block_types=vae_cfg["down_block_types"],
in_channels=vae_cfg["in_channels"],
block_out_channels=vae_cfg["block_out_channels"],
act_fn=vae_cfg["act_fn"],
latent_channels=vae_cfg["latent_channels"],
up_block_types=vae_cfg["up_block_types"],
norm_num_groups=vae_cfg["norm_num_groups"],
layers_per_block=vae_cfg["layers_per_block"],
out_channels=vae_cfg["out_channels"], )
pretrained_ckpt_path = vae_cfg["pretrained_ckpt_path"]
if pretrained_ckpt_path is not None:
state_dict = torch.load(os.path.join(default_pretrained_vae_dir, vae_cfg["pretrained_ckpt_path"]),
map_location=torch.device("cpu"))
first_stage_model.load_state_dict(state_dict=state_dict)
else:
warnings.warn(f"Pretrained weights for `AutoencoderKL` not set. Run for sanity check only.")
diffusion_cfg = OmegaConf.to_object(oc.model.diffusion)
super(PreDiffSEVIRPLModule, self).__init__(
torch_nn_module=latent_model,
layout=oc.layout.layout,
data_shape=diffusion_cfg["data_shape"],
timesteps=diffusion_cfg["timesteps"],
beta_schedule=diffusion_cfg["beta_schedule"],
loss_type=self.oc.optim.loss_type,
monitor=self.oc.optim.monitor,
use_ema=diffusion_cfg["use_ema"],
log_every_t=diffusion_cfg["log_every_t"],
clip_denoised=diffusion_cfg["clip_denoised"],
linear_start=diffusion_cfg["linear_start"],
linear_end=diffusion_cfg["linear_end"],
cosine_s=diffusion_cfg["cosine_s"],
given_betas=diffusion_cfg["given_betas"],
original_elbo_weight=diffusion_cfg["original_elbo_weight"],
v_posterior=diffusion_cfg["v_posterior"],
l_simple_weight=diffusion_cfg["l_simple_weight"],
parameterization=diffusion_cfg["parameterization"],
learn_logvar=diffusion_cfg["learn_logvar"],
logvar_init=diffusion_cfg["logvar_init"],
# latent diffusion
latent_shape=diffusion_cfg["latent_shape"],
first_stage_model=first_stage_model,
cond_stage_model=diffusion_cfg["cond_stage_model"],
num_timesteps_cond=diffusion_cfg["num_timesteps_cond"],
cond_stage_trainable=diffusion_cfg["cond_stage_trainable"],
cond_stage_forward=diffusion_cfg["cond_stage_forward"],
scale_by_std=diffusion_cfg["scale_by_std"],
scale_factor=diffusion_cfg["scale_factor"], )
# knowledge alignment
knowledge_alignment_cfg = OmegaConf.to_object(oc.model.align)
self.alignment_type = knowledge_alignment_cfg["alignment_type"]
self.use_alignment = self.alignment_type is not None
if self.use_alignment:
alignment_ckpt_path = os.path.join(default_pretrained_alignment_dir, knowledge_alignment_cfg["model_ckpt_path"])
self.alignment_obj = SEVIRAvgIntensityAlignment(
alignment_type=knowledge_alignment_cfg["alignment_type"],
guide_scale=knowledge_alignment_cfg["guide_scale"],
model_type=knowledge_alignment_cfg["model_type"],
model_args=knowledge_alignment_cfg["model_args"],
model_ckpt_path=alignment_ckpt_path
)
disable_train(self.alignment_obj.model)
self.alignment_model = self.alignment_obj.model
alignment_fn = self.alignment_obj.get_mean_shift
else:
alignment_fn = None
self.set_alignment(alignment_fn=alignment_fn)
# lr_scheduler
self.total_num_steps = total_num_steps
# logging
self.save_dir = save_dir
self.logging_prefix = oc.logging.logging_prefix
# visualization
self.train_example_data_idx_list = list(oc.eval.train_example_data_idx_list)
self.val_example_data_idx_list = list(oc.eval.val_example_data_idx_list)
self.test_example_data_idx_list = list(oc.eval.test_example_data_idx_list)
self.eval_example_only = oc.eval.eval_example_only
if self.oc.eval.eval_unaligned:
self.valid_mse = torchmetrics.MeanSquaredError()
self.valid_mae = torchmetrics.MeanAbsoluteError()
self.valid_score = SEVIRSkillScore(
mode=self.oc.dataset.metrics_mode,
seq_len=self.oc.layout.out_len,
layout=self.layout,
threshold_list=self.oc.dataset.threshold_list,
metrics_list=self.oc.dataset.metrics_list,
eps=1e-4
)
self.test_mse = torchmetrics.MeanSquaredError()
self.test_mae = torchmetrics.MeanAbsoluteError()
self.test_ssim = torchmetrics.image.StructuralSimilarityIndexMeasure()
self.test_score = SEVIRSkillScore(
mode=self.oc.dataset.metrics_mode,
seq_len=self.oc.layout.out_len,
layout=self.layout,
threshold_list=self.oc.dataset.threshold_list,
metrics_list=self.oc.dataset.metrics_list,
eps=1e-4
)
self.test_fvd = FrechetVideoDistance(
feature=self.oc.eval.fvd_features,
layout=self.layout,
reset_real_features=False,
normalize=False,
auto_t=True, )
if self.oc.eval.eval_aligned:
self.valid_aligned_mse = torchmetrics.MeanSquaredError()
self.valid_aligned_mae = torchmetrics.MeanAbsoluteError()
self.valid_aligned_score = SEVIRSkillScore(
mode=self.oc.dataset.metrics_mode,
seq_len=self.oc.layout.out_len,
layout=self.layout,
threshold_list=self.oc.dataset.threshold_list,
metrics_list=self.oc.dataset.metrics_list,
eps=1e-4, )
self.test_aligned_mse = torchmetrics.MeanSquaredError()
self.test_aligned_mae = torchmetrics.MeanAbsoluteError()
self.test_aligned_ssim = torchmetrics.image.StructuralSimilarityIndexMeasure()
self.test_aligned_score = SEVIRSkillScore(
mode=self.oc.dataset.metrics_mode,
seq_len=self.oc.layout.out_len,
layout=self.layout,
threshold_list=self.oc.dataset.threshold_list,
metrics_list=self.oc.dataset.metrics_list,
eps=1e-4, )
self.test_aligned_fvd = FrechetVideoDistance(
feature=self.oc.eval.fvd_features,
layout=self.layout,
reset_real_features=False,
normalize=False,
auto_t=True, )
self.configure_save(cfg_file_path=oc_file)
def configure_save(self, cfg_file_path=None):
self.save_dir = os.path.join(default_exps_dir, self.save_dir)
os.makedirs(self.save_dir, exist_ok=True)
if cfg_file_path is not None:
cfg_file_target_path = os.path.join(self.save_dir, "cfg.yaml")
if (not os.path.exists(cfg_file_target_path)) or \
(not os.path.samefile(cfg_file_path, cfg_file_target_path)):
copyfile(cfg_file_path, cfg_file_target_path)
self.example_save_dir = os.path.join(self.save_dir, "examples")
os.makedirs(self.example_save_dir, exist_ok=True)
self.npy_save_dir = os.path.join(self.save_dir, "npy")
os.makedirs(self.npy_save_dir, exist_ok=True)
# region Get Default Config
def get_base_config(self, oc_from_file=None):
oc = OmegaConf.create()
oc.layout = self.get_layout_config()
oc.optim = self.get_optim_config()
oc.logging = self.get_logging_config()
oc.trainer = self.get_trainer_config()
oc.eval = self.get_eval_config()
oc.model = self.get_model_config()
oc.dataset = self.get_dataset_config()
if oc_from_file is not None:
# oc = apply_omegaconf_overrides(oc, oc_from_file)
oc = OmegaConf.merge(oc, oc_from_file)
return oc
@staticmethod
def get_layout_config():
cfg = OmegaConf.create()
cfg.in_len = 7
cfg.out_len = 6
cfg.in_step=1
cfg.out_step=1
cfg.in_out_diff=1
cfg.img_height = 128
cfg.img_width = 128
cfg.data_channels = 4
cfg.layout = "NTHWC"
return cfg
@staticmethod
def get_model_config():
cfg = OmegaConf.create()
layout_cfg = PreDiffSEVIRPLModule.get_layout_config()
cfg.diffusion = OmegaConf.create()
cfg.diffusion.data_shape = (layout_cfg.out_len,
layout_cfg.img_height,
layout_cfg.img_width,
layout_cfg.data_channels)
cfg.diffusion.timesteps = 1000
cfg.diffusion.beta_schedule = "linear"
cfg.diffusion.use_ema = True
cfg.diffusion.log_every_t = 100 # log every `log_every_t` timesteps. Must be smaller than `timesteps`.
cfg.diffusion.clip_denoised = False
cfg.diffusion.linear_start = 1e-4
cfg.diffusion.linear_end = 2e-2
cfg.diffusion.cosine_s = 8e-3
cfg.diffusion.given_betas = None
cfg.diffusion.original_elbo_weight = 0.
cfg.diffusion.v_posterior = 0.
cfg.diffusion.l_simple_weight = 1.
cfg.diffusion.parameterization = "eps"
cfg.diffusion.learn_logvar = None
cfg.diffusion.logvar_init = 0.
# latent diffusion
cfg.diffusion.latent_shape = [10, 16, 16, 4]
cfg.diffusion.cond_stage_model = "__is_first_stage__"
cfg.diffusion.num_timesteps_cond = None
cfg.diffusion.cond_stage_trainable = False
cfg.diffusion.cond_stage_forward = None
cfg.diffusion.scale_by_std = False
cfg.diffusion.scale_factor = 1.0
cfg.diffusion.latent_cond_shape = [10, 16, 16, 4]
# knowledge alignment
cfg.align = OmegaConf.create()
cfg.align.alignment_type = None
cfg.align.guide_scale = 1.0
cfg.align.model_type = "cuboid"
cfg.align.model_ckpt_path = "tmp.pt"
cfg.align.model_args = OmegaConf.create()
# Earthformer
cfg.align.model_args.input_shape = [6, 16, 16, 4]
cfg.align.model_args.out_channels = 2
cfg.align.model_args.base_units = 16
cfg.align.model_args.block_units = None
cfg.align.model_args.scale_alpha = 1.0
cfg.align.model_args.depth = [1, 1]
cfg.align.model_args.downsample = 2
cfg.align.model_args.downsample_type = "patch_merge"
cfg.align.model_args.block_attn_patterns = "axial"
cfg.align.model_args.num_heads = 4
cfg.align.model_args.attn_drop = 0.0
cfg.align.model_args.proj_drop = 0.0
cfg.align.model_args.ffn_drop = 0.0
cfg.align.model_args.ffn_activation = "gelu"
cfg.align.model_args.gated_ffn = False
cfg.align.model_args.norm_layer = "layer_norm"
cfg.align.model_args.use_inter_ffn = True
cfg.align.model_args.hierarchical_pos_embed = False
cfg.align.model_args.pos_embed_type = 't+h+w'
cfg.align.model_args.padding_type = "zero"
cfg.align.model_args.checkpoint_level = 0
cfg.align.model_args.use_relative_pos = True
cfg.align.model_args.self_attn_use_final_proj = True
# global vectors
cfg.align.model_args.num_global_vectors = 0
cfg.align.model_args.use_global_vector_ffn = True
cfg.align.model_args.use_global_self_attn = False
cfg.align.model_args.separate_global_qkv = False
cfg.align.model_args.global_dim_ratio = 1
# initialization
cfg.align.model_args.attn_linear_init_mode = "0"
cfg.align.model_args.ffn_linear_init_mode = "0"
cfg.align.model_args.ffn2_linear_init_mode = "2"
cfg.align.model_args.attn_proj_linear_init_mode = "2"
cfg.align.model_args.conv_init_mode = "0"
cfg.align.model_args.down_linear_init_mode = "0"
cfg.align.model_args.global_proj_linear_init_mode = "2"
cfg.align.model_args.norm_init_mode = "0"
# timestep embedding for diffusion
cfg.align.model_args.time_embed_channels_mult = 4
cfg.align.model_args.time_embed_use_scale_shift_norm = False
cfg.align.model_args.time_embed_dropout = 0.0
# readout
cfg.align.model_args.pool = "attention"
cfg.align.model_args.readout_seq = True
cfg.align.model_args.out_len = 6
cfg.latent_model = OmegaConf.create()
cfg.latent_model.input_shape = [10, 16, 16, 4]
cfg.latent_model.target_shape = [10, 16, 16, 4]
cfg.latent_model.base_units = 4
# block_units = null
cfg.latent_model.scale_alpha = 1.0
cfg.latent_model.num_heads = 4
cfg.latent_model.attn_drop = 0.1
cfg.latent_model.proj_drop = 0.1
cfg.latent_model.ffn_drop = 0.1
# inter-attn downsample/upsample
cfg.latent_model.downsample = 2
cfg.latent_model.downsample_type = "patch_merge"
cfg.latent_model.upsample_type = "upsample"
cfg.latent_model.upsample_kernel_size = 3
# cuboid attention
cfg.latent_model.depth = [1, 1]
cfg.latent_model.self_pattern = "axial"
# global vectors
cfg.latent_model.num_global_vectors = 0
cfg.latent_model.use_dec_self_global = False
cfg.latent_model.dec_self_update_global = True
cfg.latent_model.use_dec_cross_global = False
cfg.latent_model.use_global_vector_ffn = False
cfg.latent_model.use_global_self_attn = True
cfg.latent_model.separate_global_qkv = True
cfg.latent_model.global_dim_ratio = 1
# mise
cfg.latent_model.ffn_activation = "gelu"
cfg.latent_model.gated_ffn = False
cfg.latent_model.norm_layer = "layer_norm"
cfg.latent_model.padding_type = "zeros"
cfg.latent_model.pos_embed_type = "t+h+w"
cfg.latent_model.checkpoint_level = 0
cfg.latent_model.use_relative_pos = True
cfg.latent_model.self_attn_use_final_proj = True
# initialization
cfg.latent_model.attn_linear_init_mode = "0"
cfg.latent_model.ffn_linear_init_mode = "0"
cfg.latent_model.ffn2_linear_init_mode = "2"
cfg.latent_model.attn_proj_linear_init_mode = "2"
cfg.latent_model.conv_init_mode = "0"
cfg.latent_model.down_up_linear_init_mode = "0"
cfg.latent_model.global_proj_linear_init_mode = "2"
cfg.latent_model.norm_init_mode = "0"
# timestep embedding for diffusion
cfg.latent_model.time_embed_channels_mult = 4
cfg.latent_model.time_embed_use_scale_shift_norm = False
cfg.latent_model.time_embed_dropout = 0.0
cfg.latent_model.unet_res_connect = True
cfg.vae = OmegaConf.create()
cfg.vae.data_channels = layout_cfg.data_channels
# from stable-diffusion-v1-5
cfg.vae.down_block_types = ['DownEncoderBlock2D', 'DownEncoderBlock2D', 'DownEncoderBlock2D', 'DownEncoderBlock2D']
cfg.vae.in_channels = cfg.vae.data_channels
cfg.vae.block_out_channels = [128, 256, 512, 512]
cfg.vae.act_fn = 'silu'
cfg.vae.latent_channels = 4
cfg.vae.up_block_types = ['UpDecoderBlock2D', 'UpDecoderBlock2D', 'UpDecoderBlock2D', 'UpDecoderBlock2D']
cfg.vae.norm_num_groups = 32
cfg.vae.layers_per_block = 2
cfg.vae.out_channels = cfg.vae.data_channels
return cfg
@staticmethod
def get_dataset_config():
cfg = OmegaConf.create()
cfg.dataset_name = "sevir_lr"
cfg.img_height = 128
cfg.img_width = 128
cfg.in_len = 7
cfg.out_len = 6
cfg.seq_len = 13
cfg.plot_stride = 1
cfg.interval_real_time = 10
cfg.sample_mode = "sequent"
cfg.stride = cfg.out_len
cfg.layout = "NTHWC"
cfg.start_date = None
cfg.train_val_split_date = (2019, 1, 1)
cfg.train_test_split_date = (2019, 6, 1)
cfg.end_date = None
cfg.metrics_mode = "0"
cfg.metrics_list = ('csi', 'pod', 'sucr', 'bias')
cfg.threshold_list = (16, 74, 133, 160, 181, 219)
cfg.aug_mode = "1"
return cfg
@staticmethod
def get_optim_config():
cfg = OmegaConf.create()
cfg.seed = None
cfg.total_batch_size = 32
cfg.micro_batch_size = 8
cfg.float32_matmul_precision = "high"
cfg.method = "adamw"
cfg.lr = 1.0E-6
cfg.wd = 1.0E-2
cfg.betas = (0.9, 0.999)
cfg.gradient_clip_val = 1.0
cfg.max_epochs = 50
cfg.loss_type = "l2"
# scheduler
cfg.warmup_percentage = 0.2
cfg.lr_scheduler_mode = "cosine" # Can be strings like 'linear', 'cosine', 'platue'
cfg.min_lr_ratio = 1.0E-3
cfg.warmup_min_lr_ratio = 0.0
# early stopping
cfg.monitor = "valid_loss_epoch"
cfg.early_stop = False
cfg.early_stop_mode = "min"
cfg.early_stop_patience = 5
cfg.save_top_k = 1
return cfg
@staticmethod
def get_logging_config():
cfg = OmegaConf.create()
cfg.logging_prefix = "PreDiff"
cfg.monitor_lr = True
cfg.monitor_device = False
cfg.track_grad_norm = -1
cfg.use_wandb = False
cfg.profiler = None
cfg.save_npy = False
return cfg
@staticmethod
def get_trainer_config():
cfg = OmegaConf.create()
cfg.check_val_every_n_epoch = 1
cfg.log_step_ratio = 0.001 # Logging every 1% of the total training steps per epoch
cfg.precision = 32
cfg.find_unused_parameters = True
cfg.num_sanity_val_steps = 2
return cfg
@staticmethod
def get_eval_config():
cfg = OmegaConf.create()
cfg.train_example_data_idx_list = [0, ]
cfg.val_example_data_idx_list = [0, ]
cfg.test_example_data_idx_list = [0, ]
cfg.eval_example_only = False
cfg.eval_aligned = True
cfg.eval_unaligned = True
cfg.num_samples_per_context = 1
cfg.font_size = 20
cfg.label_offset = (-0.5, 0.5)
cfg.label_avg_int = False
cfg.fvd_features = 400
return cfg
# endregion
# region Trainer and Optimizer Config
def configure_optimizers(self):
optim_cfg = self.oc.optim
params = list(self.torch_nn_module.parameters())
if self.cond_stage_trainable:
print(f"{self.__class__.__name__}: Also optimizing conditioner params!")
params = params + list(self.cond_stage_model.parameters())
if self.learn_logvar:
print('Diffusion model optimizing logvar')
params.append(self.logvar)
if optim_cfg.method == "adamw":
optimizer = torch.optim.AdamW(params, lr=optim_cfg.lr, betas=optim_cfg.betas)
else:
raise NotImplementedError(f"opimization method {optim_cfg.method} not supported.")
warmup_iter = int(np.round(self.oc.optim.warmup_percentage * self.total_num_steps))
if optim_cfg.lr_scheduler_mode == 'none':
return {'optimizer': optimizer}
else:
if optim_cfg.lr_scheduler_mode == 'cosine':
warmup_scheduler = LambdaLR(optimizer,
lr_lambda=warmup_lambda(warmup_steps=warmup_iter,
min_lr_ratio=optim_cfg.warmup_min_lr_ratio))
cosine_scheduler = CosineAnnealingLR(optimizer,
T_max=(self.total_num_steps - warmup_iter),
eta_min=optim_cfg.min_lr_ratio * optim_cfg.lr)
lr_scheduler = SequentialLR(optimizer, schedulers=[warmup_scheduler, cosine_scheduler],
milestones=[warmup_iter])
lr_scheduler_config = {
'scheduler': lr_scheduler,
'interval': 'step',
'frequency': 1,
}
else:
raise NotImplementedError
return {'optimizer': optimizer, 'lr_scheduler': lr_scheduler_config}
def set_trainer_kwargs(self, **kwargs):
r"""
Default kwargs used when initializing pl.Trainer
"""
if self.oc.logging.profiler is None:
profiler = None
elif self.oc.logging.profiler == "pytorch":
profiler = PyTorchProfiler(filename=f"{self.oc.logging.logging_prefix}_PyTorchProfiler.log")
else:
raise NotImplementedError
checkpoint_callback = ModelCheckpoint(
monitor=self.oc.optim.monitor,
dirpath=os.path.join(self.save_dir, "checkpoints"),
filename="{epoch:03d}_{val/loss:.4f}",
auto_insert_metric_name=False,
save_top_k=self.oc.optim.save_top_k,
save_last=True,
mode="min",
)
callbacks = kwargs.pop("callbacks", [])
assert isinstance(callbacks, list)
for ele in callbacks:
assert isinstance(ele, Callback)
callbacks += [checkpoint_callback, ]
if self.oc.logging.monitor_lr:
callbacks += [LearningRateMonitor(logging_interval='step'), ]
if self.oc.logging.monitor_device:
callbacks += [DeviceStatsMonitor(), ]
if self.oc.optim.early_stop:
callbacks += [EarlyStopping(monitor=self.oc.optim.monitor,
min_delta=0.0,
patience=self.oc.optim.early_stop_patience,
verbose=False,
mode=self.oc.optim.early_stop_mode), ]
logger = kwargs.pop("logger", [])
tb_logger = pl_loggers.TensorBoardLogger(save_dir=self.save_dir)
csv_logger = pl_loggers.CSVLogger(save_dir=self.save_dir)
logger += [tb_logger, csv_logger]
if self.oc.logging.use_wandb:
wandb_logger = pl_loggers.WandbLogger(
name = self.oc.logging.logging_name,
id = self.oc.logging.run_id,
project=self.oc.logging.logging_prefix,
save_dir=self.save_dir
)
logger += [wandb_logger, ]
log_every_n_steps = max(1, int(self.oc.trainer.log_step_ratio * self.total_num_steps))
trainer_init_keys = inspect.signature(Trainer).parameters.keys()
ret = dict(
callbacks=callbacks,
# log
logger=logger,
log_every_n_steps=log_every_n_steps,
profiler=profiler,
# save
default_root_dir=self.save_dir,
# ddp
accelerator="gpu",
strategy=DDPStrategy(find_unused_parameters=self.oc.trainer.find_unused_parameters),
# strategy=ApexDDPStrategy(find_unused_parameters=False, delay_allreduce=True),
# optimization
max_epochs=self.oc.optim.max_epochs,
check_val_every_n_epoch=self.oc.trainer.check_val_every_n_epoch,
gradient_clip_val=self.oc.optim.gradient_clip_val,
# NVIDIA amp
precision=self.oc.trainer.precision,
# misc
num_sanity_val_steps=self.oc.trainer.num_sanity_val_steps,
inference_mode=False,
)
oc_trainer_kwargs = OmegaConf.to_object(self.oc.trainer)
oc_trainer_kwargs = {key: val for key, val in oc_trainer_kwargs.items() if key in trainer_init_keys}
ret.update(oc_trainer_kwargs)
ret.update(kwargs)
return ret
# endregion
# region Properties Extraction and Misc Calc
@classmethod
def get_total_num_steps(
cls,
num_samples: int,
total_batch_size: int,
epoch: int = None):
r"""
Parameters
----------
num_samples: int
The number of samples of the datasets. `num_samples / micro_batch_size` is the number of steps per epoch.
total_batch_size: int
`total_batch_size == micro_batch_size * world_size * grad_accum`
epoch: int
"""
if epoch is None:
epoch = cls.get_optim_config().max_epochs
return int(epoch * num_samples / total_batch_size)
@staticmethod
def get_sevir_datamodule(dataset_cfg,
micro_batch_size: int = 1,
num_workers: int = 4):
dm = SEVIRLightningDataModule(
seq_len=dataset_cfg["seq_len"],
sample_mode=dataset_cfg["sample_mode"],
stride=dataset_cfg["stride"],
batch_size=micro_batch_size,
layout=dataset_cfg["layout"],
output_type=np.float32,
preprocess=True,
rescale_method="01",
verbose=False,
aug_mode=dataset_cfg["aug_mode"],
ret_contiguous=False,
# datamodule_only
dataset_name=dataset_cfg["dataset_name"],
start_date=dataset_cfg["start_date"],
train_test_split_date=dataset_cfg["train_test_split_date"],
end_date=dataset_cfg["end_date"],
val_ratio=dataset_cfg["val_ratio"],
num_workers=num_workers, )
return dm
@property
def in_slice(self):
if not hasattr(self, "_in_slice"):
in_slice, out_slice = step_layout_to_in_out_slice(
layout=self.oc.layout.layout,
in_len=self.oc.layout.in_len, in_step= self.oc.layout.in_step,
out_len=self.oc.layout.out_len, out_step = self.oc.layout.out_step,
in_out_diff= self.oc.layout.in_out_diff
)
self._in_slice = in_slice
self._out_slice = out_slice
return self._in_slice
@property
def out_slice(self):
if not hasattr(self, "_out_slice"):
in_slice, out_slice = step_layout_to_in_out_slice(
layout=self.oc.layout.layout,
in_len=self.oc.layout.in_len, in_step= self.oc.layout.in_step,
out_len=self.oc.layout.out_len, out_step = self.oc.layout.out_step,
in_out_diff= self.oc.layout.in_out_diff
)
self._in_slice = in_slice
self._out_slice = out_slice
return self._out_slice
@torch.no_grad()
def get_input(self, batch, **kwargs):
r"""
dataset dependent
re-implement it for each specific dataset
Parameters
----------
batch: Any
raw data batch from specific dataloader
Returns
-------
out: Sequence[torch.Tensor, Dict[str, Any]]
out[0] should be a torch.Tensor which is the target to generate
out[1] should be a dict consists of several key-value pairs for conditioning
"""
return self._get_input_sevirlr(batch=batch, return_verbose=kwargs.get("return_verbose", False))
@torch.no_grad()
def _get_input_sevirlr(self, batch, return_verbose=False):
seq = batch
in_seq = seq[self.in_slice]
out_seq = seq[self.out_slice].contiguous()
if return_verbose:
return out_seq, {"y": in_seq}, in_seq
else:
return out_seq, {"y": in_seq}
# endregion
# region Operation Step
def training_step(self, batch, batch_idx):
loss, loss_dict = self(batch)
self.log_dict(loss_dict, prog_bar=True, logger=True, on_step=True, on_epoch=True, sync_dist=False)
micro_batch_size = batch.shape[self.batch_axis]
data_idx = int(batch_idx * micro_batch_size)
if self.current_epoch % self.oc.trainer.check_val_every_n_epoch == 0 \
and self.local_rank == 0:
if data_idx in self.train_example_data_idx_list:
target_seq, cond, context_seq = \
self.get_input(batch, return_verbose=True)
aligned_pred_seq_list = []
aligned_pred_label_list = []
pred_seq_list = []
pred_label_list = []
for i in range(self.oc.eval.num_samples_per_context):
# aligned sampling
if self.use_alignment and self.oc.eval.eval_aligned:
if self.alignment_type == "avg_x":
alignment_kwargs = get_alignment_kwargs_avg_x(context_seq=context_seq,
target_seq=target_seq)
else:
raise NotImplementedError
pred_seq = self.sample(
cond=cond,
batch_size=micro_batch_size,
return_intermediates=False,
use_alignment=True,
alignment_kwargs=alignment_kwargs,
verbose=False, ).contiguous()
aligned_pred_seq_list.append(pred_seq[0].detach().float().cpu().numpy())
aligned_pred_label_list.append(f"{self.oc.logging.logging_prefix}_aligned_pred_{i}")
# no alignment
if self.oc.eval.eval_unaligned:
pred_seq = self.sample(
cond=cond,
batch_size=micro_batch_size,
return_intermediates=False,
verbose=False, ).contiguous()
pred_seq_list.append(pred_seq[0].detach().float().cpu().numpy())
pred_label_list.append(f"{self.oc.logging.logging_prefix}_pred_{i}")
pred_seq_list = aligned_pred_seq_list + pred_seq_list
pred_label_list = aligned_pred_label_list + pred_label_list
self.save_vis_step_end(
data_idx=data_idx,
context_seq=context_seq[0].detach().float().cpu().numpy(),
target_seq=target_seq[0].detach().float().cpu().numpy(),
pred_seq=pred_seq_list,
pred_label=pred_label_list,
mode="train", )
return loss
def validation_step(self, batch, batch_idx):
_, loss_dict_no_ema = self(batch)
with self.ema_scope():
_, loss_dict_ema = self(batch)
loss_dict_ema = {key + '_ema': loss_dict_ema[key] for key in loss_dict_ema}
self.log_dict(loss_dict_no_ema, prog_bar=False, logger=True, on_step=False, on_epoch=True, sync_dist=True)
self.log_dict(loss_dict_ema, prog_bar=False, logger=True, on_step=False, on_epoch=True, sync_dist=True)
micro_batch_size = batch.shape[self.batch_axis]
data_idx = int(batch_idx * micro_batch_size)
if not self.eval_example_only or data_idx in self.val_example_data_idx_list:
target_seq, cond, context_seq = \
self.get_input(batch, return_verbose=True)
aligned_pred_seq_list = []
aligned_pred_label_list = []
pred_seq_list = []
pred_label_list = []
for i in range(self.oc.eval.num_samples_per_context):
# aligned sampling
if self.use_alignment and self.oc.eval.eval_aligned:
if self.alignment_type == "avg_x":
alignment_kwargs = get_alignment_kwargs_avg_x(context_seq=context_seq,
target_seq=target_seq)
else:
raise NotImplementedError
pred_seq = self.sample(
cond=cond,
batch_size=micro_batch_size,
return_intermediates=False,
use_alignment=True,
alignment_kwargs=alignment_kwargs,
verbose=False, ).contiguous()
aligned_pred_seq_list.append(pred_seq[0].detach().float().cpu().numpy())
aligned_pred_label_list.append(f"{self.oc.logging.logging_prefix}_aligned_pred_{i}")
if pred_seq.dtype is not torch.float:
pred_seq = pred_seq.float()
self.valid_aligned_mse(pred_seq, target_seq)
self.valid_aligned_mae(pred_seq, target_seq)
self.valid_aligned_score.update(pred_seq, target_seq)
# no alignment
if self.oc.eval.eval_unaligned:
pred_seq = self.sample(
cond=cond,
batch_size=micro_batch_size,
return_intermediates=False,
verbose=False, ).contiguous()
pred_seq_list.append(pred_seq[0].detach().float().cpu().numpy())
pred_label_list.append(f"{self.oc.logging.logging_prefix}_pred_{i}")
if pred_seq.dtype is not torch.float:
pred_seq = pred_seq.float()
self.valid_mse(pred_seq, target_seq)
self.valid_mae(pred_seq, target_seq)
self.valid_score.update(pred_seq, target_seq)
pred_seq_list = aligned_pred_seq_list + pred_seq_list
pred_label_list = aligned_pred_label_list + pred_label_list
self.save_vis_step_end(
data_idx=data_idx,
context_seq=context_seq[0].detach().float().cpu().numpy(),
target_seq=target_seq[0].detach().float().cpu().numpy(),
pred_seq=pred_seq_list,
pred_label=pred_label_list,
mode="val",
suffix=f"_rank{self.local_rank}", )
def on_validation_epoch_end(self):
if self.oc.eval.eval_unaligned:
valid_mse = self.valid_mse.compute()
valid_mae = self.valid_mae.compute()
valid_score = self.valid_score.compute()
valid_loss = -valid_score["avg"]["csi"]
self.log('valid_loss_epoch', valid_loss, prog_bar=True, on_step=False, on_epoch=True, sync_dist=True)
self.log('valid_mse_epoch', valid_mse, prog_bar=True, on_step=False, on_epoch=True, sync_dist=True)
self.log('valid_mae_epoch', valid_mae, prog_bar=True, on_step=False, on_epoch=True, sync_dist=True)
self.log_score_epoch_end(score_dict=valid_score, prefix="valid")
self.valid_mse.reset()
self.valid_mae.reset()
self.valid_score.reset()
if self.oc.eval.eval_aligned:
valid_mse = self.valid_aligned_mse.compute()
valid_mae = self.valid_aligned_mae.compute()
valid_score = self.valid_aligned_score.compute()
valid_loss = -valid_score["avg"]["csi"]
self.log('valid_aligned_loss_epoch', valid_loss, prog_bar=True, on_step=False, on_epoch=True,
sync_dist=True)
self.log('valid_aligned_mse_epoch', valid_mse, prog_bar=True, on_step=False, on_epoch=True, sync_dist=True)
self.log('valid_aligned_mae_epoch', valid_mae, prog_bar=True, on_step=False, on_epoch=True, sync_dist=True)
self.log_score_epoch_end(score_dict=valid_score, prefix="valid_aligned")
self.valid_aligned_mse.reset()
self.valid_aligned_mae.reset()
self.valid_aligned_score.reset()
def test_step(self, batch, batch_idx):
micro_batch_size = batch.shape[self.batch_axis]
data_idx = int(batch_idx * micro_batch_size)
if not self.eval_example_only or data_idx in self.val_example_data_idx_list:
target_seq, cond, context_seq = \
self.get_input(batch, return_verbose=True)
target_seq_bchw = rearrange(target_seq, "b t h w c -> (b t) c h w")
aligned_pred_seq_list = []
aligned_pred_label_list = []
pred_seq_list = []
pred_label_list = []
for i in range(self.oc.eval.num_samples_per_context):
# aligned sampling
if self.use_alignment and self.oc.eval.eval_aligned:
if self.alignment_type == "avg_x":
alignment_kwargs = get_alignment_kwargs_avg_x(context_seq=context_seq,
target_seq=target_seq)
else:
raise NotImplementedError
pred_seq = self.sample(
cond=cond,
batch_size=micro_batch_size,
return_intermediates=False,
use_alignment=True,
alignment_kwargs=alignment_kwargs,
verbose=False, ).contiguous()
if self.oc.logging.save_npy:
npy_path = os.path.join(self.npy_save_dir,
f"batch{batch_idx}_rank{self.local_rank}_sample{i}_aligned.npy")
np.save(npy_path, pred_seq.detach().float().cpu().numpy())
aligned_pred_seq_list.append(pred_seq[0].detach().float().cpu().numpy())
aligned_pred_label_list.append(f"{self.oc.logging.logging_prefix}_aligned_pred_{i}")
if pred_seq.dtype is not torch.float:
pred_seq = pred_seq.float()
self.test_aligned_mse(pred_seq, target_seq)
self.test_aligned_mae(pred_seq, target_seq)
self.test_aligned_score.update(pred_seq, target_seq)
# self.test_aligned_fvd.update(pred_seq, real=False)
pred_seq_bchw = rearrange(pred_seq, "b t h w c -> (b t) c h w")
self.test_aligned_ssim(pred_seq_bchw, target_seq_bchw)
# no alignment
if self.oc.eval.eval_unaligned:
pred_seq = self.sample(
cond=cond,
batch_size=micro_batch_size,
return_intermediates=False,
verbose=False, ).contiguous()
if self.oc.logging.save_npy:
npy_path = os.path.join(self.npy_save_dir,
f"batch{batch_idx}_rank{self.local_rank}_sample{i}.npy")
np.save(npy_path, pred_seq.detach().float().cpu().numpy())
pred_seq_list.append(pred_seq[0].detach().float().cpu().numpy())
pred_label_list.append(f"{self.oc.logging.logging_prefix}_pred_{i}")
if pred_seq.dtype is not torch.float:
pred_seq = pred_seq.float()
self.test_mse(pred_seq, target_seq)
self.test_mae(pred_seq, target_seq)
self.test_score.update(pred_seq, target_seq)
# self.test_fvd.update(pred_seq, real=False)
pred_seq_bchw = rearrange(pred_seq, "b t h w c -> (b t) c h w")
self.test_ssim(pred_seq_bchw, target_seq_bchw)
# if self.use_alignment and self.oc.eval.eval_aligned:
# self.test_aligned_fvd.update(target_seq, real=True)
# if self.oc.eval.eval_unaligned:
# self.test_fvd.update(target_seq, real=True)
pred_seq_list = aligned_pred_seq_list + pred_seq_list
pred_label_list = aligned_pred_label_list + pred_label_list
self.save_vis_step_end(
data_idx=data_idx,
context_seq=context_seq[0].detach().float().cpu().numpy(),
target_seq=target_seq[0].detach().float().cpu().numpy(),
pred_seq=pred_seq_list,
pred_label=pred_label_list,
mode="test",
suffix=f"_rank{self.local_rank}", )
def on_test_epoch_end(self):
if self.oc.eval.eval_unaligned:
test_mse = self.test_mse.compute()
test_mae = self.test_mae.compute()
test_ssim = self.test_ssim.compute()
test_score = self.test_score.compute()
# test_fvd = self.test_fvd.compute()
self.log('test_mse_epoch', test_mse, prog_bar=True, on_step=False, on_epoch=True, sync_dist=True)
self.log('test_mae_epoch', test_mae, prog_bar=True, on_step=False, on_epoch=True, sync_dist=True)
self.log('test_ssim_epoch', test_ssim, prog_bar=True, on_step=False, on_epoch=True, sync_dist=True)
self.log_score_epoch_end(score_dict=test_score, prefix="test")
# self.log('test_fvd_epoch', test_fvd, prog_bar=True, on_step=False, on_epoch=True, sync_dist=True)
self.test_mse.reset()
self.test_mae.reset()
self.test_ssim.reset()
self.test_score.reset()
# self.test_fvd.reset()
if self.oc.eval.eval_aligned:
test_mse = self.test_aligned_mse.compute()
test_mae = self.test_aligned_mae.compute()
test_ssim = self.test_aligned_ssim.compute()
test_score = self.test_aligned_score.compute()
# test_fvd = self.test_aligned_fvd.compute()
self.log('test_aligned_mse_epoch', test_mse, prog_bar=True, on_step=False, on_epoch=True, sync_dist=True)
self.log('test_aligned_mae_epoch', test_mae, prog_bar=True, on_step=False, on_epoch=True, sync_dist=True)
self.log('test_aligned_ssim_epoch', test_ssim, prog_bar=True, on_step=False, on_epoch=True, sync_dist=True)
self.log_score_epoch_end(score_dict=test_score, prefix="test_aligned")
# self.log('test_aligned_fvd_epoch', test_fvd, prog_bar=True, on_step=False, on_epoch=True, sync_dist=True)
self.test_aligned_mse.reset()
self.test_aligned_mae.reset()
self.test_aligned_ssim.reset()
self.test_aligned_score.reset()
# self.test_aligned_fvd.reset()
# endregion
def save_vis_step_end(
self,
data_idx: int,
context_seq: np.ndarray,
target_seq: np.ndarray,
pred_seq: Union[np.ndarray, Sequence[np.ndarray]],
pred_label: Union[str, Sequence[str]] = None,
label_mode: str = "name",
mode: str = "train",
prefix: str = "",
suffix: str = "", ):
r"""
Parameters
----------
data_idx
context_seq, target_seq, pred_seq: np.ndarray
layout should not include batch
mode: str
"""
if mode == "train":
example_data_idx_list = self.train_example_data_idx_list
elif mode == "val":
example_data_idx_list = self.val_example_data_idx_list
elif mode == "test":
example_data_idx_list = self.test_example_data_idx_list
else:
raise ValueError(f"Wrong mode {mode}! Must be in ['train', 'val', 'test'].")
if label_mode == "name":
# use the given label
context_label = "context"
target_label = "target"
elif label_mode == "avg_int":
context_label = f"context\navg_int={np.mean(context_seq):.4f}"
target_label = f"target\navg_int={np.mean(target_seq):.4f}"
if isinstance(pred_label, Sequence):
pred_label = [f"{label}\navg_int={np.mean(seq):.4f}" for label, seq in zip(pred_label, pred_seq)]
elif isinstance(pred_label, str):
pred_label = f"{pred_label}\navg_int={np.mean(pred_seq):.4f}"
else:
raise TypeError(f"Wrong pred_label type {type(pred_label)}! must be in [str, Sequence[str]].")
else:
raise NotImplementedError
if isinstance(pred_seq, Sequence):
seq_list = [context_seq, target_seq] + list(pred_seq)
label_list = [context_label, target_label] + pred_label
else:
seq_list = [context_seq, target_seq, pred_seq]
label_list = [context_label, target_label, pred_label]
if data_idx in example_data_idx_list:
png_save_name = f"{prefix}{mode}_epoch_{self.current_epoch}_data_{data_idx}{suffix}.png"
vis_sevir_seq(
save_path=os.path.join(self.example_save_dir, png_save_name),
seq=seq_list,
label=label_list,
interval_real_time=10,
plot_stride=1, fs=self.oc.eval.fs,
label_offset=self.oc.eval.label_offset,
label_avg_int=self.oc.eval.label_avg_int, )
def log_score_epoch_end(self, score_dict: Dict, prefix: str = "valid"):
for metrics in self.oc.dataset.metrics_list:
for thresh in self.oc.dataset.threshold_list:
score_mean = np.mean(score_dict[thresh][metrics]).item()
self.log(f"{prefix}_{metrics}_{thresh}_epoch", score_mean,
prog_bar=True, on_step=False, on_epoch=True, sync_dist=True)
score_avg_mean = score_dict.get("avg", None)
if score_avg_mean is not None:
score_avg_mean = np.mean(score_avg_mean[metrics]).item()
self.log(f"{prefix}_{metrics}_avg_epoch", score_avg_mean,
prog_bar=True, on_step=False, on_epoch=True, sync_dist=True)
def on_before_optimizer_step(self, optimizer):
# Compute the 2-norm for each layer
# If using mixed precision, the gradients are already unscaled here
# reference: https://lightning.ai/docs/pytorch/2.0.9/debug/debugging_intermediate.html#look-out-for-exploding-gradients
if self.oc.logging.track_grad_norm != -1:
norms = grad_norm(self.torch_nn_module, norm_type=self.oc.logging.track_grad_norm)
self.log_dict(norms)