def _generate_test_data(self, run_name, experiment_name): """Generates the test data directory. The test data has a single run of the given name, containing: - a graph definition and metagraph definition Arguments: run_name: The directory under self.logdir into which to write events. """ run_path = os.path.join(self.logdir, run_name) writer = tf.summary.FileWriter(run_path) # Add a simple graph event. graph_def = tf.GraphDef() node1 = graph_def.node.add() node1.name = 'a' node2 = graph_def.node.add() node2.name = 'b' node2.attr['very_large_attr'].s = b'a' * 2048 # 2 KB attribute meta_graph_def = tf.MetaGraphDef(graph_def=graph_def) if self._only_use_meta_graph: writer.add_meta_graph(meta_graph_def) else: writer.add_graph(graph_def) writer.flush() writer.close() # Write data for the run to the database. # TODO(nickfelt): Figure out why reseting the graph is necessary. tf.reset_default_graph() db_writer = tf.contrib.summary.create_db_writer( db_uri=self.db_path, experiment_name=experiment_name, run_name=run_name, user_name='user') with db_writer.as_default(), tf.contrib.summary.always_record_summaries(): tf.contrib.summary.scalar('mytag', 1) with tf.Session() as sess: sess.run(tf.global_variables_initializer()) sess.run(tf.contrib.summary.summary_writer_initializer_op()) sess.run(tf.contrib.summary.all_summary_ops())
def _load_graph_using_meta(self, model): tf.reset_default_graph() graph = tf.Graph() graph_def = tf.MetaGraphDef() with open(model, "rb") as model_file: graph_def.ParseFromString(model_file.read()) with tf.Session() as sess: restorer = tf.train.import_meta_graph(graph_def) restorer.restore(sess, re.sub(r'\.meta$', '', model)) graph_def = tf.graph_util.convert_variables_to_constants( sess, graph_def.graph_def, self._config_outputs ) with graph.as_default(): tf.import_graph_def(graph_def, name='') return graph
def main(): meta_graph.read_meta_graph_file("/tmp/tf_training/ckpt/rjmodel.ckpt.meta") g = tf.MetaGraphDef() g.ParseFromString( open("/tmp/tf_training/ckpt/rjmodel.ckpt.meta", "rb").read()) #print("GraphDef from meta_graph_file:", g.graph_def) g = tf.GraphDef() g.ParseFromString(open("/tmp/stylize_quantized.pb", "rb").read()) print("GraphDef: ", g) #[n for n in g.node if n.name.find("input") != -1] # same for output or any other node you want to make sure is ok saved_model = saved_model_pb2.SavedModel() saved_model.ParseFromString( open("/tmp/SavedModel/saved_model.pb", "rb").read()) print("saved_model parsed", saved_model.saved_model_schema_version, len(saved_model.meta_graphs)) print("GraphDef from SavedModel file:", saved_model.meta_graphs[0].graph_def)
def predict_func(rows, graph_json, prediction, graph_weights, inp, activation, tf_input, tf_dropout=None, to_keep_dropout=False): rows = [r.asDict() for r in rows] if len(rows) > 0: graph = tf.MetaGraphDef() graph = json_format.Parse(graph_json, graph) loaded_weights = json.loads(graph_weights) loaded_weights = [np.asarray(x) for x in loaded_weights] A = [np.asarray(row[inp]) for row in rows] new_graph = tf.Graph() with tf.Session(graph=new_graph) as sess: tf.train.import_meta_graph(graph) sess.run(tf.global_variables_initializer()) tensorflow_set_weights(loaded_weights) out_node = tf.get_default_graph().get_tensor_by_name(activation) dropout_v = 1.0 if tf_dropout is not None and to_keep_dropout else 0.0 feed_dict = { tf_input: A } if tf_dropout is None else { tf_input: A, tf_dropout: dropout_v } pred = sess.run(out_node, feed_dict=feed_dict) for i in range(0, len(rows)): row = rows[i] try: # Vectors Dense are handled differently in python 3 internal = float(pred[i]) row[prediction] = internal except: row[prediction] = Vectors.dense(pred[i]) return [Row(**a) for a in rows] return []
def _tf_meta_graph_def_to_lgf_meta_graph_info(self, meta_graph_def): # Store the serialized partial meta graph def partial_meta_graph_def = tf.MetaGraphDef() partial_meta_graph_def.saver_def.CopyFrom(meta_graph_def.saver_def) for k, v in meta_graph_def.collection_def.items(): partial_meta_graph_def.collection_def[k].CopyFrom(v) for k, v in meta_graph_def.signature_def.items(): partial_meta_graph_def.signature_def[k].CopyFrom(v) partial_meta_graph_def.graph_def.CopyFrom(meta_graph_def.graph_def) del (partial_meta_graph_def.graph_def.node[:]) meta_graph_info = lgf_pb2.MetaGraphInfo() meta_graph_info.original_graph_info[ ImportTFSavedModelBase. PARTIAL_META_GRAPH_DEF].v = partial_meta_graph_def.SerializeToString( ) # Get all strings from the partial meta graph def proto_strings = self.get_strings_from_proto(partial_meta_graph_def) # Need to manually deserialize the collection defs for k, collection_def in partial_meta_graph_def.collection_def.items(): if collection_def.HasField("bytes_list"): proto_type = tf_ops.get_collection_proto_type(k) for serialized_proto in collection_def.bytes_list.value: proto = proto_type() proto.ParseFromString(serialized_proto) proto_strings.extend(self.get_strings_from_proto(proto)) # Get all the node names from proto_strings node_names = {n.name for n in self._graph_def.node} for string in proto_strings: if string != "": name = self.get_node_name_and_output_index(string)[0] if name in node_names: self._required_nodes.add(name) # Add required nodes to meta_graph_info meta_graph_info.required_nodes.extend(self._required_nodes) return meta_graph_info
def test_add_new_inits_to_collection(self): meta_graph_def = tf.MetaGraphDef() orig_table_inits = ['t1', 't2'] new_table_inits = ['t3', 't4'] meta_graph_def.collection_def[ tf.GraphKeys.TABLE_INITIALIZERS].node_list.value.extend( orig_table_inits) updated_init_names = { tf.GraphKeys.TABLE_INITIALIZERS: orig_table_inits + new_table_inits } meta_graph_editor._add_new_inits_to_collection(meta_graph_def, updated_init_names) self.assertEqual( meta_graph_def.collection_def[ tf.GraphKeys.TABLE_INITIALIZERS].node_list.value, orig_table_inits + new_table_inits)
def load_metagraph(model, url_prefix, metagraph_cache): """Loads and caches a trained model metagraph.""" filename = os.path.join(metagraph_cache, model + ".metagraph") try: with tf.io.gfile.GFile(filename, "rb") as f: string = f.read() except tf.errors.NotFoundError: url = url_prefix + "/" + model + ".metagraph" try: request = urllib.request.urlopen(url) string = request.read() finally: request.close() tf.io.gfile.makedirs(os.path.dirname(filename)) with tf.io.gfile.GFile(filename, "wb") as f: f.write(string) metagraph = tf.MetaGraphDef() metagraph.ParseFromString(string) tf.train.import_meta_graph(metagraph) return metagraph.signature_def
def _ProcessEvent(self, event): """Called whenever an event is loaded.""" if self._first_event_timestamp is None: self._first_event_timestamp = event.wall_time if event.HasField('file_version'): new_file_version = _ParseFileVersion(event.file_version) if self.file_version and self.file_version != new_file_version: ## This should not happen. tf.logging.warn( ('Found new file_version for event.proto. This will ' 'affect purging logic for TensorFlow restarts. ' 'Old: {0} New: {1}').format(self.file_version, new_file_version)) self.file_version = new_file_version self._MaybePurgeOrphanedData(event) ## Process the event. # GraphDef and MetaGraphDef are handled in a special way: # If no graph_def Event is available, but a meta_graph_def is, and it # contains a graph_def, then use the meta_graph_def.graph_def as our graph. # If a graph_def Event is available, always prefer it to the graph_def # inside the meta_graph_def. if event.HasField('graph_def'): if self._graph is not None: tf.logging.warn( ('Found more than one graph event per run, or there was ' 'a metagraph containing a graph_def, as well as one or ' 'more graph events. Overwriting the graph with the ' 'newest event.')) self._graph = event.graph_def self._graph_from_metagraph = False elif event.HasField('meta_graph_def'): if self._meta_graph is not None: tf.logging.warn( ('Found more than one metagraph event per run. ' 'Overwriting the metagraph with the newest event.')) self._meta_graph = event.meta_graph_def if self._graph is None or self._graph_from_metagraph: # We may have a graph_def in the metagraph. If so, and no # graph_def is directly available, use this one instead. meta_graph = tf.MetaGraphDef() meta_graph.ParseFromString(self._meta_graph) if meta_graph.graph_def: if self._graph is not None: tf.logging.warn(( 'Found multiple metagraphs containing graph_defs,' 'but did not find any graph events. Overwriting the ' 'graph with the newest metagraph version.')) self._graph_from_metagraph = True self._graph = meta_graph.graph_def.SerializeToString() elif event.HasField('tagged_run_metadata'): tag = event.tagged_run_metadata.tag if tag in self._tagged_metadata: tf.logging.warn( 'Found more than one "run metadata" event with tag ' + tag + '. Overwriting it with the newest event.') self._tagged_metadata[tag] = event.tagged_run_metadata.run_metadata elif event.HasField('summary'): for value in event.summary.value: value = data_compat.migrate_value(value) if value.HasField('metadata'): tag = value.tag # We only store the first instance of the metadata. This check # is important: the `FileWriter` does strip metadata from all # values except the first one per each tag, but a new # `FileWriter` is created every time a training job stops and # restarts. Hence, we must also ignore non-initial metadata in # this logic. if tag not in self.summary_metadata: self.summary_metadata[tag] = value.metadata plugin_data = value.metadata.plugin_data if plugin_data.plugin_name: self._plugin_to_tag_to_content[ plugin_data.plugin_name][tag] = ( plugin_data.content) else: tf.logging.warn(( 'This summary with tag %r is oddly not associated with a ' 'plugin.'), tag) for summary_type, summary_func in SUMMARY_TYPES.items(): if value.HasField(summary_type): datum = getattr(value, summary_type) tag = value.tag if summary_type == 'tensor' and not tag: # This tensor summary was created using the old method that used # plugin assets. We must still continue to support it. tag = value.node_name getattr(self, summary_func)(tag, event.wall_time, event.step, datum)
def _ProcessEvent(self, event): """Called whenever an event is loaded.""" if self._first_event_timestamp is None: self._first_event_timestamp = event.wall_time if event.HasField('file_version'): new_file_version = _ParseFileVersion(event.file_version) if self.file_version and self.file_version != new_file_version: ## This should not happen. tf.logging.warn(('Found new file_version for event.proto. This will ' 'affect purging logic for TensorFlow restarts. ' 'Old: {0} New: {1}').format(self.file_version, new_file_version)) self.file_version = new_file_version self._MaybePurgeOrphanedData(event) ## Process the event. # GraphDef and MetaGraphDef are handled in a special way: # If no graph_def Event is available, but a meta_graph_def is, and it # contains a graph_def, then use the meta_graph_def.graph_def as our graph. # If a graph_def Event is available, always prefer it to the graph_def # inside the meta_graph_def. if event.HasField('graph_def'): if self._graph is not None: tf.logging.warn( ('Found more than one graph event per run, or there was ' 'a metagraph containing a graph_def, as well as one or ' 'more graph events. Overwriting the graph with the ' 'newest event.')) self._graph = event.graph_def self._graph_from_metagraph = False elif event.HasField('meta_graph_def'): if self._meta_graph is not None: tf.logging.warn(('Found more than one metagraph event per run. ' 'Overwriting the metagraph with the newest event.')) self._meta_graph = event.meta_graph_def if self._graph is None or self._graph_from_metagraph: # We may have a graph_def in the metagraph. If so, and no # graph_def is directly available, use this one instead. meta_graph = tf.MetaGraphDef() meta_graph.ParseFromString(self._meta_graph) if meta_graph.graph_def: if self._graph is not None: tf.logging.warn( ('Found multiple metagraphs containing graph_defs,' 'but did not find any graph events. Overwriting the ' 'graph with the newest metagraph version.')) self._graph_from_metagraph = True self._graph = meta_graph.graph_def.SerializeToString() elif event.HasField('tagged_run_metadata'): tag = event.tagged_run_metadata.tag if tag in self._tagged_metadata: tf.logging.warn('Found more than one "run metadata" event with tag ' + tag + '. Overwriting it with the newest event.') self._tagged_metadata[tag] = event.tagged_run_metadata.run_metadata elif event.HasField('summary'): for value in event.summary.value: if (value.HasField('tensor') and value.tag.startswith(HEALTH_PILL_EVENT_TAG_PREFIX)): self._ProcessHealthPillSummary(value, event) else: for summary_type, summary_func in SUMMARY_TYPES.items(): if value.HasField(summary_type): datum = getattr(value, summary_type) tag = value.node_name if summary_type == 'tensor' else value.tag getattr(self, summary_func)(tag, event.wall_time, event.step, datum)
def __init__(self, meta_file): self._meta_def = tf.MetaGraphDef() with open(meta_file, 'rb') as fd: self._meta_def.ParseFromString(fd.read())
def start_service(self, tensorflowGraph, optimizer): """ Asynchronous flask service. This may be a bit confusing why the server starts here and not init. It is basically because this is ran in a separate process, and when python call fork, we want to fork from this thread and not the master thread """ app = Flask(__name__) self.app = app max_errors = self.iters lock = RWLock() server = tf.train.Server.create_local_server() graph = tf.MetaGraphDef() metagraph = json_format.Parse(tensorflowGraph, graph) ng = tf.Graph() with ng.as_default(): tf.train.import_meta_graph(metagraph) loss_variable = tf.get_collection(tf.GraphKeys.LOSSES)[0] grads = tf.gradients(loss_variable, tf.trainable_variables()) grads = list(zip(grads, tf.trainable_variables())) train_op = optimizer.apply_gradients(grads) init = tf.global_variables_initializer() glob_session = tf.Session(server.target, graph=ng) with ng.as_default(): with glob_session.as_default(): glob_session.run(init) self.weights = tensorflow_get_weights() cont = itertools.count() lock_acquired = self.acquire_lock @app.route('/') def home(): return 'Lifeomic' @app.route('/parameters', methods=['GET']) def get_parameters(): if lock_acquired: lock.acquire_read() vs = pickle.dumps(self.weights) if lock_acquired: lock.release() return vs @app.route('/update', methods=['POST']) def update_parameters(): with ng.as_default(): gradients = pickle.loads(request.data) nu_feed = {} for x, grad_var in enumerate(grads): nu_feed[grad_var[0]] = gradients[x] if lock_acquired: lock.acquire_write() with glob_session.as_default(): try: glob_session.run(train_op, feed_dict=nu_feed) self.weights = tensorflow_get_weights() except: error_cnt = cont.next() if error_cnt >= max_errors: raise Exception("Too many failures during training") finally: if lock_acquired: lock.release() return 'completed' self.app.run(host='0.0.0.0', debug=True, threaded=True, use_reloader=False, port=5000)
def main(): parser = argparse.ArgumentParser() parser.add_argument("--export", default=None, help="output_dir from --mode export run") parser.add_argument("--frozen", default=None, help="frozen graph pb file") parser.add_argument("--tflite", default=None, help="tflite file") parser.add_argument("--input", default='Untitled.png', help="input image") parser.add_argument("--output", default='out.png', help="output image") a = parser.parse_args() out_files = [a.output] im_files = [a.input] images_cv = [cv2.resize(cv2.imread(f), (256, 256)) for f in im_files] images = np.array(images_cv, dtype=np.float32) images = images / 255.0 if a.tflite: interpreter = tf.lite.Interpreter(model_path=a.tflite) interpreter.allocate_tensors() input_details = interpreter.get_input_details() output_details = interpreter.get_output_details() print(input_details) print(output_details) interpreter.set_tensor(input_details[0]['index'], images) interpreter.invoke() output = interpreter.get_tensor(output_details[0]['index']) output = output[:, :, :, ::-1] output = output * 255 print("Writing " + out_files[0]) cv2.imwrite(out_files[0], output[0]) if a.frozen: with tf.Session() as sess: with gfile.FastGFile(a.frozen, 'rb') as f: graph_def = tf.GraphDef() graph_def.ParseFromString(f.read()) sess.graph.as_default() tf.import_graph_def(graph_def, name='') input = tf.get_default_graph().get_tensor_by_name("TFLiteInput:0") output = tf.get_default_graph().get_tensor_by_name( "TFLiteOutput:0") output = output[:, :, :, ::-1] output = output * 255 print("Writing " + out_files[0]) cv2.imwrite(out_files[0], output.eval({'TFLiteInput:0': images})[0]) if a.export: with tf.Session() as sess: with gfile.FastGFile(os.path.join(a.export, 'export.meta'), 'rb') as f: meta_graph_def = tf.MetaGraphDef() meta_graph_def.ParseFromString(f.read()) tf.train.import_meta_graph(meta_graph_def) checkpoint = tf.train.latest_checkpoint(a.export) restore_saver = tf.train.Saver() restore_saver.restore(sess, checkpoint) input = tf.get_default_graph().get_tensor_by_name("TFLiteInput:0") output = tf.get_default_graph().get_tensor_by_name( "TFLiteOutput:0") output = output[:, :, :, ::-1] output = output * 255 print("Writing " + out_files[0]) cv2.imwrite(out_files[0], output.eval({'TFLiteInput:0': images})[0])
def get_meta_graph_copy(self, tags=None): """Returns a copy of a MetaGraph with the identical set of tags.""" meta_graph = self.get_meta_graph(tags) copy = tf.MetaGraphDef() copy.CopyFrom(meta_graph) return copy
def test_freeze_then_sparsify(self, freeze_mock, graph_transform_mock): tag_name = 'tag' input_nodes = 'input_nodes' output_nodes = 'output_nodes' freeze_transform = 'freeze_graph' sparsify_transform = 'sparsify_gather' base_meta_graph_def = tf.MetaGraphDef() # Add a table initializer. table_init_name = 'table_init' node_def = tf.NodeDef(name=table_init_name, op='InitializeTableV2') base_meta_graph_def.graph_def.node.extend([node_def]) # Add a group_deps node. group_deps_name = 'group_deps' node_def = tf.NodeDef(name=group_deps_name, op='NoOp') node_def.input.extend(['^table_init']) base_meta_graph_def.graph_def.node.extend([node_def]) base_meta_graph_def.collection_def[ tf.GraphKeys.TABLE_INITIALIZERS].node_list.value.extend( [table_init_name]) base_meta_graph_def.collection_def[ tf.saved_model.constants. LEGACY_INIT_OP_KEY].node_list.value.extend([group_deps_name]) # Expected metagraphdef. expected_meta_graph_def = tf.MetaGraphDef() expected_meta_graph_def.CopyFrom(base_meta_graph_def) expected_meta_graph_def.meta_info_def.tags.append(tag_name) transformed_graph_def = tf.GraphDef() transformed_graph_def.CopyFrom(expected_meta_graph_def.graph_def) freeze_mock.return_value = transformed_graph_def graph_transform_mock.return_value = transformed_graph_def # Add unsaved init node. unsaved_init_name = 'unsaved_node' node_def = tf.NodeDef(name=unsaved_init_name, op='NoOp') base_meta_graph_def.graph_def.node.extend([node_def]) # Add a saver. base_meta_graph_def.saver_def.filename_tensor_name = 'node1' base_meta_graph_def.saver_def.save_tensor_name = 'node3' base_meta_graph_def.saver_def.restore_op_name = 'node6' transformed_meta_graph_def = meta_graph_editor.meta_graph_editor( base_meta_graph_def, [input_nodes], [output_nodes], [freeze_transform, sparsify_transform], [tag_name]) self.assertEqual(expected_meta_graph_def, transformed_meta_graph_def) freeze_mock.assert_called_once_with(base_meta_graph_def.graph_def, [output_nodes], [table_init_name], group_deps_name, base_meta_graph_def.saver_def, None) graph_transform_mock.assert_called_once_with(transformed_graph_def, [ input_nodes ], [output_nodes, group_deps_name, table_init_name], [ sparsify_transform + '(group_init_node="sparify_gather_init_op")' ])
def handle_model(data, graph_json, tfInput, tfLabel=None, master_url='localhost:5000', iters=1000, mini_batch_size=-1, shuffle=True, mini_stochastic_iters=-1, verbose=0, loss_callback=None): is_supervised = tfLabel is not None features, labels = handle_features(data, is_supervised) gd = tf.MetaGraphDef() gd = json_format.Parse(graph_json, gd) new_graph = tf.Graph() with tf.Session(graph=new_graph) as sess: tf.train.import_meta_graph(gd) loss_variable = tf.get_collection(tf.GraphKeys.LOSSES)[0] sess.run(tf.global_variables_initializer()) trainable_variables = tf.trainable_variables() grads = tf.gradients(loss_variable, trainable_variables) grads = list(zip(grads, trainable_variables)) partition_id = uuid.uuid4().hex for i in range(0, iters): weights = get_server_weights(master_url) tensorflow_set_weights(weights, vs=trainable_variables) if shuffle: features, labels = handle_shuffle(features, labels) if mini_stochastic_iters >= 1: for _ in range(0, mini_stochastic_iters): gradients = [] feed_dict = handle_feed_dict(features, tfInput, tfLabel, labels, mini_batch_size) for x in range(len(grads)): gradients.append(grads[x][0].eval(feed_dict=feed_dict)) try: put_deltas_to_server(gradients, master_url) except Exception: print("Timeout error from partition %s" % partition_id) elif mini_batch_size >= 1: for r in range(0, len(features), mini_batch_size): gradients = [] weights = get_server_weights(master_url) tensorflow_set_weights(weights, vs=trainable_variables) feed_dict = handle_feed_dict(features, tfInput, tfLabel, labels, mini_batch_size, idx=r) for x in range(len(grads)): gradients.append(grads[x][0].eval(feed_dict=feed_dict)) try: put_deltas_to_server(gradients, master_url) except Exception: print("Timeout error from partition %s" % partition_id) else: gradients = [] feed_dict = handle_feed_dict(features, tfInput, tfLabel, labels, mini_batch_size) for x in range(len(grads)): gradients.append(grads[x][0].eval(feed_dict=feed_dict)) try: put_deltas_to_server(gradients, master_url) except Exception: print("Timeout error from partition %s" % partition_id) if verbose or loss_callback: feed_dict = handle_feed_dict(features, tfInput, tfLabel, labels, -1) loss = sess.run(loss_variable, feed_dict=feed_dict) if verbose: print("Partition Id: %s, Iteration: %i, Loss: %f" % (partition_id, i, loss)) if loss_callback: loss_callback(loss, i, partition_id)
def load_tf_graph_def(graph_file_name: str = "", is_binary: bool = True, checkpoint: str = "", model_dir: str = "", saved_model_tags: list = [], meta_graph_file: str = "", user_output_node_names_list: list = []): # As a provisional solution, use a native TF methods to load a model protobuf graph_def = tf.GraphDef() if isinstance(graph_file_name, str) and (re.match('.*\.(ckpt|meta)$', graph_file_name)): print( '[ WARNING ] The value for the --input_model command line parameter ends with ".ckpt" or ".meta" ' 'extension.\n' 'It means that the model is not frozen.\n' 'To load non frozen model to Model Optimizer run:' '\n\n1. For "*.ckpt" file:' '\n- if inference graph is in binary format' '\npython3 mo_tf.py --input_model "path/to/inference_graph.pb" --input_checkpoint "path/to/*.ckpt"' '\n- if inference graph is in text format' '\npython3 mo_tf.py --input_model "path/to/inference_graph.pbtxt" --input_model_is_text ' '--input_checkpoint "path/to/*.ckpt"' '\n\n2. For "*.meta" file:' '\npython3 mo_tf.py --input_meta_graph "path/to/*.meta"') variables_values = {} try: if graph_file_name and not meta_graph_file and not checkpoint: # frozen graph return read_file_to_graph_def(graph_def, graph_file_name, is_binary), variables_values if graph_file_name and not meta_graph_file and checkpoint: # inference graph and checkpoint graph_def = read_file_to_graph_def(graph_def, graph_file_name, is_binary) outputs = get_output_node_names_list(graph_def, user_output_node_names_list) if os.path.isfile(checkpoint): graph_def = freeze_checkpoint(graph_def=graph_def, checkpoint=checkpoint, output_node_names=outputs) elif os.path.isdir(checkpoint): graph_def, variables_values = freeze_checkpoints( graph_def=graph_def, checkpoint_dir=checkpoint, output_node_names=outputs) # we are sure that checkpoint is existing file or directory due to cli_parser configuration return graph_def, variables_values if not graph_file_name and meta_graph_file: meta_graph_file = deducing_metagraph_path(meta_graph_file) input_meta_graph_def = read_file_to_graph_def( tf.MetaGraphDef(), meta_graph_file, is_binary) # pylint: disable=no-member with tf.Session() as sess: restorer = tf.train.import_meta_graph(input_meta_graph_def) restorer.restore(sess, re.sub('\.meta$', '', meta_graph_file)) outputs = get_output_node_names_list( input_meta_graph_def.graph_def, user_output_node_names_list) graph_def = tf.graph_util.convert_variables_to_constants( sess, input_meta_graph_def.graph_def, outputs) return graph_def, variables_values if model_dir: # saved model directory tags = saved_model_tags if saved_model_tags is not None else [ tf.saved_model.tag_constants.SERVING ] with tf.Session() as sess: meta_graph_def = tf.saved_model.loader.load( sess, tags, model_dir) outputs = get_output_node_names_list( meta_graph_def.graph_def, user_output_node_names_list) graph_def = tf.graph_util.convert_variables_to_constants( sess, meta_graph_def.graph_def, outputs) return graph_def, variables_values except Exception as e: raise FrameworkError('Cannot load input model: {}', e) from e raise Error("Unknown configuration of input model parameters")
def _ProcessEvent(self, event): """Called whenever an event is loaded.""" if self._first_event_timestamp is None: self._first_event_timestamp = event.wall_time if event.HasField('file_version'): new_file_version = _ParseFileVersion(event.file_version) if self.file_version and self.file_version != new_file_version: ## This should not happen. tf.logging.warn( ('Found new file_version for event.proto. This will ' 'affect purging logic for TensorFlow restarts. ' 'Old: {0} New: {1}').format(self.file_version, new_file_version)) self.file_version = new_file_version self._MaybePurgeOrphanedData(event) ## Process the event. # GraphDef and MetaGraphDef are handled in a special way: # If no graph_def Event is available, but a meta_graph_def is, and it # contains a graph_def, then use the meta_graph_def.graph_def as our graph. # If a graph_def Event is available, always prefer it to the graph_def # inside the meta_graph_def. if event.HasField('graph_def'): # print('graph_def') if self._graph is not None: tf.logging.warn( ('Found more than one graph event per run, or there was ' 'a metagraph containing a graph_def, as well as one or ' 'more graph events. Overwriting the graph with the ' 'newest event.')) self._graph = event.graph_def self._graph_from_metagraph = False elif event.HasField('meta_graph_def'): # print('meta_graph_def') if self._meta_graph is not None: tf.logging.warn( ('Found more than one metagraph event per run. ' 'Overwriting the metagraph with the newest event.')) self._meta_graph = event.meta_graph_def if self._graph is None or self._graph_from_metagraph: # We may have a graph_def in the metagraph. If so, and no # graph_def is directly available, use this one instead. meta_graph = tf.MetaGraphDef() meta_graph.ParseFromString(self._meta_graph) if meta_graph.graph_def: if self._graph is not None: tf.logging.warn(( 'Found multiple metagraphs containing graph_defs,' 'but did not find any graph events. Overwriting the ' 'graph with the newest metagraph version.')) self._graph_from_metagraph = True self._graph = meta_graph.graph_def.SerializeToString() elif event.HasField('tagged_run_metadata'): # print('tagged_run_metadata') tag = event.tagged_run_metadata.tag if tag in self._tagged_metadata: tf.logging.warn( 'Found more than one "run metadata" event with tag ' + tag + '. Overwriting it with the newest event.') self._tagged_metadata[tag] = event.tagged_run_metadata.run_metadata
import tensorflow as tf import sys gf = tf.MetaGraphDef() gf.ParseFromString(open(sys.argv[1], "rb").read()) print(gf)
def convert_int8(input_model_dir, output_model_dir, batch_size, precision_mode, calib_image_dir, input_tensor, output_tensor, epochs): # (TODO) Need to check if we need Tesla T4 when conversion. config = tf.ConfigProto() config.gpu_options.allow_growth = True # Get path to calibration data. calibration_files = get_calibration_files(calib_image_dir, 'validation*') # Create dataset and apply preprocess # (TODO) Get num cpus to set appropriate number to num_parallel_calls dataset = tf.data.TFRecordDataset(calibration_files) dataset = dataset.apply( tf.contrib.data.map_and_batch( map_func=preprocess, batch_size=batch_size, num_parallel_calls=multiprocessing.cpu_count())) """ Step 1: Creating the calibration graph. """ # Create TF-TRT INT8 calibration graph. trt_int8_calib_graph = trt.create_inference_graph( input_graph_def=None, outputs=[output_tensor], max_batch_size=batch_size, input_saved_model_dir=input_model_dir, precision_mode=precision_mode) # Calibrate graph. with tf.Session(graph=tf.Graph(), config=config) as sess: tf.logging.info('preparing calibration data...') iterator = dataset.make_one_shot_iterator() next_element = iterator.get_next() tf.logging.info('Loading INT8 calibration graph...') output_node = tf.import_graph_def(trt_int8_calib_graph, return_elements=[output_tensor], name='') tf.logging.info('Calibrate model with calibration data...') for _ in range(epochs): sess.run(output_node, feed_dict={input_tensor: sess.run(next_element)[0]}) """ Step 2: Converting the calibration graph to inference graph """ tf.logging.info('Creating TF-TRT INT8 inference engine...') trt_int8_calibrated_graph = trt.calib_graph_to_infer_graph( trt_int8_calib_graph) # Copy MetaGraph from base model. with tf.Session(graph=tf.Graph(), config=config) as sess: base_model = tf.saved_model.loader.load( sess, [tf.saved_model.tag_constants.SERVING], input_model_dir) metagraph = tf.MetaGraphDef() metagraph.graph_def.CopyFrom(trt_int8_calibrated_graph) for key in base_model.collection_def: if key not in [ 'variables', 'local_variables', 'model_variables', 'trainable_variables', 'train_op', 'table_initializer' ]: metagraph.collection_def[key].CopyFrom( base_model.collection_def[key]) metagraph.meta_info_def.CopyFrom(base_model.meta_info_def) for key in base_model.signature_def: metagraph.signature_def[key].CopyFrom( base_model.signature_def[key]) saved_model_builder = ( tf.saved_model.builder.SavedModelBuilder(output_model_dir)) # Write SavedModel with INT8 precision. with tf.Graph().as_default(): tf.graph_util.import_graph_def(trt_int8_calibrated_graph, return_elements=[output_tensor], name='') with tf.Session(config=config) as sess: saved_model_builder.add_meta_graph_and_variables( sess, ('serve', ), signature_def_map=metagraph.signature_def) # Ignore other meta graphs from the input SavedModel. saved_model_builder.save()
def export_graph(self, saved_model_dir): # Collapse supported subgraphs light_graph = self.get_collapsed_light_graph(self._light_graph) # Check meta graph info meta_graph_info = light_graph.meta_graph_info() if (tf_saved_model_base_importer.ImportTFSavedModelBase. PARTIAL_META_GRAPH_DEF not in meta_graph_info.original_graph_info): raise ValueError( "Could not find partial meta graph def in meta graph info: {}". format(meta_graph_info)) # Convert to TF Graph Def meta_graph = tf.MetaGraphDef() meta_graph.ParseFromString(meta_graph_info.original_graph_info[ tf_saved_model_base_importer.ImportTFSavedModelBase. PARTIAL_META_GRAPH_DEF].v) graph_def = meta_graph.graph_def variable_values = {} for lnf in light_graph.nodes(): node_def = graph_def.node.add() if lnf.HasField(lgf_pb2.LNF.subgraph.DESCRIPTOR.name): self._subgraph_node_to_node_def(node_def, lnf) else: self._original_node_def(node_def, lnf, variable_values) # Add placeholders for nodes that don't exist in the light_graph node_names = set(n.name for n in graph_def.node) for e in light_graph.input_edges(): if not light_graph.has_node(e.name): graph_def.node.add().CopyFrom( self._edge_to_node_def_placeholder(e)) node_names.add(e.name) # Load the graph def into a session (load custom ops first) tf_ops.load_ops() with tf.Session(graph=tf.Graph()) as sess: if (meta_graph.saver_def.save_tensor_name != ""): # TF only allows importing a meta graph def when the saver def # is fully initialized. This should only be the case for graphs # with TF variables assert (len(variable_values) > 0) tf.train.import_meta_graph(meta_graph) else: assert (len(variable_values) == 0) tf.import_graph_def(meta_graph.graph_def, name="") # Load variables for v in (tf.global_variables() + tf.local_variables() + tf.trainable_variables() + tf.model_variables()): node_name, _, _ = tf_saved_model_base_importer.ImportTFSavedModelBase.\ get_node_name_and_output_index(v.name) if node_name not in variable_values: raise ValueError( "Could not find value for variable {}".format( node_name)) v.load(variable_values[node_name]) # Get input and output tensors from the graph input_tensors = self._get_tf_tensors_from_graph( sess.graph, light_graph.input_edges()) output_tensors = self._get_tf_tensors_from_graph( sess.graph, light_graph.output_edges()) # Create the saved model self.save_model(saved_model_dir, sess, input_tensors=input_tensors, output_tensors=output_tensors, output_node_names=light_graph.output_node_names())