def test_warning(self):
        state = DataPipelineState()
        state._initialized = True

        with pytest.warns(UserWarning,
                          match="data pipeline has already been initialized"):
            state.set_state(ProcessState())
Example #2
0
def test_properties_data_pipeline_state():
    """Tests that ``get_state`` and ``set_state`` work for properties and that ``DataPipelineState`` is attached
    correctly."""

    class MyProcessState1(ProcessState):
        pass

    class MyProcessState2(ProcessState):
        pass

    class OtherProcessState(ProcessState):
        pass

    my_properties = Properties()
    my_properties.set_state(MyProcessState1())
    assert my_properties._state == {MyProcessState1: MyProcessState1()}
    assert my_properties.get_state(OtherProcessState) is None

    data_pipeline_state = DataPipelineState()
    data_pipeline_state.set_state(OtherProcessState())
    my_properties.attach_data_pipeline_state(data_pipeline_state)
    assert my_properties.get_state(OtherProcessState) == OtherProcessState()

    my_properties.set_state(MyProcessState2())
    assert data_pipeline_state.get_state(MyProcessState2) == MyProcessState2()
    def test_str(self):
        state = DataPipelineState()
        state.set_state(ProcessState())

        assert str(state) == (
            "DataPipelineState(initialized=False, "
            "state={<class 'flash.core.data.properties.ProcessState'>: ProcessState()})"
        )
Example #4
0
def test_serializer_mapping():
    """Tests that ``SerializerMapping`` correctly passes its inputs to the underlying serializers.

    Also checks that state is retrieved / loaded correctly.
    """

    serializer1 = Serializer()
    serializer1.serialize = Mock(return_value="test1")

    class Serializer1State(ProcessState):
        pass

    serializer2 = Serializer()
    serializer2.serialize = Mock(return_value="test2")

    class Serializer2State(ProcessState):
        pass

    serializer_mapping = SerializerMapping({
        "key1": serializer1,
        "key2": serializer2
    })
    assert serializer_mapping({
        "key1": "serializer1",
        "key2": "serializer2"
    }) == {
        "key1": "test1",
        "key2": "test2"
    }
    serializer1.serialize.assert_called_once_with("serializer1")
    serializer2.serialize.assert_called_once_with("serializer2")

    with pytest.raises(ValueError, match="output must be a mapping"):
        serializer_mapping("not a mapping")

    serializer1_state = Serializer1State()
    serializer2_state = Serializer2State()

    serializer1.set_state(serializer1_state)
    serializer2.set_state(serializer2_state)

    data_pipeline_state = DataPipelineState()
    serializer_mapping.attach_data_pipeline_state(data_pipeline_state)

    assert serializer1._data_pipeline_state is data_pipeline_state
    assert serializer2._data_pipeline_state is data_pipeline_state

    assert data_pipeline_state.get_state(Serializer1State) is serializer1_state
    assert data_pipeline_state.get_state(Serializer2State) is serializer2_state
Example #5
0
def test_serializer_mapping():
    """Tests that ``SerializerMapping`` correctly passes its inputs to the underlying serializers. Also checks that
    state is retrieved / loaded correctly."""

    serializer1 = Serializer()
    serializer1.serialize = Mock(return_value='test1')

    class Serializer1State(ProcessState):
        pass

    serializer2 = Serializer()
    serializer2.serialize = Mock(return_value='test2')

    class Serializer2State(ProcessState):
        pass

    serializer_mapping = SerializerMapping({
        'key1': serializer1,
        'key2': serializer2
    })
    assert serializer_mapping({
        'key1': 'serializer1',
        'key2': 'serializer2'
    }) == {
        'key1': 'test1',
        'key2': 'test2'
    }
    serializer1.serialize.assert_called_once_with('serializer1')
    serializer2.serialize.assert_called_once_with('serializer2')

    with pytest.raises(ValueError, match='output must be a mapping'):
        serializer_mapping('not a mapping')

    serializer1_state = Serializer1State()
    serializer2_state = Serializer2State()

    serializer1.set_state(serializer1_state)
    serializer2.set_state(serializer2_state)

    data_pipeline_state = DataPipelineState()
    serializer_mapping.attach_data_pipeline_state(data_pipeline_state)

    assert serializer1._data_pipeline_state is data_pipeline_state
    assert serializer2._data_pipeline_state is data_pipeline_state

    assert data_pipeline_state.get_state(Serializer1State) is serializer1_state
    assert data_pipeline_state.get_state(Serializer2State) is serializer2_state
 def test_get_state(self):
     state = DataPipelineState()
     assert state.get_state(ProcessState) is None
Example #7
0
    def build_data_pipeline(
        self,
        data_source: Optional[str] = None,
        deserializer: Optional[Deserializer] = None,
        data_pipeline: Optional[DataPipeline] = None,
    ) -> Optional[DataPipeline]:
        """Build a :class:`.DataPipeline` incorporating available
        :class:`~flash.core.data.process.Preprocess` and :class:`~flash.core.data.process.Postprocess`
        objects. These will be overridden in the following resolution order (lowest priority first):

        - Lightning ``Datamodule``, either attached to the :class:`.Trainer` or to the :class:`.Task`.
        - :class:`.Task` defaults given to :meth:`.Task.__init__`.
        - :class:`.Task` manual overrides by setting :py:attr:`~data_pipeline`.
        - :class:`.DataPipeline` passed to this method.

        Args:
            data_pipeline: Optional highest priority source of
                :class:`~flash.core.data.process.Preprocess` and :class:`~flash.core.data.process.Postprocess`.

        Returns:
            The fully resolved :class:`.DataPipeline`.
        """
        deserializer, old_data_source, preprocess, postprocess, serializer = None, None, None, None, None

        # Datamodule
        datamodule = None
        if self.trainer is not None and hasattr(self.trainer, "datamodule"):
            datamodule = self.trainer.datamodule
        elif getattr(self, "datamodule", None) is not None:
            datamodule = self.datamodule

        if getattr(datamodule, "data_pipeline", None) is not None:
            old_data_source = getattr(datamodule.data_pipeline, "data_source",
                                      None)
            preprocess = getattr(datamodule.data_pipeline,
                                 "_preprocess_pipeline", None)
            postprocess = getattr(datamodule.data_pipeline,
                                  "_postprocess_pipeline", None)
            serializer = getattr(datamodule.data_pipeline, "_serializer", None)
            deserializer = getattr(datamodule.data_pipeline, "_deserializer",
                                   None)

        # Defaults / task attributes
        deserializer, preprocess, postprocess, serializer = Task._resolve(
            deserializer,
            preprocess,
            postprocess,
            serializer,
            self._deserializer,
            self._preprocess,
            self._postprocess,
            self._serializer,
        )

        # Datapipeline
        if data_pipeline is not None:
            deserializer, preprocess, postprocess, serializer = Task._resolve(
                deserializer,
                preprocess,
                postprocess,
                serializer,
                getattr(data_pipeline, "_deserializer", None),
                getattr(data_pipeline, "_preprocess_pipeline", None),
                getattr(data_pipeline, "_postprocess_pipeline", None),
                getattr(data_pipeline, "_serializer", None),
            )

        data_source = data_source or old_data_source

        if isinstance(data_source, str):
            if preprocess is None:
                data_source = DataSource(
                )  # TODO: warn the user that we are not using the specified data source
            else:
                data_source = preprocess.data_source_of_name(data_source)

        if deserializer is None or type(deserializer) is Deserializer:
            deserializer = getattr(preprocess, "deserializer", deserializer)

        data_pipeline = DataPipeline(data_source, preprocess, postprocess,
                                     deserializer, serializer)
        self._data_pipeline_state = self._data_pipeline_state or DataPipelineState(
        )
        self.attach_data_pipeline_state(self._data_pipeline_state)
        self._data_pipeline_state = data_pipeline.initialize(
            self._data_pipeline_state)
        return data_pipeline