def test_functions_have_same_trace(self): class Layer(keras.engine.base_layer.Layer): def call(self, inputs): return inputs def call2(self, inputs): return inputs * 2 layer = Layer() call_collection = keras_save.LayerCallCollection(layer) fn = call_collection.add_function(layer.call, 'call') fn2 = call_collection.add_function(layer.call2, 'call2') fn(np.ones((2, 3))) fn(np.ones((4, 5))) self.assertLen(fn._list_all_concrete_functions_for_serialization(), 2) self.assertLen(fn2._list_all_concrete_functions_for_serialization(), 2) # Check that the shapes are correct self.assertEqual( {(2, 3), (4, 5)}, set( tuple(c.structured_input_signature[0][0].shape.as_list()) for c in fn2._list_all_concrete_functions_for_serialization()))
def test_maintains_losses(self): layer = LayerWithLoss() layer(np.ones((2, 3))) previous_losses = layer.losses[:] call_collection = keras_save.LayerCallCollection(layer) fn = call_collection.add_function(layer.call, 'call') fn(np.ones((2, 3))) self.assertAllEqual(previous_losses, layer.losses)
def assert_num_traces(layer_cls, training_keyword): layer = layer_cls() call_collection = keras_save.LayerCallCollection(layer) fn = call_collection.add_function(layer.call, 'call') fn(np.ones((2, 3)), training=True) self.assertLen(fn._list_all_concrete_functions_for_serialization(), 2) fn(np.ones((2, 4)), training=False) self.assertLen(fn._list_all_concrete_functions_for_serialization(), 4) if training_keyword: fn(np.ones((2, 5)), True) self.assertLen( fn._list_all_concrete_functions_for_serialization(), 6) fn(np.ones((2, 6))) self.assertLen( fn._list_all_concrete_functions_for_serialization(), 8)