| import os |
| import GPUtil |
| import torch |
| import sys |
| import hydra |
| import wandb |
|
|
| |
| from pytorch_lightning import LightningDataModule, LightningModule, Trainer |
| from pytorch_lightning.loggers.wandb import WandbLogger |
| from pytorch_lightning.trainer import Trainer |
| from pytorch_lightning.callbacks import ModelCheckpoint |
|
|
| from omegaconf import DictConfig, OmegaConf |
| from data.pdb_dataloader import PdbDataModule |
| from models.flow_module import FlowModule |
| from experiments import utils as eu |
|
|
|
|
| os.environ["WANDB_MODE"] = "offline" |
| log = eu.get_pylogger(__name__) |
| torch.set_float32_matmul_precision('high') |
|
|
|
|
| class Experiment: |
|
|
| def __init__(self, *, cfg: DictConfig): |
| self._cfg = cfg |
| self._data_cfg = cfg.data |
| self._exp_cfg = cfg.experiment |
| self._datamodule: LightningDataModule = PdbDataModule(self._data_cfg) |
| self._model: LightningModule = FlowModule(self._cfg) |
| |
| def train(self): |
| callbacks = [] |
| if self._exp_cfg.debug: |
| log.info("Debug mode.") |
| logger = None |
| self._exp_cfg.num_devices = 1 |
| self._data_cfg.loader.num_workers = 0 |
| else: |
| logger = WandbLogger( |
| **self._exp_cfg.wandb, |
| ) |
| |
| |
| ckpt_dir = self._exp_cfg.checkpointer.dirpath |
| os.makedirs(ckpt_dir, exist_ok=True) |
| log.info(f"Checkpoints saved to {ckpt_dir}") |
| |
| |
| callbacks.append(ModelCheckpoint(**self._exp_cfg.checkpointer)) |
| |
| |
| cfg_path = os.path.join(ckpt_dir, 'config.yaml') |
| with open(cfg_path, 'w') as f: |
| OmegaConf.save(config=self._cfg, f=f.name) |
| cfg_dict = OmegaConf.to_container(self._cfg, resolve=True) |
| flat_cfg = dict(eu.flatten_dict(cfg_dict)) |
| if isinstance(logger.experiment.config, wandb.sdk.wandb_config.Config): |
| logger.experiment.config.update(flat_cfg) |
|
|
| devices = GPUtil.getAvailable(order='memory', limit = 8)[:self._exp_cfg.num_devices] |
| log.info(f"Using devices: {devices}") |
| trainer = Trainer( |
| **self._exp_cfg.trainer, |
| callbacks=callbacks, |
| logger=logger, |
| use_distributed_sampler=False, |
| enable_progress_bar=True, |
| enable_model_summary=True, |
| devices=devices, |
| ) |
|
|
| if self._exp_cfg.warm_start is not None: |
| |
| self._model = self._model.load_from_checkpoint(self._exp_cfg.warm_start, strict=False, map_location="cpu") |
|
|
| trainer.fit( |
| model=self._model, |
| datamodule=self._datamodule, |
| |
| ) |
|
|
|
|
| @hydra.main(version_base=None, config_path="../configs", config_name="base.yaml") |
| def main(cfg: DictConfig): |
|
|
| if cfg.experiment.warm_start is not None and cfg.experiment.warm_start_cfg_override: |
| |
| warm_start_cfg_path = os.path.join( |
| os.path.dirname(cfg.experiment.warm_start), 'config.yaml') |
| warm_start_cfg = OmegaConf.load(warm_start_cfg_path) |
|
|
| |
| |
| OmegaConf.set_struct(cfg.model, False) |
| OmegaConf.set_struct(warm_start_cfg.model, False) |
| cfg.model = OmegaConf.merge(cfg.model, warm_start_cfg.model) |
| OmegaConf.set_struct(cfg.model, True) |
| log.info(f'Loaded warm start config from {warm_start_cfg_path}') |
|
|
| exp = Experiment(cfg=cfg) |
| exp.train() |
|
|
| if __name__ == "__main__": |
| main() |
|
|