Exemplo n.º 1
0
  def test_functions_have_same_trace(self):

    class Layer(keras.engine.base_layer.Layer):

      def call(self, inputs):
        return inputs

      def call2(self, inputs):
        return inputs * 2

    layer = Layer()

    call_collection = keras_save.LayerCallCollection(layer)
    fn = call_collection.add_function(layer.call, 'call', True)
    fn2 = call_collection.add_function(layer.call2, 'call2', True)

    with keras_save.tracing_scope():
      fn(np.ones((2, 3)))
      fn(np.ones((4, 5)))

    self.assertLen(
        fn.wrapped_call._list_all_concrete_functions_for_serialization(), 2)
    self.assertLen(
        fn2.wrapped_call._list_all_concrete_functions_for_serialization(), 2)

    # Check that the shapes are correct
    self.assertEqual(
        {(2, 3), (4, 5)},
        set(tuple(c.structured_input_signature[0][0].shape.as_list()) for c in
            fn2.wrapped_call._list_all_concrete_functions_for_serialization()))
Exemplo n.º 2
0
  def test_maintains_losses(self):
    layer = LayerWithLoss()
    layer(np.ones((2, 3)))
    previous_losses = layer.losses[:]

    call_collection = keras_save.LayerCallCollection(layer)
    fn = call_collection.add_function(layer.call, 'call', True)
    fn(np.ones((2, 3)))

    self.assertAllEqual(previous_losses, layer.losses)
Exemplo n.º 3
0
    def assert_num_traces(layer_cls, training_keyword):
      layer = layer_cls()
      call_collection = keras_save.LayerCallCollection(layer)
      fn = call_collection.add_function(layer.call, 'call')

      fn(np.ones((2, 3)), training=True)
      self.assertLen(fn._list_all_concrete_functions_for_serialization(), 2)

      fn(np.ones((2, 4)), training=False)
      self.assertLen(fn._list_all_concrete_functions_for_serialization(), 4)

      if training_keyword:
        fn(np.ones((2, 5)), True)
        self.assertLen(fn._list_all_concrete_functions_for_serialization(), 6)
        fn(np.ones((2, 6)))
        self.assertLen(fn._list_all_concrete_functions_for_serialization(), 8)