def build_flash_serve_model_component(model, serve_input, output, transform,
                                      transform_kwargs):
    # TODO: Resolve this hack
    data_module = DataModule(
        predict_input=serve_input,
        batch_size=1,
        transform=transform,
        transform_kwargs=transform_kwargs,
    )

    class MockTrainer(Trainer):
        def __init__(self):
            super().__init__()
            self.state.stage = RunningStage.PREDICTING

        @property
        def lightning_module(self):
            return model

    data_module.trainer = MockTrainer()
    dataloader = data_module.predict_dataloader()

    collate_fn = dataloader.collate_fn

    class FlashServeModelComponent(ModelComponent):
        def __init__(self, model):
            self.model = model
            self.model.eval()
            self.serve_input = serve_input
            self.on_after_batch_transfer = data_module.on_after_batch_transfer
            self.output_transform = getattr(model, "_output_transform",
                                            None) or OutputTransform()
            # TODO (@tchaton) Remove this hack
            self.extra_arguments = len(
                inspect.signature(
                    self.model.transfer_batch_to_device).parameters) == 3
            self.device = self.model.device

        @expose(
            inputs={
                "inputs":
                FlashInputs(_ServeInputProcessor(serve_input, collate_fn))
            },
            outputs={"outputs": FlashOutputs(output)},
        )
        def predict(self, inputs):
            with torch.no_grad():
                if self.extra_arguments:
                    inputs = self.model.transfer_batch_to_device(
                        inputs, self.device, 0)
                else:
                    inputs = self.model.transfer_batch_to_device(
                        inputs, self.device)
                inputs = self.on_after_batch_transfer(inputs, 0)
                preds = self.model.predict_step(inputs, 0)
                preds = self.output_transform(preds)
                return preds

    return FlashServeModelComponent(model)
示例#2
0
def test_data_loaders_num_workers_to_0(tmpdir):
    """
    num_workers should be set to `0` internally for visualization and not for training.
    """

    datamodule = DataModule(train_dataset=range(10), num_workers=3)
    iterator = datamodule._reset_iterator(RunningStage.TRAINING)
    assert isinstance(iterator, torch.utils.data.dataloader._SingleProcessDataLoaderIter)
    iterator = iter(datamodule.train_dataloader())
    assert isinstance(iterator, torch.utils.data.dataloader._MultiProcessingDataLoaderIter)
    assert datamodule.num_workers == 3
示例#3
0
def test_available_data_sources():
    preprocess = CustomPreprocess()

    assert DefaultDataSources.TENSORS in preprocess.available_data_sources()
    assert "test" in preprocess.available_data_sources()
    assert len(preprocess.available_data_sources()) == 3

    data_module = DataModule(preprocess=preprocess)

    assert DefaultDataSources.TENSORS in data_module.available_data_sources()
    assert "test" in data_module.available_data_sources()
    assert len(data_module.available_data_sources()) == 3
示例#4
0
def test_datapipeline_transformations(tmpdir):

    datamodule = DataModule.from_data_source(
        "default", 1, 1, 1, 1, batch_size=2, num_workers=0, preprocess=TestPreprocessTransformations()
    )

    assert datamodule.train_dataloader().dataset[0] == (0, 1, 2, 3)
    batch = next(iter(datamodule.train_dataloader()))
    assert torch.equal(batch, tensor([[0, 1, 2, 3, 5], [0, 1, 2, 3, 5]]))

    assert datamodule.val_dataloader().dataset[0] == {'a': 0, 'b': 1}
    assert datamodule.val_dataloader().dataset[1] == {'a': 1, 'b': 2}
    with pytest.raises(MisconfigurationException, match="When ``to_tensor_transform``"):
        batch = next(iter(datamodule.val_dataloader()))

    datamodule = DataModule.from_data_source(
        "default", 1, 1, 1, 1, batch_size=2, num_workers=0, preprocess=TestPreprocessTransformations2()
    )
    batch = next(iter(datamodule.val_dataloader()))
    assert torch.equal(batch["a"], tensor([0, 1]))
    assert torch.equal(batch["b"], tensor([1, 2]))

    model = CustomModel()
    trainer = Trainer(
        max_epochs=1,
        limit_train_batches=2,
        limit_val_batches=1,
        limit_test_batches=2,
        limit_predict_batches=2,
        num_sanity_val_steps=1
    )
    trainer.fit(model, datamodule=datamodule)
    trainer.test(model)
    trainer.predict(model)

    preprocess = model._preprocess
    data_source = preprocess.data_source_of_name("default")
    assert data_source.train_load_data_called
    assert preprocess.train_pre_tensor_transform_called
    assert preprocess.train_collate_called
    assert preprocess.train_per_batch_transform_on_device_called
    assert data_source.val_load_data_called
    assert data_source.val_load_sample_called
    assert preprocess.val_to_tensor_transform_called
    assert preprocess.val_collate_called
    assert preprocess.val_per_batch_transform_on_device_called
    assert data_source.test_load_data_called
    assert preprocess.test_to_tensor_transform_called
    assert preprocess.test_post_tensor_transform_called
    assert data_source.predict_load_data_called
示例#5
0
def test_datapipeline_transformations_overridden_by_task():
    # define input transforms
    class ImageInput(Input):
        def load_data(self, folder):
            # from folder -> return files paths
            return ["a.jpg", "b.jpg"]

        def load_sample(self, path):
            # from a file path, load the associated image
            return np.random.uniform(0, 1, (64, 64, 3))

    class ImageClassificationInputTransform(InputTransform):
        def per_sample_transform(self) -> Callable:
            return T.Compose([T.ToTensor()])

        def per_batch_transform_on_device(self) -> Callable:
            return T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])

    class OverrideInputTransform(InputTransform):
        def per_sample_transform(self) -> Callable:
            return T.Compose([T.ToTensor(), T.Resize(128)])

    # define task which overrides transforms using set_state
    class CustomModel(Task):
        def __init__(self):
            super().__init__(model=torch.nn.Linear(1, 1),
                             loss_fn=torch.nn.MSELoss())

            # override default transform to resize images
            self.input_transform = OverrideInputTransform

        def training_step(self, batch, batch_idx):
            assert batch.shape == torch.Size([2, 3, 128, 128])
            assert torch.max(batch) <= 1.0
            assert torch.min(batch) >= 0.0

        def validation_step(self, batch, batch_idx):
            assert batch.shape == torch.Size([2, 3, 128, 128])
            assert torch.max(batch) <= 1.0
            assert torch.min(batch) >= 0.0

    transform = ImageClassificationInputTransform()
    datamodule = DataModule(
        ImageInput(RunningStage.TRAINING, [1]),
        ImageInput(RunningStage.VALIDATING, [1]),
        transform=transform,
        batch_size=2,
        num_workers=0,
    )

    # call trainer
    model = CustomModel()
    trainer = Trainer(
        max_epochs=1,
        limit_train_batches=2,
        limit_val_batches=1,
        num_sanity_val_steps=1,
    )
    trainer.fit(model, datamodule=datamodule)
示例#6
0
def test_transformations(tmpdir):

    transform = TestInputTransform()
    datamodule = DataModule(
        TestInput(RunningStage.TRAINING, [1]),
        TestInput(RunningStage.VALIDATING, [1]),
        TestInput(RunningStage.TESTING, [1]),
        transform=transform,
        batch_size=2,
        num_workers=0,
    )

    assert datamodule.train_dataloader().dataset[0] == (0, 1, 2, 3)
    batch = next(iter(datamodule.train_dataloader()))
    assert torch.equal(batch, torch.tensor([[0, 1, 2, 3, 5], [0, 1, 2, 3, 5]]))

    assert datamodule.val_dataloader().dataset[0] == {"a": 0, "b": 1}
    assert datamodule.val_dataloader().dataset[1] == {"a": 1, "b": 2}
    batch = next(iter(datamodule.val_dataloader()))

    datamodule = DataModule(
        TestInput(RunningStage.TRAINING, [1]),
        TestInput(RunningStage.VALIDATING, [1]),
        TestInput(RunningStage.TESTING, [1]),
        transform=TestInputTransform2,
        batch_size=2,
        num_workers=0,
    )
    batch = next(iter(datamodule.val_dataloader()))
    assert torch.equal(batch["a"], torch.tensor([0, 1]))
    assert torch.equal(batch["b"], torch.tensor([1, 2]))

    model = CustomModel()
    trainer = Trainer(
        max_epochs=1,
        limit_train_batches=2,
        limit_val_batches=1,
        limit_test_batches=2,
        limit_predict_batches=2,
        num_sanity_val_steps=1,
    )
    trainer.fit(model, datamodule=datamodule)
    trainer.test(model, datamodule=datamodule)

    assert datamodule.input_transform.train_per_sample_transform_called
    assert datamodule.input_transform.train_collate_called
    assert datamodule.input_transform.train_per_batch_transform_on_device_called
    assert datamodule.input_transform.train_per_sample_transform_called
    assert datamodule.input_transform.val_collate_called
    assert datamodule.input_transform.val_per_batch_transform_on_device_called
    assert datamodule.input_transform.test_per_sample_transform_called
示例#7
0
def test_val_split():
    datamodule = DataModule(
        Input(RunningStage.TRAINING, [1] * 100),
        batch_size=2,
        num_workers=0,
        val_split=0.2,
    )

    assert len(datamodule.train_dataset) == 80
    assert len(datamodule.val_dataset) == 20
示例#8
0
def test_dataloaders_with_sampler(mock_dataloader, sampler, callable):
    train_input = TestInput(RunningStage.TRAINING, [1])
    datamodule = DataModule(
        train_input,
        TestInput(RunningStage.VALIDATING, [1]),
        TestInput(RunningStage.TESTING, [1]),
        batch_size=2,
        num_workers=0,
        sampler=sampler,
    )

    assert datamodule.sampler is sampler
    dl = datamodule.train_dataloader()

    if callable:
        sampler.assert_called_once_with(train_input)

    kwargs = mock_dataloader.call_args[1]
    assert "sampler" in kwargs
    assert kwargs["sampler"] is (sampler.return_value if callable else sampler)
    for dl in [datamodule.val_dataloader(), datamodule.test_dataloader()]:
        kwargs = mock_dataloader.call_args[1]
        assert "sampler" not in kwargs
示例#9
0
def test_predict_dataset(tmpdir):
    """Tests that we can generate embeddings from a pytorch geometric dataset."""
    tudataset = datasets.TUDataset(root=tmpdir, name="KKI")
    model = GraphEmbedder(
        GraphClassifier(num_features=tudataset.num_features,
                        num_classes=tudataset.num_classes).backbone)
    datamodule = DataModule(
        predict_input=GraphClassificationDatasetInput(RunningStage.PREDICTING,
                                                      tudataset),
        transform=GraphClassificationInputTransform,
        batch_size=4,
    )
    trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True)
    out = trainer.predict(model, datamodule=datamodule)
    assert isinstance(out[0][0], torch.Tensor)
示例#10
0
def test_split_dataset(tmpdir):

    train_ds, val_ds = DataModule._split_train_val(range(100), val_split=0.1)
    assert len(train_ds) == 90
    assert len(val_ds) == 10
    assert len(np.unique(train_ds.indices)) == len(train_ds.indices)

    with pytest.raises(MisconfigurationException, match="[0, 99]"):
        SplitDataset(range(100), indices=[100])

    with pytest.raises(MisconfigurationException, match="[0, 49]"):
        SplitDataset(range(50), indices=[-1])

    with pytest.raises(MisconfigurationException, match="[0, 49]"):
        SplitDataset(list(range(50)) + list(range(50)), indices=[-1])

    with pytest.raises(MisconfigurationException, match="[0, 99]"):
        SplitDataset(list(range(50)) + list(range(50)),
                     indices=[-1],
                     use_duplicated_indices=True)

    class Dataset:
        def __init__(self):
            self.data = [0, 1, 2]
            self.name = "something"

        def __getitem__(self, index):
            return self.data[index]

        def __len__(self):
            return len(self.data)

    split_dataset = SplitDataset(Dataset(), indices=[0])
    assert split_dataset.name == "something"

    assert split_dataset._INTERNAL_KEYS == ("dataset", "indices", "data")

    split_dataset.is_passed_down = True
    assert split_dataset.dataset.is_passed_down
示例#11
0
def test_split_dataset():
    train_ds, val_ds = DataModule._split_train_val(range(100), val_split=0.1)
    assert len(train_ds) == 90
    assert len(val_ds) == 10
    assert len(np.unique(train_ds.indices)) == len(train_ds.indices)

    class Dataset:
        def __init__(self):
            self.data = [0, 1, 2]
            self.name = "something"
            self.is_passed_down = False

        def __getitem__(self, index):
            return self.data[index]

        def __len__(self):
            return len(self.data)

    split_dataset = SplitDataset(Dataset(), indices=[0])
    assert split_dataset.name == "something"

    split_dataset.is_passed_down = True
    assert not split_dataset.dataset.is_passed_down
示例#12
0
def test_not_trainable(tmpdir):
    """Tests that the model gives an error when training, validating, or testing."""
    tudataset = datasets.TUDataset(root=tmpdir, name="KKI")
    model = GraphEmbedder(
        GraphClassifier(num_features=1, num_classes=1).backbone)
    datamodule = DataModule(
        GraphClassificationDatasetInput(RunningStage.TRAINING, tudataset),
        GraphClassificationDatasetInput(RunningStage.VALIDATING, tudataset),
        GraphClassificationDatasetInput(RunningStage.TESTING, tudataset),
        transform=GraphClassificationInputTransform,
        batch_size=4,
    )
    trainer = Trainer(default_root_dir=tmpdir, num_sanity_val_steps=0)
    with pytest.raises(NotImplementedError,
                       match="Training a `GraphEmbedder` is not supported."):
        trainer.fit(model, datamodule=datamodule)

    with pytest.raises(NotImplementedError,
                       match="Validating a `GraphEmbedder` is not supported."):
        trainer.validate(model, datamodule=datamodule)

    with pytest.raises(NotImplementedError,
                       match="Testing a `GraphEmbedder` is not supported."):
        trainer.test(model, datamodule=datamodule)
            T.Resize(self.image_size),
            T.ToTensor(),
            T.RandomRotation(self.rotation)
        ]
        return T.Compose(transforms)

    def input_per_sample_transform(self) -> Callable:
        # this will be used to transform only the input value associated with
        # the `input` key within each sample.
        transforms = [T.Resize(self.image_size), T.ToTensor()]
        return T.Compose(transforms)


# Register your transform within the InputTransform registry of the Flash DataModule
# Note: Registries can be shared by multiple dataset.
DataModule.register_input_transform("base", BaseImageInputTransform)
DataModule.register_input_transform("random_rotation",
                                    ImageRandomRotationInputTransform)
DataModule.register_input_transform(
    "random_90_def_rotation",
    partial(ImageRandomRotationInputTransform, rotation=90))

#############################################################################################
#                       Step 3 / 3: Create a DataModule (Part 1)                            #
#                                                                                           #
# The `DataModule` class is a collection of `Input` for various stages and the              #
# `InputTransform` and you can pass them directly to its init function.                     #
#                                                                                           #
#############################################################################################

datamodule = DataModule(
示例#14
0
def test_flash_callback(_, __, tmpdir):
    """Test the callback hook system for fit."""

    callback_mock = mock.MagicMock()

    inputs = [(torch.rand(1), torch.rand(1))]
    transform = InputTransform()
    dm = DataModule(
        DatasetInput(RunningStage.TRAINING, inputs),
        DatasetInput(RunningStage.VALIDATING, inputs),
        DatasetInput(RunningStage.TESTING, inputs),
        transform=transform,
        batch_size=1,
        num_workers=0,
        data_fetcher=callback_mock,
    )

    _ = next(iter(dm.train_dataloader()))

    assert callback_mock.method_calls == [
        mock.call.on_load_sample(mock.ANY, RunningStage.TRAINING),
        mock.call.on_per_sample_transform(mock.ANY, RunningStage.TRAINING),
        mock.call.on_collate(mock.ANY, RunningStage.TRAINING),
        mock.call.on_per_batch_transform(mock.ANY, RunningStage.TRAINING),
    ]

    class CustomModel(Task):
        def __init__(self):
            super().__init__(model=torch.nn.Linear(1, 1), loss_fn=torch.nn.MSELoss())

        def training_step(self, batch, batch_idx):
            batch = (batch[DataKeys.INPUT], batch[DataKeys.TARGET])
            return super().training_step(batch, batch_idx)

        def validation_step(self, batch, batch_idx):
            batch = (batch[DataKeys.INPUT], batch[DataKeys.TARGET])
            return super().validation_step(batch, batch_idx)

        def test_step(self, batch, batch_idx):
            batch = (batch[DataKeys.INPUT], batch[DataKeys.TARGET])
            return super().test_step(batch, batch_idx)

    trainer = Trainer(
        default_root_dir=tmpdir,
        max_epochs=1,
        limit_val_batches=1,
        limit_train_batches=1,
        progress_bar_refresh_rate=0,
    )
    transform = InputTransform()
    dm = DataModule(
        DatasetInput(RunningStage.TRAINING, inputs),
        DatasetInput(RunningStage.VALIDATING, inputs),
        DatasetInput(RunningStage.TESTING, inputs),
        transform=transform,
        batch_size=1,
        num_workers=0,
        data_fetcher=callback_mock,
    )
    trainer.fit(CustomModel(), datamodule=dm)

    assert callback_mock.method_calls == [
        mock.call.on_load_sample(mock.ANY, RunningStage.TRAINING),
        mock.call.on_per_sample_transform(mock.ANY, RunningStage.TRAINING),
        mock.call.on_collate(mock.ANY, RunningStage.TRAINING),
        mock.call.on_per_batch_transform(mock.ANY, RunningStage.TRAINING),
        mock.call.on_load_sample(mock.ANY, RunningStage.VALIDATING),
        mock.call.on_per_sample_transform(mock.ANY, RunningStage.VALIDATING),
        mock.call.on_collate(mock.ANY, RunningStage.VALIDATING),
        mock.call.on_per_batch_transform(mock.ANY, RunningStage.VALIDATING),
        mock.call.on_per_batch_transform_on_device(mock.ANY, RunningStage.VALIDATING),
        mock.call.on_load_sample(mock.ANY, RunningStage.TRAINING),
        mock.call.on_per_sample_transform(mock.ANY, RunningStage.TRAINING),
        mock.call.on_collate(mock.ANY, RunningStage.TRAINING),
        mock.call.on_per_batch_transform(mock.ANY, RunningStage.TRAINING),
        mock.call.on_per_batch_transform_on_device(mock.ANY, RunningStage.TRAINING),
        mock.call.on_load_sample(mock.ANY, RunningStage.VALIDATING),
        mock.call.on_per_sample_transform(mock.ANY, RunningStage.VALIDATING),
        mock.call.on_collate(mock.ANY, RunningStage.VALIDATING),
        mock.call.on_per_batch_transform(mock.ANY, RunningStage.VALIDATING),
        mock.call.on_per_batch_transform_on_device(mock.ANY, RunningStage.VALIDATING),
    ]
示例#15
0
def test_deepcopy():
    """Tests that deepcopy works with the ``SplitDataset``."""
    dataset = list(range(100))
    train_ds, val_ds = DataModule._split_train_val(dataset, val_split=0.1)
    deepcopy(train_ds)
def test_dataset_data_source():

    dm = DataModule.from_datasets(range(10), range(10))
    assert dm.train_dataset.sample == {DefaultDataKeys.INPUT: 0}
示例#17
0
def test_data_module():
    seed_everything(42)

    def train_fn(data):
        return data - 100

    def val_fn(data):
        return data + 100

    def test_fn(data):
        return data - 1000

    def predict_fn(data):
        return data + 1000

    @dataclass
    class TestTransform(InputTransform):
        def per_sample_transform(self):
            def fn(x):
                return x

            return fn

        def train_per_batch_transform_on_device(self) -> Callable:
            return train_fn

        def val_per_batch_transform_on_device(self) -> Callable:
            return val_fn

        def test_per_batch_transform_on_device(self) -> Callable:
            return test_fn

        def predict_per_batch_transform_on_device(self) -> Callable:
            return predict_fn

    transform = TestTransform()
    assert transform._transform is not None

    train_dataset = Input(RunningStage.TRAINING, np.arange(10,
                                                           dtype=np.float32))
    assert train_dataset.running_stage == RunningStage.TRAINING

    val_dataset = Input(RunningStage.VALIDATING, np.arange(10,
                                                           dtype=np.float32))
    assert val_dataset.running_stage == RunningStage.VALIDATING

    test_dataset = Input(RunningStage.TESTING, np.arange(10, dtype=np.float32))
    assert test_dataset.running_stage == RunningStage.TESTING

    predict_dataset = Input(RunningStage.PREDICTING,
                            np.arange(10, dtype=np.float32))
    assert predict_dataset.running_stage == RunningStage.PREDICTING

    dm = DataModule(
        train_input=train_dataset,
        val_input=val_dataset,
        test_input=test_dataset,
        predict_input=predict_dataset,
        transform=transform,
        batch_size=2,
    )
    assert len(dm.train_dataloader()) == 5
    batch = next(iter(dm.train_dataloader()))
    assert batch.shape == torch.Size([2])
    assert batch.min() >= 0 and batch.max() < 10

    assert len(dm.val_dataloader()) == 5
    batch = next(iter(dm.val_dataloader()))
    assert batch.shape == torch.Size([2])
    assert batch.min() >= 0 and batch.max() < 10

    class TestModel(Task):
        def training_step(self, batch, batch_idx):
            assert sum(batch < 0) == 2

        def validation_step(self, batch, batch_idx):
            assert sum(batch > 0) == 2

        def test_step(self, batch, batch_idx):
            assert sum(batch < 500) == 2

        def predict_step(self, batch, *args, **kwargs):
            assert sum(batch > 500) == 2
            assert torch.equal(batch, torch.tensor([1000.0, 1001.0]))

        def on_train_dataloader(self) -> None:
            pass

        def on_val_dataloader(self) -> None:
            pass

        def on_test_dataloader(self, *_) -> None:
            pass

        def on_predict_dataloader(self) -> None:
            pass

        def on_predict_end(self) -> None:
            pass

        def on_fit_end(self) -> None:
            pass

    model = TestModel(torch.nn.Linear(1, 1))
    trainer = Trainer(fast_dev_run=True)
    trainer.fit(model, datamodule=dm)
    trainer.validate(model, datamodule=dm)
    trainer.test(model, datamodule=dm)
    trainer.predict(model, datamodule=dm)

    # Test that plain lightning module works with FlashDataModule
    class SampleBoringModel(BoringModel):
        def __init__(self):
            super().__init__()
            self.layer = torch.nn.Linear(2, 1)

    model = SampleBoringModel()
    trainer = Trainer(fast_dev_run=True)
    trainer.fit(model, datamodule=dm)
    trainer.validate(model, datamodule=dm)
    trainer.test(model, datamodule=dm)
    trainer.predict(model, datamodule=dm)

    transform = TestTransform()
    input = Input(RunningStage.TRAINING)
    dm = DataModule(train_input=input, batch_size=1, transform=transform)
    assert isinstance(dm.input_transform, TestTransform)

    class RandomDataset(Dataset):
        def __init__(self, size: int, length: int):
            self.len = length
            self.data = torch.ones(length, size)

        def __getitem__(self, index):
            return self.data[index]

        def __len__(self):
            return self.len

    def _add_hundred(x):
        if isinstance(x, Dict):
            x["input"] += 100
        else:
            x += 100
        return x

    class TrainInputTransform(InputTransform):
        def _add_one(self, x):
            if isinstance(x, Dict):
                x["input"] += 1
            else:
                x += 1
            return x

        def per_sample_transform(self) -> Callable:
            return self._add_one

        def val_per_sample_transform(self) -> Callable:
            return _add_hundred

    dm = DataModule(
        train_input=DatasetInput(RunningStage.TRAINING, RandomDataset(64, 32)),
        val_input=DatasetInput(RunningStage.VALIDATING, RandomDataset(64, 32)),
        test_input=DatasetInput(RunningStage.TESTING, RandomDataset(64, 32)),
        batch_size=3,
        transform=TrainInputTransform(),
    )
    batch = next(iter(dm.train_dataloader()))
    assert batch["input"][0][0] == 2
    batch = next(iter(dm.val_dataloader()))
    assert batch["input"][0][0] == 101
    batch = next(iter(dm.test_dataloader()))
    assert batch["input"][0][0] == 2
示例#18
0
def test_flash_callback(_, __, tmpdir):
    """Test the callback hook system for fit."""

    callback_mock = MagicMock()

    inputs = [[torch.rand(1), torch.rand(1)]]
    dm = DataModule.from_data_source("default",
                                     inputs,
                                     inputs,
                                     inputs,
                                     None,
                                     preprocess=DefaultPreprocess(),
                                     batch_size=1,
                                     num_workers=0)
    dm.preprocess.callbacks += [callback_mock]

    _ = next(iter(dm.train_dataloader()))

    assert callback_mock.method_calls == [
        call.on_load_sample(ANY, RunningStage.TRAINING),
        call.on_pre_tensor_transform(ANY, RunningStage.TRAINING),
        call.on_to_tensor_transform(ANY, RunningStage.TRAINING),
        call.on_post_tensor_transform(ANY, RunningStage.TRAINING),
        call.on_collate(ANY, RunningStage.TRAINING),
        call.on_per_batch_transform(ANY, RunningStage.TRAINING),
    ]

    class CustomModel(Task):
        def __init__(self):
            super().__init__(model=torch.nn.Linear(1, 1),
                             loss_fn=torch.nn.MSELoss())

    trainer = Trainer(
        default_root_dir=tmpdir,
        max_epochs=1,
        limit_val_batches=1,
        limit_train_batches=1,
        progress_bar_refresh_rate=0,
    )
    dm = DataModule.from_data_source("default",
                                     inputs,
                                     inputs,
                                     inputs,
                                     None,
                                     preprocess=DefaultPreprocess(),
                                     batch_size=1,
                                     num_workers=0)
    dm.preprocess.callbacks += [callback_mock]
    trainer.fit(CustomModel(), datamodule=dm)

    assert callback_mock.method_calls == [
        call.on_load_sample(ANY, RunningStage.TRAINING),
        call.on_pre_tensor_transform(ANY, RunningStage.TRAINING),
        call.on_to_tensor_transform(ANY, RunningStage.TRAINING),
        call.on_post_tensor_transform(ANY, RunningStage.TRAINING),
        call.on_collate(ANY, RunningStage.TRAINING),
        call.on_per_batch_transform(ANY, RunningStage.TRAINING),
        call.on_load_sample(ANY, RunningStage.VALIDATING),
        call.on_pre_tensor_transform(ANY, RunningStage.VALIDATING),
        call.on_to_tensor_transform(ANY, RunningStage.VALIDATING),
        call.on_post_tensor_transform(ANY, RunningStage.VALIDATING),
        call.on_collate(ANY, RunningStage.VALIDATING),
        call.on_per_batch_transform(ANY, RunningStage.VALIDATING),
        call.on_per_batch_transform_on_device(ANY, RunningStage.VALIDATING),
        call.on_load_sample(ANY, RunningStage.TRAINING),
        call.on_pre_tensor_transform(ANY, RunningStage.TRAINING),
        call.on_to_tensor_transform(ANY, RunningStage.TRAINING),
        call.on_post_tensor_transform(ANY, RunningStage.TRAINING),
        call.on_collate(ANY, RunningStage.TRAINING),
        call.on_per_batch_transform(ANY, RunningStage.TRAINING),
        call.on_per_batch_transform_on_device(ANY, RunningStage.TRAINING),
        call.on_load_sample(ANY, RunningStage.VALIDATING),
        call.on_pre_tensor_transform(ANY, RunningStage.VALIDATING),
        call.on_to_tensor_transform(ANY, RunningStage.VALIDATING),
        call.on_post_tensor_transform(ANY, RunningStage.VALIDATING),
        call.on_collate(ANY, RunningStage.VALIDATING),
        call.on_per_batch_transform(ANY, RunningStage.VALIDATING),
        call.on_per_batch_transform_on_device(ANY, RunningStage.VALIDATING),
    ]