def test_deserialized_batch_norm(self, bn_multinetwork): multinetwork = bn_multinetwork multinetwork.trainable_models = None s = dumps(multinetwork) new_multinetwork = loads(s) self.assert_model_dicts_same(multinetwork.model, new_multinetwork.model)
def test_deserialized_gradients(self, ef_multinetwork): ef_multinetwork._init_model() s = dumps(ef_multinetwork) ef_multinetwork = loads(s) model = ef_multinetwork.model['combined'] grads = K.gradients(model.total_loss, model.trainable_weights) assert all([g is not None for g in grads])
def test_layer_sharing(self, ef_multinetwork): multinetwork = ef_multinetwork s = dumps(multinetwork) new_multinetwork = loads(s) self.assert_model_dicts_same( multinetwork.model, new_multinetwork.model )
def test_optimizer_state(self, multinetwork): with multinetwork._training_context(): lr_init = K.get_value( multinetwork.model['forecaster'].optimizer.lr) lr_changed = lr_init + 1. K.set_value(multinetwork.model['forecaster'].optimizer.lr, lr_changed) s = dumps(multinetwork) new_multinetwork = loads(s) lr_new = K.get_value(new_multinetwork.model['forecaster'].optimizer.lr) assert np.allclose(lr_new, lr_changed), 'Optimizer loaded incorrectly'
def test_serialize_no_keras(self, multinetwork, mocker): import timeserio.keras.multinetwork # replace `keras` with a broken non-module mocker.patch.object(timeserio.keras.multinetwork, "keras", None) params = multinetwork.get_params() s = dumps(multinetwork) # loading fails because HABEMUS_KERAS is still True with pytest.raises(AttributeError): new_multinetwork = loads(s) # if both are mocked correctly, "unsafe" unpickling works again mocker.patch.object(timeserio.keras.multinetwork, "HABEMUS_KERAS", False) new_multinetwork = loads(s) new_params = new_multinetwork.get_params() assert new_params == params
def test_history_preserved(self, multinetwork): record = {'model': None} multinetwork.history = [record] s = dumps(multinetwork) new_multinetwork = loads(s) assert new_multinetwork.history == [record]
def test_deserialized_params(self, ef_multinetwork): params = ef_multinetwork.get_params() s = dumps(ef_multinetwork) new_multinetwork = loads(s) new_params = new_multinetwork.get_params() assert new_params == params
def test_pickle_to_string(self, obj, assert_object_same): s = dumps(obj) obj2 = loads(s) assert_object_same(obj, obj2)