Exemple #1
0
    def test_maintains_losses(self):
        layer = LayerWithLoss()
        layer(np.ones((2, 3)))
        previous_losses = layer.losses[:]

        call_collection = keras_save.LayerCallCollection(layer)
        fn = call_collection.add_function(layer.call, 'call')
        fn(np.ones((2, 3)))

        self.assertAllEqual(previous_losses, layer.losses)
    def assert_num_traces(layer_cls, training_keyword):
      layer = layer_cls()
      call_collection = keras_save.LayerCallCollection(layer)
      fn = call_collection.add_function(layer.call, 'call')

      fn(np.ones((2, 3)), training=True)
      self.assertLen(fn._list_all_concrete_functions_for_serialization(), 2)

      fn(np.ones((2, 4)), training=False)
      self.assertLen(fn._list_all_concrete_functions_for_serialization(), 4)

      if training_keyword:
        fn(np.ones((2, 5)), True)
        self.assertLen(fn._list_all_concrete_functions_for_serialization(), 6)
        fn(np.ones((2, 6)))
        self.assertLen(fn._list_all_concrete_functions_for_serialization(), 8)