コード例 #1
0
ファイル: test_marshalling.py プロジェクト: databand-ai/dbnd
    def test_history_marshalling_target_to_value(self, history, temp_target):
        history_marshaller = TensorflowKerasHistoryMarshaller()
        history_marshaller.value_to_target(history, target=temp_target)
        loaded_history = history_marshaller.target_to_value(target=temp_target)

        # Compare the inner dictionary to ensure two objects are equal value-wise
        assert loaded_history.history == history.history
コード例 #2
0
ファイル: _plugin.py プロジェクト: kalebinn/dbnd
def dbnd_setup_plugin():
    import tensorflow

    from dbnd_tensorflow.marshalling.tensorflow_marshaller import (
        TensorflowKerasHistoryMarshaller,
        TensorflowKerasModelMarshaller,
    )
    from dbnd_tensorflow.marshalling.tensorflow_values import (
        TensorflowHistoryValueType,
        TensorflowModelValueType,
    )

    register_marshaller(
        tensorflow.python.keras.engine.training.Model,
        FileFormat.tfmodel,
        TensorflowKerasModelMarshaller(),
    )
    register_marshaller(
        tensorflow.python.keras.callbacks.History,
        FileFormat.tfhistory,
        TensorflowKerasHistoryMarshaller(),
    )

    register_value_type(TensorflowModelValueType())
    register_value_type(TensorflowHistoryValueType())
コード例 #3
0
def dbnd_setup_plugin():
    from tensorflow.keras import models
    from tensorflow.keras.callbacks import History

    from dbnd_tensorflow.marshalling.tensorflow_marshaller import (
        TensorflowKerasHistoryMarshaller,
        TensorflowKerasModelMarshaller,
    )
    from dbnd_tensorflow.marshalling.tensorflow_values import (
        TensorflowHistoryValueType,
        TensorflowModelValueType,
    )

    register_marshaller(
        models.Model, FileFormat.tfmodel, TensorflowKerasModelMarshaller()
    )
    register_marshaller(
        History, FileFormat.tfhistory, TensorflowKerasHistoryMarshaller()
    )

    register_value_type(TensorflowModelValueType())
    register_value_type(TensorflowHistoryValueType())
コード例 #4
0
ファイル: test_marshalling.py プロジェクト: databand-ai/dbnd
 def test_history_marshalling_value_to_target(self, history, temp_target):
     history_marshaller = TensorflowKerasHistoryMarshaller()
     history_marshaller.value_to_target(history, target=temp_target)
     assert os.path.exists(temp_target.path)