Example #1
0
    def testFunctionCaching(self):
        @def_function.function
        def f(dict_input):
            return dict_input["x"] + constant_op.constant(1.)

        first_trace = f.get_concrete_function({"x": constant_op.constant(2.)})
        second_trace = f.get_concrete_function(
            data_structures._DictWrapper({"x": constant_op.constant(3.)}))
        self.assertIs(first_trace, second_trace)
  def testFunctionCaching(self):
    @def_function.function
    def f(dict_input):
      return dict_input["x"] + constant_op.constant(1.)

    first_trace = f.get_concrete_function({"x": constant_op.constant(2.)})
    second_trace = f.get_concrete_function(
        data_structures._DictWrapper({"x": constant_op.constant(3.)}))
    self.assertIs(first_trace, second_trace)
Example #3
0
def wrap_layer_objects(layer, serialization_cache):
    """Returns extra trackable objects to attach to the serialized layer.

  Args:
    layer: Keras Layer object.
    serialization_cache: Dictionary shared between all objects during
      serialization.

  Returns:
    A dictionary containing all checkpointable objects from a
    SerializedAttributes object. See LayerAttributes and ModelAttributes for
    entire list of objects
  """
    # Wrap all regularization losses as tf.functions.
    # First, generate list of all regularization losses in this layer and
    # sublayers.
    all_losses = layer._callable_losses[:]  # pylint: disable=protected-access
    for child_layer in utils.list_all_layers(layer):
        all_losses.extend(child_layer._callable_losses)  # pylint: disable=protected-access
    # Next, wrap all loss functions as tf.functions. Use the serialization cache
    # to store already-wrapped functions.
    keras_loss_cache = serialization_cache.setdefault('keras_losses', {})
    wrapped_loss_functions = []
    for loss_fn in all_losses:
        if loss_fn in keras_loss_cache:
            wrapped_loss_functions.append(keras_loss_cache[loss_fn])
        else:
            wrapped_loss = _wrap_unconditional_loss(loss_fn,
                                                    len(keras_loss_cache))
            keras_loss_cache[loss_fn] = wrapped_loss
            wrapped_loss_functions.append(wrapped_loss)
    wrapped_layer_losses = [
        keras_loss_cache[fn] for fn in layer._callable_losses[:]
    ]  # pylint: disable=protected-access

    layer_metrics = data_structures._DictWrapper(  # pylint: disable=protected-access
        {m.name: m
         for m in layer._metrics})  # pylint: disable=protected-access
    return dict(variables=data_structures.ListWrapper(layer.variables),
                trainable_variables=data_structures.ListWrapper(
                    layer.trainable_variables),
                non_trainable_variables=data_structures.ListWrapper(
                    layer.non_trainable_variables),
                layers=data_structures.ListWrapper(
                    utils.list_all_layers(layer)),
                metrics=data_structures.ListWrapper(layer.metrics),
                regularization_losses=data_structures.ListWrapper(
                    wrapped_loss_functions),
                layer_regularization_losses=data_structures.ListWrapper(
                    wrapped_layer_losses),
                layer_metrics=layer_metrics)
Example #4
0
 def testSameStructure(self):
     d = {1: "a"}
     nest.assert_same_structure(d, data_structures._DictWrapper(d.copy()))
Example #5
0
 def testConstructableFromSequence(self):
     result = data_structures._DictWrapper([(1, 2), (3, 4)])
     self.assertIsInstance(result, dict)
     self.assertEqual({1: 2, 3: 4}, result)
 def testPickle(self):
     original = data_structures._DictWrapper(dict(a=1, b=2))
     serialized = pickle.dumps(original)
     del original
     deserialized = pickle.loads(serialized)
     self.assertEqual(dict(a=1, b=2), deserialized)
 def testSameStructure(self):
   d = {1: "a"}
   nest.assert_same_structure(d, data_structures._DictWrapper(d.copy()))
 def testConstructableFromSequence(self):
   result = data_structures._DictWrapper([(1, 2), (3, 4)])
   self.assertIsInstance(result, dict)
   self.assertEqual({1: 2, 3: 4}, result)