Exemple #1
0
def wrap_layer_objects(layer, serialization_cache):
    """Returns extra trackable objects to attach to the serialized layer.

    Args:
      layer: Keras Layer object.
      serialization_cache: Dictionary shared between all objects during
        serialization.

    Returns:
      A dictionary containing all checkpointable objects from a
      SerializedAttributes object. See LayerAttributes and ModelAttributes for
      entire list of objects
    """
    # Wrap all regularization losses as tf.functions.
    # First, generate list of all regularization losses in this layer and
    # sublayers.
    all_losses = layer._callable_losses[:]  # pylint: disable=protected-access
    for child_layer in utils.list_all_layers(layer):
        all_losses.extend(child_layer._callable_losses)  # pylint: disable=protected-access
    # Next, wrap all loss functions as tf.functions. Use the serialization cache
    # to store already-wrapped functions.
    keras_loss_cache = serialization_cache.setdefault("keras_losses", {})
    wrapped_loss_functions = []
    for loss_fn in all_losses:
        if loss_fn in keras_loss_cache:
            wrapped_loss_functions.append(keras_loss_cache[loss_fn])
        else:
            wrapped_loss = _wrap_unconditional_loss(loss_fn,
                                                    len(keras_loss_cache))
            keras_loss_cache[loss_fn] = wrapped_loss
            wrapped_loss_functions.append(wrapped_loss)
    wrapped_layer_losses = [
        keras_loss_cache[fn] for fn in layer._callable_losses[:]  # pylint: disable=protected-access
    ]

    layer_metrics = tf.__internal__.tracking.wrap(
        {m.name: m
         for m in layer._metrics})  # pylint: disable=protected-access

    # Avoid duplicate creation of shard Variables on loading.
    # `layer.variables` will return the shard Variables rather than the
    # ShardedVariables (b/224541446), but Keras loading will create new
    # ShardedVariables (and thus shard Variables) from Keras metadata if needed.
    # There's no need to also save the shard Variables here, so filter them out.
    variables = _filter_shards(layer.variables)
    trainable_variables = _filter_shards(layer.trainable_variables)
    non_trainable_variables = _filter_shards(layer.non_trainable_variables)
    return dict(
        variables=tf.__internal__.tracking.wrap(variables),
        trainable_variables=tf.__internal__.tracking.wrap(trainable_variables),
        non_trainable_variables=tf.__internal__.tracking.wrap(
            non_trainable_variables),
        layers=tf.__internal__.tracking.wrap(utils.list_all_layers(layer)),
        metrics=tf.__internal__.tracking.wrap(layer.metrics),
        regularization_losses=tf.__internal__.tracking.wrap(
            wrapped_loss_functions),
        layer_regularization_losses=tf.__internal__.tracking.wrap(
            wrapped_layer_losses),
        layer_metrics=layer_metrics,
    )
Exemple #2
0
def wrap_layer_objects(layer, serialization_cache):
    """Returns extra trackable objects to attach to the serialized layer.

  Args:
    layer: Keras Layer object.
    serialization_cache: Dictionary shared between all objects during
      serialization.

  Returns:
    A dictionary containing all checkpointable objects from a
    SerializedAttributes object. See LayerAttributes and ModelAttributes for
    entire list of objects
  """
    # Wrap all regularization losses as tf.functions.
    # First, generate list of all regularization losses in this layer and
    # sublayers.
    all_losses = layer._callable_losses[:]  # pylint: disable=protected-access
    for child_layer in utils.list_all_layers(layer):
        all_losses.extend(child_layer._callable_losses)  # pylint: disable=protected-access
    # Next, wrap all loss functions as tf.functions. Use the serialization cache
    # to store already-wrapped functions.
    keras_loss_cache = serialization_cache.setdefault('keras_losses', {})
    wrapped_loss_functions = []
    for loss_fn in all_losses:
        if loss_fn in keras_loss_cache:
            wrapped_loss_functions.append(keras_loss_cache[loss_fn])
        else:
            wrapped_loss = _wrap_unconditional_loss(loss_fn,
                                                    len(keras_loss_cache))
            keras_loss_cache[loss_fn] = wrapped_loss
            wrapped_loss_functions.append(wrapped_loss)
    wrapped_layer_losses = [
        keras_loss_cache[fn] for fn in layer._callable_losses[:]
    ]  # pylint: disable=protected-access

    layer_metrics = data_structures._DictWrapper(  # pylint: disable=protected-access
        {m.name: m
         for m in layer._metrics})  # pylint: disable=protected-access
    return dict(variables=data_structures.ListWrapper(layer.variables),
                trainable_variables=data_structures.ListWrapper(
                    layer.trainable_variables),
                non_trainable_variables=data_structures.ListWrapper(
                    layer.non_trainable_variables),
                layers=data_structures.ListWrapper(
                    utils.list_all_layers(layer)),
                metrics=data_structures.ListWrapper(layer.metrics),
                regularization_losses=data_structures.ListWrapper(
                    wrapped_loss_functions),
                layer_regularization_losses=data_structures.ListWrapper(
                    wrapped_layer_losses),
                layer_metrics=layer_metrics)
Exemple #3
0
def _replace_child_layer_functions(layer, serialization_cache):
  """Replaces functions in the children layers with wrapped tf.functions.

  This step allows functions from parent layers to reference the wrapped
  functions from their children layers instead of retracing the ops.

  This function also resets all losses stored in the layer. These are stored in
  the returned dictionary. Use `_restore_child_layer_functions` to restore
  the original attributes.

  Args:
    layer: Keras Layer object.
    serialization_cache: Dictionary shared between all objects during
      serialization.

  Returns:
    Dictionary mapping layer objects -> original functions and losses:
      { Child layer 1: {
          'losses': Original losses,
          'call': Original call function
          '_activity_regularizer': Original activity regularizer},
        Child layer 2: ...
      }
  """
  # pylint: disable=protected-access
  original_fns = {}

  def replace_layer_functions(child_layer, serialized_fns):
    """Replaces layer call and activity regularizer with wrapped functions."""
    original_fns[child_layer] = {
        'call': child_layer.call,
        '_activity_regularizer': child_layer._activity_regularizer
    }
    with trackable.no_automatic_dependency_tracking_scope(child_layer):
      try:
        child_layer._activity_regularizer = serialized_fns.get(
            'activity_regularizer_fn')
      except AttributeError:
        # Some layers have an unsettable activity regularizer.
        pass
      child_layer.call = utils.use_wrapped_call(
          child_layer,
          serialized_fns['call_and_return_conditional_losses'],
          default_training_value=False)

  def replace_metric_functions(child_layer, serialized_fns):
    """Replaces metric functions with wrapped functions."""
    original_fns[child_layer] = {
        '__call__': child_layer.__call__,
        'result': child_layer.result,
        'update_state': child_layer.update_state
    }
    with trackable.no_automatic_dependency_tracking_scope(child_layer):
      child_layer.__call__ = serialized_fns['__call__']
      child_layer.result = serialized_fns['result']
      child_layer.update_state = serialized_fns['update_state']

  for child_layer in utils.list_all_layers(layer):
    if isinstance(child_layer, input_layer.InputLayer):
      continue

    if child_layer not in serialization_cache[constants.KERAS_CACHE_KEY]:
      serialized_functions = (
          child_layer._trackable_saved_model_saver._get_serialized_attributes(
              serialization_cache).functions)
    else:
      serialized_functions = (
          serialization_cache[constants.KERAS_CACHE_KEY][child_layer].functions)
    if not serialized_functions:
      # This indicates either:
      #   - circular dependency, which means the current layer's functions
      #     should be wrapped first.
      #   - Child layer's inputs are not defined, so its functions have not been
      #     wrapped. In this case, no replacement is necessary so move on to the
      #     next child.
      continue

    if isinstance(child_layer, metrics.Metric):
      replace_metric_functions(child_layer, serialized_functions)
    else:
      replace_layer_functions(child_layer, serialized_functions)

  return original_fns