Ejemplo n.º 1
0
        def from_inputs(cls, train_data: Any, val_data: Any, test_data: Any, predict_data: Any) -> "CustomDataModule":

            preprocess = DefaultPreprocess()

            return cls.from_data_source(
                "default",
                train_data=train_data,
                val_data=val_data,
                test_data=test_data,
                predict_data=predict_data,
                preprocess=preprocess,
                batch_size=5,
            )
Ejemplo n.º 2
0
def test_classificationtask_task_predict():
    model = nn.Sequential(nn.Flatten(), nn.Linear(28 * 28, 10), nn.Softmax())
    task = ClassificationTask(model, preprocess=DefaultPreprocess())
    ds = DummyDataset()
    expected = list(range(10))
    # single item
    x0, _ = ds[0]
    pred0 = task.predict(x0)
    assert pred0[0] in expected
    # list
    x1, _ = ds[1]
    pred1 = task.predict([x0, x1])
    assert all(c in expected for c in pred1)
    assert pred0[0] == pred1[0]
    def __init__(
        self,
        data_source: Optional[DataSource] = None,
        preprocess: Optional[Preprocess] = None,
        postprocess: Optional[Postprocess] = None,
        deserializer: Optional[Deserializer] = None,
        serializer: Optional[Serializer] = None,
    ) -> None:
        self.data_source = data_source

        self._preprocess_pipeline = preprocess or DefaultPreprocess()
        self._postprocess_pipeline = postprocess or Postprocess()
        self._serializer = serializer or Serializer()
        self._deserializer = deserializer or Deserializer()
        self._running_stage = None
Ejemplo n.º 4
0
def test_multicrop_input_transform():
    batch_size = 8
    total_crops = 6
    num_crops = [2, 4]
    size_crops = [160, 96]
    crop_scales = [[0.4, 1], [0.05, 0.4]]

    multi_crop_transform = TRANSFORM_REGISTRY["multicrop_ssl_transform"](
        total_crops, num_crops, size_crops, crop_scales)

    to_tensor_transform = ApplyToKeys(
        DefaultDataKeys.INPUT,
        multi_crop_transform,
    )
    preprocess = DefaultPreprocess(train_transform={
        "to_tensor_transform": to_tensor_transform,
        "collate": vissl_collate_fn,
    })

    datamodule = ImageClassificationData.from_datasets(
        train_dataset=FakeData(),
        preprocess=preprocess,
        batch_size=batch_size,
    )

    train_dataloader = datamodule._train_dataloader()
    batch = next(iter(train_dataloader))

    assert len(batch[DefaultDataKeys.INPUT]) == total_crops
    assert batch[DefaultDataKeys.INPUT][0].shape == (batch_size, 3,
                                                     size_crops[0],
                                                     size_crops[0])
    assert batch[DefaultDataKeys.INPUT][-1].shape == (batch_size, 3,
                                                      size_crops[-1],
                                                      size_crops[-1])
    assert list(batch[DefaultDataKeys.TARGET].shape) == [batch_size]
def test_preprocess_transforms(tmpdir):
    """
    This test makes sure that when a preprocess is being provided transforms as dictionaries,
    checking is done properly, and collate_in_worker_from_transform is properly extracted.
    """

    with pytest.raises(MisconfigurationException,
                       match="Transform should be a dict."):
        DefaultPreprocess(train_transform="choco")

    with pytest.raises(MisconfigurationException,
                       match="train_transform contains {'choco'}. Only"):
        DefaultPreprocess(train_transform={"choco": None})

    preprocess = DefaultPreprocess(
        train_transform={"to_tensor_transform": torch.nn.Linear(1, 1)})
    # keep is None
    assert preprocess._train_collate_in_worker_from_transform is True
    assert preprocess._val_collate_in_worker_from_transform is None
    assert preprocess._test_collate_in_worker_from_transform is None
    assert preprocess._predict_collate_in_worker_from_transform is None

    with pytest.raises(
            MisconfigurationException,
            match="`per_batch_transform` and `per_sample_transform_on_device`"
    ):
        preprocess = DefaultPreprocess(
            train_transform={
                "per_batch_transform": torch.nn.Linear(1, 1),
                "per_sample_transform_on_device": torch.nn.Linear(1, 1)
            })

    preprocess = DefaultPreprocess(
        train_transform={"per_batch_transform": torch.nn.Linear(1, 1)},
        predict_transform={
            "per_sample_transform_on_device": torch.nn.Linear(1, 1)
        })
    # keep is None
    assert preprocess._train_collate_in_worker_from_transform is True
    assert preprocess._val_collate_in_worker_from_transform is None
    assert preprocess._test_collate_in_worker_from_transform is None
    assert preprocess._predict_collate_in_worker_from_transform is False

    train_preprocessor = DataPipeline(
        preprocess=preprocess).worker_preprocessor(RunningStage.TRAINING)
    val_preprocessor = DataPipeline(preprocess=preprocess).worker_preprocessor(
        RunningStage.VALIDATING)
    test_preprocessor = DataPipeline(
        preprocess=preprocess).worker_preprocessor(RunningStage.TESTING)
    predict_preprocessor = DataPipeline(
        preprocess=preprocess).worker_preprocessor(RunningStage.PREDICTING)

    assert train_preprocessor.collate_fn.func == preprocess.collate
    assert val_preprocessor.collate_fn.func == preprocess.collate
    assert test_preprocessor.collate_fn.func == preprocess.collate
    assert predict_preprocessor.collate_fn.func == DataPipeline._identity

    class CustomPreprocess(DefaultPreprocess):
        def per_sample_transform_on_device(self, sample: Any) -> Any:
            return super().per_sample_transform_on_device(sample)

        def per_batch_transform(self, batch: Any) -> Any:
            return super().per_batch_transform(batch)

    preprocess = CustomPreprocess(
        train_transform={"per_batch_transform": torch.nn.Linear(1, 1)},
        predict_transform={
            "per_sample_transform_on_device": torch.nn.Linear(1, 1)
        })
    # keep is None
    assert preprocess._train_collate_in_worker_from_transform is True
    assert preprocess._val_collate_in_worker_from_transform is None
    assert preprocess._test_collate_in_worker_from_transform is None
    assert preprocess._predict_collate_in_worker_from_transform is False

    data_pipeline = DataPipeline(preprocess=preprocess)

    train_preprocessor = data_pipeline.worker_preprocessor(
        RunningStage.TRAINING)
    with pytest.raises(
            MisconfigurationException,
            match="`per_batch_transform` and `per_sample_transform_on_device`"
    ):
        val_preprocessor = data_pipeline.worker_preprocessor(
            RunningStage.VALIDATING)
    with pytest.raises(
            MisconfigurationException,
            match="`per_batch_transform` and `per_sample_transform_on_device`"
    ):
        test_preprocessor = data_pipeline.worker_preprocessor(
            RunningStage.TESTING)
    predict_preprocessor = data_pipeline.worker_preprocessor(
        RunningStage.PREDICTING)

    assert train_preprocessor.collate_fn.func == preprocess.collate
    assert predict_preprocessor.collate_fn.func == DataPipeline._identity
Ejemplo n.º 6
0
def test_flash_callback(_, __, tmpdir):
    """Test the callback hook system for fit."""

    callback_mock = MagicMock()

    inputs = [[torch.rand(1), torch.rand(1)]]
    dm = DataModule.from_data_source("default",
                                     inputs,
                                     inputs,
                                     inputs,
                                     None,
                                     preprocess=DefaultPreprocess(),
                                     batch_size=1,
                                     num_workers=0)
    dm.preprocess.callbacks += [callback_mock]

    _ = next(iter(dm.train_dataloader()))

    assert callback_mock.method_calls == [
        call.on_load_sample(ANY, RunningStage.TRAINING),
        call.on_pre_tensor_transform(ANY, RunningStage.TRAINING),
        call.on_to_tensor_transform(ANY, RunningStage.TRAINING),
        call.on_post_tensor_transform(ANY, RunningStage.TRAINING),
        call.on_collate(ANY, RunningStage.TRAINING),
        call.on_per_batch_transform(ANY, RunningStage.TRAINING),
    ]

    class CustomModel(Task):
        def __init__(self):
            super().__init__(model=torch.nn.Linear(1, 1),
                             loss_fn=torch.nn.MSELoss())

    trainer = Trainer(
        default_root_dir=tmpdir,
        max_epochs=1,
        limit_val_batches=1,
        limit_train_batches=1,
        progress_bar_refresh_rate=0,
    )
    dm = DataModule.from_data_source("default",
                                     inputs,
                                     inputs,
                                     inputs,
                                     None,
                                     preprocess=DefaultPreprocess(),
                                     batch_size=1,
                                     num_workers=0)
    dm.preprocess.callbacks += [callback_mock]
    trainer.fit(CustomModel(), datamodule=dm)

    assert callback_mock.method_calls == [
        call.on_load_sample(ANY, RunningStage.TRAINING),
        call.on_pre_tensor_transform(ANY, RunningStage.TRAINING),
        call.on_to_tensor_transform(ANY, RunningStage.TRAINING),
        call.on_post_tensor_transform(ANY, RunningStage.TRAINING),
        call.on_collate(ANY, RunningStage.TRAINING),
        call.on_per_batch_transform(ANY, RunningStage.TRAINING),
        call.on_load_sample(ANY, RunningStage.VALIDATING),
        call.on_pre_tensor_transform(ANY, RunningStage.VALIDATING),
        call.on_to_tensor_transform(ANY, RunningStage.VALIDATING),
        call.on_post_tensor_transform(ANY, RunningStage.VALIDATING),
        call.on_collate(ANY, RunningStage.VALIDATING),
        call.on_per_batch_transform(ANY, RunningStage.VALIDATING),
        call.on_per_batch_transform_on_device(ANY, RunningStage.VALIDATING),
        call.on_load_sample(ANY, RunningStage.TRAINING),
        call.on_pre_tensor_transform(ANY, RunningStage.TRAINING),
        call.on_to_tensor_transform(ANY, RunningStage.TRAINING),
        call.on_post_tensor_transform(ANY, RunningStage.TRAINING),
        call.on_collate(ANY, RunningStage.TRAINING),
        call.on_per_batch_transform(ANY, RunningStage.TRAINING),
        call.on_per_batch_transform_on_device(ANY, RunningStage.TRAINING),
        call.on_load_sample(ANY, RunningStage.VALIDATING),
        call.on_pre_tensor_transform(ANY, RunningStage.VALIDATING),
        call.on_to_tensor_transform(ANY, RunningStage.VALIDATING),
        call.on_post_tensor_transform(ANY, RunningStage.VALIDATING),
        call.on_collate(ANY, RunningStage.VALIDATING),
        call.on_per_batch_transform(ANY, RunningStage.VALIDATING),
        call.on_per_batch_transform_on_device(ANY, RunningStage.VALIDATING),
    ]