def testExportMonitorRegressionSignature(self): def _regression_signature(examples, unused_features, predictions): signatures = {} signatures['regression'] = (exporter.regression_signature( examples, predictions)) return signatures['regression'], signatures random.seed(42) x = np.random.rand(1000) y = 2 * x + 3 cont_features = [feature_column.real_valued_column('', dimension=1)] regressor = learn.LinearRegressor(feature_columns=cont_features) export_dir = os.path.join(tempfile.mkdtemp(), 'export') export_monitor = learn.monitors.ExportMonitor( every_n_steps=1, export_dir=export_dir, exports_to_keep=1, signature_fn=_regression_signature) regressor.fit(x, y, steps=10, monitors=[export_monitor]) self.assertTrue(gfile.Exists(export_dir)) with self.assertRaises(errors.NotFoundError): saver.checkpoint_exists( os.path.join(export_dir, '00000000', 'export')) self.assertTrue( saver.checkpoint_exists( os.path.join(export_dir, '00000010', 'export'))) # Validate the signature signature = self._get_default_signature( os.path.join(export_dir, '00000010', 'export.meta')) self.assertTrue(signature.HasField('regression_signature'))
def testExportMonitorRegressionSignature(self): def _regression_signature(examples, unused_features, predictions): signatures = {} signatures['regression'] = ( exporter.regression_signature(examples, predictions)) return signatures['regression'], signatures random.seed(42) x = np.random.rand(1000) y = 2 * x + 3 cont_features = [feature_column.real_valued_column('', dimension=1)] regressor = learn.LinearRegressor(feature_columns=cont_features) export_dir = os.path.join(tempfile.mkdtemp(), 'export') export_monitor = learn.monitors.ExportMonitor( every_n_steps=1, export_dir=export_dir, exports_to_keep=1, signature_fn=_regression_signature) regressor.fit(x, y, steps=10, monitors=[export_monitor]) self.assertTrue(gfile.Exists(export_dir)) with self.assertRaises(errors.NotFoundError): saver.checkpoint_exists(os.path.join(export_dir, '00000000', 'export')) self.assertTrue( saver.checkpoint_exists(os.path.join(export_dir, '00000010', 'export'))) # Validate the signature signature = self._get_default_signature( os.path.join(export_dir, '00000010', 'export.meta')) self.assertTrue(signature.HasField('regression_signature'))
def freeze_graph_with_def_protos( input_graph_def, input_saver_def, input_checkpoint, output_node_names, restore_op_name, filename_tensor_name, clear_devices, initializer_nodes, variable_names_blacklist=''): """Converts all variables in a graph and checkpoint into constants.""" del restore_op_name, filename_tensor_name # Unused by updated loading code. # 'input_checkpoint' may be a prefix if we're using Saver V2 format if not saver_lib.checkpoint_exists(input_checkpoint): raise ValueError( 'Input checkpoint "' + input_checkpoint + '" does not exist!') if not output_node_names: raise ValueError( 'You must supply the name of a node to --output_node_names.') # Remove all the explicit device specifications for this node. This helps to # make the graph more portable. if clear_devices: for node in input_graph_def.node: node.device = '' _ = importer.import_graph_def(input_graph_def, name='') with session.Session() as sess: if input_saver_def: saver = saver_lib.Saver(saver_def=input_saver_def) saver.restore(sess, input_checkpoint) else: var_list = {} reader = pywrap_tensorflow.NewCheckpointReader(input_checkpoint) var_to_shape_map = reader.get_variable_to_shape_map() for key in var_to_shape_map: try: tensor = sess.graph.get_tensor_by_name(key + ':0') except KeyError: # This tensor doesn't exist in the graph (for example it's # 'global_step' or a similar housekeeping element) so skip it. continue var_list[key] = tensor saver = saver_lib.Saver(var_list=var_list) saver.restore(sess, input_checkpoint) if initializer_nodes: sess.run(initializer_nodes) variable_names_blacklist = (variable_names_blacklist.split(',') if variable_names_blacklist else None) output_graph_def = graph_util.convert_variables_to_constants( sess, input_graph_def, output_node_names.split(','), variable_names_blacklist=variable_names_blacklist) return output_graph_def
def create_model(session, state): global _buckets # if FLAGS.qualvec or FLAGS.translate: # _buckets = [(5, 10), (10, 15), (20, 25), (40, 50), (60, 70), (70, 80), (90, 100)] # else: _buckets = [(5, 10), (10, 15), (20, 25), (40, 50)] """Create translation model and initialize or load parameters in session.""" dtype = tf.float16 if FLAGS.use_fp16 else tf.float32 model = paradet_model.QualVecModel( FLAGS.source_vocab_size, FLAGS.target_vocab_size, FLAGS.embedding_size, _buckets, FLAGS.size, FLAGS.maxout_size, FLAGS.num_layers, FLAGS.max_gradient_norm, FLAGS.batch_size, FLAGS.learning_rate, FLAGS.learning_rate_decay_factor, state=state, dtype=dtype) ckpt = tf.train.get_checkpoint_state(FLAGS.train_dir) if ckpt and save_mod.checkpoint_exists(ckpt.model_checkpoint_path): print("Reading model parameters from %s" % ckpt.model_checkpoint_path) model.saver.restore(session, ckpt.model_checkpoint_path) else: print("Created model with fresh parameters.") session.run(tf.global_variables_initializer()) return model
def create_model(session, state): """Create autoencoder model and initialize or load parameters in session.""" # dtype = tf.float16 if FLAGS.use_fp16 else tf.float32 dtype = tf.float32 model = autoencode_model.AutoencodeModel( FLAGS.source_vocab_size, FLAGS.target_vocab_size, FLAGS.embedding_size, _buckets, FLAGS.size, FLAGS.num_layers, FLAGS.max_gradient_norm, FLAGS.batch_size, FLAGS.learning_rate, FLAGS.learning_rate_decay_factor, state=state, dtype=dtype) ckpt = tf.train.get_checkpoint_state(FLAGS.train_dir) if ckpt and save_mod.checkpoint_exists(ckpt.model_checkpoint_path): print("Reading model parameters from %s" % ckpt.model_checkpoint_path) model.saver.restore(session, ckpt.model_checkpoint_path) else: print("Created model with fresh parameters.") session.run(tf.global_variables_initializer()) return model
def _read_latest_config_files(self, run_path_pairs): """Reads and returns the projector config files in every run directory.""" configs = {} config_fpaths = {} for run_name, logdir in run_path_pairs: config = ProjectorConfig() config_fpath = os.path.join(logdir, PROJECTOR_FILENAME) if file_io.file_exists(config_fpath): file_content = file_io.read_file_to_string(config_fpath) text_format.Merge(file_content, config) has_tensor_files = False for embedding in config.embeddings: if embedding.tensor_path: has_tensor_files = True break if not config.model_checkpoint_path: # See if you can find a checkpoint file in the logdir. ckpt_path = _find_latest_checkpoint(logdir) if not ckpt_path and not has_tensor_files: continue if ckpt_path: config.model_checkpoint_path = ckpt_path # Sanity check for the checkpoint file. if (config.model_checkpoint_path and not checkpoint_exists(config.model_checkpoint_path)): logging.warning('Checkpoint file %s not found', config.model_checkpoint_path) continue configs[run_name] = config config_fpaths[run_name] = config_fpath return configs, config_fpaths
def freeze_graph(input_graph: str, input_checkpoint: str, output_node_names: Iterable[str], output_graph: str): """ Convert all variables in a graph and checkpoint into constants and save the new graph to the specified file. Additionally, the graph is pruned of all nodes that are not needed for the specified outputs. ------------------------------------------------------- NOTE: this function creates new nodes in the default graph, you may want to wrap the call with the following code ------------------------------------------------------- with tf.Graph().as_default(): freeze_graph(...) ------------------------------------------------------- :param input_graph: path to the input graph file :param input_checkpoint: path to the input checkpoint :param output_node_names: iterable collection of output node names :param output_graph: path to the output frozen graph file Raises: ValueError: if any of the specified files does not exist """ if not gfile.Exists(input_graph): raise ValueError( 'Input graph file `{}` does not exist!'.format(input_graph)) # 'input_checkpoint' may be a prefix if we're using Saver V2 format if not saver_lib.checkpoint_exists(input_checkpoint): raise ValueError( 'Input checkpoint `{}` does not exist!'.format(input_checkpoint)) # read the graph definition input_graph_def = graph_pb2.GraphDef() with gfile.FastGFile(input_graph, 'rb') as file: input_graph_def.ParseFromString(file.read()) # remove all the explicit device specifications for this node. This helps to make the graph more portable. for node in input_graph_def.node: node.device = '' # restore the input graph and checkpoint _ = importer.import_graph_def(input_graph_def, name='') with session.Session() as sess: var_list = {} reader = pywrap_tensorflow.NewCheckpointReader(input_checkpoint) var_to_shape_map = reader.get_variable_to_shape_map() for key in var_to_shape_map: tensor = sess.graph.get_tensor_by_name(key + ':0') var_list[key] = tensor saver = saver_lib.Saver(var_list=var_list) saver.restore(sess, input_checkpoint) # convert all the variables to constants with DisabledLogger('tensorflow'), DisabledPrint(): output_graph_def = graph_util.convert_variables_to_constants( sess, input_graph_def, output_node_names) with gfile.GFile(output_graph, 'wb') as file: file.write(output_graph_def.SerializeToString())
def _read_latest_config_files(self, run_path_pairs): """Reads and returns the projector config files in every run directory.""" configs = {} config_fpaths = {} for run_name, assets_dir in run_path_pairs: config = projector_config_pb2.ProjectorConfig() config_fpath = os.path.join(assets_dir, PROJECTOR_FILENAME) if file_io.file_exists(config_fpath): file_content = file_io.read_file_to_string(config_fpath) text_format.Merge(file_content, config) has_tensor_files = False for embedding in config.embeddings: if embedding.tensor_path: has_tensor_files = True break if not config.model_checkpoint_path: # See if you can find a checkpoint file in the logdir. logdir = _assets_dir_to_logdir(assets_dir) ckpt_path = _find_latest_checkpoint(logdir) if not ckpt_path and not has_tensor_files: continue if ckpt_path: config.model_checkpoint_path = ckpt_path # Sanity check for the checkpoint file. if (config.model_checkpoint_path and not checkpoint_exists(config.model_checkpoint_path)): logging.warning('Checkpoint file "%s" not found', config.model_checkpoint_path) continue configs[run_name] = config config_fpaths[run_name] = config_fpath return configs, config_fpaths
def _assert_export(self, export_monitor, export_dir, expected_signature): self.assertTrue(gfile.Exists(export_dir)) # Only the written checkpoints are exported. self.assertTrue( saver.checkpoint_exists(export_dir + '00000001/export'), 'Exported checkpoint expected but not found: %s' % (export_dir + '00000001/export')) self.assertTrue( saver.checkpoint_exists(export_dir + '00000010/export'), 'Exported checkpoint expected but not found: %s' % (export_dir + '00000010/export')) self.assertEquals(six.b(os.path.join(export_dir, '00000010')), export_monitor.last_export_dir) # Validate the signature signature = self._get_default_signature(export_dir + '00000010/export.meta') self.assertTrue(signature.HasField(expected_signature))
def _assert_export(self, export_monitor, export_dir, expected_signature): self.assertTrue(gfile.Exists(export_dir)) # Only the written checkpoints are exported. self.assertTrue( saver.checkpoint_exists(export_dir + '00000001/export'), 'Exported checkpoint expected but not found: %s' % (export_dir + '00000001/export')) self.assertTrue( saver.checkpoint_exists(export_dir + '00000010/export'), 'Exported checkpoint expected but not found: %s' % (export_dir + '00000010/export')) self.assertEquals( six.b(os.path.join(export_dir, '00000010')), export_monitor.last_export_dir) # Validate the signature signature = self._get_default_signature(export_dir + '00000010/export.meta') self.assertTrue(signature.HasField(expected_signature))
def freeze_graph(layer, unit, input_names, output_names, accuracy): frozen_model_path = "data/{}layer{}unit.pb".format(layer, unit) checkpoint_file = "data/{}layer{}unit.ckpt".format(layer, unit) if not saver_lib.checkpoint_exists(checkpoint_file): print("Checkpoint file '" + checkpoint_file + "' doesn't exist!") exit(-1) print("begin loading model") saver = tf.train.import_meta_graph(checkpoint_file + '.meta', clear_devices=True) # We retrieve the protobuf graph definition graph = tf.get_default_graph() input_graph_def = graph.as_graph_def() with tf.Session() as sess: saver.restore(sess, checkpoint_file) print("model loaded") # export network weights and biases to text files if not __debug__: output_nodes = "w_in,b_in,w_out,b_out" rnn_nodes = [",rnn/multi_rnn_cell/cell_{}/basic_lstm_cell/weights," \ "rnn/multi_rnn_cell/cell_{}/basic_lstm_cell/biases".format(i, i) for i in range(args.layer)] weights = output_nodes + "".join(rnn_nodes) for name in weights.split(","): v = sess.run("{}:0".format(name)) var_file_name = "data/{}.csv".format(name.replace("/", "_")) print("save {} to file: {}".format(name, var_file_name)) np.savetxt(var_file_name, v, delimiter=",") # We use a built-in TF helper to export variables to constants output_graph_def = graph_util.convert_variables_to_constants( sess, # The session is used to retrieve the weights input_graph_def, # The graph_def is used to retrieve the nodes output_node_names=output_names.split( "," ) # The output node names are used to select the useful nodes ) # optimize graph output_graph_def = opt_inference(output_graph_def, input_names.split(","), output_names.split(","), dtypes.float32.as_datatype_enum) # Finally we serialize and dump the output graph to the filesystem with tf.gfile.GFile(frozen_model_path, "wb") as f: f.write(output_graph_def.SerializeToString()) print("frozen graph binary saved to: {}".format(frozen_model_path)) frozen_model_text_path = "data/{}layer{}unit{}.pb.txt".format( layer, unit, accuracy) with tf.gfile.FastGFile(frozen_model_text_path, "wb") as f: f.write(str(output_graph_def)) print("frozen graph text saved to: {}".format( frozen_model_text_path)) print("%d ops in the final graph." % len(output_graph_def.node))
def freeze_graph(config): input_names = config.input_names output_names = config.output_names if output_names is None: output_names = [ "eval_concat/yp", "eval_concat/yp2", "eval_concat/wy", "eval_concat/loss" ] if not os.path.exists(config.output_path): os.makedirs(config.output_path) frozen_model_path = os.path.join(config.output_path, "frozen_model.pb") checkpoint_file = config.input_path if not saver_lib.checkpoint_exists(checkpoint_file): print("Checkpoint file '" + checkpoint_file + "' doesn't exist!") exit(-1) print("begin loading model") saver = tf.train.import_meta_graph(checkpoint_file + '.meta', clear_devices=config.clear_device) graph = tf.get_default_graph() input_graph_def = graph.as_graph_def() with tf.Session(config=tf.ConfigProto(allow_soft_placement=True)) as sess: saver.restore(sess, checkpoint_file) print("model loaded") output_graph_def = graph_util.convert_variables_to_constants( sess, input_graph_def, output_node_names=output_names) if input_names is not None: output_graph_def = opt_inference(output_graph_def, input_names, output_names, dtypes.float32.as_datatype_enum) with tf.gfile.GFile(frozen_model_path, "wb") as f: f.write(output_graph_def.SerializeToString()) print("frozen graph binary saved to: {}".format(frozen_model_path)) frozen_model_text_path = "{}.txt".format(frozen_model_path) with tf.gfile.FastGFile(frozen_model_text_path, "wb") as f: f.write(str(output_graph_def)) print("frozen graph text saved to: {}".format( frozen_model_text_path)) print("%d ops in the final graph." % len(output_graph_def.node))
def _wait_for_glob(self, pattern, timeout_secs, for_checkpoint=True): """Wait for a checkpoint file to appear. Args: pattern: A string. timeout_secs: How long to wait for in seconds. for_checkpoint: whether we're globbing for checkpoints. """ end_time = time.time() + timeout_secs while time.time() < end_time: if for_checkpoint: if saver_lib.checkpoint_exists(pattern): return else: if len(gfile.Glob(pattern)) >= 1: return time.sleep(0.05) self.assertFalse(True, "Glob never matched any file: %s" % pattern)
def scan_graph(input_checkpoint=None, input_saved_model_dir=None, saved_model_tags=tag_constants.SERVING): """extract the graph to scan from a model file.""" if (not input_saved_model_dir and not input_checkpoint): print("Please specify a checkpoint or \'SavedModel\' file!") return -1 if (input_saved_model_dir and input_checkpoint): print("Please specify only *One* model file type: \ checkpoint or \'SavedModel\'!") return -1 input_graph_def = None if input_checkpoint: # now we doesn't use the variables file, but still check it for completeness if not saver_lib.checkpoint_exists(input_checkpoint): print("Input checkpoint '" + input_checkpoint + "' doesn't exist!") return -1 # Build meta file path for a checkpoint meta_file = input_checkpoint + ".meta" if not gfile.Exists(meta_file): print("Input checkpoint meta file '" + meta_file + "' doesn't exist!") return -1 try: input_graph_def = _parse_input_meta_graph_proto(meta_file, True).graph_def except: exctype, value = sys.exc_info()[:2] print("Parse checkpoint meta-graph file '%s' failed: %s(%s)" %\ (meta_file, exctype, value)) return -1 if input_saved_model_dir: try: input_graph_def = saved_model_utils.get_meta_graph_def( input_saved_model_dir, saved_model_tags).graph_def except: exctype, value = sys.exc_info()[:2] print("Parse SaveModel '%s' meta-graph file failed: %s(%s)" %\ (input_saved_model_dir, exctype, value)) return -1 return detect_ops(input_graph_def)
def _read_config_files(self, run_paths, logdir): # If there are no summary event files, the projector can still work, # thus treating the `logdir` as the model checkpoint directory. if not run_paths: run_paths['.'] = logdir configs = {} config_fpaths = {} for run_name, logdir in run_paths.items(): config = ProjectorConfig() config_fpath = os.path.join(logdir, PROJECTOR_FILENAME) if file_io.file_exists(config_fpath): file_content = file_io.read_file_to_string( config_fpath).decode('utf-8') text_format.Merge(file_content, config) has_tensor_files = False for embedding in config.embeddings: if embedding.tensor_path: has_tensor_files = True break if not config.model_checkpoint_path: # See if you can find a checkpoint file in the logdir. ckpt_path = latest_checkpoint(logdir) if not ckpt_path: # Or in the parent of logdir. ckpt_path = latest_checkpoint(os.path.join('../', logdir)) if not ckpt_path and not has_tensor_files: logging.warning('Cannot find model checkpoint in %s', logdir) continue if ckpt_path: config.model_checkpoint_path = ckpt_path # Sanity check for the checkpoint file. if (config.model_checkpoint_path and not checkpoint_exists(config.model_checkpoint_path)): logging.warning('Checkpoint file %s not found', config.model_checkpoint_path) continue configs[run_name] = config config_fpaths[run_name] = config_fpath return configs, config_fpaths
def _read_config_files(self, run_paths, logdir): # If there are no summary event files, the projector can still work, # thus treating the `logdir` as the model checkpoint directory. if not run_paths: run_paths['.'] = logdir configs = {} config_fpaths = {} for run_name, logdir in run_paths.items(): config = ProjectorConfig() config_fpath = os.path.join(logdir, PROJECTOR_FILENAME) if file_io.file_exists(config_fpath): file_content = file_io.read_file_to_string(config_fpath).decode('utf-8') text_format.Merge(file_content, config) has_tensor_files = False for embedding in config.embeddings: if embedding.tensor_path: has_tensor_files = True break if not config.model_checkpoint_path: # See if you can find a checkpoint file in the logdir. ckpt_path = latest_checkpoint(logdir) if not ckpt_path: # Or in the parent of logdir. ckpt_path = latest_checkpoint(os.path.join('../', logdir)) if not ckpt_path and not has_tensor_files: logging.warning('Cannot find model checkpoint in %s', logdir) continue if ckpt_path: config.model_checkpoint_path = ckpt_path # Sanity check for the checkpoint file. if (config.model_checkpoint_path and not checkpoint_exists(config.model_checkpoint_path)): logging.warning('Checkpoint file %s not found', config.model_checkpoint_path) continue configs[run_name] = config config_fpaths[run_name] = config_fpath return configs, config_fpaths
def freeza_graph_with_def_protos(input_graph_def,input_saver_def,input_checkpoint,output_node_names,restore_op_name,filename_tensor_name,output_graph,clear_devices,initializer_nodes,variable_names_blacklist=""): """Converts all variables in a graph and checkpoint into constants.""" del restore_op_name,filename_tensor_name if not saver_lib.checkpoint_exists(input_checkpoint): #???? print("Input checkpoint'" + input_checkpoint + " 'doesn't exist!") return -1 if not output_node_names: print("you need to supply the name of a node to --output_node_names") return -1 if clear_devices: #??? for node in input_graph_def.node: #input_graph_def存的是图结构??? node.device = "" _ = importer.import_graph_def(input_graph_def,name="") #??? #options ops:图上的结点 with session.Session() as sess: if input_saver_def: #input_saver_def存的是variable???
def freeze_graph_with_def_protos(input_graph_def, input_saver_def, input_checkpoint, output_node_names, restore_op_name, filename_tensor_name, output_graph, clear_devices, initializer_nodes, variable_names_whitelist="", variable_names_blacklist="", input_meta_graph_def=None, input_saved_model_dir=None, saved_model_tags=None, checkpoint_version=saver_pb2.SaverDef.V2): """Converts all variables in a graph and checkpoint into constants.""" del restore_op_name, filename_tensor_name # Unused by updated loading code. # 'input_checkpoint' may be a prefix if we're using Saver V2 format if (not input_saved_model_dir and not saver_lib.checkpoint_exists(input_checkpoint)): print("Input checkpoint '" + input_checkpoint + "' doesn't exist!") return -1 if not output_node_names: print("You need to supply the name of a node to --output_node_names.") return -1 # Remove all the explicit device specifications for this node. This helps to # make the graph more portable. if clear_devices: if input_meta_graph_def: for node in input_meta_graph_def.graph_def.node: node.device = "" elif input_graph_def: for node in input_graph_def.node: node.device = "" if input_graph_def: _ = importer.import_graph_def(input_graph_def, name="") with session.Session() as sess: if input_saver_def: saver = saver_lib.Saver( saver_def=input_saver_def, write_version=checkpoint_version) saver.restore(sess, input_checkpoint) elif input_meta_graph_def: restorer = saver_lib.import_meta_graph( input_meta_graph_def, clear_devices=True) restorer.restore(sess, input_checkpoint) if initializer_nodes: sess.run(initializer_nodes.replace(" ", "").split(",")) elif input_saved_model_dir: if saved_model_tags is None: saved_model_tags = [] loader.load(sess, saved_model_tags, input_saved_model_dir) else: var_list = {} reader = pywrap_tensorflow.NewCheckpointReader(input_checkpoint) var_to_shape_map = reader.get_variable_to_shape_map() for key in var_to_shape_map: try: tensor = sess.graph.get_tensor_by_name(key + ":0") except KeyError: # This tensor doesn't exist in the graph (for example it's # 'global_step' or a similar housekeeping element) so skip it. continue var_list[key] = tensor saver = saver_lib.Saver( var_list=var_list, write_version=checkpoint_version) saver.restore(sess, input_checkpoint) if initializer_nodes: sess.run(initializer_nodes.replace(" ", "").split(",")) variable_names_whitelist = ( variable_names_whitelist.replace(" ", "").split(",") if variable_names_whitelist else None) variable_names_blacklist = ( variable_names_blacklist.replace(" ", "").split(",") if variable_names_blacklist else None) if input_meta_graph_def: output_graph_def = graph_util.convert_variables_to_constants( sess, input_meta_graph_def.graph_def, output_node_names.replace(" ", "").split(","), variable_names_whitelist=variable_names_whitelist, variable_names_blacklist=variable_names_blacklist) else: output_graph_def = graph_util.convert_variables_to_constants( sess, input_graph_def, output_node_names.replace(" ", "").split(","), variable_names_whitelist=variable_names_whitelist, variable_names_blacklist=variable_names_blacklist) # Write GraphDef to file if output path has been given. if output_graph: with gfile.GFile(output_graph, "wb") as f: f.write(output_graph_def.SerializeToString()) return output_graph_def
def check_input_checkpoint(input_checkpoint): """Check if input_checkpoint is a valid path or path prefix.""" if not saver_lib.checkpoint_exists(input_checkpoint): print("Input checkpoint '{}' doesn't exist!".format(input_checkpoint)) exit(-1)
def freeze_graph_with_def_protos( input_graph_def, input_saver_def, input_checkpoint, output_node_names, restore_op_name, filename_tensor_name, output_graph, clear_devices, initializer_nodes, variable_names_blacklist=""): """Converts all variables in a graph and checkpoint into constants.""" del restore_op_name, filename_tensor_name # Unused by updated loading code. # 'input_checkpoint' may be a prefix if we're using Saver V2 format if not saver_lib.checkpoint_exists(input_checkpoint): print("Input checkpoint '" + input_checkpoint + "' doesn't exist!") return -1 if not output_node_names: print("You need to supply the name of a node to --output_node_names.") return -1 # Remove all the explicit device specifications for this node. This helps to # make the graph more portable. if clear_devices: for node in input_graph_def.node: node.device = "" _ = importer.import_graph_def(input_graph_def, name="") with session.Session() as sess: if input_saver_def: saver = saver_lib.Saver(saver_def=input_saver_def) saver.restore(sess, input_checkpoint) else: var_list = {} reader = pywrap_tensorflow.NewCheckpointReader(input_checkpoint) var_to_shape_map = reader.get_variable_to_shape_map() for key in var_to_shape_map: try: tensor = sess.graph.get_tensor_by_name(key + ":0") except KeyError: # This tensor doesn't exist in the graph (for example it's # 'global_step' or a similar housekeeping element) so skip it. continue var_list[key] = tensor saver = saver_lib.Saver(var_list=var_list) saver.restore(sess, input_checkpoint) if initializer_nodes: sess.run(initializer_nodes) variable_names_blacklist = (variable_names_blacklist.split(",") if variable_names_blacklist else None) output_graph_def = graph_util.convert_variables_to_constants( sess, input_graph_def, output_node_names.split(","), variable_names_blacklist=variable_names_blacklist) with gfile.GFile(output_graph, "wb") as f: f.write(output_graph_def.SerializeToString()) print("%d ops in the final graph." % len(output_graph_def.node))
def freeze_graph_with_def_protos(input_graph_def, input_saver_def, input_checkpoint, output_node_names, restore_op_name, filename_tensor_name, clear_devices, initializer_nodes, optimize_graph=True, variable_names_blacklist=''): """Converts all variables in a graph and checkpoint into constants.""" del restore_op_name, filename_tensor_name # Unused by updated loading code. # 'input_checkpoint' may be a prefix if we're using Saver V2 format if not saver_lib.checkpoint_exists(input_checkpoint): raise ValueError('Input checkpoint "' + input_checkpoint + '" does not exist!') if not output_node_names: raise ValueError( 'You must supply the name of a node to --output_node_names.') # Remove all the explicit device specifications for this node. This helps to # make the graph more portable. if clear_devices: for node in input_graph_def.node: node.device = '' with tf.Graph().as_default(): tf.import_graph_def(input_graph_def, name='') if optimize_graph: logging.info('Graph Rewriter optimizations enabled') rewrite_options = rewriter_config_pb2.RewriterConfig( layout_optimizer=rewriter_config_pb2.RewriterConfig.ON) rewrite_options.optimizers.append('pruning') rewrite_options.optimizers.append('constfold') rewrite_options.optimizers.append('layout') graph_options = tf.GraphOptions(rewrite_options=rewrite_options, infer_shapes=True) else: logging.info('Graph Rewriter optimizations disabled') graph_options = tf.GraphOptions() config = tf.ConfigProto(graph_options=graph_options) with session.Session(config=config) as sess: if input_saver_def: saver = saver_lib.Saver(saver_def=input_saver_def) saver.restore(sess, input_checkpoint) else: var_list = {} reader = pywrap_tensorflow.NewCheckpointReader( input_checkpoint) var_to_shape_map = reader.get_variable_to_shape_map() for key in var_to_shape_map: try: tensor = sess.graph.get_tensor_by_name(key + ':0') except KeyError: # This tensor doesn't exist in the graph (for example it's # 'global_step' or a similar housekeeping element) so skip it. continue var_list[key] = tensor saver = saver_lib.Saver(var_list=var_list) saver.restore(sess, input_checkpoint) if initializer_nodes: sess.run(initializer_nodes) variable_names_blacklist = (variable_names_blacklist.split(',') if variable_names_blacklist else None) output_graph_def = graph_util.convert_variables_to_constants( sess, input_graph_def, output_node_names.split(','), variable_names_blacklist=variable_names_blacklist) return output_graph_def
def freeze_graph_with_def_protos(input_graph_def, input_saver_def, input_checkpoint, output_node_names, restore_op_name, filename_tensor_name, output_graph, clear_devices, initializer_nodes, variable_names_whitelist="", variable_names_blacklist="", input_meta_graph_def=None, input_saved_model_dir=None, saved_model_tags=None, checkpoint_version=saver_pb2.SaverDef.V2): """Converts all variables in a graph and checkpoint into constants.""" del restore_op_name, filename_tensor_name # Unused by updated loading code. # 'input_checkpoint' may be a prefix if we're using Saver V2 format if (not input_saved_model_dir and not saver_lib.checkpoint_exists(input_checkpoint)): print("Input checkpoint '" + input_checkpoint + "' doesn't exist!") return -1 if not output_node_names: print("You need to supply the name of a node to --output_node_names.") return -1 # Remove all the explicit device specifications for this node. This helps to # make the graph more portable. if clear_devices: if input_meta_graph_def: for node in input_meta_graph_def.graph_def.node: node.device = "" elif input_graph_def: for node in input_graph_def.node: node.device = "" if input_graph_def: _ = importer.import_graph_def(input_graph_def, name="") with session.Session() as sess: if input_saver_def: saver = saver_lib.Saver( saver_def=input_saver_def, write_version=checkpoint_version) saver.restore(sess, input_checkpoint) elif input_meta_graph_def: restorer = saver_lib.import_meta_graph( input_meta_graph_def, clear_devices=True) restorer.restore(sess, input_checkpoint) if initializer_nodes: sess.run(initializer_nodes.replace(" ", "").split(",")) elif input_saved_model_dir: if saved_model_tags is None: saved_model_tags = [] loader.load(sess, saved_model_tags, input_saved_model_dir) else: var_list = {} reader = pywrap_tensorflow.NewCheckpointReader(input_checkpoint) var_to_shape_map = reader.get_variable_to_shape_map() # List of all partition variables. Because the condition is heuristic # based, the list could include false positives. all_parition_variable_names = [ tensor.name.split(":")[0] for op in sess.graph.get_operations() for tensor in op.values() if re.search(r"/part_\d+/", tensor.name) ] has_partition_var = False for key in var_to_shape_map: try: tensor = sess.graph.get_tensor_by_name(key + ":0") if any(key in name for name in all_parition_variable_names): has_partition_var = True except KeyError: # This tensor doesn't exist in the graph (for example it's # 'global_step' or a similar housekeeping element) so skip it. continue var_list[key] = tensor try: saver = saver_lib.Saver( var_list=var_list, write_version=checkpoint_version) except TypeError as e: # `var_list` is required to be a map of variable names to Variable # tensors. Partition variables are Identity tensors that cannot be # handled by Saver. if has_partition_var: print("Models containing partition variables cannot be converted " "from checkpoint files. Please pass in a SavedModel using " "the flag --input_saved_model_dir.") return -1 else: raise e saver.restore(sess, input_checkpoint) if initializer_nodes: sess.run(initializer_nodes.replace(" ", "").split(",")) variable_names_whitelist = ( variable_names_whitelist.replace(" ", "").split(",") if variable_names_whitelist else None) variable_names_blacklist = ( variable_names_blacklist.replace(" ", "").split(",") if variable_names_blacklist else None) if input_meta_graph_def: output_graph_def = graph_util.convert_variables_to_constants( sess, input_meta_graph_def.graph_def, output_node_names.replace(" ", "").split(","), variable_names_whitelist=variable_names_whitelist, variable_names_blacklist=variable_names_blacklist) else: output_graph_def = graph_util.convert_variables_to_constants( sess, input_graph_def, output_node_names.replace(" ", "").split(","), variable_names_whitelist=variable_names_whitelist, variable_names_blacklist=variable_names_blacklist) # Write GraphDef to file if output path has been given. if output_graph: with gfile.GFile(output_graph, "wb") as f: f.write(output_graph_def.SerializeToString()) return output_graph_def
def freeze_graph_with_def_protos(input_graph_def, input_saver_def, input_checkpoint, output_node_names, restore_op_name, filename_tensor_name, output_graph, clear_devices, initializer_nodes, variable_names_whitelist="", variable_names_blacklist="", input_meta_graph_def=None, input_saved_model_dir=None, saved_model_tags=None, checkpoint_version=saver_pb2.SaverDef.V2): """Converts all variables in a graph and checkpoint into constants.""" del restore_op_name, filename_tensor_name # Unused by updated loading code. # 'input_checkpoint' may be a prefix if we're using Saver V2 format if (not input_saved_model_dir and not saver_lib.checkpoint_exists(input_checkpoint)): print("Input checkpoint '" + input_checkpoint + "' doesn't exist!") return -1 if not output_node_names: print("You need to supply the name of a node to --output_node_names.") return -1 # Remove all the explicit device specifications for this node. This helps to # make the graph more portable. if clear_devices: if input_meta_graph_def: for node in input_meta_graph_def.graph_def.node: node.device = "" elif input_graph_def: for node in input_graph_def.node: node.device = "" if input_graph_def: _ = importer.import_graph_def(input_graph_def, name="") with session.Session() as sess: if input_saver_def: saver = saver_lib.Saver(saver_def=input_saver_def, write_version=checkpoint_version) saver.restore(sess, input_checkpoint) elif input_meta_graph_def: restorer = saver_lib.import_meta_graph( input_meta_graph_def, clear_devices=True) restorer.restore(sess, input_checkpoint) if initializer_nodes: sess.run(initializer_nodes.split(",")) elif input_saved_model_dir: if saved_model_tags is None: saved_model_tags = [] loader.load(sess, saved_model_tags, input_saved_model_dir) else: var_list = {} reader = pywrap_tensorflow.NewCheckpointReader(input_checkpoint) var_to_shape_map = reader.get_variable_to_shape_map() for key in var_to_shape_map: try: tensor = sess.graph.get_tensor_by_name(key + ":0") except KeyError: # This tensor doesn't exist in the graph (for example it's # 'global_step' or a similar housekeeping element) so skip it. continue var_list[key] = tensor saver = saver_lib.Saver(var_list=var_list, write_version=checkpoint_version) saver.restore(sess, input_checkpoint) if initializer_nodes: sess.run(initializer_nodes.split(",")) variable_names_whitelist = (variable_names_whitelist.split(",") if variable_names_whitelist else None) variable_names_blacklist = (variable_names_blacklist.split(",") if variable_names_blacklist else None) if input_meta_graph_def: output_graph_def = graph_util.convert_variables_to_constants( sess, input_meta_graph_def.graph_def, output_node_names.split(","), variable_names_whitelist=variable_names_whitelist, variable_names_blacklist=variable_names_blacklist) else: output_graph_def = graph_util.convert_variables_to_constants( sess, input_graph_def, output_node_names.split(","), variable_names_whitelist=variable_names_whitelist, variable_names_blacklist=variable_names_blacklist) # Write GraphDef to file if output path has been given. if output_graph: with gfile.GFile(output_graph, "wb") as f: f.write(output_graph_def.SerializeToString()) return output_graph_def
def freeze_graph(input_graph, input_saver, input_binary, input_checkpoint, output_node_names, restore_op_name, filename_tensor_name, output_graph, clear_devices, initializer_nodes, variable_names_blacklist=""): """Converts all variables in a graph and checkpoint into constants.""" del restore_op_name, filename_tensor_name # Unused by updated loading code. if not gfile.Exists(input_graph): print("Input graph file '" + input_graph + "' does not exist!") return -1 if input_saver and not gfile.Exists(input_saver): print("Input saver file '" + input_saver + "' does not exist!") return -1 # 'input_checkpoint' may be a prefix if we're using Saver V2 format if not saver_lib.checkpoint_exists(input_checkpoint): print("Input checkpoint '" + input_checkpoint + "' doesn't exist!") return -1 if not output_node_names: print("You need to supply the name of a node to --output_node_names.") return -1 input_graph_def = graph_pb2.GraphDef() mode = "rb" if input_binary else "r" with gfile.FastGFile(input_graph, mode) as f: if input_binary: input_graph_def.ParseFromString(f.read()) else: text_format.Merge(f.read(), input_graph_def) # Remove all the explicit device specifications for this node. This helps to # make the graph more portable. if clear_devices: for node in input_graph_def.node: node.device = "" _ = importer.import_graph_def(input_graph_def, name="") with session.Session() as sess: if input_saver: with gfile.FastGFile(input_saver, mode) as f: saver_def = saver_pb2.SaverDef() if input_binary: saver_def.ParseFromString(f.read()) else: text_format.Merge(f.read(), saver_def) saver = saver_lib.Saver(saver_def=saver_def) saver.restore(sess, input_checkpoint) else: var_list = {} reader = pywrap_tensorflow.NewCheckpointReader(input_checkpoint) var_to_shape_map = reader.get_variable_to_shape_map() for key in var_to_shape_map: try: tensor = sess.graph.get_tensor_by_name(key + ":0") except KeyError: # This tensor doesn't exist in the graph (for example it's # 'global_step' or a similar housekeeping element) so skip it. continue var_list[key] = tensor saver = saver_lib.Saver(var_list=var_list) saver.restore(sess, input_checkpoint) if initializer_nodes: sess.run(initializer_nodes) variable_names_blacklist = (variable_names_blacklist.split(",") if variable_names_blacklist else None) output_graph_def = graph_util.convert_variables_to_constants( sess, input_graph_def, output_node_names.split(","), variable_names_blacklist=variable_names_blacklist) with gfile.GFile(output_graph, "wb") as f: f.write(output_graph_def.SerializeToString()) print("%d ops in the final graph." % len(output_graph_def.node))
def freeze_graph(input_graph, input_saver, input_binary, input_checkpoint, output_node_names, restore_op_name, filename_tensor_name, output_graph, clear_devices, initializer_nodes, variable_names_blacklist=""): """Converts all variables in a graph and checkpoint into constants.""" # Unused by updated loading code. del restore_op_name, filename_tensor_name if not gfile.Exists(input_graph): print("Input graph file '" + input_graph + "' does not exist!") return -1 if input_saver and not gfile.Exists(input_saver): print("Input saver file '" + input_saver + "' does not exist!") return -1 # 'input_checkpoint' may be a prefix if we're using Saver V2 format if not saver_lib.checkpoint_exists(input_checkpoint): print("Input checkpoint '" + input_checkpoint + "' doesn't exist!") return -1 if not output_node_names: print("You need to supply the name of a node to --output_node_names.") return -1 input_graph_def = graph_pb2.GraphDef() mode = "rb" if input_binary else "r" with gfile.FastGFile(input_graph, mode) as f: if input_binary: input_graph_def.ParseFromString(f.read()) else: text_format.Merge(f.read().decode("utf-8"), input_graph_def) # Remove all the explicit device specifications for this node. This helps to # make the graph more portable. if clear_devices: for node in input_graph_def.node: node.device = "" _ = importer.import_graph_def(input_graph_def, name="") with session.Session() as sess: if input_saver: with gfile.FastGFile(input_saver, mode) as f: saver_def = saver_pb2.SaverDef() if input_binary: saver_def.ParseFromString(f.read()) else: text_format.Merge(f.read(), saver_def) saver = saver_lib.Saver(saver_def=saver_def) saver.restore(sess, input_checkpoint) else: var_list = {} reader = pywrap_tensorflow.NewCheckpointReader(input_checkpoint) var_to_shape_map = reader.get_variable_to_shape_map() for key in var_to_shape_map: try: tensor = sess.graph.get_tensor_by_name(key + ":0") except KeyError: # This tensor doesn't exist in the graph (for example it's # 'global_step' or a similar housekeeping element) so skip it. print(key) continue var_list[key] = tensor # """ Print ops name def _node_name(n): if n.startswith("^"): return n[1:] else: return n.split(":")[0] name_to_node_map = {} # Keyed by node name. for node in input_graph_def.node: n = _node_name(node.name) name_to_node_map[n] = node # print(name_to_node_map.keys()) # """ saver = saver_lib.Saver(var_list=var_list) saver.restore(sess, input_checkpoint) if initializer_nodes: sess.run(initializer_nodes) variable_names_blacklist = (variable_names_blacklist.split(",") if variable_names_blacklist else None) output_graph_def = graph_util.convert_variables_to_constants( sess, input_graph_def, output_node_names.split(",")) # variable_names_blacklist=variable_names_blacklist) with gfile.GFile(output_graph, "wb") as f: f.write(output_graph_def.SerializeToString()) print("%d ops in the final graph." % len(output_graph_def.node))
def freeze_graph(input_graph, input_saver, input_binary, input_checkpoint, output_node_names, restore_op_name, filename_tensor_name, output_graph, clear_devices, initializer_nodes, variable_names_blacklist=""): """Converts all variables in a graph and checkpoint into constants.""" if not gfile.Exists(input_graph): print("Input graph file '" + input_graph + "' does not exist!") return -1 if input_saver and not gfile.Exists(input_saver): print("Input saver file '" + input_saver + "' does not exist!") return -1 # 'input_checkpoint' may be a prefix if we're using Saver V2 format if not saver_lib.checkpoint_exists(input_checkpoint): print("Input checkpoint '" + input_checkpoint + "' doesn't exist!") return -1 if not output_node_names: print("You need to supply the name of a node to --output_node_names.") return -1 input_graph_def = graph_pb2.GraphDef() mode = "rb" if input_binary else "r" with gfile.FastGFile(input_graph, mode) as f: if input_binary: input_graph_def.ParseFromString(f.read()) else: text_format.Merge(f.read().decode("utf-8"), input_graph_def) # Remove all the explicit device specifications for this node. This helps to # make the graph more portable. if clear_devices: for node in input_graph_def.node: node.device = "" _ = importer.import_graph_def(input_graph_def, name="") with session.Session() as sess: if input_saver: with gfile.FastGFile(input_saver, mode) as f: saver_def = saver_pb2.SaverDef() if input_binary: saver_def.ParseFromString(f.read()) else: text_format.Merge(f.read(), saver_def) saver = saver_lib.Saver(saver_def=saver_def) saver.restore(sess, input_checkpoint) else: sess.run([restore_op_name], {filename_tensor_name: input_checkpoint}) if initializer_nodes: sess.run(initializer_nodes) variable_names_blacklist = (variable_names_blacklist.split(",") if variable_names_blacklist else None) output_graph_def = graph_util.convert_variables_to_constants( sess, input_graph_def, output_node_names.split(","), variable_names_blacklist=variable_names_blacklist) with gfile.GFile(output_graph, "wb") as f: f.write(output_graph_def.SerializeToString()) print("%d ops in the final graph." % len(output_graph_def.node))
def freeze_graph_with_def_protos( input_graph_def, input_saver_def, input_checkpoint, output_node_names, restore_op_name, filename_tensor_name, clear_devices, initializer_nodes, optimize_graph=False, variable_names_blacklist=''): """Converts all variables in a graph and checkpoint into constants.""" del restore_op_name, filename_tensor_name # Unused by updated loading code. # 'input_checkpoint' may be a prefix if we're using Saver V2 format if not saver_lib.checkpoint_exists(input_checkpoint): raise ValueError( 'Input checkpoint "' + input_checkpoint + '" does not exist!') if not output_node_names: raise ValueError( 'You must supply the name of a node to --output_node_names.') # Remove all the explicit device specifications for this node. This helps to # make the graph more portable. if clear_devices: for node in input_graph_def.node: node.device = '' with tf.Graph().as_default(): tf.import_graph_def(input_graph_def, name='') if optimize_graph: logging.info('Graph Rewriter optimizations enabled') rewrite_options = rewriter_config_pb2.RewriterConfig( optimize_tensor_layout=True) rewrite_options.optimizers.append('pruning') rewrite_options.optimizers.append('constfold') rewrite_options.optimizers.append('layout') graph_options = tf.GraphOptions( rewrite_options=rewrite_options, infer_shapes=True) else: logging.info('Graph Rewriter optimizations disabled') graph_options = tf.GraphOptions() config = tf.ConfigProto(graph_options=graph_options) with session.Session(config=config) as sess: if input_saver_def: saver = saver_lib.Saver(saver_def=input_saver_def) saver.restore(sess, input_checkpoint) else: var_list = {} reader = pywrap_tensorflow.NewCheckpointReader(input_checkpoint) var_to_shape_map = reader.get_variable_to_shape_map() for key in var_to_shape_map: try: tensor = sess.graph.get_tensor_by_name(key + ':0') except KeyError: # This tensor doesn't exist in the graph (for example it's # 'global_step' or a similar housekeeping element) so skip it. continue var_list[key] = tensor saver = saver_lib.Saver(var_list=var_list) saver.restore(sess, input_checkpoint) if initializer_nodes: sess.run(initializer_nodes) variable_names_blacklist = (variable_names_blacklist.split(',') if variable_names_blacklist else None) output_graph_def = graph_util.convert_variables_to_constants( sess, input_graph_def, output_node_names.split(','), variable_names_blacklist=variable_names_blacklist) return output_graph_def