def test_is_overridden_recursive(tmpdir):
    class TestInputTransform(InputTransform):
        @staticmethod
        def custom_transform(x):
            return x

        def collate(self):
            return self.custom_transform

        def val_collate(self):
            return self.custom_transform

    input_transform = TestInputTransform()
    assert DataPipeline._is_overridden_recursive("collate",
                                                 input_transform,
                                                 InputTransform,
                                                 prefix="val")
    assert DataPipeline._is_overridden_recursive("collate",
                                                 input_transform,
                                                 InputTransform,
                                                 prefix="train")
    assert not DataPipeline._is_overridden_recursive(
        "per_batch_transform_on_device",
        input_transform,
        InputTransform,
        prefix="train")
    assert not DataPipeline._is_overridden_recursive(
        "per_batch_transform_on_device", input_transform, InputTransform)
    with pytest.raises(
            MisconfigurationException,
            match="This function doesn't belong to the parent class"):
        assert not DataPipeline._is_overridden_recursive(
            "chocolate", input_transform, InputTransform)
Exemplo n.º 2
0
def __configure_worker_and_device_collate_fn(
        running_stage: RunningStage,
        input_transform: InputTransform) -> Tuple[Callable, Callable]:

    from flash.core.data.data_pipeline import DataPipeline

    prefix: str = _STAGES_PREFIX[running_stage]
    transform_for_stage: _InputTransformPerStage = input_transform._transform[
        running_stage]

    per_batch_transform_overridden: bool = DataPipeline._is_overridden_recursive(
        "per_batch_transform", input_transform, InputTransform, prefix=prefix)

    per_sample_transform_on_device_overridden: bool = DataPipeline._is_overridden_recursive(
        "per_sample_transform_on_device",
        input_transform,
        InputTransform,
        prefix=prefix)

    is_per_overridden = per_batch_transform_overridden and per_sample_transform_on_device_overridden
    if transform_for_stage.collate_in_worker_from_transform is None and is_per_overridden:
        raise MisconfigurationException(
            f"{input_transform.__class__.__name__}: `per_batch_transform` and `per_sample_transform_on_device` "
            f"are mutually exclusive for stage {running_stage}")

    if isinstance(transform_for_stage.collate_in_worker_from_transform, bool):
        worker_collate_fn, device_collate_fn = __make_collates(
            input_transform,
            not transform_for_stage.collate_in_worker_from_transform,
            input_transform._collate)
    else:
        worker_collate_fn, device_collate_fn = __make_collates(
            input_transform, per_sample_transform_on_device_overridden,
            input_transform._collate)

    worker_collate_fn = (worker_collate_fn.collate_fn if isinstance(
        worker_collate_fn, _InputTransformProcessor) else worker_collate_fn)

    return worker_collate_fn, device_collate_fn