def test_data_pipeline_predict_worker_preprocessor_and_device_preprocessor():

    preprocess = CustomPreprocess()
    data_pipeline = DataPipeline(preprocess)

    data_pipeline.worker_preprocessor(RunningStage.TRAINING)
    with pytest.raises(MisconfigurationException, match="are mutual exclusive"):
        data_pipeline.worker_preprocessor(RunningStage.VALIDATING)
    with pytest.raises(MisconfigurationException, match="are mutual exclusive"):
        data_pipeline.worker_preprocessor(RunningStage.TESTING)
    data_pipeline.worker_preprocessor(RunningStage.PREDICTING)
def test_data_pipeline_is_overriden_and_resolve_function_hierarchy(tmpdir):

    class CustomPreprocess(Preprocess):

        def load_data(self, *_, **__):
            pass

        def test_load_data(self, *_, **__):
            pass

        def predict_load_data(self, *_, **__):
            pass

        def predict_load_sample(self, *_, **__):
            pass

        def val_load_sample(self, *_, **__):
            pass

        def val_pre_tensor_transform(self, *_, **__):
            pass

        def predict_to_tensor_transform(self, *_, **__):
            pass

        def train_post_tensor_transform(self, *_, **__):
            pass

        def test_collate(self, *_, **__):
            pass

        def val_per_sample_transform_on_device(self, *_, **__):
            pass

        def train_per_batch_transform_on_device(self, *_, **__):
            pass

        def test_per_batch_transform_on_device(self, *_, **__):
            pass

    preprocess = CustomPreprocess()
    data_pipeline = DataPipeline(preprocess)
    train_func_names = {
        k: data_pipeline._resolve_function_hierarchy(
            k, data_pipeline._preprocess_pipeline, RunningStage.TRAINING, Preprocess
        )
        for k in data_pipeline.PREPROCESS_FUNCS
    }
    val_func_names = {
        k: data_pipeline._resolve_function_hierarchy(
            k, data_pipeline._preprocess_pipeline, RunningStage.VALIDATING, Preprocess
        )
        for k in data_pipeline.PREPROCESS_FUNCS
    }
    test_func_names = {
        k: data_pipeline._resolve_function_hierarchy(
            k, data_pipeline._preprocess_pipeline, RunningStage.TESTING, Preprocess
        )
        for k in data_pipeline.PREPROCESS_FUNCS
    }
    predict_func_names = {
        k: data_pipeline._resolve_function_hierarchy(
            k, data_pipeline._preprocess_pipeline, RunningStage.PREDICTING, Preprocess
        )
        for k in data_pipeline.PREPROCESS_FUNCS
    }
    # load_data
    assert train_func_names["load_data"] == "load_data"
    assert val_func_names["load_data"] == "load_data"
    assert test_func_names["load_data"] == "test_load_data"
    assert predict_func_names["load_data"] == "predict_load_data"

    # load_sample
    assert train_func_names["load_sample"] == "load_sample"
    assert val_func_names["load_sample"] == "val_load_sample"
    assert test_func_names["load_sample"] == "load_sample"
    assert predict_func_names["load_sample"] == "predict_load_sample"

    # pre_tensor_transform
    assert train_func_names["pre_tensor_transform"] == "pre_tensor_transform"
    assert val_func_names["pre_tensor_transform"] == "val_pre_tensor_transform"
    assert test_func_names["pre_tensor_transform"] == "pre_tensor_transform"
    assert predict_func_names["pre_tensor_transform"] == "pre_tensor_transform"

    # to_tensor_transform
    assert train_func_names["to_tensor_transform"] == "to_tensor_transform"
    assert val_func_names["to_tensor_transform"] == "to_tensor_transform"
    assert test_func_names["to_tensor_transform"] == "to_tensor_transform"
    assert predict_func_names["to_tensor_transform"] == "predict_to_tensor_transform"

    # post_tensor_transform
    assert train_func_names["post_tensor_transform"] == "train_post_tensor_transform"
    assert val_func_names["post_tensor_transform"] == "post_tensor_transform"
    assert test_func_names["post_tensor_transform"] == "post_tensor_transform"
    assert predict_func_names["post_tensor_transform"] == "post_tensor_transform"

    # collate
    assert train_func_names["collate"] == "collate"
    assert val_func_names["collate"] == "collate"
    assert test_func_names["collate"] == "test_collate"
    assert predict_func_names["collate"] == "collate"

    # per_sample_transform_on_device
    assert train_func_names["per_sample_transform_on_device"] == "per_sample_transform_on_device"
    assert val_func_names["per_sample_transform_on_device"] == "val_per_sample_transform_on_device"
    assert test_func_names["per_sample_transform_on_device"] == "per_sample_transform_on_device"
    assert predict_func_names["per_sample_transform_on_device"] == "per_sample_transform_on_device"

    # per_batch_transform_on_device
    assert train_func_names["per_batch_transform_on_device"] == "train_per_batch_transform_on_device"
    assert val_func_names["per_batch_transform_on_device"] == "per_batch_transform_on_device"
    assert test_func_names["per_batch_transform_on_device"] == "test_per_batch_transform_on_device"
    assert predict_func_names["per_batch_transform_on_device"] == "per_batch_transform_on_device"

    train_worker_preprocessor = data_pipeline.worker_preprocessor(RunningStage.TRAINING)
    val_worker_preprocessor = data_pipeline.worker_preprocessor(RunningStage.VALIDATING)
    test_worker_preprocessor = data_pipeline.worker_preprocessor(RunningStage.TESTING)
    predict_worker_preprocessor = data_pipeline.worker_preprocessor(RunningStage.PREDICTING)

    _seq = train_worker_preprocessor.per_sample_transform
    assert _seq.pre_tensor_transform.func == preprocess.pre_tensor_transform
    assert _seq.to_tensor_transform.func == preprocess.to_tensor_transform
    assert _seq.post_tensor_transform.func == preprocess.train_post_tensor_transform
    assert train_worker_preprocessor.collate_fn.func == default_collate
    assert train_worker_preprocessor.per_batch_transform.func == preprocess.per_batch_transform

    _seq = val_worker_preprocessor.per_sample_transform
    assert _seq.pre_tensor_transform.func == preprocess.val_pre_tensor_transform
    assert _seq.to_tensor_transform.func == preprocess.to_tensor_transform
    assert _seq.post_tensor_transform.func == preprocess.post_tensor_transform
    assert val_worker_preprocessor.collate_fn.func == data_pipeline._identity
    assert val_worker_preprocessor.per_batch_transform.func == preprocess.per_batch_transform

    _seq = test_worker_preprocessor.per_sample_transform
    assert _seq.pre_tensor_transform.func == preprocess.pre_tensor_transform
    assert _seq.to_tensor_transform.func == preprocess.to_tensor_transform
    assert _seq.post_tensor_transform.func == preprocess.post_tensor_transform
    assert test_worker_preprocessor.collate_fn.func == preprocess.test_collate
    assert test_worker_preprocessor.per_batch_transform.func == preprocess.per_batch_transform

    _seq = predict_worker_preprocessor.per_sample_transform
    assert _seq.pre_tensor_transform.func == preprocess.pre_tensor_transform
    assert _seq.to_tensor_transform.func == preprocess.predict_to_tensor_transform
    assert _seq.post_tensor_transform.func == preprocess.post_tensor_transform
    assert predict_worker_preprocessor.collate_fn.func == default_collate
    assert predict_worker_preprocessor.per_batch_transform.func == preprocess.per_batch_transform
Ejemplo n.º 3
0
def test_preprocess_transforms(tmpdir):
    """
    This test makes sure that when a preprocess is being provided transforms as dictionaries,
    checking is done properly, and collate_in_worker_from_transform is properly extracted.
    """

    with pytest.raises(MisconfigurationException,
                       match="Transform should be a dict."):
        DefaultPreprocess(train_transform="choco")

    with pytest.raises(MisconfigurationException,
                       match="train_transform contains {'choco'}. Only"):
        DefaultPreprocess(train_transform={"choco": None})

    preprocess = DefaultPreprocess(
        train_transform={"to_tensor_transform": torch.nn.Linear(1, 1)})
    # keep is None
    assert preprocess._train_collate_in_worker_from_transform is True
    assert preprocess._val_collate_in_worker_from_transform is None
    assert preprocess._test_collate_in_worker_from_transform is None
    assert preprocess._predict_collate_in_worker_from_transform is None

    with pytest.raises(
            MisconfigurationException,
            match="`per_batch_transform` and `per_sample_transform_on_device`"
    ):
        preprocess = DefaultPreprocess(
            train_transform={
                "per_batch_transform": torch.nn.Linear(1, 1),
                "per_sample_transform_on_device": torch.nn.Linear(1, 1)
            })

    preprocess = DefaultPreprocess(
        train_transform={"per_batch_transform": torch.nn.Linear(1, 1)},
        predict_transform={
            "per_sample_transform_on_device": torch.nn.Linear(1, 1)
        })
    # keep is None
    assert preprocess._train_collate_in_worker_from_transform is True
    assert preprocess._val_collate_in_worker_from_transform is None
    assert preprocess._test_collate_in_worker_from_transform is None
    assert preprocess._predict_collate_in_worker_from_transform is False

    train_preprocessor = DataPipeline(
        preprocess=preprocess).worker_preprocessor(RunningStage.TRAINING)
    val_preprocessor = DataPipeline(preprocess=preprocess).worker_preprocessor(
        RunningStage.VALIDATING)
    test_preprocessor = DataPipeline(
        preprocess=preprocess).worker_preprocessor(RunningStage.TESTING)
    predict_preprocessor = DataPipeline(
        preprocess=preprocess).worker_preprocessor(RunningStage.PREDICTING)

    assert train_preprocessor.collate_fn.func == default_collate
    assert val_preprocessor.collate_fn.func == default_collate
    assert test_preprocessor.collate_fn.func == default_collate
    assert predict_preprocessor.collate_fn.func == DataPipeline._identity

    class CustomPreprocess(DefaultPreprocess):
        def per_sample_transform_on_device(self, sample: Any) -> Any:
            return super().per_sample_transform_on_device(sample)

        def per_batch_transform(self, batch: Any) -> Any:
            return super().per_batch_transform(batch)

    preprocess = CustomPreprocess(
        train_transform={"per_batch_transform": torch.nn.Linear(1, 1)},
        predict_transform={
            "per_sample_transform_on_device": torch.nn.Linear(1, 1)
        })
    # keep is None
    assert preprocess._train_collate_in_worker_from_transform is True
    assert preprocess._val_collate_in_worker_from_transform is None
    assert preprocess._test_collate_in_worker_from_transform is None
    assert preprocess._predict_collate_in_worker_from_transform is False

    data_pipeline = DataPipeline(preprocess=preprocess)

    train_preprocessor = data_pipeline.worker_preprocessor(
        RunningStage.TRAINING)
    with pytest.raises(
            MisconfigurationException,
            match="`per_batch_transform` and `per_sample_transform_on_device`"
    ):
        val_preprocessor = data_pipeline.worker_preprocessor(
            RunningStage.VALIDATING)
    with pytest.raises(
            MisconfigurationException,
            match="`per_batch_transform` and `per_sample_transform_on_device`"
    ):
        test_preprocessor = data_pipeline.worker_preprocessor(
            RunningStage.TESTING)
    predict_preprocessor = data_pipeline.worker_preprocessor(
        RunningStage.PREDICTING)

    assert train_preprocessor.collate_fn.func == default_collate
    assert predict_preprocessor.collate_fn.func == DataPipeline._identity