Beispiel #1
0
  def _init_from_metadata(cls, metadata):
    """Create revived layer from metadata stored in the SavedModel proto."""
    init_args = dict(
        name=metadata['name'],
        trainable=metadata['trainable'])
    if metadata.get('dtype') is not None:
      init_args['dtype'] = metadata['dtype']
    if metadata.get('batch_input_shape') is not None:
      init_args['batch_input_shape'] = metadata['batch_input_shape']

    revived_obj = cls(**init_args)

    with utils.no_automatic_dependency_tracking_scope(revived_obj):
      # pylint:disable=protected-access
      revived_obj._expects_training_arg = metadata['expects_training_arg']
      config = metadata.get('config')
      if generic_utils.validate_config(config):
        revived_obj._config = config
      if metadata.get('input_spec') is not None:
        revived_obj.input_spec = recursively_deserialize_keras_object(
            metadata['input_spec'],
            module_objects={'InputSpec': input_spec.InputSpec})
      if metadata.get('activity_regularizer') is not None:
        revived_obj.activity_regularizer = regularizers.deserialize(
            metadata['activity_regularizer'])
      if metadata.get('_is_feature_layer') is not None:
        revived_obj._is_feature_layer = metadata['_is_feature_layer']
      if metadata.get('stateful') is not None:
        revived_obj.stateful = metadata['stateful']
      # pylint:enable=protected-access

    return revived_obj, _revive_setter
Beispiel #2
0
def _maybe_add_serialized_attributes(layer, metadata):
  # Store attributes revived from SerializedAttributes in a un-tracked
  # dictionary. The attributes are the ones listed in CommonEndpoints or
  # "keras_api" for keras-specific attributes.
  if not hasattr(layer, '_serialized_attributes'):
    with utils.no_automatic_dependency_tracking_scope(layer):
      layer._serialized_attributes = {'metadata': metadata}  # pylint: disable=protected-access
Beispiel #3
0
def _set_network_attributes_from_metadata(revived_obj):
  """Sets attributes recorded in the metadata."""
  with utils.no_automatic_dependency_tracking_scope(revived_obj):
    # pylint:disable=protected-access
    metadata = revived_obj._serialized_attributes['metadata']
    if metadata.get('dtype') is not None:
      revived_obj._set_dtype_policy(metadata['dtype'])
    revived_obj._trainable = metadata['trainable']
Beispiel #4
0
def _restore_child_layer_functions(original_fns):
    """Restores attributes replaced with `_replace_child_layer_functions`."""
    for child_layer, fns in original_fns.items():
        with utils.no_automatic_dependency_tracking_scope(child_layer):
            for fn_name, fn in fns.items():
                try:
                    setattr(child_layer, fn_name, fn)  # pylint: disable=protected-access
                except AttributeError:
                    pass  # In the case of _activity_regularizer, setting the attribute
Beispiel #5
0
 def replace_metric_functions(child_layer, serialized_fns):
     """Replaces metric functions with wrapped functions."""
     original_fns[child_layer] = {
         "__call__": child_layer.__call__,
         "result": child_layer.result,
         "update_state": child_layer.update_state,
     }
     with utils.no_automatic_dependency_tracking_scope(child_layer):
         child_layer.__call__ = serialized_fns["__call__"]
         child_layer.result = serialized_fns["result"]
         child_layer.update_state = serialized_fns["update_state"]
Beispiel #6
0
 def replace_metric_functions(child_layer, serialized_fns):
     """Replaces metric functions with wrapped functions."""
     original_fns[child_layer] = {
         '__call__': child_layer.__call__,
         'result': child_layer.result,
         'update_state': child_layer.update_state
     }
     with utils.no_automatic_dependency_tracking_scope(child_layer):
         child_layer.__call__ = serialized_fns['__call__']
         child_layer.result = serialized_fns['result']
         child_layer.update_state = serialized_fns['update_state']
Beispiel #7
0
def _restore_child_layer_functions(original_fns):
    """Restores attributes replaced with `_replace_child_layer_functions`."""
    for child_layer, fns in original_fns.items():
        with utils.no_automatic_dependency_tracking_scope(child_layer):
            for fn_name, fn in fns.items():
                try:
                    setattr(child_layer, fn_name, fn)
                except AttributeError:
                    # In the case of _activity_regularizer, setting the
                    # attribute may be disallowed.
                    pass
Beispiel #8
0
def _reset_layer_losses(parent_layer):
    """Resets losses of layer and its sublayers, and returns original losses."""
    losses_dict = {}
    for layer in utils.list_all_layers_and_sublayers(parent_layer):
        losses_dict[layer] = {
            "losses": layer._losses[:],
            "eager_losses": layer._eager_losses[:],
        }
        with utils.no_automatic_dependency_tracking_scope(layer):
            layer._losses = []
            layer._eager_losses = []
    return losses_dict
Beispiel #9
0
  def _init_from_metadata(cls, metadata):
    """Revives the saved InputLayer from the Metadata."""
    init_args = dict(
        name=metadata['name'],
        dtype=metadata['dtype'],
        sparse=metadata['sparse'],
        ragged=metadata['ragged'],
        batch_input_shape=metadata['batch_input_shape'])
    revived_obj = cls(**init_args)
    with utils.no_automatic_dependency_tracking_scope(revived_obj):
      revived_obj._config = metadata['config']  # pylint:disable=protected-access

    return revived_obj, setattr
Beispiel #10
0
 def replace_layer_functions(child_layer, serialized_fns):
     """Replaces layer call and activity regularizer with wrapped functions."""
     original_fns[child_layer] = {
         'call': child_layer.call,
         '_activity_regularizer': child_layer._activity_regularizer
     }
     with utils.no_automatic_dependency_tracking_scope(child_layer):
         try:
             child_layer._activity_regularizer = serialized_fns.get(
                 'activity_regularizer_fn')
         except AttributeError:
             # Some layers have an unsettable activity regularizer.
             pass
         child_layer.call = utils.use_wrapped_call(
             child_layer,
             serialized_fns['call_and_return_conditional_losses'],
             default_training_value=False)
Beispiel #11
0
  def _init_from_metadata(cls, metadata):
    """Create revived network from metadata stored in the SavedModel proto."""
    revived_obj = cls(name=metadata['name'])

    # Store attributes revived from SerializedAttributes in a un-tracked
    # dictionary. The attributes are the ones listed in CommonEndpoints or
    # "keras_api" for keras-specific attributes.
    with utils.no_automatic_dependency_tracking_scope(revived_obj):
      # pylint:disable=protected-access
      revived_obj._expects_training_arg = metadata['expects_training_arg']
      config = metadata.get('config')
      if generic_utils.validate_config(config):
        revived_obj._config = config

      if metadata.get('activity_regularizer') is not None:
        revived_obj.activity_regularizer = regularizers.deserialize(
            metadata['activity_regularizer'])
      # pylint:enable=protected-access

    return revived_obj, _revive_setter  # pylint:disable=protected-access
Beispiel #12
0
def _restore_layer_losses(losses_dict):
    for layer in losses_dict:
        with utils.no_automatic_dependency_tracking_scope(layer):
            layer._losses = losses_dict[layer]["losses"]
            layer._eager_losses = losses_dict[layer]["eager_losses"]