コード例 #1
0
ファイル: test_multinetwork.py プロジェクト: yivash/timeserio
 def test_deserialized_batch_norm(self, bn_multinetwork):
     multinetwork = bn_multinetwork
     multinetwork.trainable_models = None
     s = dumps(multinetwork)
     new_multinetwork = loads(s)
     self.assert_model_dicts_same(multinetwork.model,
                                  new_multinetwork.model)
コード例 #2
0
ファイル: test_multinetwork.py プロジェクト: yivash/timeserio
 def test_deserialized_gradients(self, ef_multinetwork):
     ef_multinetwork._init_model()
     s = dumps(ef_multinetwork)
     ef_multinetwork = loads(s)
     model = ef_multinetwork.model['combined']
     grads = K.gradients(model.total_loss, model.trainable_weights)
     assert all([g is not None for g in grads])
コード例 #3
0
 def test_layer_sharing(self, ef_multinetwork):
     multinetwork = ef_multinetwork
     s = dumps(multinetwork)
     new_multinetwork = loads(s)
     self.assert_model_dicts_same(
         multinetwork.model, new_multinetwork.model
     )
コード例 #4
0
ファイル: test_multinetwork.py プロジェクト: yivash/timeserio
    def test_serialize_no_keras(self, multinetwork, mocker):
        import timeserio.keras.multinetwork

        # replace `keras` with a broken non-module
        mocker.patch.object(timeserio.keras.multinetwork, "keras", None)

        params = multinetwork.get_params()
        s = dumps(multinetwork)

        # loading fails because HABEMUS_KERAS is still True
        with pytest.raises(AttributeError):
            new_multinetwork = loads(s)

        # if both are mocked correctly, "unsafe" unpickling works again
        mocker.patch.object(timeserio.keras.multinetwork, "HABEMUS_KERAS",
                            False)
        new_multinetwork = loads(s)
        new_params = new_multinetwork.get_params()

        assert new_params == params
コード例 #5
0
ファイル: test_multinetwork.py プロジェクト: yivash/timeserio
 def test_optimizer_state(self, multinetwork):
     with multinetwork._training_context():
         lr_init = K.get_value(
             multinetwork.model['forecaster'].optimizer.lr)
         lr_changed = lr_init + 1.
         K.set_value(multinetwork.model['forecaster'].optimizer.lr,
                     lr_changed)
     s = dumps(multinetwork)
     new_multinetwork = loads(s)
     lr_new = K.get_value(new_multinetwork.model['forecaster'].optimizer.lr)
     assert np.allclose(lr_new, lr_changed), 'Optimizer loaded incorrectly'
コード例 #6
0
ファイル: test_multinetwork.py プロジェクト: yivash/timeserio
 def test_history_preserved(self, multinetwork):
     record = {'model': None}
     multinetwork.history = [record]
     s = dumps(multinetwork)
     new_multinetwork = loads(s)
     assert new_multinetwork.history == [record]
コード例 #7
0
ファイル: test_multinetwork.py プロジェクト: yivash/timeserio
 def test_deserialized_params(self, ef_multinetwork):
     params = ef_multinetwork.get_params()
     s = dumps(ef_multinetwork)
     new_multinetwork = loads(s)
     new_params = new_multinetwork.get_params()
     assert new_params == params
コード例 #8
0
 def test_pickle_to_string(self, obj, assert_object_same):
     s = dumps(obj)
     obj2 = loads(s)
     assert_object_same(obj, obj2)