def test_dm_init_from_argparse_args(tmpdir): parser = ArgumentParser() parser = BoringDataModule.add_argparse_args(parser) args = parser.parse_args(['--data_dir', str(tmpdir)]) dm = BoringDataModule.from_argparse_args(args) dm.prepare_data() dm.setup() assert dm.data_dir == args.data_dir == str(tmpdir)
def test_can_prepare_data(local_rank, node_rank): dm = BoringDataModule() trainer = Trainer() trainer.datamodule = dm # 1 no DM # prepare_data_per_node = True # local rank = 0 (True) trainer.prepare_data_per_node = True local_rank.return_value = 0 assert trainer.local_rank == 0 assert trainer.data_connector.can_prepare_data() # local rank = 1 (False) local_rank.return_value = 1 assert trainer.local_rank == 1 assert not trainer.data_connector.can_prepare_data() # prepare_data_per_node = False (prepare across all nodes) # global rank = 0 (True) trainer.prepare_data_per_node = False node_rank.return_value = 0 local_rank.return_value = 0 assert trainer.data_connector.can_prepare_data() # global rank = 1 (False) node_rank.return_value = 1 local_rank.return_value = 0 assert not trainer.data_connector.can_prepare_data() node_rank.return_value = 0 local_rank.return_value = 1 assert not trainer.data_connector.can_prepare_data() # 2 dm # prepar per node = True # local rank = 0 (True) trainer.prepare_data_per_node = True local_rank.return_value = 0 # is_overridden prepare data = True # has been called # False dm._has_prepared_data = True assert not trainer.data_connector.can_prepare_data() # has not been called # True dm._has_prepared_data = False assert trainer.data_connector.can_prepare_data() # is_overridden prepare data = False # True dm.prepare_data = None assert trainer.data_connector.can_prepare_data()
def test_v1_5_0_datamodule_setter(): model = BoringModel() datamodule = BoringDataModule() with no_deprecated_call(match="The `LightningModule.datamodule`"): model.datamodule = datamodule with pytest.deprecated_call(match="The `LightningModule.datamodule`"): _ = model.datamodule
def test_trainer_attached_to_dm(tmpdir): reset_seed() dm = BoringDataModule() model = BoringModel() trainer = Trainer( default_root_dir=tmpdir, max_epochs=1, limit_train_batches=2, limit_val_batches=2, limit_test_batches=2, weights_summary=None, deterministic=True, ) # fit model trainer.fit(model, dm) assert trainer.state == TrainerState.FINISHED, f"Training failed with {trainer.state}" assert dm.trainer is not None # test result = trainer.test(datamodule=dm) result = result[0] assert dm.trainer is not None
def test_can_prepare_data(local_rank, node_rank): dm = BoringDataModule() trainer = Trainer() trainer.datamodule = dm # 1 no DM # prepare_data_per_node = True # local rank = 0 (True) dm.random_full = None local_rank.return_value = 0 assert trainer.local_rank == 0 trainer._data_connector.prepare_data() assert dm.random_full is not None # local rank = 1 (False) dm.random_full = None local_rank.return_value = 1 assert trainer.local_rank == 1 trainer._data_connector.prepare_data() assert dm.random_full is None # prepare_data_per_node = False (prepare across all nodes) # global rank = 0 (True) dm.random_full = None dm.prepare_data_per_node = False node_rank.return_value = 0 local_rank.return_value = 0 trainer._data_connector.prepare_data() assert dm.random_full is not None # global rank = 1 (False) dm.random_full = None node_rank.return_value = 1 local_rank.return_value = 0 trainer._data_connector.prepare_data() assert dm.random_full is None node_rank.return_value = 0 local_rank.return_value = 1 trainer._data_connector.prepare_data() assert dm.random_full is None # 2 dm # prepar per node = True # local rank = 0 (True) dm.prepare_data_per_node = True local_rank.return_value = 0 with mock.patch.object(trainer.datamodule, "prepare_data") as dm_mock: # is_overridden prepare data = True trainer._data_connector.prepare_data() dm_mock.assert_called_once()
def test_v1_5_0_datamodule_setter(): model = BoringModel() datamodule = BoringDataModule() with no_deprecated_call(match="The `LightningModule.datamodule`"): model.datamodule = datamodule from pytorch_lightning.core.lightning import warning_cache warning_cache.clear() _ = model.datamodule assert any("The `LightningModule.datamodule`" in w for w in warning_cache)
def test_is_overridden(): model = BoringModel() datamodule = BoringDataModule() # edge cases assert not is_overridden("whatever", None) with pytest.raises(ValueError, match="Expected a parent"): is_overridden("whatever", object()) assert not is_overridden("whatever", model) assert not is_overridden("whatever", model, parent=LightningDataModule) class TestModel(BoringModel): def foo(self): pass def bar(self): return 1 with pytest.raises(ValueError, match="The parent should define the method"): is_overridden("foo", TestModel()) # normal usage assert is_overridden("training_step", model) assert is_overridden("train_dataloader", datamodule) class WrappedModel(TestModel): def __new__(cls, *args, **kwargs): obj = super().__new__(cls) obj.foo = cls.wrap(obj.foo) obj.bar = cls.wrap(obj.bar) return obj @staticmethod def wrap(fn): @wraps(fn) def wrapper(): fn() return wrapper def bar(self): return 2 # `functools.wraps()` support assert not is_overridden("foo", WrappedModel(), parent=TestModel) assert is_overridden("bar", WrappedModel(), parent=TestModel) # `Mock` support mock = Mock(spec=BoringModel, wraps=model) assert is_overridden("training_step", mock) mock = Mock(spec=BoringDataModule, wraps=datamodule) assert is_overridden("train_dataloader", mock) # `partial` support model.training_step = partial(model.training_step) assert is_overridden("training_step", model)
def test_data_hooks_called_with_stage_kwarg(tmpdir): dm = BoringDataModule() dm.prepare_data() assert dm.has_prepared_data is True dm.setup(stage='fit') assert dm.has_setup_fit is True assert dm.has_setup_test is False dm.setup(stage='test') assert dm.has_setup_fit is True assert dm.has_setup_test is True
def test_inconsistent_prepare_data_per_node(tmpdir): with pytest.raises( MisconfigurationException, match="Inconsistent settings found for `prepare_data_per_node`."): model = BoringModel() dm = BoringDataModule() trainer = Trainer(prepare_data_per_node=False) trainer.model = model trainer.datamodule = dm trainer.data_connector.prepare_data()
def test_test_loop_only(tmpdir): reset_seed() dm = BoringDataModule() model = BoringModel() trainer = Trainer( default_root_dir=tmpdir, max_epochs=1, weights_summary=None, ) trainer.test(model, datamodule=dm)
def test_data_hooks_called(): dm = BoringDataModule() assert not dm.has_prepared_data assert not dm.has_setup_fit assert not dm.has_setup_test assert not dm.has_setup_validate assert not dm.has_setup_predict assert not dm.has_teardown_fit assert not dm.has_teardown_test assert not dm.has_teardown_validate assert not dm.has_teardown_predict dm.prepare_data() assert dm.has_prepared_data assert not dm.has_setup_fit assert not dm.has_setup_test assert not dm.has_setup_validate assert not dm.has_setup_predict assert not dm.has_teardown_fit assert not dm.has_teardown_test assert not dm.has_teardown_validate assert not dm.has_teardown_predict dm.setup() assert dm.has_prepared_data assert dm.has_setup_fit assert dm.has_setup_test assert dm.has_setup_validate assert not dm.has_setup_predict assert not dm.has_teardown_fit assert not dm.has_teardown_test assert not dm.has_teardown_validate assert not dm.has_teardown_predict dm.teardown() assert dm.has_prepared_data assert dm.has_setup_fit assert dm.has_setup_test assert dm.has_setup_validate assert not dm.has_setup_predict assert dm.has_teardown_fit assert dm.has_teardown_test assert dm.has_teardown_validate assert not dm.has_teardown_predict
def test_is_overridden(): model = BoringModel() datamodule = BoringDataModule() # edge cases assert not is_overridden("whatever", None) with pytest.raises(ValueError, match="Expected a parent"): is_overridden("whatever", object()) assert not is_overridden("whatever", model) assert not is_overridden("whatever", model, parent=LightningDataModule) class TestModel(BoringModel): def foo(self): pass with pytest.raises(ValueError, match="The parent should define the method"): is_overridden("foo", TestModel()) # normal usage assert is_overridden("training_step", model) assert is_overridden("train_dataloader", datamodule) # `Mock` support mock = Mock(spec=BoringModel, wraps=model) assert is_overridden("training_step", mock) mock = Mock(spec=BoringDataModule, wraps=datamodule) assert is_overridden("train_dataloader", mock) # `partial` support model.training_step = partial(model.training_step) assert is_overridden("training_step", model) # `_PatchDataLoader.patch_loader_code` support class TestModel(BoringModel): def on_fit_start(self): assert is_overridden("train_dataloader", self) self.on_fit_start_called = True model = TestModel() trainer = Trainer(fast_dev_run=1) trainer.fit(model, train_dataloader=model.train_dataloader()) assert model.on_fit_start_called
def test_data_hooks_called_verbose(tmpdir): dm = BoringDataModule() assert dm.has_prepared_data is False assert dm.has_setup_fit is False assert dm.has_setup_test is False dm.prepare_data() assert dm.has_prepared_data is True assert dm.has_setup_fit is False assert dm.has_setup_test is False dm.setup('fit') assert dm.has_prepared_data is True assert dm.has_setup_fit is True assert dm.has_setup_test is False dm.setup('test') assert dm.has_prepared_data is True assert dm.has_setup_fit is True assert dm.has_setup_test is True
def test_v1_6_0_datamodule_lifecycle_properties(tmpdir): dm = BoringDataModule() with pytest.deprecated_call(match=r"DataModule property `has_prepared_data` was deprecated in v1.4"): dm.has_prepared_data with pytest.deprecated_call(match=r"DataModule property `has_setup_fit` was deprecated in v1.4"): dm.has_setup_fit with pytest.deprecated_call(match=r"DataModule property `has_setup_validate` was deprecated in v1.4"): dm.has_setup_validate with pytest.deprecated_call(match=r"DataModule property `has_setup_test` was deprecated in v1.4"): dm.has_setup_test with pytest.deprecated_call(match=r"DataModule property `has_setup_predict` was deprecated in v1.4"): dm.has_setup_predict with pytest.deprecated_call(match=r"DataModule property `has_teardown_fit` was deprecated in v1.4"): dm.has_teardown_fit with pytest.deprecated_call(match=r"DataModule property `has_teardown_validate` was deprecated in v1.4"): dm.has_teardown_validate with pytest.deprecated_call(match=r"DataModule property `has_teardown_test` was deprecated in v1.4"): dm.has_teardown_test with pytest.deprecated_call(match=r"DataModule property `has_teardown_predict` was deprecated in v1.4"): dm.has_teardown_predict
def test_helper_boringdatamodule_with_verbose_setup(): dm = BoringDataModule() dm.prepare_data() dm.setup('fit') dm.setup('test')
def test_write_predictions(tmpdir, option: int, do_train: bool, gpus: int): class CustomBoringModel(BoringModel): def test_step(self, batch, batch_idx, optimizer_idx=None): output = self(batch) test_loss = self.loss(batch, output) self.log('test_loss', test_loss) batch_size = batch.size(0) lst_of_str = [random.choice(['dog', 'cat']) for i in range(batch_size)] lst_of_int = [random.randint(500, 1000) for i in range(batch_size)] lst_of_lst = [[x] for x in lst_of_int] lst_of_dict = [{k: v} for k, v in zip(lst_of_str, lst_of_int)] prediction_file = getattr(self, 'prediction_file', 'predictions.pt') lazy_ids = torch.arange(batch_idx * batch_size, batch_idx * batch_size + batch_size) # Base if option == 0: self.write_prediction('idxs', lazy_ids, prediction_file) self.write_prediction('preds', output, prediction_file) # Check mismatching tensor len elif option == 1: self.write_prediction('idxs', torch.cat((lazy_ids, lazy_ids)), prediction_file) self.write_prediction('preds', output, prediction_file) # write multi-dimension elif option == 2: self.write_prediction('idxs', lazy_ids, prediction_file) self.write_prediction('preds', output, prediction_file) self.write_prediction('x', batch, prediction_file) # write str list elif option == 3: self.write_prediction('idxs', lazy_ids, prediction_file) self.write_prediction('vals', lst_of_str, prediction_file) # write int list elif option == 4: self.write_prediction('idxs', lazy_ids, prediction_file) self.write_prediction('vals', lst_of_int, prediction_file) # write nested list elif option == 5: self.write_prediction('idxs', lazy_ids, prediction_file) self.write_prediction('vals', lst_of_lst, prediction_file) # write dict list elif option == 6: self.write_prediction('idxs', lazy_ids, prediction_file) self.write_prediction('vals', lst_of_dict, prediction_file) elif option == 7: self.write_prediction_dict({'idxs': lazy_ids, 'preds': output}, prediction_file) prediction_file = Path(tmpdir) / 'predictions.pt' dm = BoringDataModule() model = CustomBoringModel() model.test_epoch_end = None model.prediction_file = prediction_file.as_posix() trainer = Trainer( default_root_dir=tmpdir, max_epochs=3, weights_summary=None, deterministic=True, gpus=gpus, ) # Prediction file shouldn't exist yet because we haven't done anything assert not prediction_file.exists() if do_train: trainer.fit(model, dm) assert trainer.state.finished, f"Training failed with {trainer.state}" trainer.test(datamodule=dm) else: trainer.test(model, datamodule=dm) # check prediction file now exists and is of expected length assert prediction_file.exists() predictions = torch.load(prediction_file) assert len(predictions) == len(dm.random_test)
def test_data_hooks_called_verbose(use_kwarg): dm = BoringDataModule() dm.prepare_data() assert not dm.has_setup_fit assert not dm.has_setup_test assert not dm.has_setup_validate assert not dm.has_setup_predict assert not dm.has_teardown_fit assert not dm.has_teardown_test assert not dm.has_teardown_validate assert not dm.has_teardown_predict dm.setup(stage="fit") if use_kwarg else dm.setup("fit") assert dm.has_setup_fit assert not dm.has_setup_validate assert not dm.has_setup_test assert not dm.has_setup_predict dm.setup(stage="validate") if use_kwarg else dm.setup("validate") assert dm.has_setup_fit assert dm.has_setup_validate assert not dm.has_setup_test assert not dm.has_setup_predict dm.setup(stage="test") if use_kwarg else dm.setup("test") assert dm.has_setup_fit assert dm.has_setup_validate assert dm.has_setup_test assert not dm.has_setup_predict dm.setup(stage="predict") if use_kwarg else dm.setup("predict") assert dm.has_setup_fit assert dm.has_setup_validate assert dm.has_setup_test assert dm.has_setup_predict dm.teardown(stage="fit") if use_kwarg else dm.teardown("fit") assert dm.has_teardown_fit assert not dm.has_teardown_validate assert not dm.has_teardown_test assert not dm.has_teardown_predict dm.teardown(stage="validate") if use_kwarg else dm.teardown("validate") assert dm.has_teardown_fit assert dm.has_teardown_validate assert not dm.has_teardown_test assert not dm.has_teardown_predict dm.teardown(stage="test") if use_kwarg else dm.teardown("test") assert dm.has_teardown_fit assert dm.has_teardown_validate assert dm.has_teardown_test assert not dm.has_teardown_predict dm.teardown(stage="predict") if use_kwarg else dm.teardown("predict") assert dm.has_teardown_fit assert dm.has_teardown_validate assert dm.has_teardown_test assert dm.has_teardown_predict
def test_dm_pickle_after_init(tmpdir): dm = BoringDataModule() pickle.dumps(dm)
def __init__(self): super().__init__() self.train_dataloader = None self.val_dataloader = None self.test_dataloader = None self.predict_dataloader = None @pytest.mark.parametrize( "instance,available", [ (None, True), (BoringModel().train_dataloader(), True), (BoringModel(), True), (NoDataLoaderModel(), False), (BoringDataModule(), True), ], ) def test_dataloader_source_available(instance, available): """Test the availability check for _DataLoaderSource.""" source = _DataLoaderSource(instance=instance, name="train_dataloader") assert source.is_defined() is available def test_dataloader_source_direct_access(): """Test requesting a dataloader when the source is already a dataloader.""" dataloader = BoringModel().train_dataloader() source = _DataLoaderSource(instance=dataloader, name="any") assert not source.is_module() assert source.is_defined() assert source.dataloader() is dataloader
def test_dm_add_argparse_args(tmpdir): parser = ArgumentParser() parser = BoringDataModule.add_argparse_args(parser) args = parser.parse_args(['--data_dir', str(tmpdir)]) assert args.data_dir == str(tmpdir)
def reset_instances(self): warning_cache.clear() return BoringDataModule(), BoringModel(), Trainer()
def test_base_datamodule_with_verbose_setup(tmpdir): dm = BoringDataModule() dm.prepare_data() dm.setup('fit') dm.setup('test')
def test_base_datamodule(tmpdir): dm = BoringDataModule() dm.prepare_data() dm.setup()
def test_data_hooks_called_verbose(use_kwarg): dm = BoringDataModule() dm.prepare_data() assert not dm.has_setup_fit assert not dm.has_setup_test assert not dm.has_setup_validate assert not dm.has_setup_predict assert not dm.has_teardown_fit assert not dm.has_teardown_test assert not dm.has_teardown_validate assert not dm.has_teardown_predict dm.setup(stage='fit') if use_kwarg else dm.setup('fit') assert dm.has_setup_fit assert not dm.has_setup_validate assert not dm.has_setup_test assert not dm.has_setup_predict dm.setup(stage='validate') if use_kwarg else dm.setup('validate') assert dm.has_setup_fit assert dm.has_setup_validate assert not dm.has_setup_test assert not dm.has_setup_predict dm.setup(stage='test') if use_kwarg else dm.setup('test') assert dm.has_setup_fit assert dm.has_setup_validate assert dm.has_setup_test assert not dm.has_setup_predict dm.setup(stage='predict') if use_kwarg else dm.setup('predict') assert dm.has_setup_fit assert dm.has_setup_validate assert dm.has_setup_test assert dm.has_setup_predict dm.teardown(stage='fit') if use_kwarg else dm.teardown('fit') assert dm.has_teardown_fit assert not dm.has_teardown_validate assert not dm.has_teardown_test assert not dm.has_teardown_predict dm.teardown(stage='validate') if use_kwarg else dm.teardown('validate') assert dm.has_teardown_fit assert dm.has_teardown_validate assert not dm.has_teardown_test assert not dm.has_teardown_predict dm.teardown(stage='test') if use_kwarg else dm.teardown('test') assert dm.has_teardown_fit assert dm.has_teardown_validate assert dm.has_teardown_test assert not dm.has_teardown_predict dm.teardown(stage='predict') if use_kwarg else dm.teardown('predict') assert dm.has_teardown_fit assert dm.has_teardown_validate assert dm.has_teardown_test assert dm.has_teardown_predict
def test_result_obj_predictions(tmpdir, test_option, do_train, gpus): class CustomBoringModel(BoringModel): def test_step(self, batch, batch_idx, optimizer_idx=None): output = self(batch) test_loss = self.loss(batch, output) self.log('test_loss', test_loss) batch_size = batch.size(0) lst_of_str = [random.choice(['dog', 'cat']) for i in range(batch_size)] lst_of_int = [random.randint(500, 1000) for i in range(batch_size)] lst_of_lst = [[x] for x in lst_of_int] lst_of_dict = [{k: v} for k, v in zip(lst_of_str, lst_of_int)] # This is passed in from pytest via parameterization option = getattr(self, 'test_option', 0) prediction_file = getattr(self, 'prediction_file', 'predictions.pt') lazy_ids = torch.arange(batch_idx * batch_size, batch_idx * batch_size + batch_size) # Base if option == 0: self.write_prediction('idxs', lazy_ids, prediction_file) self.write_prediction('preds', output, prediction_file) # Check mismatching tensor len elif option == 1: self.write_prediction('idxs', torch.cat((lazy_ids, lazy_ids)), prediction_file) self.write_prediction('preds', output, prediction_file) # write multi-dimension elif option == 2: self.write_prediction('idxs', lazy_ids, prediction_file) self.write_prediction('preds', output, prediction_file) self.write_prediction('x', batch, prediction_file) # write str list elif option == 3: self.write_prediction('idxs', lazy_ids, prediction_file) self.write_prediction('vals', lst_of_str, prediction_file) # write int list elif option == 4: self.write_prediction('idxs', lazy_ids, prediction_file) self.write_prediction('vals', lst_of_int, prediction_file) # write nested list elif option == 5: self.write_prediction('idxs', lazy_ids, prediction_file) self.write_prediction('vals', lst_of_lst, prediction_file) # write dict list elif option == 6: self.write_prediction('idxs', lazy_ids, prediction_file) self.write_prediction('vals', lst_of_dict, prediction_file) elif option == 7: self.write_prediction_dict({'idxs': lazy_ids, 'preds': output}, prediction_file) class CustomBoringDataModule(BoringDataModule): def train_dataloader(self): return DataLoader(self.random_train, batch_size=4) def val_dataloader(self): return DataLoader(self.random_val, batch_size=4) def test_dataloader(self): return DataLoader(self.random_test, batch_size=4) tutils.reset_seed() prediction_file = Path(tmpdir) / 'predictions.pt' dm = BoringDataModule() model = CustomBoringModel() model.test_step_end = None model.test_epoch_end = None model.test_end = None model.test_option = test_option model.prediction_file = prediction_file.as_posix() if prediction_file.exists(): prediction_file.unlink() trainer = Trainer( default_root_dir=tmpdir, max_epochs=3, weights_summary=None, deterministic=True, gpus=gpus, ) # Prediction file shouldn't exist yet because we haven't done anything assert not prediction_file.exists() if do_train: result = trainer.fit(model, dm) assert trainer.state == TrainerState.FINISHED, f"Training failed with {trainer.state}" assert result result = trainer.test(datamodule=dm) # TODO: add end-to-end test # assert result[0]['test_loss'] < 0.6 else: result = trainer.test(model, datamodule=dm) # check prediction file now exists and is of expected length assert prediction_file.exists() predictions = torch.load(prediction_file) assert len(predictions) == len(dm.random_test)
def test_helper_boringdatamodule(): dm = BoringDataModule() dm.prepare_data() dm.setup()