def test_dm_init_from_argparse_args(tmpdir):
    parser = ArgumentParser()
    parser = BoringDataModule.add_argparse_args(parser)
    args = parser.parse_args(['--data_dir', str(tmpdir)])
    dm = BoringDataModule.from_argparse_args(args)
    dm.prepare_data()
    dm.setup()
    assert dm.data_dir == args.data_dir == str(tmpdir)
def test_can_prepare_data(local_rank, node_rank):

    dm = BoringDataModule()
    trainer = Trainer()
    trainer.datamodule = dm

    # 1 no DM
    # prepare_data_per_node = True
    # local rank = 0   (True)
    trainer.prepare_data_per_node = True

    local_rank.return_value = 0
    assert trainer.local_rank == 0
    assert trainer.data_connector.can_prepare_data()

    # local rank = 1   (False)
    local_rank.return_value = 1
    assert trainer.local_rank == 1
    assert not trainer.data_connector.can_prepare_data()

    # prepare_data_per_node = False (prepare across all nodes)
    # global rank = 0   (True)
    trainer.prepare_data_per_node = False
    node_rank.return_value = 0
    local_rank.return_value = 0
    assert trainer.data_connector.can_prepare_data()

    # global rank = 1   (False)
    node_rank.return_value = 1
    local_rank.return_value = 0
    assert not trainer.data_connector.can_prepare_data()
    node_rank.return_value = 0
    local_rank.return_value = 1
    assert not trainer.data_connector.can_prepare_data()

    # 2 dm
    # prepar per node = True
    # local rank = 0 (True)
    trainer.prepare_data_per_node = True
    local_rank.return_value = 0

    # is_overridden prepare data = True
    # has been called
    # False
    dm._has_prepared_data = True
    assert not trainer.data_connector.can_prepare_data()

    # has not been called
    # True
    dm._has_prepared_data = False
    assert trainer.data_connector.can_prepare_data()

    # is_overridden prepare data = False
    # True
    dm.prepare_data = None
    assert trainer.data_connector.can_prepare_data()
Example #3
0
def test_v1_5_0_datamodule_setter():
    model = BoringModel()
    datamodule = BoringDataModule()
    with no_deprecated_call(match="The `LightningModule.datamodule`"):
        model.datamodule = datamodule
    with pytest.deprecated_call(match="The `LightningModule.datamodule`"):
        _ = model.datamodule
def test_trainer_attached_to_dm(tmpdir):
    reset_seed()

    dm = BoringDataModule()
    model = BoringModel()

    trainer = Trainer(
        default_root_dir=tmpdir,
        max_epochs=1,
        limit_train_batches=2,
        limit_val_batches=2,
        limit_test_batches=2,
        weights_summary=None,
        deterministic=True,
    )

    # fit model
    trainer.fit(model, dm)
    assert trainer.state == TrainerState.FINISHED, f"Training failed with {trainer.state}"
    assert dm.trainer is not None

    # test
    result = trainer.test(datamodule=dm)
    result = result[0]
    assert dm.trainer is not None
Example #5
0
def test_can_prepare_data(local_rank, node_rank):
    dm = BoringDataModule()
    trainer = Trainer()
    trainer.datamodule = dm

    # 1 no DM
    # prepare_data_per_node = True
    # local rank = 0   (True)
    dm.random_full = None
    local_rank.return_value = 0
    assert trainer.local_rank == 0

    trainer._data_connector.prepare_data()
    assert dm.random_full is not None

    # local rank = 1   (False)
    dm.random_full = None
    local_rank.return_value = 1
    assert trainer.local_rank == 1

    trainer._data_connector.prepare_data()
    assert dm.random_full is None

    # prepare_data_per_node = False (prepare across all nodes)
    # global rank = 0   (True)
    dm.random_full = None
    dm.prepare_data_per_node = False
    node_rank.return_value = 0
    local_rank.return_value = 0

    trainer._data_connector.prepare_data()
    assert dm.random_full is not None

    # global rank = 1   (False)
    dm.random_full = None
    node_rank.return_value = 1
    local_rank.return_value = 0

    trainer._data_connector.prepare_data()
    assert dm.random_full is None

    node_rank.return_value = 0
    local_rank.return_value = 1

    trainer._data_connector.prepare_data()
    assert dm.random_full is None

    # 2 dm
    # prepar per node = True
    # local rank = 0 (True)
    dm.prepare_data_per_node = True
    local_rank.return_value = 0

    with mock.patch.object(trainer.datamodule, "prepare_data") as dm_mock:
        # is_overridden prepare data = True
        trainer._data_connector.prepare_data()
        dm_mock.assert_called_once()
def test_v1_5_0_datamodule_setter():
    model = BoringModel()
    datamodule = BoringDataModule()
    with no_deprecated_call(match="The `LightningModule.datamodule`"):
        model.datamodule = datamodule
    from pytorch_lightning.core.lightning import warning_cache
    warning_cache.clear()
    _ = model.datamodule
    assert any("The `LightningModule.datamodule`" in w for w in warning_cache)
Example #7
0
def test_is_overridden():
    model = BoringModel()
    datamodule = BoringDataModule()

    # edge cases
    assert not is_overridden("whatever", None)
    with pytest.raises(ValueError, match="Expected a parent"):
        is_overridden("whatever", object())
    assert not is_overridden("whatever", model)
    assert not is_overridden("whatever", model, parent=LightningDataModule)

    class TestModel(BoringModel):
        def foo(self):
            pass

        def bar(self):
            return 1

    with pytest.raises(ValueError, match="The parent should define the method"):
        is_overridden("foo", TestModel())

    # normal usage
    assert is_overridden("training_step", model)
    assert is_overridden("train_dataloader", datamodule)

    class WrappedModel(TestModel):
        def __new__(cls, *args, **kwargs):
            obj = super().__new__(cls)
            obj.foo = cls.wrap(obj.foo)
            obj.bar = cls.wrap(obj.bar)
            return obj

        @staticmethod
        def wrap(fn):
            @wraps(fn)
            def wrapper():
                fn()

            return wrapper

        def bar(self):
            return 2

    # `functools.wraps()` support
    assert not is_overridden("foo", WrappedModel(), parent=TestModel)
    assert is_overridden("bar", WrappedModel(), parent=TestModel)

    # `Mock` support
    mock = Mock(spec=BoringModel, wraps=model)
    assert is_overridden("training_step", mock)
    mock = Mock(spec=BoringDataModule, wraps=datamodule)
    assert is_overridden("train_dataloader", mock)

    # `partial` support
    model.training_step = partial(model.training_step)
    assert is_overridden("training_step", model)
def test_data_hooks_called_with_stage_kwarg(tmpdir):
    dm = BoringDataModule()
    dm.prepare_data()
    assert dm.has_prepared_data is True

    dm.setup(stage='fit')
    assert dm.has_setup_fit is True
    assert dm.has_setup_test is False

    dm.setup(stage='test')
    assert dm.has_setup_fit is True
    assert dm.has_setup_test is True
Example #9
0
def test_inconsistent_prepare_data_per_node(tmpdir):
    with pytest.raises(
            MisconfigurationException,
            match="Inconsistent settings found for `prepare_data_per_node`."):
        model = BoringModel()
        dm = BoringDataModule()
        trainer = Trainer(prepare_data_per_node=False)
        trainer.model = model
        trainer.datamodule = dm
        trainer.data_connector.prepare_data()
def test_test_loop_only(tmpdir):
    reset_seed()

    dm = BoringDataModule()
    model = BoringModel()

    trainer = Trainer(
        default_root_dir=tmpdir,
        max_epochs=1,
        weights_summary=None,
    )
    trainer.test(model, datamodule=dm)
def test_data_hooks_called():
    dm = BoringDataModule()
    assert not dm.has_prepared_data
    assert not dm.has_setup_fit
    assert not dm.has_setup_test
    assert not dm.has_setup_validate
    assert not dm.has_setup_predict
    assert not dm.has_teardown_fit
    assert not dm.has_teardown_test
    assert not dm.has_teardown_validate
    assert not dm.has_teardown_predict

    dm.prepare_data()
    assert dm.has_prepared_data
    assert not dm.has_setup_fit
    assert not dm.has_setup_test
    assert not dm.has_setup_validate
    assert not dm.has_setup_predict
    assert not dm.has_teardown_fit
    assert not dm.has_teardown_test
    assert not dm.has_teardown_validate
    assert not dm.has_teardown_predict

    dm.setup()
    assert dm.has_prepared_data
    assert dm.has_setup_fit
    assert dm.has_setup_test
    assert dm.has_setup_validate
    assert not dm.has_setup_predict
    assert not dm.has_teardown_fit
    assert not dm.has_teardown_test
    assert not dm.has_teardown_validate
    assert not dm.has_teardown_predict

    dm.teardown()
    assert dm.has_prepared_data
    assert dm.has_setup_fit
    assert dm.has_setup_test
    assert dm.has_setup_validate
    assert not dm.has_setup_predict
    assert dm.has_teardown_fit
    assert dm.has_teardown_test
    assert dm.has_teardown_validate
    assert not dm.has_teardown_predict
Example #12
0
def test_is_overridden():
    model = BoringModel()
    datamodule = BoringDataModule()

    # edge cases
    assert not is_overridden("whatever", None)
    with pytest.raises(ValueError, match="Expected a parent"):
        is_overridden("whatever", object())
    assert not is_overridden("whatever", model)
    assert not is_overridden("whatever", model, parent=LightningDataModule)

    class TestModel(BoringModel):

        def foo(self):
            pass

    with pytest.raises(ValueError, match="The parent should define the method"):
        is_overridden("foo", TestModel())

    # normal usage
    assert is_overridden("training_step", model)
    assert is_overridden("train_dataloader", datamodule)

    # `Mock` support
    mock = Mock(spec=BoringModel, wraps=model)
    assert is_overridden("training_step", mock)
    mock = Mock(spec=BoringDataModule, wraps=datamodule)
    assert is_overridden("train_dataloader", mock)

    # `partial` support
    model.training_step = partial(model.training_step)
    assert is_overridden("training_step", model)

    # `_PatchDataLoader.patch_loader_code` support
    class TestModel(BoringModel):

        def on_fit_start(self):
            assert is_overridden("train_dataloader", self)
            self.on_fit_start_called = True

    model = TestModel()
    trainer = Trainer(fast_dev_run=1)
    trainer.fit(model, train_dataloader=model.train_dataloader())
    assert model.on_fit_start_called
def test_data_hooks_called_verbose(tmpdir):
    dm = BoringDataModule()
    assert dm.has_prepared_data is False
    assert dm.has_setup_fit is False
    assert dm.has_setup_test is False

    dm.prepare_data()
    assert dm.has_prepared_data is True
    assert dm.has_setup_fit is False
    assert dm.has_setup_test is False

    dm.setup('fit')
    assert dm.has_prepared_data is True
    assert dm.has_setup_fit is True
    assert dm.has_setup_test is False

    dm.setup('test')
    assert dm.has_prepared_data is True
    assert dm.has_setup_fit is True
    assert dm.has_setup_test is True
Example #14
0
def test_v1_6_0_datamodule_lifecycle_properties(tmpdir):
    dm = BoringDataModule()
    with pytest.deprecated_call(match=r"DataModule property `has_prepared_data` was deprecated in v1.4"):
        dm.has_prepared_data
    with pytest.deprecated_call(match=r"DataModule property `has_setup_fit` was deprecated in v1.4"):
        dm.has_setup_fit
    with pytest.deprecated_call(match=r"DataModule property `has_setup_validate` was deprecated in v1.4"):
        dm.has_setup_validate
    with pytest.deprecated_call(match=r"DataModule property `has_setup_test` was deprecated in v1.4"):
        dm.has_setup_test
    with pytest.deprecated_call(match=r"DataModule property `has_setup_predict` was deprecated in v1.4"):
        dm.has_setup_predict
    with pytest.deprecated_call(match=r"DataModule property `has_teardown_fit` was deprecated in v1.4"):
        dm.has_teardown_fit
    with pytest.deprecated_call(match=r"DataModule property `has_teardown_validate` was deprecated in v1.4"):
        dm.has_teardown_validate
    with pytest.deprecated_call(match=r"DataModule property `has_teardown_test` was deprecated in v1.4"):
        dm.has_teardown_test
    with pytest.deprecated_call(match=r"DataModule property `has_teardown_predict` was deprecated in v1.4"):
        dm.has_teardown_predict
def test_helper_boringdatamodule_with_verbose_setup():
    dm = BoringDataModule()
    dm.prepare_data()
    dm.setup('fit')
    dm.setup('test')
Example #16
0
def test_write_predictions(tmpdir, option: int, do_train: bool, gpus: int):

    class CustomBoringModel(BoringModel):

        def test_step(self, batch, batch_idx, optimizer_idx=None):
            output = self(batch)
            test_loss = self.loss(batch, output)
            self.log('test_loss', test_loss)

            batch_size = batch.size(0)
            lst_of_str = [random.choice(['dog', 'cat']) for i in range(batch_size)]
            lst_of_int = [random.randint(500, 1000) for i in range(batch_size)]
            lst_of_lst = [[x] for x in lst_of_int]
            lst_of_dict = [{k: v} for k, v in zip(lst_of_str, lst_of_int)]

            prediction_file = getattr(self, 'prediction_file', 'predictions.pt')

            lazy_ids = torch.arange(batch_idx * batch_size, batch_idx * batch_size + batch_size)

            # Base
            if option == 0:
                self.write_prediction('idxs', lazy_ids, prediction_file)
                self.write_prediction('preds', output, prediction_file)

            # Check mismatching tensor len
            elif option == 1:
                self.write_prediction('idxs', torch.cat((lazy_ids, lazy_ids)), prediction_file)
                self.write_prediction('preds', output, prediction_file)

            # write multi-dimension
            elif option == 2:
                self.write_prediction('idxs', lazy_ids, prediction_file)
                self.write_prediction('preds', output, prediction_file)
                self.write_prediction('x', batch, prediction_file)

            # write str list
            elif option == 3:
                self.write_prediction('idxs', lazy_ids, prediction_file)
                self.write_prediction('vals', lst_of_str, prediction_file)

            # write int list
            elif option == 4:
                self.write_prediction('idxs', lazy_ids, prediction_file)
                self.write_prediction('vals', lst_of_int, prediction_file)

            # write nested list
            elif option == 5:
                self.write_prediction('idxs', lazy_ids, prediction_file)
                self.write_prediction('vals', lst_of_lst, prediction_file)

            # write dict list
            elif option == 6:
                self.write_prediction('idxs', lazy_ids, prediction_file)
                self.write_prediction('vals', lst_of_dict, prediction_file)

            elif option == 7:
                self.write_prediction_dict({'idxs': lazy_ids, 'preds': output}, prediction_file)

    prediction_file = Path(tmpdir) / 'predictions.pt'

    dm = BoringDataModule()
    model = CustomBoringModel()
    model.test_epoch_end = None
    model.prediction_file = prediction_file.as_posix()

    trainer = Trainer(
        default_root_dir=tmpdir,
        max_epochs=3,
        weights_summary=None,
        deterministic=True,
        gpus=gpus,
    )

    # Prediction file shouldn't exist yet because we haven't done anything
    assert not prediction_file.exists()

    if do_train:
        trainer.fit(model, dm)
        assert trainer.state.finished, f"Training failed with {trainer.state}"
        trainer.test(datamodule=dm)
    else:
        trainer.test(model, datamodule=dm)

    # check prediction file now exists and is of expected length
    assert prediction_file.exists()
    predictions = torch.load(prediction_file)
    assert len(predictions) == len(dm.random_test)
Example #17
0
def test_data_hooks_called_verbose(use_kwarg):
    dm = BoringDataModule()
    dm.prepare_data()
    assert not dm.has_setup_fit
    assert not dm.has_setup_test
    assert not dm.has_setup_validate
    assert not dm.has_setup_predict
    assert not dm.has_teardown_fit
    assert not dm.has_teardown_test
    assert not dm.has_teardown_validate
    assert not dm.has_teardown_predict

    dm.setup(stage="fit") if use_kwarg else dm.setup("fit")
    assert dm.has_setup_fit
    assert not dm.has_setup_validate
    assert not dm.has_setup_test
    assert not dm.has_setup_predict

    dm.setup(stage="validate") if use_kwarg else dm.setup("validate")
    assert dm.has_setup_fit
    assert dm.has_setup_validate
    assert not dm.has_setup_test
    assert not dm.has_setup_predict

    dm.setup(stage="test") if use_kwarg else dm.setup("test")
    assert dm.has_setup_fit
    assert dm.has_setup_validate
    assert dm.has_setup_test
    assert not dm.has_setup_predict

    dm.setup(stage="predict") if use_kwarg else dm.setup("predict")
    assert dm.has_setup_fit
    assert dm.has_setup_validate
    assert dm.has_setup_test
    assert dm.has_setup_predict

    dm.teardown(stage="fit") if use_kwarg else dm.teardown("fit")
    assert dm.has_teardown_fit
    assert not dm.has_teardown_validate
    assert not dm.has_teardown_test
    assert not dm.has_teardown_predict

    dm.teardown(stage="validate") if use_kwarg else dm.teardown("validate")
    assert dm.has_teardown_fit
    assert dm.has_teardown_validate
    assert not dm.has_teardown_test
    assert not dm.has_teardown_predict

    dm.teardown(stage="test") if use_kwarg else dm.teardown("test")
    assert dm.has_teardown_fit
    assert dm.has_teardown_validate
    assert dm.has_teardown_test
    assert not dm.has_teardown_predict

    dm.teardown(stage="predict") if use_kwarg else dm.teardown("predict")
    assert dm.has_teardown_fit
    assert dm.has_teardown_validate
    assert dm.has_teardown_test
    assert dm.has_teardown_predict
def test_dm_pickle_after_init(tmpdir):
    dm = BoringDataModule()
    pickle.dumps(dm)
Example #19
0
    def __init__(self):
        super().__init__()
        self.train_dataloader = None
        self.val_dataloader = None
        self.test_dataloader = None
        self.predict_dataloader = None


@pytest.mark.parametrize(
    "instance,available",
    [
        (None, True),
        (BoringModel().train_dataloader(), True),
        (BoringModel(), True),
        (NoDataLoaderModel(), False),
        (BoringDataModule(), True),
    ],
)
def test_dataloader_source_available(instance, available):
    """Test the availability check for _DataLoaderSource."""
    source = _DataLoaderSource(instance=instance, name="train_dataloader")
    assert source.is_defined() is available


def test_dataloader_source_direct_access():
    """Test requesting a dataloader when the source is already a dataloader."""
    dataloader = BoringModel().train_dataloader()
    source = _DataLoaderSource(instance=dataloader, name="any")
    assert not source.is_module()
    assert source.is_defined()
    assert source.dataloader() is dataloader
def test_dm_add_argparse_args(tmpdir):
    parser = ArgumentParser()
    parser = BoringDataModule.add_argparse_args(parser)
    args = parser.parse_args(['--data_dir', str(tmpdir)])
    assert args.data_dir == str(tmpdir)
Example #21
0
 def reset_instances(self):
     warning_cache.clear()
     return BoringDataModule(), BoringModel(), Trainer()
def test_base_datamodule_with_verbose_setup(tmpdir):
    dm = BoringDataModule()
    dm.prepare_data()
    dm.setup('fit')
    dm.setup('test')
def test_base_datamodule(tmpdir):
    dm = BoringDataModule()
    dm.prepare_data()
    dm.setup()
def test_data_hooks_called_verbose(use_kwarg):
    dm = BoringDataModule()
    dm.prepare_data()
    assert not dm.has_setup_fit
    assert not dm.has_setup_test
    assert not dm.has_setup_validate
    assert not dm.has_setup_predict
    assert not dm.has_teardown_fit
    assert not dm.has_teardown_test
    assert not dm.has_teardown_validate
    assert not dm.has_teardown_predict

    dm.setup(stage='fit') if use_kwarg else dm.setup('fit')
    assert dm.has_setup_fit
    assert not dm.has_setup_validate
    assert not dm.has_setup_test
    assert not dm.has_setup_predict

    dm.setup(stage='validate') if use_kwarg else dm.setup('validate')
    assert dm.has_setup_fit
    assert dm.has_setup_validate
    assert not dm.has_setup_test
    assert not dm.has_setup_predict

    dm.setup(stage='test') if use_kwarg else dm.setup('test')
    assert dm.has_setup_fit
    assert dm.has_setup_validate
    assert dm.has_setup_test
    assert not dm.has_setup_predict

    dm.setup(stage='predict') if use_kwarg else dm.setup('predict')
    assert dm.has_setup_fit
    assert dm.has_setup_validate
    assert dm.has_setup_test
    assert dm.has_setup_predict

    dm.teardown(stage='fit') if use_kwarg else dm.teardown('fit')
    assert dm.has_teardown_fit
    assert not dm.has_teardown_validate
    assert not dm.has_teardown_test
    assert not dm.has_teardown_predict

    dm.teardown(stage='validate') if use_kwarg else dm.teardown('validate')
    assert dm.has_teardown_fit
    assert dm.has_teardown_validate
    assert not dm.has_teardown_test
    assert not dm.has_teardown_predict

    dm.teardown(stage='test') if use_kwarg else dm.teardown('test')
    assert dm.has_teardown_fit
    assert dm.has_teardown_validate
    assert dm.has_teardown_test
    assert not dm.has_teardown_predict

    dm.teardown(stage='predict') if use_kwarg else dm.teardown('predict')
    assert dm.has_teardown_fit
    assert dm.has_teardown_validate
    assert dm.has_teardown_test
    assert dm.has_teardown_predict
Example #25
0
def test_result_obj_predictions(tmpdir, test_option, do_train, gpus):

    class CustomBoringModel(BoringModel):

        def test_step(self, batch, batch_idx, optimizer_idx=None):
            output = self(batch)
            test_loss = self.loss(batch, output)
            self.log('test_loss', test_loss)

            batch_size = batch.size(0)
            lst_of_str = [random.choice(['dog', 'cat']) for i in range(batch_size)]
            lst_of_int = [random.randint(500, 1000) for i in range(batch_size)]
            lst_of_lst = [[x] for x in lst_of_int]
            lst_of_dict = [{k: v} for k, v in zip(lst_of_str, lst_of_int)]

            # This is passed in from pytest via parameterization
            option = getattr(self, 'test_option', 0)
            prediction_file = getattr(self, 'prediction_file', 'predictions.pt')

            lazy_ids = torch.arange(batch_idx * batch_size, batch_idx * batch_size + batch_size)

            # Base
            if option == 0:
                self.write_prediction('idxs', lazy_ids, prediction_file)
                self.write_prediction('preds', output, prediction_file)

            # Check mismatching tensor len
            elif option == 1:
                self.write_prediction('idxs', torch.cat((lazy_ids, lazy_ids)), prediction_file)
                self.write_prediction('preds', output, prediction_file)

            # write multi-dimension
            elif option == 2:
                self.write_prediction('idxs', lazy_ids, prediction_file)
                self.write_prediction('preds', output, prediction_file)
                self.write_prediction('x', batch, prediction_file)

            # write str list
            elif option == 3:
                self.write_prediction('idxs', lazy_ids, prediction_file)
                self.write_prediction('vals', lst_of_str, prediction_file)

            # write int list
            elif option == 4:
                self.write_prediction('idxs', lazy_ids, prediction_file)
                self.write_prediction('vals', lst_of_int, prediction_file)

            # write nested list
            elif option == 5:
                self.write_prediction('idxs', lazy_ids, prediction_file)
                self.write_prediction('vals', lst_of_lst, prediction_file)

            # write dict list
            elif option == 6:
                self.write_prediction('idxs', lazy_ids, prediction_file)
                self.write_prediction('vals', lst_of_dict, prediction_file)

            elif option == 7:
                self.write_prediction_dict({'idxs': lazy_ids, 'preds': output}, prediction_file)

    class CustomBoringDataModule(BoringDataModule):

        def train_dataloader(self):
            return DataLoader(self.random_train, batch_size=4)

        def val_dataloader(self):
            return DataLoader(self.random_val, batch_size=4)

        def test_dataloader(self):
            return DataLoader(self.random_test, batch_size=4)

    tutils.reset_seed()
    prediction_file = Path(tmpdir) / 'predictions.pt'

    dm = BoringDataModule()
    model = CustomBoringModel()
    model.test_step_end = None
    model.test_epoch_end = None
    model.test_end = None

    model.test_option = test_option
    model.prediction_file = prediction_file.as_posix()

    if prediction_file.exists():
        prediction_file.unlink()

    trainer = Trainer(
        default_root_dir=tmpdir,
        max_epochs=3,
        weights_summary=None,
        deterministic=True,
        gpus=gpus,
    )

    # Prediction file shouldn't exist yet because we haven't done anything
    assert not prediction_file.exists()

    if do_train:
        result = trainer.fit(model, dm)
        assert trainer.state == TrainerState.FINISHED, f"Training failed with {trainer.state}"
        assert result
        result = trainer.test(datamodule=dm)
        # TODO: add end-to-end test
        # assert result[0]['test_loss'] < 0.6
    else:
        result = trainer.test(model, datamodule=dm)

    # check prediction file now exists and is of expected length
    assert prediction_file.exists()
    predictions = torch.load(prediction_file)
    assert len(predictions) == len(dm.random_test)
def test_helper_boringdatamodule():
    dm = BoringDataModule()
    dm.prepare_data()
    dm.setup()