Example #1
0
def _finalize_saved_model_layers(layers):
  """Runs the final steps of loading Keras Layers from SavedModel."""
  # pylint: disable=protected-access
  # 1. Set up call functions for all layers (skip this step for Sequential and
  # Functional models).
  for layer in layers:
    layer.built = True
    if hasattr(_get_keras_attr(layer), 'call_and_return_conditional_losses'):
      layer.call = utils.use_wrapped_call(
          layer, _get_keras_attr(layer).call_and_return_conditional_losses,
          return_method=True)
      layer._init_call_fn_args()
    else:
      layer.call = types.MethodType(
          _unable_to_call_layer_due_to_serialization_issue, layer)

  for layer in layers:
    # 2. Set model inputs and outputs.
    if isinstance(layer, RevivedNetwork):
      _set_network_attributes_from_metadata(layer)

      if hasattr(_get_keras_attr(layer), 'call_and_return_conditional_losses'):
        call_fn = _get_keras_attr(layer).call_and_return_conditional_losses
        if call_fn.input_signature is None:
          inputs = infer_inputs_from_restored_call_function(call_fn)
        else:
          inputs = call_fn.input_signature[0]
        layer._set_inputs(inputs)  # pylint: disable=protected-access

    # 3. Add losses that aren't generated by the layer.call function.
    _restore_layer_unconditional_losses(layer)
    _restore_layer_activation_loss(layer)

    # 4. Restore metrics list
    _restore_layer_metrics(layer)
Example #2
0
def _replace_child_layer_functions(layer, serialization_cache):
  """Replaces functions in the children layers with wrapped tf.functions.

  This step allows functions from parent layers to reference the wrapped
  functions from their children layers instead of retracing the ops.

  This function also resets all losses stored in the layer. These are stored in
  the returned dictionary. Use `_restore_child_layer_functions` to restore
  the original attributes.

  Args:
    layer: Keras Layer object.
    serialization_cache: Dictionary shared between all objects during
      serialization.

  Returns:
    Dictionary mapping layer objects -> original functions and losses:
      { Child layer 1: {
          'losses': Original losses,
          'call': Original call function
          'activity_regularizer': Original activity regularizer},
        Child layer 2: ...
      }
  """
  # pylint: disable=protected-access
  original_fns = {}
  for child_layer in utils.list_all_layers(layer):
    if isinstance(child_layer, input_layer.InputLayer):
      continue

    if child_layer not in serialization_cache[constants.KERAS_CACHE_KEY]:
      layer_fns = (
          child_layer._trackable_saved_model_saver._get_serialized_attributes(
              serialization_cache).functions)
    else:
      layer_fns = (
          serialization_cache[constants.KERAS_CACHE_KEY][child_layer].functions)
    if not layer_fns:
      # This indicates either:
      #   - circular dependency, which means the current layer's functions
      #     should be wrapped first.
      #   - Child layer's inputs are not defined, so its functions have not been
      #     wrapped. In this case, no replacement is necessary so move on to the
      #     next child.
      continue
    original_fns[child_layer] = {
        'call': child_layer.call,
        'activity_regularizer': child_layer._activity_regularizer
    }
    with trackable.no_automatic_dependency_tracking_scope(child_layer):
      try:
        child_layer._activity_regularizer = layer_fns.get(
            'activity_regularizer_fn')
      except AttributeError:
        # Some layers have an unsettable activity regularizer.
        pass
      child_layer.call = utils.use_wrapped_call(
          child_layer, layer_fns['call_and_return_conditional_losses'],
          default_training_value=False)
  return original_fns
Example #3
0
    def _finalize(self):
        # pylint: disable=protected-access
        for node in self._nodes:
            if isinstance(node, RevivedLayer):
                if not isinstance(node, RevivedSequential):
                    if hasattr(node.keras_api,
                               'call_and_return_conditional_losses'):
                        node.call = utils.use_wrapped_call(
                            node,
                            node.keras_api.call_and_return_conditional_losses,
                            return_method=True)
                        node._init_call_fn_args()

        for node in self._nodes:
            if isinstance(node, RevivedModel):
                call_fn = node.keras_api.call_and_return_conditional_losses
                if call_fn.input_signature is None:
                    inputs = infer_inputs_from_restored_call_function(call_fn)
                else:
                    inputs = call_fn.input_signature[0]
                if isinstance(node, RevivedSequential):
                    with trackable.no_automatic_dependency_tracking_scope(
                            node):
                        node._layers = []
                    for layer in node.keras_api.layers:
                        node.add(layer)

                if not node.inputs:
                    # Since this revived object is technically a subclassed model (even if
                    # the original model is functional/sequential), inputs should be set.
                    node._set_inputs(inputs)
            if isinstance(node, RevivedLayer):
                if hasattr(node.keras_api, 'layer_regularization_losses'):
                    losses = getattr(node.keras_api,
                                     'layer_regularization_losses', [])
                else:
                    # Some earlier SavedModels may not have layer_regularization_losses
                    # serialized separately. Fall back to using the regularization_losses
                    # list if it does not exist.
                    losses = node._serialized_attributes.get(
                        'regularization_losses', [])
                for loss in losses:
                    node.add_loss(loss)

                # Use wrapped activity regularizer function if the layer's activity
                # regularizer wasn't created during initialization.
                if node.activity_regularizer is None:
                    node.activity_regularizer = getattr(
                        node.keras_api, 'activity_regularizer_fn', None)

                # Now that the node object has been fully loaded and restored from the,
                # checkpoint, the object no longer needs to track objects added from
                # SerializedAttributes. (Note that saving a training checkpoint still
                # functions correctly, because layers and variables are tracked
                # separately by the Layer object.)
                # TODO(kathywu): Instead of outright deleting these nodes (which would
                # make restoring from a different checkpoint tricky), mark them as extra
                # dependencies that are OK to overwrite.
                for name in PUBLIC_ATTRIBUTES:
                    delete_tracking(node, name)
Example #4
0
 def replace_layer_functions(child_layer, serialized_fns):
     """Replaces layer call and activity regularizer with wrapped functions."""
     original_fns[child_layer] = {
         'call': child_layer.call,
         '_activity_regularizer': child_layer._activity_regularizer
     }
     with trackable.no_automatic_dependency_tracking_scope(child_layer):
         try:
             child_layer._activity_regularizer = serialized_fns.get(
                 'activity_regularizer_fn')
         except AttributeError:
             # Some layers have an unsettable activity regularizer.
             pass
         child_layer.call = utils.use_wrapped_call(
             child_layer,
             serialized_fns['call_and_return_conditional_losses'],
             default_training_value=False)
Example #5
0
def _finalize_saved_model_layers(layers):
  """Runs the final steps of loading Keras Layers from SavedModel."""
  # pylint: disable=protected-access
  # 1. Set up call functions for all layers initialized from the SavedModel (
  # and not the config)
  for layer in layers:
    layer.built = True
    layer_call = getattr(_get_keras_attr(layer),
                         'call_and_return_conditional_losses', None)
    if layer_call and layer_call.concrete_functions:
      layer.call = utils.use_wrapped_call(
          layer, layer_call, return_method=True)
      expects_training_arg = layer._serialized_attributes['metadata'][
          'expects_training_arg']
      if 'training' in layer_call.function_spec.arg_names:
        # This could change the value of `expects_training_arg` if this layer
        # doesn't expect a training arg, but has a child layer that does.
        expects_training_arg = True
      layer._init_call_fn_args(expects_training_arg)
    else:
      layer.call = types.MethodType(
          _unable_to_call_layer_due_to_serialization_issue, layer)

  for layer in layers:
    # 2. Set model inputs and outputs.
    if isinstance(layer, RevivedNetwork):
      _set_network_attributes_from_metadata(layer)

      if hasattr(_get_keras_attr(layer), 'call_and_return_conditional_losses'):
        call_fn = _get_keras_attr(layer).call_and_return_conditional_losses
        if not call_fn.concrete_functions:
          continue
        if call_fn.input_signature is None:
          inputs = infer_inputs_from_restored_call_function(call_fn)
        else:
          inputs = call_fn.input_signature[0]
        layer._set_inputs(inputs)  # pylint: disable=protected-access

    # 3. Add losses that aren't generated by the layer.call function.
    _restore_layer_unconditional_losses(layer)
    _restore_layer_activation_loss(layer)

    # 4. Restore metrics list
    _restore_layer_metrics(layer)
Example #6
0
    def _finalize(self):
        # pylint: disable=protected-access

        # Set up call functions for all layers (skip this step for Sequential and
        # Functional models).
        for node in self._nodes:
            if isinstance(node, RevivedLayer):
                node.built = True
                is_graph_network = node._serialized_attributes['metadata'].get(
                    'is_graph_network', False)
                if not (isinstance(node, models_lib.Sequential)
                        or is_graph_network):
                    if hasattr(node.keras_api,
                               'call_and_return_conditional_losses'):
                        node.call = utils.use_wrapped_call(
                            node,
                            node.keras_api.call_and_return_conditional_losses,
                            return_method=True)
                        node._init_call_fn_args()

        for node in self._nodes:
            if isinstance(node, RevivedNetwork):
                call_fn = node.keras_api.call_and_return_conditional_losses
                if call_fn.input_signature is None:
                    inputs = infer_inputs_from_restored_call_function(call_fn)
                else:
                    inputs = call_fn.input_signature[0]

                # Set model inputs and outputs.
                is_graph_network = node._serialized_attributes['metadata'].get(
                    'is_graph_network', False)
                if isinstance(node, models_lib.Sequential):
                    with trackable.no_automatic_dependency_tracking_scope(
                            node):
                        node._layers = []
                    for layer in node.keras_api.layers:
                        node.add(layer)
                elif is_graph_network:
                    # Reconstruct functional model from the config and layers loaded
                    # from the SavedModel.
                    inputs, outputs, _ = network_lib.reconstruct_from_config(
                        node.get_config(),
                        created_layers={
                            layer.name: layer
                            for layer in node.layers
                        })
                    node._init_graph_network(
                        inputs,
                        outputs,
                        name=node._serialized_attributes['metadata']['name'])
                    # Set the metadata attributes once more, since _init_graph_network
                    # resets these attributes.
                    _set_network_attributes_from_metadata(node)
                else:  # Model is subclassed.
                    node._set_inputs(inputs)

            # Add unconditional losses.
            if isinstance(node, RevivedLayer):
                if hasattr(node.keras_api, 'layer_regularization_losses'):
                    losses = getattr(node.keras_api,
                                     'layer_regularization_losses', [])
                else:
                    # Some earlier SavedModels may not have layer_regularization_losses
                    # serialized separately. Fall back to using the regularization_losses
                    # list if it does not exist.
                    losses = node._serialized_attributes.get(
                        'regularization_losses', [])
                for loss in losses:
                    node.add_loss(loss)

                # Use wrapped activity regularizer function if the layer's activity
                # regularizer wasn't created during initialization.
                if node.activity_regularizer is None:
                    node.activity_regularizer = getattr(
                        node.keras_api, 'activity_regularizer_fn', None)

                # Now that the node object has been fully loaded and restored from the,
                # checkpoint, the object no longer needs to track objects added from
                # SerializedAttributes. (Note that saving a training checkpoint still
                # functions correctly, because layers and variables are tracked
                # separately by the Layer object.)
                # TODO(kathywu): Instead of outright deleting these nodes (which would
                # make restoring from a different checkpoint tricky), mark them as extra
                # dependencies that are OK to overwrite.
                for name in PUBLIC_ATTRIBUTES:
                    delete_tracking(node, name)