Ejemplo n.º 1
0
def write_vocabulary(vocab_file, words):
    with gfile.GFile(vocab_file, mode="wb") as vocab_file:
        for w in words:
            vocab_file.write(bytes(w, 'utf-8') + b"\n")
def freeze_graph_with_def_protos(input_graph_def,
                                 input_saver_def,
                                 input_checkpoint,
                                 output_node_names,
                                 restore_op_name,
                                 filename_tensor_name,
                                 output_graph,
                                 clear_devices,
                                 initializer_nodes,
                                 variable_names_whitelist="",
                                 variable_names_blacklist="",
                                 input_meta_graph_def=None,
                                 input_saved_model_dir=None,
                                 saved_model_tags=None,
                                 checkpoint_version=saver_pb2.SaverDef.V2):
  """Converts all variables in a graph and checkpoint into constants."""
  del restore_op_name, filename_tensor_name  # Unused by updated loading code.

  # 'input_checkpoint' may be a prefix if we're using Saver V2 format
  if (not input_saved_model_dir and
      not saver_lib.checkpoint_exists(input_checkpoint)):
    print("Input checkpoint '" + input_checkpoint + "' doesn't exist!")
    return -1

  if not output_node_names:
    print("You need to supply the name of a node to --output_node_names.")
    return -1

  # Remove all the explicit device specifications for this node. This helps to
  # make the graph more portable.
  if clear_devices:
    if input_meta_graph_def:
      for node in input_meta_graph_def.graph_def.node:
        node.device = ""
    elif input_graph_def:
      for node in input_graph_def.node:
        node.device = ""

  if input_graph_def:
    _ = importer.import_graph_def(input_graph_def, name="")
  with session.Session() as sess:
    if input_saver_def:
      saver = saver_lib.Saver(
          saver_def=input_saver_def, write_version=checkpoint_version)
      saver.restore(sess, input_checkpoint)
    elif input_meta_graph_def:
      restorer = saver_lib.import_meta_graph(
          input_meta_graph_def, clear_devices=True)
      restorer.restore(sess, input_checkpoint)
      if initializer_nodes:
        sess.run(initializer_nodes.replace(" ", "").split(","))
    elif input_saved_model_dir:
      if saved_model_tags is None:
        saved_model_tags = []
      loader.load(sess, saved_model_tags, input_saved_model_dir)
    else:
      var_list = {}
      reader = pywrap_tensorflow.NewCheckpointReader(input_checkpoint)
      var_to_shape_map = reader.get_variable_to_shape_map()

      # List of all partition variables. Because the condition is heuristic
      # based, the list could include false positives.
      all_parition_variable_names = [
          tensor.name.split(":")[0]
          for op in sess.graph.get_operations()
          for tensor in op.values()
          if re.search(r"/part_\d+/", tensor.name)
      ]
      has_partition_var = False

      for key in var_to_shape_map:
        try:
          tensor = sess.graph.get_tensor_by_name(key + ":0")
          if any(key in name for name in all_parition_variable_names):
            has_partition_var = True
        except KeyError:
          # This tensor doesn't exist in the graph (for example it's
          # 'global_step' or a similar housekeeping element) so skip it.
          continue
        var_list[key] = tensor

      try:
        saver = saver_lib.Saver(
            var_list=var_list, write_version=checkpoint_version)
      except TypeError as e:
        # `var_list` is required to be a map of variable names to Variable
        # tensors. Partition variables are Identity tensors that cannot be
        # handled by Saver.
        if has_partition_var:
          print("Models containing partition variables cannot be converted "
                "from checkpoint files. Please pass in a SavedModel using "
                "the flag --input_saved_model_dir.")
          return -1
        else:
          raise e

      saver.restore(sess, input_checkpoint)
      if initializer_nodes:
        sess.run(initializer_nodes.replace(" ", "").split(","))

    variable_names_whitelist = (
        variable_names_whitelist.replace(" ", "").split(",")
        if variable_names_whitelist else None)
    variable_names_blacklist = (
        variable_names_blacklist.replace(" ", "").split(",")
        if variable_names_blacklist else None)

    if input_meta_graph_def:
      output_graph_def = graph_util.convert_variables_to_constants(
          sess,
          input_meta_graph_def.graph_def,
          output_node_names.replace(" ", "").split(","),
          variable_names_whitelist=variable_names_whitelist,
          variable_names_blacklist=variable_names_blacklist)
    else:
      output_graph_def = graph_util.convert_variables_to_constants(
          sess,
          input_graph_def,
          output_node_names.replace(" ", "").split(","),
          variable_names_whitelist=variable_names_whitelist,
          variable_names_blacklist=variable_names_blacklist)

  # Write GraphDef to file if output path has been given.
  if output_graph:
    with gfile.GFile(output_graph, "wb") as f:
      f.write(output_graph_def.SerializeToString())

  return output_graph_def
Ejemplo n.º 3
0
 def read_benchmark_entry(f):
   s = gfile.GFile(f, "rb").read()
   entries = test_log_pb2.BenchmarkEntries.FromString(s)
   self.assertEquals(1, len(entries.entry))
   return entries.entry[0]
Ejemplo n.º 4
0
  def doBasicsOneExportPath(self,
                            export_path,
                            clear_devices=False,
                            global_step=GLOBAL_STEP,
                            sharded=True,
                            export_count=1):
    # Build a graph with 2 parameter nodes on different devices.
    ops.reset_default_graph()
    with session.Session(
        target="",
        config=config_pb2.ConfigProto(device_count={"CPU": 2})) as sess:
      # v2 is an unsaved variable derived from v0 and v1.  It is used to
      # exercise the ability to run an init op when restoring a graph.
      with sess.graph.device("/cpu:0"):
        v0 = variables.VariableV1(10, name="v0")
      with sess.graph.device("/cpu:1"):
        v1 = variables.VariableV1(20, name="v1")
      v2 = variables.VariableV1(1, name="v2", trainable=False, collections=[])
      assign_v2 = state_ops.assign(v2, math_ops.add(v0, v1))
      init_op = control_flow_ops.group(assign_v2, name="init_op")

      ops.add_to_collection("v", v0)
      ops.add_to_collection("v", v1)
      ops.add_to_collection("v", v2)

      named_tensor_bindings = {"logical_input_A": v0, "logical_input_B": v1}
      signatures = {
          "foo":
              exporter.regression_signature(
                  input_tensor=v0, output_tensor=v1),
          "generic":
              exporter.generic_signature(named_tensor_bindings)
      }

      asset_filepath_orig = os.path.join(test.get_temp_dir(), "hello42.txt")
      asset_file = constant_op.constant(asset_filepath_orig, name="filename42")
      ops.add_to_collection(ops.GraphKeys.ASSET_FILEPATHS, asset_file)

      with gfile.GFile(asset_filepath_orig, "w") as f:
        f.write("your data here")
      assets_collection = ops.get_collection(ops.GraphKeys.ASSET_FILEPATHS)

      ignored_asset = os.path.join(test.get_temp_dir(), "ignored.txt")
      with gfile.GFile(ignored_asset, "w") as f:
        f.write("additional data here")

      variables.global_variables_initializer().run()

      # Run an export.
      save = saver.Saver(
          {
              "v0": v0,
              "v1": v1
          },
          restore_sequentially=True,
          sharded=sharded,
          write_version=saver_pb2.SaverDef.V1)
      export = exporter.Exporter(save)
      compare_def = ops.get_default_graph().as_graph_def()
      export.init(
          compare_def,
          init_op=init_op,
          clear_devices=clear_devices,
          default_graph_signature=exporter.classification_signature(
              input_tensor=v0),
          named_graph_signatures=signatures,
          assets_collection=assets_collection)

      for x in range(export_count):
        export.export(
            export_path,
            constant_op.constant(global_step + x),
            sess,
            exports_to_keep=gc.largest_export_versions(2))
      # Set global_step to the last exported version, as the rest of the test
      # uses it to construct model export path, loads model from it, and does
      # verifications. We want to make sure to always use the last exported
      # version, as old ones may have be garbage-collected.
      global_step += export_count - 1

    # Restore graph.
    ops.reset_default_graph()
    with session.Session(
        target="",
        config=config_pb2.ConfigProto(device_count={"CPU": 2})) as sess:
      save = saver.import_meta_graph(
          os.path.join(export_path, constants.VERSION_FORMAT_SPECIFIER %
                       global_step, constants.META_GRAPH_DEF_FILENAME))
      self.assertIsNotNone(save)
      meta_graph_def = save.export_meta_graph()
      collection_def = meta_graph_def.collection_def

      # Validate custom graph_def.
      graph_def_any = collection_def[constants.GRAPH_KEY].any_list.value
      self.assertEquals(len(graph_def_any), 1)
      graph_def = graph_pb2.GraphDef()
      graph_def_any[0].Unpack(graph_def)
      if clear_devices:
        for node in compare_def.node:
          node.device = ""
      self.assertProtoEquals(compare_def, graph_def)

      # Validate init_op.
      init_ops = collection_def[constants.INIT_OP_KEY].node_list.value
      self.assertEquals(len(init_ops), 1)
      self.assertEquals(init_ops[0], "init_op")

      # Validate signatures.
      signatures_any = collection_def[constants.SIGNATURES_KEY].any_list.value
      self.assertEquals(len(signatures_any), 1)
      signatures = manifest_pb2.Signatures()
      signatures_any[0].Unpack(signatures)
      default_signature = signatures.default_signature
      self.assertEqual(
          default_signature.classification_signature.input.tensor_name, "v0:0")
      bindings = signatures.named_signatures["generic"].generic_signature.map
      self.assertEquals(bindings["logical_input_A"].tensor_name, "v0:0")
      self.assertEquals(bindings["logical_input_B"].tensor_name, "v1:0")
      read_foo_signature = (
          signatures.named_signatures["foo"].regression_signature)
      self.assertEquals(read_foo_signature.input.tensor_name, "v0:0")
      self.assertEquals(read_foo_signature.output.tensor_name, "v1:0")

      # Validate the assets.
      assets_any = collection_def[constants.ASSETS_KEY].any_list.value
      self.assertEquals(len(assets_any), 1)
      asset = manifest_pb2.AssetFile()
      assets_any[0].Unpack(asset)
      assets_path = os.path.join(export_path,
                                 constants.VERSION_FORMAT_SPECIFIER %
                                 global_step, constants.ASSETS_DIRECTORY,
                                 "hello42.txt")
      asset_contents = gfile.GFile(assets_path).read()
      self.assertEqual(asset_contents, "your data here")
      self.assertEquals("hello42.txt", asset.filename)
      self.assertEquals("filename42:0", asset.tensor_binding.tensor_name)
      ignored_asset_path = os.path.join(export_path,
                                        constants.VERSION_FORMAT_SPECIFIER %
                                        global_step, constants.ASSETS_DIRECTORY,
                                        "ignored.txt")
      self.assertFalse(gfile.Exists(ignored_asset_path))

      # Validate graph restoration.
      if sharded:
        save.restore(sess,
                     os.path.join(export_path,
                                  constants.VERSION_FORMAT_SPECIFIER %
                                  global_step,
                                  constants.VARIABLES_FILENAME_PATTERN))
      else:
        save.restore(sess,
                     os.path.join(export_path,
                                  constants.VERSION_FORMAT_SPECIFIER %
                                  global_step, constants.VARIABLES_FILENAME))
      self.assertEqual(10, ops.get_collection("v")[0].eval())
      self.assertEqual(20, ops.get_collection("v")[1].eval())
      ops.get_collection(constants.INIT_OP_KEY)[0].run()
      self.assertEqual(30, ops.get_collection("v")[2].eval())
Ejemplo n.º 5
0
    def train(self,
              learn_rate,
              dropout_rate,
              save_step,
              batch_size,
              eval_step,
              training_time,
              rate_step,
              display_step,
              train_data,
              Validation_data,
              init=False):
        self._save_step = save_step
        self._training_time = training_time
        assert type(learn_rate) == list,\
             "Learn Rate should be a List to be used. e.g [.001, .0001]"

        self._ground_truth_input = tf.compat.v1.placeholder(
            tf.int64, [None], name='groundtruth_input')

        with tf.compat.v1.name_scope('cross_entropy'):
            self._cross_entropy_mean = tf.compat.v1.losses.sparse_softmax_cross_entropy(
                labels=self._ground_truth_input, logits=self._softmax_layer)

        learning_rate_input = tf.compat.v1.placeholder(
            tf.float32, [], name='learning_rate_input')

        train_step = tf.compat.v1.train.GradientDescentOptimizer(
            learning_rate_input).minimize(self._cross_entropy_mean)

        self._predicted = tf.argmax(input=self._softmax_layer, axis=1)
        correct_prediction = tf.equal(self._predicted,
                                      self._ground_truth_input)

        self._evaluation_step = tf.reduce_mean(
            input_tensor=tf.cast(correct_prediction, tf.float32))

        saver = tf.compat.v1.train.Saver(tf.compat.v1.global_variables())

        if self._loaded is False and self._start_step == 0:
            self._global_step = tf.compat.v1.train.get_or_create_global_step()
            tf.compat.v1.global_variables_initializer().run()
            #self._loaded = True

        increment_global_step = tf.compat.v1.assign(self._global_step,
                                                    self._global_step + 1)
        if init is False:
            tf.io.write_graph(self._sess.graph_def, self._model_dir,
                              "model" + '.pbtxt')

            with gfile.GFile(
                    os.path.join(self._model_dir, "commands" + '_labels.txt'),
                    'wb') as f:
                f.write('\n'.join(self.commands))

        if training_time <= self._start_step and self._loaded is True:
            print(
                f"Checkpoint Loaded has been trained to {self._start_step} epochs,\
                \n New Trainig starts from {self._start_step}, Please increase Training_time to train model"
            )

        if init is False:
            if tf.config.list_physical_devices('GPU'):
                strategy = tf.distribute.MirroredStrategy()
            else:  # use default strategy
                strategy = tf.distribute.get_strategy()

            with strategy.scope():
                history = {
                    "categorical_accuracy": [],
                    "loss": [],
                    "val_categorical_accuracy": [],
                    "val_loss": []
                }
                learning_rate = learn_rate[0]
                for training_step in xrange(self._start_step, training_time):
                    if training_step == int(rate_step):
                        learning_rate = learn_rate[1]

                    x_train, y_train = self.get_next_batch(
                        batch_size, train_data)
                    train_accuracy, cross_entropy_value, _, _ = self._sess.run(
                        [
                            self._evaluation_step,
                            self._cross_entropy_mean,
                            train_step,
                            increment_global_step,
                        ],
                        feed_dict={
                            self._fingerprint_input: x_train,
                            self._ground_truth_input: y_train,
                            learning_rate_input: learning_rate,
                            self._dropout_placeholder: dropout_rate
                        })

                    if training_step % int(display_step) == 0:
                        print(
                            'Step #%d: learning rate %f, accuracy %.1f%%, cross entropy %f'
                            % (training_step, learning_rate,
                               train_accuracy * 100, cross_entropy_value))
                        history["categorical_accuracy"].append(train_accuracy)
                        history["loss"].append(cross_entropy_value)

                    if training_step % int(eval_step) == 0:
                        x_val, y_val = self.get_next_batch(
                            batch_size * 4, Validation_data)
                        validation_accuracy, val_crossentropy_value = self._sess.run(
                            [self._evaluation_step, self._cross_entropy_mean],
                            feed_dict={
                                self._fingerprint_input: x_val,
                                self._ground_truth_input: y_val,
                                self._dropout_placeholder: 0.0
                            })

                        history["val_categorical_accuracy"].append(
                            validation_accuracy)
                        history["val_loss"].append(val_crossentropy_value)

                        print(
                            'Step %d: Validation accuracy = %.1f%% (Val Size=%d), Validation loss = %f'
                            % (training_step, validation_accuracy * 100,
                               batch_size * 4, val_crossentropy_value))

                    if (training_step % int(save_step)
                            == 0) or (training_step == training_time - 1):
                        path_to_save = os.path.join(
                            self._model_dir, "model_checkpoint" + '.ckpt')
                        if (training_step == training_time - 1):
                            training_step = training_time
                        saver.save(self._sess,
                                   path_to_save,
                                   global_step=training_step)
                        self._start_step = self._global_step.eval(
                            session=self._sess)

            return history
Ejemplo n.º 6
0
                        "Save model after this many steps (default: 500)")
# Misc Parameters
tf.flags.DEFINE_boolean("allow_soft_placement", True,
                        "Allow device soft device placement")
tf.flags.DEFINE_boolean("log_device_placement", False,
                        "Log placement of ops on devices")

FLAGS = tf.flags.FLAGS
FLAGS._parse_flags()

if __name__ == "__main__":
    ids_path, vocab_path = data_utils.prepare_data(FLAGS.data_path,
                                                   FLAGS.vocab_size)

    dataset = []
    with gfile.GFile(ids_path, mode="r") as ids_file:
        ids = ids_file.readline()
        while ids:
            ids = map(int, ids.split())
            dataset.append(ids)
            ids = ids_file.readline()

    dataset = [one for one in dataset if len(one) <= FLAGS.max_seq_len]
    train_val_split_point = len(dataset) - int(
        len(dataset) * FLAGS.validation_ratio)
    train_set = dataset[0:train_val_split_point]
    validation_set = dataset[train_val_split_point:]

    model = seq2seq_network.Seq2SeqAutoEncoder(
        FLAGS.vocab_size, FLAGS.embedding_dim, FLAGS.num_units,
        FLAGS.num_layers, FLAGS.max_seq_len, FLAGS.max_gradient_norm,
Ejemplo n.º 7
0
def freeze_graph_with_def_protos(input_graph_def,
                                 input_saver_def,
                                 input_checkpoint,
                                 output_node_names,
                                 restore_op_name,
                                 filename_tensor_name,
                                 output_graph,
                                 clear_devices,
                                 initializer_nodes,
                                 variable_names_whitelist="",
                                 variable_names_blacklist="",
                                 input_meta_graph_def=None,
                                 input_saved_model_dir=None,
                                 saved_model_tags=None):
    """Converts all variables in a graph and checkpoint into constants."""
    del restore_op_name, filename_tensor_name  # Unused by updated loading code.

    # 'input_checkpoint' may be a prefix if we're using Saver V2 format
    if (not input_saved_model_dir
            and not saver_lib.checkpoint_exists(input_checkpoint)):
        print("Input checkpoint '" + input_checkpoint + "' doesn't exist!")
        return -1

    if not output_node_names:
        print("You need to supply the name of a node to --output_node_names.")
        return -1

    # Remove all the explicit device specifications for this node. This helps to
    # make the graph more portable.
    if clear_devices:
        if input_meta_graph_def:
            for node in input_meta_graph_def.graph_def.node:
                node.device = ""
        elif input_graph_def:
            for node in input_graph_def.node:
                node.device = ""

    if input_graph_def:
        _ = importer.import_graph_def(input_graph_def, name="")
    with session.Session() as sess:
        if input_saver_def:
            saver = saver_lib.Saver(saver_def=input_saver_def)
            saver.restore(sess, input_checkpoint)
        elif input_meta_graph_def:
            restorer = saver_lib.import_meta_graph(input_meta_graph_def,
                                                   clear_devices=True)
            restorer.restore(sess, input_checkpoint)
            if initializer_nodes:
                sess.run(initializer_nodes.split(","))
        elif input_saved_model_dir:
            if saved_model_tags is None:
                saved_model_tags = []
            loader.load(sess, saved_model_tags, input_saved_model_dir)
        else:
            var_list = {}
            reader = pywrap_tensorflow.NewCheckpointReader(input_checkpoint)
            var_to_shape_map = reader.get_variable_to_shape_map()
            for key in var_to_shape_map:
                try:
                    tensor = sess.graph.get_tensor_by_name(key + ":0")
                except KeyError:
                    # This tensor doesn't exist in the graph (for example it's
                    # 'global_step' or a similar housekeeping element) so skip it.
                    continue
                var_list[key] = tensor
            saver = saver_lib.Saver(var_list=var_list)
            saver.restore(sess, input_checkpoint)
            if initializer_nodes:
                sess.run(initializer_nodes.split(","))

        variable_names_whitelist = (variable_names_whitelist.split(",")
                                    if variable_names_whitelist else None)
        variable_names_blacklist = (variable_names_blacklist.split(",")
                                    if variable_names_blacklist else None)

        if input_meta_graph_def:
            output_graph_def = graph_util.convert_variables_to_constants(
                sess,
                input_meta_graph_def.graph_def,
                output_node_names.split(","),
                variable_names_whitelist=variable_names_whitelist,
                variable_names_blacklist=variable_names_blacklist)
        else:
            output_graph_def = graph_util.convert_variables_to_constants(
                sess,
                input_graph_def,
                output_node_names.split(","),
                variable_names_whitelist=variable_names_whitelist,
                variable_names_blacklist=variable_names_blacklist)

    # Write GraphDef to file if output path has been given.
    if output_graph:
        with gfile.GFile(output_graph, "wb") as f:
            f.write(output_graph_def.SerializeToString())

    return output_graph_def
    def test_train_infer(self):
        """Tests training and inference scripts.
    """
        # Create dummy data
        sources_train, targets_train = test_utils.create_temp_parallel_data(
            sources=["a a a a", "b b b b", "c c c c", "笑 笑 笑 笑"],
            targets=["b b b b", "a a a a", "c c c c", "泣 泣 泣 泣"])
        sources_dev, targets_dev = test_utils.create_temp_parallel_data(
            sources=["a a", "b b", "c c c", "笑 笑 笑"],
            targets=["b b", "a a", "c c c", "泣 泣 泣"])
        vocab_source = test_utils.create_temporary_vocab_file(
            ["a", "b", "c", "笑"])
        vocab_target = test_utils.create_temporary_vocab_file(
            ["a", "b", "c", "泣"])

        _clear_flags()
        tf.reset_default_graph()
        train_script = imp.load_source("seq2seq.test.train_bin",
                                       os.path.join(BIN_FOLDER, "train.py"))

        # Set training flags
        tf.app.flags.FLAGS.output_dir = self.output_dir
        tf.app.flags.FLAGS.train_source = sources_train.name
        tf.app.flags.FLAGS.train_target = targets_train.name
        tf.app.flags.FLAGS.vocab_source = vocab_source.name
        tf.app.flags.FLAGS.vocab_target = vocab_target.name
        tf.app.flags.FLAGS.model = "AttentionSeq2Seq"
        tf.app.flags.FLAGS.batch_size = 2

        # We pass a few flags via a config file
        config_path = os.path.join(self.output_dir, "train_config.yml")
        with gfile.GFile(config_path, "w") as config_file:
            yaml.dump(
                {
                    "dev_source": sources_dev.name,
                    "dev_target": targets_dev.name,
                    "train_steps": 50,
                    "hparams": {
                        "embedding.dim": 64,
                        "attention.dim": 16,
                        "decoder.rnn_cell.cell_spec": {
                            "class": "GRUCell",
                            "num_units": 32
                        }
                    }
                }, config_file)

        tf.app.flags.FLAGS.config_path = config_path

        # Run training
        tf.logging.set_verbosity(tf.logging.INFO)
        train_script.main([])

        # Make sure a checkpoint was written
        expected_checkpoint = os.path.join(
            self.output_dir, "model.ckpt-50.data-00000-of-00001")
        self.assertTrue(os.path.exists(expected_checkpoint))

        # Reset flags and import inference script
        _clear_flags()
        tf.reset_default_graph()
        infer_script = imp.load_source("seq2seq.test.infer_bin",
                                       os.path.join(BIN_FOLDER, "infer.py"))

        # Set inference flags
        attention_dir = os.path.join(self.output_dir, "att")
        tf.app.flags.FLAGS.model_dir = self.output_dir
        tf.app.flags.FLAGS.source = sources_dev.name
        tf.app.flags.FLAGS.batch_size = 2
        tf.app.flags.FLAGS.checkpoint_path = os.path.join(
            self.output_dir, "model.ckpt-50")
        tf.app.flags.FLAGS.dump_attention_dir = attention_dir

        # Make sure inference runs successfully
        infer_script.main([])

        # Make sure attention scores and visualizations exist
        self.assertTrue(
            os.path.exists(os.path.join(attention_dir,
                                        "attention_scores.npz")))
        self.assertTrue(
            os.path.exists(os.path.join(attention_dir, "00002.png")))

        # Load attention scores and assert shape
        scores = np.load(os.path.join(attention_dir, "attention_scores.npz"))
        self.assertIn("arr_0", scores)
        self.assertEqual(scores["arr_0"].shape[1], 3)
        self.assertIn("arr_1", scores)
        self.assertEqual(scores["arr_1"].shape[1], 3)
        self.assertIn("arr_2", scores)
        self.assertEqual(scores["arr_2"].shape[1], 4)
        self.assertIn("arr_3", scores)
        self.assertEqual(scores["arr_3"].shape[1], 4)
Ejemplo n.º 9
0
def save_int8_frezon_pb(q_model, path):
    from tensorflow.python.platform import gfile
    f = gfile.GFile(path, 'wb')
    f.write(q_model.as_graph_def().SerializeToString())
    print("Save to {}".format(path))
Ejemplo n.º 10
0
def skip_gram_sample_with_text_vocab(input_tensor,
                                     vocab_freq_file,
                                     vocab_token_index=0,
                                     vocab_token_dtype=dtypes.string,
                                     vocab_freq_index=1,
                                     vocab_freq_dtype=dtypes.float64,
                                     vocab_delimiter=",",
                                     vocab_min_count=0,
                                     vocab_subsampling=None,
                                     corpus_size=None,
                                     min_skips=1,
                                     max_skips=5,
                                     start=0,
                                     limit=-1,
                                     emit_self_as_target=False,
                                     batch_size=None,
                                     batch_capacity=None,
                                     seed=None,
                                     name=None):
    """Skip-gram sampling with a text vocabulary file.

  Wrapper around `skip_gram_sample()` for use with a text vocabulary file. The
  vocabulary file is expected to be a plain-text file, with lines of
  `vocab_delimiter`-separated columns. The `vocab_token_index` column should
  contain the vocabulary term, while the `vocab_freq_index` column should
  contain the number of times that term occurs in the corpus. For example, with
  a text vocabulary file of:

    ```
    bonjour,fr,42
    hello,en,777
    hola,es,99
    ```

  You should set `vocab_delimiter=","`, `vocab_token_index=0`, and
  `vocab_freq_index=2`.

  See `skip_gram_sample()` documentation for more details about the skip-gram
  sampling process.

  Args:
    input_tensor: A rank-1 `Tensor` from which to generate skip-gram candidates.
    vocab_freq_file: `string` specifying full file path to the text vocab file.
    vocab_token_index: `int` specifying which column in the text vocab file
      contains the tokens.
    vocab_token_dtype: `DType` specifying the format of the tokens in the text
      vocab file.
    vocab_freq_index: `int` specifying which column in the text vocab file
      contains the frequency counts of the tokens.
    vocab_freq_dtype: `DType` specifying the format of the frequency counts in
      the text vocab file.
    vocab_delimiter: `string` specifying the delimiter used in the text vocab
      file.
    vocab_min_count: `int`, `float`, or scalar `Tensor` specifying
      minimum frequency threshold (from `vocab_freq_file`) for a token to be
      kept in `input_tensor`. This should correspond with `vocab_freq_dtype`.
    vocab_subsampling: (Optional) `float` specifying frequency proportion
      threshold for tokens from `input_tensor`. Tokens that occur more
      frequently will be randomly down-sampled. Reasonable starting values may
      be around 1e-3 or 1e-5. See Eq. 5 in http://arxiv.org/abs/1310.4546 for
      more details.
    corpus_size: (Optional) `int`, `float`, or scalar `Tensor` specifying the
      total number of tokens in the corpus (e.g., sum of all the frequency
      counts of `vocab_freq_file`). Used with `vocab_subsampling` for
      down-sampling frequently occurring tokens. If this is specified,
      `vocab_freq_file` and `vocab_subsampling` must also be specified.
      If `corpus_size` is needed but not supplied, then it will be calculated
      from `vocab_freq_file`. You might want to supply your own value if you
      have already eliminated infrequent tokens from your vocabulary files
      (where frequency < vocab_min_count) to save memory in the internal token
      lookup table. Otherwise, the unused tokens' variables will waste memory.
      The user-supplied `corpus_size` value must be greater than or equal to the
      sum of all the frequency counts of `vocab_freq_file`.
    min_skips: `int` or scalar `Tensor` specifying the minimum window size to
      randomly use for each token. Must be >= 0 and <= `max_skips`. If
      `min_skips` and `max_skips` are both 0, the only label outputted will be
      the token itself.
    max_skips: `int` or scalar `Tensor` specifying the maximum window size to
      randomly use for each token. Must be >= 0.
    start: `int` or scalar `Tensor` specifying the position in `input_tensor`
      from which to start generating skip-gram candidates.
    limit: `int` or scalar `Tensor` specifying the maximum number of elements in
      `input_tensor` to use in generating skip-gram candidates. -1 means to use
      the rest of the `Tensor` after `start`.
    emit_self_as_target: `bool` or scalar `Tensor` specifying whether to emit
      each token as a label for itself.
    batch_size: (Optional) `int` specifying batch size of returned `Tensors`.
    batch_capacity: (Optional) `int` specifying batch capacity for the queue
      used for batching returned `Tensors`. Only has an effect if
      `batch_size` > 0. Defaults to 100 * `batch_size` if not specified.
    seed: (Optional) `int` used to create a random seed for window size and
      subsampling. See
      [`set_random_seed`](../../g3doc/python/constant_op.md#set_random_seed)
      for behavior.
    name: (Optional) A `string` name or a name scope for the operations.

  Returns:
    A `tuple` containing (token, label) `Tensors`. Each output `Tensor` is of
    rank-1 and has the same type as `input_tensor`. The `Tensors` will be of
    length `batch_size`; if `batch_size` is not specified, they will be of
    random length, though they will be in sync with each other as long as they
    are evaluated together.

  Raises:
    ValueError: If `vocab_token_index` or `vocab_freq_index` is less than 0 or
      exceeds the number of columns in `vocab_freq_file`. If `vocab_token_index`
      and `vocab_freq_index` are both set to the same column. If any token in
      `vocab_freq_file` has a negative frequency.
  """

    if vocab_token_index < 0 or vocab_freq_index < 0:
        raise ValueError(
            "vocab_token_index={} and vocab_freq_index={} must both be >= 0.".
            format(vocab_token_index, vocab_freq_index))
    if vocab_token_index == vocab_freq_index:
        raise ValueError(
            "vocab_token_index and vocab_freq_index should be different, but are "
            "both {}.".format(vocab_token_index))

    # Iterates through the vocab file and calculates the number of vocab terms as
    # well as the total corpus size (by summing the frequency counts of all the
    # vocab terms).
    calculated_corpus_size = 0.0
    vocab_size = 0
    with gfile.GFile(vocab_freq_file, mode="r") as f:
        reader = csv.reader(f, delimiter=vocab_delimiter)
        for row in reader:
            if vocab_token_index >= len(row) or vocab_freq_index >= len(row):
                raise ValueError(
                    "Row in vocab file only has {} columns, so vocab_token_index={} or "
                    "vocab_freq_index={} is out of bounds. Row content: {}".
                    format(len(row), vocab_token_index, vocab_freq_index, row))
            vocab_size += 1
            freq = vocab_freq_dtype.as_numpy_dtype(row[vocab_freq_index])
            if freq < 0:
                raise ValueError(
                    "Row in vocab file has negative frequency of {}. Row content: {}"
                    .format(freq, row))
            # Note: tokens whose frequencies are below vocab_min_count will still
            # contribute to the total corpus size used for vocab subsampling.
            calculated_corpus_size += freq

    if not corpus_size:
        corpus_size = calculated_corpus_size
    elif calculated_corpus_size - corpus_size > 1e-6:
        raise ValueError(
            "`corpus_size`={} must be greater than or equal to the sum of all the "
            "frequency counts ({}) of `vocab_freq_file` ({}).".format(
                corpus_size, calculated_corpus_size, vocab_freq_file))

    vocab_freq_table = lookup.HashTable(
        lookup.TextFileInitializer(filename=vocab_freq_file,
                                   key_dtype=vocab_token_dtype,
                                   key_index=vocab_token_index,
                                   value_dtype=vocab_freq_dtype,
                                   value_index=vocab_freq_index,
                                   vocab_size=vocab_size,
                                   delimiter=vocab_delimiter),
        # For vocab terms not in vocab file, use a default value of -1.
        default_value=-1)

    return skip_gram_sample(
        input_tensor,
        min_skips=min_skips,
        max_skips=max_skips,
        start=start,
        limit=limit,
        emit_self_as_target=emit_self_as_target,
        vocab_freq_table=vocab_freq_table,
        vocab_min_count=vocab_min_count,
        vocab_subsampling=vocab_subsampling,
        # corpus_size is not used unless vocab_subsampling is specified.
        corpus_size=None if vocab_subsampling is None else corpus_size,
        batch_size=batch_size,
        batch_capacity=batch_capacity,
        seed=seed,
        name=name)
Ejemplo n.º 11
0
def _global_report_benchmark(name,
                             iters=None,
                             cpu_time=None,
                             wall_time=None,
                             throughput=None,
                             extras=None):
    """Method for recording a benchmark directly.

  Args:
    name: The BenchmarkEntry name.
    iters: (optional) How many iterations were run
    cpu_time: (optional) Total cpu time in seconds
    wall_time: (optional) Total wall time in seconds
    throughput: (optional) Throughput (in MB/s)
    extras: (optional) Dict mapping string keys to additional benchmark info.

  Raises:
    TypeError: if extras is not a dict.
    IOError: if the benchmark output file already exists.
  """
    if extras is not None:
        if not isinstance(extras, dict):
            raise TypeError("extras must be a dict")

    logging.info(
        "Benchmark [%s] iters: %d, wall_time: %g, cpu_time: %g,"
        "throughput: %g %s", name, iters if iters is not None else -1,
        wall_time if wall_time is not None else -1,
        cpu_time if cpu_time is not None else -1,
        throughput if throughput is not None else -1,
        str(extras) if extras else "")

    entries = test_log_pb2.BenchmarkEntries()
    entry = entries.entry.add()
    entry.name = name
    if iters is not None:
        entry.iters = iters
    if cpu_time is not None:
        entry.cpu_time = cpu_time
    if wall_time is not None:
        entry.wall_time = wall_time
    if throughput is not None:
        entry.throughput = throughput
    if extras is not None:
        for (k, v) in extras.items():
            if isinstance(v, numbers.Number):
                entry.extras[k].double_value = v
            else:
                entry.extras[k].string_value = str(v)

    test_env = os.environ.get(TEST_REPORTER_TEST_ENV, None)
    if test_env is None:
        # Reporting was not requested, just print the proto
        print(str(entries))
        return

    serialized_entry = entries.SerializeToString()

    mangled_name = name.replace("/", "__")
    output_path = "%s%s" % (test_env, mangled_name)
    if gfile.Exists(output_path):
        raise IOError("File already exists: %s" % output_path)
    with gfile.GFile(output_path, "wb") as out:
        out.write(serialized_entry)
Ejemplo n.º 12
0
def run_training(sess,
                 trainers,
                 annotator,
                 evaluator,
                 pretrain_steps,
                 train_steps,
                 train_corpus,
                 eval_corpus,
                 eval_gold,
                 batch_size,
                 summary_writer,
                 report_every,
                 saver,
                 checkpoint_filename,
                 checkpoint_stats=None):
    """Runs multi-task DRAGNN training on a single corpus.

  Arguments:
    sess: TF session to use.
    trainers: List of training ops to use.
    annotator: Annotation op.
    evaluator: Function taking two serialized corpora and returning a dict of
      scalar summaries representing evaluation metrics. The 'eval_metric'
      summary will be used for early stopping.
    pretrain_steps: List of the no. of pre-training steps for each train op.
    train_steps: List of the total no. of steps for each train op.
    train_corpus: Training corpus to use.
    eval_corpus: Holdout Corpus for early stoping.
    eval_gold: Reference of eval_corpus for computing accuracy.
      eval_corpus and eval_gold are allowed to be the same if eval_corpus
      already contains gold annotation.
      Note for segmentation eval_corpus and eval_gold are always different since
      eval_corpus contains sentences whose tokens are utf8-characters while
      eval_gold's tokens are gold words.
    batch_size: How many examples to send to the train op at a time.
    summary_writer: TF SummaryWriter to use to write summaries.
    report_every: How often to compute summaries (in steps).
    saver: TF saver op to save variables.
    checkpoint_filename: File to save checkpoints to.
    checkpoint_stats: Stats of checkpoint.
  """
    random.seed(0x31337)

    if not checkpoint_stats:
        checkpoint_stats = [0] * (len(train_steps) + 1)
    tf.logging.info('Determining the training schedule...')
    target_for_step = []
    for target_idx in xrange(len(pretrain_steps)):
        target_for_step += [target_idx] * pretrain_steps[target_idx]
    while sum(train_steps) > 0:
        step = random.randint(0, sum(train_steps) - 1)
        cumulative_steps = 0
        for target_idx in xrange(len(train_steps)):
            cumulative_steps += train_steps[target_idx]
            if step < cumulative_steps:
                break
        assert train_steps[target_idx] > 0
        train_steps[target_idx] -= 1
        target_for_step.append(target_idx)
    tf.logging.info('Training schedule defined!')

    best_eval_metric = -1.0
    tf.logging.info('Starting training...')
    actual_step = sum(checkpoint_stats[1:])
    for step, target_idx in enumerate(target_for_step):
        run_training_step(sess, trainers[target_idx], train_corpus, batch_size)
        checkpoint_stats[target_idx + 1] += 1
        if step % 100 == 0:
            tf.logging.info('training step: %d, actual: %d', step,
                            actual_step + step)
        if step % report_every == 0:
            tf.logging.info('finished step: %d, actual: %d', step,
                            actual_step + step)

            annotated = annotate_dataset(sess, annotator, eval_corpus)
            summaries = evaluator(eval_gold, annotated)
            for label, metric in summaries.iteritems():
                write_summary(summary_writer, label, metric,
                              actual_step + step)
            eval_metric = summaries['eval_metric']
            tf.logging.info('Current eval metric: %.2f', eval_metric)
            if best_eval_metric < eval_metric:
                tf.logging.info(
                    'Updating best eval to %.2f, saving checkpoint.',
                    eval_metric)
                best_eval_metric = eval_metric
                saver.save(sess, checkpoint_filename)

                with gfile.GFile('%s.stats' % checkpoint_filename, 'w') as f:
                    stats_str = ','.join([str(x) for x in checkpoint_stats])
                    f.write(stats_str)
                    tf.logging.info('Writing stats: %s', stats_str)

    tf.logging.info('Finished training!')
Ejemplo n.º 13
0
def train():
    """Train a en->fr translation model using WMT data."""
    from_train = None
    to_train = None
    from_dev = None
    to_dev = None
    if FLAGS.from_train_data and FLAGS.to_train_data:
        from_train_data = FLAGS.from_train_data
        to_train_data = FLAGS.to_train_data
        from_dev_data = from_train_data
        to_dev_data = to_train_data
        if FLAGS.from_dev_data and FLAGS.to_dev_data:
            from_dev_data = FLAGS.from_dev_data
            to_dev_data = FLAGS.to_dev_data
        from_train, to_train, from_dev, to_dev, _, _ = data_utils.prepare_data(
            FLAGS.data_dir, from_train_data, to_train_data, from_dev_data,
            to_dev_data, FLAGS.from_vocab_size, FLAGS.to_vocab_size)
    else:
        # Prepare WMT data.
        print("Preparing WMT data in %s" % FLAGS.data_dir)
        from_train, to_train, from_dev, to_dev, _, _ = data_utils.prepare_wmt_data(
            FLAGS.data_dir, FLAGS.from_vocab_size, FLAGS.to_vocab_size)

    with tf.Session() as sess:
        # Create model.
        print("Creating %d layers of %d units." %
              (FLAGS.num_layers, FLAGS.size))
        model = create_model(sess, False)

        # Read data into buckets and compute their sizes.
        print("Reading development and training data (limit: %d)." %
              FLAGS.max_train_data_size)
        dev_set = read_data(from_dev, to_dev)
        train_set = read_data(from_train, to_train, FLAGS.max_train_data_size)
        train_bucket_sizes = [len(train_set[b]) for b in xrange(len(_buckets))]
        train_total_size = float(sum(train_bucket_sizes))

        # A bucket scale is a list of increasing numbers from 0 to 1 that we'll use
        # to select a bucket. Length of [scale[i], scale[i+1]] is proportional to
        # the size if i-th training bucket, as used later.
        train_buckets_scale = [
            sum(train_bucket_sizes[:i + 1]) / train_total_size
            for i in xrange(len(train_bucket_sizes))
        ]

        # This is the training loop.
        step_time, loss = 0.0, 0.0
        current_step = 0
        previous_losses = []
        while current_step < FLAGS.num_train_step:
            # Choose a bucket according to data distribution. We pick a random number
            # in [0, 1] and use the corresponding interval in train_buckets_scale.
            random_number_01 = np.random.random_sample()
            bucket_id = min([
                i for i in xrange(len(train_buckets_scale))
                if train_buckets_scale[i] > random_number_01
            ])

            # Get a batch and make a step.
            start_time = time.time()
            encoder_inputs, decoder_inputs, target_weights = model.get_batch(
                train_set, bucket_id)
            _, step_loss, _, enc_init_states, enc_all_outputs = model.step(
                sess,
                encoder_inputs,
                decoder_inputs,  #MK change
                target_weights,
                bucket_id,
                False,
                1)
            step_time += (time.time() -
                          start_time) / FLAGS.steps_per_checkpoint
            loss += step_loss / FLAGS.steps_per_checkpoint
            current_step += 1

            # Once in a while, we save checkpoint, print statistics, and run evals.
            if current_step % FLAGS.steps_per_checkpoint == 0:
                # Print statistics for the previous epoch.
                first_layer = np.array(enc_init_states[0])
                mat_first_layer = np.matrix(first_layer)
                with open('first_layer_states.txt', 'wb') as f:
                    for line in mat_first_layer:
                        np.savetxt(f, line, fmt='%.2f')

                second_layer = np.array(enc_init_states[1])
                mat_second_layer = np.matrix(second_layer)
                with open('second_layer_states.txt', 'wb') as f:
                    for line in mat_second_layer:
                        np.savetxt(f, line, fmt='%.2f')

                perplexity = math.exp(
                    float(loss)) if loss < 300 else float("inf")
                print(
                    "global step %d learning rate %.4f step-time %.5f perplexity "
                    "%.5f" %
                    (model.global_step.eval(), model.learning_rate.eval(),
                     step_time, perplexity))
                # Decrease learning rate if no improvement was seen over last 3 times.
                if len(previous_losses) > 2 and loss > max(
                        previous_losses[-3:]):
                    sess.run(model.learning_rate_decay_op)
                previous_losses.append(loss)
                # Save checkpoint and zero timer and loss.
                checkpoint_path = os.path.join(FLAGS.train_dir,
                                               "translate.ckpt")
                model.saver.save(sess,
                                 checkpoint_path,
                                 global_step=model.global_step)
                step_time, loss = 0.0, 0.0
                # Run evals on development set and print their perplexity.
                for bucket_id in xrange(len(_buckets)):
                    if len(dev_set[bucket_id]) == 0:
                        print("  eval: empty bucket %d" % (bucket_id))
                        continue
                    encoder_inputs, decoder_inputs, target_weights = model.get_batch(
                        dev_set, bucket_id)
                    _, eval_loss, _, _, _ = model.step(sess, encoder_inputs,
                                                       decoder_inputs,
                                                       target_weights,
                                                       bucket_id, False, 0)
                    eval_ppx = math.exp(
                        float(eval_loss)) if eval_loss < 300 else float("inf")
                    print("  eval: bucket %d perplexity %.5f" %
                          (bucket_id, eval_ppx))
                sys.stdout.flush()

        en_vocab_path = os.path.join(FLAGS.data_dir,
                                     "vocab%d.from" % FLAGS.from_vocab_size)
        fr_vocab_path = os.path.join(FLAGS.data_dir,
                                     "vocab%d.to" % FLAGS.to_vocab_size)
        en_vocab, _ = data_utils.initialize_vocabulary(en_vocab_path)
        _, rev_fr_vocab = data_utils.initialize_vocabulary(fr_vocab_path)

        max_iter = 100
        count = 0
        model.batch_size = 1
        with gfile.GFile(FLAGS.from_train_data, mode="rb") as f:
            for sentence in f:
                count = count + 1
                if max_iter < count:
                    break
        #sentence = sys.stdin.readline()
        #while sentence:
                print(sentence)
                # Get token-ids for the input sentence.
                token_ids = data_utils.sentence_to_token_ids(
                    tf.compat.as_bytes(sentence), en_vocab)
                # Which bucket does it belong to?
                bucket_id = len(_buckets) - 1
                for i, bucket in enumerate(_buckets):
                    if bucket[0] >= len(token_ids):
                        bucket_id = i
                        break
                else:
                    logging.warning("Sentence truncated: %s", sentence)

                # Get a 1-element batch to feed the sentence to the model.
                encoder_inputs, decoder_inputs, target_weights = model.get_batch(
                    {bucket_id: [(token_ids, [])]}, bucket_id)
                # Get output logits for the sentence.
                _, _, output_logits, enc_all_state, _ = model.step(
                    sess, encoder_inputs, decoder_inputs, target_weights,
                    bucket_id, False, 0)
                quit()
 def testExists(self):
   self.assertFalse(gfile.Exists(self.tmp + "test_exists"))
   with gfile.GFile(self.tmp + "test_exists", "w"):
     pass
   self.assertTrue(gfile.Exists(self.tmp + "test_exists"))
Ejemplo n.º 15
0
  def test_export_savedmodel_assets(self):
    tmpdir = tempfile.mkdtemp()
    est = estimator.Estimator(model_fn=_model_fn_for_export_tests)
    est.train(input_fn=dummy_input_fn, steps=1)
    feature_spec = {'x': parsing_ops.VarLenFeature(dtype=dtypes.int64),
                    'y': parsing_ops.VarLenFeature(dtype=dtypes.int64)}
    serving_input_receiver_fn = export.build_parsing_serving_input_receiver_fn(
        feature_spec)

    # Create a fake asset.
    vocab_file_name = os.path.join(
        compat.as_bytes(tmpdir), compat.as_bytes('my_vocab_file'))
    vocab_file = gfile.GFile(vocab_file_name, mode='w')
    vocab_file.write(_VOCAB_FILE_CONTENT)
    vocab_file.close()

    # hack in an op that uses the asset, in order to test asset export.
    # this is not actually valid, of course.
    def serving_input_receiver_with_asset_fn():
      features, receiver_tensor = serving_input_receiver_fn()
      filename = ops.convert_to_tensor(vocab_file_name,
                                       dtypes.string,
                                       name='asset_filepath')
      ops.add_to_collection(ops.GraphKeys.ASSET_FILEPATHS, filename)
      features['bogus_filename'] = filename

      return export.ServingInputReceiver(features, receiver_tensor)

    # Perform the export.
    export_dir_base = os.path.join(
        compat.as_bytes(tmpdir), compat.as_bytes('export'))
    export_dir = est.export_savedmodel(
        export_dir_base, serving_input_receiver_with_asset_fn)

    # Check that the asset files are in the right places.
    expected_vocab_file_name = os.path.join(
        compat.as_bytes(export_dir), compat.as_bytes('assets/my_vocab_file'))
    self.assertTrue(gfile.Exists(os.path.join(
        compat.as_bytes(export_dir), compat.as_bytes('assets'))))
    self.assertTrue(gfile.Exists(expected_vocab_file_name))
    self.assertEqual(
        compat.as_bytes(_VOCAB_FILE_CONTENT),
        compat.as_bytes(gfile.GFile(expected_vocab_file_name).read()))

    # Restore, to validate that the export was well-formed.
    with ops.Graph().as_default() as graph:
      with session.Session(graph=graph) as sess:
        loader.load(sess, [tag_constants.SERVING], export_dir)
        assets = [
            x.eval()
            for x in graph.get_collection(ops.GraphKeys.ASSET_FILEPATHS)
        ]
        self.assertItemsEqual([vocab_file_name], assets)
        graph_ops = [x.name for x in graph.get_operations()]
        self.assertTrue('input_example_tensor' in graph_ops)
        self.assertTrue('ParseExample/ParseExample' in graph_ops)
        self.assertTrue('asset_filepath' in graph_ops)
        self.assertTrue('weight' in graph_ops)

    # cleanup
    gfile.DeleteRecursively(tmpdir)
Ejemplo n.º 16
0
def wav_to_features(sample_rate, clip_duration_ms, window_size_ms,
                    window_stride_ms, feature_bin_count, quantize, preprocess,
                    input_wav, output_c_file):
    """Converts an audio file into its corresponding feature map.

    Args:
      sample_rate: Expected sample rate of the wavs.
      clip_duration_ms: Expected duration in milliseconds of the wavs.
      window_size_ms: How long each spectrogram timeslice is.
      window_stride_ms: How far to move in time between spectrogram timeslices.
      feature_bin_count: How many bins to use for the feature fingerprint.
      quantize: Whether to train the model for eight-bit deployment.
      preprocess: Spectrogram processing mode; "mfcc", "average" or "micro".
      input_wav: Path to the audio WAV file to read.
      output_c_file: Where to save the generated C source file.
    """

    # Start a new TensorFlow session.
    sess = tf.compat.v1.InteractiveSession()

    model_settings = models.prepare_model_settings(
        0, sample_rate, clip_duration_ms, window_size_ms, window_stride_ms,
        feature_bin_count, preprocess)
    audio_processor = input_data.AudioProcessor(None, None, 0, 0, '', 0, 0,
                                                model_settings, None)

    results = audio_processor.get_features_for_wav(input_wav, model_settings,
                                                   sess)
    features = results[0]

    variable_base = os.path.splitext(os.path.basename(input_wav).lower())[0]

    # Save a C source file containing the feature data as an array.
    with gfile.GFile(output_c_file, 'w') as f:
        f.write('/* File automatically created by\n')
        f.write(
            ' * tensorflow/examples/speech_commands/wav_to_features.py \\\n')
        f.write(' * --sample_rate=%d \\\n' % sample_rate)
        f.write(' * --clip_duration_ms=%d \\\n' % clip_duration_ms)
        f.write(' * --window_size_ms=%d \\\n' % window_size_ms)
        f.write(' * --window_stride_ms=%d \\\n' % window_stride_ms)
        f.write(' * --feature_bin_count=%d \\\n' % feature_bin_count)
        if quantize:
            f.write(' * --quantize=1 \\\n')
        f.write(' * --preprocess="%s" \\\n' % preprocess)
        f.write(' * --input_wav="%s" \\\n' % input_wav)
        f.write(' * --output_c_file="%s" \\\n' % output_c_file)
        f.write(' */\n\n')
        f.write('const int g_%s_width = %d;\n' %
                (variable_base, model_settings['fingerprint_width']))
        f.write('const int g_%s_height = %d;\n' %
                (variable_base, model_settings['spectrogram_length']))
        if quantize:
            features_min, features_max = input_data.get_features_range(
                model_settings)
            f.write('const unsigned char g_%s_data[] = {' % variable_base)
            i = 0
            for value in features.flatten():
                quantized_value = int(
                    round((255 * (value - features_min)) /
                          (features_max - features_min)))
                if quantized_value < 0:
                    quantized_value = 0
                if quantized_value > 255:
                    quantized_value = 255
                if i == 0:
                    f.write('\n  ')
                f.write('%d, ' % (quantized_value))
                i = (i + 1) % 10
        else:
            f.write('const float g_%s_data[] = {\n' % variable_base)
            i = 0
            for value in features.flatten():
                if i == 0:
                    f.write('\n  ')
                f.write('%f, ' % value)
                i = (i + 1) % 10
        f.write('\n};\n')
Ejemplo n.º 17
0
 def benchmark_sentencepiece_tokenizer(self):
     model = gfile.GFile((_SENTENCEPIECE_MODEL_FILE), "rb").read()
     tokenizer = text_ops.SentencepieceTokenizer(model)
     self._run(tokenizer)
Ejemplo n.º 18
0
if __name__ == "__main__":
    os.environ['CUDA_VISIBLE_DEVICE'] = "0"
    output_path = './outputs_replica'

    checkpoint_to_use = os.path.join(output_path, 'model.ckpt')
    output_graph_path = './export/exported_graph_replica_v2.pb'
    output_node_names = 'NetOutput'

    tf.reset_default_graph()

    with tf.Session() as sess:
        saver_restore = tf.train.import_meta_graph(
            os.path.join(output_path, 'model.ckpt.meta'))
        saver_restore.restore(sess, tf.train.latest_checkpoint(output_path))

        saver_export = tf.train.Saver(max_to_keep=10)

        frozen_graph_def = freeze_graph_with_def_protos(
            input_graph_def=tf.get_default_graph().as_graph_def(
                add_shapes=True),
            input_saver_def=saver_export.as_saver_def(),
            input_checkpoint=checkpoint_to_use,
            output_node_names=output_node_names,
            restore_op_name=None,
            filename_tensor_name=None,
            clear_devices=True,
            initializer_nodes=None)

        with gfile.GFile(output_graph_path, 'wb') as f:
            f.write(frozen_graph_def.SerializeToString())
Ejemplo n.º 19
0
    def test_writing_canned_variants(self):
        """Tests writing all the variants that are 'canned' in our tfrecord file."""
        # This file is in the TF record format
        tfrecord_file = test_utils.genomics_core_testdata(
            'test_samples.vcf.golden.tfrecord')

        writer_options = variants_pb2.VcfWriterOptions()
        header = variants_pb2.VcfHeader(
            contigs=[
                reference_pb2.ContigInfo(name='chr1', n_bases=248956422),
                reference_pb2.ContigInfo(name='chr2', n_bases=242193529),
                reference_pb2.ContigInfo(name='chr3', n_bases=198295559),
                reference_pb2.ContigInfo(name='chrX', n_bases=156040895)
            ],
            sample_names=['NA12878_18_99'],
            filters=[
                variants_pb2.VcfFilterInfo(id='PASS',
                                           description='All filters passed'),
                variants_pb2.VcfFilterInfo(id='LowQual', description=''),
                variants_pb2.VcfFilterInfo(id='VQSRTrancheINDEL95.00to96.00'),
                variants_pb2.VcfFilterInfo(id='VQSRTrancheINDEL96.00to97.00'),
                variants_pb2.VcfFilterInfo(id='VQSRTrancheINDEL97.00to99.00'),
                variants_pb2.VcfFilterInfo(id='VQSRTrancheINDEL99.00to99.50'),
                variants_pb2.VcfFilterInfo(id='VQSRTrancheINDEL99.50to99.90'),
                variants_pb2.VcfFilterInfo(id='VQSRTrancheINDEL99.90to99.95'),
                variants_pb2.VcfFilterInfo(
                    id='VQSRTrancheINDEL99.95to100.00+'),
                variants_pb2.VcfFilterInfo(id='VQSRTrancheINDEL99.95to100.00'),
                variants_pb2.VcfFilterInfo(id='VQSRTrancheSNP99.50to99.60'),
                variants_pb2.VcfFilterInfo(id='VQSRTrancheSNP99.60to99.80'),
                variants_pb2.VcfFilterInfo(id='VQSRTrancheSNP99.80to99.90'),
                variants_pb2.VcfFilterInfo(id='VQSRTrancheSNP99.90to99.95'),
                variants_pb2.VcfFilterInfo(id='VQSRTrancheSNP99.95to100.00+'),
                variants_pb2.VcfFilterInfo(id='VQSRTrancheSNP99.95to100.00'),
            ],
            infos=[
                variants_pb2.VcfInfo(
                    id='END',
                    number='1',
                    type='Integer',
                    description='Stop position of the interval')
            ],
            formats=[
                variants_pb2.VcfFormatInfo(id='GT',
                                           number='1',
                                           type='String',
                                           description='Genotype'),
                variants_pb2.VcfFormatInfo(id='GQ',
                                           number='1',
                                           type='Integer',
                                           description='Genotype Quality'),
                variants_pb2.VcfFormatInfo(
                    id='DP',
                    number='1',
                    type='Integer',
                    description='Read depth of all passing filters reads.'),
                variants_pb2.VcfFormatInfo(
                    id='MIN_DP',
                    number='1',
                    type='Integer',
                    description='Minimum DP observed within the GVCF block.'),
                variants_pb2.VcfFormatInfo(
                    id='AD',
                    number='R',
                    type='Integer',
                    description=
                    'Read depth of all passing filters reads for each allele.'
                ),
                variants_pb2.VcfFormatInfo(
                    id='VAF',
                    number='A',
                    type='Float',
                    description='Variant allele fractions.'),
                variants_pb2.VcfFormatInfo(
                    id='PL',
                    number='G',
                    type='Integer',
                    description='Genotype likelihoods, Phred encoded'),
            ],
        )
        variant_records = list(
            tfrecord.read_tfrecords(tfrecord_file, proto=variants_pb2.Variant))
        out_fname = test_utils.test_tmpfile('output.vcf')
        with vcf_writer.VcfWriter.to_file(out_fname, header,
                                          writer_options) as writer:
            for record in variant_records[:5]:
                writer.write(record)

        # Check: are the variants written as expected?
        # pylint: disable=line-too-long
        expected_vcf_content = [
            '##fileformat=VCFv4.2\n',
            '##FILTER=<ID=PASS,Description="All filters passed">\n',
            '##FILTER=<ID=LowQual,Description="">\n',
            '##FILTER=<ID=VQSRTrancheINDEL95.00to96.00,Description="">\n',
            '##FILTER=<ID=VQSRTrancheINDEL96.00to97.00,Description="">\n',
            '##FILTER=<ID=VQSRTrancheINDEL97.00to99.00,Description="">\n',
            '##FILTER=<ID=VQSRTrancheINDEL99.00to99.50,Description="">\n',
            '##FILTER=<ID=VQSRTrancheINDEL99.50to99.90,Description="">\n',
            '##FILTER=<ID=VQSRTrancheINDEL99.90to99.95,Description="">\n',
            '##FILTER=<ID=VQSRTrancheINDEL99.95to100.00+,Description="">\n',
            '##FILTER=<ID=VQSRTrancheINDEL99.95to100.00,Description="">\n',
            '##FILTER=<ID=VQSRTrancheSNP99.50to99.60,Description="">\n',
            '##FILTER=<ID=VQSRTrancheSNP99.60to99.80,Description="">\n',
            '##FILTER=<ID=VQSRTrancheSNP99.80to99.90,Description="">\n',
            '##FILTER=<ID=VQSRTrancheSNP99.90to99.95,Description="">\n',
            '##FILTER=<ID=VQSRTrancheSNP99.95to100.00+,Description="">\n',
            '##FILTER=<ID=VQSRTrancheSNP99.95to100.00,Description="">\n',
            '##INFO=<ID=END,Number=1,Type=Integer,Description="Stop position of '
            'the interval">\n',
            '##FORMAT=<ID=GT,Number=1,Type=String,Description="Genotype">\n',
            '##FORMAT=<ID=GQ,Number=1,Type=Integer,Description="Genotype Quality">\n',
            '##FORMAT=<ID=DP,Number=1,Type=Integer,Description="Read depth of all '
            'passing filters reads.">\n',
            '##FORMAT=<ID=MIN_DP,Number=1,Type=Integer,Description="Minimum DP '
            'observed within the GVCF block.">\n',
            '##FORMAT=<ID=AD,Number=R,Type=Integer,Description="Read depth of all '
            'passing filters reads for each allele.">\n',
            '##FORMAT=<ID=VAF,Number=A,Type=Float,Description=\"Variant allele '
            'fractions.">\n',
            '##FORMAT=<ID=PL,Number=G,Type=Integer,Description="Genotype '
            'likelihoods, Phred encoded">\n',
            '##contig=<ID=chr1,length=248956422>\n',
            '##contig=<ID=chr2,length=242193529>\n',
            '##contig=<ID=chr3,length=198295559>\n',
            '##contig=<ID=chrX,length=156040895>\n',
            '#CHROM\tPOS\tID\tREF\tALT\tQUAL\tFILTER\tINFO\tFORMAT\tNA12878_18_99\n',
            'chr1\t13613\t.\tT\tA\t39.88\tVQSRTrancheSNP99.90to99.95\t.\tGT:GQ:DP:AD:PL\t0/1:16:4:1,3:68,0,16\n',
            'chr1\t13813\t.\tT\tG\t90.28\tPASS\t.\tGT:GQ:DP:AD:PL\t1/1:9:3:0,3:118,9,0\n',
            'chr1\t13838\trs28428499\tC\tT\t62.74\tPASS\t.\tGT:GQ:DP:AD:PL\t1/1:6:2:0,2:90,6,0\n',
            'chr1\t14397\trs756427959\tCTGT\tC\t37.73\tPASS\t.\tGT:GQ:DP:AD:PL\t0/1:75:5:3,2:75,0,152\n',
            'chr1\t14522\t.\tG\tA\t49.77\tVQSRTrancheSNP99.60to99.80\t.\tGT:GQ:DP:AD:PL\t0/1:78:10:6,4:78,0,118\n'
        ]
        # pylint: enable=line-too-long

        with gfile.GFile(out_fname, 'r') as f:
            self.assertEqual(f.readlines(), expected_vcf_content)
Ejemplo n.º 20
0
    def __init__(self, frozen_file, inputshape, in_nodes, dest_nodes):
        if LooseVersion(tensorflow.__version__) < LooseVersion('1.8.0'):
            raise ImportError(
                'Your TensorFlow version %s is outdated. '
                'MMdnn requires tensorflow>=1.8.0' % tensorflow.__version__)

        super(TensorflowParser2, self).__init__()

        self.weight_loaded = True
        # load model files into TensorFlow graph
        with open(frozen_file, 'rb') as f:
            serialized = f.read()
        tensorflow.reset_default_graph()
        original_gdef = tensorflow.GraphDef()

        original_gdef.ParseFromString(serialized)

        in_type_list = {}
        for n in original_gdef.node:
            if n.name in in_nodes:
                in_type_list[n.name] = n.attr['dtype'].type

        from tensorflow.python.tools import strip_unused_lib
        from tensorflow.python.framework import dtypes
        from tensorflow.python.platform import gfile
        original_gdef = strip_unused_lib.strip_unused(
                input_graph_def = original_gdef,
                input_node_names = in_nodes,
                output_node_names = dest_nodes,
                placeholder_type_enum = dtypes.float32.as_datatype_enum)
        # Save it to an output file
        frozen_model_file = './frozen.pb'
        with gfile.GFile(frozen_model_file, "wb") as f:
            f.write(original_gdef.SerializeToString())
        with open(frozen_model_file, 'rb') as f:
            serialized = f.read()
        tensorflow.reset_default_graph()
        model = tensorflow.GraphDef()
        model.ParseFromString(serialized)

        output_shape_map = dict()
        input_shape_map = dict()
        dtype = tensorflow.float32

        with tensorflow.Graph().as_default() as g:
            input_map = {}
            for i in range(len(inputshape)):
                if in_type_list[in_nodes[i]] == 1 or in_type_list[in_nodes[i]] == 0:
                    dtype = tensorflow.float32
                    x = tensorflow.placeholder(dtype, shape = [None] + inputshape[i])

                elif in_type_list[in_nodes[i]] == 3:
                    dtype = tensorflow.int32
                    x = tensorflow.placeholder(dtype, shape = inputshape[i])

                elif in_type_list[in_nodes[i]] == 10:
                    dtype = tensorflow.bool
                    x = tensorflow.placeholder(dtype)

                input_map[in_nodes[i] + ':0'] = x

            tensorflow.import_graph_def(model, name='', input_map=input_map)

        # graph_options = tensorflow.GraphOptions(
        #     optimizer_options=tensorflow.OptimizerOptions(
        #         opt_level=tensorflow.OptimizerOptions.L0, do_function_inlining=False))

        # config = tensorflow.ConfigProto(graph_options=graph_options)
        # with tensorflow.Session(graph = g, config=config) as sess:
        with tensorflow.Session(graph = g) as sess:

            meta_graph_def = tensorflow.train.export_meta_graph(filename='./my-model.meta')
            model = meta_graph_def.graph_def


        self.tf_graph = TensorflowGraph(model)
        self.tf_graph.build()
Ejemplo n.º 21
0
def main(_):
    best_acc = 0
    best_step = 0
    best_acc_istrain = 0
    best_step_istrain = 0
    # We want to see all the logging messages for this tutorial.
    tf.logging.set_verbosity(tf.logging.INFO)

    # Start a new TensorFlow session.
    sess = tf.InteractiveSession()

    # Begin by making sure we have the training data we need. If you already have
    # training data of your own, use `--data_url= ` on the command line to avoid
    # downloading.
    model_settings = models.prepare_model_settings(
        len(
            input_data_filler.prepare_words_list_my(
                FLAGS.wanted_words.split(','))), FLAGS.sample_rate,
        FLAGS.clip_duration_ms, FLAGS.window_size_ms, FLAGS.window_stride_ms,
        FLAGS.dct_coefficient_count)
    audio_processor = input_data_filler.AudioProcessor(
        FLAGS.data_url, FLAGS.data_dir,
        FLAGS.silence_percentage, FLAGS.unknown_percentage,
        FLAGS.wanted_words.split(','), FLAGS.validation_percentage,
        FLAGS.testing_percentage, model_settings)
    fingerprint_size = model_settings['fingerprint_size']
    label_count = model_settings['label_count']
    time_shift_samples = int((FLAGS.time_shift_ms * FLAGS.sample_rate) / 1000)

    training_steps_list = list(
        map(int, FLAGS.how_many_training_steps.split(',')))
    learning_rates_list = list(map(float, FLAGS.learning_rate.split(',')))
    if len(training_steps_list) != len(learning_rates_list):
        raise Exception(
            '--how_many_training_steps and --learning_rate must be equal length '
            'lists, but are %d and %d long instead' %
            (len(training_steps_list), len(learning_rates_list)))
##############################################
############tensorflow modules##########

    fingerprint_input = tf.placeholder(tf.float32, [None, fingerprint_size],
                                       name='fingerprint_input')

    # ############ 模型创建 ##########
    istrain = tf.placeholder(tf.bool, name='istrain')
    logits = models.create_model(fingerprint_input,
                                 model_settings,
                                 FLAGS.model_architecture,
                                 is_training=istrain)
    ############ 模型创建 ##########
    # logits, dropout_prob= models.create_model(
    #     fingerprint_input,
    #     model_settings,
    #     FLAGS.model_architecture,
    #     is_training=True)
    # Define loss and optimizer

    ############ 真实值 ##########
    ground_truth_input = tf.placeholder(tf.float32, [None, label_count],
                                        name='groundtruth_input')

    # Optionally we can add runtime checks to spot when NaNs or other symptoms of
    # numerical errors start occurring during training.
    control_dependencies = []
    if FLAGS.check_nans:
        checks = tf.add_check_numerics_ops()
        control_dependencies = [checks]

    # Create the back propagation and training evaluation machinery in the graph.
    ############ 交叉熵计算 ##########
    # with tf.name_scope('cross_entropy'):
    #   cross_entropy_mean = tf.reduce_mean(
    #       tf.nn.softmax_cross_entropy_with_logits(
    #           labels=ground_truth_input, logits=logits)) + beta*loss_norm
    with tf.name_scope('cross_entropy'):
        cross_entropy_mean = tf.reduce_mean(
            tf.nn.softmax_cross_entropy_with_logits(labels=ground_truth_input,
                                                    logits=logits))
    tf.summary.scalar('cross_entropy', cross_entropy_mean)

    ############ 学习率、准确率、混淆矩阵 ##########
    # learning_rate_input    学习率输入(tf.placeholder)
    # train_step             训练过程 (优化器)
    # predicted_indices      预测输出索引
    # expected_indices       实际希望输出索引
    # correct_prediction     正确预测矩阵
    # confusion_matrix       混淆矩阵
    # evaluation_step        正确分类概率(每个阶段)
    # global_step            全局训练阶段
    # increment_global_step  全局训练阶段递增

    learning_rate_input = tf.placeholder(tf.float32, [],
                                         name='learning_rate_input')
    update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
    with tf.control_dependencies(update_ops):
        train_step = tf.train.AdamOptimizer(learning_rate_input).minimize(
            cross_entropy_mean)
    # with tf.name_scope('train'), tf.control_dependencies(control_dependencies):
    #   learning_rate_input = tf.placeholder(
    #       tf.float32, [], name='learning_rate_input')
    #  # train_step = tf.train.GradientDescentOptimizer(
    #     #  learning_rate_input).minimize(cross_entropy_mean)
    #   with tf.control_dependencies(update_ops):
    #       train_step = tf.train.AdamOptimizer(
    #           learning_rate_input).minimize(cross_entropy_mean)
    predicted_indices = tf.argmax(logits, 1)
    expected_indices = tf.argmax(ground_truth_input, 1)
    correct_prediction = tf.equal(predicted_indices, expected_indices)
    confusion_matrix = tf.confusion_matrix(expected_indices,
                                           predicted_indices,
                                           num_classes=label_count)
    evaluation_step = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
    acc = tf.summary.scalar('accuracy', evaluation_step)

    global_step = tf.train.get_or_create_global_step()
    increment_global_step = tf.assign(global_step, global_step + 1)

    saver = tf.train.Saver(tf.global_variables(),
                           max_to_keep=None)  # max keep file // moren 5

    # Merge all the summaries and write them out to /tmp/retrain_logs (by default)
    merged_summaries = tf.summary.merge_all()
    validation_merged_summaries = tf.summary.merge([
        tf.get_collection(tf.GraphKeys.SUMMARIES, 'accuracy'),
        tf.get_collection(tf.GraphKeys.SUMMARIES, 'cross_entropy')
    ])
    test_summaries = tf.summary.merge([acc])
    test_summaries_istrain = tf.summary.merge([
        tf.get_collection(tf.GraphKeys.SUMMARIES, 'accuracy'),
        tf.get_collection(tf.GraphKeys.SUMMARIES, 'cross_entropy')
    ])

    #test_summaries_istrain = tf.summary.merge([acc])
    train_writer = tf.summary.FileWriter(FLAGS.summaries_dir + '/train',
                                         sess.graph)
    # validation_writer = tf.summary.FileWriter(FLAGS.summaries_dir + '/validation')
    test_writer = tf.summary.FileWriter(FLAGS.summaries_dir + '/test')
    test_istrain_writer = tf.summary.FileWriter(FLAGS.summaries_dir +
                                                '/test_istrain')
    tf.global_variables_initializer().run()

    start_step = 1

    if FLAGS.start_checkpoint:
        models.load_variables_from_checkpoint(sess, FLAGS.start_checkpoint)
        start_step = global_step.eval(session=sess)

    tf.logging.info('Training from step: %d ', start_step)

    # Save graph.pbtxt.
    tf.train.write_graph(sess.graph_def, FLAGS.train_dir,
                         FLAGS.model_architecture + '.pbtxt')

    # Save list of words.
    with gfile.GFile(
            os.path.join(FLAGS.train_dir,
                         FLAGS.model_architecture + '_labels.txt'), 'w') as f:
        f.write('\n'.join(audio_processor.words_list))
###
# model1: fc
# model2: conv :940k个parameter
# model3:low_latancy_conv:~~model1
# model4: 750k
# Training loop.
#############################################
########            主循环              ######
#############################################
    training_steps_max = np.sum(training_steps_list)
    for training_step in xrange(start_step, training_steps_max + 1):
        # Figure out what the current learning rate is.
        #######       自动切换学习率      #######
        if training_step < 12000 + 1:
            learning_rate_value = learning_rates_list[0] * 0.02**(
                training_step / 12000)
        else:
            learning_rate_value = learning_rates_list[0] * 0.02  #0.015 12000
        training_steps_sum = 0
        # for i in range(len(training_steps_list)):
        #   training_steps_sum += training_steps_list[i]
        #   if training_step <= training_steps_sum:
        #     learning_rate_value = learning_rates_list[i]
        #     break

        # Pull the audio samples we'll use for training.
        #######       audio处理器导入数据      ##################################
        ##get_data(self, how_many, offset, model_settings, background_frequency,
        ##         background_volume_range, time_shift, mode, sess)
        ########################################################################
        train_fingerprints, train_ground_truth = audio_processor.get_data_my(
            FLAGS.batch_size, 0, model_settings, FLAGS.background_frequency,
            FLAGS.background_volume, time_shift_samples, 'training', sess)
        #mid = np.abs(np.max(train_fingerprints) + np.min(train_fingerprints)) / 2
        #half = np.max(train_fingerprints) - np.min(train_fingerprints)
        # train_fingerprints = ((train_fingerprints + mid) / half * 255).astype(int)
        #train_fingerprints = np_round_and_clip_5bit(train_fingerprints)

        train_summary, train_accuracy, cross_entropy_value, _, _ = sess.run(
            [
                merged_summaries, evaluation_step, cross_entropy_mean,
                train_step, increment_global_step
            ],
            feed_dict={
                fingerprint_input: train_fingerprints,
                ground_truth_input: train_ground_truth,
                learning_rate_input: learning_rate_value,
                istrain: True
            })
        train_writer.add_summary(train_summary, training_step)
        tf.logging.info(
            'Step #%d: rate %f, accuracy %.1f%%, cross entropy %f' %
            (training_step, learning_rate_value, train_accuracy * 100,
             cross_entropy_value))
        is_last_step = (training_step == training_steps_max)
        if (training_step % FLAGS.eval_step_interval) == 0 or is_last_step:

            #############################################
            ########  测试集重复计算正确率和混淆矩阵  ######
            set_size = audio_processor.set_size('testing')
            tf.logging.info('set_size=%d', set_size)
            test_fingerprints, test_ground_truth = audio_processor.get_data_my(
                -3, 0, model_settings, 0.0, 0.0, 0, 'testing', sess)
            #mid = np.abs(np.max(test_fingerprints) + np.min(test_fingerprints)) / 2
            #half = np.max(test_fingerprints) - np.min(test_fingerprints)
            # test_fingerprints = np_round_and_clip_5bit(test_fingerprints)
            final_summary, test_accuracy, conf_matrix = sess.run(
                [test_summaries, evaluation_step, confusion_matrix],
                feed_dict={
                    fingerprint_input: test_fingerprints,
                    ground_truth_input: test_ground_truth,
                    istrain: False
                })
            final_summary_istrain, test_accuracy_istrain = sess.run(
                [test_summaries_istrain, evaluation_step],
                feed_dict={
                    fingerprint_input: test_fingerprints,
                    ground_truth_input: test_ground_truth,
                    istrain: True
                })

            if test_accuracy > best_acc:
                best_acc = test_accuracy
                best_step = training_step
            if test_accuracy_istrain > best_acc_istrain:
                best_acc_istrain = test_accuracy_istrain
                best_step_istrain = training_step
            test_writer.add_summary(final_summary, training_step)
            test_istrain_writer.add_summary(final_summary_istrain,
                                            training_step)
            tf.logging.info('Confusion Matrix:\n %s' % (conf_matrix))
            tf.logging.info('test accuracy = %.1f%% (N=%d)' %
                            (test_accuracy * 100, 6882))
            tf.logging.info('test_istrain accuracy = %.1f%% (N=%d)' %
                            (test_accuracy_istrain * 100, 6882))

            tf.logging.info('Best test accuracy before now = %.1f%% (N=%d)' %
                            (best_acc * 100, 6882) + '  at step of ' +
                            str(best_step))
            tf.logging.info(
                'Best test_istrain accuracy before now = %.1f%% (N=%d)' %
                (best_acc_istrain * 100, 6882) + '  at step of ' +
                str(best_step_istrain))
        # Save the model checkpoint periodically.
        if (training_step % FLAGS.save_step_interval == 0
                or training_step == training_steps_max):
            checkpoint_path = os.path.join(
                FLAGS.train_dir + '/' + FLAGS.model_architecture,
                FLAGS.model_architecture + '.ckpt')
            tf.logging.info('Saving to "%s-%d"', checkpoint_path,
                            training_step)
            saver.save(sess, checkpoint_path, global_step=training_step)
        print_line = 'Best test accuracy before now = %.1f%% (N=%d)' % (best_acc * 100,6882) + '  at step of ' + str(best_step) + '\n' + \
                     'Best test_istrain accuracy before now = %.1f%% (N=%d)' % (best_acc_istrain * 100,6882) + '  at step of ' + str(best_step_istrain)
        if training_step == training_steps_max:
            with open(
                    FLAGS.train_dir + '/' + FLAGS.model_architecture +
                    '/details.txt', 'w') as f:
                f.write(print_line)
Ejemplo n.º 22
0
def build_buckets(buckets, max_vocab, source_languages, target_languages):
    """
    Build bucketed versions of the data, tokenizing it and writing it to a file in the process.

    :param buckets: List of pairs representing (source, target) bucket length.
    :param max_vocab: Dictionary mapping language id to Maximum Vocabulary Size.
    """
    if len(os.listdir(os.path.join(FLAGS.bucket_dir, str(0)))) <= 1:
        for target_id in target_languages:
            target_vocab, _ = init_vocab(target_id, max_vocab[target_id])
            for source_id in source_languages:
                source_vocab, _ = init_vocab(source_id, max_vocab[source_id])
                base_fp = "%s-%s." % tuple(sorted([source_id, target_id]))
                eval_fp = "%s-%s-eval." % tuple(sorted([source_id, target_id]))
                print base_fp, source_id, target_id
                data_set = {k: [] for k in range(len(buckets))}
                with gfile.GFile(
                        os.path.join(FLAGS.raw_dir, base_fp + source_id),
                        'r') as src:
                    with gfile.GFile(
                            os.path.join(FLAGS.raw_dir, base_fp + target_id),
                            'r') as trg:
                        source, target = src.readline(), trg.readline()
                        while source and target:
                            src_ids = [
                                source_vocab.get(w, UNK_ID)
                                for w in basic_tokenizer(source)
                            ]
                            trg_ids = [
                                target_vocab.get(w, UNK_ID)
                                for w in basic_tokenizer(target)
                            ]
                            trg_ids.append(EOS_ID)

                            # Find a bucket
                            for bucket_id, (source_size,
                                            target_size) in enumerate(buckets):
                                if len(src_ids) < source_size and len(
                                        trg_ids) < target_size:
                                    data_set[bucket_id].append(
                                        [src_ids, trg_ids])
                                    break

                            source, target = src.readline(), trg.readline()
                for k in data_set:
                    counter = 0
                    data_train = data_set[k][:-200]
                    data_eval = data_set[k][-200:]
                    bucket_path = os.path.join(FLAGS.bucket_dir, str(k))
                    with gfile.GFile(
                            os.path.join(bucket_path, base_fp + source_id),
                            'w') as src:
                        with gfile.GFile(
                                os.path.join(bucket_path, base_fp + target_id),
                                'w') as trg:
                            for i in range(len(data_train)):
                                counter += 1
                                s, t = data_train[i]
                                src.write(" ".join(map(str, s)) + "\n")
                                trg.write(" ".join(map(str, t)) + "\n")
                    with gfile.GFile(
                            os.path.join(bucket_path, eval_fp + source_id),
                            'w') as src:
                        with gfile.GFile(
                                os.path.join(bucket_path, eval_fp + target_id),
                                'w') as trg:
                            for i in range(len(data_eval)):
                                s, t = data_eval[i]
                                src.write(" ".join(map(str, s)) + "\n")
                                trg.write(" ".join(map(str, t)) + "\n")
                    print "Bucket", k, "for %s-%s" % (
                        source_id, target_id), "has", counter, "examples!"
Ejemplo n.º 23
0
def _read_words(filename):
    with gfile.GFile(filename, "r") as f:
        return f.read().replace("\n", "<eos>").split()
Ejemplo n.º 24
0
    with tf.Session() as sess:
        constant_graph = graph_util.convert_variables_to_constants(
            sess, sess.graph.as_graph_def(), pred_node_names)

    frozen = graph_util.remove_training_nodes(constant_graph)

    output = "cnn.pb"
    graph_io.write_graph(frozen, ".", output, as_text=False)


export_cnn()

tf.reset_default_graph()

model_filename = "cnn.pb"
with gfile.GFile(model_filename, "rb") as f:
    graph_def = tf.GraphDef()
    graph_def.ParseFromString(f.read())

config = tfe.LocalConfig([
    "server0", "server1", "crypto-producer", "prediction-client",
    "weights-provider"
])


def provide_input() -> tf.Tensor:
    return tf.constant(np.random.normal(size=(1, 1, 28, 28)), tf.float32)


def receive_output(tensor: tf.Tensor) -> tf.Tensor:
    tf.print(tensor, [tensor])
Ejemplo n.º 25
0
def main(_):

    # We want to see all the logging messages for this tutorial.
    tf.logging.set_verbosity(tf.logging.INFO)

    # Start a new TensorFlow session.
    sess = tf.InteractiveSession()

    model_settings = models.prepare_model_settings(
        len(input_data.prepare_words_list(FLAGS.wanted_words.split(','))),
        FLAGS.sample_rate, FLAGS.clip_duration_ms, FLAGS.window_size_ms,
        FLAGS.window_stride_ms, FLAGS.dct_coefficient_count)

    audio_processor = input_data.AudioProcessor(FLAGS.data_url, FLAGS.data_dir,
                                                FLAGS.silence_percentage,
                                                FLAGS.unknown_percentage,
                                                FLAGS.wanted_words.split(','),
                                                FLAGS.validation_percentage,
                                                FLAGS.testing_percentage,
                                                model_settings)

    fingerprint_size = model_settings['fingerprint_size']
    label_count = model_settings['label_count']
    time_shift_samples = int((FLAGS.time_shift_ms * FLAGS.sample_rate) / 1000)

    training_steps = FLAGS.how_many_training_steps
    learning_rate = FLAGS.learning_rate

    # -----------------------------------------------------------------------
    # -----------------------------Placeholder-------------------------------
    # -----------------------------------------------------------------------

    fingerprint_input = tf.placeholder(tf.float32, [None, fingerprint_size],
                                       name='fingerprint_input')

    logits, dropout_prob = models.create_model(fingerprint_input,
                                               model_settings,
                                               FLAGS.model_architecture,
                                               is_training=True)

    # Define loss and optimizer
    ground_truth_input = tf.placeholder(tf.int64, [None],
                                        name='groundtruth_input')

    # Optionally we can add runtime checks to spot when NaNs or other symptoms of
    # numerical errors start occurring during training.
    control_dependencies = []
    if FLAGS.check_nans:
        checks = tf.add_check_numerics_ops()
        control_dependencies = [checks]

    # -----------------------------------------------------------------------
    # -----------------Back propagation and training evaluation--------------
    # -----------------------------------------------------------------------

    # Create the back propagation and training evaluation machinery in the graph.
    with tf.name_scope('cross_entropy'):
        cross_entropy_mean = tf.losses.sparse_softmax_cross_entropy(
            labels=ground_truth_input, logits=logits)

    tf.summary.scalar('cross_entropy', cross_entropy_mean)

    with tf.name_scope('train'), tf.control_dependencies(control_dependencies):
        train_step = tf.train.AdamOptimizer(learning_rate).minimize(
            cross_entropy_mean)

    predicted_indices = tf.argmax(logits, 1)
    correct_prediction = tf.equal(predicted_indices, ground_truth_input)
    confusion_matrix = tf.confusion_matrix(ground_truth_input,
                                           predicted_indices,
                                           num_classes=label_count)
    evaluation_step = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
    tf.summary.scalar('accuracy', evaluation_step)

    global_step = tf.train.get_or_create_global_step()
    increment_global_step = tf.assign(global_step, global_step + 1)

    saver = tf.train.Saver(tf.global_variables())

    # Merge all the summaries and write them out to /tmp/retrain_logs (by default)
    merged_summaries = tf.summary.merge_all()
    train_writer = tf.summary.FileWriter(FLAGS.summaries_dir + '/train',
                                         sess.graph)
    validation_writer = tf.summary.FileWriter(FLAGS.summaries_dir +
                                              '/validation')

    tf.global_variables_initializer().run()

    start_step = 1

    if FLAGS.start_checkpoint:
        models.load_variables_from_checkpoint(sess, FLAGS.start_checkpoint)
        start_step = global_step.eval(session=sess)

    tf.logging.info('Training from step: %d ', start_step)

    # Save graph.pbtxt.
    tf.train.write_graph(sess.graph_def, FLAGS.train_dir,
                         FLAGS.model_architecture + '.pbtxt')

    # Save list of words.
    with gfile.GFile(
            os.path.join(FLAGS.train_dir,
                         FLAGS.model_architecture + '_labels.txt'), 'w') as f:
        f.write('\n'.join(audio_processor.words_list))

    # -----------------------------------------------------------------------
    # -----------------Training and validation-------------------------------
    # -----------------------------------------------------------------------

    # Training loop.
    training_steps_max = training_steps

    # Print the local time of beginning training
    beg_time = datetime.datetime.now()
    print("Beginning time : " + str(beg_time))

    for training_step in xrange(start_step, training_steps_max + 1):

        # Pull the audio samples we'll use for training.
        train_fingerprints, train_ground_truth = audio_processor.get_data(
            FLAGS.batch_size, 0, model_settings, FLAGS.background_frequency,
            FLAGS.background_volume, time_shift_samples, 'training', sess)

        # Run the graph with this batch of training data.
        train_summary, train_accuracy, cross_entropy_value, _, _ = sess.run(
            [
                merged_summaries, evaluation_step, cross_entropy_mean,
                train_step, increment_global_step
            ],
            feed_dict={
                fingerprint_input: train_fingerprints,
                ground_truth_input: train_ground_truth,
                dropout_prob: 0.5
            })

        train_writer.add_summary(train_summary, training_step)
        tf.logging.info(
            'Step #%d: rate %f, accuracy %.1f%%, cross entropy %f' %
            (training_step, learning_rate, train_accuracy * 100,
             cross_entropy_value))
        is_last_step = (training_step == training_steps_max)

        # Validation
        if (training_step % FLAGS.eval_step_interval) == 0 or is_last_step:

            set_size = audio_processor.set_size('validation')
            total_accuracy = 0
            total_conf_matrix = None

            for i in xrange(0, set_size, FLAGS.batch_size):

                validation_fingerprints, validation_ground_truth = (
                    audio_processor.get_data(FLAGS.batch_size, i,
                                             model_settings, 0.0, 0.0, 0,
                                             'validation', sess))

                # Run a validation step and capture training summaries for TensorBoard
                # with the `merged` op.
                validation_summary, validation_accuracy, conf_matrix = sess.run(
                    [merged_summaries, evaluation_step, confusion_matrix],
                    feed_dict={
                        fingerprint_input: validation_fingerprints,
                        ground_truth_input: validation_ground_truth,
                        dropout_prob: 1.0
                    })

                validation_writer.add_summary(validation_summary,
                                              training_step)
                batch_size = min(FLAGS.batch_size, set_size - i)
                total_accuracy += (validation_accuracy * batch_size) / set_size

                if total_conf_matrix is None:
                    total_conf_matrix = conf_matrix
                else:
                    total_conf_matrix += conf_matrix

            tf.logging.info('Confusion Matrix:\n %s' % (total_conf_matrix))
            tf.logging.info('Step %d: Validation accuracy = %.1f%% (N=%d)' %
                            (training_step, total_accuracy * 100, set_size))

        # Save the model checkpoint periodically.
        if (training_step % FLAGS.save_step_interval == 0
                or training_step == training_steps_max):
            checkpoint_path = os.path.join(FLAGS.train_dir,
                                           FLAGS.model_architecture + '.ckpt')

            tf.logging.info('Saving to "%s-%d"', checkpoint_path,
                            training_step)
            saver.save(sess, checkpoint_path, global_step=training_step)

    # Print the local time of ending training
    print("Beginning time : " + str(beg_time))
    print("Ending time : " + str(datetime.datetime.now()))

    # -----------------------------------------------------------------------
    # ------------------------------Test-------------------------------------
    # -----------------------------------------------------------------------

    set_size = audio_processor.set_size('testing')
    tf.logging.info('set_size=%d', set_size)
    total_accuracy = 0
    total_conf_matrix = None

    for i in xrange(0, set_size, FLAGS.batch_size):

        test_fingerprints, test_ground_truth = audio_processor.get_data(
            FLAGS.batch_size, i, model_settings, 0.0, 0.0, 0, 'testing', sess)

        test_accuracy, conf_matrix = sess.run(
            [evaluation_step, confusion_matrix],
            feed_dict={
                fingerprint_input: test_fingerprints,
                ground_truth_input: test_ground_truth,
                dropout_prob: 1.0
            })

        batch_size = min(FLAGS.batch_size, set_size - i)
        total_accuracy += (test_accuracy * batch_size) / set_size

        if total_conf_matrix is None:
            total_conf_matrix = conf_matrix
        else:
            total_conf_matrix += conf_matrix

    tf.logging.info('Confusion Matrix:\n %s' % (total_conf_matrix))
    tf.logging.info('Final test accuracy = %.1f%% (N=%d)' %
                    (total_accuracy * 100, set_size))
Ejemplo n.º 26
0
  def test_export_savedmodel(self):
    tmpdir = tempfile.mkdtemp()
    est, serving_input_fn = _build_estimator_for_export_tests(tmpdir)

    extra_file_name = os.path.join(
        compat.as_bytes(tmpdir), compat.as_bytes('my_extra_file'))
    extra_file = gfile.GFile(extra_file_name, mode='w')
    extra_file.write(EXTRA_FILE_CONTENT)
    extra_file.close()
    assets_extra = {'some/sub/directory/my_extra_file': extra_file_name}

    export_dir_base = os.path.join(
        compat.as_bytes(tmpdir), compat.as_bytes('export'))
    export_dir = est.export_savedmodel(
        export_dir_base, serving_input_fn, assets_extra=assets_extra)

    self.assertTrue(gfile.Exists(export_dir_base))
    self.assertTrue(gfile.Exists(export_dir))
    self.assertTrue(
        gfile.Exists(
            os.path.join(
                compat.as_bytes(export_dir), compat.as_bytes(
                    'saved_model.pb'))))
    self.assertTrue(
        gfile.Exists(
            os.path.join(
                compat.as_bytes(export_dir), compat.as_bytes('variables'))))
    self.assertTrue(
        gfile.Exists(
            os.path.join(
                compat.as_bytes(export_dir),
                compat.as_bytes('variables/variables.index'))))
    self.assertTrue(
        gfile.Exists(
            os.path.join(
                compat.as_bytes(export_dir),
                compat.as_bytes('variables/variables.data-00000-of-00001'))))

    self.assertTrue(
        gfile.Exists(
            os.path.join(
                compat.as_bytes(export_dir), compat.as_bytes('assets'))))
    self.assertTrue(
        gfile.Exists(
            os.path.join(
                compat.as_bytes(export_dir),
                compat.as_bytes('assets/my_vocab_file'))))
    self.assertEqual(
        compat.as_bytes(VOCAB_FILE_CONTENT),
        compat.as_bytes(
            gfile.GFile(
                os.path.join(
                    compat.as_bytes(export_dir),
                    compat.as_bytes('assets/my_vocab_file'))).read()))

    expected_extra_path = os.path.join(
        compat.as_bytes(export_dir),
        compat.as_bytes('assets.extra/some/sub/directory/my_extra_file'))
    self.assertTrue(
        gfile.Exists(
            os.path.join(
                compat.as_bytes(export_dir), compat.as_bytes('assets.extra'))))
    self.assertTrue(gfile.Exists(expected_extra_path))
    self.assertEqual(
        compat.as_bytes(EXTRA_FILE_CONTENT),
        compat.as_bytes(gfile.GFile(expected_extra_path).read()))

    expected_vocab_file = os.path.join(
        compat.as_bytes(tmpdir), compat.as_bytes('my_vocab_file'))
    # Restore, to validate that the export was well-formed.
    with ops.Graph().as_default() as graph:
      with session_lib.Session(graph=graph) as sess:
        loader.load(sess, [tag_constants.SERVING], export_dir)
        assets = [
            x.eval()
            for x in graph.get_collection(ops.GraphKeys.ASSET_FILEPATHS)
        ]
        self.assertItemsEqual([expected_vocab_file], assets)
        graph_ops = [x.name for x in graph.get_operations()]
        self.assertTrue('input_example_tensor' in graph_ops)
        self.assertTrue('ParseExample/ParseExample' in graph_ops)
        self.assertTrue('linear/linear/feature/matmul' in graph_ops)

    # cleanup
    gfile.DeleteRecursively(tmpdir)
Ejemplo n.º 27
0
def main(_):
    # We want to see all the logging messages for this tutorial.
    tf.logging.set_verbosity(tf.logging.INFO)

    # Start a new TensorFlow session.
    sess = tf.InteractiveSession()

    # Begin by making sure we have the training data we need. If you already have
    # training data of your own, use `--data_url= ` on the command line to avoid
    # downloading.
    model_settings = models.prepare_model_settings(
        len(input_data.prepare_words_list(FLAGS.wanted_words.split(','))),
        FLAGS.sample_rate, FLAGS.clip_duration_ms, FLAGS.window_size_ms,
        FLAGS.window_stride_ms, FLAGS.dct_coefficient_count)
    audio_processor = input_data.AudioProcessor(FLAGS.data_url, FLAGS.data_dir,
                                                FLAGS.silence_percentage,
                                                FLAGS.unknown_percentage,
                                                FLAGS.wanted_words.split(','),
                                                FLAGS.validation_percentage,
                                                FLAGS.testing_percentage,
                                                model_settings)
    fingerprint_size = model_settings['fingerprint_size']
    label_count = model_settings['label_count']
    time_shift_samples = int((FLAGS.time_shift_ms * FLAGS.sample_rate) / 1000)
    # Figure out the learning rates for each training phase. Since it's often
    # effective to have high learning rates at the start of training, followed by
    # lower levels towards the end, the number of steps and learning rates can be
    # specified as comma-separated lists to define the rate at each stage. For
    # example --how_many_training_steps=10000,3000 --learning_rate=0.001,0.0001
    # will run 13,000 training loops in total, with a rate of 0.001 for the first
    # 10,000, and 0.0001 for the final 3,000.
    training_steps_list = list(
        map(int, FLAGS.how_many_training_steps.split(',')))
    learning_rates_list = list(map(float, FLAGS.learning_rate.split(',')))
    if len(training_steps_list) != len(learning_rates_list):
        raise Exception(
            '--how_many_training_steps and --learning_rate must be equal length '
            'lists, but are %d and %d long instead' %
            (len(training_steps_list), len(learning_rates_list)))

    fingerprint_input = tf.placeholder(tf.float32, [None, fingerprint_size],
                                       name='fingerprint_input')

    logits, dropout_prob = models.create_model(fingerprint_input,
                                               model_settings,
                                               FLAGS.model_architecture,
                                               is_training=True)

    # Define loss and optimizer
    ground_truth_input = tf.placeholder(tf.int64, [None],
                                        name='groundtruth_input')

    # Optionally we can add runtime checks to spot when NaNs or other symptoms of
    # numerical errors start occurring during training.
    control_dependencies = []
    if FLAGS.check_nans:
        checks = tf.add_check_numerics_ops()
        control_dependencies = [checks]

    # Create the back propagation and training evaluation machinery in the graph.
    with tf.name_scope('cross_entropy'):
        cross_entropy_mean = tf.losses.sparse_softmax_cross_entropy(
            labels=ground_truth_input, logits=logits)
    tf.summary.scalar('cross_entropy', cross_entropy_mean)
    with tf.name_scope('train'), tf.control_dependencies(control_dependencies):
        learning_rate_input = tf.placeholder(tf.float32, [],
                                             name='learning_rate_input')
        train_step = tf.train.GradientDescentOptimizer(
            learning_rate_input).minimize(cross_entropy_mean)
    predicted_indices = tf.argmax(logits, 1)
    correct_prediction = tf.equal(predicted_indices, ground_truth_input)
    confusion_matrix = tf.confusion_matrix(ground_truth_input,
                                           predicted_indices,
                                           num_classes=label_count)
    evaluation_step = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
    tf.summary.scalar('accuracy', evaluation_step)

    global_step = tf.train.get_or_create_global_step()
    increment_global_step = tf.assign(global_step, global_step + 1)

    saver = tf.train.Saver(tf.global_variables())

    # Merge all the summaries and write them out to /tmp/retrain_logs (by default)
    merged_summaries = tf.summary.merge_all()
    train_writer = tf.summary.FileWriter(FLAGS.summaries_dir + '/train',
                                         sess.graph)
    validation_writer = tf.summary.FileWriter(FLAGS.summaries_dir +
                                              '/validation')

    tf.global_variables_initializer().run()

    start_step = 1

    if FLAGS.start_checkpoint:
        models.load_variables_from_checkpoint(sess, FLAGS.start_checkpoint)
        start_step = global_step.eval(session=sess)

    tf.logging.info('Training from step: %d ', start_step)

    # Save graph.pbtxt.
    tf.train.write_graph(sess.graph_def, FLAGS.train_dir,
                         FLAGS.model_architecture + '.pbtxt')

    # Save list of words.
    with gfile.GFile(
            os.path.join(FLAGS.train_dir,
                         FLAGS.model_architecture + '_labels.txt'), 'w') as f:
        f.write('\n'.join(audio_processor.words_list))

    # Training loop.
    training_steps_max = np.sum(training_steps_list)
    for training_step in xrange(start_step, training_steps_max + 1):
        # Figure out what the current learning rate is.
        training_steps_sum = 0
        for i in range(len(training_steps_list)):
            training_steps_sum += training_steps_list[i]
            if training_step <= training_steps_sum:
                learning_rate_value = learning_rates_list[i]
                break
        # Pull the audio samples we'll use for training.
        train_fingerprints, train_ground_truth = audio_processor.get_data(
            FLAGS.batch_size, 0, model_settings, FLAGS.background_frequency,
            FLAGS.background_volume, time_shift_samples, 'training', sess)
        # Run the graph with this batch of training data.
        train_summary, train_accuracy, cross_entropy_value, _, _ = sess.run(
            [
                merged_summaries, evaluation_step, cross_entropy_mean,
                train_step, increment_global_step
            ],
            feed_dict={
                fingerprint_input: train_fingerprints,
                ground_truth_input: train_ground_truth,
                learning_rate_input: learning_rate_value,
                dropout_prob: 0.5
            })
        train_writer.add_summary(train_summary, training_step)
        tf.logging.info(
            'Step #%d: rate %f, accuracy %.1f%%, cross entropy %f' %
            (training_step, learning_rate_value, train_accuracy * 100,
             cross_entropy_value))
        is_last_step = (training_step == training_steps_max)
        if (training_step % FLAGS.eval_step_interval) == 0 or is_last_step:
            set_size = audio_processor.set_size('validation')
            total_accuracy = 0
            total_conf_matrix = None
            for i in xrange(0, set_size, FLAGS.batch_size):
                validation_fingerprints, validation_ground_truth = (
                    audio_processor.get_data(FLAGS.batch_size, i,
                                             model_settings, 0.0, 0.0, 0,
                                             'validation', sess))
                # Run a validation step and capture training summaries for TensorBoard
                # with the `merged` op.
                validation_summary, validation_accuracy, conf_matrix = sess.run(
                    [merged_summaries, evaluation_step, confusion_matrix],
                    feed_dict={
                        fingerprint_input: validation_fingerprints,
                        ground_truth_input: validation_ground_truth,
                        dropout_prob: 1.0
                    })
                validation_writer.add_summary(validation_summary,
                                              training_step)
                batch_size = min(FLAGS.batch_size, set_size - i)
                total_accuracy += (validation_accuracy * batch_size) / set_size
                if total_conf_matrix is None:
                    total_conf_matrix = conf_matrix
                else:
                    total_conf_matrix += conf_matrix
            tf.logging.info('Confusion Matrix:\n %s' % (total_conf_matrix))
            tf.logging.info('Step %d: Validation accuracy = %.1f%% (N=%d)' %
                            (training_step, total_accuracy * 100, set_size))

        # Save the model checkpoint periodically.
        if (training_step % FLAGS.save_step_interval == 0
                or training_step == training_steps_max):
            checkpoint_path = os.path.join(FLAGS.train_dir,
                                           FLAGS.model_architecture + '.ckpt')
            tf.logging.info('Saving to "%s-%d"', checkpoint_path,
                            training_step)
            saver.save(sess, checkpoint_path, global_step=training_step)

    set_size = audio_processor.set_size('testing')
    tf.logging.info('set_size=%d', set_size)
    total_accuracy = 0
    total_conf_matrix = None
    for i in xrange(0, set_size, FLAGS.batch_size):
        test_fingerprints, test_ground_truth = audio_processor.get_data(
            FLAGS.batch_size, i, model_settings, 0.0, 0.0, 0, 'testing', sess)
        test_accuracy, conf_matrix = sess.run(
            [evaluation_step, confusion_matrix],
            feed_dict={
                fingerprint_input: test_fingerprints,
                ground_truth_input: test_ground_truth,
                dropout_prob: 1.0
            })
        batch_size = min(FLAGS.batch_size, set_size - i)
        total_accuracy += (test_accuracy * batch_size) / set_size
        if total_conf_matrix is None:
            total_conf_matrix = conf_matrix
        else:
            total_conf_matrix += conf_matrix
    tf.logging.info('Confusion Matrix:\n %s' % (total_conf_matrix))
    tf.logging.info('Final test accuracy = %.1f%% (N=%d)' %
                    (total_accuracy * 100, set_size))
Ejemplo n.º 28
0
def freeze_graph_with_def_protos(input_graph_def,
                                 input_saver_def,
                                 input_checkpoint,
                                 output_node_names,
                                 restore_op_name,
                                 filename_tensor_name,
                                 output_graph,
                                 clear_devices,
                                 initializer_nodes,
                                 variable_names_whitelist="",
                                 variable_names_blacklist="",
                                 input_meta_graph_def=None,
                                 input_saved_model_dir=None,
                                 saved_model_tags=None,
                                 checkpoint_version=saver_pb2.SaverDef.V2):
    """Converts all variables in a graph and checkpoint into constants."""
    del restore_op_name, filename_tensor_name  # Unused by updated loading code.

    # 'input_checkpoint' may be a prefix if we're using Saver V2 format
    if (not input_saved_model_dir
            and not saver_lib.checkpoint_exists(input_checkpoint)):
        print("Input checkpoint '" + input_checkpoint + "' doesn't exist!")
        return -1

    if not output_node_names:
        print("You need to supply the name of a node to --output_node_names.")
        return -1

    # Remove all the explicit device specifications for this node. This helps to
    # make the graph more portable.
    if clear_devices:
        if input_meta_graph_def:
            for node in input_meta_graph_def.graph_def.node:
                node.device = ""
        elif input_graph_def:
            for node in input_graph_def.node:
                node.device = ""
    if input_graph_def:
        _ = importer.import_graph_def(input_graph_def, name="")
    with session.Session() as sess:
        if input_saver_def:
            saver = saver_lib.Saver(saver_def=input_saver_def,
                                    write_version=checkpoint_version)
            saver.restore(sess, input_checkpoint)
        elif input_meta_graph_def:
            restorer = saver_lib.import_meta_graph(input_meta_graph_def,
                                                   clear_devices=True)
            restorer.restore(sess, input_checkpoint)
            if initializer_nodes:
                sess.run(initializer_nodes.replace(" ", "").split(","))
        elif input_saved_model_dir:
            if saved_model_tags is None:
                saved_model_tags = []
            loader.load(sess, saved_model_tags, input_saved_model_dir)
        else:
            var_list = {}
            reader = pywrap_tensorflow.NewCheckpointReader(input_checkpoint)
            # ruiyingy:start
            # checkpointkey_to_node = {}
            # object_graph_string = reader.get_tensor(base.OBJECT_GRAPH_PROTO_KEY)
            # object_graph_proto = trackable_object_graph_pb2.TrackableObjectGraph()
            # object_graph_proto.ParseFromString(object_graph_string)
            # for node in object_graph_proto.nodes:
            #   for attr in node.attributes:
            #     checkpointkey_to_node[attr.checkpoint_key] = attr.full_name
            # ruiyingy:end
            var_to_shape_map = reader.get_variable_to_shape_map()
            for key in var_to_shape_map:
                try:
                    # ruiyingy
                    # key = checkpointkey_to_node[key]
                    tensor = sess.graph.get_tensor_by_name(key + ":0")
                except KeyError:
                    # This tensor doesn't exist in the graph (for example it's
                    # 'global_step' or a similar housekeeping element) so skip it.
                    continue
                var_list[key] = tensor
            saver = saver_lib.Saver(var_list=var_list,
                                    write_version=checkpoint_version)
            saver.restore(sess, input_checkpoint)
            if initializer_nodes:
                sess.run(initializer_nodes.replace(" ", "").split(","))

        variable_names_whitelist = (variable_names_whitelist.replace(
            " ", "").split(",") if variable_names_whitelist else None)
        variable_names_blacklist = (variable_names_blacklist.replace(
            " ", "").split(",") if variable_names_blacklist else None)

        if input_meta_graph_def:
            output_graph_def = graph_util.convert_variables_to_constants(
                sess,
                input_meta_graph_def.graph_def,
                output_node_names.replace(" ", "").split(","),
                variable_names_whitelist=variable_names_whitelist,
                variable_names_blacklist=variable_names_blacklist)
        else:
            output_graph_def = graph_util.convert_variables_to_constants(
                sess,
                input_graph_def,
                output_node_names.replace(" ", "").split(","),
                variable_names_whitelist=variable_names_whitelist,
                variable_names_blacklist=variable_names_blacklist)

    # Write GraphDef to file if output path has been given.
    if output_graph:
        with gfile.GFile(output_graph, "wb") as f:
            f.write(output_graph_def.SerializeToString())
        #with gfile.GFile(output_graph, "w") as f:
        #  f.write(text_format.MessageToString(output_graph_def))
        #tensorflow.train.write_graph(output_graph, "./", name, as_text=True)

    return output_graph_def
Ejemplo n.º 29
0
def create_vocabulary(vocabulary_path,
                      data_path,
                      max_vocabulary_size,
                      tokenizer=None,
                      normalize_digits=True,
                      _DIGIT_RE=re.compile(br"\d"),
                      _START_VOCAB=None):
    r"""Create vocabulary file (if it does not exist yet) from data file.

    Data file is assumed to contain one sentence per line. Each sentence is
    tokenized and digits are normalized (if normalize_digits is set).
    Vocabulary contains the most-frequent tokens up to max_vocabulary_size.
    We write it to vocabulary_path in a one-token-per-line format, so that later
    token in the first line gets id=0, second line gets id=1, and so on.

    Parameters
    -----------
    vocabulary_path : str
        Path where the vocabulary will be created.
    data_path : str
        Data file that will be used to create vocabulary.
    max_vocabulary_size : int
        Limit on the size of the created vocabulary.
    tokenizer : function
        A function to use to tokenize each data sentence. If None, basic_tokenizer will be used.
    normalize_digits : boolean
        If true, all digits are replaced by `0`.
    _DIGIT_RE : regular expression function
        Default is ``re.compile(br"\d")``.
    _START_VOCAB : list of str
        The pad, go, eos and unk token, default is ``[b"_PAD", b"_GO", b"_EOS", b"_UNK"]``.

    References
    ----------
    - Code from ``/tensorflow/models/rnn/translation/data_utils.py``

    """
    if _START_VOCAB is None:
        _START_VOCAB = [b"_PAD", b"_GO", b"_EOS", b"_UNK"]
    if not gfile.Exists(vocabulary_path):
        tl.logging.info("Creating vocabulary %s from data %s" %
                        (vocabulary_path, data_path))
        vocab = {}
        with gfile.GFile(data_path, mode="rb") as f:
            counter = 0
            for line in f:
                counter += 1
                if counter % 100000 == 0:
                    tl.logging.info("  processing line %d" % counter)
                tokens = tokenizer(line) if tokenizer else basic_tokenizer(
                    line)
                for w in tokens:
                    word = re.sub(_DIGIT_RE, b"0",
                                  w) if normalize_digits else w
                    if word in vocab:
                        vocab[word] += 1
                    else:
                        vocab[word] = 1
            vocab_list = _START_VOCAB + sorted(
                vocab, key=vocab.get, reverse=True)
            if len(vocab_list) > max_vocabulary_size:
                vocab_list = vocab_list[:max_vocabulary_size]
            with gfile.GFile(vocabulary_path, mode="wb") as vocab_file:
                for w in vocab_list:
                    vocab_file.write(w + b"\n")
    else:
        tl.logging.info("Vocabulary %s from data %s exists" %
                        (vocabulary_path, data_path))
Ejemplo n.º 30
0
def freeze_graph_with_def_protos(input_graph_def,
                                 input_saver_def,
                                 input_checkpoint,
                                 output_node_names,
                                 restore_op_name,
                                 filename_tensor_name,
                                 output_graph,
                                 clear_devices,
                                 initializer_nodes,
                                 variable_names_whitelist="",
                                 variable_names_blacklist="",
                                 input_meta_graph_def=None,
                                 input_saved_model_dir=None,
                                 saved_model_tags=None,
                                 checkpoint_version=saver_pb2.SaverDef.V2):
    """Converts all variables in a graph and checkpoint into constants.
  Args:
    input_graph_def: A `GraphDef`.
    input_saver_def: A `SaverDef` (optional).
    input_checkpoint: The prefix of a V1 or V2 checkpoint, with V2 taking
      priority.  Typically the result of `Saver.save()` or that of
      `tf.train.latest_checkpoint()`, regardless of sharded/non-sharded or
      V1/V2.
    output_node_names: The name(s) of the output nodes, comma separated.
    restore_op_name: Unused.
    filename_tensor_name: Unused.
    output_graph: String where to write the frozen `GraphDef`.
    clear_devices: A Bool whether to remove device specifications.
    initializer_nodes: Comma separated string of initializer nodes to run before
                       freezing.
    variable_names_whitelist: The set of variable names to convert (optional, by
                              default, all variables are converted).
    variable_names_blacklist: The set of variable names to omit converting
                              to constants (optional).
    input_meta_graph_def: A `MetaGraphDef` (optional),
    input_saved_model_dir: Path to the dir with TensorFlow 'SavedModel' file
                           and variables (optional).
    saved_model_tags: Group of comma separated tag(s) of the MetaGraphDef to
                      load, in string format (optional).
    checkpoint_version: Tensorflow variable file format (saver_pb2.SaverDef.V1
                        or saver_pb2.SaverDef.V2)
  Returns:
    Location of the output_graph_def.
  """
    del restore_op_name, filename_tensor_name  # Unused by updated loading code.

    # 'input_checkpoint' may be a prefix if we're using Saver V2 format
    if (not input_saved_model_dir
            and not checkpoint_management.checkpoint_exists(input_checkpoint)):
        print("Input checkpoint '" + input_checkpoint + "' doesn't exist!")
        return -1

    if not output_node_names:
        print("You need to supply the name of a node to --output_node_names.")
        return -1

    # Remove all the explicit device specifications for this node. This helps to
    # make the graph more portable.
    if clear_devices:
        if input_meta_graph_def:
            for node in input_meta_graph_def.graph_def.node:
                node.device = ""
        elif input_graph_def:
            for node in input_graph_def.node:
                node.device = ""

    if input_graph_def:
        _ = importer.import_graph_def(input_graph_def, name="")
    with session.Session() as sess:
        if input_saver_def:
            saver = saver_lib.Saver(saver_def=input_saver_def,
                                    write_version=checkpoint_version)
            saver.restore(sess, input_checkpoint)
        elif input_meta_graph_def:
            restorer = saver_lib.import_meta_graph(input_meta_graph_def,
                                                   clear_devices=True)
            restorer.restore(sess, input_checkpoint)
            if initializer_nodes:
                sess.run(initializer_nodes.replace(" ", "").split(","))
        elif input_saved_model_dir:
            if saved_model_tags is None:
                saved_model_tags = []
            loader.load(sess, saved_model_tags, input_saved_model_dir)
        else:
            var_list = {}
            reader = pywrap_tensorflow.NewCheckpointReader(input_checkpoint)
            var_to_shape_map = reader.get_variable_to_shape_map()

            # List of all partition variables. Because the condition is heuristic
            # based, the list could include false positives.
            all_parition_variable_names = [
                tensor.name.split(":")[0]
                for op in sess.graph.get_operations()
                for tensor in op.values()
                if re.search(r"/part_\d+/", tensor.name)
            ]
            has_partition_var = False

            for key in var_to_shape_map:
                try:
                    tensor = sess.graph.get_tensor_by_name(key + ":0")
                    if any(key in name
                           for name in all_parition_variable_names):
                        has_partition_var = True
                except KeyError:
                    # This tensor doesn't exist in the graph (for example it's
                    # 'global_step' or a similar housekeeping element) so skip it.
                    continue
                var_list[key] = tensor

            try:
                saver = saver_lib.Saver(var_list=var_list,
                                        write_version=checkpoint_version)
            except TypeError as e:
                # `var_list` is required to be a map of variable names to Variable
                # tensors. Partition variables are Identity tensors that cannot be
                # handled by Saver.
                if has_partition_var:
                    print(
                        "Models containing partition variables cannot be converted "
                        "from checkpoint files. Please pass in a SavedModel using "
                        "the flag --input_saved_model_dir.")
                    return -1
                # Models that have been frozen previously do not contain Variables.
                elif _has_no_variables(sess):
                    print(
                        "No variables were found in this model. It is likely the model "
                        "was frozen previously. You cannot freeze a graph twice."
                    )
                    return 0
                else:
                    raise e

            saver.restore(sess, input_checkpoint)
            if initializer_nodes:
                sess.run(initializer_nodes.replace(" ", "").split(","))

        variable_names_whitelist = (variable_names_whitelist.replace(
            " ", "").split(",") if variable_names_whitelist else None)
        variable_names_blacklist = (variable_names_blacklist.replace(
            " ", "").split(",") if variable_names_blacklist else None)

        if input_meta_graph_def:
            output_graph_def = graph_util.convert_variables_to_constants(
                sess,
                input_meta_graph_def.graph_def,
                output_node_names.replace(" ", "").split(","),
                variable_names_whitelist=variable_names_whitelist,
                variable_names_blacklist=variable_names_blacklist)
        else:
            output_graph_def = graph_util.convert_variables_to_constants(
                sess,
                input_graph_def,
                output_node_names.replace(" ", "").split(","),
                variable_names_whitelist=variable_names_whitelist,
                variable_names_blacklist=variable_names_blacklist)

    # Write GraphDef to file if output path has been given.
    if output_graph:
        with gfile.GFile(output_graph, "wb") as f:
            f.write(output_graph_def.SerializeToString())

    return output_graph_def