Ejemplo n.º 1
0
    def __init__(
        self,
        preprocess: Optional[Preprocess] = None,
        postprocess: Optional[Postprocess] = None,
        serializer: Optional[Serializer] = None,
    ) -> None:
        self._preprocess_pipeline = preprocess or DefaultPreprocess()
        self._postprocess_pipeline = postprocess or Postprocess()

        self._serializer = serializer or Serializer()

        self._running_stage = None
Ejemplo n.º 2
0
        def from_inputs(cls, train_data: Any, val_data: Any, test_data: Any, predict_data: Any) -> "CustomDataModule":

            preprocess = DefaultPreprocess()

            return cls.from_load_data_inputs(
                train_load_data_input=train_data,
                val_load_data_input=val_data,
                test_load_data_input=test_data,
                predict_load_data_input=predict_data,
                preprocess=preprocess,
                batch_size=5
            )
Ejemplo n.º 3
0
        def from_inputs(cls, train_data: Any, val_data: Any, test_data: Any, predict_data: Any) -> "CustomDataModule":

            preprocess = DefaultPreprocess()

            return cls.from_data_source(
                "default",
                train_data=train_data,
                val_data=val_data,
                test_data=test_data,
                predict_data=predict_data,
                preprocess=preprocess,
                batch_size=5,
            )
Ejemplo n.º 4
0
def test_classificationtask_task_predict():
    model = nn.Sequential(nn.Flatten(), nn.Linear(28 * 28, 10), nn.Softmax())
    task = ClassificationTask(model, preprocess=DefaultPreprocess())
    ds = DummyDataset()
    expected = list(range(10))
    # single item
    x0, _ = ds[0]
    pred0 = task.predict(x0)
    assert pred0[0] in expected
    # list
    x1, _ = ds[1]
    pred1 = task.predict([x0, x1])
    assert all(c in expected for c in pred1)
    assert pred0[0] == pred1[0]
Ejemplo n.º 5
0
def test_flash_callback(_, tmpdir):
    """Test the callback hook system for fit."""

    callback_mock = MagicMock()

    inputs = [[torch.rand(1), torch.rand(1)]]
    dm = DataModule.from_data_source("default",
                                     inputs,
                                     inputs,
                                     inputs,
                                     None,
                                     preprocess=DefaultPreprocess(),
                                     batch_size=1,
                                     num_workers=0)
    dm.preprocess.callbacks += [callback_mock]

    _ = next(iter(dm.train_dataloader()))

    assert callback_mock.method_calls == [
        call.on_load_sample(ANY, RunningStage.TRAINING),
        call.on_pre_tensor_transform(ANY, RunningStage.TRAINING),
        call.on_to_tensor_transform(ANY, RunningStage.TRAINING),
        call.on_post_tensor_transform(ANY, RunningStage.TRAINING),
        call.on_collate(ANY, RunningStage.TRAINING),
        call.on_per_batch_transform(ANY, RunningStage.TRAINING),
    ]

    class CustomModel(Task):
        def __init__(self):
            super().__init__(model=torch.nn.Linear(1, 1),
                             loss_fn=torch.nn.MSELoss())

    trainer = Trainer(
        default_root_dir=tmpdir,
        max_epochs=1,
        limit_val_batches=1,
        limit_train_batches=1,
        progress_bar_refresh_rate=0,
    )
    dm = DataModule.from_data_source("default",
                                     inputs,
                                     inputs,
                                     inputs,
                                     None,
                                     preprocess=DefaultPreprocess(),
                                     batch_size=1,
                                     num_workers=0)
    dm.preprocess.callbacks += [callback_mock]
    trainer.fit(CustomModel(), datamodule=dm)

    assert callback_mock.method_calls == [
        call.on_load_sample(ANY, RunningStage.TRAINING),
        call.on_pre_tensor_transform(ANY, RunningStage.TRAINING),
        call.on_to_tensor_transform(ANY, RunningStage.TRAINING),
        call.on_post_tensor_transform(ANY, RunningStage.TRAINING),
        call.on_collate(ANY, RunningStage.TRAINING),
        call.on_per_batch_transform(ANY, RunningStage.TRAINING),
        call.on_load_sample(ANY, RunningStage.VALIDATING),
        call.on_pre_tensor_transform(ANY, RunningStage.VALIDATING),
        call.on_to_tensor_transform(ANY, RunningStage.VALIDATING),
        call.on_post_tensor_transform(ANY, RunningStage.VALIDATING),
        call.on_collate(ANY, RunningStage.VALIDATING),
        call.on_per_batch_transform(ANY, RunningStage.VALIDATING),
        call.on_per_batch_transform_on_device(ANY, RunningStage.VALIDATING),
        call.on_load_sample(ANY, RunningStage.TRAINING),
        call.on_pre_tensor_transform(ANY, RunningStage.TRAINING),
        call.on_to_tensor_transform(ANY, RunningStage.TRAINING),
        call.on_post_tensor_transform(ANY, RunningStage.TRAINING),
        call.on_collate(ANY, RunningStage.TRAINING),
        call.on_per_batch_transform(ANY, RunningStage.TRAINING),
        call.on_per_batch_transform_on_device(ANY, RunningStage.TRAINING),
        call.on_load_sample(ANY, RunningStage.VALIDATING),
        call.on_pre_tensor_transform(ANY, RunningStage.VALIDATING),
        call.on_to_tensor_transform(ANY, RunningStage.VALIDATING),
        call.on_post_tensor_transform(ANY, RunningStage.VALIDATING),
        call.on_collate(ANY, RunningStage.VALIDATING),
        call.on_per_batch_transform(ANY, RunningStage.VALIDATING),
        call.on_per_batch_transform_on_device(ANY, RunningStage.VALIDATING)
    ]
Ejemplo n.º 6
0
def test_preprocess_transforms(tmpdir):
    """
    This test makes sure that when a preprocess is being provided transforms as dictionaries,
    checking is done properly, and collate_in_worker_from_transform is properly extracted.
    """

    with pytest.raises(MisconfigurationException,
                       match="Transform should be a dict."):
        DefaultPreprocess(train_transform="choco")

    with pytest.raises(MisconfigurationException,
                       match="train_transform contains {'choco'}. Only"):
        DefaultPreprocess(train_transform={"choco": None})

    preprocess = DefaultPreprocess(
        train_transform={"to_tensor_transform": torch.nn.Linear(1, 1)})
    # keep is None
    assert preprocess._train_collate_in_worker_from_transform is True
    assert preprocess._val_collate_in_worker_from_transform is None
    assert preprocess._test_collate_in_worker_from_transform is None
    assert preprocess._predict_collate_in_worker_from_transform is None

    with pytest.raises(
            MisconfigurationException,
            match="`per_batch_transform` and `per_sample_transform_on_device`"
    ):
        preprocess = DefaultPreprocess(
            train_transform={
                "per_batch_transform": torch.nn.Linear(1, 1),
                "per_sample_transform_on_device": torch.nn.Linear(1, 1)
            })

    preprocess = DefaultPreprocess(
        train_transform={"per_batch_transform": torch.nn.Linear(1, 1)},
        predict_transform={
            "per_sample_transform_on_device": torch.nn.Linear(1, 1)
        })
    # keep is None
    assert preprocess._train_collate_in_worker_from_transform is True
    assert preprocess._val_collate_in_worker_from_transform is None
    assert preprocess._test_collate_in_worker_from_transform is None
    assert preprocess._predict_collate_in_worker_from_transform is False

    train_preprocessor = DataPipeline(
        preprocess=preprocess).worker_preprocessor(RunningStage.TRAINING)
    val_preprocessor = DataPipeline(preprocess=preprocess).worker_preprocessor(
        RunningStage.VALIDATING)
    test_preprocessor = DataPipeline(
        preprocess=preprocess).worker_preprocessor(RunningStage.TESTING)
    predict_preprocessor = DataPipeline(
        preprocess=preprocess).worker_preprocessor(RunningStage.PREDICTING)

    assert train_preprocessor.collate_fn.func == default_collate
    assert val_preprocessor.collate_fn.func == default_collate
    assert test_preprocessor.collate_fn.func == default_collate
    assert predict_preprocessor.collate_fn.func == DataPipeline._identity

    class CustomPreprocess(DefaultPreprocess):
        def per_sample_transform_on_device(self, sample: Any) -> Any:
            return super().per_sample_transform_on_device(sample)

        def per_batch_transform(self, batch: Any) -> Any:
            return super().per_batch_transform(batch)

    preprocess = CustomPreprocess(
        train_transform={"per_batch_transform": torch.nn.Linear(1, 1)},
        predict_transform={
            "per_sample_transform_on_device": torch.nn.Linear(1, 1)
        })
    # keep is None
    assert preprocess._train_collate_in_worker_from_transform is True
    assert preprocess._val_collate_in_worker_from_transform is None
    assert preprocess._test_collate_in_worker_from_transform is None
    assert preprocess._predict_collate_in_worker_from_transform is False

    data_pipeline = DataPipeline(preprocess=preprocess)

    train_preprocessor = data_pipeline.worker_preprocessor(
        RunningStage.TRAINING)
    with pytest.raises(
            MisconfigurationException,
            match="`per_batch_transform` and `per_sample_transform_on_device`"
    ):
        val_preprocessor = data_pipeline.worker_preprocessor(
            RunningStage.VALIDATING)
    with pytest.raises(
            MisconfigurationException,
            match="`per_batch_transform` and `per_sample_transform_on_device`"
    ):
        test_preprocessor = data_pipeline.worker_preprocessor(
            RunningStage.TESTING)
    predict_preprocessor = data_pipeline.worker_preprocessor(
        RunningStage.PREDICTING)

    assert train_preprocessor.collate_fn.func == default_collate
    assert predict_preprocessor.collate_fn.func == DataPipeline._identity