예제 #1
0
def generate_keras_metadata(saved_nodes, node_paths):
  """Constructs a KerasMetadata proto with the metadata of each keras object."""
  metadata = saved_metadata_pb2.SavedMetadata()
  for node_id, node in enumerate(saved_nodes):
    if isinstance(node, base_layer.Layer):
      path = node_paths[node]
      if not path:
        node_path = "root"
      else:
        node_path = "root.{}".format(
            ".".join([ref.name for ref in path]))

      metadata.nodes.add(
          node_id=node_id,
          node_path=node_path,
          version=versions_pb2.VersionDef(
              producer=2, min_consumer=1, bad_consumers=[]),
          identifier=node._object_identifier,  # pylint: disable=protected-access
          metadata=node._tracking_metadata)  # pylint: disable=protected-access

      # Log warning if the node's class name conflicts with a Keras built-in
      # object.
      class_name = node.__class__.__name__
      builtin_layer = serialization.get_builtin_layer(class_name)
      if builtin_layer:
        if not isinstance(node, builtin_layer):
          logging.warning(
              "%s has the same name '%s' as a built-in Keras "
              "object. Consider renaming %s to avoid naming "
              "conflicts when loading with "
              "`tf.keras.models.load_model`. If renaming is not possible, pass "
              "the object in the `custom_objects` parameter of the load "
              "function.", node, class_name, node.__class__)

  return metadata
예제 #2
0
  def testAddFullSaveSpec(self):
    save_spec = tf.TensorSpec([3, 5], dtype=tf.int32)
    node_metadata = json_utils.Encoder().encode({'save_spec': save_spec})

    metadata = saved_metadata_pb2.SavedMetadata()
    metadata.nodes.add(
        version=versions_pb2.VersionDef(
            producer=1, min_consumer=1, bad_consumers=[]),
        identifier='_tf_keras_model',
        metadata=node_metadata)  # pylint: disable=protected-access

    new_metadata = keras_load._update_to_current_version(metadata)
    node_metadata = json_utils.decode(new_metadata.nodes[0].metadata)
    expected_full_spec = ([tf.TensorSpec(shape=(3, 5), dtype=tf.int32)], {})
    self.assertAllEqual(expected_full_spec, node_metadata.get('full_save_spec'))
예제 #3
0
def _read_legacy_metadata(object_graph_def, metadata):
  """Builds a KerasMetadata proto from the SavedModel ObjectGraphDef."""
  # Older SavedModels store the metadata directly in the proto instead of the
  # separate pb file.
  node_paths = _generate_object_paths(object_graph_def)
  for node_id, proto in enumerate(object_graph_def.nodes):
    if (proto.WhichOneof('kind') == 'user_object' and
        proto.user_object.identifier in constants.KERAS_OBJECT_IDENTIFIERS):
      metadata.nodes.add(
          node_id=node_id,
          node_path=node_paths[node_id],
          version=versions_pb2.VersionDef(
              producer=1, min_consumer=1, bad_consumers=[]),
          identifier=proto.user_object.identifier,
          metadata=proto.user_object.metadata)
예제 #4
0
def _read_legacy_metadata(object_graph_def, metadata):
  """Builds a KerasMetadata proto from the SavedModel ObjectGraphDef."""
  # Older SavedModels store the metadata directly in the proto instead of the
  # separate pb file.
  node_paths = _generate_object_paths(object_graph_def)
  for node_id, proto in enumerate(object_graph_def.nodes):
    if (proto.WhichOneof('kind') == 'user_object' and
        proto.user_object.identifier in constants.KERAS_OBJECT_IDENTIFIERS):
      if not proto.user_object.metadata:
        raise ValueError('Unable to create a Keras model from this SavedModel. '
                         'This SavedModel was created with '
                         '`tf.saved_model.save`, and lacks the Keras metadata.'
                         'Please save your Keras model by calling `model.save`'
                         'or `tf.keras.models.save_model`.')
      metadata.nodes.add(
          node_id=node_id,
          node_path=node_paths[node_id],
          version=versions_pb2.VersionDef(
              producer=1, min_consumer=1, bad_consumers=[]),
          identifier=proto.user_object.identifier,
          metadata=proto.user_object.metadata)
예제 #5
0
def generate_keras_metadata(saved_nodes, node_paths):
  """Constructs a KerasMetadata proto with the metadata of each keras object."""
  metadata = saved_metadata_pb2.SavedMetadata()

  for node_id, node in enumerate(saved_nodes):
    if isinstance(node, base_layer.Layer):
      path = node_paths[node]
      if not path:
        node_path = "root"
      else:
        node_path = "root.{}".format(
            ".".join([ref.name for ref in path]))

      metadata.nodes.add(
          node_id=node_id,
          node_path=node_path,
          version=versions_pb2.VersionDef(
              producer=2, min_consumer=1, bad_consumers=[]),
          identifier=node._object_identifier,  # pylint: disable=protected-access
          metadata=node._tracking_metadata)  # pylint: disable=protected-access

  return metadata