def test_add_pruned_signature(self): base_meta_graph_def = meta_graph_pb2.MetaGraphDef() signature_name_keep = 'test_signature_keep' base_sig_keep = base_meta_graph_def.signature_def[signature_name_keep] base_sig_keep.inputs['input_1'].name = 'input_1' base_sig_keep.outputs['output_1'].name = 'output_1' signature_name_remove = 'test_signature_remove' base_sig_remove = base_meta_graph_def.signature_def[ signature_name_remove] base_sig_remove.inputs['node2'].name = 'node2' base_sig_remove.outputs['output_1'].name = 'output_1' meta_graph_def = meta_graph_pb2.MetaGraphDef() removed_op_names = ['node2', 'node4', 'node5'] meta_graph_transform._add_pruned_signature(base_meta_graph_def, meta_graph_def, signature_name_keep, removed_op_names) meta_graph_transform._add_pruned_signature(base_meta_graph_def, meta_graph_def, signature_name_remove, removed_op_names) self.assertTrue(signature_name_keep in meta_graph_def.signature_def) sig_keep = meta_graph_def.signature_def[signature_name_keep] self.assertEqual(base_sig_keep, sig_keep) self.assertFalse(signature_name_remove in meta_graph_def.signature_def)
def testAssertProtoEqualsAny(self): # Test assertProtoEquals with a protobuf.Any field. meta_graph_def_str = """ meta_info_def { meta_graph_version: "outer" any_info { [type.googleapis.com/tensorflow.MetaGraphDef] { meta_info_def { meta_graph_version: "inner" } } } } """ meta_graph_def_outer = meta_graph_pb2.MetaGraphDef() meta_graph_def_outer.meta_info_def.meta_graph_version = "outer" meta_graph_def_inner = meta_graph_pb2.MetaGraphDef() meta_graph_def_inner.meta_info_def.meta_graph_version = "inner" meta_graph_def_outer.meta_info_def.any_info.Pack(meta_graph_def_inner) self.assertProtoEquals(meta_graph_def_str, meta_graph_def_outer) self.assertProtoEquals(meta_graph_def_outer, meta_graph_def_outer) # Check if the assertion failure message contains the content of # the inner proto. with self.assertRaisesRegexp(AssertionError, r'meta_graph_version: "inner"'): self.assertProtoEquals("", meta_graph_def_outer)
def test_add_pruned_collection_proto_in_any_list(self): # Note: This also tests _is_removed_mentioned(). collection_name = 'proto_collection' base_meta_graph_def = meta_graph_pb2.MetaGraphDef() base_meta_graph_def.collection_def[ collection_name].any_list.value.extend([ _make_asset_file_def_any('node1'), _make_asset_file_def_any('node2'), _make_asset_file_def_any('node3'), _make_asset_file_def_any('node4'), _make_asset_file_def_any('/a/a_1'), _make_asset_file_def_any('/b/b_1') ]) meta_graph_def = meta_graph_pb2.MetaGraphDef() removed_op_names = ['node2', 'node4', 'node5', '/a', '/b/b_1'] meta_graph_transform._add_pruned_collection(base_meta_graph_def, meta_graph_def, collection_name, removed_op_names) collection = meta_graph_def.collection_def[collection_name] expected_protos = [ _make_asset_file_def_any('node1'), _make_asset_file_def_any('node3'), _make_asset_file_def_any('/a/a_1'), ] self.assertEqual(expected_protos, collection.any_list.value[:])
def test_add_pruned_collection_node(self): collection_name = 'node_collection' base_meta_graph_def = meta_graph_pb2.MetaGraphDef() base_meta_graph_def.collection_def[collection_name].node_list.value.extend( ['node1', 'node2', 'node3', 'node4']) meta_graph_def = meta_graph_pb2.MetaGraphDef() removed_op_names = ['node2', 'node4', 'node5'] meta_graph_transform._add_pruned_collection( base_meta_graph_def, meta_graph_def, collection_name, removed_op_names) collection = meta_graph_def.collection_def[collection_name] expected_nodes = ['node1', 'node3'] self.assertEqual(expected_nodes, collection.node_list.value)
def test_add_pruned_saver(self): base_meta_graph_def = meta_graph_pb2.MetaGraphDef() base_meta_graph_def.saver_def.filename_tensor_name = 'node1' base_meta_graph_def.saver_def.save_tensor_name = 'node3' base_meta_graph_def.saver_def.restore_op_name = 'node6' meta_graph_def = meta_graph_pb2.MetaGraphDef() removed_op_names = ['node2', 'node4', 'node5'] meta_graph_transform._add_pruned_saver(base_meta_graph_def, meta_graph_def, removed_op_names) # TODO(b/63447631): For now the saver is just copied unchanged self.assertEqual(base_meta_graph_def.saver_def, meta_graph_def.saver_def)
def test_add_pruned_collection_int(self): collection_name = 'int_collection' base_meta_graph_def = meta_graph_pb2.MetaGraphDef() base_meta_graph_def.collection_def[collection_name].int64_list.value[:] = ( [10, 20, 30, 40]) meta_graph_def = meta_graph_pb2.MetaGraphDef() removed_op_names = ['node2', 'node4', 'node5'] meta_graph_transform._add_pruned_collection( base_meta_graph_def, meta_graph_def, collection_name, removed_op_names) collection = meta_graph_def.collection_def[collection_name] expected_ints = [10, 20, 30, 40] self.assertEqual(expected_ints, collection.int64_list.value)
def testGetNodeMapMultiple(self): meta_graph_def = meta_graph_pb2.MetaGraphDef() meta_graph_def.collection_def[ 'my_collection/%s' % encoding.KEY_SUFFIX].bytes_list.value[:] = map( encoding.encode_key, ['alpha', 'bravo', 'charlie']) meta_graph_def.collection_def[ 'my_collection/fruits'].bytes_list.value[:] = [ 'apple', 'banana', 'cherry' ] meta_graph_def.collection_def[ 'my_collection/animals'].bytes_list.value[:] = [ 'aardvark', 'badger', 'camel' ] expected = { 'alpha': { 'fruits': 'apple', 'animals': 'aardvark' }, 'bravo': { 'fruits': 'banana', 'animals': 'badger' }, 'charlie': { 'fruits': 'cherry', 'animals': 'camel' } } self.assertDictEqual( expected, graph_ref.get_node_map(meta_graph_def, 'my_collection', ['fruits', 'animals']))
def create_apply_graph(self, signature, input_tensors, name): """See `ModuleImpl.create_apply_graph`.""" signature_def = self._meta_graph.signature_def.get(signature) # Build a input map to feed when importing the apply-graph by augmenting the # state_map with the input args. This allows an input to override a tensor # from the state-graph. feed_map = dict(self._state_map) feed_map.update( tensor_info.build_input_map(signature_def.inputs, input_tensors)) # Make state tensors enter the current context. This way the Module can be # applied inside a control flow structure such as a while_loop. control_flow = self._graph._get_control_flow_context() # pylint: disable=protected-access if control_flow: for key, value in sorted(feed_map.items()): feed_map[key] = control_flow.AddValue(value) # Don't mark the name as used at this point - import_scoped_meta_graph will # start using it. absolute_scope_name = self._graph.unique_name(name, mark_as_used=False) relative_scope_name = absolute_scope_name.split("/")[-1] import_collections = [ # In most cases ASSET_FILEPATHS are only used for the TABLE_INITIALIZERS # ops, however one could create a graph that uses an asset at any other # time. As so everytime we bring the tensor with that has the asset # filename we must annotate it as so, so later re-exports have that # semantic information and can handle it. tf.GraphKeys.ASSET_FILEPATHS, tf.GraphKeys.COND_CONTEXT, tf.GraphKeys.WHILE_CONTEXT, ] if self._trainable: import_collections.extend([tf.GraphKeys.UPDATE_OPS]) meta_graph = meta_graph_pb2.MetaGraphDef() meta_graph.CopyFrom(self._meta_graph) meta_graph_lib.filter_collections(meta_graph, import_collections) meta_graph_lib.prefix_shared_name_attributes(meta_graph, absolute_scope_name) tf.train.import_meta_graph(meta_graph, input_map=feed_map, import_scope=relative_scope_name) fix_colocation_after_import(input_map=feed_map, absolute_import_scope=absolute_scope_name) def get_tensor(name): # When trying to output an input tensor there are no nodes created within # the apply scope. So one must look into the input map. try: return feed_map[name] except KeyError: return self._graph.get_tensor_by_name( meta_graph_lib.prepend_name_scope( name, import_scope=absolute_scope_name)) return tensor_info.build_output_map(signature_def.outputs, get_tensor)
def _build_meta_graph(obj, export_dir, signatures, options, meta_graph_def=None): """Creates a MetaGraph containing the resources and functions of an object.""" if ops.inside_function(): raise AssertionError( "tf.saved_model.save is not supported inside a traced " "@tf.function. Move the call to the outer eagerly-executed " "context.") # pylint: enable=line-too-long if not isinstance(obj, base.Trackable): raise ValueError( "Expected a Trackable object for export, got {}.".format(obj)) meta_graph_def = meta_graph_def or meta_graph_pb2.MetaGraphDef() checkpoint_graph_view = _AugmentedGraphView(obj) if signatures is None: signatures = signature_serialization.find_function_to_export( checkpoint_graph_view) signatures, wrapped_functions = ( signature_serialization.canonicalize_signatures(signatures)) signature_serialization.validate_saveable_view(checkpoint_graph_view) signature_map = signature_serialization.create_signature_map(signatures) checkpoint_graph_view.add_object( parent_node=checkpoint_graph_view.root, name_in_parent=signature_serialization.SIGNATURE_ATTRIBUTE_NAME, subgraph_root=signature_map) # Use _SaveableView to provide a frozen listing of properties and functions. # Note we run this twice since, while constructing the view the first time # there can be side effects of creating variables. _ = _SaveableView(checkpoint_graph_view) saveable_view = _SaveableView(checkpoint_graph_view, wrapped_functions) object_saver = util.TrackableSaver(checkpoint_graph_view) asset_info, exported_graph = _fill_meta_graph_def( meta_graph_def, saveable_view, signatures, options.namespace_whitelist) if options.function_aliases: function_aliases = meta_graph_def.meta_info_def.function_aliases for alias, func in options.function_aliases.items(): for fdef in func._stateful_fn._function_cache.all_values(): # pylint: disable=protected-access function_aliases[fdef.name] = alias for fdef in func._stateless_fn._function_cache.all_values(): # pylint: disable=protected-access function_aliases[fdef.name] = alias object_graph_proto = _serialize_object_graph(saveable_view, asset_info.asset_index) meta_graph_def.object_graph_def.CopyFrom(object_graph_proto) # Save debug info, if requested. if options.save_debug_info: graph_debug_info = _export_debug_info(exported_graph) file_io.atomic_write_string_to_file( os.path.join(utils_impl.get_or_create_debug_dir(export_dir), constants.DEBUG_INFO_FILENAME_PB), graph_debug_info.SerializeToString(deterministic=True)) return meta_graph_def, exported_graph, object_saver, asset_info
def main(_): if FLAGS.metagraphdef: with gfile.GFile(FLAGS.metagraphdef) as meta_file: metagraph = meta_graph_pb2.MetaGraphDef() metagraph.ParseFromString(meta_file.read()) else: with gfile.GFile(FLAGS.graphdef) as graph_file: graph_def = graph_pb2.GraphDef() if FLAGS.graphdef.endswith(".pbtxt"): text_format.Merge(graph_file.read(), graph_def) else: graph_def.ParseFromString(graph_file.read()) importer.import_graph_def(graph_def, name="") graph = ops.get_default_graph() fetch = graph.get_operation_by_name(FLAGS.fetch) graph.add_to_collection("train_op", fetch) metagraph = saver.export_meta_graph(graph_def=graph.as_graph_def(), graph=graph) if FLAGS.rewriter_config is not None: rewriter_config = rewriter_config_pb2.RewriterConfig() text_format.Merge(FLAGS.rewriter_config, rewriter_config) optimized_graph = tf_optimizer.OptimizeGraph(rewriter_config, metagraph) metagraph.graph_def.CopyFrom(optimized_graph) report = cost_analyzer.GenerateCostReport(metagraph, FLAGS.per_node_report) print(report)
def testGetSignatureDefByKeyRegression(self): input1 = constant_op.constant("a", name="input-1") output1 = constant_op.constant("b", name="output-1") meta_graph_def = meta_graph_pb2.MetaGraphDef() self._add_to_signature_def_map( meta_graph_def, { "my_regression": signature_def_utils.regression_signature_def(input1, output1) }) # Look up the regression signature with the key used while saving. signature_def = signature_def_contrib_utils.get_signature_def_by_key( meta_graph_def, "my_regression") # Check the method name to match the constants regression method name. self.assertEqual(signature_constants.REGRESS_METHOD_NAME, signature_def.method_name) # Check inputs in signature def. self.assertEqual(1, len(signature_def.inputs)) self._check_tensor_info(signature_def.inputs, signature_constants.REGRESS_INPUTS, "input-1:0") # Check outputs in signature def. self.assertEqual(1, len(signature_def.outputs)) self._check_tensor_info(signature_def.outputs, signature_constants.REGRESS_OUTPUTS, "output-1:0")
def _GenerateTestData(self): """Generates the test data directory. The test data has a single run named run1 which contains: - a graph definition and metagraph definition Returns: temp_dir: The directory the test data is generated under. """ temp_dir = tempfile.mkdtemp(prefix=self.get_temp_dir()) self.addCleanup(shutil.rmtree, temp_dir) run1_path = os.path.join(temp_dir, 'run1') os.makedirs(run1_path) writer = tf.summary.FileWriter(run1_path) # Add a simple graph event. graph_def = tf.GraphDef() node1 = graph_def.node.add() node1.name = 'a' node2 = graph_def.node.add() node2.name = 'b' node2.attr['very_large_attr'].s = b'a' * 2048 # 2 KB attribute meta_graph_def = meta_graph_pb2.MetaGraphDef(graph_def=graph_def) if self._only_use_meta_graph: writer.add_meta_graph(meta_graph_def) else: writer.add_graph(graph_def) writer.flush() writer.close() return temp_dir
def read_meta_graph_file(filename): """Reads a file containing `MetaGraphDef` and returns the protocol buffer. Args: filename: `meta_graph_def` filename including the path. Returns: A `MetaGraphDef` protocol buffer. Raises: IOError: If the file doesn't exist, or cannot be successfully parsed. """ meta_graph_def = meta_graph_pb2.MetaGraphDef() if not file_io.file_exists(filename): raise IOError(f"File does not exist. Received: {filename}.") # First try to read it as a binary file. with file_io.FileIO(filename, "rb") as f: file_content = f.read() try: meta_graph_def.ParseFromString(file_content) return meta_graph_def except Exception: # pylint: disable=broad-except pass # Next try to read it as a text file. try: text_format.Merge(file_content.decode("utf-8"), meta_graph_def) except text_format.ParseError as e: raise IOError(f"Cannot parse file {filename}: {str(e)}.") return meta_graph_def
def testGetNodeMapBasic(self): meta_graph_def = meta_graph_pb2.MetaGraphDef() meta_graph_def.collection_def[ 'my_collection/%s' % encoding.KEY_SUFFIX].bytes_list.value[:] = map( encoding.encode_key, ['alpha', 'bravo', 'charlie']) meta_graph_def.collection_def[ 'my_collection/fruits'].bytes_list.value[:] = [ b'apple', b'banana', b'cherry' ] expected = { 'alpha': { 'fruits': b'apple' }, 'bravo': { 'fruits': b'banana' }, 'charlie': { 'fruits': b'cherry' } } self.assertDictEqual( expected, graph_ref.get_node_map(meta_graph_def, 'my_collection', ['fruits']))
def testGetSignatureDefByKeyClassification(self): input1 = constant_op.constant("a", name="input-1") output1 = constant_op.constant("b", name="output-1") output2 = constant_op.constant("c", name="output-2") meta_graph_def = meta_graph_pb2.MetaGraphDef() self._add_to_signature_def_map( meta_graph_def, { "my_classification": signature_def_utils.classification_signature_def( input1, output1, output2) }) # Look up the classification signature def with the key used while saving. signature_def = signature_def_contrib_utils.get_signature_def_by_key( meta_graph_def, "my_classification") # Check the method name to match the constants classification method name. self.assertEqual(signature_constants.CLASSIFY_METHOD_NAME, signature_def.method_name) # Check inputs in signature def. self.assertEqual(1, len(signature_def.inputs)) self._check_tensor_info(signature_def.inputs, signature_constants.CLASSIFY_INPUTS, "input-1:0") # Check outputs in signature def. self.assertEqual(2, len(signature_def.outputs)) self._check_tensor_info(signature_def.outputs, signature_constants.CLASSIFY_OUTPUT_CLASSES, "output-1:0") self._check_tensor_info(signature_def.outputs, signature_constants.CLASSIFY_OUTPUT_SCORES, "output-2:0")
def get_metagraph(): """Constructs and returns a MetaGraphDef from the input file.""" if FLAGS.metagraphdef: with gfile.GFile(FLAGS.metagraphdef) as meta_file: metagraph = meta_graph_pb2.MetaGraphDef() if FLAGS.metagraphdef.endswith(".pbtxt"): text_format.Merge(meta_file.read(), metagraph) else: metagraph.ParseFromString(meta_file.read()) if FLAGS.fetch is not None: fetch_collection = meta_graph_pb2.CollectionDef() for fetch in FLAGS.fetch.split(","): fetch_collection.node_list.value.append(fetch) metagraph.collection_def["train_op"].CopyFrom(fetch_collection) else: with gfile.GFile(FLAGS.graphdef) as graph_file: graph_def = graph_pb2.GraphDef() if FLAGS.graphdef.endswith(".pbtxt"): text_format.Merge(graph_file.read(), graph_def) else: graph_def.ParseFromString(graph_file.read()) importer.import_graph_def(graph_def, name="") graph = ops.get_default_graph() for fetch in FLAGS.fetch.split(","): fetch_op = graph.get_operation_by_name(fetch) graph.add_to_collection("train_op", fetch_op) metagraph = saver.export_meta_graph(graph_def=graph.as_graph_def(), graph=graph) return metagraph
def _assertEventsWithGraph(self, test_dir, g, has_shapes): meta_graph_def = meta_graph.create_meta_graph_def( graph_def=g.as_graph_def(add_shapes=has_shapes)) rr = self._EventsReader(test_dir) # The first event should list the file_version. ev = next(rr) self._assertRecent(ev.wall_time) self.assertEquals("brain.Event:2", ev.file_version) # The next event should have the graph. ev = next(rr) self._assertRecent(ev.wall_time) self.assertEquals(0, ev.step) ev_graph = graph_pb2.GraphDef() ev_graph.ParseFromString(ev.graph_def) self.assertProtoEquals(g.as_graph_def(add_shapes=has_shapes), ev_graph) # The next event should have the metagraph. ev = next(rr) self._assertRecent(ev.wall_time) self.assertEquals(0, ev.step) ev_meta_graph = meta_graph_pb2.MetaGraphDef() ev_meta_graph.ParseFromString(ev.meta_graph_def) self.assertProtoEquals(meta_graph_def, ev_meta_graph) # We should be done. self.assertRaises(StopIteration, lambda: next(rr))
def meta_graph_transform(base_meta_graph_def, input_names, output_names, transforms, tags, checkpoint_path=None): """Apply the Graph Transform tool to a MetaGraphDef. Args: base_meta_graph_def: A MetaGraphDef protocol buffer to transform. input_names: Names of input nodes. output_names: Names of output nodes. transforms: A list of strings naming the graph transforms to be applied in order. These transform names are exactly those supported by the Graph Transform Tool, with the addition of the 'freeze_graph' transform. tags: A list of tags with which to annotate the transformed MetaGraphDef. checkpoint_path: A path to a checkpoint to restore during freezing, if needed (default None). Returns: A new transformed MetaGraphDef protocol buffer. """ meta_graph_def = _meta_graph_pb2.MetaGraphDef() initializer_names = _find_all_mandatory_retain_ops(base_meta_graph_def) transformed_graph_def = _do_transforms(base_meta_graph_def.graph_def, input_names, output_names, initializer_names, transforms, base_meta_graph_def.saver_def, checkpoint_path) meta_graph_def.graph_def.CopyFrom(transformed_graph_def) meta_graph_def.meta_info_def.CopyFrom(base_meta_graph_def.meta_info_def) meta_graph_def.meta_info_def.ClearField('tags') for tag in tags: meta_graph_def.meta_info_def.tags.append(tag) base_op_names = [ compat.as_str(node.name) for node in base_meta_graph_def.graph_def.node ] retained_op_names = [ compat.as_str(node.name) for node in meta_graph_def.graph_def.node ] removed_op_names = set(base_op_names) - set(retained_op_names) # Copy saver, excluding any pruned nodes _add_pruned_saver(base_meta_graph_def, meta_graph_def, removed_op_names) # Copy collections, excluding any pruned nodes for collection_name in base_meta_graph_def.collection_def: _add_pruned_collection(base_meta_graph_def, meta_graph_def, collection_name, removed_op_names) # Copy signature_defs, excluding any pruned nodes for signature_name in base_meta_graph_def.signature_def: _add_pruned_signature(base_meta_graph_def, meta_graph_def, signature_name, removed_op_names) return meta_graph_def
def testNoLogdirButExplicitSummaryWriter(self): logdir = self._test_dir("explicit_summary_writer") with ops.Graph().as_default(): summary.scalar("c1", constant_op.constant(1)) summary.scalar("c2", constant_op.constant(2)) summary.scalar("c3", constant_op.constant(3)) summ = summary.merge_all() sw = writer.FileWriter(logdir) sv = supervisor.Supervisor(logdir="", summary_op=None, summary_writer=sw) meta_graph_def = meta_graph.create_meta_graph_def() sess = sv.prepare_or_wait_for_session("") sv.summary_computed(sess, sess.run(summ)) sess.close() # Wait to make sure everything is written to file before stopping. time.sleep(1) sv.stop() # Check the summary was written to 'logdir' rr = _summary_iterator(logdir) # The first event should list the file_version. ev = next(rr) self.assertEquals("brain.Event:2", ev.file_version) # The next one has the graph. ev = next(rr) ev_graph = graph_pb2.GraphDef() ev_graph.ParseFromString(ev.graph_def) self.assertProtoEquals(sess.graph.as_graph_def(add_shapes=True), ev_graph) # Stored MetaGraphDef ev = next(rr) ev_meta_graph = meta_graph_pb2.MetaGraphDef() ev_meta_graph.ParseFromString(ev.meta_graph_def) self.assertProtoEquals(meta_graph_def, ev_meta_graph) self.assertProtoEquals(sess.graph.as_graph_def(add_shapes=True), ev_meta_graph.graph_def) # The next one should have the values from the summary. ev = next(rr) self.assertProtoEquals( """ value { tag: 'c1' simple_value: 1.0 } value { tag: 'c2' simple_value: 2.0 } value { tag: 'c3' simple_value: 3.0 } """, ev.summary) # The next one should be a stop message if we closed cleanly. ev = next(rr) self.assertEquals(event_pb2.SessionLog.STOP, ev.session_log.status) # We should be done. self.assertRaises(StopIteration, lambda: next(rr))
def testNoModuleAttachments(self): meta_graph = meta_graph_pb2.MetaGraphDef() with tf.Graph().as_default(): # No calls to attach_bytes. saved_model_lib._export_module_attachments(meta_graph) actual = saved_model_lib.get_attached_bytes_map(meta_graph) self.assertDictEqual({}, actual) # Check there were no unwarranted subscript operations. self.assertFalse(saved_model_lib.ATTACHMENT_COLLECTION_SAVED in meta_graph.collection_def)
def main(unused_argv=None): event_file = os.path.expanduser(FLAGS.event_file) if not event_file: msg = ('The path to an event_file must be specified. ' 'Run `inspect_summary.py --help` for usage instructions.') logging.error(msg) return -1 for event in tf.train.summary_iterator(event_file): # Yields a sequence of `tensorflow.core.util.event_pb2.Event` # https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/util/event.proto #print '>>', event #print event.wall_time, event.step #import IPython; IPython.embed(); exit() what = event.WhichOneof('what') if what == 'file_version': assert event.file_version == "brain.Event:2" elif what == 'graph_def': # An encoded version of a GraphDef. logging.info('=' * 80) logging.info('GraphDef[wall_time=%s, step=%s]' % (event.wall_time, event.step)) graph = graph_pb2.GraphDef() graph.ParseFromString(event.graph_def) logging.info(str(graph)) logging.info('=' * 80) elif what == 'summary': logging.warning('NYI: summary') elif what == 'log_message': logging.warning('NYI: log_message') # user output a log message elif what == 'session_log': logging.warning('NYI: session_log') elif what == 'tagged_run_metadata': trn = event.tagged_run_metadata logging.info('=' * 80) logging.info('TaggedRunMetadata [wall_time=%s, step=%s, tag=%s]' % (event.wall_time, event.step, trn.tag)) # Lazy deserialization from a byte buffer run_meta = config_pb2.RunMetadata() run_meta.ParseFromString(trn.run_metadata) logging.info(str(run_meta)) logging.info('=' * 80) elif what == 'meta_graph_def': # An encoded version of a MetaGraphDef. logging.info('=' * 80) logging.info('MetaGraphDef [wall_time=%s, step=%s]' % (event.wall_time, event.step)) meta_graph = meta_graph_pb2.MetaGraphDef() meta_graph.ParseFromString(event.meta_graph_def) logging.info(str(meta_graph)) logging.info('=' * 80) else: raise NotImplemented()
def _convert_saved_model(self): """Convert the input SavedModel.""" graph = ops.Graph() with session.Session(graph=graph, config=self._session_config) as sess: input_meta_graph_def = loader.load(sess, self._input_saved_model_tags, self._input_saved_model_dir) input_signature_def = input_meta_graph_def.signature_def[ self._input_saved_model_signature_key] def _gather_names(tensor_info): """Get the node names from a TensorInfo.""" return set([ tensor_info[key].name.split(":")[0] for key in tensor_info ]) # Get input and outputs from all SignatureDef. output_node_names = _gather_names( input_signature_def.inputs).union( _gather_names(input_signature_def.outputs)) # Preserve nodes in collection for collection_key in self._collections_to_keep( input_meta_graph_def.collection_def): for op in sess.graph.get_collection(collection_key): if isinstance(op, ops.Operation): output_node_names.add(op.name.split(":")[0]) # Freeze the variables in the SavedModel graph and copy the frozen # graph over. frozen_graph_def = graph_util.convert_variables_to_constants( sess, sess.graph.as_graph_def(add_shapes=True), list(output_node_names)) self._grappler_meta_graph_def = meta_graph_pb2.MetaGraphDef() self._grappler_meta_graph_def.graph_def.CopyFrom(frozen_graph_def) # Copy the collections that are not variables. for collection_key in self._collections_to_keep( input_meta_graph_def.collection_def): self._grappler_meta_graph_def.collection_def[ collection_key].CopyFrom( input_meta_graph_def.collection_def[collection_key]) self._add_nodes_blacklist() # Copy other information. self._grappler_meta_graph_def.meta_info_def.CopyFrom( input_meta_graph_def.meta_info_def) self._grappler_meta_graph_def.signature_def[ self._input_saved_model_signature_key].CopyFrom( input_signature_def) # TODO(laigd): maybe add back AssetFileDef. self._run_conversion()
def testGetNodeInGraph(self): g = tf.Graph() with g.as_default(): apple = tf.constant(1.0) meta_graph_def = meta_graph_pb2.MetaGraphDef() meta_graph_def.collection_def['fruit_node'].any_list.value.extend( [encoding.encode_tensor_node(apple)]) self.assertEqual( apple, graph_ref.get_node_in_graph(meta_graph_def, 'fruit_node', g))
def testModuleAttachments(self): meta_graph = meta_graph_pb2.MetaGraphDef() with tf.Graph().as_default(): saved_model_lib.attach_bytes("key1", tf.compat.as_bytes("oops")) saved_model_lib.attach_bytes("key2", tf.compat.as_bytes("value2")) saved_model_lib.attach_bytes("key1", tf.compat.as_bytes("value1")) saved_model_lib._export_module_attachments(meta_graph) actual = saved_model_lib.get_attached_bytes_map(meta_graph) expected = {"key1": tf.compat.as_bytes("value1"), "key2": tf.compat.as_bytes("value2")} self.assertDictEqual(expected, actual)
def test_add_pruned_collection_proto_in_bytes_list(self): collection_name = 'proto_collection' base_meta_graph_def = meta_graph_pb2.MetaGraphDef() base_meta_graph_def.collection_def[collection_name].bytes_list.value.extend( [compat.as_bytes(compat.as_str_any(_make_asset_file_def_any('node1'))), compat.as_bytes(compat.as_str_any(_make_asset_file_def_any('node2'))), compat.as_bytes(compat.as_str_any(_make_asset_file_def_any('node3'))), compat.as_bytes(compat.as_str_any(_make_asset_file_def_any('node4')))]) meta_graph_def = meta_graph_pb2.MetaGraphDef() removed_op_names = ['node2', 'node4', 'node5'] meta_graph_transform._add_pruned_collection( base_meta_graph_def, meta_graph_def, collection_name, removed_op_names) collection = meta_graph_def.collection_def[collection_name] expected_values = [ compat.as_bytes(compat.as_str_any(_make_asset_file_def_any('node1'))), compat.as_bytes(compat.as_str_any(_make_asset_file_def_any('node3')))] self.assertEqual(expected_values, collection.bytes_list.value[:])
def testGetSignatureDefByKey(self): x = array_ops.placeholder(dtypes.float32, 1, name="x") x_tensor_info = utils.build_tensor_info(x) y = array_ops.placeholder(dtypes.float32, name="y") y_tensor_info = utils.build_tensor_info(y) foo_signature_def = signature_def_utils.build_signature_def( {"foo-input": x_tensor_info}, {"foo-output": y_tensor_info}, "foo-method-name") bar_signature_def = signature_def_utils.build_signature_def( {"bar-input": x_tensor_info}, {"bar-output": y_tensor_info}, "bar-method-name") meta_graph_def = meta_graph_pb2.MetaGraphDef() self._add_to_signature_def_map(meta_graph_def, { "foo": foo_signature_def, "bar": bar_signature_def }) # Look up a key that does not exist in the SignatureDefMap. missing_key = "missing-key" with self.assertRaisesRegexp( ValueError, "No SignatureDef with key '%s' found in MetaGraphDef" % missing_key): signature_def_contrib_utils.get_signature_def_by_key( meta_graph_def, missing_key) # Look up the key, `foo` which exists in the SignatureDefMap. foo_signature_def = signature_def_contrib_utils.get_signature_def_by_key( meta_graph_def, "foo") self.assertTrue("foo-method-name", foo_signature_def.method_name) # Check inputs in signature def. self.assertEqual(1, len(foo_signature_def.inputs)) self._check_tensor_info(foo_signature_def.inputs, "foo-input", "x:0") # Check outputs in signature def. self.assertEqual(1, len(foo_signature_def.outputs)) self._check_tensor_info(foo_signature_def.outputs, "foo-output", "y:0") # Look up the key, `bar` which exists in the SignatureDefMap. bar_signature_def = signature_def_contrib_utils.get_signature_def_by_key( meta_graph_def, "bar") self.assertTrue("bar-method-name", bar_signature_def.method_name) # Check inputs in signature def. self.assertEqual(1, len(bar_signature_def.inputs)) self._check_tensor_info(bar_signature_def.inputs, "bar-input", "x:0") # Check outputs in signature def. self.assertEqual(1, len(bar_signature_def.outputs)) self._check_tensor_info(bar_signature_def.outputs, "bar-output", "y:0")
def _convert_saved_model(self): """Convert the input SavedModel.""" graph = ops.Graph() with session.Session(graph=graph, config=self._session_config) as sess: input_meta_graph_def = loader.load(sess, self._input_saved_model_tags, self._input_saved_model_dir) input_signature_def = input_meta_graph_def.signature_def[ self._input_saved_model_signature_key] def _gather_names(tensor_info): """Get the node names from a TensorInfo.""" return set([ tensor_info[key].name.split(":")[0] for key in tensor_info ]) # Get input and outputs from all SignatureDef. output_node_names = _gather_names( input_signature_def.inputs).union( _gather_names(input_signature_def.outputs)) # Freeze the variables in the SavedModel graph and copy the frozen # graph over. frozen_graph_def = graph_util.convert_variables_to_constants( sess, sess.graph.as_graph_def(add_shapes=True), list(output_node_names)) self._grappler_meta_graph_def = meta_graph_pb2.MetaGraphDef() self._grappler_meta_graph_def.graph_def.CopyFrom(frozen_graph_def) # Copy the collections that are not variables. for key in input_meta_graph_def.collection_def: # TODO(laigd): currently we use the collection key to filter out # collections that depend on variable ops, but this may miss some # other user-defined collections. A better way would be to use # CollectionDef::NodeList for the filtering. if key not in [ "variables", "local_variables", "model_variables", "trainable_variables", "train_op", "table_initializer" ]: self._grappler_meta_graph_def.collection_def[key].CopyFrom( input_meta_graph_def.collection_def[key]) self._add_nodes_blacklist() # Copy other information. self._grappler_meta_graph_def.meta_info_def.CopyFrom( input_meta_graph_def.meta_info_def) self._grappler_meta_graph_def.signature_def[ self._input_saved_model_signature_key].CopyFrom( input_signature_def) # TODO(laigd): maybe add back AssetFileDef. self._run_conversion()
def testChiefCanWriteEvents(self): logdir = _test_dir("can_write") with tf.Graph().as_default(): tf.summary.scalar("c1", tf.constant(1)) tf.summary.scalar("c2", tf.constant(2)) tf.summary.scalar("c3", tf.constant(3)) summ = tf.summary.merge_all() sv = tf.train.Supervisor(is_chief=True, logdir=logdir, summary_op=None) meta_graph_def = meta_graph.create_meta_graph_def() sess = sv.prepare_or_wait_for_session("") sv.summary_computed(sess, sess.run(summ)) sess.close() # Wait to make sure everything is written to file before stopping. time.sleep(1) sv.stop() rr = _summary_iterator(logdir) # The first event should list the file_version. ev = next(rr) self.assertEquals("brain.Event:2", ev.file_version) # The next one has the graph. ev = next(rr) ev_graph = tf.GraphDef() ev_graph.ParseFromString(ev.graph_def) self.assertProtoEquals(sess.graph.as_graph_def(add_shapes=True), ev_graph) # Stored MetaGraphDef ev = next(rr) ev_meta_graph = meta_graph_pb2.MetaGraphDef() ev_meta_graph.ParseFromString(ev.meta_graph_def) self.assertProtoEquals(meta_graph_def, ev_meta_graph) self.assertProtoEquals(sess.graph.as_graph_def(add_shapes=True), ev_meta_graph.graph_def) # The next one should have the values from the summary. ev = next(rr) self.assertProtoEquals( """ value { tag: 'c1' simple_value: 1.0 } value { tag: 'c2' simple_value: 2.0 } value { tag: 'c3' simple_value: 3.0 } """, ev.summary) # The next one should be a stop message if we closed cleanly. ev = next(rr) self.assertEquals(tf.SessionLog.STOP, ev.session_log.status) # We should be done. self.assertRaises(StopIteration, lambda: next(rr))
def main(_): with gfile.GFile(FLAGS.input) as input_file: metagraph = meta_graph_pb2.MetaGraphDef() metagraph.ParseFromString(input_file.read()) if FLAGS.rewriter_config is not None: rewriter_config = rewriter_config_pb2.RewriterConfig() text_format.Merge(FLAGS.rewriter_config, rewriter_config) optimized_graph = tf_optimizer.OptimizeGraph(rewriter_config, metagraph) metagraph.graph_def.CopyFrom(optimized_graph) report = cost_analyzer.GenerateCostReport(metagraph, FLAGS.per_node_report) print(report)
def testStandardServicesWithoutGlobalStep(self): logdir = _test_dir("standard_services_without_global_step") # Create a checkpoint. with tf.Graph().as_default(): v = tf.Variable([1.0], name="foo") tf.scalar_summary(["v"], v) sv = tf.train.Supervisor(logdir=logdir) meta_graph_def = meta_graph.create_meta_graph_def( saver_def=sv.saver.saver_def) sess = sv.prepare_or_wait_for_session("") save_path = sv.save_path self._wait_for_glob(save_path, 3.0) self._wait_for_glob(os.path.join(logdir, "*events*"), 3.0, for_checkpoint=False) # Wait to make sure everything is written to file before stopping. time.sleep(1) sv.stop() # There should be an event file with a version number. rr = _summary_iterator(logdir) ev = next(rr) self.assertEquals("brain.Event:2", ev.file_version) ev = next(rr) ev_graph = tf.GraphDef() ev_graph.ParseFromString(ev.graph_def) self.assertProtoEquals(sess.graph.as_graph_def(add_shapes=True), ev_graph) # Stored MetaGraphDef ev = next(rr) ev_meta_graph = meta_graph_pb2.MetaGraphDef() ev_meta_graph.ParseFromString(ev.meta_graph_def) self.assertProtoEquals(meta_graph_def, ev_meta_graph) self.assertProtoEquals(sess.graph.as_graph_def(add_shapes=True), ev_meta_graph.graph_def) ev = next(rr) self.assertProtoEquals("value { tag: 'v' simple_value: 1.0 }", ev.summary) ev = next(rr) self.assertEquals(tf.SessionLog.STOP, ev.session_log.status) self.assertRaises(StopIteration, lambda: next(rr)) # There should be a checkpoint file with the variable "foo" with tf.Graph().as_default(), self.test_session() as sess: v = tf.Variable([10.10], name="foo") sav = tf.train.Saver([v]) sav.restore(sess, save_path) self.assertEqual(1.0, v.eval()[0])