def test_overfit_batch_limits(tmpdir): # ------------------------------------------------------ # Make sure shuffle is correct across loaders initially # ------------------------------------------------------ model = EvalModelTemplate() model.train_dataloader() # original train loader which should be replaced in all methods train_loader = model.train_dataloader() # make sure the val and tests are not shuffled assert isinstance(train_loader.sampler, RandomSampler) assert isinstance(model.val_dataloader().sampler, SequentialSampler) assert isinstance(model.test_dataloader().sampler, SequentialSampler) # ------------------------------------------------------ # get the training loader and batch # ------------------------------------------------------ # Create a reference train dataloader without shuffling. train_loader = DataLoader(model.train_dataloader().dataset, shuffle=False) (xa, ya) = next(iter(train_loader)) train_loader = DataLoader(model.train_dataloader().dataset, shuffle=True) full_train_samples = len(train_loader) num_train_samples = int(0.11 * full_train_samples) # ------------------------------------------------------ # set VAL and Test loaders # ------------------------------------------------------ val_loader = DataLoader(model.val_dataloader().dataset, shuffle=False) test_loader = DataLoader(model.test_dataloader().dataset, shuffle=False) # set the model loaders model.train_dataloader = lambda: train_loader model.val_dataloader = lambda: val_loader model.test_dataloader = lambda: test_loader # ------------------------------------------------------ # test train loader applies correct limits # ------------------------------------------------------ trainer = Trainer(overfit_batches=4) trainer.reset_train_dataloader(model) assert trainer.num_training_batches == 4 # make sure the loaders are the same (xb, yb) = next(iter(trainer.train_dataloader)) assert torch.eq(xa, xb).all() assert torch.eq(ya, yb).all() trainer = Trainer(overfit_batches=0.11) trainer.reset_train_dataloader(model) # The dataloader should have been overwritten with a Sequential sampler. assert trainer.train_dataloader is not train_loader assert trainer.num_training_batches == num_train_samples # make sure the loaders are the same (xb, yb) = next(iter(trainer.train_dataloader)) assert torch.eq(xa, xb).all() assert torch.eq(ya, yb).all() # ------------------------------------------------------ # run tests for both val and test # ------------------------------------------------------ for split in ['val', 'test']: # ------------------------------------------------------ # test overfit_batches as percent # ------------------------------------------------------ loader_num_batches, dataloaders = Trainer( overfit_batches=0.11)._reset_eval_dataloader(model, split) assert loader_num_batches[0] == num_train_samples # make sure we turned off shuffle for the user assert isinstance(dataloaders[0].sampler, SequentialSampler) # make sure the loaders are the same (xb, yb) = next(iter(dataloaders[0])) assert torch.eq(xa, xb).all() assert torch.eq(ya, yb).all() # ------------------------------------------------------ # test overfit_batches as int # ------------------------------------------------------ loader_num_batches, dataloaders = Trainer( overfit_batches=1)._reset_eval_dataloader(model, split) assert loader_num_batches[0] == 1 loader_num_batches, dataloaders = Trainer( overfit_batches=5)._reset_eval_dataloader(model, split) assert loader_num_batches[0] == 5 # ------------------------------------------------------ # test limit_xxx_batches as percent AND int # ------------------------------------------------------ if split == 'val': loader_num_batches, dataloaders = Trainer( limit_val_batches=0.1)._reset_eval_dataloader(model, split) assert loader_num_batches[0] == int(0.1 * len(val_loader)) loader_num_batches, dataloaders = Trainer( limit_val_batches=10)._reset_eval_dataloader(model, split) assert loader_num_batches[0] == 10 else: loader_num_batches, dataloaders = Trainer( limit_test_batches=0.1)._reset_eval_dataloader(model, split) assert loader_num_batches[0] == int(0.1 * len(test_loader)) loader_num_batches, dataloaders = Trainer( limit_test_batches=10)._reset_eval_dataloader(model, split) assert loader_num_batches[0] == 10
def test_cpu_slurm_save_load(tmpdir): """Verify model save/load/checkpoint on CPU.""" hparams = EvalModelTemplate.get_default_hparams() model = EvalModelTemplate(**hparams) # logger file to get meta logger = tutils.get_default_logger(tmpdir) version = logger.version # fit model trainer = Trainer( default_root_dir=tmpdir, max_epochs=1, logger=logger, limit_train_batches=0.2, limit_val_batches=0.2, checkpoint_callback=ModelCheckpoint(tmpdir), ) result = trainer.fit(model) real_global_step = trainer.global_step # traning complete assert result == 1, 'cpu model failed to complete' # predict with trained model before saving # make a prediction dataloaders = model.test_dataloader() if not isinstance(dataloaders, list): dataloaders = [dataloaders] for dataloader in dataloaders: for batch in dataloader: break x, y = batch x = x.view(x.size(0), -1) model.eval() pred_before_saving = model(x) # test HPC saving # simulate snapshot on slurm saved_filepath = trainer.hpc_save(trainer.weights_save_path, logger) assert os.path.exists(saved_filepath) # new logger file to get meta logger = tutils.get_default_logger(tmpdir, version=version) trainer = Trainer( default_root_dir=tmpdir, max_epochs=1, logger=logger, checkpoint_callback=ModelCheckpoint(tmpdir), ) model = EvalModelTemplate(**hparams) # set the epoch start hook so we can predict before the model does the full training def assert_pred_same(): assert trainer.global_step == real_global_step and trainer.global_step > 0 # predict with loaded model to make sure answers are the same trainer.model.eval() new_pred = trainer.model(x) assert torch.all(torch.eq(pred_before_saving, new_pred)).item() == 1 model.on_epoch_start = assert_pred_same # by calling fit again, we trigger training, loading weights from the cluster # and our hook to predict using current model before any more weight updates trainer.fit(model)
def test_dataloaders_with_limit_num_batches(tmpdir, limit_train_batches, limit_val_batches, limit_test_batches): """Verify num_batches for train, val & test dataloaders passed with batch limit as number""" os.environ['PL_DEV_DEBUG'] = '1' model = EvalModelTemplate() model.val_dataloader = model.val_dataloader__multiple_mixed_length model.test_dataloader = model.test_dataloader__multiple_mixed_length model.validation_step = model.validation_step__multiple_dataloaders model.validation_epoch_end = model.validation_epoch_end__multiple_dataloaders model.test_step = model.test_step__multiple_dataloaders model.test_epoch_end = model.test_epoch_end__multiple_dataloaders # train, multiple val and multiple test passed with percent_check trainer = Trainer( default_root_dir=tmpdir, max_epochs=1, limit_train_batches=limit_train_batches, limit_val_batches=limit_val_batches, limit_test_batches=limit_test_batches, ) trainer.fit(model) # ------------------------------------------- # MAKE SURE THE TRAINER SET THE CORRECT VALUES # ------------------------------------------- assert trainer.num_training_batches == limit_train_batches assert trainer.num_val_batches == [limit_val_batches] * len( trainer.val_dataloaders) trainer.test(ckpt_path=None) # when the limit is greater than the number of test batches it should be the num in loaders test_dataloader_lengths = [len(x) for x in model.test_dataloader()] if limit_test_batches > 1e10: assert trainer.num_test_batches == test_dataloader_lengths else: assert trainer.num_test_batches == [limit_test_batches] * len( trainer.test_dataloaders) # ------------------------------------------- # make sure we actually saw the expected num of batches # ------------------------------------------- num_val_dataloaders = len(model.val_dataloader()) num_test_dataloaders = len(model.test_dataloader()) if limit_train_batches > 0: # make sure val batches are as expected assert len(trainer.dev_debugger.num_seen_val_check_batches ) == num_val_dataloaders for dataloader_idx, num_batches in trainer.dev_debugger.num_seen_val_check_batches.items( ): assert num_batches == limit_val_batches # make sure test batches are as expected assert len(trainer.dev_debugger.num_seen_test_check_batches ) == num_test_dataloaders for dataloader_idx, num_batches in trainer.dev_debugger.num_seen_test_check_batches.items( ): if limit_test_batches > 1e10: assert num_batches == test_dataloader_lengths[dataloader_idx] else: assert num_batches == limit_test_batches