def test_rl_early_stopping(): gata_double_dqn = GATADoubleDQN() trainer = Trainer() es = RLEarlyStopping("val_monitor", "train_monitor", 0.95, patience=3) # if val score and train score are all below the threshold 0.95, don't stop trainer.callback_metrics = {"val_monitor": 0.1, "train_monitor": 0.1} es._run_early_stopping_check(trainer, gata_double_dqn) assert not trainer.should_stop # if val score is 1.0 and train score is above the threshold, stop trainer.callback_metrics = {"val_monitor": 1.0, "train_monitor": 0.95} trainer.current_epoch = 1 es._run_early_stopping_check(trainer, gata_double_dqn) assert trainer.should_stop assert es.stopped_epoch == 1 # if train score is above the threshold for `patience` times, # but val score is not 1.0, stop trainer.should_stop = False es.wait_count = 0 es.stopped_epoch = 0 for i in range(3): trainer.current_epoch = i trainer.callback_metrics = {"val_monitor": 0.9, "train_monitor": 0.95} es._run_early_stopping_check(trainer, gata_double_dqn) if i == 2: assert trainer.should_stop assert es.stopped_epoch == 2 else: assert not trainer.should_stop assert es.stopped_epoch == 0
def test_model_checkpoint_options(tmpdir, save_top_k, save_last, file_prefix, expected_files): """Test ModelCheckpoint options.""" def mock_save_function(filepath, *args): open(filepath, 'a').close() # simulated losses losses = [10, 9, 2.8, 5, 2.5] checkpoint_callback = ModelCheckpoint(tmpdir, save_top_k=save_top_k, save_last=save_last, prefix=file_prefix, verbose=1) checkpoint_callback.save_function = mock_save_function trainer = Trainer() # emulate callback's calls during the training for i, loss in enumerate(losses): trainer.current_epoch = i trainer.callback_metrics = {'val_loss': loss} checkpoint_callback.on_validation_end(trainer, trainer.get_model()) file_lists = set(os.listdir(tmpdir)) assert len(file_lists) == len(expected_files), \ "Should save %i models when save_top_k=%i" % (len(expected_files), save_top_k) # verify correct naming for fname in expected_files: assert fname in file_lists
def test_model_checkpoint_options(tmpdir, save_top_k, save_last, file_prefix, expected_files): """Test ModelCheckpoint options.""" def mock_save_function(filepath, *args): open(filepath, 'a').close() # simulated losses losses = [10, 9, 2.8, 5, 2.5] checkpoint_callback = ModelCheckpoint(tmpdir, monitor='checkpoint_on', save_top_k=save_top_k, save_last=save_last, prefix=file_prefix, verbose=1) checkpoint_callback.save_function = mock_save_function trainer = Trainer() # emulate callback's calls during the training for i, loss in enumerate(losses): trainer.current_epoch = i trainer.logger_connector.callback_metrics = { 'checkpoint_on': torch.tensor(loss) } checkpoint_callback.on_validation_end(trainer, trainer.get_model()) file_lists = set(os.listdir(tmpdir)) assert len(file_lists) == len(expected_files), ( f"Should save {len(expected_files)} models when save_top_k={save_top_k} but found={file_lists}" ) # verify correct naming for fname in expected_files: assert fname in file_lists
def trainWithTune(config, checkpoint_dir=None, datamodule=None, num_epochs=10, num_gpus=0): trainer = Trainer( max_epochs=num_epochs, # If fractional GPUs passed in, convert to int. gpus=math.ceil(num_gpus), logger=TensorBoardLogger(save_dir=tune.get_trial_dir(), name="", version="."), progress_bar_refresh_rate=0, callbacks=[ TuneReportCheckpointCallback(metrics={ "loss": "val_loss", "mean_accuracy": "val_acc", "mean_iou": "val_iou", }, filename="checkpoint", on="validation_end") ]) if checkpoint_dir: # Currently, this leads to errors: # model = LightningMNISTClassifier.load_from_checkpoint( # os.path.join(checkpoint, "checkpoint")) # Workaround: ckpt = pl_load(os.path.join(checkpoint_dir, "checkpoint"), map_location=lambda storage, loc: storage) model = MMETrainingModule._load_model_state( ckpt, lr=10**config['log_lr'], lrRatio=10**config['log_lrRatio'], decay=10**config['log_decay'], num_cls=NUM_CLS) trainer.current_epoch = ckpt["epoch"] else: model = MMETrainingModule(lr=10**config['log_lr'], lrRatio=10**config['log_lrRatio'], decay=10**config['log_decay'], num_cls=NUM_CLS) trainer.fit(model, datamodule=datamodule)
default_root_dir=args.results, resume_from_checkpoint=ckpt_path, accelerator="ddp" if args.gpus > 1 else None, limit_train_batches=1.0 if args.train_batches == 0 else args.train_batches, limit_val_batches=1.0 if args.test_batches == 0 else args.test_batches, limit_test_batches=1.0 if args.test_batches == 0 else args.test_batches, ) if args.benchmark: if args.exec_mode == "train": trainer.fit(model, train_dataloader=data_module.train_dataloader()) else: # warmup trainer.test(model, test_dataloaders=data_module.test_dataloader()) # benchmark run trainer.current_epoch = 1 trainer.test(model, test_dataloaders=data_module.test_dataloader()) elif args.exec_mode == "train": trainer.fit(model, data_module) if is_main_process(): logname = args.logname if args.logname is not None else "train_log.json" log(logname, torch.tensor(model.best_mean_dice), results=args.results) elif args.exec_mode == "evaluate": model.args = args trainer.test(model, test_dataloaders=data_module.val_dataloader()) if is_main_process(): logname = args.logname if args.logname is not None else "eval_log.json" log(logname, model.eval_dice, results=args.results) elif args.exec_mode == "predict": if args.save_preds: ckpt_name = "_".join(args.ckpt_path.split("/")[-1].split(".")[:-1])
def test_model_checkpoint_options(tmp_path): """Test ModelCheckpoint options.""" def mock_save_function(filepath): open(filepath, 'a').close() hparams = tutils.get_hparams() _ = LightningTestModel(hparams) # simulated losses save_dir = tmp_path / "1" save_dir.mkdir() losses = [10, 9, 2.8, 5, 2.5] # ----------------- # CASE K=-1 (all) checkpoint_callback = ModelCheckpoint(save_dir, save_top_k=-1, verbose=1) checkpoint_callback.save_function = mock_save_function trainer = Trainer() # emulate callback's calls during the training for i, loss in enumerate(losses): trainer.current_epoch = i trainer.callback_metrics = {'val_loss': loss} checkpoint_callback.on_validation_end(trainer, trainer.get_model()) file_lists = set(os.listdir(save_dir)) assert len(file_lists) == len( losses), "Should save all models when save_top_k=-1" # verify correct naming for i in range(0, len(losses)): assert f"_ckpt_epoch_{i}.ckpt" in file_lists save_dir = tmp_path / "2" save_dir.mkdir() # ----------------- # CASE K=0 (none) checkpoint_callback = ModelCheckpoint(save_dir, save_top_k=0, verbose=1) checkpoint_callback.save_function = mock_save_function trainer = Trainer() # emulate callback's calls during the training for i, loss in enumerate(losses): trainer.current_epoch = i trainer.callback_metrics = {'val_loss': loss} checkpoint_callback.on_validation_end(trainer, trainer.get_model()) file_lists = os.listdir(save_dir) assert len(file_lists) == 0, "Should save 0 models when save_top_k=0" save_dir = tmp_path / "3" save_dir.mkdir() # ----------------- # CASE K=1 (2.5, epoch 4) checkpoint_callback = ModelCheckpoint(save_dir, save_top_k=1, verbose=1, prefix='test_prefix') checkpoint_callback.save_function = mock_save_function trainer = Trainer() # emulate callback's calls during the training for i, loss in enumerate(losses): trainer.current_epoch = i trainer.callback_metrics = {'val_loss': loss} checkpoint_callback.on_validation_end(trainer, trainer.get_model()) file_lists = set(os.listdir(save_dir)) assert len(file_lists) == 1, "Should save 1 model when save_top_k=1" assert 'test_prefix_ckpt_epoch_4.ckpt' in file_lists save_dir = tmp_path / "4" save_dir.mkdir() # ----------------- # CASE K=2 (2.5 epoch 4, 2.8 epoch 2) # make sure other files don't get deleted checkpoint_callback = ModelCheckpoint(save_dir, save_top_k=2, verbose=1) open(f"{save_dir}/other_file.ckpt", 'a').close() checkpoint_callback.save_function = mock_save_function trainer = Trainer() # emulate callback's calls during the training for i, loss in enumerate(losses): trainer.current_epoch = i trainer.callback_metrics = {'val_loss': loss} checkpoint_callback.on_validation_end(trainer, trainer.get_model()) file_lists = set(os.listdir(save_dir)) assert len(file_lists) == 3, 'Should save 2 model when save_top_k=2' assert '_ckpt_epoch_4.ckpt' in file_lists assert '_ckpt_epoch_2.ckpt' in file_lists assert 'other_file.ckpt' in file_lists save_dir = tmp_path / "5" save_dir.mkdir() # ----------------- # CASE K=4 (save all 4 models) # multiple checkpoints within same epoch checkpoint_callback = ModelCheckpoint(save_dir, save_top_k=4, verbose=1) checkpoint_callback.save_function = mock_save_function trainer = Trainer() # emulate callback's calls during the training for loss in losses: trainer.current_epoch = 0 trainer.callback_metrics = {'val_loss': loss} checkpoint_callback.on_validation_end(trainer, trainer.get_model()) file_lists = set(os.listdir(save_dir)) assert len( file_lists ) == 4, 'Should save all 4 models when save_top_k=4 within same epoch' save_dir = tmp_path / "6" save_dir.mkdir() # ----------------- # CASE K=3 (save the 2nd, 3rd, 4th model) # multiple checkpoints within same epoch checkpoint_callback = ModelCheckpoint(save_dir, save_top_k=3, verbose=1) checkpoint_callback.save_function = mock_save_function trainer = Trainer() # emulate callback's calls during the training for loss in losses: trainer.current_epoch = 0 trainer.callback_metrics = {'val_loss': loss} checkpoint_callback.on_validation_end(trainer, trainer.get_model()) file_lists = set(os.listdir(save_dir)) assert len(file_lists) == 3, 'Should save 3 models when save_top_k=3' assert '_ckpt_epoch_0_v2.ckpt' in file_lists assert '_ckpt_epoch_0_v1.ckpt' in file_lists assert '_ckpt_epoch_0.ckpt' in file_lists