def test_ddp_sgd_comm_hook(tmpdir): """Test for DDP FP16 compress hook.""" model = BoringModel() training_type_plugin = DDPPlugin( ddp_comm_state=powerSGD.PowerSGDState(process_group=None), ddp_comm_hook=powerSGD.powerSGD_hook, sync_batchnorm=True, ) trainer = Trainer( max_epochs=1, gpus=2, plugins=[training_type_plugin], default_root_dir=tmpdir, sync_batchnorm=True, fast_dev_run=True, ) trainer.fit(model) trainer_comm_hook = (trainer.accelerator.training_type_plugin._model.get_ddp_logging_data().comm_hook) expected_comm_hook = powerSGD.powerSGD_hook.__qualname__ assert trainer_comm_hook == expected_comm_hook assert trainer.state == TrainerState.FINISHED, f"Training failed with {trainer.state}"
def test_ddp_fp16_compress_comm_hook(tmpdir): """Test for DDP FP16 compress hook.""" model = BoringModel() training_type_plugin = DDPPlugin( ddp_comm_hook=default.fp16_compress_hook, sync_batchnorm=True, ) trainer = Trainer( max_epochs=1, gpus=2, plugins=[training_type_plugin], default_root_dir=tmpdir, sync_batchnorm=True, fast_dev_run=True, ) trainer.fit(model) trainer_comm_hook = (trainer.accelerator.training_type_plugin._model. get_ddp_logging_data().comm_hook) expected_comm_hook = default.fp16_compress_hook.__qualname__ assert trainer_comm_hook == expected_comm_hook assert trainer.state.finished, f"Training failed with {trainer.state}"
def test_ddp_fp16_compress_wrap_sgd_comm_hook(tmpdir): """Test for DDP FP16 compress wrapper for SGD hook.""" model = BoringModel() training_type_plugin = DDPPlugin( ddp_comm_state=powerSGD.PowerSGDState(process_group=None), ddp_comm_hook=powerSGD.powerSGD_hook, ddp_comm_wrapper=default.fp16_compress_wrapper, sync_batchnorm=True, ) trainer = Trainer( max_epochs=1, gpus=2, strategy=training_type_plugin, default_root_dir=tmpdir, sync_batchnorm=True, fast_dev_run=True, ) trainer.fit(model) trainer_comm_hook = trainer.accelerator.training_type_plugin._model.get_ddp_logging_data().comm_hook expected_comm_hook = default.fp16_compress_wrapper(powerSGD.powerSGD_hook).__qualname__ assert trainer_comm_hook == expected_comm_hook assert trainer.state.finished, f"Training failed with {trainer.state}"
def get_trainer(wandb_logger, callbacks, config): gpus = [] if config.gpu0: gpus.append(0) if config.gpu1: gpus.append(1) logging.info("gpus active", gpus) if len(gpus) >= 2: distributed_backend = "ddp" accelerator = "dpp" plugins = DDPPlugin(find_unused_parameters=False) else: distributed_backend = None accelerator = None plugins = None trainer = pl.Trainer( logger=wandb_logger, gpus=gpus, max_epochs=config.NUM_EPOCHS, precision=config.precision_compute, # limit_train_batches=0.1, #only to debug # limit_val_batches=0.1, #only to debug # limit_test_batches=0.1, # val_check_interval=1, auto_lr_find=config.AUTO_LR, log_gpu_memory=True, # distributed_backend=distributed_backend, # accelerator=accelerator, # plugins=plugins, callbacks=callbacks, progress_bar_refresh_rate=5, ) return trainer
def main(): system = configure_system( hyperparameter_defaults["system"])(hyperparameter_defaults) logger = TensorBoardLogger( 'experiments_logs', name=str(hyperparameter_defaults['system']) + "_" + str(system.model.__class__.__name__) + "_" + str(hyperparameter_defaults['criterion']) + "_" + str(hyperparameter_defaults['scheduler'])) early_stop = EarlyStopping(monitor="valid_iou", mode="max", verbose=True, patience=hyperparameter_defaults["patience"]) model_checkpoint = ModelCheckpoint( monitor="valid_iou", mode="max", verbose=True, filename='Model-{epoch:02d}-{valid_iou:.5f}', save_top_k=3, save_last=True) trainer = pl.Trainer( gpus=[0, 1], plugins=DDPPlugin(find_unused_parameters=True), max_epochs=hyperparameter_defaults['epochs'], logger=logger, check_val_every_n_epoch=1, accelerator='ddp', callbacks=[early_stop, model_checkpoint], num_sanity_val_steps=0, limit_train_batches=1.0, deterministic=True, ) trainer.fit(system) trainer.test(system)
def cli_main(): parser = ArgumentParser() parser.add_argument("--batch-size", default=6, type=int) parser.add_argument("--dataset", default="librimix", type=str, choices=["wsj0-mix", "librimix"]) parser.add_argument( "--root-dir", type=Path, help="The path to the directory where the directory ``Libri2Mix`` or ``Libri3Mix`` is stored.", ) parser.add_argument( "--librimix-tr-split", default="train-360", choices=["train-360", "train-100"], help="The training partition of librimix dataset. (default: ``train-360``)", ) parser.add_argument( "--librimix-task", default="sep_clean", type=str, choices=["sep_clean", "sep_noisy", "enh_single", "enh_both"], help="The task to perform (separation or enhancement, noisy or clean). (default: ``sep_clean``)", ) parser.add_argument( "--num-speakers", default=2, type=int, help="The number of speakers in the mixture. (default: 2)" ) parser.add_argument( "--sample-rate", default=8000, type=int, help="Sample rate of audio files in the given dataset. (default: 8000)", ) parser.add_argument( "--exp-dir", default=Path("./exp"), type=Path, help="The directory to save checkpoints and logs." ) parser.add_argument( "--epochs", metavar="NUM_EPOCHS", default=200, type=int, help="The number of epochs to train. (default: 200)", ) parser.add_argument( "--learning-rate", default=1e-3, type=float, help="Initial learning rate. (default: 1e-3)", ) parser.add_argument( "--num-gpu", default=1, type=int, help="The number of GPUs for training. (default: 1)", ) parser.add_argument( "--num-node", default=1, type=int, help="The number of nodes in the cluster for training. (default: 1)", ) parser.add_argument( "--num-workers", default=4, type=int, help="The number of workers for dataloader. (default: 4)", ) args = parser.parse_args() model = _get_model(num_sources=args.num_speakers) optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate) lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( optimizer, mode="min", factor=0.5, patience=5 ) train_loader, valid_loader, eval_loader = _get_dataloader( args.dataset, args.root_dir, args.num_speakers, args.sample_rate, args.batch_size, args.num_workers, args.librimix_task, args.librimix_tr_split, ) loss = si_sdr_loss metric_dict = { "sdri": sdri_metric, "sisdri": sisdri_metric, } model = ConvTasNetModule( model=model, train_loader=train_loader, val_loader=valid_loader, loss=loss, optim=optimizer, metrics=metric_dict, lr_scheduler=lr_scheduler, ) checkpoint_dir = args.exp_dir / "checkpoints" checkpoint = ModelCheckpoint( checkpoint_dir, monitor="Losses/val_loss", mode="min", save_top_k=5, save_weights_only=True, verbose=True ) callbacks = [ checkpoint, EarlyStopping(monitor="Losses/val_loss", mode="min", patience=30, verbose=True), ] trainer = Trainer( default_root_dir=args.exp_dir, max_epochs=args.epochs, gpus=args.num_gpu, num_nodes=args.num_node, accelerator="ddp", plugins=DDPPlugin(find_unused_parameters=False), # make sure there is no unused params limit_train_batches=1.0, # Useful for fast experiment gradient_clip_val=5.0, callbacks=callbacks, ) trainer.fit(model) model.load_from_checkpoint(checkpoint.best_model_path) state_dict = torch.load(checkpoint.best_model_path, map_location="cpu") state_dict = {k.replace("model.", ""): v for k, v in state_dict["state_dict"].items()} torch.save(state_dict, args.exp_dir / "best_model.pth") trainer.test(model, eval_loader)
def main(): # parse arguments args = parse_args() rank_zero_only(pprint.pprint)(vars(args)) # init default-cfg and merge it with the main- and data-cfg config = get_cfg_defaults() config.merge_from_file(args.main_cfg_path) config.merge_from_file(args.data_cfg_path) pl.seed_everything(config.TRAINER.SEED) # reproducibility # TODO: Use different seeds for each dataloader workers # This is needed for data augmentation # scale lr and warmup-step automatically args.gpus = _n_gpus = setup_gpus(args.gpus) config.TRAINER.WORLD_SIZE = _n_gpus * args.num_nodes config.TRAINER.TRUE_BATCH_SIZE = config.TRAINER.WORLD_SIZE * args.batch_size _scaling = config.TRAINER.TRUE_BATCH_SIZE / config.TRAINER.CANONICAL_BS config.TRAINER.SCALING = _scaling config.TRAINER.TRUE_LR = config.TRAINER.CANONICAL_LR * _scaling config.TRAINER.WARMUP_STEP = math.floor(config.TRAINER.WARMUP_STEP / _scaling) # lightning module profiler = build_profiler(args.profiler_name) model = PL_LoFTR(config, pretrained_ckpt=args.ckpt_path, profiler=profiler) loguru_logger.info(f"LoFTR LightningModule initialized!") # lightning data data_module = MultiSceneDataModule(args, config) loguru_logger.info(f"LoFTR DataModule initialized!") # TensorBoard Logger logger = TensorBoardLogger(save_dir='logs/tb_logs', name=args.exp_name, default_hp_metric=False) ckpt_dir = Path(logger.log_dir) / 'checkpoints' # Callbacks # TODO: update ModelCheckpoint to monitor multiple metrics ckpt_callback = ModelCheckpoint( monitor='auc@10', verbose=True, save_top_k=5, mode='max', save_last=True, dirpath=str(ckpt_dir), filename='{epoch}-{auc@5:.3f}-{auc@10:.3f}-{auc@20:.3f}') lr_monitor = LearningRateMonitor(logging_interval='step') callbacks = [lr_monitor] if not args.disable_ckpt: callbacks.append(ckpt_callback) # Lightning Trainer trainer = pl.Trainer.from_argparse_args( args, plugins=DDPPlugin(find_unused_parameters=False, num_nodes=args.num_nodes, sync_batchnorm=config.TRAINER.WORLD_SIZE > 0), gradient_clip_val=config.TRAINER.GRADIENT_CLIPPING, callbacks=callbacks, logger=logger, sync_batchnorm=config.TRAINER.WORLD_SIZE > 0, replace_sampler_ddp=False, # use custom sampler reload_dataloaders_every_epoch=False, # avoid repeated samples! weights_summary='full', profiler=profiler) loguru_logger.info(f"Trainer initialized!") loguru_logger.info(f"Start training!") trainer.fit(model, datamodule=data_module)
def cli_main(): # ------------ # args # ------------ parser = argparse.ArgumentParser() parser.add_argument('--config', action='store', dest='config', help='config.yaml', required=True) parser.add_argument('--ckpt', action='store', dest='ckpt', help='checkpoint to load', required=True) args = parser.parse_args() with open(args.config, 'r') as ymlfile: config = yaml.load(ymlfile, Loader=yaml.FullLoader) config = DotMap(config) assert (config.name in [ "lstur", "nrms", "naml", "naml_simple", "sentirec", "robust_sentirec" ]) pl.seed_everything(1234) # ------------ # logging # ------------ logger = TensorBoardLogger(**config.logger) # ------------ # data # ------------ test_dataset = BaseDataset(path.join(config.test_behavior), path.join(config.test_news), config) test_loader = DataLoader(test_dataset, **config.test_dataloader) #print(len(dataset), len(train_dataset), len(val_dataset)) # ------------ # init model # ------------ # ------------ # init model # ------------ # load embedding pre-trained embedding weights embedding_weights = [] with open(config.embedding_weights, 'r') as file: lines = file.readlines() for line in tqdm(lines): weights = [float(w) for w in line.split(" ")] embedding_weights.append(weights) pretrained_word_embedding = torch.from_numpy( np.array(embedding_weights, dtype=np.float32)) if config.name == "lstur": model = LSTUR.load_from_checkpoint( args.ckpt, config=config, pretrained_word_embedding=pretrained_word_embedding) elif config.name == "nrms": model = NRMS.load_from_checkpoint( args.ckpt, config=config, pretrained_word_embedding=pretrained_word_embedding) elif config.name == "naml": model = NAML.load_from_checkpoint( args.ckpt, config=config, pretrained_word_embedding=pretrained_word_embedding) elif config.name == "naml_simple": model = NAML_Simple.load_from_checkpoint( args.ckpt, config=config, pretrained_word_embedding=pretrained_word_embedding) elif config.name == "sentirec": model = SENTIREC.load_from_checkpoint( args.ckpt, config=config, pretrained_word_embedding=pretrained_word_embedding) elif config.name == "robust_sentirec": model = ROBUST_SENTIREC.load_from_checkpoint( args.ckpt, config=config, pretrained_word_embedding=pretrained_word_embedding) # elif: # UPCOMING MODELS # ------------ # Test # ------------ trainer = Trainer(**config.trainer, logger=logger, plugins=DDPPlugin(find_unused_parameters=False)) trainer.test(model=model, test_dataloaders=test_loader)
checkpoint_callback = ModelCheckpoint( dirpath=args.save_dir, filename='{epoch}-{val_loss:.3f}-{train_loss:.3f}', save_top_k=-1) logger = CometLogger( api_key="YOUR-API-KEY", project_name=proj_name, ) model = lit_gazetrack_model(args.dataset_dir, args.save_dir, args.batch_size, logger) if (args.checkpoint): if (args.gpus == 0): w = torch.load(args.checkpoint, map_location=torch.device('cpu'))['state_dict'] else: w = torch.load(args.checkpoint)['state_dict'] model.load_state_dict(w) print("Loaded checkpoint") trainer = pl.Trainer(gpus=args.gpus, logger=logger, accelerator="ddp", max_epochs=args.epochs, default_root_dir=args.save_dir, progress_bar_refresh_rate=1, callbacks=[checkpoint_callback], plugins=DDPPlugin(find_unused_parameters=False)) trainer.fit(model) print("DONE")
def process(args): torch.multiprocessing.set_sharing_strategy('file_system') # Pretraining data if args.dataset == "ZINC5k": dataset = ZINC5K("../data/torchdrug/molecule-datasets/", node_feature="pretrain", edge_feature="pretrain", lazy=True) elif args.dataset == "ZINC250k": dataset = datasets.ZINC250k("../data/torchdrug/molecule-datasets/", node_feature="pretrain", edge_feature="pretrain", lazy=True) elif args.dataset == "ZINC2m": # defaults to lazy load dataset = datasets.ZINC2m("../data/torchdrug/molecule-datasets/", node_feature="pretrain", edge_feature="pretrain") # CTRP smiles to embed ctrp = pd.read_csv("../data/drug_screens/CTRP/v20.meta.per_compound.txt", sep="\t") ctrp_ds = MoleculeDataset() ctrp_ds.load_smiles(smiles_list=ctrp['cpd_smiles'], targets=dict(), node_feature='pretrain', edge_feature='pretrain') # Self-supervised pretraining dm = ChemGraphDataModule.from_argparse_args(args, train=dataset, predict=ctrp_ds) model = ChemGraphEmbeddingNetwork(task=args.task, input_dim=dataset.node_feature_dim, hidden_dims=[512] * 5, edge_input_dim=dataset.edge_feature_dim, batch_norm=True, readout="mean", mask_rate=0.15) # Callbacks fname = f"{args.name}_{args.task}_{args.dataset}" logger = TensorBoardLogger(save_dir=args.default_root_dir, version=fname, name='lightning_logs') early_stop = EarlyStopping(monitor='accuracy', min_delta=0.001, patience=5, verbose=False, mode='max') checkpoint_callback = ModelCheckpoint(monitor='accuracy', mode='max') trainer = Trainer.from_argparse_args( args, default_root_dir=logger.log_dir, logger=logger, callbacks=[early_stop, checkpoint_callback], strategy=DDPPlugin(find_unused_parameters=False), profiler='simple') trainer.fit(model, dm) # Generate CTRP embeddings model.to('cpu') model.eval() dl = DataLoader(ctrp_ds, batch_size=len(ctrp_ds)) graph_embeds = [] node_embeds = [] for batch in dl: graph_feature, node_feature = model(batch) graph_embeds.append(graph_feature.detach()) node_embeds.append(node_feature.detach()) graph_embeds = torch.cat(graph_embeds).numpy() node_embeds = torch.cat(node_embeds).numpy() # Write out node_cpd_ids = [ np.repeat(cpd_id, n['graph'].num_node) for n, cpd_id in zip(ctrp_ds, ctrp['broad_cpd_id']) ] node_cpd_ids = np.concatenate(node_cpd_ids) node_embeds = pd.DataFrame(node_embeds, index=node_cpd_ids) node_embeds['atom_type'] = np.concatenate( [[ATOM_SYMBOL[a] for a in n['graph'].atom_type] for n in ctrp_ds]) graph_embeds = pd.DataFrame(graph_embeds, index=ctrp['broad_cpd_id']) node_embeds.to_csv( f"../data/torchdrug/molecule-datasets/{fname}_ctrp_node_embeds.csv", sep=",") graph_embeds.to_csv( f"../data/torchdrug/molecule-datasets/{fname}_ctrp_graph_embeds.csv", sep=",")
def train_model(model, model_dir): # Setup trainer tb_logger = pl_loggers.TensorBoardLogger('{}/logs/'.format(model_dir)) chkpt1 = ModelCheckpoint(save_last=True) chkpt2 = ModelCheckpoint(every_n_train_steps=10000) # save every 10000 steps if Constants.n_gpus != 0: trainer = Trainer(gpus=Constants.n_gpus, callbacks=[chkpt1, chkpt2], accelerator='ddp_spawn', plugins=DDPPlugin(find_unused_parameters=False), precision=16, logger=tb_logger, default_root_dir=model_dir, max_epochs=n_epochs) else: trainer = Trainer(gpus=0, default_root_dir=model_dir, logger=tb_logger, callbacks=[chkpt1, chkpt2], max_epochs=n_epochs) trainer.fit(model)
def train_model( train_config: TrainConfig, video_loader_config: Optional[VideoLoaderConfig] = None, ): """Trains a model. Args: train_config (TrainConfig): Pydantic config for training. video_loader_config (VideoLoaderConfig, optional): Pydantic config for preprocessing videos. If None, will use default for model specified in TrainConfig. """ # get default VLC for model if not specified if video_loader_config is None: video_loader_config = ModelConfig( train_config=train_config, video_loader_config=video_loader_config ).video_loader_config # set up model model = instantiate_model( checkpoint=train_config.checkpoint, scheduler_config=train_config.scheduler_config, weight_download_region=train_config.weight_download_region, model_cache_dir=train_config.model_cache_dir, labels=train_config.labels, from_scratch=train_config.from_scratch, model_name=train_config.model_name, predict_all_zamba_species=train_config.predict_all_zamba_species, ) data_module = ZambaDataModule( video_loader_config=video_loader_config, transform=MODEL_MAPPING[model.__class__.__name__]["transform"], train_metadata=train_config.labels, batch_size=train_config.batch_size, num_workers=train_config.num_workers, ) validate_species(model, data_module) train_config.save_dir.mkdir(parents=True, exist_ok=True) # add folder version_n that auto increments if we are not overwriting tensorboard_version = train_config.save_dir.name if train_config.overwrite else None tensorboard_save_dir = ( train_config.save_dir.parent if train_config.overwrite else train_config.save_dir ) tensorboard_logger = TensorBoardLogger( save_dir=tensorboard_save_dir, name=None, version=tensorboard_version, default_hp_metric=False, ) logging_and_save_dir = ( tensorboard_logger.log_dir if not train_config.overwrite else train_config.save_dir ) model_checkpoint = ModelCheckpoint( dirpath=logging_and_save_dir, filename=train_config.model_name, monitor=train_config.early_stopping_config.monitor if train_config.early_stopping_config is not None else None, mode=train_config.early_stopping_config.mode if train_config.early_stopping_config is not None else "min", ) callbacks = [model_checkpoint] if train_config.early_stopping_config is not None: callbacks.append(EarlyStopping(**train_config.early_stopping_config.dict())) if train_config.backbone_finetune_config is not None: callbacks.append(BackboneFinetuning(**train_config.backbone_finetune_config.dict())) trainer = pl.Trainer( gpus=train_config.gpus, max_epochs=train_config.max_epochs, auto_lr_find=train_config.auto_lr_find, logger=tensorboard_logger, callbacks=callbacks, fast_dev_run=train_config.dry_run, accelerator="ddp" if data_module.multiprocessing_context is not None else None, plugins=DDPPlugin(find_unused_parameters=False) if data_module.multiprocessing_context is not None else None, ) if video_loader_config.cache_dir is None: logger.info("No cache dir is specified. Videos will not be cached.") else: logger.info(f"Videos will be cached to {video_loader_config.cache_dir}.") if train_config.auto_lr_find: logger.info("Finding best learning rate.") trainer.tune(model, data_module) try: git_hash = git.Repo(search_parent_directories=True).head.object.hexsha except git.exc.InvalidGitRepositoryError: git_hash = None configuration = { "git_hash": git_hash, "model_class": model.model_class, "species": model.species, "starting_learning_rate": model.lr, "train_config": json.loads(train_config.json(exclude={"labels"})), "training_start_time": datetime.utcnow().isoformat(), "video_loader_config": json.loads(video_loader_config.json()), } if not train_config.dry_run: config_path = Path(logging_and_save_dir) / "train_configuration.yaml" config_path.parent.mkdir(exist_ok=True, parents=True) logger.info(f"Writing out full configuration to {config_path}.") with config_path.open("w") as fp: yaml.dump(configuration, fp) logger.info("Starting training...") trainer.fit(model, data_module) if not train_config.dry_run: if trainer.datamodule.test_dataloader() is not None: logger.info("Calculating metrics on holdout set.") test_metrics = trainer.test(dataloaders=trainer.datamodule.test_dataloader())[0] with (Path(logging_and_save_dir) / "test_metrics.json").open("w") as fp: json.dump(test_metrics, fp, indent=2) if trainer.datamodule.val_dataloader() is not None: logger.info("Calculating metrics on validation set.") val_metrics = trainer.validate(dataloaders=trainer.datamodule.val_dataloader())[0] with (Path(logging_and_save_dir) / "val_metrics.json").open("w") as fp: json.dump(val_metrics, fp, indent=2) return trainer
def cli_main(): # ------------ # args # ------------ parser = argparse.ArgumentParser() parser.add_argument('--config', action='store', dest='config', help='config.yaml', required=True) parser.add_argument('--resume', action='store', dest='resume', help='resume training form ckpt', required=False) args = parser.parse_args() with open(args.config, 'r') as ymlfile: config = yaml.load(ymlfile, Loader=yaml.FullLoader) config = DotMap(config) assert (config.name in [ "lstur", "nrms", "naml", "naml_simple", "sentirec", "robust_sentirec" ]) pl.seed_everything(1234) # ------------ # init callbacks & logging # ------------ checkpoint_callback = ModelCheckpoint(**config.checkpoint) logger = TensorBoardLogger(**config.logger) # ------------ # data # ------------ train_dataset = BaseDataset(path.join(config.train_behavior), path.join(config.train_news), config) val_dataset = BaseDataset(path.join(config.val_behavior), path.join(config.train_news), config) train_loader = DataLoader(train_dataset, **config.train_dataloader) val_loader = DataLoader(val_dataset, **config.val_dataloader) # ------------ # init model # ------------ # load embedding pre-trained embedding weights embedding_weights = [] with open(config.embedding_weights, 'r') as file: lines = file.readlines() for line in tqdm(lines): weights = [float(w) for w in line.split(" ")] embedding_weights.append(weights) pretrained_word_embedding = torch.from_numpy( np.array(embedding_weights, dtype=np.float32)) if config.name == "lstur": model = LSTUR(config, pretrained_word_embedding) elif config.name == "nrms": model = NRMS(config, pretrained_word_embedding) elif config.name == "naml": model = NAML(config, pretrained_word_embedding) elif config.name == "naml_simple": model = NAML_Simple(config, pretrained_word_embedding) elif config.name == "sentirec": model = SENTIREC(config, pretrained_word_embedding) elif config.name == "robust_sentirec": model = ROBUST_SENTIREC(config, pretrained_word_embedding) # elif: # UPCOMING MODELS # ------------ # training # ------------ early_stop_callback = EarlyStopping(**config.early_stop) if args.resume is not None: model = model.load_from_checkpoint( args.resume, config=config, pretrained_word_embedding=pretrained_word_embedding) trainer = Trainer(**config.trainer, callbacks=[early_stop_callback, checkpoint_callback], logger=logger, plugins=DDPPlugin(find_unused_parameters=config. find_unused_parameters), resume_from_checkpoint=args.resume) else: trainer = Trainer(**config.trainer, callbacks=[early_stop_callback, checkpoint_callback], logger=logger, plugins=DDPPlugin(find_unused_parameters=config. find_unused_parameters)) trainer.fit(model=model, train_dataloader=train_loader, val_dataloaders=val_loader)
def run(config): # build hooks loss_fn = build_loss(config) metric_fn = build_metrics(config) hooks = build_hooks(config) hooks.update({"loss_fn": loss_fn, "metric_fn": metric_fn}) # build model model = build_model(config) # build callbacks callbacks = build_callbacks(config) # build logger logger = build_logger(config) # debug if config.debug: logger = None OmegaConf.set_struct(config, True) with open_dict(config): config.trainer.trainer.max_epochs = None config.trainer.trainer.max_steps = 10 # logging for wandb or mlflow if hasattr(logger, "log_hyperparams"): for k, v in config.trainer.items(): if not k in ("metrics", "inference"): logger.log_hyperparams(params=v) logger.log_hyperparams(params=config.dataset) logger.log_hyperparams(params=config.augmentation) # last linear training if (hasattr(config.trainer.model, "last_linear") and (config.trainer.model.last_linear.training) and (config.trainer.model.params.pretrained)): model = train_last_linear(config, model, hooks, logger) # initialize model model, params = kvt.utils.initialize_model(config, model) # build optimizer optimizer = build_optimizer(config, model=model, params=params) # build scheduler scheduler = build_scheduler(config, optimizer=optimizer) # build dataloaders dataloaders = build_dataloaders(config) # build strong transform strong_transform, storong_transform_p = build_strong_transform(config) # build lightning module lightning_module = build_lightning_module( config, model=model, optimizer=optimizer, scheduler=scheduler, hooks=hooks, dataloaders=dataloaders, strong_transform=strong_transform, storong_transform_p=storong_transform_p, ) # build plugins # fix this issue # https://github.com/PyTorchLightning/pytorch-lightning/discussions/6219 plugins = [] if hasattr(config.trainer.trainer, "accelerator") and (config.trainer.trainer.accelerator in ("ddp", "ddp2")): if hasattr(config.trainer, "find_unused_parameters"): plugins.append( DDPPlugin(find_unused_parameters=config.trainer. find_unused_parameters), ) else: plugins.append(DDPPlugin(find_unused_parameters=False), ) # best model path dir_path = config.trainer.callbacks.ModelCheckpoint.dirpath if isinstance(OmegaConf.to_container(config.dataset.dataset), list): idx_fold = config.dataset.dataset[0].params.idx_fold else: idx_fold = config.dataset.dataset.params.idx_fold filename = f"fold_{idx_fold}_best.ckpt" best_model_path = os.path.join(dir_path, filename) # train loop trainer = pl.Trainer( logger=logger, callbacks=callbacks, plugins=plugins, **config.trainer.trainer, ) if not config.trainer.skip_training: trainer.fit(lightning_module) path = trainer.checkpoint_callback.best_model_path if path: print(f"Best model: {path}") print("Renaming...") # copy best model subprocess.run(f"mv {path} {best_model_path}", shell=True, stdout=PIPE, stderr=PIPE) # if there is no best_model_path # e.g. no valid dataloader else: print("Saving current trainer...") trainer.save_checkpoint(best_model_path) # log best model if hasattr(logger, "log_hyperparams"): logger.log_hyperparams(params={"best_model_path": best_model_path}) # load best checkpoint if os.path.exists(best_model_path): print(f"Loading best model: {best_model_path}") state_dict = torch.load(best_model_path)["state_dict"] # if using dp, it is necessary to fix state dict keys if (hasattr(config.trainer.trainer, "sync_batchnorm") and config.trainer.trainer.sync_batchnorm): state_dict = kvt.utils.fix_dp_model_state_dict(state_dict) lightning_module.model.load_state_dict(state_dict) else: print(f"Best model {best_model_path} does not exist.") # evaluate metric_dict = evaluate(lightning_module, hooks, config, mode=["validation"]) print("Result:") print(metric_dict) if hasattr(logger, "log_metrics"): logger.log_metrics(metric_dict)
optimizer = torch.optim.Adam(self.parameters(), lr=1e-3) return optimizer dataset = MNIST(os.getcwd(), download=False, transform=transforms.ToTensor()) train_loader = DataLoader(dataset) # init model autoencoder = LitAutoEncoder() # most basic trainer, uses good defaults (auto-tensorboard, checkpoints, logs, and more) parallel_devices = [torch.device(i) for i in range(torch.cuda.device_count())] acc = GPUAccelerator(precision_plugin=NativeMixedPrecisionPlugin(), training_type_plugin=DDPPlugin( parallel_devices=parallel_devices, cluster_environment=LSFEnvironment())) targs = { 'max_epochs': 1, 'num_nodes': 2, 'accumulate_grad_batches': 1, 'gpus': 6, 'accelerator': acc, 'limit_train_batches': 10, 'limit_val_batches': 5, 'log_every_n_steps': 1 } # trainer = pl.Trainer(gpus=8) (if you have GPUs) trainer = pl.Trainer(**targs)
@RunIf(min_gpus=2) @mock.patch.dict( os.environ, { "CUDA_VISIBLE_DEVICES": "0,1", "SLURM_NTASKS": "2", "SLURM_JOB_NAME": "SOME_NAME", "SLURM_NODEID": "0", "SLURM_PROCID": "1", "SLURM_LOCALID": "1", }, ) @mock.patch("pytorch_lightning.plugins.DDPPlugin.setup_distributed", autospec=True) @pytest.mark.parametrize("strategy", ["ddp", DDPPlugin()]) def test_strategy_choice_ddp_slurm(setup_distributed_mock, strategy): class CB(Callback): def on_fit_start(self, trainer, pl_module): assert trainer._accelerator_connector._is_slurm_managing_tasks() assert isinstance(trainer.accelerator, GPUAccelerator) assert isinstance(trainer.training_type_plugin, DDPPlugin) assert isinstance(trainer.training_type_plugin.cluster_environment, SLURMEnvironment) assert trainer.training_type_plugin.cluster_environment.local_rank() == 1 assert trainer.training_type_plugin.local_rank == 1 raise SystemExit() model = BoringModel() trainer = Trainer(fast_dev_run=True, strategy=strategy, gpus=2, callbacks=[CB()]) with pytest.raises(SystemExit):
def create_lightning_trainer(container: LightningContainer, resume_from_checkpoint: Optional[Path] = None, num_nodes: int = 1, multiple_trainloader_mode: str = "max_size_cycle") -> \ Tuple[Trainer, StoringLogger]: """ Creates a Pytorch Lightning Trainer object for the given model configuration. It creates checkpoint handlers and loggers. That includes a diagnostic logger for use in unit tests, that is also returned as the second return value. :param container: The container with model and data. :param resume_from_checkpoint: If provided, training resumes from this checkpoint point. :param num_nodes: The number of nodes to use in distributed training. :return: A tuple [Trainer object, diagnostic logger] """ logging.debug(f"resume_from_checkpoint: {resume_from_checkpoint}") num_gpus = container.num_gpus_per_node() effective_num_gpus = num_gpus * num_nodes strategy = None if effective_num_gpus == 0: accelerator = "cpu" devices = 1 message = "CPU" else: accelerator = "gpu" devices = num_gpus message = f"{devices} GPU" if effective_num_gpus > 1: # Accelerator should be "ddp" when running large models in AzureML (when using DDP_spawn, we get out of # GPU memory). # Initialize the DDP plugin. The default for pl_find_unused_parameters is False. If True, the plugin # prints out lengthy warnings about the performance impact of find_unused_parameters. strategy = DDPPlugin(find_unused_parameters=container.pl_find_unused_parameters) message += "s per node with DDP" logging.info(f"Using {message}") tensorboard_logger = TensorBoardLogger(save_dir=str(container.logs_folder), name="Lightning", version="") loggers = [tensorboard_logger, AzureMLLogger(False)] storing_logger = StoringLogger() loggers.append(storing_logger) # Use 32bit precision when running on CPU. Otherwise, make it depend on use_mixed_precision flag. precision = 32 if num_gpus == 0 else 16 if container.use_mixed_precision else 32 # The next two flags control the settings in torch.backends.cudnn.deterministic and torch.backends.cudnn.benchmark # https://pytorch.org/docs/stable/notes/randomness.html # Note that switching to deterministic models can have large performance downside. if container.pl_deterministic: deterministic = True benchmark = False else: deterministic = False benchmark = True # The last checkpoint is considered the "best" checkpoint. For large segmentation # models, this still appears to be the best way of choosing them because validation loss on the relatively small # training patches is not stable enough. Going by the validation loss somehow works for the Prostate model, but # not for the HeadAndNeck model. # Note that "last" is somehow a misnomer, it should rather be "latest". There is a "last" checkpoint written in # every epoch. We could use that for recovery too, but it could happen that the job gets preempted right during # writing that file, and we would end up with an invalid file. last_checkpoint_callback = ModelCheckpoint(dirpath=str(container.checkpoint_folder), save_last=True, save_top_k=0) recovery_checkpoint_callback = ModelCheckpoint(dirpath=str(container.checkpoint_folder), filename=AUTOSAVE_CHECKPOINT_FILE_NAME, every_n_val_epochs=container.autosave_every_n_val_epochs, save_last=False) callbacks: List[Callback] = [ last_checkpoint_callback, recovery_checkpoint_callback, ] if container.monitor_loading: # TODO antonsc: Remove after fixing the callback. raise NotImplementedError("Monitoring batch loading times has been temporarily disabled.") # callbacks.append(BatchTimeCallback()) if num_gpus > 0 and container.monitor_gpu: logging.info("Adding monitoring for GPU utilization") callbacks.append(GPUStatsMonitor(intra_step_time=True, inter_step_time=True)) # Add the additional callbacks that were specified in get_trainer_arguments for LightningContainers additional_args = container.get_trainer_arguments() # Callbacks can be specified via the "callbacks" argument (the legacy behaviour) or the new get_callbacks method if "callbacks" in additional_args: more_callbacks = additional_args.pop("callbacks") if isinstance(more_callbacks, list): callbacks.extend(more_callbacks) # type: ignore else: callbacks.append(more_callbacks) # type: ignore callbacks.extend(container.get_callbacks()) is_azureml_run = not is_offline_run_context(RUN_CONTEXT) progress_bar_refresh_rate = container.pl_progress_bar_refresh_rate if progress_bar_refresh_rate is None: progress_bar_refresh_rate = 50 logging.info(f"The progress bar refresh rate is not set. Using a default of {progress_bar_refresh_rate}. " f"To change, modify the pl_progress_bar_refresh_rate field of the container.") if is_azureml_run: callbacks.append(AzureMLProgressBar(refresh_rate=progress_bar_refresh_rate, write_to_logging_info=True, print_timestamp=False)) else: callbacks.append(TQDMProgressBar(refresh_rate=progress_bar_refresh_rate)) # Read out additional model-specific args here. # We probably want to keep essential ones like numgpu and logging. trainer = Trainer(default_root_dir=str(container.outputs_folder), deterministic=deterministic, benchmark=benchmark, accelerator=accelerator, strategy=strategy, max_epochs=container.num_epochs, # Both these arguments can be integers or floats. If integers, it is the number of batches. # If float, it's the fraction of batches. We default to 1.0 (processing all batches). limit_train_batches=container.pl_limit_train_batches or 1.0, limit_val_batches=container.pl_limit_val_batches or 1.0, num_sanity_val_steps=container.pl_num_sanity_val_steps, check_val_every_n_epoch=container.pl_check_val_every_n_epoch, callbacks=callbacks, logger=loggers, num_nodes=num_nodes, devices=devices, precision=precision, sync_batchnorm=True, detect_anomaly=container.detect_anomaly, profiler=container.pl_profiler, resume_from_checkpoint=str(resume_from_checkpoint) if resume_from_checkpoint else None, multiple_trainloader_mode=multiple_trainloader_mode, **additional_args) return trainer, storing_logger
def main(args): backbone = "bert-base-uncased-itokens" tokenizer = BertTokenizerFast.from_pretrained(backbone) # encoder_decoder_config = EncoderDecoderConfig.from_pretrained("bert-base-uncased-itokens") # model = EncoderDecoderModel.from_pretrained( # "bert-base-uncased-itokens", config=encoder_decoder_config # ) # model = EncoderDecoderModel.from_encoder_decoder_pretrained( # "bert-base-uncased-itokens", "bert-base-uncased-itokens", tie_encoder_decoder=True # ) # generator = Generator(model) # discriminator = Discriminator( # AutoModel.from_pretrained("bert-base-uncased-itokens") # ) if args.test: model = GAN.load_from_checkpoint(args.load_checkpoint, args=args, tokenizer=tokenizer, backbone=backbone) model.cuda() model.eval() model.inference(args.scene_graphs_json) return # train if args.gpus > 1: dm = VGDataModule(args, tokenizer, 2) else: dm = VGDataModule(args, tokenizer) if args.load_checkpoint != "": model = GAN.load_from_checkpoint(args.load_checkpoint, args=args, tokenizer=tokenizer, backbone=backbone) else: model = GAN(args, tokenizer, backbone) training_args = { "gpus": args.gpus, "fast_dev_run": False, "max_steps": args.num_iterations, "precision": 32, "gradient_clip_val": 1, } if args.gpus > 1: additional_args = { "accelerator": "ddp", "plugins": [DDPPlugin(find_unused_parameters=True)] # "plugins": [my_ddp] } training_args.update(additional_args) trainer = pl.Trainer(**training_args) trainer.fit(model, dm)
def main(): """ Main training loop. """ parser = ArgumentParser() parser = UNet.add_model_specific_args(parser) parser = Trainer.add_argparse_args(parser) args = parser.parse_args() prod = bool(os.getenv("PROD")) logging.getLogger(__name__).setLevel(logging.INFO) if prod: logging.info( "Training i production mode, disabling all debugging APIs") torch.autograd.set_detect_anomaly(False) torch.autograd.profiler.profile(enabled=False) torch.autograd.profiler.emit_nvtx(enabled=False) else: logging.info("Training i development mode, debugging APIs active.") torch.autograd.set_detect_anomaly(True) torch.autograd.profiler.profile(enabled=True, use_cuda=True, record_shapes=True, profile_memory=True) torch.autograd.profiler.emit_nvtx(enabled=True, record_shapes=True) model = UNet(**vars(args)) logging.info( f"Network:\n" f"\t{model.hparams.n_channels} input channels\n" f"\t{model.hparams.n_classes} output channels (classes)\n" f'\t{"Bilinear" if model.hparams.bilinear else "Transposed conv"} upscaling' ) cudnn.benchmark = True # cudnn Autotuner cudnn.enabled = True # look for optimal algorithms early_stop_callback = EarlyStopping( monitor="val_loss", min_delta=0.00, mode="min", patience=10 if not os.getenv("EARLY_STOP") else int( os.getenv("EARLY_STOP")), verbose=True, ) lr_monitor = LearningRateMonitor() run_name = "{}_LR{}_BS{}_IS{}".format( datetime.now().strftime("%d-%m-%Y-%H-%M-%S"), args.lr, args.batch_size, args.image_size, ).replace(".", "_") log_folder = ("./logs" if not os.getenv("DIR_ROOT_DIR") else os.getenv("DIR_ROOT_DIR")) if not os.path.isdir(log_folder): os.mkdir(log_folder) logger = TensorBoardLogger(log_folder, name=run_name) checkpoint_callback = ModelCheckpoint( monitor='val_loss', dirpath='./checkpoints', filename='unet-{epoch:02d}-{val_loss:.2f}', save_top_k=3, mode='min', ) try: trainer = Trainer.from_argparse_args( args, gpus=-1, accelerator="ddp", plugins=DDPPlugin(find_unused_parameters=False), precision=16, auto_lr_find="learning_rate" if float(os.getenv("LRN_RATE")) == 0.0 else False, logger=logger, callbacks=[early_stop_callback, lr_monitor, checkpoint_callback], accumulate_grad_batches=1.0 if not os.getenv("ACC_GRAD") else int( os.getenv("ACC_GRAD")), gradient_clip_val=0.0 if not os.getenv("GRAD_CLIP") else float( os.getenv("GRAD_CLIP")), max_epochs=100 if not os.getenv("EPOCHS") else int( os.getenv("EPOCHS")), val_check_interval=0.1 if not os.getenv("VAL_INT_PER") else float( os.getenv("VAL_INT_PER")), default_root_dir=os.getcwd() if not os.getenv("DIR_ROOT_DIR") else os.getenv("DIR_ROOT_DIR"), fast_dev_run=True if os.getenv("FAST_DEV_RUN") == "True" else False, ) if float(os.getenv("LRN_RATE")) == 0.0: trainer.tune(model) trainer.fit(model) trainer.test(model) except KeyboardInterrupt: torch.save(model.state_dict(), "INTERRUPTED.pth") logging.info("Saved interrupt") try: sys.exit(0) except SystemExit: os._exit(0)
def run(cfg: DictConfig): local_rank = int(os.environ.get('LOCAL_RANK', 0)) # The logs & checkpoints are dumped in: ${cfg.output_dir}/${cfg.experiment_name}/vN, where vN # is v0, v1, .... The version number increases automatically. script_dir = Path.cwd() experiment_dir = script_dir / cfg.output_dir / cfg.experiment_name experiment_dir.mkdir(parents=True, exist_ok=True) existing_ver = list() for d in experiment_dir.iterdir(): if d.name.startswith('v') and d.name[1:].isdecimal() and d.is_dir(): existing_ver.append(int(d.name[1:])) if local_rank == 0: current_ver = max(existing_ver) + 1 if existing_ver else 0 output_dir = experiment_dir / f'v{current_ver}' output_dir.mkdir() else: # Use the same directory for output with the main process. current_ver = max(existing_ver) output_dir = experiment_dir / f'v{current_ver}' pl_logger = logging.getLogger('lightning') logging.config.fileConfig( script_dir / 'logging.conf', disable_existing_loggers=False, defaults={'log_filename': output_dir / f'run_rank{local_rank}.log'}) # Only the process with LOCAL_RANK = 0 will print logs on the console. # And all the processes will print logs in their own log files. if local_rank != 0: root_logger = logging.getLogger() root_logger.removeHandler(root_logger.handlers[0]) pl_logger.info(f'Output logs & checkpoints in: {output_dir}') # Dump experiment configurations for reproducibility if local_rank == 0: with open(output_dir / 'cfg.yaml', 'w') as yaml_file: yaml_file.write(OmegaConf.to_yaml(cfg)) pl_logger.info('The final experiment setup is dumped as: ./cfg.yaml') pl.seed_everything(cfg.seed, workers=True) # Create model net = load_obj(cfg.model.class_name, 'torchvision.models')(**cfg.model.params) pl_logger.info( f'Create model "{type(net)}". You can view its graph using TensorBoard.' ) # Inject quantizers into the model net = nz.quantizer_inject(net, cfg.quan) quan_cnt, quan_dict = nz.quantizer_stat(net) msg = f'Inject {quan_cnt} quantizers into the model:' for k, v in quan_dict.items(): msg += f'\n {k} = {len(v)}' yaml.safe_dump(quan_dict, open(output_dir / 'quan_stat.yaml', 'w')) pl_logger.info(msg) pl_logger.info( 'A complete list of injected quantizers is dumped as: ./quan_stat.yaml' ) # Prepare the dataset dm = apputil.get_datamodule(cfg) pl_logger.info( f'Prepare the "{cfg.dataset.name}" dataset from: {cfg.dataset.data_dir}' ) msg = f'The dataset samples are split into three sets:' \ f'\n Train = {len(dm.train_dataloader())} batches (batch size = {dm.train_dataloader().batch_size})' \ f'\n Val = {len(dm.val_dataloader())} batches (batch size = {dm.val_dataloader().batch_size})' \ f'\n Test = {len(dm.test_dataloader())} batches (batch size = {dm.test_dataloader().batch_size})' pl_logger.info(msg) progressbar_cb = apputil.ProgressBar(pl_logger) # gpu_stats_cb = pl.callbacks.GPUStatsMonitor() if cfg.checkpoint.path: assert Path(cfg.checkpoint.path).is_file( ), f'Checkpoint path is not a file: {cfg.checkpoint.path}' pl_logger.info( f'Resume training checkpoint from: {cfg.checkpoint.path}') if cfg.eval: pl_logger.info('Training process skipped. Evaluate the resumed model.') assert cfg.checkpoint.path is not None, 'Try to evaluate the model resumed from the checkpoint, but got None' # Initialize the Trainer trainer = pl.Trainer(callbacks=[progressbar_cb], **cfg.trainer) pl_logger.info( f'The model is distributed to {trainer.num_gpus} GPUs with {cfg.trainer.accelerator} backend.' ) pretrained_lit = LitModuleWrapper.load_from_checkpoint( checkpoint_path=cfg.checkpoint.path, model=net, cfg=cfg) trainer.test(pretrained_lit, datamodule=dm, verbose=False) else: # train + eval tb_logger = TensorBoardLogger(output_dir / 'tb_runs', name=cfg.experiment_name, log_graph=True) pl_logger.info('Tensorboard logger initialized in: ./tb_runs') lr_monitor_cb = pl.callbacks.LearningRateMonitor() checkpoint_cb = pl.callbacks.ModelCheckpoint( dirpath=output_dir / 'checkpoints', filename='{epoch}-{val_loss_epoch:.4f}-{val_acc_epoch:.4f}', monitor='val_loss_epoch', mode='min', save_top_k=3, save_last=True) pl_logger.info( 'Checkpoints of the best 3 models as well as the last one will be saved to: ./checkpoints' ) # Wrap model with LightningModule lit = LitModuleWrapper(net, cfg) # A fake input array for TensorBoard to generate graph lit.example_input_array = t.rand(dm.size()).unsqueeze(dim=0) # Initialize the Trainer trainer = pl.Trainer( logger=[tb_logger], callbacks=[checkpoint_cb, lr_monitor_cb, progressbar_cb], resume_from_checkpoint=cfg.checkpoint.path, plugins=DDPPlugin(find_unused_parameters=False), **cfg.trainer) pl_logger.info( f'The model is distributed to {trainer.num_gpus} GPUs with {cfg.trainer.accelerator} backend.' ) pl_logger.info('Training process begins.') trainer.fit(model=lit, datamodule=dm) pl_logger.info('Evaluate the best trained model.') trainer.test(datamodule=dm, ckpt_path='best', verbose=False) pl_logger.info('Program completed successfully. Exiting...') pl_logger.info( 'If you have any questions or suggestions, please visit: github.com/zhutmost/neuralzip' )
def main(conf): train_set = PodcastMixDataloader( csv_dir=conf["data"]["train_dir"], sample_rate=conf["data"]["sample_rate"], original_sample_rate=conf["data"]["original_sample_rate"], segment=conf["data"]["segment"], shuffle_tracks=True, multi_speakers=conf["training"]["multi_speakers"]) val_set = PodcastMixDataloader( csv_dir=conf["data"]["valid_dir"], sample_rate=conf["data"]["sample_rate"], original_sample_rate=conf["data"]["original_sample_rate"], segment=conf["data"]["segment"], shuffle_tracks=True, multi_speakers=conf["training"]["multi_speakers"]) train_loader = DataLoader(train_set, shuffle=True, batch_size=conf["training"]["batch_size"], num_workers=conf["training"]["num_workers"], drop_last=True, pin_memory=True) val_loader = DataLoader(val_set, shuffle=False, batch_size=conf["training"]["batch_size"], num_workers=conf["training"]["num_workers"], drop_last=True, pin_memory=True) if (conf["model"]["name"] == "ConvTasNet"): sys.path.append('ConvTasNet_model') from conv_tasnet_norm import ConvTasNetNorm conf["masknet"].update({"n_src": conf["data"]["n_src"]}) model = ConvTasNetNorm(**conf["filterbank"], **conf["masknet"], sample_rate=conf["data"]["sample_rate"]) loss_func = LogL2Time() plugins = None elif (conf["model"]["name"] == "UNet"): # UNet with logl2 time loss and normalization inside model sys.path.append('UNet_model') from unet_model import UNet model = UNet(conf["data"]["sample_rate"], conf["data"]["fft_size"], conf["data"]["hop_size"], conf["data"]["window_size"], conf["convolution"]["kernel_size"], conf["convolution"]["stride"]) loss_func = LogL2Time() plugins = DDPPlugin(find_unused_parameters=False) optimizer = make_optimizer(model.parameters(), **conf["optim"]) if conf["training"]["half_lr"]: scheduler = ReduceLROnPlateau(optimizer=optimizer, factor=0.5, patience=5) # Just after instantiating, save the args. Easy loading in the future. exp_dir = conf["model"]["name"] + "_model/" + conf["main_args"]["exp_dir"] os.makedirs(exp_dir, exist_ok=True) conf_path = os.path.join(exp_dir, "conf.yml") with open(conf_path, "w") as outfile: yaml.safe_dump(conf, outfile) system = System(model=model, loss_func=loss_func, optimizer=optimizer, train_loader=train_loader, val_loader=val_loader, scheduler=scheduler, config=conf) # Define callbacks callbacks = [] checkpoint_dir = os.path.join(exp_dir, "checkpoints/") checkpoint = ModelCheckpoint(checkpoint_dir, monitor="val_loss", mode="min", save_top_k=5, verbose=True) callbacks.append(checkpoint) if conf["training"]["early_stop"]: callbacks.append( EarlyStopping(monitor="val_loss", mode="min", patience=100, verbose=True)) # Don't ask GPU if they are not available. gpus = -1 if torch.cuda.is_available() else None distributed_backend = "ddp" if torch.cuda.is_available() else None trainer = pl.Trainer( max_epochs=conf["training"]["epochs"], callbacks=callbacks, default_root_dir=exp_dir, gpus=gpus, distributed_backend=distributed_backend, gradient_clip_val=5.0, resume_from_checkpoint=conf["main_args"]["resume_from"], precision=32, plugins=plugins) trainer.fit(system) best_k = {k: v.item() for k, v in checkpoint.best_k_models.items()} with open(os.path.join(exp_dir, "best_k_models.json"), "w") as f: print(best_k, f) json.dump(best_k, f, indent=0) print(checkpoint.best_model_path) state_dict = torch.load(checkpoint.best_model_path) system.load_state_dict(state_dict=state_dict["state_dict"]) system.cpu() to_save = system.model.serialize() to_save.update(train_set.get_infos()) torch.save(to_save, os.path.join(exp_dir, "best_model.pth"))