def testFunctionCaching(self): @def_function.function def f(dict_input): return dict_input["x"] + constant_op.constant(1.) first_trace = f.get_concrete_function({"x": constant_op.constant(2.)}) second_trace = f.get_concrete_function( data_structures._DictWrapper({"x": constant_op.constant(3.)})) self.assertIs(first_trace, second_trace)
def testFunctionCaching(self): @def_function.function def f(dict_input): return dict_input["x"] + constant_op.constant(1.) first_trace = f.get_concrete_function({"x": constant_op.constant(2.)}) second_trace = f.get_concrete_function( data_structures._DictWrapper({"x": constant_op.constant(3.)})) self.assertIs(first_trace, second_trace)
def wrap_layer_objects(layer, serialization_cache): """Returns extra trackable objects to attach to the serialized layer. Args: layer: Keras Layer object. serialization_cache: Dictionary shared between all objects during serialization. Returns: A dictionary containing all checkpointable objects from a SerializedAttributes object. See LayerAttributes and ModelAttributes for entire list of objects """ # Wrap all regularization losses as tf.functions. # First, generate list of all regularization losses in this layer and # sublayers. all_losses = layer._callable_losses[:] # pylint: disable=protected-access for child_layer in utils.list_all_layers(layer): all_losses.extend(child_layer._callable_losses) # pylint: disable=protected-access # Next, wrap all loss functions as tf.functions. Use the serialization cache # to store already-wrapped functions. keras_loss_cache = serialization_cache.setdefault('keras_losses', {}) wrapped_loss_functions = [] for loss_fn in all_losses: if loss_fn in keras_loss_cache: wrapped_loss_functions.append(keras_loss_cache[loss_fn]) else: wrapped_loss = _wrap_unconditional_loss(loss_fn, len(keras_loss_cache)) keras_loss_cache[loss_fn] = wrapped_loss wrapped_loss_functions.append(wrapped_loss) wrapped_layer_losses = [ keras_loss_cache[fn] for fn in layer._callable_losses[:] ] # pylint: disable=protected-access layer_metrics = data_structures._DictWrapper( # pylint: disable=protected-access {m.name: m for m in layer._metrics}) # pylint: disable=protected-access return dict(variables=data_structures.ListWrapper(layer.variables), trainable_variables=data_structures.ListWrapper( layer.trainable_variables), non_trainable_variables=data_structures.ListWrapper( layer.non_trainable_variables), layers=data_structures.ListWrapper( utils.list_all_layers(layer)), metrics=data_structures.ListWrapper(layer.metrics), regularization_losses=data_structures.ListWrapper( wrapped_loss_functions), layer_regularization_losses=data_structures.ListWrapper( wrapped_layer_losses), layer_metrics=layer_metrics)
def testSameStructure(self): d = {1: "a"} nest.assert_same_structure(d, data_structures._DictWrapper(d.copy()))
def testConstructableFromSequence(self): result = data_structures._DictWrapper([(1, 2), (3, 4)]) self.assertIsInstance(result, dict) self.assertEqual({1: 2, 3: 4}, result)
def testPickle(self): original = data_structures._DictWrapper(dict(a=1, b=2)) serialized = pickle.dumps(original) del original deserialized = pickle.loads(serialized) self.assertEqual(dict(a=1, b=2), deserialized)
def testSameStructure(self): d = {1: "a"} nest.assert_same_structure(d, data_structures._DictWrapper(d.copy()))
def testConstructableFromSequence(self): result = data_structures._DictWrapper([(1, 2), (3, 4)]) self.assertIsInstance(result, dict) self.assertEqual({1: 2, 3: 4}, result)