Пример #1
0
def wrap_layer_objects(layer, serialization_cache):
    """Returns extra trackable objects to attach to the serialized layer.

  Args:
    layer: Keras Layer object.
    serialization_cache: Dictionary shared between all objects during
      serialization.

  Returns:
    A dictionary containing all checkpointable objects from a
    SerializedAttributes object. See LayerAttributes and ModelAttributes for
    entire list of objects
  """
    # Wrap all regularization losses as tf.functions.
    # First, generate list of all regularization losses in this layer and
    # sublayers.
    all_losses = layer._callable_losses[:]  # pylint: disable=protected-access
    for child_layer in utils.list_all_layers(layer):
        all_losses.extend(child_layer._callable_losses)  # pylint: disable=protected-access
    # Next, wrap all loss functions as tf.functions. Use the serialization cache
    # to store already-wrapped functions.
    keras_loss_cache = serialization_cache.setdefault('keras_losses', {})
    wrapped_loss_functions = []
    for loss_fn in all_losses:
        if loss_fn in keras_loss_cache:
            wrapped_loss_functions.append(keras_loss_cache[loss_fn])
        else:
            wrapped_loss = _wrap_unconditional_loss(loss_fn,
                                                    len(keras_loss_cache))
            keras_loss_cache[loss_fn] = wrapped_loss
            wrapped_loss_functions.append(wrapped_loss)
    wrapped_layer_losses = [
        keras_loss_cache[fn] for fn in layer._callable_losses[:]
    ]  # pylint: disable=protected-access

    layer_metrics = data_structures._DictWrapper(  # pylint: disable=protected-access
        {m.name: m
         for m in layer._metrics})  # pylint: disable=protected-access
    return dict(variables=data_structures.ListWrapper(layer.variables),
                trainable_variables=data_structures.ListWrapper(
                    layer.trainable_variables),
                non_trainable_variables=data_structures.ListWrapper(
                    layer.non_trainable_variables),
                layers=data_structures.ListWrapper(
                    utils.list_all_layers(layer)),
                metrics=data_structures.ListWrapper(layer.metrics),
                regularization_losses=data_structures.ListWrapper(
                    wrapped_loss_functions),
                layer_regularization_losses=data_structures.ListWrapper(
                    wrapped_layer_losses),
                layer_metrics=layer_metrics)
Пример #2
0
def _replace_child_layer_functions(layer, serialization_cache):
  """Replaces functions in the children layers with wrapped tf.functions.

  This step allows functions from parent layers to reference the wrapped
  functions from their children layers instead of retracing the ops.

  This function also resets all losses stored in the layer. These are stored in
  the returned dictionary. Use `_restore_child_layer_functions` to restore
  the original attributes.

  Args:
    layer: Keras Layer object.
    serialization_cache: Dictionary shared between all objects during
      serialization.

  Returns:
    Dictionary mapping layer objects -> original functions and losses:
      { Child layer 1: {
          'losses': Original losses,
          'call': Original call function
          'activity_regularizer': Original activity regularizer},
        Child layer 2: ...
      }
  """
  # pylint: disable=protected-access
  original_fns = {}
  for child_layer in utils.list_all_layers(layer):
    if isinstance(child_layer, input_layer.InputLayer):
      continue

    if child_layer not in serialization_cache[constants.KERAS_CACHE_KEY]:
      layer_fns = (
          child_layer._trackable_saved_model_saver._get_serialized_attributes(
              serialization_cache).functions)
    else:
      layer_fns = (
          serialization_cache[constants.KERAS_CACHE_KEY][child_layer].functions)
    if not layer_fns:
      # This indicates either:
      #   - circular dependency, which means the current layer's functions
      #     should be wrapped first.
      #   - Child layer's inputs are not defined, so its functions have not been
      #     wrapped. In this case, no replacement is necessary so move on to the
      #     next child.
      continue
    original_fns[child_layer] = {
        'call': child_layer.call,
        'activity_regularizer': child_layer._activity_regularizer
    }
    with trackable.no_automatic_dependency_tracking_scope(child_layer):
      try:
        child_layer._activity_regularizer = layer_fns.get(
            'activity_regularizer_fn')
      except AttributeError:
        # Some layers have an unsettable activity regularizer.
        pass
      child_layer.call = utils.use_wrapped_call(
          child_layer, layer_fns['call_and_return_conditional_losses'],
          default_training_value=False)
  return original_fns
Пример #3
0
def _reset_layer_losses(parent_layer):
    """Resets losses of layer and its sublayers, and returns original losses."""
    losses_dict = {}
    for layer in utils.list_all_layers(parent_layer) + [parent_layer]:
        losses_dict[layer] = {
            'losses': layer._losses[:],
            'eager_losses': layer._eager_losses[:]
        }
        with trackable.no_automatic_dependency_tracking_scope(layer):
            layer._losses = []
            layer._eager_losses = []
    return losses_dict