def __init__(self): super(HasMapping, self).__init__() self.layer_dict = data_structures.wrap_or_unwrap(dict(output=core.Dense(7))) self.layer_dict["norm"] = data_structures.wrap_or_unwrap([]) self.layer_dict["dense"] = data_structures.wrap_or_unwrap([]) self.layer_dict["dense"].extend( [core.Dense(5), core.Dense(6, kernel_regularizer=tf.reduce_sum)]) self.layer_dict["norm"].append( normalization.BatchNormalization()) self.layer_dict["norm"].append( normalization.BatchNormalization())
def _get_serialized_attributes_internal(self, serialization_cache): objects, functions = (super( RNNSavedModelSaver, self)._get_serialized_attributes_internal(serialization_cache)) states = data_structures.wrap_or_unwrap(self.obj.states) # SaveModel require all the objects to be Trackable when saving. # If the states is still a tuple after wrap_or_unwrap, it means it doesn't # contain any trackable item within it, eg empty tuple or (None, None) for # stateless ConvLSTM2D. We convert them to list so that wrap_or_unwrap can # make it a Trackable again for saving. When loaded, ConvLSTM2D is # able to handle the tuple/list conversion. if isinstance(states, tuple): states = data_structures.wrap_or_unwrap(list(states)) objects['states'] = states return objects, functions
def _get_serialized_attributes_internal(self, serialization_cache): objects, functions = (super( RNNSavedModelSaver, self)._get_serialized_attributes_internal(serialization_cache)) objects['states'] = data_structures.wrap_or_unwrap(self.obj.states) return objects, functions
def testDictWrapperBadKeys(self): a = tf.Module() a.d = {} a.d[1] = data_structures.wrap_or_unwrap([]) model = training.Model() model.sub = a save_path = os.path.join(self.get_temp_dir(), "ckpt") with self.assertRaisesRegex(ValueError, "non-string key"): model.save_weights(save_path)
def __init__(self): super(HasList, self).__init__() self.layer_list = data_structures.wrap_or_unwrap([core.Dense(3)]) self.layer_list.append(core.Dense(4)) self.layer_list.extend( [core.Dense(5), core.Dense(6, kernel_regularizer=tf.reduce_sum)]) self.layer_list += [ core.Dense(7, bias_regularizer=tf.reduce_sum), core.Dense(8) ] self.layer_list += ( data_structures.wrap_or_unwrap([core.Dense(9)]) + data_structures.wrap_or_unwrap([core.Dense(10)])) self.layer_list.extend( data_structures.wrap_or_unwrap( list([core.Dense(11)]) + [core.Dense(12)])) self.layers_with_updates = data_structures.wrap_or_unwrap( [normalization.BatchNormalization()])
def wrap_layer_objects(layer, serialization_cache): """Returns extra trackable objects to attach to the serialized layer. Args: layer: Keras Layer object. serialization_cache: Dictionary shared between all objects during serialization. Returns: A dictionary containing all checkpointable objects from a SerializedAttributes object. See LayerAttributes and ModelAttributes for entire list of objects """ # Wrap all regularization losses as tf.functions. # First, generate list of all regularization losses in this layer and # sublayers. all_losses = layer._callable_losses[:] # pylint: disable=protected-access for child_layer in utils.list_all_layers(layer): all_losses.extend(child_layer._callable_losses) # pylint: disable=protected-access # Next, wrap all loss functions as tf.functions. Use the serialization cache # to store already-wrapped functions. keras_loss_cache = serialization_cache.setdefault('keras_losses', {}) wrapped_loss_functions = [] for loss_fn in all_losses: if loss_fn in keras_loss_cache: wrapped_loss_functions.append(keras_loss_cache[loss_fn]) else: wrapped_loss = _wrap_unconditional_loss(loss_fn, len(keras_loss_cache)) keras_loss_cache[loss_fn] = wrapped_loss wrapped_loss_functions.append(wrapped_loss) wrapped_layer_losses = [ keras_loss_cache[fn] for fn in layer._callable_losses[:] ] # pylint: disable=protected-access layer_metrics = data_structures.wrap_or_unwrap( {m.name: m for m in layer._metrics}) # pylint: disable=protected-access return dict(variables=data_structures.wrap_or_unwrap(layer.variables), trainable_variables=data_structures.wrap_or_unwrap( layer.trainable_variables), non_trainable_variables=data_structures.wrap_or_unwrap( layer.non_trainable_variables), layers=data_structures.wrap_or_unwrap( utils.list_all_layers(layer)), metrics=data_structures.wrap_or_unwrap(layer.metrics), regularization_losses=data_structures.wrap_or_unwrap( wrapped_loss_functions), layer_regularization_losses=data_structures.wrap_or_unwrap( wrapped_layer_losses), layer_metrics=layer_metrics)
def convert_to_trackable(obj, parent=None): """Converts `obj` to `Trackable`.""" if isinstance(obj, base.Trackable): return obj obj = data_structures.wrap_or_unwrap(obj) if (tensor_util.is_tf_type(obj) and obj.dtype not in (dtypes.variant, dtypes.resource) and not resource_variable_ops.is_resource_variable(obj)): return function_saved_model_utils.TrackableConstant(obj, parent) if not isinstance(obj, base.Trackable): raise ValueError(f"Cannot convert {obj} to Trackable.") return obj
def _get_serialized_attributes_internal(self, serialization_cache): objects, functions = (super( RNNSavedModelSaver, self)._get_serialized_attributes_internal(serialization_cache)) states = data_structures.wrap_or_unwrap(self.obj.states) # Force the tuple into TupleWrapper which is a trackable object. The # save/load code requires all the objects to be trackable. # Tuple is not converted to TupleWrapper by data_structures.wrap_or_unwrap() # if it doesn't contains any trackable objects. if isinstance(states, tuple): states = data_structures._TupleWrapper(states) # pylint: disable=protected-access objects['states'] = states return objects, functions
def _get_serialized_attributes_internal(self, unused_serialization_cache): return ( dict(variables=data_structures.wrap_or_unwrap(self.obj.variables)), dict()) # TODO(b/135550038): save functions to enable saving
def testLayerCollectionWithExternalMutation(self): l = [] l_wrapper = data_structures.wrap_or_unwrap(l) layer = core.Dense(1) l.append(layer) self.assertEqual([layer], l_wrapper.layers)