Пример #1
0
  def test_functions_have_same_trace(self):

    class Layer(keras.engine.base_layer.Layer):

      def call(self, inputs):
        return inputs

      def call2(self, inputs):
        return inputs * 2

    layer = Layer()

    call_collection = keras_save.LayerCallCollection(layer)
    fn = call_collection.add_function(layer.call, 'call', True)
    fn2 = call_collection.add_function(layer.call2, 'call2', True)

    with keras_save.tracing_scope():
      fn(np.ones((2, 3)))
      fn(np.ones((4, 5)))

    self.assertLen(
        fn.wrapped_call._list_all_concrete_functions_for_serialization(), 2)
    self.assertLen(
        fn2.wrapped_call._list_all_concrete_functions_for_serialization(), 2)

    # Check that the shapes are correct
    self.assertEqual(
        {(2, 3), (4, 5)},
        set(tuple(c.structured_input_signature[0][0].shape.as_list()) for c in
            fn2.wrapped_call._list_all_concrete_functions_for_serialization()))
Пример #2
0
    def assert_num_traces(layer_cls, training_keyword):
      layer = layer_cls()
      call_collection = keras_save.LayerCallCollection(layer)
      fn = call_collection.add_function(layer.call, 'call', True)

      with keras_save.tracing_scope():
        fn(np.ones((2, 3)), training=True)
      self.assertLen(
          fn.wrapped_call._list_all_concrete_functions_for_serialization(), 2)
      with keras_save.tracing_scope():
        fn(np.ones((2, 4)), training=False)
      self.assertLen(
          fn.wrapped_call._list_all_concrete_functions_for_serialization(), 4)

      if training_keyword:
        with keras_save.tracing_scope():
          fn(np.ones((2, 5)), True)
        self.assertLen(
            fn.wrapped_call._list_all_concrete_functions_for_serialization(), 6)
        with keras_save.tracing_scope():
          fn(np.ones((2, 6)))
        self.assertLen(
            fn.wrapped_call._list_all_concrete_functions_for_serialization(), 8)