def write_vocabulary(vocab_file, words): with gfile.GFile(vocab_file, mode="wb") as vocab_file: for w in words: vocab_file.write(bytes(w, 'utf-8') + b"\n")
def freeze_graph_with_def_protos(input_graph_def, input_saver_def, input_checkpoint, output_node_names, restore_op_name, filename_tensor_name, output_graph, clear_devices, initializer_nodes, variable_names_whitelist="", variable_names_blacklist="", input_meta_graph_def=None, input_saved_model_dir=None, saved_model_tags=None, checkpoint_version=saver_pb2.SaverDef.V2): """Converts all variables in a graph and checkpoint into constants.""" del restore_op_name, filename_tensor_name # Unused by updated loading code. # 'input_checkpoint' may be a prefix if we're using Saver V2 format if (not input_saved_model_dir and not saver_lib.checkpoint_exists(input_checkpoint)): print("Input checkpoint '" + input_checkpoint + "' doesn't exist!") return -1 if not output_node_names: print("You need to supply the name of a node to --output_node_names.") return -1 # Remove all the explicit device specifications for this node. This helps to # make the graph more portable. if clear_devices: if input_meta_graph_def: for node in input_meta_graph_def.graph_def.node: node.device = "" elif input_graph_def: for node in input_graph_def.node: node.device = "" if input_graph_def: _ = importer.import_graph_def(input_graph_def, name="") with session.Session() as sess: if input_saver_def: saver = saver_lib.Saver( saver_def=input_saver_def, write_version=checkpoint_version) saver.restore(sess, input_checkpoint) elif input_meta_graph_def: restorer = saver_lib.import_meta_graph( input_meta_graph_def, clear_devices=True) restorer.restore(sess, input_checkpoint) if initializer_nodes: sess.run(initializer_nodes.replace(" ", "").split(",")) elif input_saved_model_dir: if saved_model_tags is None: saved_model_tags = [] loader.load(sess, saved_model_tags, input_saved_model_dir) else: var_list = {} reader = pywrap_tensorflow.NewCheckpointReader(input_checkpoint) var_to_shape_map = reader.get_variable_to_shape_map() # List of all partition variables. Because the condition is heuristic # based, the list could include false positives. all_parition_variable_names = [ tensor.name.split(":")[0] for op in sess.graph.get_operations() for tensor in op.values() if re.search(r"/part_\d+/", tensor.name) ] has_partition_var = False for key in var_to_shape_map: try: tensor = sess.graph.get_tensor_by_name(key + ":0") if any(key in name for name in all_parition_variable_names): has_partition_var = True except KeyError: # This tensor doesn't exist in the graph (for example it's # 'global_step' or a similar housekeeping element) so skip it. continue var_list[key] = tensor try: saver = saver_lib.Saver( var_list=var_list, write_version=checkpoint_version) except TypeError as e: # `var_list` is required to be a map of variable names to Variable # tensors. Partition variables are Identity tensors that cannot be # handled by Saver. if has_partition_var: print("Models containing partition variables cannot be converted " "from checkpoint files. Please pass in a SavedModel using " "the flag --input_saved_model_dir.") return -1 else: raise e saver.restore(sess, input_checkpoint) if initializer_nodes: sess.run(initializer_nodes.replace(" ", "").split(",")) variable_names_whitelist = ( variable_names_whitelist.replace(" ", "").split(",") if variable_names_whitelist else None) variable_names_blacklist = ( variable_names_blacklist.replace(" ", "").split(",") if variable_names_blacklist else None) if input_meta_graph_def: output_graph_def = graph_util.convert_variables_to_constants( sess, input_meta_graph_def.graph_def, output_node_names.replace(" ", "").split(","), variable_names_whitelist=variable_names_whitelist, variable_names_blacklist=variable_names_blacklist) else: output_graph_def = graph_util.convert_variables_to_constants( sess, input_graph_def, output_node_names.replace(" ", "").split(","), variable_names_whitelist=variable_names_whitelist, variable_names_blacklist=variable_names_blacklist) # Write GraphDef to file if output path has been given. if output_graph: with gfile.GFile(output_graph, "wb") as f: f.write(output_graph_def.SerializeToString()) return output_graph_def
def read_benchmark_entry(f): s = gfile.GFile(f, "rb").read() entries = test_log_pb2.BenchmarkEntries.FromString(s) self.assertEquals(1, len(entries.entry)) return entries.entry[0]
def doBasicsOneExportPath(self, export_path, clear_devices=False, global_step=GLOBAL_STEP, sharded=True, export_count=1): # Build a graph with 2 parameter nodes on different devices. ops.reset_default_graph() with session.Session( target="", config=config_pb2.ConfigProto(device_count={"CPU": 2})) as sess: # v2 is an unsaved variable derived from v0 and v1. It is used to # exercise the ability to run an init op when restoring a graph. with sess.graph.device("/cpu:0"): v0 = variables.VariableV1(10, name="v0") with sess.graph.device("/cpu:1"): v1 = variables.VariableV1(20, name="v1") v2 = variables.VariableV1(1, name="v2", trainable=False, collections=[]) assign_v2 = state_ops.assign(v2, math_ops.add(v0, v1)) init_op = control_flow_ops.group(assign_v2, name="init_op") ops.add_to_collection("v", v0) ops.add_to_collection("v", v1) ops.add_to_collection("v", v2) named_tensor_bindings = {"logical_input_A": v0, "logical_input_B": v1} signatures = { "foo": exporter.regression_signature( input_tensor=v0, output_tensor=v1), "generic": exporter.generic_signature(named_tensor_bindings) } asset_filepath_orig = os.path.join(test.get_temp_dir(), "hello42.txt") asset_file = constant_op.constant(asset_filepath_orig, name="filename42") ops.add_to_collection(ops.GraphKeys.ASSET_FILEPATHS, asset_file) with gfile.GFile(asset_filepath_orig, "w") as f: f.write("your data here") assets_collection = ops.get_collection(ops.GraphKeys.ASSET_FILEPATHS) ignored_asset = os.path.join(test.get_temp_dir(), "ignored.txt") with gfile.GFile(ignored_asset, "w") as f: f.write("additional data here") variables.global_variables_initializer().run() # Run an export. save = saver.Saver( { "v0": v0, "v1": v1 }, restore_sequentially=True, sharded=sharded, write_version=saver_pb2.SaverDef.V1) export = exporter.Exporter(save) compare_def = ops.get_default_graph().as_graph_def() export.init( compare_def, init_op=init_op, clear_devices=clear_devices, default_graph_signature=exporter.classification_signature( input_tensor=v0), named_graph_signatures=signatures, assets_collection=assets_collection) for x in range(export_count): export.export( export_path, constant_op.constant(global_step + x), sess, exports_to_keep=gc.largest_export_versions(2)) # Set global_step to the last exported version, as the rest of the test # uses it to construct model export path, loads model from it, and does # verifications. We want to make sure to always use the last exported # version, as old ones may have be garbage-collected. global_step += export_count - 1 # Restore graph. ops.reset_default_graph() with session.Session( target="", config=config_pb2.ConfigProto(device_count={"CPU": 2})) as sess: save = saver.import_meta_graph( os.path.join(export_path, constants.VERSION_FORMAT_SPECIFIER % global_step, constants.META_GRAPH_DEF_FILENAME)) self.assertIsNotNone(save) meta_graph_def = save.export_meta_graph() collection_def = meta_graph_def.collection_def # Validate custom graph_def. graph_def_any = collection_def[constants.GRAPH_KEY].any_list.value self.assertEquals(len(graph_def_any), 1) graph_def = graph_pb2.GraphDef() graph_def_any[0].Unpack(graph_def) if clear_devices: for node in compare_def.node: node.device = "" self.assertProtoEquals(compare_def, graph_def) # Validate init_op. init_ops = collection_def[constants.INIT_OP_KEY].node_list.value self.assertEquals(len(init_ops), 1) self.assertEquals(init_ops[0], "init_op") # Validate signatures. signatures_any = collection_def[constants.SIGNATURES_KEY].any_list.value self.assertEquals(len(signatures_any), 1) signatures = manifest_pb2.Signatures() signatures_any[0].Unpack(signatures) default_signature = signatures.default_signature self.assertEqual( default_signature.classification_signature.input.tensor_name, "v0:0") bindings = signatures.named_signatures["generic"].generic_signature.map self.assertEquals(bindings["logical_input_A"].tensor_name, "v0:0") self.assertEquals(bindings["logical_input_B"].tensor_name, "v1:0") read_foo_signature = ( signatures.named_signatures["foo"].regression_signature) self.assertEquals(read_foo_signature.input.tensor_name, "v0:0") self.assertEquals(read_foo_signature.output.tensor_name, "v1:0") # Validate the assets. assets_any = collection_def[constants.ASSETS_KEY].any_list.value self.assertEquals(len(assets_any), 1) asset = manifest_pb2.AssetFile() assets_any[0].Unpack(asset) assets_path = os.path.join(export_path, constants.VERSION_FORMAT_SPECIFIER % global_step, constants.ASSETS_DIRECTORY, "hello42.txt") asset_contents = gfile.GFile(assets_path).read() self.assertEqual(asset_contents, "your data here") self.assertEquals("hello42.txt", asset.filename) self.assertEquals("filename42:0", asset.tensor_binding.tensor_name) ignored_asset_path = os.path.join(export_path, constants.VERSION_FORMAT_SPECIFIER % global_step, constants.ASSETS_DIRECTORY, "ignored.txt") self.assertFalse(gfile.Exists(ignored_asset_path)) # Validate graph restoration. if sharded: save.restore(sess, os.path.join(export_path, constants.VERSION_FORMAT_SPECIFIER % global_step, constants.VARIABLES_FILENAME_PATTERN)) else: save.restore(sess, os.path.join(export_path, constants.VERSION_FORMAT_SPECIFIER % global_step, constants.VARIABLES_FILENAME)) self.assertEqual(10, ops.get_collection("v")[0].eval()) self.assertEqual(20, ops.get_collection("v")[1].eval()) ops.get_collection(constants.INIT_OP_KEY)[0].run() self.assertEqual(30, ops.get_collection("v")[2].eval())
def train(self, learn_rate, dropout_rate, save_step, batch_size, eval_step, training_time, rate_step, display_step, train_data, Validation_data, init=False): self._save_step = save_step self._training_time = training_time assert type(learn_rate) == list,\ "Learn Rate should be a List to be used. e.g [.001, .0001]" self._ground_truth_input = tf.compat.v1.placeholder( tf.int64, [None], name='groundtruth_input') with tf.compat.v1.name_scope('cross_entropy'): self._cross_entropy_mean = tf.compat.v1.losses.sparse_softmax_cross_entropy( labels=self._ground_truth_input, logits=self._softmax_layer) learning_rate_input = tf.compat.v1.placeholder( tf.float32, [], name='learning_rate_input') train_step = tf.compat.v1.train.GradientDescentOptimizer( learning_rate_input).minimize(self._cross_entropy_mean) self._predicted = tf.argmax(input=self._softmax_layer, axis=1) correct_prediction = tf.equal(self._predicted, self._ground_truth_input) self._evaluation_step = tf.reduce_mean( input_tensor=tf.cast(correct_prediction, tf.float32)) saver = tf.compat.v1.train.Saver(tf.compat.v1.global_variables()) if self._loaded is False and self._start_step == 0: self._global_step = tf.compat.v1.train.get_or_create_global_step() tf.compat.v1.global_variables_initializer().run() #self._loaded = True increment_global_step = tf.compat.v1.assign(self._global_step, self._global_step + 1) if init is False: tf.io.write_graph(self._sess.graph_def, self._model_dir, "model" + '.pbtxt') with gfile.GFile( os.path.join(self._model_dir, "commands" + '_labels.txt'), 'wb') as f: f.write('\n'.join(self.commands)) if training_time <= self._start_step and self._loaded is True: print( f"Checkpoint Loaded has been trained to {self._start_step} epochs,\ \n New Trainig starts from {self._start_step}, Please increase Training_time to train model" ) if init is False: if tf.config.list_physical_devices('GPU'): strategy = tf.distribute.MirroredStrategy() else: # use default strategy strategy = tf.distribute.get_strategy() with strategy.scope(): history = { "categorical_accuracy": [], "loss": [], "val_categorical_accuracy": [], "val_loss": [] } learning_rate = learn_rate[0] for training_step in xrange(self._start_step, training_time): if training_step == int(rate_step): learning_rate = learn_rate[1] x_train, y_train = self.get_next_batch( batch_size, train_data) train_accuracy, cross_entropy_value, _, _ = self._sess.run( [ self._evaluation_step, self._cross_entropy_mean, train_step, increment_global_step, ], feed_dict={ self._fingerprint_input: x_train, self._ground_truth_input: y_train, learning_rate_input: learning_rate, self._dropout_placeholder: dropout_rate }) if training_step % int(display_step) == 0: print( 'Step #%d: learning rate %f, accuracy %.1f%%, cross entropy %f' % (training_step, learning_rate, train_accuracy * 100, cross_entropy_value)) history["categorical_accuracy"].append(train_accuracy) history["loss"].append(cross_entropy_value) if training_step % int(eval_step) == 0: x_val, y_val = self.get_next_batch( batch_size * 4, Validation_data) validation_accuracy, val_crossentropy_value = self._sess.run( [self._evaluation_step, self._cross_entropy_mean], feed_dict={ self._fingerprint_input: x_val, self._ground_truth_input: y_val, self._dropout_placeholder: 0.0 }) history["val_categorical_accuracy"].append( validation_accuracy) history["val_loss"].append(val_crossentropy_value) print( 'Step %d: Validation accuracy = %.1f%% (Val Size=%d), Validation loss = %f' % (training_step, validation_accuracy * 100, batch_size * 4, val_crossentropy_value)) if (training_step % int(save_step) == 0) or (training_step == training_time - 1): path_to_save = os.path.join( self._model_dir, "model_checkpoint" + '.ckpt') if (training_step == training_time - 1): training_step = training_time saver.save(self._sess, path_to_save, global_step=training_step) self._start_step = self._global_step.eval( session=self._sess) return history
"Save model after this many steps (default: 500)") # Misc Parameters tf.flags.DEFINE_boolean("allow_soft_placement", True, "Allow device soft device placement") tf.flags.DEFINE_boolean("log_device_placement", False, "Log placement of ops on devices") FLAGS = tf.flags.FLAGS FLAGS._parse_flags() if __name__ == "__main__": ids_path, vocab_path = data_utils.prepare_data(FLAGS.data_path, FLAGS.vocab_size) dataset = [] with gfile.GFile(ids_path, mode="r") as ids_file: ids = ids_file.readline() while ids: ids = map(int, ids.split()) dataset.append(ids) ids = ids_file.readline() dataset = [one for one in dataset if len(one) <= FLAGS.max_seq_len] train_val_split_point = len(dataset) - int( len(dataset) * FLAGS.validation_ratio) train_set = dataset[0:train_val_split_point] validation_set = dataset[train_val_split_point:] model = seq2seq_network.Seq2SeqAutoEncoder( FLAGS.vocab_size, FLAGS.embedding_dim, FLAGS.num_units, FLAGS.num_layers, FLAGS.max_seq_len, FLAGS.max_gradient_norm,
def freeze_graph_with_def_protos(input_graph_def, input_saver_def, input_checkpoint, output_node_names, restore_op_name, filename_tensor_name, output_graph, clear_devices, initializer_nodes, variable_names_whitelist="", variable_names_blacklist="", input_meta_graph_def=None, input_saved_model_dir=None, saved_model_tags=None): """Converts all variables in a graph and checkpoint into constants.""" del restore_op_name, filename_tensor_name # Unused by updated loading code. # 'input_checkpoint' may be a prefix if we're using Saver V2 format if (not input_saved_model_dir and not saver_lib.checkpoint_exists(input_checkpoint)): print("Input checkpoint '" + input_checkpoint + "' doesn't exist!") return -1 if not output_node_names: print("You need to supply the name of a node to --output_node_names.") return -1 # Remove all the explicit device specifications for this node. This helps to # make the graph more portable. if clear_devices: if input_meta_graph_def: for node in input_meta_graph_def.graph_def.node: node.device = "" elif input_graph_def: for node in input_graph_def.node: node.device = "" if input_graph_def: _ = importer.import_graph_def(input_graph_def, name="") with session.Session() as sess: if input_saver_def: saver = saver_lib.Saver(saver_def=input_saver_def) saver.restore(sess, input_checkpoint) elif input_meta_graph_def: restorer = saver_lib.import_meta_graph(input_meta_graph_def, clear_devices=True) restorer.restore(sess, input_checkpoint) if initializer_nodes: sess.run(initializer_nodes.split(",")) elif input_saved_model_dir: if saved_model_tags is None: saved_model_tags = [] loader.load(sess, saved_model_tags, input_saved_model_dir) else: var_list = {} reader = pywrap_tensorflow.NewCheckpointReader(input_checkpoint) var_to_shape_map = reader.get_variable_to_shape_map() for key in var_to_shape_map: try: tensor = sess.graph.get_tensor_by_name(key + ":0") except KeyError: # This tensor doesn't exist in the graph (for example it's # 'global_step' or a similar housekeeping element) so skip it. continue var_list[key] = tensor saver = saver_lib.Saver(var_list=var_list) saver.restore(sess, input_checkpoint) if initializer_nodes: sess.run(initializer_nodes.split(",")) variable_names_whitelist = (variable_names_whitelist.split(",") if variable_names_whitelist else None) variable_names_blacklist = (variable_names_blacklist.split(",") if variable_names_blacklist else None) if input_meta_graph_def: output_graph_def = graph_util.convert_variables_to_constants( sess, input_meta_graph_def.graph_def, output_node_names.split(","), variable_names_whitelist=variable_names_whitelist, variable_names_blacklist=variable_names_blacklist) else: output_graph_def = graph_util.convert_variables_to_constants( sess, input_graph_def, output_node_names.split(","), variable_names_whitelist=variable_names_whitelist, variable_names_blacklist=variable_names_blacklist) # Write GraphDef to file if output path has been given. if output_graph: with gfile.GFile(output_graph, "wb") as f: f.write(output_graph_def.SerializeToString()) return output_graph_def
def test_train_infer(self): """Tests training and inference scripts. """ # Create dummy data sources_train, targets_train = test_utils.create_temp_parallel_data( sources=["a a a a", "b b b b", "c c c c", "笑 笑 笑 笑"], targets=["b b b b", "a a a a", "c c c c", "泣 泣 泣 泣"]) sources_dev, targets_dev = test_utils.create_temp_parallel_data( sources=["a a", "b b", "c c c", "笑 笑 笑"], targets=["b b", "a a", "c c c", "泣 泣 泣"]) vocab_source = test_utils.create_temporary_vocab_file( ["a", "b", "c", "笑"]) vocab_target = test_utils.create_temporary_vocab_file( ["a", "b", "c", "泣"]) _clear_flags() tf.reset_default_graph() train_script = imp.load_source("seq2seq.test.train_bin", os.path.join(BIN_FOLDER, "train.py")) # Set training flags tf.app.flags.FLAGS.output_dir = self.output_dir tf.app.flags.FLAGS.train_source = sources_train.name tf.app.flags.FLAGS.train_target = targets_train.name tf.app.flags.FLAGS.vocab_source = vocab_source.name tf.app.flags.FLAGS.vocab_target = vocab_target.name tf.app.flags.FLAGS.model = "AttentionSeq2Seq" tf.app.flags.FLAGS.batch_size = 2 # We pass a few flags via a config file config_path = os.path.join(self.output_dir, "train_config.yml") with gfile.GFile(config_path, "w") as config_file: yaml.dump( { "dev_source": sources_dev.name, "dev_target": targets_dev.name, "train_steps": 50, "hparams": { "embedding.dim": 64, "attention.dim": 16, "decoder.rnn_cell.cell_spec": { "class": "GRUCell", "num_units": 32 } } }, config_file) tf.app.flags.FLAGS.config_path = config_path # Run training tf.logging.set_verbosity(tf.logging.INFO) train_script.main([]) # Make sure a checkpoint was written expected_checkpoint = os.path.join( self.output_dir, "model.ckpt-50.data-00000-of-00001") self.assertTrue(os.path.exists(expected_checkpoint)) # Reset flags and import inference script _clear_flags() tf.reset_default_graph() infer_script = imp.load_source("seq2seq.test.infer_bin", os.path.join(BIN_FOLDER, "infer.py")) # Set inference flags attention_dir = os.path.join(self.output_dir, "att") tf.app.flags.FLAGS.model_dir = self.output_dir tf.app.flags.FLAGS.source = sources_dev.name tf.app.flags.FLAGS.batch_size = 2 tf.app.flags.FLAGS.checkpoint_path = os.path.join( self.output_dir, "model.ckpt-50") tf.app.flags.FLAGS.dump_attention_dir = attention_dir # Make sure inference runs successfully infer_script.main([]) # Make sure attention scores and visualizations exist self.assertTrue( os.path.exists(os.path.join(attention_dir, "attention_scores.npz"))) self.assertTrue( os.path.exists(os.path.join(attention_dir, "00002.png"))) # Load attention scores and assert shape scores = np.load(os.path.join(attention_dir, "attention_scores.npz")) self.assertIn("arr_0", scores) self.assertEqual(scores["arr_0"].shape[1], 3) self.assertIn("arr_1", scores) self.assertEqual(scores["arr_1"].shape[1], 3) self.assertIn("arr_2", scores) self.assertEqual(scores["arr_2"].shape[1], 4) self.assertIn("arr_3", scores) self.assertEqual(scores["arr_3"].shape[1], 4)
def save_int8_frezon_pb(q_model, path): from tensorflow.python.platform import gfile f = gfile.GFile(path, 'wb') f.write(q_model.as_graph_def().SerializeToString()) print("Save to {}".format(path))
def skip_gram_sample_with_text_vocab(input_tensor, vocab_freq_file, vocab_token_index=0, vocab_token_dtype=dtypes.string, vocab_freq_index=1, vocab_freq_dtype=dtypes.float64, vocab_delimiter=",", vocab_min_count=0, vocab_subsampling=None, corpus_size=None, min_skips=1, max_skips=5, start=0, limit=-1, emit_self_as_target=False, batch_size=None, batch_capacity=None, seed=None, name=None): """Skip-gram sampling with a text vocabulary file. Wrapper around `skip_gram_sample()` for use with a text vocabulary file. The vocabulary file is expected to be a plain-text file, with lines of `vocab_delimiter`-separated columns. The `vocab_token_index` column should contain the vocabulary term, while the `vocab_freq_index` column should contain the number of times that term occurs in the corpus. For example, with a text vocabulary file of: ``` bonjour,fr,42 hello,en,777 hola,es,99 ``` You should set `vocab_delimiter=","`, `vocab_token_index=0`, and `vocab_freq_index=2`. See `skip_gram_sample()` documentation for more details about the skip-gram sampling process. Args: input_tensor: A rank-1 `Tensor` from which to generate skip-gram candidates. vocab_freq_file: `string` specifying full file path to the text vocab file. vocab_token_index: `int` specifying which column in the text vocab file contains the tokens. vocab_token_dtype: `DType` specifying the format of the tokens in the text vocab file. vocab_freq_index: `int` specifying which column in the text vocab file contains the frequency counts of the tokens. vocab_freq_dtype: `DType` specifying the format of the frequency counts in the text vocab file. vocab_delimiter: `string` specifying the delimiter used in the text vocab file. vocab_min_count: `int`, `float`, or scalar `Tensor` specifying minimum frequency threshold (from `vocab_freq_file`) for a token to be kept in `input_tensor`. This should correspond with `vocab_freq_dtype`. vocab_subsampling: (Optional) `float` specifying frequency proportion threshold for tokens from `input_tensor`. Tokens that occur more frequently will be randomly down-sampled. Reasonable starting values may be around 1e-3 or 1e-5. See Eq. 5 in http://arxiv.org/abs/1310.4546 for more details. corpus_size: (Optional) `int`, `float`, or scalar `Tensor` specifying the total number of tokens in the corpus (e.g., sum of all the frequency counts of `vocab_freq_file`). Used with `vocab_subsampling` for down-sampling frequently occurring tokens. If this is specified, `vocab_freq_file` and `vocab_subsampling` must also be specified. If `corpus_size` is needed but not supplied, then it will be calculated from `vocab_freq_file`. You might want to supply your own value if you have already eliminated infrequent tokens from your vocabulary files (where frequency < vocab_min_count) to save memory in the internal token lookup table. Otherwise, the unused tokens' variables will waste memory. The user-supplied `corpus_size` value must be greater than or equal to the sum of all the frequency counts of `vocab_freq_file`. min_skips: `int` or scalar `Tensor` specifying the minimum window size to randomly use for each token. Must be >= 0 and <= `max_skips`. If `min_skips` and `max_skips` are both 0, the only label outputted will be the token itself. max_skips: `int` or scalar `Tensor` specifying the maximum window size to randomly use for each token. Must be >= 0. start: `int` or scalar `Tensor` specifying the position in `input_tensor` from which to start generating skip-gram candidates. limit: `int` or scalar `Tensor` specifying the maximum number of elements in `input_tensor` to use in generating skip-gram candidates. -1 means to use the rest of the `Tensor` after `start`. emit_self_as_target: `bool` or scalar `Tensor` specifying whether to emit each token as a label for itself. batch_size: (Optional) `int` specifying batch size of returned `Tensors`. batch_capacity: (Optional) `int` specifying batch capacity for the queue used for batching returned `Tensors`. Only has an effect if `batch_size` > 0. Defaults to 100 * `batch_size` if not specified. seed: (Optional) `int` used to create a random seed for window size and subsampling. See [`set_random_seed`](../../g3doc/python/constant_op.md#set_random_seed) for behavior. name: (Optional) A `string` name or a name scope for the operations. Returns: A `tuple` containing (token, label) `Tensors`. Each output `Tensor` is of rank-1 and has the same type as `input_tensor`. The `Tensors` will be of length `batch_size`; if `batch_size` is not specified, they will be of random length, though they will be in sync with each other as long as they are evaluated together. Raises: ValueError: If `vocab_token_index` or `vocab_freq_index` is less than 0 or exceeds the number of columns in `vocab_freq_file`. If `vocab_token_index` and `vocab_freq_index` are both set to the same column. If any token in `vocab_freq_file` has a negative frequency. """ if vocab_token_index < 0 or vocab_freq_index < 0: raise ValueError( "vocab_token_index={} and vocab_freq_index={} must both be >= 0.". format(vocab_token_index, vocab_freq_index)) if vocab_token_index == vocab_freq_index: raise ValueError( "vocab_token_index and vocab_freq_index should be different, but are " "both {}.".format(vocab_token_index)) # Iterates through the vocab file and calculates the number of vocab terms as # well as the total corpus size (by summing the frequency counts of all the # vocab terms). calculated_corpus_size = 0.0 vocab_size = 0 with gfile.GFile(vocab_freq_file, mode="r") as f: reader = csv.reader(f, delimiter=vocab_delimiter) for row in reader: if vocab_token_index >= len(row) or vocab_freq_index >= len(row): raise ValueError( "Row in vocab file only has {} columns, so vocab_token_index={} or " "vocab_freq_index={} is out of bounds. Row content: {}". format(len(row), vocab_token_index, vocab_freq_index, row)) vocab_size += 1 freq = vocab_freq_dtype.as_numpy_dtype(row[vocab_freq_index]) if freq < 0: raise ValueError( "Row in vocab file has negative frequency of {}. Row content: {}" .format(freq, row)) # Note: tokens whose frequencies are below vocab_min_count will still # contribute to the total corpus size used for vocab subsampling. calculated_corpus_size += freq if not corpus_size: corpus_size = calculated_corpus_size elif calculated_corpus_size - corpus_size > 1e-6: raise ValueError( "`corpus_size`={} must be greater than or equal to the sum of all the " "frequency counts ({}) of `vocab_freq_file` ({}).".format( corpus_size, calculated_corpus_size, vocab_freq_file)) vocab_freq_table = lookup.HashTable( lookup.TextFileInitializer(filename=vocab_freq_file, key_dtype=vocab_token_dtype, key_index=vocab_token_index, value_dtype=vocab_freq_dtype, value_index=vocab_freq_index, vocab_size=vocab_size, delimiter=vocab_delimiter), # For vocab terms not in vocab file, use a default value of -1. default_value=-1) return skip_gram_sample( input_tensor, min_skips=min_skips, max_skips=max_skips, start=start, limit=limit, emit_self_as_target=emit_self_as_target, vocab_freq_table=vocab_freq_table, vocab_min_count=vocab_min_count, vocab_subsampling=vocab_subsampling, # corpus_size is not used unless vocab_subsampling is specified. corpus_size=None if vocab_subsampling is None else corpus_size, batch_size=batch_size, batch_capacity=batch_capacity, seed=seed, name=name)
def _global_report_benchmark(name, iters=None, cpu_time=None, wall_time=None, throughput=None, extras=None): """Method for recording a benchmark directly. Args: name: The BenchmarkEntry name. iters: (optional) How many iterations were run cpu_time: (optional) Total cpu time in seconds wall_time: (optional) Total wall time in seconds throughput: (optional) Throughput (in MB/s) extras: (optional) Dict mapping string keys to additional benchmark info. Raises: TypeError: if extras is not a dict. IOError: if the benchmark output file already exists. """ if extras is not None: if not isinstance(extras, dict): raise TypeError("extras must be a dict") logging.info( "Benchmark [%s] iters: %d, wall_time: %g, cpu_time: %g," "throughput: %g %s", name, iters if iters is not None else -1, wall_time if wall_time is not None else -1, cpu_time if cpu_time is not None else -1, throughput if throughput is not None else -1, str(extras) if extras else "") entries = test_log_pb2.BenchmarkEntries() entry = entries.entry.add() entry.name = name if iters is not None: entry.iters = iters if cpu_time is not None: entry.cpu_time = cpu_time if wall_time is not None: entry.wall_time = wall_time if throughput is not None: entry.throughput = throughput if extras is not None: for (k, v) in extras.items(): if isinstance(v, numbers.Number): entry.extras[k].double_value = v else: entry.extras[k].string_value = str(v) test_env = os.environ.get(TEST_REPORTER_TEST_ENV, None) if test_env is None: # Reporting was not requested, just print the proto print(str(entries)) return serialized_entry = entries.SerializeToString() mangled_name = name.replace("/", "__") output_path = "%s%s" % (test_env, mangled_name) if gfile.Exists(output_path): raise IOError("File already exists: %s" % output_path) with gfile.GFile(output_path, "wb") as out: out.write(serialized_entry)
def run_training(sess, trainers, annotator, evaluator, pretrain_steps, train_steps, train_corpus, eval_corpus, eval_gold, batch_size, summary_writer, report_every, saver, checkpoint_filename, checkpoint_stats=None): """Runs multi-task DRAGNN training on a single corpus. Arguments: sess: TF session to use. trainers: List of training ops to use. annotator: Annotation op. evaluator: Function taking two serialized corpora and returning a dict of scalar summaries representing evaluation metrics. The 'eval_metric' summary will be used for early stopping. pretrain_steps: List of the no. of pre-training steps for each train op. train_steps: List of the total no. of steps for each train op. train_corpus: Training corpus to use. eval_corpus: Holdout Corpus for early stoping. eval_gold: Reference of eval_corpus for computing accuracy. eval_corpus and eval_gold are allowed to be the same if eval_corpus already contains gold annotation. Note for segmentation eval_corpus and eval_gold are always different since eval_corpus contains sentences whose tokens are utf8-characters while eval_gold's tokens are gold words. batch_size: How many examples to send to the train op at a time. summary_writer: TF SummaryWriter to use to write summaries. report_every: How often to compute summaries (in steps). saver: TF saver op to save variables. checkpoint_filename: File to save checkpoints to. checkpoint_stats: Stats of checkpoint. """ random.seed(0x31337) if not checkpoint_stats: checkpoint_stats = [0] * (len(train_steps) + 1) tf.logging.info('Determining the training schedule...') target_for_step = [] for target_idx in xrange(len(pretrain_steps)): target_for_step += [target_idx] * pretrain_steps[target_idx] while sum(train_steps) > 0: step = random.randint(0, sum(train_steps) - 1) cumulative_steps = 0 for target_idx in xrange(len(train_steps)): cumulative_steps += train_steps[target_idx] if step < cumulative_steps: break assert train_steps[target_idx] > 0 train_steps[target_idx] -= 1 target_for_step.append(target_idx) tf.logging.info('Training schedule defined!') best_eval_metric = -1.0 tf.logging.info('Starting training...') actual_step = sum(checkpoint_stats[1:]) for step, target_idx in enumerate(target_for_step): run_training_step(sess, trainers[target_idx], train_corpus, batch_size) checkpoint_stats[target_idx + 1] += 1 if step % 100 == 0: tf.logging.info('training step: %d, actual: %d', step, actual_step + step) if step % report_every == 0: tf.logging.info('finished step: %d, actual: %d', step, actual_step + step) annotated = annotate_dataset(sess, annotator, eval_corpus) summaries = evaluator(eval_gold, annotated) for label, metric in summaries.iteritems(): write_summary(summary_writer, label, metric, actual_step + step) eval_metric = summaries['eval_metric'] tf.logging.info('Current eval metric: %.2f', eval_metric) if best_eval_metric < eval_metric: tf.logging.info( 'Updating best eval to %.2f, saving checkpoint.', eval_metric) best_eval_metric = eval_metric saver.save(sess, checkpoint_filename) with gfile.GFile('%s.stats' % checkpoint_filename, 'w') as f: stats_str = ','.join([str(x) for x in checkpoint_stats]) f.write(stats_str) tf.logging.info('Writing stats: %s', stats_str) tf.logging.info('Finished training!')
def train(): """Train a en->fr translation model using WMT data.""" from_train = None to_train = None from_dev = None to_dev = None if FLAGS.from_train_data and FLAGS.to_train_data: from_train_data = FLAGS.from_train_data to_train_data = FLAGS.to_train_data from_dev_data = from_train_data to_dev_data = to_train_data if FLAGS.from_dev_data and FLAGS.to_dev_data: from_dev_data = FLAGS.from_dev_data to_dev_data = FLAGS.to_dev_data from_train, to_train, from_dev, to_dev, _, _ = data_utils.prepare_data( FLAGS.data_dir, from_train_data, to_train_data, from_dev_data, to_dev_data, FLAGS.from_vocab_size, FLAGS.to_vocab_size) else: # Prepare WMT data. print("Preparing WMT data in %s" % FLAGS.data_dir) from_train, to_train, from_dev, to_dev, _, _ = data_utils.prepare_wmt_data( FLAGS.data_dir, FLAGS.from_vocab_size, FLAGS.to_vocab_size) with tf.Session() as sess: # Create model. print("Creating %d layers of %d units." % (FLAGS.num_layers, FLAGS.size)) model = create_model(sess, False) # Read data into buckets and compute their sizes. print("Reading development and training data (limit: %d)." % FLAGS.max_train_data_size) dev_set = read_data(from_dev, to_dev) train_set = read_data(from_train, to_train, FLAGS.max_train_data_size) train_bucket_sizes = [len(train_set[b]) for b in xrange(len(_buckets))] train_total_size = float(sum(train_bucket_sizes)) # A bucket scale is a list of increasing numbers from 0 to 1 that we'll use # to select a bucket. Length of [scale[i], scale[i+1]] is proportional to # the size if i-th training bucket, as used later. train_buckets_scale = [ sum(train_bucket_sizes[:i + 1]) / train_total_size for i in xrange(len(train_bucket_sizes)) ] # This is the training loop. step_time, loss = 0.0, 0.0 current_step = 0 previous_losses = [] while current_step < FLAGS.num_train_step: # Choose a bucket according to data distribution. We pick a random number # in [0, 1] and use the corresponding interval in train_buckets_scale. random_number_01 = np.random.random_sample() bucket_id = min([ i for i in xrange(len(train_buckets_scale)) if train_buckets_scale[i] > random_number_01 ]) # Get a batch and make a step. start_time = time.time() encoder_inputs, decoder_inputs, target_weights = model.get_batch( train_set, bucket_id) _, step_loss, _, enc_init_states, enc_all_outputs = model.step( sess, encoder_inputs, decoder_inputs, #MK change target_weights, bucket_id, False, 1) step_time += (time.time() - start_time) / FLAGS.steps_per_checkpoint loss += step_loss / FLAGS.steps_per_checkpoint current_step += 1 # Once in a while, we save checkpoint, print statistics, and run evals. if current_step % FLAGS.steps_per_checkpoint == 0: # Print statistics for the previous epoch. first_layer = np.array(enc_init_states[0]) mat_first_layer = np.matrix(first_layer) with open('first_layer_states.txt', 'wb') as f: for line in mat_first_layer: np.savetxt(f, line, fmt='%.2f') second_layer = np.array(enc_init_states[1]) mat_second_layer = np.matrix(second_layer) with open('second_layer_states.txt', 'wb') as f: for line in mat_second_layer: np.savetxt(f, line, fmt='%.2f') perplexity = math.exp( float(loss)) if loss < 300 else float("inf") print( "global step %d learning rate %.4f step-time %.5f perplexity " "%.5f" % (model.global_step.eval(), model.learning_rate.eval(), step_time, perplexity)) # Decrease learning rate if no improvement was seen over last 3 times. if len(previous_losses) > 2 and loss > max( previous_losses[-3:]): sess.run(model.learning_rate_decay_op) previous_losses.append(loss) # Save checkpoint and zero timer and loss. checkpoint_path = os.path.join(FLAGS.train_dir, "translate.ckpt") model.saver.save(sess, checkpoint_path, global_step=model.global_step) step_time, loss = 0.0, 0.0 # Run evals on development set and print their perplexity. for bucket_id in xrange(len(_buckets)): if len(dev_set[bucket_id]) == 0: print(" eval: empty bucket %d" % (bucket_id)) continue encoder_inputs, decoder_inputs, target_weights = model.get_batch( dev_set, bucket_id) _, eval_loss, _, _, _ = model.step(sess, encoder_inputs, decoder_inputs, target_weights, bucket_id, False, 0) eval_ppx = math.exp( float(eval_loss)) if eval_loss < 300 else float("inf") print(" eval: bucket %d perplexity %.5f" % (bucket_id, eval_ppx)) sys.stdout.flush() en_vocab_path = os.path.join(FLAGS.data_dir, "vocab%d.from" % FLAGS.from_vocab_size) fr_vocab_path = os.path.join(FLAGS.data_dir, "vocab%d.to" % FLAGS.to_vocab_size) en_vocab, _ = data_utils.initialize_vocabulary(en_vocab_path) _, rev_fr_vocab = data_utils.initialize_vocabulary(fr_vocab_path) max_iter = 100 count = 0 model.batch_size = 1 with gfile.GFile(FLAGS.from_train_data, mode="rb") as f: for sentence in f: count = count + 1 if max_iter < count: break #sentence = sys.stdin.readline() #while sentence: print(sentence) # Get token-ids for the input sentence. token_ids = data_utils.sentence_to_token_ids( tf.compat.as_bytes(sentence), en_vocab) # Which bucket does it belong to? bucket_id = len(_buckets) - 1 for i, bucket in enumerate(_buckets): if bucket[0] >= len(token_ids): bucket_id = i break else: logging.warning("Sentence truncated: %s", sentence) # Get a 1-element batch to feed the sentence to the model. encoder_inputs, decoder_inputs, target_weights = model.get_batch( {bucket_id: [(token_ids, [])]}, bucket_id) # Get output logits for the sentence. _, _, output_logits, enc_all_state, _ = model.step( sess, encoder_inputs, decoder_inputs, target_weights, bucket_id, False, 0) quit()
def testExists(self): self.assertFalse(gfile.Exists(self.tmp + "test_exists")) with gfile.GFile(self.tmp + "test_exists", "w"): pass self.assertTrue(gfile.Exists(self.tmp + "test_exists"))
def test_export_savedmodel_assets(self): tmpdir = tempfile.mkdtemp() est = estimator.Estimator(model_fn=_model_fn_for_export_tests) est.train(input_fn=dummy_input_fn, steps=1) feature_spec = {'x': parsing_ops.VarLenFeature(dtype=dtypes.int64), 'y': parsing_ops.VarLenFeature(dtype=dtypes.int64)} serving_input_receiver_fn = export.build_parsing_serving_input_receiver_fn( feature_spec) # Create a fake asset. vocab_file_name = os.path.join( compat.as_bytes(tmpdir), compat.as_bytes('my_vocab_file')) vocab_file = gfile.GFile(vocab_file_name, mode='w') vocab_file.write(_VOCAB_FILE_CONTENT) vocab_file.close() # hack in an op that uses the asset, in order to test asset export. # this is not actually valid, of course. def serving_input_receiver_with_asset_fn(): features, receiver_tensor = serving_input_receiver_fn() filename = ops.convert_to_tensor(vocab_file_name, dtypes.string, name='asset_filepath') ops.add_to_collection(ops.GraphKeys.ASSET_FILEPATHS, filename) features['bogus_filename'] = filename return export.ServingInputReceiver(features, receiver_tensor) # Perform the export. export_dir_base = os.path.join( compat.as_bytes(tmpdir), compat.as_bytes('export')) export_dir = est.export_savedmodel( export_dir_base, serving_input_receiver_with_asset_fn) # Check that the asset files are in the right places. expected_vocab_file_name = os.path.join( compat.as_bytes(export_dir), compat.as_bytes('assets/my_vocab_file')) self.assertTrue(gfile.Exists(os.path.join( compat.as_bytes(export_dir), compat.as_bytes('assets')))) self.assertTrue(gfile.Exists(expected_vocab_file_name)) self.assertEqual( compat.as_bytes(_VOCAB_FILE_CONTENT), compat.as_bytes(gfile.GFile(expected_vocab_file_name).read())) # Restore, to validate that the export was well-formed. with ops.Graph().as_default() as graph: with session.Session(graph=graph) as sess: loader.load(sess, [tag_constants.SERVING], export_dir) assets = [ x.eval() for x in graph.get_collection(ops.GraphKeys.ASSET_FILEPATHS) ] self.assertItemsEqual([vocab_file_name], assets) graph_ops = [x.name for x in graph.get_operations()] self.assertTrue('input_example_tensor' in graph_ops) self.assertTrue('ParseExample/ParseExample' in graph_ops) self.assertTrue('asset_filepath' in graph_ops) self.assertTrue('weight' in graph_ops) # cleanup gfile.DeleteRecursively(tmpdir)
def wav_to_features(sample_rate, clip_duration_ms, window_size_ms, window_stride_ms, feature_bin_count, quantize, preprocess, input_wav, output_c_file): """Converts an audio file into its corresponding feature map. Args: sample_rate: Expected sample rate of the wavs. clip_duration_ms: Expected duration in milliseconds of the wavs. window_size_ms: How long each spectrogram timeslice is. window_stride_ms: How far to move in time between spectrogram timeslices. feature_bin_count: How many bins to use for the feature fingerprint. quantize: Whether to train the model for eight-bit deployment. preprocess: Spectrogram processing mode; "mfcc", "average" or "micro". input_wav: Path to the audio WAV file to read. output_c_file: Where to save the generated C source file. """ # Start a new TensorFlow session. sess = tf.compat.v1.InteractiveSession() model_settings = models.prepare_model_settings( 0, sample_rate, clip_duration_ms, window_size_ms, window_stride_ms, feature_bin_count, preprocess) audio_processor = input_data.AudioProcessor(None, None, 0, 0, '', 0, 0, model_settings, None) results = audio_processor.get_features_for_wav(input_wav, model_settings, sess) features = results[0] variable_base = os.path.splitext(os.path.basename(input_wav).lower())[0] # Save a C source file containing the feature data as an array. with gfile.GFile(output_c_file, 'w') as f: f.write('/* File automatically created by\n') f.write( ' * tensorflow/examples/speech_commands/wav_to_features.py \\\n') f.write(' * --sample_rate=%d \\\n' % sample_rate) f.write(' * --clip_duration_ms=%d \\\n' % clip_duration_ms) f.write(' * --window_size_ms=%d \\\n' % window_size_ms) f.write(' * --window_stride_ms=%d \\\n' % window_stride_ms) f.write(' * --feature_bin_count=%d \\\n' % feature_bin_count) if quantize: f.write(' * --quantize=1 \\\n') f.write(' * --preprocess="%s" \\\n' % preprocess) f.write(' * --input_wav="%s" \\\n' % input_wav) f.write(' * --output_c_file="%s" \\\n' % output_c_file) f.write(' */\n\n') f.write('const int g_%s_width = %d;\n' % (variable_base, model_settings['fingerprint_width'])) f.write('const int g_%s_height = %d;\n' % (variable_base, model_settings['spectrogram_length'])) if quantize: features_min, features_max = input_data.get_features_range( model_settings) f.write('const unsigned char g_%s_data[] = {' % variable_base) i = 0 for value in features.flatten(): quantized_value = int( round((255 * (value - features_min)) / (features_max - features_min))) if quantized_value < 0: quantized_value = 0 if quantized_value > 255: quantized_value = 255 if i == 0: f.write('\n ') f.write('%d, ' % (quantized_value)) i = (i + 1) % 10 else: f.write('const float g_%s_data[] = {\n' % variable_base) i = 0 for value in features.flatten(): if i == 0: f.write('\n ') f.write('%f, ' % value) i = (i + 1) % 10 f.write('\n};\n')
def benchmark_sentencepiece_tokenizer(self): model = gfile.GFile((_SENTENCEPIECE_MODEL_FILE), "rb").read() tokenizer = text_ops.SentencepieceTokenizer(model) self._run(tokenizer)
if __name__ == "__main__": os.environ['CUDA_VISIBLE_DEVICE'] = "0" output_path = './outputs_replica' checkpoint_to_use = os.path.join(output_path, 'model.ckpt') output_graph_path = './export/exported_graph_replica_v2.pb' output_node_names = 'NetOutput' tf.reset_default_graph() with tf.Session() as sess: saver_restore = tf.train.import_meta_graph( os.path.join(output_path, 'model.ckpt.meta')) saver_restore.restore(sess, tf.train.latest_checkpoint(output_path)) saver_export = tf.train.Saver(max_to_keep=10) frozen_graph_def = freeze_graph_with_def_protos( input_graph_def=tf.get_default_graph().as_graph_def( add_shapes=True), input_saver_def=saver_export.as_saver_def(), input_checkpoint=checkpoint_to_use, output_node_names=output_node_names, restore_op_name=None, filename_tensor_name=None, clear_devices=True, initializer_nodes=None) with gfile.GFile(output_graph_path, 'wb') as f: f.write(frozen_graph_def.SerializeToString())
def test_writing_canned_variants(self): """Tests writing all the variants that are 'canned' in our tfrecord file.""" # This file is in the TF record format tfrecord_file = test_utils.genomics_core_testdata( 'test_samples.vcf.golden.tfrecord') writer_options = variants_pb2.VcfWriterOptions() header = variants_pb2.VcfHeader( contigs=[ reference_pb2.ContigInfo(name='chr1', n_bases=248956422), reference_pb2.ContigInfo(name='chr2', n_bases=242193529), reference_pb2.ContigInfo(name='chr3', n_bases=198295559), reference_pb2.ContigInfo(name='chrX', n_bases=156040895) ], sample_names=['NA12878_18_99'], filters=[ variants_pb2.VcfFilterInfo(id='PASS', description='All filters passed'), variants_pb2.VcfFilterInfo(id='LowQual', description=''), variants_pb2.VcfFilterInfo(id='VQSRTrancheINDEL95.00to96.00'), variants_pb2.VcfFilterInfo(id='VQSRTrancheINDEL96.00to97.00'), variants_pb2.VcfFilterInfo(id='VQSRTrancheINDEL97.00to99.00'), variants_pb2.VcfFilterInfo(id='VQSRTrancheINDEL99.00to99.50'), variants_pb2.VcfFilterInfo(id='VQSRTrancheINDEL99.50to99.90'), variants_pb2.VcfFilterInfo(id='VQSRTrancheINDEL99.90to99.95'), variants_pb2.VcfFilterInfo( id='VQSRTrancheINDEL99.95to100.00+'), variants_pb2.VcfFilterInfo(id='VQSRTrancheINDEL99.95to100.00'), variants_pb2.VcfFilterInfo(id='VQSRTrancheSNP99.50to99.60'), variants_pb2.VcfFilterInfo(id='VQSRTrancheSNP99.60to99.80'), variants_pb2.VcfFilterInfo(id='VQSRTrancheSNP99.80to99.90'), variants_pb2.VcfFilterInfo(id='VQSRTrancheSNP99.90to99.95'), variants_pb2.VcfFilterInfo(id='VQSRTrancheSNP99.95to100.00+'), variants_pb2.VcfFilterInfo(id='VQSRTrancheSNP99.95to100.00'), ], infos=[ variants_pb2.VcfInfo( id='END', number='1', type='Integer', description='Stop position of the interval') ], formats=[ variants_pb2.VcfFormatInfo(id='GT', number='1', type='String', description='Genotype'), variants_pb2.VcfFormatInfo(id='GQ', number='1', type='Integer', description='Genotype Quality'), variants_pb2.VcfFormatInfo( id='DP', number='1', type='Integer', description='Read depth of all passing filters reads.'), variants_pb2.VcfFormatInfo( id='MIN_DP', number='1', type='Integer', description='Minimum DP observed within the GVCF block.'), variants_pb2.VcfFormatInfo( id='AD', number='R', type='Integer', description= 'Read depth of all passing filters reads for each allele.' ), variants_pb2.VcfFormatInfo( id='VAF', number='A', type='Float', description='Variant allele fractions.'), variants_pb2.VcfFormatInfo( id='PL', number='G', type='Integer', description='Genotype likelihoods, Phred encoded'), ], ) variant_records = list( tfrecord.read_tfrecords(tfrecord_file, proto=variants_pb2.Variant)) out_fname = test_utils.test_tmpfile('output.vcf') with vcf_writer.VcfWriter.to_file(out_fname, header, writer_options) as writer: for record in variant_records[:5]: writer.write(record) # Check: are the variants written as expected? # pylint: disable=line-too-long expected_vcf_content = [ '##fileformat=VCFv4.2\n', '##FILTER=<ID=PASS,Description="All filters passed">\n', '##FILTER=<ID=LowQual,Description="">\n', '##FILTER=<ID=VQSRTrancheINDEL95.00to96.00,Description="">\n', '##FILTER=<ID=VQSRTrancheINDEL96.00to97.00,Description="">\n', '##FILTER=<ID=VQSRTrancheINDEL97.00to99.00,Description="">\n', '##FILTER=<ID=VQSRTrancheINDEL99.00to99.50,Description="">\n', '##FILTER=<ID=VQSRTrancheINDEL99.50to99.90,Description="">\n', '##FILTER=<ID=VQSRTrancheINDEL99.90to99.95,Description="">\n', '##FILTER=<ID=VQSRTrancheINDEL99.95to100.00+,Description="">\n', '##FILTER=<ID=VQSRTrancheINDEL99.95to100.00,Description="">\n', '##FILTER=<ID=VQSRTrancheSNP99.50to99.60,Description="">\n', '##FILTER=<ID=VQSRTrancheSNP99.60to99.80,Description="">\n', '##FILTER=<ID=VQSRTrancheSNP99.80to99.90,Description="">\n', '##FILTER=<ID=VQSRTrancheSNP99.90to99.95,Description="">\n', '##FILTER=<ID=VQSRTrancheSNP99.95to100.00+,Description="">\n', '##FILTER=<ID=VQSRTrancheSNP99.95to100.00,Description="">\n', '##INFO=<ID=END,Number=1,Type=Integer,Description="Stop position of ' 'the interval">\n', '##FORMAT=<ID=GT,Number=1,Type=String,Description="Genotype">\n', '##FORMAT=<ID=GQ,Number=1,Type=Integer,Description="Genotype Quality">\n', '##FORMAT=<ID=DP,Number=1,Type=Integer,Description="Read depth of all ' 'passing filters reads.">\n', '##FORMAT=<ID=MIN_DP,Number=1,Type=Integer,Description="Minimum DP ' 'observed within the GVCF block.">\n', '##FORMAT=<ID=AD,Number=R,Type=Integer,Description="Read depth of all ' 'passing filters reads for each allele.">\n', '##FORMAT=<ID=VAF,Number=A,Type=Float,Description=\"Variant allele ' 'fractions.">\n', '##FORMAT=<ID=PL,Number=G,Type=Integer,Description="Genotype ' 'likelihoods, Phred encoded">\n', '##contig=<ID=chr1,length=248956422>\n', '##contig=<ID=chr2,length=242193529>\n', '##contig=<ID=chr3,length=198295559>\n', '##contig=<ID=chrX,length=156040895>\n', '#CHROM\tPOS\tID\tREF\tALT\tQUAL\tFILTER\tINFO\tFORMAT\tNA12878_18_99\n', 'chr1\t13613\t.\tT\tA\t39.88\tVQSRTrancheSNP99.90to99.95\t.\tGT:GQ:DP:AD:PL\t0/1:16:4:1,3:68,0,16\n', 'chr1\t13813\t.\tT\tG\t90.28\tPASS\t.\tGT:GQ:DP:AD:PL\t1/1:9:3:0,3:118,9,0\n', 'chr1\t13838\trs28428499\tC\tT\t62.74\tPASS\t.\tGT:GQ:DP:AD:PL\t1/1:6:2:0,2:90,6,0\n', 'chr1\t14397\trs756427959\tCTGT\tC\t37.73\tPASS\t.\tGT:GQ:DP:AD:PL\t0/1:75:5:3,2:75,0,152\n', 'chr1\t14522\t.\tG\tA\t49.77\tVQSRTrancheSNP99.60to99.80\t.\tGT:GQ:DP:AD:PL\t0/1:78:10:6,4:78,0,118\n' ] # pylint: enable=line-too-long with gfile.GFile(out_fname, 'r') as f: self.assertEqual(f.readlines(), expected_vcf_content)
def __init__(self, frozen_file, inputshape, in_nodes, dest_nodes): if LooseVersion(tensorflow.__version__) < LooseVersion('1.8.0'): raise ImportError( 'Your TensorFlow version %s is outdated. ' 'MMdnn requires tensorflow>=1.8.0' % tensorflow.__version__) super(TensorflowParser2, self).__init__() self.weight_loaded = True # load model files into TensorFlow graph with open(frozen_file, 'rb') as f: serialized = f.read() tensorflow.reset_default_graph() original_gdef = tensorflow.GraphDef() original_gdef.ParseFromString(serialized) in_type_list = {} for n in original_gdef.node: if n.name in in_nodes: in_type_list[n.name] = n.attr['dtype'].type from tensorflow.python.tools import strip_unused_lib from tensorflow.python.framework import dtypes from tensorflow.python.platform import gfile original_gdef = strip_unused_lib.strip_unused( input_graph_def = original_gdef, input_node_names = in_nodes, output_node_names = dest_nodes, placeholder_type_enum = dtypes.float32.as_datatype_enum) # Save it to an output file frozen_model_file = './frozen.pb' with gfile.GFile(frozen_model_file, "wb") as f: f.write(original_gdef.SerializeToString()) with open(frozen_model_file, 'rb') as f: serialized = f.read() tensorflow.reset_default_graph() model = tensorflow.GraphDef() model.ParseFromString(serialized) output_shape_map = dict() input_shape_map = dict() dtype = tensorflow.float32 with tensorflow.Graph().as_default() as g: input_map = {} for i in range(len(inputshape)): if in_type_list[in_nodes[i]] == 1 or in_type_list[in_nodes[i]] == 0: dtype = tensorflow.float32 x = tensorflow.placeholder(dtype, shape = [None] + inputshape[i]) elif in_type_list[in_nodes[i]] == 3: dtype = tensorflow.int32 x = tensorflow.placeholder(dtype, shape = inputshape[i]) elif in_type_list[in_nodes[i]] == 10: dtype = tensorflow.bool x = tensorflow.placeholder(dtype) input_map[in_nodes[i] + ':0'] = x tensorflow.import_graph_def(model, name='', input_map=input_map) # graph_options = tensorflow.GraphOptions( # optimizer_options=tensorflow.OptimizerOptions( # opt_level=tensorflow.OptimizerOptions.L0, do_function_inlining=False)) # config = tensorflow.ConfigProto(graph_options=graph_options) # with tensorflow.Session(graph = g, config=config) as sess: with tensorflow.Session(graph = g) as sess: meta_graph_def = tensorflow.train.export_meta_graph(filename='./my-model.meta') model = meta_graph_def.graph_def self.tf_graph = TensorflowGraph(model) self.tf_graph.build()
def main(_): best_acc = 0 best_step = 0 best_acc_istrain = 0 best_step_istrain = 0 # We want to see all the logging messages for this tutorial. tf.logging.set_verbosity(tf.logging.INFO) # Start a new TensorFlow session. sess = tf.InteractiveSession() # Begin by making sure we have the training data we need. If you already have # training data of your own, use `--data_url= ` on the command line to avoid # downloading. model_settings = models.prepare_model_settings( len( input_data_filler.prepare_words_list_my( FLAGS.wanted_words.split(','))), FLAGS.sample_rate, FLAGS.clip_duration_ms, FLAGS.window_size_ms, FLAGS.window_stride_ms, FLAGS.dct_coefficient_count) audio_processor = input_data_filler.AudioProcessor( FLAGS.data_url, FLAGS.data_dir, FLAGS.silence_percentage, FLAGS.unknown_percentage, FLAGS.wanted_words.split(','), FLAGS.validation_percentage, FLAGS.testing_percentage, model_settings) fingerprint_size = model_settings['fingerprint_size'] label_count = model_settings['label_count'] time_shift_samples = int((FLAGS.time_shift_ms * FLAGS.sample_rate) / 1000) training_steps_list = list( map(int, FLAGS.how_many_training_steps.split(','))) learning_rates_list = list(map(float, FLAGS.learning_rate.split(','))) if len(training_steps_list) != len(learning_rates_list): raise Exception( '--how_many_training_steps and --learning_rate must be equal length ' 'lists, but are %d and %d long instead' % (len(training_steps_list), len(learning_rates_list))) ############################################## ############tensorflow modules########## fingerprint_input = tf.placeholder(tf.float32, [None, fingerprint_size], name='fingerprint_input') # ############ 模型创建 ########## istrain = tf.placeholder(tf.bool, name='istrain') logits = models.create_model(fingerprint_input, model_settings, FLAGS.model_architecture, is_training=istrain) ############ 模型创建 ########## # logits, dropout_prob= models.create_model( # fingerprint_input, # model_settings, # FLAGS.model_architecture, # is_training=True) # Define loss and optimizer ############ 真实值 ########## ground_truth_input = tf.placeholder(tf.float32, [None, label_count], name='groundtruth_input') # Optionally we can add runtime checks to spot when NaNs or other symptoms of # numerical errors start occurring during training. control_dependencies = [] if FLAGS.check_nans: checks = tf.add_check_numerics_ops() control_dependencies = [checks] # Create the back propagation and training evaluation machinery in the graph. ############ 交叉熵计算 ########## # with tf.name_scope('cross_entropy'): # cross_entropy_mean = tf.reduce_mean( # tf.nn.softmax_cross_entropy_with_logits( # labels=ground_truth_input, logits=logits)) + beta*loss_norm with tf.name_scope('cross_entropy'): cross_entropy_mean = tf.reduce_mean( tf.nn.softmax_cross_entropy_with_logits(labels=ground_truth_input, logits=logits)) tf.summary.scalar('cross_entropy', cross_entropy_mean) ############ 学习率、准确率、混淆矩阵 ########## # learning_rate_input 学习率输入(tf.placeholder) # train_step 训练过程 (优化器) # predicted_indices 预测输出索引 # expected_indices 实际希望输出索引 # correct_prediction 正确预测矩阵 # confusion_matrix 混淆矩阵 # evaluation_step 正确分类概率(每个阶段) # global_step 全局训练阶段 # increment_global_step 全局训练阶段递增 learning_rate_input = tf.placeholder(tf.float32, [], name='learning_rate_input') update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) with tf.control_dependencies(update_ops): train_step = tf.train.AdamOptimizer(learning_rate_input).minimize( cross_entropy_mean) # with tf.name_scope('train'), tf.control_dependencies(control_dependencies): # learning_rate_input = tf.placeholder( # tf.float32, [], name='learning_rate_input') # # train_step = tf.train.GradientDescentOptimizer( # # learning_rate_input).minimize(cross_entropy_mean) # with tf.control_dependencies(update_ops): # train_step = tf.train.AdamOptimizer( # learning_rate_input).minimize(cross_entropy_mean) predicted_indices = tf.argmax(logits, 1) expected_indices = tf.argmax(ground_truth_input, 1) correct_prediction = tf.equal(predicted_indices, expected_indices) confusion_matrix = tf.confusion_matrix(expected_indices, predicted_indices, num_classes=label_count) evaluation_step = tf.reduce_mean(tf.cast(correct_prediction, tf.float32)) acc = tf.summary.scalar('accuracy', evaluation_step) global_step = tf.train.get_or_create_global_step() increment_global_step = tf.assign(global_step, global_step + 1) saver = tf.train.Saver(tf.global_variables(), max_to_keep=None) # max keep file // moren 5 # Merge all the summaries and write them out to /tmp/retrain_logs (by default) merged_summaries = tf.summary.merge_all() validation_merged_summaries = tf.summary.merge([ tf.get_collection(tf.GraphKeys.SUMMARIES, 'accuracy'), tf.get_collection(tf.GraphKeys.SUMMARIES, 'cross_entropy') ]) test_summaries = tf.summary.merge([acc]) test_summaries_istrain = tf.summary.merge([ tf.get_collection(tf.GraphKeys.SUMMARIES, 'accuracy'), tf.get_collection(tf.GraphKeys.SUMMARIES, 'cross_entropy') ]) #test_summaries_istrain = tf.summary.merge([acc]) train_writer = tf.summary.FileWriter(FLAGS.summaries_dir + '/train', sess.graph) # validation_writer = tf.summary.FileWriter(FLAGS.summaries_dir + '/validation') test_writer = tf.summary.FileWriter(FLAGS.summaries_dir + '/test') test_istrain_writer = tf.summary.FileWriter(FLAGS.summaries_dir + '/test_istrain') tf.global_variables_initializer().run() start_step = 1 if FLAGS.start_checkpoint: models.load_variables_from_checkpoint(sess, FLAGS.start_checkpoint) start_step = global_step.eval(session=sess) tf.logging.info('Training from step: %d ', start_step) # Save graph.pbtxt. tf.train.write_graph(sess.graph_def, FLAGS.train_dir, FLAGS.model_architecture + '.pbtxt') # Save list of words. with gfile.GFile( os.path.join(FLAGS.train_dir, FLAGS.model_architecture + '_labels.txt'), 'w') as f: f.write('\n'.join(audio_processor.words_list)) ### # model1: fc # model2: conv :940k个parameter # model3:low_latancy_conv:~~model1 # model4: 750k # Training loop. ############################################# ######## 主循环 ###### ############################################# training_steps_max = np.sum(training_steps_list) for training_step in xrange(start_step, training_steps_max + 1): # Figure out what the current learning rate is. ####### 自动切换学习率 ####### if training_step < 12000 + 1: learning_rate_value = learning_rates_list[0] * 0.02**( training_step / 12000) else: learning_rate_value = learning_rates_list[0] * 0.02 #0.015 12000 training_steps_sum = 0 # for i in range(len(training_steps_list)): # training_steps_sum += training_steps_list[i] # if training_step <= training_steps_sum: # learning_rate_value = learning_rates_list[i] # break # Pull the audio samples we'll use for training. ####### audio处理器导入数据 ################################## ##get_data(self, how_many, offset, model_settings, background_frequency, ## background_volume_range, time_shift, mode, sess) ######################################################################## train_fingerprints, train_ground_truth = audio_processor.get_data_my( FLAGS.batch_size, 0, model_settings, FLAGS.background_frequency, FLAGS.background_volume, time_shift_samples, 'training', sess) #mid = np.abs(np.max(train_fingerprints) + np.min(train_fingerprints)) / 2 #half = np.max(train_fingerprints) - np.min(train_fingerprints) # train_fingerprints = ((train_fingerprints + mid) / half * 255).astype(int) #train_fingerprints = np_round_and_clip_5bit(train_fingerprints) train_summary, train_accuracy, cross_entropy_value, _, _ = sess.run( [ merged_summaries, evaluation_step, cross_entropy_mean, train_step, increment_global_step ], feed_dict={ fingerprint_input: train_fingerprints, ground_truth_input: train_ground_truth, learning_rate_input: learning_rate_value, istrain: True }) train_writer.add_summary(train_summary, training_step) tf.logging.info( 'Step #%d: rate %f, accuracy %.1f%%, cross entropy %f' % (training_step, learning_rate_value, train_accuracy * 100, cross_entropy_value)) is_last_step = (training_step == training_steps_max) if (training_step % FLAGS.eval_step_interval) == 0 or is_last_step: ############################################# ######## 测试集重复计算正确率和混淆矩阵 ###### set_size = audio_processor.set_size('testing') tf.logging.info('set_size=%d', set_size) test_fingerprints, test_ground_truth = audio_processor.get_data_my( -3, 0, model_settings, 0.0, 0.0, 0, 'testing', sess) #mid = np.abs(np.max(test_fingerprints) + np.min(test_fingerprints)) / 2 #half = np.max(test_fingerprints) - np.min(test_fingerprints) # test_fingerprints = np_round_and_clip_5bit(test_fingerprints) final_summary, test_accuracy, conf_matrix = sess.run( [test_summaries, evaluation_step, confusion_matrix], feed_dict={ fingerprint_input: test_fingerprints, ground_truth_input: test_ground_truth, istrain: False }) final_summary_istrain, test_accuracy_istrain = sess.run( [test_summaries_istrain, evaluation_step], feed_dict={ fingerprint_input: test_fingerprints, ground_truth_input: test_ground_truth, istrain: True }) if test_accuracy > best_acc: best_acc = test_accuracy best_step = training_step if test_accuracy_istrain > best_acc_istrain: best_acc_istrain = test_accuracy_istrain best_step_istrain = training_step test_writer.add_summary(final_summary, training_step) test_istrain_writer.add_summary(final_summary_istrain, training_step) tf.logging.info('Confusion Matrix:\n %s' % (conf_matrix)) tf.logging.info('test accuracy = %.1f%% (N=%d)' % (test_accuracy * 100, 6882)) tf.logging.info('test_istrain accuracy = %.1f%% (N=%d)' % (test_accuracy_istrain * 100, 6882)) tf.logging.info('Best test accuracy before now = %.1f%% (N=%d)' % (best_acc * 100, 6882) + ' at step of ' + str(best_step)) tf.logging.info( 'Best test_istrain accuracy before now = %.1f%% (N=%d)' % (best_acc_istrain * 100, 6882) + ' at step of ' + str(best_step_istrain)) # Save the model checkpoint periodically. if (training_step % FLAGS.save_step_interval == 0 or training_step == training_steps_max): checkpoint_path = os.path.join( FLAGS.train_dir + '/' + FLAGS.model_architecture, FLAGS.model_architecture + '.ckpt') tf.logging.info('Saving to "%s-%d"', checkpoint_path, training_step) saver.save(sess, checkpoint_path, global_step=training_step) print_line = 'Best test accuracy before now = %.1f%% (N=%d)' % (best_acc * 100,6882) + ' at step of ' + str(best_step) + '\n' + \ 'Best test_istrain accuracy before now = %.1f%% (N=%d)' % (best_acc_istrain * 100,6882) + ' at step of ' + str(best_step_istrain) if training_step == training_steps_max: with open( FLAGS.train_dir + '/' + FLAGS.model_architecture + '/details.txt', 'w') as f: f.write(print_line)
def build_buckets(buckets, max_vocab, source_languages, target_languages): """ Build bucketed versions of the data, tokenizing it and writing it to a file in the process. :param buckets: List of pairs representing (source, target) bucket length. :param max_vocab: Dictionary mapping language id to Maximum Vocabulary Size. """ if len(os.listdir(os.path.join(FLAGS.bucket_dir, str(0)))) <= 1: for target_id in target_languages: target_vocab, _ = init_vocab(target_id, max_vocab[target_id]) for source_id in source_languages: source_vocab, _ = init_vocab(source_id, max_vocab[source_id]) base_fp = "%s-%s." % tuple(sorted([source_id, target_id])) eval_fp = "%s-%s-eval." % tuple(sorted([source_id, target_id])) print base_fp, source_id, target_id data_set = {k: [] for k in range(len(buckets))} with gfile.GFile( os.path.join(FLAGS.raw_dir, base_fp + source_id), 'r') as src: with gfile.GFile( os.path.join(FLAGS.raw_dir, base_fp + target_id), 'r') as trg: source, target = src.readline(), trg.readline() while source and target: src_ids = [ source_vocab.get(w, UNK_ID) for w in basic_tokenizer(source) ] trg_ids = [ target_vocab.get(w, UNK_ID) for w in basic_tokenizer(target) ] trg_ids.append(EOS_ID) # Find a bucket for bucket_id, (source_size, target_size) in enumerate(buckets): if len(src_ids) < source_size and len( trg_ids) < target_size: data_set[bucket_id].append( [src_ids, trg_ids]) break source, target = src.readline(), trg.readline() for k in data_set: counter = 0 data_train = data_set[k][:-200] data_eval = data_set[k][-200:] bucket_path = os.path.join(FLAGS.bucket_dir, str(k)) with gfile.GFile( os.path.join(bucket_path, base_fp + source_id), 'w') as src: with gfile.GFile( os.path.join(bucket_path, base_fp + target_id), 'w') as trg: for i in range(len(data_train)): counter += 1 s, t = data_train[i] src.write(" ".join(map(str, s)) + "\n") trg.write(" ".join(map(str, t)) + "\n") with gfile.GFile( os.path.join(bucket_path, eval_fp + source_id), 'w') as src: with gfile.GFile( os.path.join(bucket_path, eval_fp + target_id), 'w') as trg: for i in range(len(data_eval)): s, t = data_eval[i] src.write(" ".join(map(str, s)) + "\n") trg.write(" ".join(map(str, t)) + "\n") print "Bucket", k, "for %s-%s" % ( source_id, target_id), "has", counter, "examples!"
def _read_words(filename): with gfile.GFile(filename, "r") as f: return f.read().replace("\n", "<eos>").split()
with tf.Session() as sess: constant_graph = graph_util.convert_variables_to_constants( sess, sess.graph.as_graph_def(), pred_node_names) frozen = graph_util.remove_training_nodes(constant_graph) output = "cnn.pb" graph_io.write_graph(frozen, ".", output, as_text=False) export_cnn() tf.reset_default_graph() model_filename = "cnn.pb" with gfile.GFile(model_filename, "rb") as f: graph_def = tf.GraphDef() graph_def.ParseFromString(f.read()) config = tfe.LocalConfig([ "server0", "server1", "crypto-producer", "prediction-client", "weights-provider" ]) def provide_input() -> tf.Tensor: return tf.constant(np.random.normal(size=(1, 1, 28, 28)), tf.float32) def receive_output(tensor: tf.Tensor) -> tf.Tensor: tf.print(tensor, [tensor])
def main(_): # We want to see all the logging messages for this tutorial. tf.logging.set_verbosity(tf.logging.INFO) # Start a new TensorFlow session. sess = tf.InteractiveSession() model_settings = models.prepare_model_settings( len(input_data.prepare_words_list(FLAGS.wanted_words.split(','))), FLAGS.sample_rate, FLAGS.clip_duration_ms, FLAGS.window_size_ms, FLAGS.window_stride_ms, FLAGS.dct_coefficient_count) audio_processor = input_data.AudioProcessor(FLAGS.data_url, FLAGS.data_dir, FLAGS.silence_percentage, FLAGS.unknown_percentage, FLAGS.wanted_words.split(','), FLAGS.validation_percentage, FLAGS.testing_percentage, model_settings) fingerprint_size = model_settings['fingerprint_size'] label_count = model_settings['label_count'] time_shift_samples = int((FLAGS.time_shift_ms * FLAGS.sample_rate) / 1000) training_steps = FLAGS.how_many_training_steps learning_rate = FLAGS.learning_rate # ----------------------------------------------------------------------- # -----------------------------Placeholder------------------------------- # ----------------------------------------------------------------------- fingerprint_input = tf.placeholder(tf.float32, [None, fingerprint_size], name='fingerprint_input') logits, dropout_prob = models.create_model(fingerprint_input, model_settings, FLAGS.model_architecture, is_training=True) # Define loss and optimizer ground_truth_input = tf.placeholder(tf.int64, [None], name='groundtruth_input') # Optionally we can add runtime checks to spot when NaNs or other symptoms of # numerical errors start occurring during training. control_dependencies = [] if FLAGS.check_nans: checks = tf.add_check_numerics_ops() control_dependencies = [checks] # ----------------------------------------------------------------------- # -----------------Back propagation and training evaluation-------------- # ----------------------------------------------------------------------- # Create the back propagation and training evaluation machinery in the graph. with tf.name_scope('cross_entropy'): cross_entropy_mean = tf.losses.sparse_softmax_cross_entropy( labels=ground_truth_input, logits=logits) tf.summary.scalar('cross_entropy', cross_entropy_mean) with tf.name_scope('train'), tf.control_dependencies(control_dependencies): train_step = tf.train.AdamOptimizer(learning_rate).minimize( cross_entropy_mean) predicted_indices = tf.argmax(logits, 1) correct_prediction = tf.equal(predicted_indices, ground_truth_input) confusion_matrix = tf.confusion_matrix(ground_truth_input, predicted_indices, num_classes=label_count) evaluation_step = tf.reduce_mean(tf.cast(correct_prediction, tf.float32)) tf.summary.scalar('accuracy', evaluation_step) global_step = tf.train.get_or_create_global_step() increment_global_step = tf.assign(global_step, global_step + 1) saver = tf.train.Saver(tf.global_variables()) # Merge all the summaries and write them out to /tmp/retrain_logs (by default) merged_summaries = tf.summary.merge_all() train_writer = tf.summary.FileWriter(FLAGS.summaries_dir + '/train', sess.graph) validation_writer = tf.summary.FileWriter(FLAGS.summaries_dir + '/validation') tf.global_variables_initializer().run() start_step = 1 if FLAGS.start_checkpoint: models.load_variables_from_checkpoint(sess, FLAGS.start_checkpoint) start_step = global_step.eval(session=sess) tf.logging.info('Training from step: %d ', start_step) # Save graph.pbtxt. tf.train.write_graph(sess.graph_def, FLAGS.train_dir, FLAGS.model_architecture + '.pbtxt') # Save list of words. with gfile.GFile( os.path.join(FLAGS.train_dir, FLAGS.model_architecture + '_labels.txt'), 'w') as f: f.write('\n'.join(audio_processor.words_list)) # ----------------------------------------------------------------------- # -----------------Training and validation------------------------------- # ----------------------------------------------------------------------- # Training loop. training_steps_max = training_steps # Print the local time of beginning training beg_time = datetime.datetime.now() print("Beginning time : " + str(beg_time)) for training_step in xrange(start_step, training_steps_max + 1): # Pull the audio samples we'll use for training. train_fingerprints, train_ground_truth = audio_processor.get_data( FLAGS.batch_size, 0, model_settings, FLAGS.background_frequency, FLAGS.background_volume, time_shift_samples, 'training', sess) # Run the graph with this batch of training data. train_summary, train_accuracy, cross_entropy_value, _, _ = sess.run( [ merged_summaries, evaluation_step, cross_entropy_mean, train_step, increment_global_step ], feed_dict={ fingerprint_input: train_fingerprints, ground_truth_input: train_ground_truth, dropout_prob: 0.5 }) train_writer.add_summary(train_summary, training_step) tf.logging.info( 'Step #%d: rate %f, accuracy %.1f%%, cross entropy %f' % (training_step, learning_rate, train_accuracy * 100, cross_entropy_value)) is_last_step = (training_step == training_steps_max) # Validation if (training_step % FLAGS.eval_step_interval) == 0 or is_last_step: set_size = audio_processor.set_size('validation') total_accuracy = 0 total_conf_matrix = None for i in xrange(0, set_size, FLAGS.batch_size): validation_fingerprints, validation_ground_truth = ( audio_processor.get_data(FLAGS.batch_size, i, model_settings, 0.0, 0.0, 0, 'validation', sess)) # Run a validation step and capture training summaries for TensorBoard # with the `merged` op. validation_summary, validation_accuracy, conf_matrix = sess.run( [merged_summaries, evaluation_step, confusion_matrix], feed_dict={ fingerprint_input: validation_fingerprints, ground_truth_input: validation_ground_truth, dropout_prob: 1.0 }) validation_writer.add_summary(validation_summary, training_step) batch_size = min(FLAGS.batch_size, set_size - i) total_accuracy += (validation_accuracy * batch_size) / set_size if total_conf_matrix is None: total_conf_matrix = conf_matrix else: total_conf_matrix += conf_matrix tf.logging.info('Confusion Matrix:\n %s' % (total_conf_matrix)) tf.logging.info('Step %d: Validation accuracy = %.1f%% (N=%d)' % (training_step, total_accuracy * 100, set_size)) # Save the model checkpoint periodically. if (training_step % FLAGS.save_step_interval == 0 or training_step == training_steps_max): checkpoint_path = os.path.join(FLAGS.train_dir, FLAGS.model_architecture + '.ckpt') tf.logging.info('Saving to "%s-%d"', checkpoint_path, training_step) saver.save(sess, checkpoint_path, global_step=training_step) # Print the local time of ending training print("Beginning time : " + str(beg_time)) print("Ending time : " + str(datetime.datetime.now())) # ----------------------------------------------------------------------- # ------------------------------Test------------------------------------- # ----------------------------------------------------------------------- set_size = audio_processor.set_size('testing') tf.logging.info('set_size=%d', set_size) total_accuracy = 0 total_conf_matrix = None for i in xrange(0, set_size, FLAGS.batch_size): test_fingerprints, test_ground_truth = audio_processor.get_data( FLAGS.batch_size, i, model_settings, 0.0, 0.0, 0, 'testing', sess) test_accuracy, conf_matrix = sess.run( [evaluation_step, confusion_matrix], feed_dict={ fingerprint_input: test_fingerprints, ground_truth_input: test_ground_truth, dropout_prob: 1.0 }) batch_size = min(FLAGS.batch_size, set_size - i) total_accuracy += (test_accuracy * batch_size) / set_size if total_conf_matrix is None: total_conf_matrix = conf_matrix else: total_conf_matrix += conf_matrix tf.logging.info('Confusion Matrix:\n %s' % (total_conf_matrix)) tf.logging.info('Final test accuracy = %.1f%% (N=%d)' % (total_accuracy * 100, set_size))
def test_export_savedmodel(self): tmpdir = tempfile.mkdtemp() est, serving_input_fn = _build_estimator_for_export_tests(tmpdir) extra_file_name = os.path.join( compat.as_bytes(tmpdir), compat.as_bytes('my_extra_file')) extra_file = gfile.GFile(extra_file_name, mode='w') extra_file.write(EXTRA_FILE_CONTENT) extra_file.close() assets_extra = {'some/sub/directory/my_extra_file': extra_file_name} export_dir_base = os.path.join( compat.as_bytes(tmpdir), compat.as_bytes('export')) export_dir = est.export_savedmodel( export_dir_base, serving_input_fn, assets_extra=assets_extra) self.assertTrue(gfile.Exists(export_dir_base)) self.assertTrue(gfile.Exists(export_dir)) self.assertTrue( gfile.Exists( os.path.join( compat.as_bytes(export_dir), compat.as_bytes( 'saved_model.pb')))) self.assertTrue( gfile.Exists( os.path.join( compat.as_bytes(export_dir), compat.as_bytes('variables')))) self.assertTrue( gfile.Exists( os.path.join( compat.as_bytes(export_dir), compat.as_bytes('variables/variables.index')))) self.assertTrue( gfile.Exists( os.path.join( compat.as_bytes(export_dir), compat.as_bytes('variables/variables.data-00000-of-00001')))) self.assertTrue( gfile.Exists( os.path.join( compat.as_bytes(export_dir), compat.as_bytes('assets')))) self.assertTrue( gfile.Exists( os.path.join( compat.as_bytes(export_dir), compat.as_bytes('assets/my_vocab_file')))) self.assertEqual( compat.as_bytes(VOCAB_FILE_CONTENT), compat.as_bytes( gfile.GFile( os.path.join( compat.as_bytes(export_dir), compat.as_bytes('assets/my_vocab_file'))).read())) expected_extra_path = os.path.join( compat.as_bytes(export_dir), compat.as_bytes('assets.extra/some/sub/directory/my_extra_file')) self.assertTrue( gfile.Exists( os.path.join( compat.as_bytes(export_dir), compat.as_bytes('assets.extra')))) self.assertTrue(gfile.Exists(expected_extra_path)) self.assertEqual( compat.as_bytes(EXTRA_FILE_CONTENT), compat.as_bytes(gfile.GFile(expected_extra_path).read())) expected_vocab_file = os.path.join( compat.as_bytes(tmpdir), compat.as_bytes('my_vocab_file')) # Restore, to validate that the export was well-formed. with ops.Graph().as_default() as graph: with session_lib.Session(graph=graph) as sess: loader.load(sess, [tag_constants.SERVING], export_dir) assets = [ x.eval() for x in graph.get_collection(ops.GraphKeys.ASSET_FILEPATHS) ] self.assertItemsEqual([expected_vocab_file], assets) graph_ops = [x.name for x in graph.get_operations()] self.assertTrue('input_example_tensor' in graph_ops) self.assertTrue('ParseExample/ParseExample' in graph_ops) self.assertTrue('linear/linear/feature/matmul' in graph_ops) # cleanup gfile.DeleteRecursively(tmpdir)
def main(_): # We want to see all the logging messages for this tutorial. tf.logging.set_verbosity(tf.logging.INFO) # Start a new TensorFlow session. sess = tf.InteractiveSession() # Begin by making sure we have the training data we need. If you already have # training data of your own, use `--data_url= ` on the command line to avoid # downloading. model_settings = models.prepare_model_settings( len(input_data.prepare_words_list(FLAGS.wanted_words.split(','))), FLAGS.sample_rate, FLAGS.clip_duration_ms, FLAGS.window_size_ms, FLAGS.window_stride_ms, FLAGS.dct_coefficient_count) audio_processor = input_data.AudioProcessor(FLAGS.data_url, FLAGS.data_dir, FLAGS.silence_percentage, FLAGS.unknown_percentage, FLAGS.wanted_words.split(','), FLAGS.validation_percentage, FLAGS.testing_percentage, model_settings) fingerprint_size = model_settings['fingerprint_size'] label_count = model_settings['label_count'] time_shift_samples = int((FLAGS.time_shift_ms * FLAGS.sample_rate) / 1000) # Figure out the learning rates for each training phase. Since it's often # effective to have high learning rates at the start of training, followed by # lower levels towards the end, the number of steps and learning rates can be # specified as comma-separated lists to define the rate at each stage. For # example --how_many_training_steps=10000,3000 --learning_rate=0.001,0.0001 # will run 13,000 training loops in total, with a rate of 0.001 for the first # 10,000, and 0.0001 for the final 3,000. training_steps_list = list( map(int, FLAGS.how_many_training_steps.split(','))) learning_rates_list = list(map(float, FLAGS.learning_rate.split(','))) if len(training_steps_list) != len(learning_rates_list): raise Exception( '--how_many_training_steps and --learning_rate must be equal length ' 'lists, but are %d and %d long instead' % (len(training_steps_list), len(learning_rates_list))) fingerprint_input = tf.placeholder(tf.float32, [None, fingerprint_size], name='fingerprint_input') logits, dropout_prob = models.create_model(fingerprint_input, model_settings, FLAGS.model_architecture, is_training=True) # Define loss and optimizer ground_truth_input = tf.placeholder(tf.int64, [None], name='groundtruth_input') # Optionally we can add runtime checks to spot when NaNs or other symptoms of # numerical errors start occurring during training. control_dependencies = [] if FLAGS.check_nans: checks = tf.add_check_numerics_ops() control_dependencies = [checks] # Create the back propagation and training evaluation machinery in the graph. with tf.name_scope('cross_entropy'): cross_entropy_mean = tf.losses.sparse_softmax_cross_entropy( labels=ground_truth_input, logits=logits) tf.summary.scalar('cross_entropy', cross_entropy_mean) with tf.name_scope('train'), tf.control_dependencies(control_dependencies): learning_rate_input = tf.placeholder(tf.float32, [], name='learning_rate_input') train_step = tf.train.GradientDescentOptimizer( learning_rate_input).minimize(cross_entropy_mean) predicted_indices = tf.argmax(logits, 1) correct_prediction = tf.equal(predicted_indices, ground_truth_input) confusion_matrix = tf.confusion_matrix(ground_truth_input, predicted_indices, num_classes=label_count) evaluation_step = tf.reduce_mean(tf.cast(correct_prediction, tf.float32)) tf.summary.scalar('accuracy', evaluation_step) global_step = tf.train.get_or_create_global_step() increment_global_step = tf.assign(global_step, global_step + 1) saver = tf.train.Saver(tf.global_variables()) # Merge all the summaries and write them out to /tmp/retrain_logs (by default) merged_summaries = tf.summary.merge_all() train_writer = tf.summary.FileWriter(FLAGS.summaries_dir + '/train', sess.graph) validation_writer = tf.summary.FileWriter(FLAGS.summaries_dir + '/validation') tf.global_variables_initializer().run() start_step = 1 if FLAGS.start_checkpoint: models.load_variables_from_checkpoint(sess, FLAGS.start_checkpoint) start_step = global_step.eval(session=sess) tf.logging.info('Training from step: %d ', start_step) # Save graph.pbtxt. tf.train.write_graph(sess.graph_def, FLAGS.train_dir, FLAGS.model_architecture + '.pbtxt') # Save list of words. with gfile.GFile( os.path.join(FLAGS.train_dir, FLAGS.model_architecture + '_labels.txt'), 'w') as f: f.write('\n'.join(audio_processor.words_list)) # Training loop. training_steps_max = np.sum(training_steps_list) for training_step in xrange(start_step, training_steps_max + 1): # Figure out what the current learning rate is. training_steps_sum = 0 for i in range(len(training_steps_list)): training_steps_sum += training_steps_list[i] if training_step <= training_steps_sum: learning_rate_value = learning_rates_list[i] break # Pull the audio samples we'll use for training. train_fingerprints, train_ground_truth = audio_processor.get_data( FLAGS.batch_size, 0, model_settings, FLAGS.background_frequency, FLAGS.background_volume, time_shift_samples, 'training', sess) # Run the graph with this batch of training data. train_summary, train_accuracy, cross_entropy_value, _, _ = sess.run( [ merged_summaries, evaluation_step, cross_entropy_mean, train_step, increment_global_step ], feed_dict={ fingerprint_input: train_fingerprints, ground_truth_input: train_ground_truth, learning_rate_input: learning_rate_value, dropout_prob: 0.5 }) train_writer.add_summary(train_summary, training_step) tf.logging.info( 'Step #%d: rate %f, accuracy %.1f%%, cross entropy %f' % (training_step, learning_rate_value, train_accuracy * 100, cross_entropy_value)) is_last_step = (training_step == training_steps_max) if (training_step % FLAGS.eval_step_interval) == 0 or is_last_step: set_size = audio_processor.set_size('validation') total_accuracy = 0 total_conf_matrix = None for i in xrange(0, set_size, FLAGS.batch_size): validation_fingerprints, validation_ground_truth = ( audio_processor.get_data(FLAGS.batch_size, i, model_settings, 0.0, 0.0, 0, 'validation', sess)) # Run a validation step and capture training summaries for TensorBoard # with the `merged` op. validation_summary, validation_accuracy, conf_matrix = sess.run( [merged_summaries, evaluation_step, confusion_matrix], feed_dict={ fingerprint_input: validation_fingerprints, ground_truth_input: validation_ground_truth, dropout_prob: 1.0 }) validation_writer.add_summary(validation_summary, training_step) batch_size = min(FLAGS.batch_size, set_size - i) total_accuracy += (validation_accuracy * batch_size) / set_size if total_conf_matrix is None: total_conf_matrix = conf_matrix else: total_conf_matrix += conf_matrix tf.logging.info('Confusion Matrix:\n %s' % (total_conf_matrix)) tf.logging.info('Step %d: Validation accuracy = %.1f%% (N=%d)' % (training_step, total_accuracy * 100, set_size)) # Save the model checkpoint periodically. if (training_step % FLAGS.save_step_interval == 0 or training_step == training_steps_max): checkpoint_path = os.path.join(FLAGS.train_dir, FLAGS.model_architecture + '.ckpt') tf.logging.info('Saving to "%s-%d"', checkpoint_path, training_step) saver.save(sess, checkpoint_path, global_step=training_step) set_size = audio_processor.set_size('testing') tf.logging.info('set_size=%d', set_size) total_accuracy = 0 total_conf_matrix = None for i in xrange(0, set_size, FLAGS.batch_size): test_fingerprints, test_ground_truth = audio_processor.get_data( FLAGS.batch_size, i, model_settings, 0.0, 0.0, 0, 'testing', sess) test_accuracy, conf_matrix = sess.run( [evaluation_step, confusion_matrix], feed_dict={ fingerprint_input: test_fingerprints, ground_truth_input: test_ground_truth, dropout_prob: 1.0 }) batch_size = min(FLAGS.batch_size, set_size - i) total_accuracy += (test_accuracy * batch_size) / set_size if total_conf_matrix is None: total_conf_matrix = conf_matrix else: total_conf_matrix += conf_matrix tf.logging.info('Confusion Matrix:\n %s' % (total_conf_matrix)) tf.logging.info('Final test accuracy = %.1f%% (N=%d)' % (total_accuracy * 100, set_size))
def freeze_graph_with_def_protos(input_graph_def, input_saver_def, input_checkpoint, output_node_names, restore_op_name, filename_tensor_name, output_graph, clear_devices, initializer_nodes, variable_names_whitelist="", variable_names_blacklist="", input_meta_graph_def=None, input_saved_model_dir=None, saved_model_tags=None, checkpoint_version=saver_pb2.SaverDef.V2): """Converts all variables in a graph and checkpoint into constants.""" del restore_op_name, filename_tensor_name # Unused by updated loading code. # 'input_checkpoint' may be a prefix if we're using Saver V2 format if (not input_saved_model_dir and not saver_lib.checkpoint_exists(input_checkpoint)): print("Input checkpoint '" + input_checkpoint + "' doesn't exist!") return -1 if not output_node_names: print("You need to supply the name of a node to --output_node_names.") return -1 # Remove all the explicit device specifications for this node. This helps to # make the graph more portable. if clear_devices: if input_meta_graph_def: for node in input_meta_graph_def.graph_def.node: node.device = "" elif input_graph_def: for node in input_graph_def.node: node.device = "" if input_graph_def: _ = importer.import_graph_def(input_graph_def, name="") with session.Session() as sess: if input_saver_def: saver = saver_lib.Saver(saver_def=input_saver_def, write_version=checkpoint_version) saver.restore(sess, input_checkpoint) elif input_meta_graph_def: restorer = saver_lib.import_meta_graph(input_meta_graph_def, clear_devices=True) restorer.restore(sess, input_checkpoint) if initializer_nodes: sess.run(initializer_nodes.replace(" ", "").split(",")) elif input_saved_model_dir: if saved_model_tags is None: saved_model_tags = [] loader.load(sess, saved_model_tags, input_saved_model_dir) else: var_list = {} reader = pywrap_tensorflow.NewCheckpointReader(input_checkpoint) # ruiyingy:start # checkpointkey_to_node = {} # object_graph_string = reader.get_tensor(base.OBJECT_GRAPH_PROTO_KEY) # object_graph_proto = trackable_object_graph_pb2.TrackableObjectGraph() # object_graph_proto.ParseFromString(object_graph_string) # for node in object_graph_proto.nodes: # for attr in node.attributes: # checkpointkey_to_node[attr.checkpoint_key] = attr.full_name # ruiyingy:end var_to_shape_map = reader.get_variable_to_shape_map() for key in var_to_shape_map: try: # ruiyingy # key = checkpointkey_to_node[key] tensor = sess.graph.get_tensor_by_name(key + ":0") except KeyError: # This tensor doesn't exist in the graph (for example it's # 'global_step' or a similar housekeeping element) so skip it. continue var_list[key] = tensor saver = saver_lib.Saver(var_list=var_list, write_version=checkpoint_version) saver.restore(sess, input_checkpoint) if initializer_nodes: sess.run(initializer_nodes.replace(" ", "").split(",")) variable_names_whitelist = (variable_names_whitelist.replace( " ", "").split(",") if variable_names_whitelist else None) variable_names_blacklist = (variable_names_blacklist.replace( " ", "").split(",") if variable_names_blacklist else None) if input_meta_graph_def: output_graph_def = graph_util.convert_variables_to_constants( sess, input_meta_graph_def.graph_def, output_node_names.replace(" ", "").split(","), variable_names_whitelist=variable_names_whitelist, variable_names_blacklist=variable_names_blacklist) else: output_graph_def = graph_util.convert_variables_to_constants( sess, input_graph_def, output_node_names.replace(" ", "").split(","), variable_names_whitelist=variable_names_whitelist, variable_names_blacklist=variable_names_blacklist) # Write GraphDef to file if output path has been given. if output_graph: with gfile.GFile(output_graph, "wb") as f: f.write(output_graph_def.SerializeToString()) #with gfile.GFile(output_graph, "w") as f: # f.write(text_format.MessageToString(output_graph_def)) #tensorflow.train.write_graph(output_graph, "./", name, as_text=True) return output_graph_def
def create_vocabulary(vocabulary_path, data_path, max_vocabulary_size, tokenizer=None, normalize_digits=True, _DIGIT_RE=re.compile(br"\d"), _START_VOCAB=None): r"""Create vocabulary file (if it does not exist yet) from data file. Data file is assumed to contain one sentence per line. Each sentence is tokenized and digits are normalized (if normalize_digits is set). Vocabulary contains the most-frequent tokens up to max_vocabulary_size. We write it to vocabulary_path in a one-token-per-line format, so that later token in the first line gets id=0, second line gets id=1, and so on. Parameters ----------- vocabulary_path : str Path where the vocabulary will be created. data_path : str Data file that will be used to create vocabulary. max_vocabulary_size : int Limit on the size of the created vocabulary. tokenizer : function A function to use to tokenize each data sentence. If None, basic_tokenizer will be used. normalize_digits : boolean If true, all digits are replaced by `0`. _DIGIT_RE : regular expression function Default is ``re.compile(br"\d")``. _START_VOCAB : list of str The pad, go, eos and unk token, default is ``[b"_PAD", b"_GO", b"_EOS", b"_UNK"]``. References ---------- - Code from ``/tensorflow/models/rnn/translation/data_utils.py`` """ if _START_VOCAB is None: _START_VOCAB = [b"_PAD", b"_GO", b"_EOS", b"_UNK"] if not gfile.Exists(vocabulary_path): tl.logging.info("Creating vocabulary %s from data %s" % (vocabulary_path, data_path)) vocab = {} with gfile.GFile(data_path, mode="rb") as f: counter = 0 for line in f: counter += 1 if counter % 100000 == 0: tl.logging.info(" processing line %d" % counter) tokens = tokenizer(line) if tokenizer else basic_tokenizer( line) for w in tokens: word = re.sub(_DIGIT_RE, b"0", w) if normalize_digits else w if word in vocab: vocab[word] += 1 else: vocab[word] = 1 vocab_list = _START_VOCAB + sorted( vocab, key=vocab.get, reverse=True) if len(vocab_list) > max_vocabulary_size: vocab_list = vocab_list[:max_vocabulary_size] with gfile.GFile(vocabulary_path, mode="wb") as vocab_file: for w in vocab_list: vocab_file.write(w + b"\n") else: tl.logging.info("Vocabulary %s from data %s exists" % (vocabulary_path, data_path))
def freeze_graph_with_def_protos(input_graph_def, input_saver_def, input_checkpoint, output_node_names, restore_op_name, filename_tensor_name, output_graph, clear_devices, initializer_nodes, variable_names_whitelist="", variable_names_blacklist="", input_meta_graph_def=None, input_saved_model_dir=None, saved_model_tags=None, checkpoint_version=saver_pb2.SaverDef.V2): """Converts all variables in a graph and checkpoint into constants. Args: input_graph_def: A `GraphDef`. input_saver_def: A `SaverDef` (optional). input_checkpoint: The prefix of a V1 or V2 checkpoint, with V2 taking priority. Typically the result of `Saver.save()` or that of `tf.train.latest_checkpoint()`, regardless of sharded/non-sharded or V1/V2. output_node_names: The name(s) of the output nodes, comma separated. restore_op_name: Unused. filename_tensor_name: Unused. output_graph: String where to write the frozen `GraphDef`. clear_devices: A Bool whether to remove device specifications. initializer_nodes: Comma separated string of initializer nodes to run before freezing. variable_names_whitelist: The set of variable names to convert (optional, by default, all variables are converted). variable_names_blacklist: The set of variable names to omit converting to constants (optional). input_meta_graph_def: A `MetaGraphDef` (optional), input_saved_model_dir: Path to the dir with TensorFlow 'SavedModel' file and variables (optional). saved_model_tags: Group of comma separated tag(s) of the MetaGraphDef to load, in string format (optional). checkpoint_version: Tensorflow variable file format (saver_pb2.SaverDef.V1 or saver_pb2.SaverDef.V2) Returns: Location of the output_graph_def. """ del restore_op_name, filename_tensor_name # Unused by updated loading code. # 'input_checkpoint' may be a prefix if we're using Saver V2 format if (not input_saved_model_dir and not checkpoint_management.checkpoint_exists(input_checkpoint)): print("Input checkpoint '" + input_checkpoint + "' doesn't exist!") return -1 if not output_node_names: print("You need to supply the name of a node to --output_node_names.") return -1 # Remove all the explicit device specifications for this node. This helps to # make the graph more portable. if clear_devices: if input_meta_graph_def: for node in input_meta_graph_def.graph_def.node: node.device = "" elif input_graph_def: for node in input_graph_def.node: node.device = "" if input_graph_def: _ = importer.import_graph_def(input_graph_def, name="") with session.Session() as sess: if input_saver_def: saver = saver_lib.Saver(saver_def=input_saver_def, write_version=checkpoint_version) saver.restore(sess, input_checkpoint) elif input_meta_graph_def: restorer = saver_lib.import_meta_graph(input_meta_graph_def, clear_devices=True) restorer.restore(sess, input_checkpoint) if initializer_nodes: sess.run(initializer_nodes.replace(" ", "").split(",")) elif input_saved_model_dir: if saved_model_tags is None: saved_model_tags = [] loader.load(sess, saved_model_tags, input_saved_model_dir) else: var_list = {} reader = pywrap_tensorflow.NewCheckpointReader(input_checkpoint) var_to_shape_map = reader.get_variable_to_shape_map() # List of all partition variables. Because the condition is heuristic # based, the list could include false positives. all_parition_variable_names = [ tensor.name.split(":")[0] for op in sess.graph.get_operations() for tensor in op.values() if re.search(r"/part_\d+/", tensor.name) ] has_partition_var = False for key in var_to_shape_map: try: tensor = sess.graph.get_tensor_by_name(key + ":0") if any(key in name for name in all_parition_variable_names): has_partition_var = True except KeyError: # This tensor doesn't exist in the graph (for example it's # 'global_step' or a similar housekeeping element) so skip it. continue var_list[key] = tensor try: saver = saver_lib.Saver(var_list=var_list, write_version=checkpoint_version) except TypeError as e: # `var_list` is required to be a map of variable names to Variable # tensors. Partition variables are Identity tensors that cannot be # handled by Saver. if has_partition_var: print( "Models containing partition variables cannot be converted " "from checkpoint files. Please pass in a SavedModel using " "the flag --input_saved_model_dir.") return -1 # Models that have been frozen previously do not contain Variables. elif _has_no_variables(sess): print( "No variables were found in this model. It is likely the model " "was frozen previously. You cannot freeze a graph twice." ) return 0 else: raise e saver.restore(sess, input_checkpoint) if initializer_nodes: sess.run(initializer_nodes.replace(" ", "").split(",")) variable_names_whitelist = (variable_names_whitelist.replace( " ", "").split(",") if variable_names_whitelist else None) variable_names_blacklist = (variable_names_blacklist.replace( " ", "").split(",") if variable_names_blacklist else None) if input_meta_graph_def: output_graph_def = graph_util.convert_variables_to_constants( sess, input_meta_graph_def.graph_def, output_node_names.replace(" ", "").split(","), variable_names_whitelist=variable_names_whitelist, variable_names_blacklist=variable_names_blacklist) else: output_graph_def = graph_util.convert_variables_to_constants( sess, input_graph_def, output_node_names.replace(" ", "").split(","), variable_names_whitelist=variable_names_whitelist, variable_names_blacklist=variable_names_blacklist) # Write GraphDef to file if output path has been given. if output_graph: with gfile.GFile(output_graph, "wb") as f: f.write(output_graph_def.SerializeToString()) return output_graph_def