Exemple #1
0
  def _revive_graph_network(self, metadata, node_id):
    """Revives a graph network from config."""
    # Determine whether the metadata contains information for reviving a
    # functional or Sequential model.
    config = metadata.get('config')
    if not generic_utils.validate_config(config):
      return None

    class_name = compat.as_str(metadata['class_name'])
    if generic_utils.get_registered_object(class_name) is not None:
      return None
    model_is_functional_or_sequential = (
        metadata.get('is_graph_network', False) or
        class_name == 'Sequential' or
        class_name == 'Functional')
    if not model_is_functional_or_sequential:
      return None

    # Revive functional and sequential models as blank model objects for now (
    # must be initialized to enable setattr tracking and attribute caching).
    # Reconstruction of the network is deferred until all of the model's layers
    # have been revived.
    if class_name == 'Sequential':
      model = models_lib.Sequential(name=config['name'])
    else:
      model = models_lib.Functional(
          inputs=[], outputs=[], name=config['name'])

    # Record this model and its layers. This will later be used to reconstruct
    # the model.
    layers = self._get_child_layer_node_ids(node_id)
    self.model_layer_dependencies[node_id] = (model, layers)
    if not layers:
      self._models_to_reconstruct.append(node_id)
    return model
Exemple #2
0
  def _init_from_metadata(cls, metadata):
    """Create revived layer from metadata stored in the SavedModel proto."""
    init_args = dict(
        name=metadata['name'],
        trainable=metadata['trainable'])
    if metadata.get('dtype') is not None:
      init_args['dtype'] = metadata['dtype']
    if metadata.get('batch_input_shape') is not None:
      init_args['batch_input_shape'] = metadata['batch_input_shape']

    revived_obj = cls(**init_args)

    with trackable.no_automatic_dependency_tracking_scope(revived_obj):
      # pylint:disable=protected-access
      revived_obj._expects_training_arg = metadata['expects_training_arg']
      config = metadata.get('config')
      if generic_utils.validate_config(config):
        revived_obj._config = config
      if metadata.get('input_spec') is not None:
        revived_obj.input_spec = recursively_deserialize_keras_object(
            metadata['input_spec'],
            module_objects={'InputSpec': input_spec.InputSpec})
      if metadata.get('activity_regularizer') is not None:
        revived_obj.activity_regularizer = regularizers.deserialize(
            metadata['activity_regularizer'])
      if metadata.get('_is_feature_layer') is not None:
        revived_obj._is_feature_layer = metadata['_is_feature_layer']
      if metadata.get('stateful') is not None:
        revived_obj.stateful = metadata['stateful']
      # pylint:enable=protected-access

    return revived_obj, _revive_setter
Exemple #3
0
  def _load_layer(self, node_id, identifier, metadata):
    """Load a single layer from a SavedUserObject proto."""
    metadata = json_utils.decode(metadata)

    # If node was already created
    if node_id in self.loaded_nodes:
      node, setter = self.loaded_nodes[node_id]

      # Revive setter requires the object to have a `_serialized_attributes`
      # property. Add it here.
      _maybe_add_serialized_attributes(node, metadata)

      config = metadata.get('config')
      if _is_graph_network(node) and generic_utils.validate_config(config):
        child_nodes = self._get_child_layer_node_ids(node_id)
        self.model_layer_dependencies[node_id] = (node, child_nodes)
        if not child_nodes:
          self._models_to_reconstruct.append(node_id)
      return node, setter

    # Detect whether this object can be revived from the config. If not, then
    # revive from the SavedModel instead.
    obj, setter = self._revive_from_config(identifier, metadata, node_id)
    if obj is None:
      obj, setter = revive_custom_object(identifier, metadata)

    # Add an attribute that stores the extra functions/objects saved in the
    # SavedModel. Most of these functions/objects are ignored, but some are
    # used later in the loading process (e.g. the list of regularization
    # losses, or the training config of compiled models).
    _maybe_add_serialized_attributes(obj, metadata)
    return obj, setter
Exemple #4
0
  def _revive_layer_or_model_from_config(self, metadata, node_id):
    """Revives a layer/custom model from config; returns None if infeasible."""
    # Check that the following requirements are met for reviving from config:
    #    1. Object can be deserialized from config.
    #    2. If the object needs to be built, then the build input shape can be
    #       found.
    class_name = metadata.get('class_name')
    config = metadata.get('config')
    shared_object_id = metadata.get('shared_object_id')
    must_restore_from_config = metadata.get('must_restore_from_config')
    if not generic_utils.validate_config(config):
      return None

    try:
      obj = layers_module.deserialize(
          generic_utils.serialize_keras_class_and_config(
              class_name, config, shared_object_id=shared_object_id))
    except ValueError:
      if must_restore_from_config:
        raise RuntimeError(
            'Unable to restore a layer of class {cls}. Layers of '
            'class {cls} require that the class be provided to '
            'the model loading code, either by registering the '
            'class using @keras.utils.register_keras_serializable '
            'on the class def and including that file in your '
            'program, or by passing the class in a '
            'keras.utils.CustomObjectScope that wraps this load '
            'call.'.format(cls=class_name))
      else:
        return None

    # Use the dtype, name, and trainable status. Often times these are not
    # specified in custom configs, so retrieve their values from the metadata.
    # pylint: disable=protected-access
    obj._name = metadata['name']
    if metadata.get('trainable') is not None:
      obj.trainable = metadata['trainable']
    if metadata.get('dtype') is not None:
      obj._set_dtype_policy(metadata['dtype'])
    if metadata.get('stateful') is not None:
      obj.stateful = metadata['stateful']
    # Restore model save spec for subclassed models. (layers do not store a
    # SaveSpec)
    if isinstance(obj, training_lib.Model):
      save_spec = metadata.get('save_spec')
      if save_spec is not None:
        obj._set_save_spec(save_spec)
    # pylint: enable=protected-access

    build_input_shape = metadata.get('build_input_shape')
    built = self._try_build_layer(obj, node_id, build_input_shape)

    if not built:
      # If the layer cannot be built, revive a custom layer instead.
      return None
    return obj
Exemple #5
0
  def _revive_metric_from_config(self, metadata, node_id):
    class_name = compat.as_str(metadata['class_name'])
    config = metadata.get('config')

    if not generic_utils.validate_config(config):
      return None

    try:
      obj = metrics.deserialize(
          generic_utils.serialize_keras_class_and_config(class_name, config))
    except ValueError:
      return None

    build_input_shape = metadata.get('build_input_shape')
    if build_input_shape is not None and hasattr(obj, '_build'):
      obj._build(build_input_shape)  # pylint: disable=protected-access

    return obj
Exemple #6
0
  def _init_from_metadata(cls, metadata):
    """Create revived network from metadata stored in the SavedModel proto."""
    revived_obj = cls(name=metadata['name'])

    # Store attributes revived from SerializedAttributes in a un-tracked
    # dictionary. The attributes are the ones listed in CommonEndpoints or
    # "keras_api" for keras-specific attributes.
    with trackable.no_automatic_dependency_tracking_scope(revived_obj):
      # pylint:disable=protected-access
      revived_obj._expects_training_arg = metadata['expects_training_arg']
      config = metadata.get('config')
      if generic_utils.validate_config(config):
        revived_obj._config = config

      if metadata.get('activity_regularizer') is not None:
        revived_obj.activity_regularizer = regularizers.deserialize(
            metadata['activity_regularizer'])
      # pylint:enable=protected-access

    return revived_obj, _revive_setter  # pylint:disable=protected-access
Exemple #7
0
    def _revive_layer_from_config(self, metadata, node_id):
        """Revives a layer from config, or returns None if infeasible."""
        # Check that the following requirements are met for reviving from config:
        #    1. Object can be deserialized from config.
        #    2. If the object needs to be built, then the build input shape can be
        #       found.
        class_name = metadata.get('class_name')
        config = metadata.get('config')
        if not generic_utils.validate_config(config):
            return None

        try:
            obj = layers_module.deserialize(
                generic_utils.serialize_keras_class_and_config(
                    class_name, config))
        except ValueError:
            return None

        # Use the dtype, name, and trainable status. Often times these are not
        # specified in custom configs, so retrieve their values from the metadata.
        # pylint: disable=protected-access
        obj._name = metadata['name']
        if metadata.get('trainable') is not None:
            obj.trainable = metadata['trainable']
        if metadata.get('dtype') is not None:
            obj._set_dtype_policy(metadata['dtype'])
        if metadata.get('stateful') is not None:
            obj.stateful = metadata['stateful']
        # pylint: enable=protected-access

        build_input_shape = metadata.get('build_input_shape')
        built = self._try_build_layer(obj, node_id, build_input_shape)

        if not built:
            # If the layer cannot be built, revive a custom layer instead.
            return None

        return obj