Exemple #1
0
 def __init__(self):
   super(HasMapping, self).__init__()
   self.layer_dict = data_structures.wrap_or_unwrap(dict(output=core.Dense(7)))
   self.layer_dict["norm"] = data_structures.wrap_or_unwrap([])
   self.layer_dict["dense"] = data_structures.wrap_or_unwrap([])
   self.layer_dict["dense"].extend(
       [core.Dense(5),
        core.Dense(6, kernel_regularizer=tf.reduce_sum)])
   self.layer_dict["norm"].append(
       normalization.BatchNormalization())
   self.layer_dict["norm"].append(
       normalization.BatchNormalization())
 def _get_serialized_attributes_internal(self, serialization_cache):
     objects, functions = (super(
         RNNSavedModelSaver,
         self)._get_serialized_attributes_internal(serialization_cache))
     states = data_structures.wrap_or_unwrap(self.obj.states)
     # SaveModel require all the objects to be Trackable when saving.
     # If the states is still a tuple after wrap_or_unwrap, it means it doesn't
     # contain any trackable item within it, eg empty tuple or (None, None) for
     # stateless ConvLSTM2D. We convert them to list so that wrap_or_unwrap can
     # make it a Trackable again for saving. When loaded, ConvLSTM2D is
     # able to handle the tuple/list conversion.
     if isinstance(states, tuple):
         states = data_structures.wrap_or_unwrap(list(states))
     objects['states'] = states
     return objects, functions
    def _get_serialized_attributes_internal(self, serialization_cache):
        objects, functions = (super(
            RNNSavedModelSaver,
            self)._get_serialized_attributes_internal(serialization_cache))

        objects['states'] = data_structures.wrap_or_unwrap(self.obj.states)
        return objects, functions
Exemple #4
0
 def testDictWrapperBadKeys(self):
   a = tf.Module()
   a.d = {}
   a.d[1] = data_structures.wrap_or_unwrap([])
   model = training.Model()
   model.sub = a
   save_path = os.path.join(self.get_temp_dir(), "ckpt")
   with self.assertRaisesRegex(ValueError, "non-string key"):
     model.save_weights(save_path)
Exemple #5
0
 def __init__(self):
   super(HasList, self).__init__()
   self.layer_list = data_structures.wrap_or_unwrap([core.Dense(3)])
   self.layer_list.append(core.Dense(4))
   self.layer_list.extend(
       [core.Dense(5),
        core.Dense(6, kernel_regularizer=tf.reduce_sum)])
   self.layer_list += [
       core.Dense(7, bias_regularizer=tf.reduce_sum),
       core.Dense(8)
   ]
   self.layer_list += (
       data_structures.wrap_or_unwrap([core.Dense(9)]) +
       data_structures.wrap_or_unwrap([core.Dense(10)]))
   self.layer_list.extend(
       data_structures.wrap_or_unwrap(
           list([core.Dense(11)]) + [core.Dense(12)]))
   self.layers_with_updates = data_structures.wrap_or_unwrap(
       [normalization.BatchNormalization()])
Exemple #6
0
def wrap_layer_objects(layer, serialization_cache):
    """Returns extra trackable objects to attach to the serialized layer.

  Args:
    layer: Keras Layer object.
    serialization_cache: Dictionary shared between all objects during
      serialization.

  Returns:
    A dictionary containing all checkpointable objects from a
    SerializedAttributes object. See LayerAttributes and ModelAttributes for
    entire list of objects
  """
    # Wrap all regularization losses as tf.functions.
    # First, generate list of all regularization losses in this layer and
    # sublayers.
    all_losses = layer._callable_losses[:]  # pylint: disable=protected-access
    for child_layer in utils.list_all_layers(layer):
        all_losses.extend(child_layer._callable_losses)  # pylint: disable=protected-access
    # Next, wrap all loss functions as tf.functions. Use the serialization cache
    # to store already-wrapped functions.
    keras_loss_cache = serialization_cache.setdefault('keras_losses', {})
    wrapped_loss_functions = []
    for loss_fn in all_losses:
        if loss_fn in keras_loss_cache:
            wrapped_loss_functions.append(keras_loss_cache[loss_fn])
        else:
            wrapped_loss = _wrap_unconditional_loss(loss_fn,
                                                    len(keras_loss_cache))
            keras_loss_cache[loss_fn] = wrapped_loss
            wrapped_loss_functions.append(wrapped_loss)
    wrapped_layer_losses = [
        keras_loss_cache[fn] for fn in layer._callable_losses[:]
    ]  # pylint: disable=protected-access

    layer_metrics = data_structures.wrap_or_unwrap(
        {m.name: m
         for m in layer._metrics})  # pylint: disable=protected-access
    return dict(variables=data_structures.wrap_or_unwrap(layer.variables),
                trainable_variables=data_structures.wrap_or_unwrap(
                    layer.trainable_variables),
                non_trainable_variables=data_structures.wrap_or_unwrap(
                    layer.non_trainable_variables),
                layers=data_structures.wrap_or_unwrap(
                    utils.list_all_layers(layer)),
                metrics=data_structures.wrap_or_unwrap(layer.metrics),
                regularization_losses=data_structures.wrap_or_unwrap(
                    wrapped_loss_functions),
                layer_regularization_losses=data_structures.wrap_or_unwrap(
                    wrapped_layer_losses),
                layer_metrics=layer_metrics)
Exemple #7
0
def convert_to_trackable(obj, parent=None):
    """Converts `obj` to `Trackable`."""
    if isinstance(obj, base.Trackable):
        return obj
    obj = data_structures.wrap_or_unwrap(obj)
    if (tensor_util.is_tf_type(obj)
            and obj.dtype not in (dtypes.variant, dtypes.resource)
            and not resource_variable_ops.is_resource_variable(obj)):
        return function_saved_model_utils.TrackableConstant(obj, parent)
    if not isinstance(obj, base.Trackable):
        raise ValueError(f"Cannot convert {obj} to Trackable.")
    return obj
Exemple #8
0
 def _get_serialized_attributes_internal(self, serialization_cache):
     objects, functions = (super(
         RNNSavedModelSaver,
         self)._get_serialized_attributes_internal(serialization_cache))
     states = data_structures.wrap_or_unwrap(self.obj.states)
     # Force the tuple into TupleWrapper which is a trackable object. The
     # save/load code requires all the objects to be trackable.
     # Tuple is not converted to TupleWrapper by data_structures.wrap_or_unwrap()
     # if it doesn't contains any trackable objects.
     if isinstance(states, tuple):
         states = data_structures._TupleWrapper(states)  # pylint: disable=protected-access
     objects['states'] = states
     return objects, functions
Exemple #9
0
 def _get_serialized_attributes_internal(self, unused_serialization_cache):
     return (
         dict(variables=data_structures.wrap_or_unwrap(self.obj.variables)),
         dict())  # TODO(b/135550038): save functions to enable saving
Exemple #10
0
 def testLayerCollectionWithExternalMutation(self):
   l = []
   l_wrapper = data_structures.wrap_or_unwrap(l)
   layer = core.Dense(1)
   l.append(layer)
   self.assertEqual([layer], l_wrapper.layers)