def generate_keras_metadata(saved_nodes, node_paths): """Constructs a KerasMetadata proto with the metadata of each keras object.""" metadata = saved_metadata_pb2.SavedMetadata() for node_id, node in enumerate(saved_nodes): if isinstance(node, base_layer.Layer): path = node_paths[node] if not path: node_path = "root" else: node_path = "root.{}".format( ".".join([ref.name for ref in path])) metadata.nodes.add( node_id=node_id, node_path=node_path, version=versions_pb2.VersionDef( producer=2, min_consumer=1, bad_consumers=[]), identifier=node._object_identifier, # pylint: disable=protected-access metadata=node._tracking_metadata) # pylint: disable=protected-access # Log warning if the node's class name conflicts with a Keras built-in # object. class_name = node.__class__.__name__ builtin_layer = serialization.get_builtin_layer(class_name) if builtin_layer: if not isinstance(node, builtin_layer): logging.warning( "%s has the same name '%s' as a built-in Keras " "object. Consider renaming %s to avoid naming " "conflicts when loading with " "`tf.keras.models.load_model`. If renaming is not possible, pass " "the object in the `custom_objects` parameter of the load " "function.", node, class_name, node.__class__) return metadata
def testAddFullSaveSpec(self): save_spec = tf.TensorSpec([3, 5], dtype=tf.int32) node_metadata = json_utils.Encoder().encode({'save_spec': save_spec}) metadata = saved_metadata_pb2.SavedMetadata() metadata.nodes.add( version=versions_pb2.VersionDef( producer=1, min_consumer=1, bad_consumers=[]), identifier='_tf_keras_model', metadata=node_metadata) # pylint: disable=protected-access new_metadata = keras_load._update_to_current_version(metadata) node_metadata = json_utils.decode(new_metadata.nodes[0].metadata) expected_full_spec = ([tf.TensorSpec(shape=(3, 5), dtype=tf.int32)], {}) self.assertAllEqual(expected_full_spec, node_metadata.get('full_save_spec'))
def _read_legacy_metadata(object_graph_def, metadata): """Builds a KerasMetadata proto from the SavedModel ObjectGraphDef.""" # Older SavedModels store the metadata directly in the proto instead of the # separate pb file. node_paths = _generate_object_paths(object_graph_def) for node_id, proto in enumerate(object_graph_def.nodes): if (proto.WhichOneof('kind') == 'user_object' and proto.user_object.identifier in constants.KERAS_OBJECT_IDENTIFIERS): metadata.nodes.add( node_id=node_id, node_path=node_paths[node_id], version=versions_pb2.VersionDef( producer=1, min_consumer=1, bad_consumers=[]), identifier=proto.user_object.identifier, metadata=proto.user_object.metadata)
def _read_legacy_metadata(object_graph_def, metadata): """Builds a KerasMetadata proto from the SavedModel ObjectGraphDef.""" # Older SavedModels store the metadata directly in the proto instead of the # separate pb file. node_paths = _generate_object_paths(object_graph_def) for node_id, proto in enumerate(object_graph_def.nodes): if (proto.WhichOneof('kind') == 'user_object' and proto.user_object.identifier in constants.KERAS_OBJECT_IDENTIFIERS): if not proto.user_object.metadata: raise ValueError('Unable to create a Keras model from this SavedModel. ' 'This SavedModel was created with ' '`tf.saved_model.save`, and lacks the Keras metadata.' 'Please save your Keras model by calling `model.save`' 'or `tf.keras.models.save_model`.') metadata.nodes.add( node_id=node_id, node_path=node_paths[node_id], version=versions_pb2.VersionDef( producer=1, min_consumer=1, bad_consumers=[]), identifier=proto.user_object.identifier, metadata=proto.user_object.metadata)
def generate_keras_metadata(saved_nodes, node_paths): """Constructs a KerasMetadata proto with the metadata of each keras object.""" metadata = saved_metadata_pb2.SavedMetadata() for node_id, node in enumerate(saved_nodes): if isinstance(node, base_layer.Layer): path = node_paths[node] if not path: node_path = "root" else: node_path = "root.{}".format( ".".join([ref.name for ref in path])) metadata.nodes.add( node_id=node_id, node_path=node_path, version=versions_pb2.VersionDef( producer=2, min_consumer=1, bad_consumers=[]), identifier=node._object_identifier, # pylint: disable=protected-access metadata=node._tracking_metadata) # pylint: disable=protected-access return metadata