def main(args): hp = OmegaConf.load(args.config) train_name = "%s_%s" % (hp.model.name.upper(), Path(hp.data.file).stem) logger = TensorBoardLogger('logs/', name=train_name) logger.log_hyperparams(OmegaConf.to_container(hp)) if hp.model.name == 'gcn': net = MeshRefineGCN(hp) elif hp.model.name == 'pn': net = MeshRefinePN(hp) elif hp.model.name == 'zero': net = CondenseMesh(hp) else: raise ValueError("Invalid model name: %s" % hp.model.name) trainer = Trainer(logger=logger, max_epochs=args.max_epochs, gpus=-1, default_root_dir='logs') trainer.fit(net) mesh, pcd = net.current_mesh, net.source_pcd if hp.model.name != 'zero': net.get_loss.show(mesh) utils.show_overlay(mesh, pcd) utils.save_result(os.path.join(logger.log_dir, 'objects'), -1, mesh, pcd) print("Done!")
def test_gcs_logging(tmpdir): dir_path = gcs_path_join(tmpdir) name = "tb_versioning" log_dir = os.path.join(dir_path, name) gcs_fs.mkdir(log_dir) expected_version = "101" logger = TensorBoardLogger(save_dir=dir_path, name=name, version=expected_version) logger.log_hyperparams({"a": 1, "b": 2, 123: 3, 3.5: 4, 5j: 5}) assert logger.version == expected_version gcs_paths = [ os.path.basename(path) for path in gcs_fs.listdir(log_dir, detail=False) ] gcs_paths = list(filter(lambda x: len(x) > 0, gcs_paths)) assert gcs_paths == [expected_version] assert gcs_fs.listdir(os.path.join(log_dir, expected_version), detail=False) assert gcs_rm_dir(dir_path)
def train_tune(hparams, rdm): model = get_model(hparams) logger = TensorBoardLogger(save_dir=tune.get_trial_dir(), name="", version=".", default_hp_metric=False) logger.log_hyperparams( hparams, { 'train_acc': 0, 'train_f1': 0, 'train_loss': 0, 'valid_acc': 0, 'valid_f1': 0, 'valid_loss': 0, }) trainer = pl.Trainer(max_epochs=hparams['n_epochs'], gpus=1, logger=logger, progress_bar_refresh_rate=0, callbacks=[ TuneReportCallback( ['valid_acc', 'valid_f1', 'valid_loss'], on="validation_end") ]) trainer.fit(model, rdm)
def main(): args = parse_args() cfg = Config.fromfile(args.config) setup_seed(cfg.random_seed) model = LightningModel(cfg) checkpoint_callback = ModelCheckpoint( filepath=f"{cfg.checkpoint_path}/{cfg.name}/{cfg.version}/" f"{cfg.name}_{cfg.version}_{{epoch}}_{{avg_val_loss:.3f}}_{{ade:.3f}}_{{fde:.3f}}_{{fiou:.3f}}", save_last=True, save_top_k=8, verbose=True, monitor='fiou', mode='max', prefix='') lr_logger_callback = LearningRateLogger(logging_interval='step') logger = TensorBoardLogger(save_dir=cfg.log_path, name=cfg.name, version=cfg.version) logger.log_hyperparams(model.hparams) profiler = SimpleProfiler() if cfg.simple_profiler else AdvancedProfiler() check_val_every_n_epoch = cfg.check_val_every_n_epoch if hasattr( cfg, 'check_val_every_n_epoch') else 1 trainer = pl.Trainer( gpus=cfg.num_gpus, max_epochs=cfg.max_epochs, logger=logger, profiler=profiler, # this line won't work in multi-gpu setting. weights_summary="top", gradient_clip_val=cfg.gradient_clip_val, callbacks=[lr_logger_callback], checkpoint_callback=checkpoint_callback, resume_from_checkpoint=cfg.resume_from_checkpoint, accumulate_grad_batches=cfg.batch_size_times, check_val_every_n_epoch=check_val_every_n_epoch) if (not (args.train or args.test)) or args.train: shutil.copy( args.config, os.path.join(cfg.log_path, cfg.name, cfg.version, args.config.split('/')[-1])) if cfg.load_from_checkpoint is not None: model_ckpt = partial_state_dict(model, cfg.load_from_checkpoint) model.load_state_dict(model_ckpt) trainer.fit(model) if args.test: if cfg.test_checkpoint is not None: model_ckpt = partial_state_dict(model, cfg.test_checkpoint) model.load_state_dict(model_ckpt) trainer.test(model)
def test_tensorboard_log_hyperparams(tmpdir): logger = TensorBoardLogger(tmpdir) hparams = { "float": 0.3, "int": 1, "string": "abc", "bool": True } logger.log_hyperparams(hparams)
def test_tensorboard_named_version(tmpdir): """Verify that manual versioning works for string versions, e.g. '2020-02-05-162402' """ tmpdir.mkdir("tb_versioning") expected_version = "2020-02-05-162402" logger = TensorBoardLogger(save_dir=tmpdir, name="tb_versioning", version=expected_version) logger.log_hyperparams({"a": 1, "b": 2}) # Force data to be written assert logger.version == expected_version
def test_tensorboard_no_name(tmpdir, name): """Verify that None or empty name works""" logger = TensorBoardLogger(save_dir=tmpdir, name=name) logger.log_hyperparams({ "a": 1, "b": 2, 123: 3, 3.5: 4, 5j: 5 }) # Force data to be written assert logger.root_dir == tmpdir assert os.listdir(tmpdir / "version_0")
def test_tensorboard_log_hyperparams(tmpdir): logger = TensorBoardLogger(tmpdir) hparams = { "float": 0.3, "int": 1, "string": "abc", "bool": True, "list": [1, 2, 3], "namespace": Namespace(foo=3), "layer": torch.nn.BatchNorm1d } logger.log_hyperparams(hparams)
def test_tensorboard_named_version(tmpdir): """Verify that manual versioning works for string versions, e.g. '2020-02-05-162402' """ name = "tb_versioning" (tmpdir / name).mkdir() expected_version = "2020-02-05-162402" logger = TensorBoardLogger(save_dir=tmpdir, name=name, version=expected_version) logger.log_hyperparams({"a": 1, "b": 2, 123: 3, 3.5: 4, 5j: 5}) # Force data to be written assert logger.version == expected_version assert os.listdir(tmpdir / name) == [expected_version] assert os.listdir(tmpdir / name / expected_version)
def test_tensorboard_no_name(tmpdir, name): """Verify that None or empty name works.""" logger = TensorBoardLogger(save_dir=tmpdir, name=name) logger.log_hyperparams({ "a": 1, "b": 2, 123: 3, 3.5: 4, 5j: 5 }) # Force data to be written assert os.path.normpath( logger.root_dir) == tmpdir # use os.path.normpath to handle trailing / assert os.listdir(tmpdir / "version_0")
def train_and_test(args: argparse.Namespace): dict_args = vars(args) seed = args.rng_seed log_dir = args.log_dir early_stop = args.early_stop early_stop_min_delta = args.early_stop_min_delta early_stop_patience = args.early_stop_patience pl.seed_everything(seed) callbacks: List[pl.callbacks.Callback] = [ LearningRateMonitor(logging_interval='step') ] if early_stop: # Should give enough time for lr_scheduler to try do it's thing. callbacks.append( EarlyStopping(monitor='val_loss', mode='min', min_delta=early_stop_min_delta, patience=early_stop_patience, verbose=True, strict=True)) checkpoint_callback = ModelCheckpoint( # TODO: Is low val_loss the best choice for choosing the best model? monitor='val_loss', mode='min', filepath='./checkpoints/snn-omniglot-{epoch}-{val_loss:.2f}', save_top_k=3) logger = TensorBoardLogger(log_dir, name='snn') trainer = pl.Trainer.from_argparse_args( args, logger=logger, progress_bar_refresh_rate=20, deterministic=True, auto_lr_find=True, checkpoint_callback=checkpoint_callback, callbacks=callbacks) model = TwinNet(**dict_args) # Tune learning rate. trainer.tune(model) logger.log_hyperparams(params=model.hparams) # Train model. trainer.fit(model) print('Best model saved to: ', checkpoint_callback.best_model_path) # Test using best checkpoint. trainer.test()
def test_tensorboard_log_hparams_and_metrics(tmpdir): logger = TensorBoardLogger(tmpdir) hparams = { "float": 0.3, "int": 1, "string": "abc", "bool": True, "dict": {"a": {"b": "c"}}, "list": [1, 2, 3], "namespace": Namespace(foo=Namespace(bar="buzz")), "layer": torch.nn.BatchNorm1d, } metrics = {"abc": torch.tensor([0.54])} logger.log_hyperparams(hparams, metrics)
def test_tensorboard_log_hyperparams(tmpdir): logger = TensorBoardLogger(tmpdir) hparams = { "float": 0.3, "int": 1, "string": "abc", "bool": True, "dict": {"a": {"b": "c"}}, "list": [1, 2, 3], "namespace": Namespace(foo=Namespace(bar="buzz")), "layer": torch.nn.BatchNorm1d, "tensor": torch.empty(2, 2, 2), "array": np.empty([2, 2, 2]), } logger.log_hyperparams(hparams)
def test_tensorboard_log_omegaconf_hparams_and_metrics(tmpdir): logger = TensorBoardLogger(tmpdir, default_hp_metric=False) hparams = { "float": 0.3, "int": 1, "string": "abc", "bool": True, "dict": {"a": {"b": "c"}}, "list": [1, 2, 3], # "namespace": Namespace(foo=Namespace(bar="buzz")), # "layer": torch.nn.BatchNorm1d, } hparams = OmegaConf.create(hparams) metrics = {"abc": torch.tensor([0.54])} logger.log_hyperparams(hparams, metrics)
def test_tensorboard_log_omegaconf_hparams_and_metrics(tmpdir): logger = TensorBoardLogger(tmpdir, default_hp_metric=False) hparams = { "float": 0.3, "int": 1, "string": "abc", "bool": True, "dict": { "a": { "b": "c" } }, "list": [1, 2, 3], } hparams = OmegaConf.create(hparams) metrics = {"abc": torch.tensor([0.54])} logger.log_hyperparams(hparams, metrics)
def main(): args = parse_args() paths = Paths() checkpoints_path = str(paths.CHECKPOINTS_PATH) logging_path = str(paths.LOG_PATH) callbacks = [PrintCallback()] checkpoint_callback = ModelCheckpoint(filepath=checkpoints_path + '/{epoch}-{val_acc:.3f}', save_top_k=True, verbose=True, monitor='val_acc', mode='max', prefix='') early_stop_callback = EarlyStopping(monitor='val_acc', mode='max', verbose=False, strict=False, min_delta=0.0, patience=2) gpus = gpu_count() log_save_interval = args.log_save_interval logger = TensorBoardLogger(save_dir=logging_path, name='tuna-log') logger.log_hyperparams(args) max_epochs = args.epochs model = LeNet(hparams=args, paths=paths) trainer = Trainer( callbacks=callbacks, checkpoint_callback=checkpoint_callback, early_stop_callback=early_stop_callback, fast_dev_run=True, gpus=gpus, log_save_interval=log_save_interval, logger=logger, max_epochs=max_epochs, min_epochs=1, show_progress_bar=True, weights_summary='full', ) trainer.fit(model)
def main(): args = parse_args() cfg = Config.fromfile(args.config) setup_seed(cfg.random_seed) model = LightningTransformer(cfg) checkpoint_callback = ModelCheckpoint(filepath=os.path.join( cfg.checkpoint_path, cfg.name, cfg.version, "{}_{}_{{epoch}}_{{val_loss_per_word}}".format(cfg.name, cfg.version)), save_last=True, save_top_k=8, verbose=True, monitor='val_loss_per_word', mode='min', prefix='') lr_logger_callback = LearningRateLogger(logging_interval='step') logger = TensorBoardLogger(save_dir=cfg.log_path, name=cfg.name, version=cfg.version) logger.log_hyperparams(model.hparams) profiler = SimpleProfiler() if cfg.simple_profiler else AdvancedProfiler() trainer = pl.Trainer(gpus=cfg.num_gpus, max_epochs=cfg.max_epochs, logger=logger, profiler=profiler, weights_summary="top", callbacks=[lr_logger_callback], checkpoint_callback=checkpoint_callback, resume_from_checkpoint=cfg.resume_from_checkpoint, accumulate_grad_batches=cfg.batch_size_times) if cfg.load_from_checkpoint is not None: ckpt = torch.load(cfg.load_from_checkpoint, map_location=lambda storage, loc: storage) model.load_state_dict(ckpt['state_dict']) trainer.fit(model)
def main(dm_cls, model_cls, model_args, logger_name): parser = ArgumentParser() parser.add_argument("--seed", type=int, default=None, help="random seed") parser.add_argument("--logger_name", type=str, default=logger_name, help="logger name to identify") parser.add_argument("--save_top_k", type=int, default=1, help="num of best models to save") parser = pl.Trainer.add_argparse_args(parser) parser = dm_cls.add_argparse_args(parser) parser = model_cls.add_argparse_args(parser) args = parser.parse_args() print(args) validate_args(args) seed_everything(args.seed) dm = dm_cls.from_argparse_args(args) byol_args = model_cls.extract_kwargs_from_argparse_args(args, **model_args) model = model_cls(**byol_args) logger = TensorBoardLogger('tb_logs', name=args.logger_name) logger.log_hyperparams(args) checkpoint = ModelCheckpoint(monitor='val_acc', filepath=None, save_top_k=args.save_top_k) trainer = pl.Trainer.from_argparse_args(args, deterministic=True, callbacks=[TargetNetworkUpdator()], checkpoint_callback=checkpoint, logger=logger) trainer.fit(model, dm)
def train(hparams): rdm = RetinalDataModule() model = get_model(hparams) logger = TensorBoardLogger('logs', name=get_exp_name(hparams), default_hp_metric=False) # log hparams to tensorboard logger.log_hyperparams(hparams, { 'train_acc': 0, 'train_f1': 0, 'train_loss': 0, 'valid_acc': 0, 'valid_f1': 0, 'valid_loss': 0, }) trainer = pl.Trainer(gpus=1, min_epochs=50, max_epochs=hparams['n_epochs'], logger=logger, callbacks=[ EarlyStopping(monitor='valid_loss', patience=10, mode='min'), ModelCheckpoint(monitor='valid_loss') ]) trainer.fit(model, rdm)
def main(hparams): # ------------------------ # 1 INIT LIGHTNING MODEL # ------------------------ model = GAN(hparams) # ------------------------ # 2 INIT TRAINER # ------------------------ logger = TensorBoardLogger(hparams.logdir, name="DA_Reduction_GAN") logger.log_hyperparams = dontloghparams trainer = pl.Trainer( logger=logger, # At the moment, PTL breaks when using builtin logger max_nb_epochs=100, distributed_backend="dp", gpus=[0]) # ------------------------ # 3 START TRAINING # ------------------------ trainer.fit(model)
checkpointer = ModelCheckpoint(filepath=checkpoint_filename + '.ckpt', monitor='swa_loss_no_reg') trainer = Trainer(gpus=1, num_nodes=1, max_epochs=epochs, logger=logger, callbacks=[lr_logger], checkpoint_callback=checkpointer, benchmark=True, terminate_on_nan=True, gradient_clip_val=max_l2_norm) try: trainer.fit(swag_model) except ValueError: print("Model", checkpoint_filename, 'exited early!', flush=True) exit(1) # Save model: logger.log_hyperparams( params=swag_model.hparams, metrics={'swa_loss_no_reg': checkpointer.best_model_score.item()}) logger.save() logger.finalize('success') spock_reg_model.save_swag(swag_model, output_filename + '.pkl') import pickle as pkl pkl.dump(swag_model.ssX, open(output_filename + '_ssX.pkl', 'wb'))
def train_and_test(args: argparse.Namespace): dict_args = vars(args) seed = args.rng_seed log_dir = args.log_dir early_stop = args.early_stop early_stop_min_delta = args.early_stop_min_delta early_stop_patience = args.early_stop_patience checkpoint_dir = args.checkpoint_dir pl.seed_everything(seed) callbacks: List[pl.callbacks.Callback] = [ LearningRateMonitor(logging_interval='step') ] if early_stop: # Should give enough time for lr_scheduler to try do it's thing. callbacks.append( EarlyStopping(monitor='val_loss', mode='min', min_delta=early_stop_min_delta, patience=early_stop_patience, verbose=True, strict=True)) checkpoint_callback = ModelCheckpoint( monitor='val_loss' if args.fast_dev_run else 'val_eer', mode='min', filepath=checkpoint_dir + args.model + '-{epoch}-{val_loss:.2f}-{val_eer:.2f}', save_top_k=3) logger = TensorBoardLogger(log_dir, name=args.model, log_graph=True, default_hp_metric=False) trainer = pl.Trainer.from_argparse_args( args, logger=logger, progress_bar_refresh_rate=20, deterministic=True, auto_lr_find=False, # Do this manually. checkpoint_callback=checkpoint_callback, callbacks=callbacks) model: BaseNet if args.model == 'snn': model = SNN(**dict_args) datamodule = LibriSpeechDataModule(train_set_type='pair', **dict_args) elif args.model == 'snn-capsnet': model = SNNCapsNet(**dict_args) datamodule = LibriSpeechDataModule(train_set_type='pair', **dict_args) elif args.model == 'snn-angularproto': model = SNNAngularProto(**dict_args) datamodule = LibriSpeechDataModule(train_set_type='nshotkway', **dict_args) elif args.model == 'snn-softmaxproto': model = SNNSoftmaxProto(**dict_args) datamodule = LibriSpeechDataModule(train_set_type='nshotkway', **dict_args) # Tune. trainer.tune(model, datamodule=datamodule) # Prefer provided LR over lr_finder. new_lr: float if args.learning_rate: new_lr = args.learning_rate elif args.fast_dev_run: new_lr = 1e-3 else: lr_finder = trainer.tuner.lr_find(model) new_lr = lr_finder.suggestion() # Could also try max(lr_finder.results['lr']) max_lr = new_lr * (args.max_lr_multiplier or 3) model.hparams.max_learning_rate = max_lr # type: ignore model.hparams.learning_rate = new_lr # type: ignore print('Learning rate set to {}.'.format(new_lr)) logger.log_hyperparams(params=model.hparams) # Train model. trainer.fit(model, datamodule=datamodule) print('Best model saved to: ', checkpoint_callback.best_model_path) trainer.save_checkpoint(checkpoint_dir + args.model + '-last.ckpt') # Test using best checkpoint. trainer.test(datamodule=datamodule)
def objective(trial): if hparams.version is None: hparams.version = str(uuid1()) # main LightningModule pretrain_system = PreTrainSystem( learning_rate=trial.suggest_loguniform("learning_rate", 1e-5, 1e-2), beta_1=hparams.beta_1, beta_2=hparams.beta_2, weight_decay=trial.suggest_uniform("weight_decay", 1e-5, 1e-2), optimizer=hparams.optimizer, batch_size=hparams.batch_size, multiplier=hparams.multiplier, scheduler_patience=hparams.scheduler_patience, ) pretrain_checkpoints = ModelCheckpoint( dirpath=MODEL_CHECKPOINTS_DIR, monitor="Val/loss_epoch", verbose=True, mode="min", save_top_k=hparams.save_top_k, ) pretrain_early_stopping = EarlyStopping( monitor="Val/loss_epoch", min_delta=0.00, patience=hparams.patience, verbose=False, mode="min", ) pretrain_gpu_stats_monitor = GPUStatsMonitor(temperature=True) log_recoloring_to_tensorboard = LogPairRecoloringToTensorboard() optuna_pruning = PyTorchLightningPruningCallback(monitor="Val/loss_epoch", trial=trial) logger = TensorBoardLogger( S3_LIGHTNING_LOGS_DIR, name=hparams.name, version=hparams.version, log_graph=True, default_hp_metric=False, ) trainer = Trainer.from_argparse_args( hparams, logger=logger, checkpoint_callback=pretrain_checkpoints, callbacks=[ pretrain_early_stopping, log_recoloring_to_tensorboard, pretrain_gpu_stats_monitor, optuna_pruning, ], profiler="simple", ) datamodule = PreTrainDataModule( batch_size=pretrain_system.hparams.batch_size, multiplier=pretrain_system.hparams.multiplier, shuffle=hparams.shuffle, num_workers=hparams.num_workers, size=hparams.size, pin_memory=hparams.pin_memory, train_batch_from_same_image=hparams.train_batch_from_same_image, val_batch_from_same_image=hparams.val_batch_from_same_image, test_batch_from_same_image=hparams.test_batch_from_same_image, ) # trainer.tune(pretrain_system, datamodule=datamodule) trainer.fit(pretrain_system, datamodule=datamodule) # get best checkpoint best_model_path = pretrain_checkpoints.best_model_path pretrain_system = PreTrainSystem.load_from_checkpoint(best_model_path) test_result = trainer.test(pretrain_system, datamodule=datamodule) pretrain_system.hparams.test_metric_name = test_result[0]["Test/loss_epoch"] logger.log_hyperparams(pretrain_system.hparams) logger.finalize(status="success") # upload best model to S3 S3_best_model_path = os.path.join( S3_MODEL_CHECKPOINTS_RELATIVE_DIR, hparams.name, ".".join([hparams.version, best_model_path.split(".")[-1]]), ) upload_to_s3(best_model_path, S3_best_model_path) return test_result[0]["Test/loss_epoch"]
name = 'full_swag_pre_' + checkpoint_filename logger = TensorBoardLogger("tb_logs", name=name) checkpointer = ModelCheckpoint(filepath=checkpoint_filename + '/{version}') model = spock_reg_model.VarModel(args) model.make_dataloaders() labels = ['time', 'e+_near', 'e-_near', 'max_strength_mmr_near', 'e+_far', 'e-_far', 'max_strength_mmr_far', 'megno', 'a1', 'e1', 'i1', 'cos_Omega1', 'sin_Omega1', 'cos_pomega1', 'sin_pomega1', 'cos_theta1', 'sin_theta1', 'a2', 'e2', 'i2', 'cos_Omega2', 'sin_Omega2', 'cos_pomega2', 'sin_pomega2', 'cos_theta2', 'sin_theta2', 'a3', 'e3', 'i3', 'cos_Omega3', 'sin_Omega3', 'cos_pomega3', 'sin_pomega3', 'cos_theta3', 'sin_theta3', 'm1', 'm2', 'm3', 'nan_mmr_near', 'nan_mmr_far', 'nan_megno'] max_l2_norm = args['gradient_clip']*sum(p.numel() for p in model.parameters() if p.requires_grad) trainer = Trainer( gpus=1, num_nodes=1, max_epochs=args['epochs'], logger=logger, checkpoint_callback=checkpointer, benchmark=True, terminate_on_nan=True, gradient_clip_val=max_l2_norm ) try: trainer.fit(model) except ValueError: model.load_state_dict(torch.load(checkpointer.best_model_path)['state_dict']) logger.log_hyperparams(params=model.hparams, metrics={'val_loss': checkpointer.best_model_score.item()}) logger.save() logger.finalize('success') logger.save() model.load_state_dict(torch.load(checkpointer.best_model_path)['state_dict']) model.make_dataloaders()
is_train=False ) ## Testing Dataloader testloader = data.DataLoader( test_set, batch_size=HPARAMS['data_batch_size'], shuffle=False, num_workers=hparams.n_workers ) print('Dataset Split (Train, Validation, Test)=', len(train_set), len(valid_set), len(test_set)) #Training the Model logger = TensorBoardLogger('NISP_logs', name='') logger.log_hyperparams(HPARAMS) model = LightningModel(HPARAMS) checkpoint_callback = ModelCheckpoint( monitor='v_loss', mode='min', verbose=1) trainer = pl.Trainer(fast_dev_run=hparams.dev, gpus=hparams.gpu, max_epochs=hparams.epochs, checkpoint_callback=checkpoint_callback, callbacks=[ EarlyStopping( monitor='v_loss',
def main(arguments: argparse.Namespace) -> None: """Train the model. Args: arguments: Model hyper-parameters Note: For the sake of the example, the images dataset will be downloaded to a temporary directory. """ print_system_info() print("Using following configuration: ") pprint(vars(arguments)) for fold in range(arguments.folds): if arguments.only_fold != -1: fold = arguments.only_fold print(f"Fold {fold}: Training is starting...") arguments.fold = fold model = OneCycleModule(arguments) logger = TensorBoardLogger("../logs", name=f"{arguments.backbone}-fold-{fold}") early_stop_callback = EarlyStopping(monitor='val_f1', min_delta=0.00, patience=5, verbose=True, mode='max') checkpoint_callback = ModelCheckpoint(filepath=os.path.join( arguments.save_model_path, f"checkpoint-{arguments.backbone}-fold-{fold}" + "-{epoch:02d}-{val_f1:.2f}"), save_top_k=arguments.save_top_k, monitor="val_f1", mode="max", verbose=True) trainer = pl.Trainer( weights_summary=None, num_sanity_val_steps=0, gpus=arguments.gpus, min_epochs=arguments.epochs, max_epochs=arguments.epochs, logger=logger, deterministic=True, benchmark=True, early_stop_callback=early_stop_callback, checkpoint_callback=checkpoint_callback, callbacks=[lr_logger], precision=arguments.precision, row_log_interval=10, # val_check_interval=0.5, accumulate_grad_batches=1 # fast_dev_run=True ) trainer.fit(model) logger.log_hyperparams( arguments, {"hparams/val_f1": checkpoint_callback.best_model_score.item()}) logger.save() print("-" * 80) print(f"Testing the model on fold: {fold}") trainer.test(model) model.cpu() del model del trainer del logger del early_stop_callback del checkpoint_callback # end CV loop if we only train on one fold if arguments.only_fold != -1: break