def test_maintains_losses(self): layer = LayerWithLoss() layer(np.ones((2, 3))) previous_losses = layer.losses[:] call_collection = keras_save.LayerCallCollection(layer) fn = call_collection.add_function(layer.call, 'call') fn(np.ones((2, 3))) self.assertAllEqual(previous_losses, layer.losses)
def assert_num_traces(layer_cls, training_keyword): layer = layer_cls() call_collection = keras_save.LayerCallCollection(layer) fn = call_collection.add_function(layer.call, 'call') fn(np.ones((2, 3)), training=True) self.assertLen(fn._list_all_concrete_functions_for_serialization(), 2) fn(np.ones((2, 4)), training=False) self.assertLen(fn._list_all_concrete_functions_for_serialization(), 4) if training_keyword: fn(np.ones((2, 5)), True) self.assertLen(fn._list_all_concrete_functions_for_serialization(), 6) fn(np.ones((2, 6))) self.assertLen(fn._list_all_concrete_functions_for_serialization(), 8)