Ejemplo n.º 1
0
  def _revive_layer_or_model_from_config(self, metadata, node_id):
    """Revives a layer/custom model from config; returns None if infeasible."""
    # Check that the following requirements are met for reviving from config:
    #    1. Object can be deserialized from config.
    #    2. If the object needs to be built, then the build input shape can be
    #       found.
    class_name = metadata.get('class_name')
    config = metadata.get('config')
    shared_object_id = metadata.get('shared_object_id')
    must_restore_from_config = metadata.get('must_restore_from_config')
    if not generic_utils.validate_config(config):
      return None

    try:
      obj = layers_module.deserialize(
          generic_utils.serialize_keras_class_and_config(
              class_name, config, shared_object_id=shared_object_id))
    except ValueError:
      if must_restore_from_config:
        raise RuntimeError(
            'Unable to restore a layer of class {cls}. Layers of '
            'class {cls} require that the class be provided to '
            'the model loading code, either by registering the '
            'class using @keras.utils.register_keras_serializable '
            'on the class def and including that file in your '
            'program, or by passing the class in a '
            'keras.utils.CustomObjectScope that wraps this load '
            'call.'.format(cls=class_name))
      else:
        return None

    # Use the dtype, name, and trainable status. Often times these are not
    # specified in custom configs, so retrieve their values from the metadata.
    # pylint: disable=protected-access
    obj._name = metadata['name']
    if metadata.get('trainable') is not None:
      obj.trainable = metadata['trainable']
    if metadata.get('dtype') is not None:
      obj._set_dtype_policy(metadata['dtype'])
    if metadata.get('stateful') is not None:
      obj.stateful = metadata['stateful']
    # Restore model save spec for subclassed models. (layers do not store a
    # SaveSpec)
    if isinstance(obj, training_lib.Model):
      save_spec = metadata.get('save_spec')
      if save_spec is not None:
        obj._set_save_spec(save_spec)
    # pylint: enable=protected-access

    build_input_shape = metadata.get('build_input_shape')
    built = self._try_build_layer(obj, node_id, build_input_shape)

    if not built:
      # If the layer cannot be built, revive a custom layer instead.
      return None
    return obj
Ejemplo n.º 2
0
def serialize_feature_column(fc):
    """Serializes a FeatureColumn or a raw string key.

  This method should only be used to serialize parent FeatureColumns when
  implementing FeatureColumn._get_config(), else serialize_feature_columns()
  is preferable.

  This serialization also keeps information of the FeatureColumn class, so
  deserialization is possible without knowing the class type. For example:

  a = numeric_column('x')
  a._get_config() gives:
  {
      'key': 'price',
      'shape': (1,),
      'default_value': None,
      'dtype': 'float32',
      'normalizer_fn': None
  }
  While serialize_feature_column(a) gives:
  {
      'class_name': 'NumericColumn',
      'config': {
          'key': 'price',
          'shape': (1,),
          'default_value': None,
          'dtype': 'float32',
          'normalizer_fn': None
      }
  }

  Args:
    fc: A FeatureColumn or raw feature key string.

  Returns:
    Keras serialization for FeatureColumns, leaves string keys unaffected.

  Raises:
    ValueError if called with input that is not string or FeatureColumn.
  """
    # Import here to avoid circular imports.
    from tensorflow.python.keras.utils import generic_utils  # pylint: disable=g-import-not-at-top

    if isinstance(fc, six.string_types):
        return fc
    elif isinstance(fc, fc_lib.FeatureColumn):
        return generic_utils.serialize_keras_class_and_config(
            fc.__class__.__name__, fc._get_config())  # pylint: disable=protected-access
    else:
        raise ValueError('Instance: {} is not a FeatureColumn'.format(fc))
Ejemplo n.º 3
0
  def _revive_metric_from_config(self, metadata, node_id):
    class_name = compat.as_str(metadata['class_name'])
    config = metadata.get('config')

    if not generic_utils.validate_config(config):
      return None

    try:
      obj = metrics.deserialize(
          generic_utils.serialize_keras_class_and_config(class_name, config))
    except ValueError:
      return None

    build_input_shape = metadata.get('build_input_shape')
    if build_input_shape is not None and hasattr(obj, '_build'):
      obj._build(build_input_shape)  # pylint: disable=protected-access

    return obj
Ejemplo n.º 4
0
    def _revive_layer_from_config(self, metadata, node_id):
        """Revives a layer from config, or returns None if infeasible."""
        # Check that the following requirements are met for reviving from config:
        #    1. Object can be deserialized from config.
        #    2. If the object needs to be built, then the build input shape can be
        #       found.
        class_name = metadata.get('class_name')
        config = metadata.get('config')
        if not generic_utils.validate_config(config):
            return None

        try:
            obj = layers_module.deserialize(
                generic_utils.serialize_keras_class_and_config(
                    class_name, config))
        except ValueError:
            return None

        # Use the dtype, name, and trainable status. Often times these are not
        # specified in custom configs, so retrieve their values from the metadata.
        # pylint: disable=protected-access
        obj._name = metadata['name']
        if metadata.get('trainable') is not None:
            obj.trainable = metadata['trainable']
        if metadata.get('dtype') is not None:
            obj._set_dtype_policy(metadata['dtype'])
        if metadata.get('stateful') is not None:
            obj.stateful = metadata['stateful']
        # pylint: enable=protected-access

        build_input_shape = metadata.get('build_input_shape')
        built = self._try_build_layer(obj, node_id, build_input_shape)

        if not built:
            # If the layer cannot be built, revive a custom layer instead.
            return None

        return obj
Ejemplo n.º 5
0
    def _revive_layer_from_config(self, metadata, node_id):
        """Revives a layer from config, or returns None if infeasible."""
        # Check that the following requirements are met for reviving from config:
        #    1. Object can be deserialized from config.
        #    2. If the object needs to be built, then the build input shape can be
        #       found.
        class_name = metadata.get('class_name')
        config = metadata.get('config')
        if config is None or generic_utils.LAYER_UNDEFINED_CONFIG_KEY in config:
            return None

        try:
            obj = layers_module.deserialize(
                generic_utils.serialize_keras_class_and_config(
                    class_name, config))
        except ValueError:
            return None

        # Use the dtype, name, and trainable status. Often times these are not
        # specified in custom configs, so retrieve their values from the metadata.
        # pylint: disable=protected-access
        obj._name = metadata['name']
        if metadata.get('trainable') is not None:
            obj.trainable = metadata['trainable']
        if metadata.get('dtype') is not None:
            obj._set_dtype_policy(metadata['dtype'])
        # pylint: enable=protected-access

        input_shape = None
        if not isinstance(obj, input_layer.InputLayer):
            input_shape = self._infer_inputs(node_id, convert_to_shapes=True)
            if input_shape is None:
                return None
        obj.build(input_shape)
        obj.built = True

        return obj