Ejemplo n.º 1
0
    def after_run(self, run_context, run_values):

        predictions_val = run_values.results['predictions']
        labels_val = run_values.results['labels']
        step = run_values.results['step']
        hit_at_one = eval_util.calculate_hit_at_one(predictions_val,
                                                    labels_val)
        perr = eval_util.calculate_precision_at_equal_recall_rate(
            predictions_val, labels_val)
        gap = eval_util.calculate_gap(predictions_val, labels_val)

        self.writer.add_summary(
            utils.MakeSummary("model/Training_Hit@1", hit_at_one), step)
        self.writer.add_summary(utils.MakeSummary("model/Training_Perr", perr),
                                step)
        self.writer.add_summary(utils.MakeSummary("model/Training_GAP", gap),
                                step)
Ejemplo n.º 2
0
def supervised_tasks(self, sv, res):
  global_step = res["global_step"]
  predictions = res["predictions"]
  if type(predictions) == list:
    predictions = eval_util.transform_preds(self, predictions)
  dense_labels = res["dense_labels"]

  hit_at_one = eval_util.calculate_hit_at_one(predictions, dense_labels)
  perr = eval_util.calculate_precision_at_equal_recall_rate(predictions,
                                                            dense_labels)
  gap = eval_util.calculate_gap(predictions, dense_labels)

  log_info = {
      "Training step": global_step,
      "Hit@1": hit_at_one,
      "PERR": perr,
      "GAP": gap,
      "Loss": res["loss"],
      "Global norm": res["global_norm"],
      "Exps/sec": res["examples_per_second"],
  }

  if self.is_chief and global_step % 10 == 0 and self.config.train_dir:
    sv.summary_writer.add_summary(
        utils.MakeSummary("model/Training_Hit@1",
                          hit_at_one), global_step)
    sv.summary_writer.add_summary(
        utils.MakeSummary("model/Training_Perr", perr),
        global_step)
    sv.summary_writer.add_summary(
        utils.MakeSummary("model/Training_GAP", gap),
        global_step)
    sv.summary_writer.add_summary(
        utils.MakeSummary("global_step/Examples/Second",
                          res["examples_per_second"]),
        global_step)
    sv.summary_writer.flush()
  return log_info
Ejemplo n.º 3
0
  def run(self, start_new_model=False):
    """Performs training on the currently defined Tensorflow graph.

    Returns:
      A tuple of the training Hit@1 and the training PERR.
    """
    if self.is_master and start_new_model:
      self.remove_training_directory(self.train_dir)

    if not os.path.exists(self.train_dir):
      os.makedirs(self.train_dir)

    model_flags_dict = {
        "model": FLAGS.model,
        "feature_sizes": FLAGS.feature_sizes,
        "feature_names": FLAGS.feature_names,
        "frame_features": FLAGS.frame_features,
        "label_loss": FLAGS.label_loss,
    }
    flags_json_path = os.path.join(FLAGS.train_dir, "model_flags.json")
    if os.path.exists(flags_json_path):
      existing_flags = json.load(open(flags_json_path))
      if existing_flags != model_flags_dict:
        logging.error("Model flags do not match existing file %s. Please "
                      "delete the file, change --train_dir, or pass flag "
                      "--start_new_model",
                      flags_json_path)
        logging.error("Ran model with flags: %s", str(model_flags_dict))
        logging.error("Previously ran with flags: %s", str(existing_flags))
        exit(1)
    else:
      # Write the file.
      with open(flags_json_path, "w") as fout:

        fout.write(json.dumps(model_flags_dict))

    target, device_fn = self.start_server_if_distributed()

    meta_filename = self.get_meta_filename(start_new_model, self.train_dir)

    with tf.Graph().as_default() as graph:
      if meta_filename:
        saver = self.recover_model(meta_filename)

      with tf.device(device_fn):
        if not meta_filename:
          saver = self.build_model(self.model, self.reader)

        global_step = tf.get_collection("global_step")[0]
        loss = tf.get_collection("loss")[0]
        predictions = tf.get_collection("predictions")[0]
        labels = tf.get_collection("labels")[0]
        labels_origin = tf.get_collection("original_labels")[0]
        sampling_distribution = tf.get_collection("sampling_probs")[0]
        train_op = tf.get_collection("train_op")[0]
        init_op = tf.global_variables_initializer()

    sv = tf.train.Supervisor(
        graph,
        logdir=self.train_dir,
        init_op=init_op,
        is_chief=self.is_master,
        global_step=global_step,
        #save_model_secs=15 * 60,
        save_model_secs=int(FLAGS.save_checkpoint_every_n_hour * 3600),
        #save_summaries_secs=120,
        save_summaries_secs=int(FLAGS.save_checkpoint_every_n_hour * 3600),
        saver=saver)
    logging.info("%s: Starting managed session.", task_as_string(self.task))
    with sv.managed_session(target, config=self.config) as sess:
      try:
        logging.info("%s: Entering training loop.", task_as_string(self.task))
        while (not sv.should_stop()) and (not self.max_steps_reached):
          batch_start_time = time.time()
          _, global_step_val, loss_val, predictions_val, labels_val, labels_origin_val, sampling_distribution_val = sess.run(
              [train_op, global_step, loss, predictions, labels, labels_origin, sampling_distribution])
          seconds_per_batch = time.time() - batch_start_time
          examples_per_second = labels_val.shape[0] / seconds_per_batch

          ## check on the correctness of label sampling
          #if not np.all(labels_val[0]==labels_origin_val[0]):
          #    print(np.where(labels_val[0])[0], " --- ", np.where(labels_origin_val[0])[0])
          #print(np.all(labels_val == labels_origin_val, 1).mean())

          #temp_bool = np.sum(labels_val, 1)<1
          #if np.any(temp_bool):
          #  print("\n\n\nFOUND!!", np.where(labels_val[temp_bool])[0], "\n\n\n")
          #print(sampling_distribution_val[:5, :5])


          if self.max_steps and self.max_steps <= global_step_val:
            self.max_steps_reached = True

          #if self.is_master and global_step_val % 10 == 0 and self.train_dir:
          if self.is_master and global_step_val % FLAGS.validate_every_n_training_steps == 0 and self.train_dir:
            eval_start_time = time.time()
            hit_at_one = eval_util.calculate_hit_at_one(predictions_val, labels_origin_val)
            perr = eval_util.calculate_precision_at_equal_recall_rate(predictions_val,
                                                                      labels_origin_val)
            gap = eval_util.calculate_gap(predictions_val, labels_origin_val)
            eval_end_time = time.time()
            eval_time = eval_end_time - eval_start_time

            logging.info("training step " + str(global_step_val) + " | Loss: " + ("%.2f" % loss_val) +
              " Examples/sec: " + ("%.2f" % examples_per_second) + " | Hit@1: " +
              ("%.2f" % hit_at_one) + " PERR: " + ("%.2f" % perr) +
              " GAP: " + ("%.2f" % gap))

            sv.summary_writer.add_summary(
                utils.MakeSummary("model/Training_Hit@1", hit_at_one),
                global_step_val)
            sv.summary_writer.add_summary(
                utils.MakeSummary("model/Training_Perr", perr), global_step_val)
            sv.summary_writer.add_summary(
                utils.MakeSummary("model/Training_GAP", gap), global_step_val)
            sv.summary_writer.add_summary(
                utils.MakeSummary("global_step/Examples/Second",
                                  examples_per_second), global_step_val)
            sv.summary_writer.flush()

            with open(FLAGS.train_dir + '/global_step_{%d}_training_GAP_{%.6f}.txt' % (global_step_val, gap), 'w') as f:
              f.write('\n')

            # Exporting the model every x steps
            time_to_export = ((self.last_model_export_step == 0) or
                (global_step_val - self.last_model_export_step
                 >= self.export_model_steps))

            if self.is_master and time_to_export:
              self.export_model(global_step_val, sv.saver, sv.save_path, sess)
              self.last_model_export_step = global_step_val
          else:
            #logging.info("training step " + str(global_step_val) + " | Loss: " +
              #("%.2f" % loss_val) + " Examples/sec: " + ("%.2f" % examples_per_second))
            continue
      except tf.errors.OutOfRangeError:
        logging.info("%s: Done training -- epoch limit reached.",
                     task_as_string(self.task))

    logging.info("%s: Exited training loop.", task_as_string(self.task))
    sv.Stop()
Ejemplo n.º 4
0
    def run(self, start_new_model=False):
        """Performs training on the currently defined Tensorflow graph.

    Returns:
      A tuple of the training Hit@1 and the training PERR.
    """
        if self.is_master and start_new_model:
            self.remove_training_directory(self.train_dir)

        target, device_fn = self.start_server_if_distributed()

        meta_filename = self.get_meta_filename(start_new_model, self.train_dir)

        with tf.Graph().as_default() as graph:

            if meta_filename:
                saver = self.recover_model(meta_filename)

            with tf.device(device_fn):
                if not meta_filename:
                    saver = self.build_model(self.model, self.reader)

                global_step = tf.get_collection("global_step")[0]
                restart_learning_rate = tf.get_collection(
                    "restart_learning_rate")[0]
                layers_keep_probs = tf.get_collection("layers_keep_probs")[0]
                loss = tf.get_collection("loss")[0]
                predictions = tf.get_collection("predictions")[0]
                labels = tf.get_collection("labels")[0]
                train_op = tf.get_collection("train_op")[0]
                if FLAGS.use_ema == True:
                    ema_op = tf.get_collection("ema_op")[0]

                e_loss = tf.get_collection("e_loss")[0]
                e_labels = tf.get_collection("e_labels")[0]
                e_predictions = tf.get_collection("e_predictions")[0]

                init_op = tf.global_variables_initializer()
                restart_op = tf.assign(restart_learning_rate,
                                       FLAGS.restart_learning_rate)
                # getting a proper number of keep_prob parameters for dropout
                # max is 10 and we have to pad the vector with 1s
                # not the nicest solution, but works
                tmp_layers = []
                if FLAGS.layers_keep_probs is not None:
                    tmp_layers = [
                        float(x) for x in FLAGS.layers_keep_probs.replace(
                            ' ', '').split(',')
                    ]

                tmp_layers_padded = tmp_layers + [
                    1.0 for x in range(10 - len(tmp_layers))
                ]
                with tf.variable_scope("tower", reuse=True) as scope:
                    keep_op = tf.assign(layers_keep_probs, tmp_layers_padded)

        sv = tf.train.Supervisor(graph,
                                 logdir=self.train_dir,
                                 init_op=init_op,
                                 is_chief=self.is_master,
                                 global_step=global_step,
                                 save_model_secs=FLAGS.save_model_minutes * 60,
                                 save_summaries_secs=120,
                                 saver=saver)

        logging.info("%s: Starting managed session.",
                     task_as_string(self.task))
        with sv.managed_session(target, config=self.config) as sess:
            try:
                if FLAGS.restart_learning_rate > 0.0:
                    sess.run(restart_op)
                    logging.info("restart learning rate: %f\n" %
                                 FLAGS.restart_learning_rate)
                if FLAGS.layers_keep_probs != "1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0":
                    logging.info("============")
                    sess.run(keep_op)
                    logging.info("layers keep probabilites: %s" %
                                 FLAGS.layers_keep_probs)
                logging.info("%s: Entering training loop.",
                             task_as_string(self.task))
                while (not sv.should_stop()) and (not self.max_steps_reached):
                    batch_start_time = time.time()
                    _, global_step_val, loss_val, predictions_val, labels_val = sess.run(
                        [train_op, global_step, loss, predictions, labels])
                    seconds_per_batch = time.time() - batch_start_time
                    examples_per_second = labels_val.shape[
                        0] / seconds_per_batch

                    if self.max_steps and self.max_steps <= global_step_val:
                        self.max_steps_reached = True

                    if self.is_master and global_step_val % 50 == 0 and self.train_dir:
                        eval_start_time = time.time()
                        hit_at_one = eval_util.calculate_hit_at_one(
                            predictions_val, labels_val)
                        perr = eval_util.calculate_precision_at_equal_recall_rate(
                            predictions_val, labels_val)
                        gap = eval_util.calculate_gap(predictions_val,
                                                      labels_val)
                        eval_end_time = time.time()
                        eval_time = eval_end_time - eval_start_time

                        logging.info("training step " + str(global_step_val) +
                                     " |  Loss: " + ("%.2f" % loss_val) +
                                     " | Hit@1: " + ("%.2f" % hit_at_one) +
                                     "  PERR: " + ("%.2f" % perr) + "  GAP: " +
                                     ("%.4f" % gap))

                        sv.summary_writer.add_summary(
                            utils.MakeSummary("model/Training_Hit@1",
                                              hit_at_one), global_step_val)
                        sv.summary_writer.add_summary(
                            utils.MakeSummary("model/Training_Perr", perr),
                            global_step_val)
                        sv.summary_writer.add_summary(
                            utils.MakeSummary("model/Training_GAP", gap),
                            global_step_val)
                        sv.summary_writer.add_summary(
                            utils.MakeSummary("global_step/Examples/Second",
                                              examples_per_second),
                            global_step_val)

                        #also do eval
                        e_loss_val, e_predictions_val, e_labels_val = sess.run(
                            [e_loss, e_predictions, e_labels])
                        e_hit_at_one = eval_util.calculate_hit_at_one(
                            e_predictions_val, e_labels_val)
                        e_perr = eval_util.calculate_precision_at_equal_recall_rate(
                            e_predictions_val, e_labels_val)
                        e_gap = eval_util.calculate_gap(
                            e_predictions_val, e_labels_val)
                        logging.info("training step " + str(global_step_val) +
                                     " | eLoss: " + ("%.2f" % e_loss_val) +
                                     " |eHit@1: " + ("%.2f" % e_hit_at_one) +
                                     " ePERR: " + ("%.2f" % e_perr) +
                                     " eGAP: " + ("%.4f" % e_gap))

                        sv.summary_writer.add_summary(
                            utils.MakeSummary("model/Eval_Hit@1",
                                              e_hit_at_one), global_step_val)
                        sv.summary_writer.add_summary(
                            utils.MakeSummary("model/Eval_Perr", e_perr),
                            global_step_val)
                        sv.summary_writer.add_summary(
                            utils.MakeSummary("model/Eval_GAP", e_gap),
                            global_step_val)

                        sv.summary_writer.flush()

                        # Exporting the model every x steps
                        time_to_export = (
                            (self.last_model_export_step == 0)
                            or (global_step_val - self.last_model_export_step
                                >= self.export_model_steps))

                        if self.is_master and time_to_export:
                            self.export_model(global_step_val, sv.saver,
                                              sv.save_path, sess)
                            self.last_model_export_step = global_step_val
                    else:
                        logging.info("training step " + str(global_step_val) +
                                     " | Loss: " + ("%.2f" % loss_val) +
                                     " Examples/sec: " +
                                     ("%.2f" % examples_per_second))
            except tf.errors.OutOfRangeError:
                logging.info("%s: Done training -- epoch limit reached.",
                             task_as_string(self.task))

        logging.info("%s: Exited training loop.", task_as_string(self.task))
        sv.Stop()
Ejemplo n.º 5
0
  def run(self, start_new_model=False):
    """Performs training on the currently defined Tensorflow graph.

    Returns:
      A tuple of the training Hit@1 and the training PERR.
    """
    if self.is_master and start_new_model:
      self.remove_training_directory(self.train_dir)

    if not os.path.exists(self.train_dir):
      os.makedirs(self.train_dir)

    model_flags_dict = {
        "model": FLAGS.model,
        "feature_sizes": FLAGS.feature_sizes,
        "feature_names": FLAGS.feature_names,
        "frame_features": FLAGS.frame_features,
        "label_loss": FLAGS.label_loss,
    }
    flags_json_path = os.path.join(FLAGS.train_dir, "model_flags.json")
    if os.path.exists(flags_json_path):
      existing_flags = json.load(open(flags_json_path))
      if existing_flags != model_flags_dict:
        logging.error("Model flags do not match existing file %s. Please "
                      "delete the file, change --train_dir, or pass flag "
                      "--start_new_model",
                      flags_json_path)
        logging.error("Ran model with flags: %s", str(model_flags_dict))
        logging.error("Previously ran with flags: %s", str(existing_flags))
        exit(1)
    else:
      # Write the file.
      with open(flags_json_path, "w") as fout:
        fout.write(json.dumps(model_flags_dict))

    target, device_fn = self.start_server_if_distributed()

    meta_filename = self.get_meta_filename(start_new_model, self.train_dir)

    with tf.Graph().as_default() as graph:
      if meta_filename:
        saver = self.recover_model(meta_filename)

      with tf.device(device_fn):
        if not meta_filename:
          saver = self.build_model(self.model, self.reader)

        global_step = tf.get_collection("global_step")[0]
        loss = tf.get_collection("loss")[0]
        predictions = tf.get_collection("predictions")[0]
        labels = tf.get_collection("labels")[0]
        train_op = tf.get_collection("train_op")[0]
        init_op = tf.global_variables_initializer()
        if FLAGS.ema_source:
            # Here the variables still exsist
            ema_op = tf.get_collection("ema_op")[0]
            def_vars = tf.get_collection("updatable_vars")
            ema_vars = tf.get_collection("ema_vars")

    sv = tf.train.Supervisor(
        graph,
        logdir=self.train_dir,
        init_op=init_op,
        is_chief=self.is_master,
        global_step=global_step,
        save_model_secs=15 * 60,
        save_summaries_secs=120,
        saver=saver)

    logging.info("%s: Starting managed session.", task_as_string(self.task))
    with sv.managed_session(target, config=self.config) as sess:

      if FLAGS.ema_source:
        logging.info("%s: Entering training loop.", task_as_string(self.task))
        sess.graph._unsafe_unfinalize()
        ckpt_reader = pywrap_tensorflow.NewCheckpointReader(FLAGS.ema_source)
        for xtensor, ematensor in zip(def_vars, ema_vars):
            src_tensor = ckpt_reader.get_tensor(xtensor.name.split(":")[0])
            # Loading does not take up graoh space
            xtensor.load(src_tensor, session=sess)
            ematensor.load(src_tensor, session=sess)
            # sess.run(tf.assign(xtensor, src_tensor))
            # sess.run(tf.assign(ematensor, src_tensor))
      try:
        logging.info("%s: Entering training loop.", task_as_string(self.task))
        while (not sv.should_stop()) and (not self.max_steps_reached):
          batch_start_time = time.time()
          _, global_step_val, loss_val, predictions_val, labels_val = sess.run(
              [train_op, global_step, loss, predictions, labels])
          if FLAGS.ema_source:  # Update EMA if needed
              _ = sess.run(ema_op)

          seconds_per_batch = time.time() - batch_start_time
          examples_per_second = labels_val.shape[0] / seconds_per_batch

          if self.max_steps and self.max_steps <= global_step_val:
            self.max_steps_reached = True

          if self.is_master and global_step_val % 10 == 0 and self.train_dir:
            eval_start_time = time.time()
            hit_at_one = eval_util.calculate_hit_at_one(predictions_val, labels_val)
            perr = eval_util.calculate_precision_at_equal_recall_rate(predictions_val,
                                                                      labels_val)
            gap = eval_util.calculate_gap(predictions_val, labels_val)
            eval_end_time = time.time()
            eval_time = eval_end_time - eval_start_time

            logging.info("training step " + str(global_step_val) + " | Loss: " + ("%.2f" % loss_val) +
              " Examples/sec: " + ("%.2f" % examples_per_second) + " | Hit@1: " +
              ("%.2f" % hit_at_one) + " PERR: " + ("%.2f" % perr) +
              " GAP: " + ("%.2f" % gap))

            sv.summary_writer.add_summary(
                utils.MakeSummary("model/Training_Hit@1", hit_at_one),
                global_step_val)
            sv.summary_writer.add_summary(
                utils.MakeSummary("model/Training_Perr", perr), global_step_val)
            sv.summary_writer.add_summary(
                utils.MakeSummary("model/Training_GAP", gap), global_step_val)
            sv.summary_writer.add_summary(
                utils.MakeSummary("global_step/Examples/Second",
                                  examples_per_second), global_step_val)
            sv.summary_writer.flush()

            # Exporting the model every x steps
            time_to_export = ((self.last_model_export_step == 0) or
                (global_step_val - self.last_model_export_step
                 >= self.export_model_steps))

            if self.is_master and time_to_export:
              # self.export_model(global_step_val, sv.saver, sv.save_path, sess)
              self.last_model_export_step = global_step_val
          else:
            logging.info("training step " + str(global_step_val) + " | Loss: " +
              ("%.2f" % loss_val) + " Examples/sec: " + ("%.2f" % examples_per_second))
      except tf.errors.OutOfRangeError:
        save_name = "{0}/model.ckpt".format(self.train_dir)
        saver.save(sess, save_name, global_step_val)
        logging.info("Final model export.")
        logging.info("%s: Done training -- epoch limit reached.",
                     task_as_string(self.task))

    logging.info("%s: Exited training loop.", task_as_string(self.task))
    sv.Stop()
Ejemplo n.º 6
0
    def run(self, start_new_model=False):
        """Performs training on the currently defined Tensorflow graph.

    Returns:
      A tuple of the training Hit@1 and the training PERR.
    """
        if self.is_master and start_new_model:
            self.remove_training_directory(self.train_dir)

        target, device_fn = self.start_server_if_distributed()

        latest_checkpoint, meta_filename = self.get_meta_filename(
            start_new_model, self.train_dir)

        with tf.Graph().as_default() as graph:

            if meta_filename:
                if not FLAGS.change_file:
                    saver = self.recover_model(meta_filename)

            with tf.device(device_fn):
                if not meta_filename:
                    saver = self.build_model(self.model, self.reader)
                if FLAGS.change_file:
                    saver = self.build_model(self.model, self.reader)

                global_step = tf.get_collection("global_step")[0]
                loss = tf.get_collection("loss")[0]
                predictions = tf.get_collection("predictions")[0]
                labels = tf.get_collection("labels")[0]
                train_op = tf.get_collection("train_op")[0]
                init_op = tf.global_variables_initializer()

        sv = tf.train.Supervisor(graph,
                                 logdir=self.train_dir,
                                 init_op=init_op,
                                 is_chief=self.is_master,
                                 global_step=global_step,
                                 save_model_secs=60 * FLAGS.time_to_save_model,
                                 save_summaries_secs=120,
                                 saver=saver)

        logging.info("%s: Starting managed session.",
                     task_as_string(self.task))
        with sv.managed_session(target, config=self.config) as sess:
            try:
                logging.info("%s: Entering training loop.",
                             task_as_string(self.task))

                logging.info("TANG:restoring")
                saver.restore(sess, latest_checkpoint)

                while (not sv.should_stop()) and (not self.max_steps_reached):
                    batch_start_time = time.time()
                    _, global_step_val, loss_val, predictions_val, labels_val = sess.run(
                        [train_op, global_step, loss, predictions, labels])
                    seconds_per_batch = time.time() - batch_start_time
                    examples_per_second = labels_val.shape[
                        0] / seconds_per_batch

                    if self.max_steps and self.max_steps <= global_step_val:
                        self.max_steps_reached = True

                    if self.is_master and global_step_val % FLAGS.eval_loop == 0 and self.train_dir:
                        eval_start_time = time.time()
                        hit_at_one = eval_util.calculate_hit_at_one(
                            predictions_val, labels_val)
                        perr = eval_util.calculate_precision_at_equal_recall_rate(
                            predictions_val, labels_val)
                        gap = eval_util.calculate_gap(predictions_val,
                                                      labels_val)
                        eval_end_time = time.time()
                        eval_time = eval_end_time - eval_start_time

                        logging.info("training step " + str(global_step_val) +
                                     " | Loss: " + ("%.2f" % loss_val) +
                                     " Examples/sec: " +
                                     ("%.2f" % examples_per_second) +
                                     " | Hit@1: " + ("%.2f" % hit_at_one) +
                                     " PERR: " + ("%.2f" % perr) + " GAP: " +
                                     ("%.2f" % gap))

                        sv.summary_writer.add_summary(
                            utils.MakeSummary("model/Training_Hit@1",
                                              hit_at_one), global_step_val)
                        sv.summary_writer.add_summary(
                            utils.MakeSummary("model/Training_Perr", perr),
                            global_step_val)
                        sv.summary_writer.add_summary(
                            utils.MakeSummary("model/Training_GAP", gap),
                            global_step_val)
                        sv.summary_writer.add_summary(
                            utils.MakeSummary("global_step/Examples/Second",
                                              examples_per_second),
                            global_step_val)
                        sv.summary_writer.flush()

                        # Exporting the model every x steps
                        time_to_export = (
                            (self.last_model_export_step == 0)
                            or (global_step_val - self.last_model_export_step
                                >= self.export_model_steps))

                        if self.is_master and time_to_export:
                            self.export_model(global_step_val, sv.saver,
                                              sv.save_path, sess)
                            self.last_model_export_step = global_step_val
                    else:
                        logging.info("training step " + str(global_step_val) +
                                     " | Loss: " + ("%.2f" % loss_val) +
                                     " Examples/sec: " +
                                     ("%.2f" % examples_per_second))
            except tf.errors.OutOfRangeError:
                logging.info("%s: Done training -- epoch limit reached.",
                             task_as_string(self.task))

        logging.info("%s: Exited training loop.", task_as_string(self.task))
        sv.Stop()
Ejemplo n.º 7
0
def train():
    with tf.Graph().as_default(), tf.device('/cpu:0'):
        global_step = tf.Variable(0, trainable=False)
        video_id, labels, rgb, audio, num_frames = readers.input(True)
        coord = tf.train.Coordinator()

        lr = tf.train.exponential_decay(FLAGS.lr,
                                        global_step,
                                        FLAGS.decay_steps,
                                        FLAGS.learning_decay_rate,
                                        staircase=True)
        tf.summary.scalar('learning_rate', lr)
        opt = tf.train.AdamOptimizer(lr,
                                     beta1=FLAGS.beta1,
                                     beta2=FLAGS.beta2,
                                     epsilon=1e-08,
                                     use_locking=False,
                                     name='Adam')
        grads = inference(rgb, audio, num_frames, label=labels, train=True)
        loss = tf.get_collection("loss")[0]
        predict = tf.get_collection("predict")[0]

        tvars = tf.trainable_variables()

        for var in tvars:
            tf.summary.histogram(var.op.name, var)

        for grad, var in grads:
            print var.op.name
            if grad is not None and type(grad) is not tf.IndexedSlices:
                tf.summary.histogram(var.op.name + '/gradients', grad)
            elif type(grad) is tf.IndexedSlices:
                print "This is a indexslice gradient"
                print grad.dense_shape
            else:
                print "There is a None gradient"

        apply_gradient_op = opt.apply_gradients(grads, global_step=global_step)
        variable_averages = tf.train.ExponentialMovingAverage(
            FLAGS.moving_average_decay, global_step)
        variables_averages_op = variable_averages.apply(
            tf.trainable_variables())

        train_op = tf.group(apply_gradient_op, variables_averages_op)

        saver = tf.train.Saver(tf.global_variables())
        summary_op = tf.summary.merge_all()
        init = tf.group(tf.global_variables_initializer(),
                        tf.local_variables_initializer())

        config = tf.ConfigProto(log_device_placement=False,
                                allow_soft_placement=True)
        config.intra_op_parallelism_threads = 10
        config.inter_op_parallelism_threads = 16
        sess = tf.Session(config=config)
        sess.run(init)
        tf.train.start_queue_runners(sess=sess, coord=coord)
        summary_writer = tf.summary.FileWriter(FLAGS.train_dir, sess.graph)

        ckpt = tf.train.get_checkpoint_state(FLAGS.train_dir)
        if not FLAGS.new_model:
            saver.restore(sess, ckpt.model_checkpoint_path)

        #loader.restore(sess, ckpt.model_checkpoint_path)

        for step in xrange(FLAGS.max_steps):

            start_time = time.time()
            _, loss_value, predict_value, labels_value, num_frames_value = sess.run(
                [train_op, loss, predict, labels, num_frames])
            duration = time.time() - start_time

            assert not np.isnan(loss_value), 'Model diverged with loss = NaN'

            if step % 10 == 0:
                num_examples_per_step = FLAGS.batch_size
                examples_per_sec = num_examples_per_step / duration
                sec_per_batch = float(duration)

                hit_at_one = eval_util.calculate_hit_at_one(
                    predict_value, labels_value)
                perr = eval_util.calculate_precision_at_equal_recall_rate(
                    predict_value, labels_value)
                gap = eval_util.calculate_gap(predict_value, labels_value)

                format_str = (
                    '%s: step %d, loss = %.2f, hit@one = %.2f, perr = %.2f, gap = %.2f, (%.1f examples/sec; %.3f '
                    'sec/batch)')
                print(format_str %
                      (datetime.now(), step, loss_value, hit_at_one, perr, gap,
                       examples_per_sec, sec_per_batch))

            if step % 100 == 0:
                summary_str = sess.run(summary_op)
                summary_writer.add_summary(summary_str, step)
                summary_writer.add_summary(
                    utils.MakeSummary("Hit@1", hit_at_one), step)
                summary_writer.add_summary(utils.MakeSummary("Perr", perr),
                                           step)
                summary_writer.add_summary(utils.MakeSummary("Gap", gap), step)
                summary_writer.add_summary(
                    utils.MakeSummary("example per second", examples_per_sec),
                    step)

            if (step % 1000 == 0
                    and step != 0) or (step + 1) == FLAGS.max_steps:
                checkpoint_path = os.path.join(FLAGS.train_dir, 'model.ckpt')
                saver.save(sess, checkpoint_path, global_step=step)
Ejemplo n.º 8
0
    def run(self, start_new_model=False):
        """Performs training on the currently defined Tensorflow graph.

    Returns:
      A tuple of the training Hit@1 and the training PERR.
    """
        if self.is_master and start_new_model:
            self.remove_training_directory(self.train_dir)

        target, device_fn = self.start_server_if_distributed()

        meta_filename = self.get_meta_filename(start_new_model, self.train_dir)

        with tf.Graph().as_default() as graph:

            if meta_filename:
                saver = self.recover_model(meta_filename)

            with tf.device(device_fn):

                if not meta_filename:
                    saver = self.build_model()

                global_step = tf.get_collection("global_step")[0]
                loss = tf.get_collection("loss")[0]
                predictions = tf.get_collection("predictions")[0]
                labels = tf.get_collection("labels")[0]
                train_op = tf.get_collection("train_op")[0]
                init_op = tf.global_variables_initializer()

        sv = tf.train.Supervisor(graph,
                                 logdir=self.train_dir,
                                 init_op=init_op,
                                 is_chief=self.is_master,
                                 global_step=global_step,
                                 save_model_secs=15 * 60,
                                 save_summaries_secs=120,
                                 saver=saver)

        logging.info("%s: Starting managed session.",
                     task_as_string(self.task))
        with sv.managed_session(target, config=self.config) as sess:

            try:
                logging.info("%s: Entering training loop.",
                             task_as_string(self.task))
                while not sv.should_stop():

                    batch_start_time = time.time()
                    _, global_step_val, loss_val, predictions_val, labels_val = sess.run(
                        [train_op, global_step, loss, predictions, labels])
                    seconds_per_batch = time.time() - batch_start_time

                    if self.is_master:
                        examples_per_second = labels_val.shape[
                            0] / seconds_per_batch
                        hit_at_one = eval_util.calculate_hit_at_one(
                            predictions_val, labels_val)
                        perr = eval_util.calculate_precision_at_equal_recall_rate(
                            predictions_val, labels_val)
                        gap = eval_util.calculate_gap(predictions_val,
                                                      labels_val)

                        logging.info(
                            "%s: training step " + str(global_step_val) +
                            "| Hit@1: " + ("%.2f" % hit_at_one) + " PERR: " +
                            ("%.2f" % perr) + " GAP: " +
                            ("%.2f" % gap) + " Loss: " + str(loss_val),
                            task_as_string(self.task))

                        sv.summary_writer.add_summary(
                            utils.MakeSummary("model/Training_Hit@1",
                                              hit_at_one), global_step_val)
                        sv.summary_writer.add_summary(
                            utils.MakeSummary("model/Training_Perr", perr),
                            global_step_val)
                        sv.summary_writer.add_summary(
                            utils.MakeSummary("model/Training_GAP", gap),
                            global_step_val)
                        sv.summary_writer.add_summary(
                            utils.MakeSummary("global_step/Examples/Second",
                                              examples_per_second),
                            global_step_val)
                        sv.summary_writer.flush()

            except tf.errors.OutOfRangeError:
                logging.info("%s: Done training -- epoch limit reached.",
                             task_as_string(self.task))

        logging.info("%s: Exited training loop.", task_as_string(self.task))
        sv.Stop()
Ejemplo n.º 9
0
    def run(self, start_new_model=False):
        """Performs training on the currently defined Tensorflow graph.

    Returns:
      A tuple of the training Hit@1 and the training PERR.
    """
        if self.is_master and start_new_model:
            self.remove_training_directory(self.train_dir)

        target, device_fn = self.start_server_if_distributed()

        meta_filename = []
        for filename in self.train_dir.split(','):
            logging.info("filename:%s", str(filename))
            meta_filename.append(
                self.get_meta_filename(start_new_model, filename))

        label_loss_fn = find_class_by_name(FLAGS.label_loss, [losses])()
        optimizer_class = find_class_by_name(FLAGS.optimizer, [tf.train])

        local_device_protos = device_lib.list_local_devices()
        gpus = [x.name for x in local_device_protos if x.device_type == 'GPU']
        num_gpus = len(gpus)

        if num_gpus > 0:
            logging.info("Using the following GPUs to train: " + str(gpus))
            num_towers = num_gpus
            device_string = '/gpu:%d'
        else:
            logging.info("No GPUs found. Training on CPU.")
            num_towers = 1
            device_string = '/cpu:%d'
        # build_graph_retrain(
        #     reader=self.reader,
        #     model=self.model,
        #     train_data_pattern=FLAGS.train_data_pattern,
        #     label_loss_fn=label_loss_fn,
        #     num_readers=FLAGS.num_readers,
        #     batch_size=FLAGS.batch_size)

        # with tf.variable_scope("net2"):

        ####

        global_step = tf.Variable(0, trainable=False, name="global_step")
        learning_rate = tf.train.exponential_decay(
            FLAGS.base_learning_rate,
            global_step * FLAGS.batch_size * num_towers,
            FLAGS.learning_rate_decay_examples,
            FLAGS.learning_rate_decay,
            staircase=True)
        tf.summary.scalar('learning_rate', learning_rate)
        video_id_batch, model_input_raw, labels_batch, num_frames = get_input_data_tensors(
            # pylint: disable=g-line-too-long
            self.reader,
            FLAGS.train_data_pattern,
            batch_size=FLAGS.batch_size,
            num_readers=FLAGS.num_readers)
        tf.summary.histogram("model_input_raw", model_input_raw)

        feature_dim = len(model_input_raw.get_shape()) - 1

        # Normalize input features.
        model_input = tf.nn.l2_normalize(model_input_raw, feature_dim)
        # with tf.variable_scope("net1"):
        with tf.variable_scope("tower"):

            result1 = self.model[0].create_model(
                model_input,
                num_frames=num_frames,
                vocab_size=self.reader.num_classes,
                is_training=False)
            #####

            result1 = tf.stop_gradient(result1)
            result2 = self.model[1].create_model(
                model_input,
                num_frames=num_frames,
                vocab_size=self.reader.num_classes,
                labels=labels_batch,
                is_training=False)
            result2 = tf.stop_gradient(result2)
            all_vars = tf.global_variables()
            # for v in all_vars:
            #   print v.name
            # for i in v_vars:
            #   logging.info(str(i))
            for i, v in enumerate(all_vars):
                logging.info(str(v.name))
                if 'rnn' in v.name:
                    vars1 = all_vars[:i]
                    vars2 = all_vars[i:]
                    break
            # v_vars0 = [v for v in all_vars if v.name == 'tower/input_bn/beta:0'
            #           or v.name == 'tower/input_bn/gamma:0'
            #           or v.name == 'tower/input_bn/beta:0'
            #           or v.name == 'tower/rnn/multi_rnn_cell/cell_1/basic_lstm_cell/biases:0']
            # v_vars = [v for v in all_vars if v.name == 'tower/rnn/multi_rnn_cell/cell_0/basic_lstm_cell/weights:0'
            #           or v.name == 'tower/rnn/multi_rnn_cell/cell_0/basic_lstm_cell/biases:0'
            #           or v.name == 'tower/rnn/multi_rnn_cell/cell_1/basic_lstm_cell/weights:0'
            #           or v.name == 'tower/rnn/multi_rnn_cell/cell_1/basic_lstm_cell/biases:0']

            result1 = tf.nn.l2_normalize(result1, dim=1)
            result2 = tf.nn.l2_normalize(result2, dim=1)
            embeddings = tf.concat([result1, result2], axis=1)
            model_concat = find_class_by_name('MoeModel',
                                              [video_level_models])()
            result = model_concat.create_model(
                embeddings, vocab_size=self.reader.num_classes, num_mixtures=4)
            predictions = result["predictions"]
            # predictions=(result1["predictions"]+result2["predictions"])/2
            tf.summary.histogram("model_activations", predictions)
            # if "loss" in result.keys():
            #   label_loss = result["loss"]
            # else:
            label_loss = label_loss_fn.calculate_loss(predictions,
                                                      labels_batch)
            tf.summary.scalar("label_loss", label_loss)
            if "regularization_loss" in result.keys():
                reg_loss = result["regularization_loss"]
            reg_losses = tf.losses.get_regularization_losses()
            if "regularization_loss" in result.keys():
                reg_loss = result["regularization_loss"]
            else:
                reg_loss = tf.constant(0.0)
            final_loss = FLAGS.regularization_penalty * reg_loss + label_loss

            optimizer = optimizer_class(learning_rate)
            gradients = optimizer.compute_gradients(
                final_loss, colocate_gradients_with_ops=False)

            with tf.name_scope('clip_grads'):
                merged_gradients = utils.clip_gradient_norms(gradients, 1.0)
            train_op = optimizer.apply_gradients(merged_gradients,
                                                 global_step=global_step)

            tf.add_to_collection("global_step", global_step)
            tf.add_to_collection("loss", label_loss)
            tf.add_to_collection("predictions", predictions)
            tf.add_to_collection("input_batch", model_input)
            tf.add_to_collection("video_id_batch", video_id_batch)
            tf.add_to_collection("num_frames", num_frames)
            tf.add_to_collection("labels", tf.cast(labels_batch, tf.float32))
            tf.add_to_collection("summary_op", tf.summary.merge_all())
            tf.add_to_collection("train_op", train_op)

            video_id_batch = tf.get_collection("video_id_batch")[0]
            prediction_batch = tf.get_collection("predictions")[0]
            label_batch = tf.get_collection("labels")[0]
            loss = tf.get_collection("loss")[0]
            summary_op = tf.get_collection("summary_op")[0]
            # saver = tf.train.Saver(tf.global_variables())
            # saver=tf.train.Saver(result1)
            summary_writer = tf.summary.FileWriter(
                FLAGS.ensemble_dir, graph=tf.get_default_graph())

            config = tf.ConfigProto(allow_soft_placement=True,
                                    log_device_placement=False)
            config.gpu_options.allow_growth = True

            with tf.Session(config=config) as sess:
                train_dirs = FLAGS.train_dir.split(',')
                latest_checkpoint0 = tf.train.latest_checkpoint(train_dirs[0])
                latest_checkpoint1 = tf.train.latest_checkpoint(train_dirs[1])
                sess.run(tf.global_variables_initializer())

                if latest_checkpoint0:
                    logging.info("Loading checkpoint for eval: " +
                                 latest_checkpoint0)
                    saver1 = tf.train.Saver(vars1)

                    saver1.restore(sess, latest_checkpoint0)

                if latest_checkpoint1:
                    saver2 = tf.train.Saver(vars2)
                    logging.info("Loading checkpoint for eval: " +
                                 latest_checkpoint1)

                    saver2.restore(sess, latest_checkpoint1)

                saver = tf.train.Saver()
                fetches = [
                    learning_rate, global_step, train_op, video_id_batch,
                    prediction_batch, label_batch, loss, summary_op
                ]

                coord = tf.train.Coordinator()

                threads = []
                for qr in tf.get_collection(tf.GraphKeys.QUEUE_RUNNERS):
                    threads.extend(
                        qr.create_threads(sess,
                                          coord=coord,
                                          daemon=True,
                                          start=True))

                while not coord.should_stop():
                    # batch_start_time = time.time()
                    learning_rate_val, global_step_val, _, vid_val, predictions_val, labels_val, loss_val, summary_val = sess.run(
                        fetches)
                    # hit_at_one = eval_util.calculate_hit_at_one(predictions_val, labels_val)
                    # perr = eval_util.calculate_precision_at_equal_recall_rate(predictions_val,
                    #                                                           labels_val)
                    # gap = eval_util.calculate_gap(predictions_val, labels_val)
                    # logging.info( "training step " + str(global_step_val)+" | Loss: " + ("%.2f" % loss_val) +" | Hit@1: " +
                    #              ("%.4f" % hit_at_one) + " PERR: " + ("%.4f" % perr) +
                    #              " GAP: " + ("%.4f" % gap))

                    if self.is_master and global_step_val % self.disp_batches == 0 and self.train_dir:
                        eval_start_time = time.time()
                        hit_at_one = eval_util.calculate_hit_at_one(
                            predictions_val, labels_val)
                        perr = eval_util.calculate_precision_at_equal_recall_rate(
                            predictions_val, labels_val)
                        gap = eval_util.calculate_gap(predictions_val,
                                                      labels_val)
                        eval_end_time = time.time()
                        eval_time = eval_end_time - eval_start_time
                        logging.info("training step " + str(global_step_val) +
                                     "| learning rate: " +
                                     ("%.4f" % learning_rate_val) +
                                     " | Loss: " + ("%.2f" % loss_val) +
                                     " | Hit@1: " + ("%.4f" % hit_at_one) +
                                     " PERR: " + ("%.4f" % perr) + " GAP: " +
                                     ("%.4f" % gap))
                        summary_writer.add_summary(
                            utils.MakeSummary("model/Training_Hit@1",
                                              hit_at_one), global_step_val)
                        summary_writer.add_summary(
                            utils.MakeSummary("model/Training_Perr", perr),
                            global_step_val)
                        summary_writer.add_summary(
                            utils.MakeSummary("model/Training_GAP", gap),
                            global_step_val)
                        summary_writer.add_summary(
                            utils.MakeSummary("model/loss", loss_val),
                            global_step_val)
                        summary_writer.add_summary(
                            utils.MakeSummary("model/lr", learning_rate_val),
                            global_step_val)
                        summary_writer.flush()
                        if global_step_val % FLAGS.export_model_steps == 0:
                            saver.save(sess,
                                       FLAGS.ensemble_dir,
                                       global_step=global_step_val)

                coord.request_stop()
                coord.join(threads, stop_grace_period_secs=10)
Ejemplo n.º 10
0
    def run(self, start_new_model=False):
        """Performs training on the currently defined Tensorflow graph.

        Returns:
          A tuple of the training Hit@1 and the training PERR.
        """
        if self.is_master and start_new_model:
            self.remove_training_directory(self.train_dir)

        target, device_fn = self.start_server_if_distributed()

        meta_filename = self.get_meta_filename(start_new_model, self.train_dir)

        with tf.Graph().as_default() as graph:

            if meta_filename:
                saver = self.recover_model(meta_filename)

            with tf.device(device_fn):

                if not meta_filename:
                    saver = self.build_model(self.model, self.reader)

                global_step = tf.get_collection("global_step")[0]
                loss = tf.get_collection("loss")[0]
                predictions = tf.get_collection("predictions")[0]
                labels = tf.get_collection("labels")[0]
                train_op = tf.get_collection("train_op")[0]
                input_batch_raw = tf.get_collection("input_batch_raw")[0]
                is_neg = tf.get_collection("is_negative")[0]
                init_op = tf.global_variables_initializer()

                if FLAGS.model == "EmbeddingModel":
                    hidden_layer_activations = tf.get_collection("hidden_layer_activations")[0]

        sv = tf.train.Supervisor(
            graph,
            logdir=self.train_dir,
            init_op=init_op,
            is_chief=self.is_master,
            global_step=global_step,
            save_model_secs=15 * 60,
            save_summaries_secs=120,
            saver=saver)

        logging.info("%s: Starting managed session.", task_as_string(self.task))
        with sv.managed_session(target, config=self.config) as sess:

            try:
                logging.info("%s: Entering training loop.", task_as_string(self.task))
                batch_counter = 0
                while (not sv.should_stop()) and (not self.max_steps_reached):
                    batch_counter += 1
                    batch_start_time = time.time()
                    if FLAGS.model == "EmbeddingModel":
                        _, global_step_val, loss_val, predictions_val, labels_val, input_batch_raw_val, embeddings, is_neg_val = sess.run(
                            [train_op, global_step, loss, predictions, labels, input_batch_raw,
                             hidden_layer_activations, is_neg])
                    else:
                        _, global_step_val, loss_val, predictions_val, labels_val, input_batch_raw_val = sess.run(
                            [train_op, global_step, loss, predictions, labels, input_batch_raw])

                    seconds_per_batch = time.time() - batch_start_time

                    if self.max_steps and self.max_steps <= global_step_val:
                        self.max_steps_reached = True

                    if self.is_master:
                        k = 10
                        examples_per_second = labels_val.shape[0] / seconds_per_batch
                        predictions_val = predictions_val[:, 0:4716]
                        hit_at_one = eval_util.calculate_hit_at_one(predictions_val,
                                                                    labels_val)

                        perr = eval_util.calculate_precision_at_equal_recall_rate(
                            predictions_val, labels_val)
                        gap = eval_util.calculate_gap(predictions_val, labels_val)

                        if FLAGS.model == "EmbeddingModel" \
                                          "":
                            logging.info(is_neg_val[1])
                            hit_emb = eval_util.calculate_hit_at_k_embedding(embeddings, k)
                            logging.info(numpy.sum(numpy.multiply(embeddings[1, 0:FLAGS.embedding_size], embeddings[1, FLAGS.embedding_size:2 * FLAGS.embedding_size])))
                            logging.info("%s Training step " + str(global_step_val) + "| Hit@1: " +
                                         ("%.2f" % hit_at_one) + " HitEmbedding@" + ("%.0f: " % k) + (
                                         "%.2f" % hit_emb) + " GAP: " +
                                         ("%.2f" % gap) + " Loss: " + str(loss_val), task_as_string(self.task))
                            sv.summary_writer.add_summary(
                                utils.MakeSummary("model/Training_HitEmbedding@10", hit_emb),
                                global_step_val)
                        sv.summary_writer.add_summary(
                            utils.MakeSummary("model/Training_Hit@1", hit_at_one),
                            global_step_val)
                        sv.summary_writer.add_summary(
                            utils.MakeSummary("model/Training_Perr", perr), global_step_val)
                        sv.summary_writer.add_summary(
                            utils.MakeSummary("model/Training_GAP", gap), global_step_val)
                        sv.summary_writer.add_summary(
                            utils.MakeSummary("global_step/Examples/Second",
                                              examples_per_second), global_step_val)
                        sv.summary_writer.flush()

                        # Exporting the model every x steps
                        time_to_export = ((self.last_model_export_step == 0) or
                                          (global_step_val - self.last_model_export_step
                                           >= self.export_model_steps))

                        if self.is_master and time_to_export:
                            self.export_model(global_step_val, sv.saver, sv.save_path, sess)
                            self.last_model_export_step = global_step_val

                    if FLAGS.model == "EmbeddingModel":

                        if FLAGS.image_server & (batch_counter == 9000):
                            pred_audio = np.asarray(predictions_val[1, 0:FLAGS.embedding_size])
                            pred_frames = np.asarray(predictions_val[1, FLAGS.embedding_size:2 * FLAGS.embedding_size])
                            # plt.bar(range(1, 129), pred_audio / np.linalg.norm(pred_audio))
                            # plt.savefig("embedding_audio2.png")
                            # plt.cla()
                            # plt.bar(range(1, 129), pred_frames / np.linalg.norm(pred_frames))
                            # plt.savefig("embedding_frames2.png")
                            # plt.cla()
                            # plt.bar(range(1, 129),
                            #         pred_frames / np.linalg.norm(pred_frames) - pred_audio / np.linalg.norm(pred_audio))
                            # plt.savefig("embedding_diferencia2.png")
                            # plt.cla()
                            # logging.info("Imatges guardades")

                # Exporting the final model
                if self.is_master:
                    self.export_model(global_step_val, sv.saver, sv.save_path, sess)


            except tf.errors.OutOfRangeError:
                logging.info("%s: Done training -- epoch limit reached.",
                             task_as_string(self.task))

        logging.info("%s: Exited training loop.", task_as_string(self.task))
        sv.Stop()
        print("Hem acabat")
Ejemplo n.º 11
0
    def run(self, start_new_model=False):
        """Performs training on the currently defined Tensorflow graph.

        Returns:
          A tuple of the training Hit@1 and the training PERR.
        """
        if self.is_master and start_new_model:
            self.remove_training_directory(self.train_dir)

        if not os.path.exists(self.train_dir):
            os.makedirs(self.train_dir)

        model_flags_dict = {
            "model": FLAGS.model,
            "feature_sizes": FLAGS.feature_sizes,
            "feature_names": FLAGS.feature_names,
            "frame_features": FLAGS.frame_features,
            "label_loss": FLAGS.label_loss,
        }
        flags_json_path = os.path.join(FLAGS.train_dir, "model_flags.json")
        if file_io.file_exists(flags_json_path):
            existing_flags = json.load(
                file_io.FileIO(flags_json_path, mode="r"))
            if existing_flags != model_flags_dict:
                logging.error(
                    "Model flags do not match existing file %s. Please "
                    "delete the file, change --train_dir, or pass flag "
                    "--start_new_model", flags_json_path)
                logging.error("Ran model with flags: %s",
                              str(model_flags_dict))
                logging.error("Previously ran with flags: %s",
                              str(existing_flags))
                exit(1)
        else:
            # Write the file.
            with file_io.FileIO(flags_json_path, mode="w") as fout:
                fout.write(json.dumps(model_flags_dict))

        target, device_fn = self.start_server_if_distributed()

        meta_filename = self.get_meta_filename(start_new_model, self.train_dir)

        with tf.Graph().as_default() as graph:

            if meta_filename:
                saver = self.recover_model(meta_filename)

            with tf.device(device_fn):

                if not meta_filename:
                    saver = self.build_model()

                global_step = tf.get_collection("global_step")[0]
                loss = tf.get_collection("loss")[0]
                predictions = tf.get_collection("predictions")[0]
                labels = tf.get_collection("labels")[0]
                train_op = tf.get_collection("train_op")[0]
                init_op = tf.global_variables_initializer()

                if FLAGS.dropout:
                    keep_prob_tensor = tf.get_collection("keep_prob")[0]
                if FLAGS.noise_level > 0:
                    noise_level_tensor = tf.get_collection("noise_level")[0]
                if FLAGS.reweight:
                    weights_input, weights_assignment = None, None
                    if len(tf.get_collection("weights_input")) > 0:
                        weights_input = tf.get_collection("weights_input")[0]
                        weights_assignment = tf.get_collection(
                            "weights_assignment")[0]

        sv = tf.train.Supervisor(
            graph,
            logdir=self.train_dir,
            init_op=init_op,
            is_chief=self.is_master,
            global_step=global_step,
            save_model_secs=FLAGS.keep_checkpoint_interval * 60,
            save_summaries_secs=120,
            saver=saver)

        logging.info("%s: Starting managed session.",
                     task_as_string(self.task))
        with sv.managed_session(target, config=self.config) as sess:

            # re-assign weights
            if FLAGS.reweight:
                optional_assign_weights(sess, weights_input,
                                        weights_assignment)

            steps = 0
            try:
                logging.info("%s: Entering training loop.",
                             task_as_string(self.task))
                while not sv.should_stop():

                    steps += 1
                    batch_start_time = time.time()
                    custom_feed = {}
                    if FLAGS.dropout:
                        custom_feed[keep_prob_tensor] = FLAGS.keep_prob
                    if FLAGS.noise_level > 0:
                        custom_feed[noise_level_tensor] = FLAGS.noise_level

                    _, global_step_val, loss_val, predictions_val, labels_val = sess.run(
                        [train_op, global_step, loss, predictions, labels],
                        feed_dict=custom_feed)
                    seconds_per_batch = time.time() - batch_start_time

                    if self.is_master:
                        examples_per_second = labels_val.shape[
                            0] / seconds_per_batch
                        hit_at_one = eval_util.calculate_hit_at_one(
                            predictions_val, labels_val)
                        perr = eval_util.calculate_precision_at_equal_recall_rate(
                            predictions_val, labels_val)
                        recall = "N/A"
                        if False:
                            recall = eval_util.calculate_recall_at_n(
                                predictions_val, labels_val, FLAGS.recall_at_n)
                            sv.summary_writer.add_summary(
                                utils.MakeSummary(
                                    "model/Training_Recall@%d" %
                                    FLAGS.recall_at_n, recall),
                                global_step_val)
                            recall = "%.2f" % recall
                        gap = eval_util.calculate_gap(predictions_val,
                                                      labels_val)

                        logging.info(
                            "%s: training step " + str(global_step_val) +
                            "| Hit@1: " + ("%.2f" % hit_at_one) + " PERR: " +
                            ("%.2f" % perr) + " GAP: " + ("%.2f" % gap) +
                            " Recall@%d: " % FLAGS.recall_at_n +
                            recall + " Loss: " + str(loss_val),
                            task_as_string(self.task))

                        sv.summary_writer.add_summary(
                            utils.MakeSummary("model/Training_Hit@1",
                                              hit_at_one), global_step_val)
                        sv.summary_writer.add_summary(
                            utils.MakeSummary("model/Training_Perr", perr),
                            global_step_val)
                        sv.summary_writer.add_summary(
                            utils.MakeSummary("model/Training_GAP", gap),
                            global_step_val)
                        sv.summary_writer.add_summary(
                            utils.MakeSummary("global_step/Examples/Second",
                                              examples_per_second),
                            global_step_val)
                        sv.summary_writer.flush()

                    if FLAGS.max_steps is not None and steps > FLAGS.max_steps:
                        logging.info(
                            "%s: Done training -- max_steps limit reached.",
                            task_as_string(self.task))
                        break

            except tf.errors.OutOfRangeError:
                logging.info("%s: Done training -- epoch limit reached.",
                             task_as_string(self.task))

        logging.info("%s: Exited training loop.", task_as_string(self.task))
        sv.Stop()
Ejemplo n.º 12
0
    def run(self, start_new_model=False):
        """Performs training on the currently defined Tensorflow graph.

        Returns:
          A tuple of the training Hit@1 and the training PERR.
        """
        if self.is_master and start_new_model:
            # Training process is only recorded in master node.
            # Remove training directory. The function invoked will handle non-existing case.
            self.remove_training_directory(self.train_dir)

        target, device_fn = self.start_server_if_distributed()
        # The full path to the latest checkpoint or None if no checkpoint was found or start_new_model or ....
        meta_filename = self.get_meta_filename(start_new_model, self.train_dir)

        # Recover graph or build a new one.
        with tf.Graph().as_default() as graph:
            if meta_filename:
                saver = self.recover_model(meta_filename)

            with tf.device(device_fn):
                if not meta_filename:
                    saver = self.build_model(self.model, self.reader)

                # Get collection from tf default graph (not graph here).
                global_step = tf.get_collection("global_step")[0]
                loss = tf.get_collection("loss")[0]
                predictions = tf.get_collection("predictions")[0]
                labels = tf.get_collection("labels")[0]
                train_op = tf.get_collection("train_op")[0]
                init_op = tf.global_variables_initializer()
        # De-indentation won'tfinalize graph.
        # This means it can still be added operations (constant, variable, and other ops) to.

        # A training helper that checkpoints models and computes summaries.
        # Supervisor is a small wrapper around a Coordinator, a Saver, and a SessionManager
        # that takes care of common needs of TensorFlow training programs.
        # https://www.tensorflow.org/programmers_guide/supervisor
        sv = tf.train.Supervisor(graph,
                                 logdir=self.train_dir,
                                 init_op=init_op,
                                 is_chief=self.is_master,
                                 global_step=global_step,
                                 save_model_secs=15 * 60,
                                 save_summaries_secs=120,
                                 saver=saver)

        task_str = task_as_string(self.task)
        logging.info("{}: Starting managed session.".format(task_str))

        # Get a TensorFlow session managed by the supervisor.
        with sv.managed_session(target, config=self.config) as sess:

            try:
                logging.info("{}: Entering training loop.".format(task_str))
                while (not sv.should_stop()) and (not self.max_steps_reached):

                    batch_start_time = time.time()
                    # train_op returns None.
                    _, global_step_val, loss_val, predictions_val, labels_val = sess.run(
                        [train_op, global_step, loss, predictions, labels])
                    seconds_per_batch = time.time() - batch_start_time

                    # early stopping.
                    if self.max_steps and self.max_steps <= global_step_val:
                        self.max_steps_reached = True

                    if self.is_master:
                        #
                        examples_per_second = labels_val.shape[
                            0] / seconds_per_batch
                        hit_at_one = eval_util.calculate_hit_at_one(
                            predictions_val, labels_val)
                        perr = eval_util.calculate_precision_at_equal_recall_rate(
                            predictions_val, labels_val)
                        gap = eval_util.calculate_gap(predictions_val,
                                                      labels_val)

                        logging.info(
                            "training step {0} | Hit@1: {1} | PERR: {2} | GAP: {3} | Loss: {4}"
                            .format(global_step_val, hit_at_one, perr, gap,
                                    loss_val))

                        sv.summary_writer.add_summary(
                            utils.MakeSummary("model/Training_Hit@1",
                                              hit_at_one), global_step_val)
                        sv.summary_writer.add_summary(
                            utils.MakeSummary("model/Training_Perr", perr),
                            global_step_val)
                        sv.summary_writer.add_summary(
                            utils.MakeSummary("model/Training_GAP", gap),
                            global_step_val)
                        sv.summary_writer.add_summary(
                            utils.MakeSummary("global_step/Examples/Second",
                                              examples_per_second),
                            global_step_val)
                        sv.summary_writer.flush()

                        # Exporting the model every x steps
                        time_to_export = (
                            (self.last_model_export_step == 0)
                            or (global_step_val - self.last_model_export_step
                                >= self.export_model_steps))

                        if self.is_master and time_to_export:
                            self.export_model(global_step_val, sv.saver,
                                              sv.save_path, sess)
                            self.last_model_export_step = global_step_val

                # Exporting the final model
                if self.is_master:
                    self.export_model(global_step_val, sv.saver, sv.save_path,
                                      sess)

            except tf.errors.OutOfRangeError:
                # Queue does not have enough examples any more, caused by reaching maximal epochs of queue.
                logging.info(
                    "{}: Done training -- epoch limit reached.".format(
                        task_str))

        logging.info("{}: Exited training loop.".format(task_str))
        # Stop supervisor.
        sv.stop()
Ejemplo n.º 13
0
    def run(self, start_new_model=False):
        """Performs training on the currently defined Tensorflow graph.

    Returns:
      A tuple of the training Hit@1 and the training PERR.
    """
        if self.is_master and start_new_model:
            self.remove_training_directory(self.train_dir)

        if not os.path.exists(self.train_dir):
            os.makedirs(self.train_dir)

        logging.info(
            '############## PARAMETERS ##############################')
        logging.info("feature_names {}".format(FLAGS.feature_names))
        logging.info("feature_sizes {}".format(FLAGS.feature_sizes))
        logging.info("frame_features {}".format(FLAGS.frame_features))
        logging.info("model {}".format(FLAGS.model))
        logging.info("start_new_model {}".format(FLAGS.start_new_model))
        logging.info("num_gpu {}".format(FLAGS.num_gpu))
        logging.info("batch_size {}".format(FLAGS.batch_size))
        logging.info("label_loss {}".format(FLAGS.label_loss))
        logging.info("regularization_penalty {}".format(
            FLAGS.regularization_penalty))
        logging.info("base_learning_rate {}".format(FLAGS.base_learning_rate))
        logging.info("learning_rate_decay {}".format(
            FLAGS.learning_rate_decay))
        logging.info("learning_rate_decay_examples {}".format(
            FLAGS.learning_rate_decay_examples))
        logging.info("num_epochs {}".format(FLAGS.num_epochs))
        logging.info("max_steps {}".format(FLAGS.max_steps))
        logging.info("export_model_steps {}".format(FLAGS.export_model_steps))
        logging.info("num_readers {}".format(FLAGS.num_readers))
        logging.info("optimizer {}".format(FLAGS.optimizer))
        logging.info("clip_gradient_norm {}".format(FLAGS.clip_gradient_norm))
        logging.info(
            '########################################################')
        logging.info(' '.join([x for x in sys.argv]))

        model_flags_dict = {
            "model": FLAGS.model,
            "feature_sizes": FLAGS.feature_sizes,
            "feature_names": FLAGS.feature_names,
            "frame_features": FLAGS.frame_features,
            "label_loss": FLAGS.label_loss,
        }
        flags_json_path = os.path.join(FLAGS.train_dir, "model_flags.json")
        if os.path.exists(flags_json_path):
            existing_flags = json.load(open(flags_json_path))
            if existing_flags != model_flags_dict:
                logging.error(
                    "Model flags do not match existing file %s. Please "
                    "delete the file, change --train_dir, or pass flag "
                    "--start_new_model", flags_json_path)
                logging.error("Ran model with flags: %s",
                              str(model_flags_dict))
                logging.error("Previously ran with flags: %s",
                              str(existing_flags))
                exit(1)
        else:
            # Write the file.
            with open(flags_json_path, "w") as fout:
                fout.write(json.dumps(model_flags_dict))

        target, device_fn = self.start_server_if_distributed()

        meta_filename = self.get_meta_filename(start_new_model, self.train_dir)

        with tf.Graph().as_default() as graph:

            with tf.device(device_fn):

                saver = self.build_model(self.model, self.reader)

                global_step = tf.get_collection("global_step")[0]
                loss = tf.get_collection("loss")[0]
                predictions = tf.get_collection("predictions")[0]
                labels = tf.get_collection("labels")[0]
                train_op = tf.get_collection("train_op")[0]

                init_op, init_fn = None, None
                if meta_filename:
                    saver = tf.train.Saver(tf.global_variables(),
                                           max_to_keep=0,
                                           keep_checkpoint_every_n_hours=0.25)

                    def init_fn(sess):
                        return saver.restore(sess, meta_filename)
                else:
                    init_op = tf.global_variables_initializer()

        sv = tf.train.Supervisor(graph,
                                 logdir=self.train_dir,
                                 init_op=init_op,
                                 init_fn=init_fn,
                                 is_chief=self.is_master,
                                 global_step=global_step,
                                 save_model_secs=40 * 60,
                                 save_summaries_secs=120,
                                 saver=saver)

        logging.info("%s: Starting managed session.",
                     task_as_string(self.task))
        with sv.managed_session(target, config=self.config) as sess:
            try:
                logging.info("%s: Entering training loop.",
                             task_as_string(self.task))
                while (not sv.should_stop()) and (not self.max_steps_reached):
                    batch_start_time = time.time()
                    _, global_step_val, loss_val, predictions_val, labels_val = sess.run(
                        [train_op, global_step, loss, predictions, labels])
                    seconds_per_batch = time.time() - batch_start_time
                    examples_per_second = labels_val.shape[
                        0] / seconds_per_batch

                    if self.max_steps and self.max_steps <= global_step_val:
                        self.max_steps_reached = True

                    if self.is_master and global_step_val % 10 == 0 and self.train_dir:
                        eval_start_time = time.time()
                        hit_at_one = eval_util.calculate_hit_at_one(
                            predictions_val, labels_val)
                        perr = eval_util.calculate_precision_at_equal_recall_rate(
                            predictions_val, labels_val)
                        gap = eval_util.calculate_gap(predictions_val,
                                                      labels_val)
                        eval_end_time = time.time()
                        eval_time = eval_end_time - eval_start_time

                        logging.info("training step " + str(global_step_val) +
                                     " | Loss: " + ("%.2f" % loss_val) +
                                     " Examples/sec: " +
                                     ("%.2f" % examples_per_second) +
                                     " | Hit@1: " + ("%.2f" % hit_at_one) +
                                     " PERR: " + ("%.2f" % perr) + " GAP: " +
                                     ("%.2f" % gap))

                        sv.summary_writer.add_summary(
                            utils.MakeSummary("model/Training_Hit@1",
                                              hit_at_one), global_step_val)
                        sv.summary_writer.add_summary(
                            utils.MakeSummary("model/Training_Perr", perr),
                            global_step_val)
                        sv.summary_writer.add_summary(
                            utils.MakeSummary("model/Training_GAP", gap),
                            global_step_val)
                        sv.summary_writer.add_summary(
                            utils.MakeSummary("global_step/Examples/Second",
                                              examples_per_second),
                            global_step_val)
                        sv.summary_writer.flush()

                        # Exporting the model every x steps
                        time_to_export = (
                            (self.last_model_export_step == 0)
                            or (global_step_val - self.last_model_export_step
                                >= self.export_model_steps))

                        if self.is_master and time_to_export:
                            self.export_model(global_step_val, sv.saver,
                                              sv.save_path, sess)
                            self.last_model_export_step = global_step_val
                    else:
                        logging.info("training step " + str(global_step_val) +
                                     " | Loss: " + ("%.2f" % loss_val) +
                                     " Examples/sec: " +
                                     ("%.2f" % examples_per_second))
            except tf.errors.OutOfRangeError:
                logging.info("%s: Done training -- epoch limit reached.",
                             task_as_string(self.task))

        logging.info("%s: Exited training loop.", task_as_string(self.task))
        sv.Stop()
Ejemplo n.º 14
0
    def run(self, start_new_model=False):
        """Performs training on the currently defined Tensorflow graph.

    Returns:
      A tuple of the training Hit@1 and the training PERR.
    """
        if self.is_master and start_new_model:
            self.remove_training_directory(self.train_dir)

        target, device_fn = self.start_server_if_distributed()

        meta_filename = self.get_meta_filename(start_new_model, self.train_dir)

        with tf.Graph().as_default() as graph:

            if meta_filename:
                saver = self.recover_model(meta_filename)

            with tf.device(device_fn):
                if not meta_filename:
                    saver = self.build_model(self.model, self.reader)

                global_step = tf.get_collection("global_step")[0]
                loss = tf.get_collection("loss")[0]
                input_batch_raw = tf.get_collection("input_batch_raw")[0]
                input_batch = tf.get_collection("input_batch")[0]
                model_input_raw_ph = tf.get_collection("model_input_raw_ph")[0]
                predictions = tf.get_collection("predictions")[0]
                labels = tf.get_collection("labels")[0]
                train_op = tf.get_collection("train_op")[0]
                num_frames = tf.get_collection("num_frames")[0]
                num_frames_ph = tf.get_collection("num_frames_ph")[0]
                learning_rate = tf.get_collection("learning_rate")[0]

                init_op = tf.global_variables_initializer()

        sv = tf.train.Supervisor(graph,
                                 logdir=self.train_dir,
                                 init_op=init_op,
                                 is_chief=self.is_master,
                                 global_step=global_step,
                                 save_model_secs=15 * 60,
                                 save_summaries_secs=120,
                                 saver=saver)

        logging.info("%s: Starting managed session.",
                     task_as_string(self.task))
        with sv.managed_session(target, config=self.config) as sess:
            try:
                logging.info("%s: Entering training loop.",
                             task_as_string(self.task))
                while (not sv.should_stop()) and (not self.max_steps_reached):
                    batch_start_time = time.time()

                    model_input_raw_val, num_frames_val, learning_rate_val = sess.run(
                        [input_batch_raw, num_frames, learning_rate])

                    pr_feature = []
                    pr_num = []

                    for i in range(model_input_raw_val.shape[0]):
                        if num_frames_val[i] / FLAGS.max_scene <= 2:
                            num_tmp = num_frames_val[i] * np.ceil(
                                1 + FLAGS.max_scene / num_frames_val[i])
                            input_tmp = model_input_raw_val[
                                i][:num_frames_val[i]]
                            input_tmp = np.repeat(
                                input_tmp,
                                np.ceil(1 +
                                        FLAGS.max_scene / num_frames_val[i]),
                                0)
                        else:
                            num_tmp = num_frames_val[i]
                            input_tmp = model_input_raw_val[i][:num_tmp]

                        numvec = (input_tmp[1:] *
                                  input_tmp[:-1]).sum(axis=1) / (np.sqrt(
                                      (input_tmp[1:]**2).sum(1)) * (np.sqrt(
                                          (input_tmp[:-1]**2).sum(1))))
                        idx = np.sort(
                            numvec.argpartition(FLAGS.max_scene -
                                                1)[:FLAGS.max_scene - 1] + 1)

                        example_splits = np.split(input_tmp, idx, 0)

                        example_splits_mean = [
                            np.mean(example_split, 0)
                            for example_split in example_splits
                        ]
                        example_splits_mean = np.stack(example_splits_mean, 0)
                        pr_num.append(FLAGS.max_scene)
                        pr_feature.append(example_splits_mean)
                    pr_feature = np.stack(pr_feature, 0)
                    pr_num = np.stack(pr_num, 0)

                    _, global_step_val, loss_val, predictions_val, labels_val = sess.run(
                        [train_op, global_step, loss, predictions, labels],
                        feed_dict={
                            model_input_raw_ph: pr_feature,
                            num_frames_ph: pr_num
                        })

                    seconds_per_batch = time.time() - batch_start_time
                    examples_per_second = labels_val.shape[
                        0] / seconds_per_batch

                    if self.max_steps and self.max_steps <= global_step_val:
                        self.max_steps_reached = True

                    if self.is_master and global_step_val % 10 == 0 and self.train_dir:
                        eval_start_time = time.time()
                        hit_at_one = eval_util.calculate_hit_at_one(
                            predictions_val, labels_val)
                        perr = eval_util.calculate_precision_at_equal_recall_rate(
                            predictions_val, labels_val)
                        gap = eval_util.calculate_gap(predictions_val,
                                                      labels_val)
                        eval_end_time = time.time()
                        eval_time = eval_end_time - eval_start_time

                        logging.info("training step " + str(global_step_val) +
                                     " | Loss: " + ("%.2f" % loss_val) +
                                     " Examples/sec: " +
                                     ("%.2f" % examples_per_second) +
                                     " | Hit@1: " + ("%.2f" % hit_at_one) +
                                     " PERR: " + ("%.2f" % perr) + " GAP: " +
                                     ("%.2f" % gap))

                        sv.summary_writer.add_summary(
                            utils.MakeSummary("model/Training_Hit@1",
                                              hit_at_one), global_step_val)
                        sv.summary_writer.add_summary(
                            utils.MakeSummary("model/Training_Perr", perr),
                            global_step_val)
                        sv.summary_writer.add_summary(
                            utils.MakeSummary("model/Training_GAP", gap),
                            global_step_val)
                        sv.summary_writer.add_summary(
                            utils.MakeSummary("global_step/Examples/Second",
                                              examples_per_second),
                            global_step_val)
                        sv.summary_writer.flush()

                        # Exporting the model every x steps
                        time_to_export = (
                            (self.last_model_export_step == 0)
                            or (global_step_val - self.last_model_export_step
                                >= self.export_model_steps))

                        if self.is_master and time_to_export:
                            self.export_model(global_step_val, sv.saver,
                                              sv.save_path, sess)
                            self.last_model_export_step = global_step_val
                    else:
                        logging.info("training step " + str(global_step_val) +
                                     " | Loss: " + ("%.2f" % loss_val) +
                                     " Examples/sec: " +
                                     ("%.2f" % examples_per_second))
            except tf.errors.OutOfRangeError:
                logging.info("%s: Done training -- epoch limit reached.",
                             task_as_string(self.task))

        logging.info("%s: Exited training loop.", task_as_string(self.task))
        sv.Stop()
Ejemplo n.º 15
0
    def run(self, start_new_model=False):
        """Performs training on the currently defined Tensorflow graph.

    Returns:
      A tuple of the training Hit@1 and the training PERR.
    """
        if self.is_master and start_new_model:
            self.remove_training_directory(self.train_dir)

        target, device_fn = self.start_server_if_distributed()

        meta_filename = self.get_meta_filename(start_new_model, self.train_dir)

        with tf.Graph().as_default() as graph:

            if meta_filename:
                saver = self.recover_model(meta_filename)

            with tf.device(device_fn):

                if not meta_filename:
                    saver = self.build_model()

                global_step = tf.get_collection("global_step")[0]
                loss = tf.get_collection("loss")[0]
                predictions = tf.get_collection("predictions")[0]
                labels = tf.get_collection("labels")[0]
                train_op = tf.get_collection("train_op")[0]
                init_op = tf.global_variables_initializer()

                if FLAGS.dropout:
                    keep_prob_tensor = tf.get_collection("keep_prob")[0]
                if FLAGS.noise_level > 0:
                    noise_level_tensor = tf.get_collection("noise_level")[0]
                if FLAGS.reweight:
                    weights_input, weights_assignment = None, None
                    if len(tf.get_collection("weights_input")) > 0:
                        weights_input = tf.get_collection("weights_input")[0]
                        weights_assignment = tf.get_collection(
                            "weights_assignment")[0]

        sv = tf.train.Supervisor(graph,
                                 logdir=self.train_dir,
                                 init_op=init_op,
                                 is_chief=self.is_master,
                                 global_step=global_step,
                                 save_model_secs=6 * 60,
                                 save_summaries_secs=120,
                                 saver=saver)

        logging.info("%s: Starting managed session.",
                     task_as_string(self.task))
        with sv.managed_session(target, config=self.config) as sess:

            # re-assign weights
            if FLAGS.reweight:
                optional_assign_weights(sess, weights_input,
                                        weights_assignment)

            try:
                logging.info("%s: Entering training loop.",
                             task_as_string(self.task))
                while not sv.should_stop():

                    batch_start_time = time.time()
                    custom_feed = {}
                    if FLAGS.dropout:
                        custom_feed[keep_prob_tensor] = FLAGS.keep_prob
                    if FLAGS.noise_level > 0:
                        custom_feed[noise_level_tensor] = FLAGS.noise_level

                    _, global_step_val, loss_val, predictions_val, labels_val = sess.run(
                        [train_op, global_step, loss, predictions, labels],
                        feed_dict=custom_feed)
                    seconds_per_batch = time.time() - batch_start_time

                    if self.is_master:
                        examples_per_second = labels_val.shape[
                            0] / seconds_per_batch
                        hit_at_one = eval_util.calculate_hit_at_one(
                            predictions_val, labels_val)
                        perr = eval_util.calculate_precision_at_equal_recall_rate(
                            predictions_val, labels_val)
                        recall = "N/A"
                        if False:
                            recall = eval_util.calculate_recall_at_n(
                                predictions_val, labels_val, FLAGS.recall_at_n)
                            sv.summary_writer.add_summary(
                                utils.MakeSummary(
                                    "model/Training_Recall@%d" %
                                    FLAGS.recall_at_n, recall),
                                global_step_val)
                            recall = "%.2f" % recall
                        gap = eval_util.calculate_gap(predictions_val,
                                                      labels_val)

                        logging.info(
                            "%s: training step " + str(global_step_val) +
                            "| Hit@1: " + ("%.2f" % hit_at_one) + " PERR: " +
                            ("%.2f" % perr) + " GAP: " + ("%.2f" % gap) +
                            " Recall@%d: " % FLAGS.recall_at_n +
                            recall + " Loss: " + str(loss_val),
                            task_as_string(self.task))

                        sv.summary_writer.add_summary(
                            utils.MakeSummary("model/Training_Hit@1",
                                              hit_at_one), global_step_val)
                        sv.summary_writer.add_summary(
                            utils.MakeSummary("model/Training_Perr", perr),
                            global_step_val)
                        sv.summary_writer.add_summary(
                            utils.MakeSummary("model/Training_GAP", gap),
                            global_step_val)
                        sv.summary_writer.add_summary(
                            utils.MakeSummary("global_step/Examples/Second",
                                              examples_per_second),
                            global_step_val)
                        sv.summary_writer.flush()

            except tf.errors.OutOfRangeError:
                logging.info("%s: Done training -- epoch limit reached.",
                             task_as_string(self.task))

        logging.info("%s: Exited training loop.", task_as_string(self.task))
        sv.Stop()
Ejemplo n.º 16
0
  def run(self, start_new_model=False):
    """Performs training on the currently defined Tensorflow graph.

    Returns:
      A tuple of the training Hit@1 and the training PERR.
    """
    if self.is_master and start_new_model:
      self.remove_training_directory(self.train_dir)

    target, device_fn = self.start_server_if_distributed()

    meta_filename = self.get_meta_filename(start_new_model, self.train_dir)

    with tf.Graph().as_default() as graph:

      if meta_filename:
        saver = self.recover_model(meta_filename)

      with tf.device(device_fn):

        if not meta_filename:
          saver = self.build_model()

        global_step = tf.get_collection("global_step")[0]

        loss = tf.get_collection("loss")[0]
        refiner_loss = tf.get_collection("refiner_loss")[0]
        refiner2_loss = tf.get_collection("refiner2_loss")[0]
        discriminator_loss = tf.get_collection("discriminator_loss")[0]
        similarity_loss = tf.get_collection("similarity_loss")[0]

        predictions = tf.get_collection("predictions")[0]
        labels = tf.get_collection("labels")[0]
        float_labels = tf.get_collection("float_labels")[0]
        num_examples = tf.get_collection("num_examples")[0]

        if FLAGS.accumulate_gradients:
          refiner_init_ops = tf.get_collection("refiner_train/init_ops")
          refiner_accum_ops = tf.get_collection("refiner_train/accum_ops")
          refiner_apply_op = tf.get_collection("refiner_train/apply_op")[0]
          discriminator_init_ops = tf.get_collection("discriminator_train/init_ops")
          discriminator_accum_ops = tf.get_collection("discriminator_train/accum_ops")
          discriminator_apply_op = tf.get_collection("discriminator_train/apply_op")[0]
          refiner2_init_ops = tf.get_collection("refiner2_train/init_ops")
          refiner2_accum_ops = tf.get_collection("refiner2_train/accum_ops")
          refiner2_apply_op = tf.get_collection("refiner2_train/apply_op")[0]
        else:
          refiner_train_op = tf.get_collection("refiner_train/train_op")[0]
          discriminator_train_op = tf.get_collection("discriminator_train/train_op")[0]
          refiner2_train_op = tf.get_collection("refiner2_train/train_op")[0]

        mean_iou = tf.get_collection("mean_iou")[0]
        init_op = tf.global_variables_initializer()


    sv = tf.train.Supervisor(
        graph,
        logdir=self.train_dir,
        init_op=init_op,
        is_chief=self.is_master,
        global_step=global_step,
        save_model_secs=FLAGS.keep_checkpoint_interval * 60,
        save_summaries_secs=120,
        saver=saver)

    mean = lambda x: sum(x) / len(x)

    logging.info("%s: Starting managed session.", task_as_string(self.task))
    with sv.managed_session(target, config=self.config) as sess:

      steps = sess.run(global_step)
      try:
        logging.info("%s: Entering training loop.", task_as_string(self.task))
        while not sv.should_stop():

          steps += 1
          batch_start_time = time.time()

          num_examples_processed = 0

          refiner_stage = 10000
          discriminator_stage = 500
          interleave_stage = 40
          refiner_ratio = 4
          discriminator_ratio = 1

          training_flag = ""

          if steps < refiner_stage:
            training_flag = " refiner_init"
            sub_loss = refiner_loss
            if FLAGS.accumulate_gradients:
              init_ops = refiner_init_ops
              accum_ops = refiner_accum_ops
              apply_op = refiner_apply_op
            else:
              train_op = refiner_train_op
          elif refiner_stage <= steps < refiner_stage + discriminator_stage:
            training_flag = " discriminator_init"
            sub_loss = discriminator_loss
            if FLAGS.accumulate_gradients:
              init_ops = discriminator_init_ops
              accum_ops = discriminator_accum_ops
              apply_op = discriminator_apply_op
            else:
              train_op = discriminator_train_op
          else:
            if ((steps - refiner_stage - discriminator_stage) / interleave_stage) % \
                (refiner_ratio + discriminator_ratio) < refiner_ratio:
              training_flag = " refiner"
              sub_loss = refiner2_loss
              if FLAGS.accumulate_gradients:
                init_ops = refiner2_init_ops
                accum_ops = refiner2_accum_ops
                apply_op = refiner2_apply_op
              else:
                train_op = refiner2_train_op
            else:
              training_flag = " discriminator"
              sub_loss = discriminator_loss
              if FLAGS.accumulate_gradients:
                init_ops = discriminator_init_ops
                accum_ops = discriminator_accum_ops
                apply_op = discriminator_apply_op
              else:
                train_op = discriminator_train_op
              

          if FLAGS.accumulate_gradients:
            # init the buffer to zero
            sess.run(init_ops)
            # compute gradients
            loss_val, sub_loss_val, mean_iou_val = [], [], []
            for i in xrange(FLAGS.apply_every_n_batches):
              ret_list = sess.run([num_examples, loss, sub_loss, mean_iou] + accum_ops)
              num_examples_processed += ret_list[0]
              loss_val.append(ret_list[1])
              sub_loss_val.append(ret_list[2])
              mean_iou_val.append(ret_list[3])
            # accumulate all
            loss_val, sub_loss_val, mean_iou_val = map(mean, [loss_val, sub_loss_val, mean_iou_val])
            _, global_step_val = sess.run([apply_op, global_step])

          else:
            # the original apply-every-batch scheme
            _, global_step_val, loss_val, sub_loss_val, predictions_val, labels_val, mean_iou_val, num_examples_val = sess.run(
                [train_op, global_step, loss, sub_loss, predictions, labels, mean_iou, num_examples])
            num_examples_processed += num_examples_val

          seconds_per_batch = time.time() - batch_start_time

          if self.is_master:
            examples_per_second = num_examples_processed / seconds_per_batch

            logging.info("%s: training step " + str(global_step_val) + 
                         "| IOU: " + ("%.5f" % mean_iou_val) + 
                         " Loss: " + str(loss_val) +
                         " SubLoss: " + str(sub_loss_val) +
                         " " + training_flag, task_as_string(self.task))

            sv.summary_writer.add_summary(
                utils.MakeSummary(
                    "model/Training_IOU", mean_iou_val), global_step_val)
            sv.summary_writer.add_summary(
                utils.MakeSummary(
                    "global_step/Examples/Second", examples_per_second), global_step_val)
            sv.summary_writer.flush()

          if FLAGS.max_steps is not None and steps > FLAGS.max_steps:
            logging.info("%s: Done training -- max_steps limit reached.",
                         task_as_string(self.task))
            break

      except tf.errors.OutOfRangeError:
        logging.info("%s: Done training -- epoch limit reached.",
                     task_as_string(self.task))

    logging.info("%s: Exited training loop.", task_as_string(self.task))
    sv.Stop()
Ejemplo n.º 17
0
def train_loop(train_dir=None,
               saver=None,
               is_chief=True,
               master="",
               start_supervisor_services=True):
    """Performs training on the currently defined tensorflow graph.

  Args:
    train_dir: Where to save the model checkpoints.
    saver: The class to use for serializing the graph variables.
    is_chief: Whether this worker is the primary worker (which is responsible
    for writing checkpoints and summaries), or an anonymous member of the flock.
    master: Which Tensorflow master to listen to.
    start_supervisor_services: Whether to start threads for writing summaries
      and checkpoints.

  Returns:
  A tuple of the training Hit@1 and the training PERR.
  """
    global_step = tf.get_collection("global_step")[0]
    loss = tf.get_collection("loss")[0]
    predictions = tf.get_collection("predictions")[0]
    labels = tf.get_collection("labels")[0]
    train_op = tf.get_collection("train_op")[0]

    sv = tf.train.Supervisor(logdir=train_dir,
                             is_chief=is_chief,
                             global_step=global_step,
                             save_model_secs=60,
                             save_summaries_secs=60,
                             saver=saver)
    sess = sv.prepare_or_wait_for_session(
        master,
        start_standard_services=start_supervisor_services,
        config=tf.ConfigProto(log_device_placement=False))

    logging.info("prepared session")
    sv.start_queue_runners(sess)
    logging.info("started queue runners")

    try:
        logging.info("entering training loop")
        while not sv.should_stop():
            batch_start_time = time.time()
            _, global_step_val, loss_val, predictions_val, labels_val = sess.run(
                [train_op, global_step, loss, predictions, labels])
            seconds_per_batch = time.time() - batch_start_time
            examples_per_second = labels_val.shape[0] / seconds_per_batch

            hit_at_one = eval_util.calculate_hit_at_one(
                predictions_val, labels_val)
            perr = eval_util.calculate_precision_at_equal_recall_rate(
                predictions_val, labels_val)
            gap = eval_util.calculate_gap(predictions_val, labels_val)

            logging.info("training step " + str(global_step_val) +
                         "| Hit@1: " + ("%.2f" % hit_at_one) + " PERR: " +
                         ("%.2f" % perr) + " GAP: " + ("%.2f" % gap) +
                         " Loss: " + str(loss_val))
            if is_chief and global_step_val % 10 == 0 and train_dir:
                sv.summary_writer.add_summary(
                    utils.MakeSummary("model/Training_Hit@1", hit_at_one),
                    global_step_val)
                sv.summary_writer.add_summary(
                    utils.MakeSummary("model/Training_Perr", perr),
                    global_step_val)
                sv.summary_writer.add_summary(
                    utils.MakeSummary("global_step/Examples/Second",
                                      examples_per_second), global_step_val)
                sv.summary_writer.flush()
    except tf.errors.OutOfRangeError:
        logging.info("Done training -- epoch limit reached")
    logging.info("exited training loop")
    sv.Stop()
    return hit_at_one, perr
Ejemplo n.º 18
0
    def run(self, start_new_model=False):
        """Performs training on the currently defined Tensorflow graph.

    Returns:
      A tuple of the training Hit@1 and the training PERR.
    """
        if self.is_master and start_new_model:
            self.remove_training_directory(self.train_dir)

        target, device_fn = self.start_server_if_distributed()

        meta_filename = self.get_meta_filename(start_new_model, self.train_dir)

        num_towers = max(len(get_gpus()), 1)
        total_batch_size = FLAGS.batch_size * num_towers
        image_width, image_height = self.reader.get_image_size()

        with tf.Graph().as_default() as graph:
            if meta_filename:
                saver = self.recover_model(meta_filename)

            with tf.device(device_fn):
                if not meta_filename:
                    saver = self.build_model(self.generator_model,
                                             self.discriminator_model,
                                             self.reader)

                global_step = tf.get_collection("global_step")[0]
                D_loss = tf.get_collection("D_loss")[0]
                G_loss = tf.get_collection("G_loss")[0]
                p_for_fake = tf.get_collection("p_for_fake")[0]
                p_for_real = tf.get_collection("p_for_data")[0]
                generated_images = tf.get_collection("generated_images")[0]

                D_train_op = tf.get_collection("D_train_op")[0]
                G_train_op = tf.get_collection("G_train_op")[0]
                noise_input = tf.get_collection("noise_input_placeholder")[0]
                init_op = tf.global_variables_initializer()

        # NOTE: Set save_summaries_sec=0 here because Supervisor doesn't support
        # feeding placeholder on summary_op. Instead, it feeds summary_op manually
        # in below loop.
        sv = tf.train.Supervisor(graph,
                                 logdir=self.train_dir,
                                 init_op=init_op,
                                 is_chief=self.is_master,
                                 global_step=global_step,
                                 save_model_secs=15 * 60,
                                 save_summaries_secs=0,
                                 saver=saver)

        with sv.managed_session(target, config=self.config) as sess:
            try:
                logging.info("%s: Entering training loop.",
                             task_as_string(self.task))
                while (not sv.should_stop()) and (not self.max_steps_reached):
                    batch_start_time = time.time()

                    noise_input_batch = random_noise_generator.generate_noise(
                        total_batch_size)
                    _, _, global_step_val, D_loss_val, G_loss_val, p_fake_val, p_real_val, generated_images_val = sess.run(
                        [
                            D_train_op, G_train_op, global_step, D_loss,
                            G_loss, p_for_fake, p_for_real, generated_images
                        ],
                        feed_dict={noise_input: noise_input_batch})
                    seconds_per_batch = time.time() - batch_start_time
                    examples_per_second = p_real_val.shape[
                        0] / seconds_per_batch

                    if self.max_steps and self.max_steps <= global_step_val:
                        self.max_steps_reached = True

                    if self.is_master and global_step_val % 10 == 0 and self.train_dir:
                        eval_start_time = time.time()
                        eval_end_time = time.time()
                        eval_time = eval_end_time - eval_start_time

                        accuracy_on_fake = eval_util.calculate_accuracy_on_fake(
                            p_fake_val)
                        accuracy_on_real = eval_util.calculate_accuracy_on_real(
                            p_real_val)

                        logging.info("training step " + str(global_step_val) +
                                     " | G Loss: " + ("%.4f" % G_loss_val) +
                                     " | D loss: " + ("%.4f" % D_loss_val) +
                                     " | Examples/sec: " +
                                     ("%.2f" % examples_per_second) +
                                     " | D accuracy on G: " +
                                     ("%.2f" % accuracy_on_fake) +
                                     " | D accuracy on real: " +
                                     ("%.2f" % accuracy_on_real))

                        sv.summary_writer.add_summary(
                            utils.MakeSummary("global_step/Examples/Second",
                                              examples_per_second),
                            global_step_val)
                        sv.summary_writer.flush()

                        # Exporting the model, and gather summary every x steps
                        time_to_export = (
                            (self.last_model_export_step == 0)
                            or (global_step_val - self.last_model_export_step
                                >= self.export_model_steps))

                        if self.is_master and time_to_export:
                            self.export_model(global_step_val, sv.saver,
                                              sv.save_path, sess)
                            self.last_model_export_step = global_step_val
                            sv.summary_computed(
                                sess,
                                sess.run(
                                    sv.summary_op,
                                    feed_dict={noise_input:
                                               noise_input_batch}))

                    else:
                        logging.info("training step " + str(global_step_val) +
                                     " | G Loss: " + ("%.4f" % G_loss_val) +
                                     " | D loss: " + ("%.4f" % D_loss_val) +
                                     " | Examples/sec: " +
                                     ("%.2f" % examples_per_second))

                    # Save some generated image samples in png file.
                    if self.is_master and self.export_generated_images and\
                        (global_step_val % self.export_image_steps) == 0 and self.image_dir:
                        fig = plot(generated_images_val[:16, :], image_width,
                                   image_height)
                        filename = (self.image_dir + '{}.png').format(
                            str(global_step_val /
                                self.export_image_steps).zfill(3))
                        plt.savefig(filename, bbox_inches='tight')
                        plt.close(fig)
                        logging.info("Exported image - " + filename)

            except tf.errors.OutOfRangeError:
                logging.info("%s: Done training -- epoch limit reached.",
                             task_as_string(self.task))

        logging.info("%s: Exited training loop.", task_as_string(self.task))
        sv.Stop()
Ejemplo n.º 19
0
    def run(self, start_new_model=False):
        """Performs training on the currently defined Tensorflow graph.
    Returns:
      A tuple of the training Hit@1 and the training PERR.
    """
        if self.is_master and start_new_model:
            self.remove_training_directory(self.train_dir)

        if not os.path.exists(self.train_dir):
            os.makedirs(self.train_dir)

        model_flags_dict = {
            "model": FLAGS.model,
            "feature_sizes": FLAGS.feature_sizes,
            "feature_names": FLAGS.feature_names,
            "frame_features": FLAGS.frame_features,
            "label_loss": FLAGS.label_loss,
            "netvlad_cluster_size": FLAGS.netvlad_cluster_size,
            "netvlad_hidden_size": FLAGS.netvlad_hidden_size,
            "moe_l2": FLAGS.moe_l2,
            "iterations": FLAGS.iterations,
            "netvlad_relu": FLAGS.netvlad_relu,
            "gating": FLAGS.gating,
            "moe_num_mixtures": FLAGS.moe_num_mixtures,
            "moe_prob_gating": FLAGS.moe_prob_gating,
        }
        #     model_flags_dict = {}

        #     for k in FLAGS.__flags:
        #         model_flags_dict[k] = FLAGS[k].value
        flags_json_path = os.path.join(FLAGS.train_dir, "model_flags.json")
        if os.path.exists(flags_json_path):
            existing_flags = json.load(open(flags_json_path))
            if existing_flags != model_flags_dict:
                logging.error(
                    "Model flags do not match existing file %s. Please "
                    "delete the file, change --train_dir, or pass flag "
                    "--start_new_model", flags_json_path)
                logging.error("Ran model with flags: %s",
                              str(model_flags_dict))
                logging.error("Previously ran with flags: %s",
                              str(existing_flags))
                exit(1)
        else:
            # Write the file.
            with open(flags_json_path, "w") as fout:
                fout.write(json.dumps(model_flags_dict))

        target, device_fn = self.start_server_if_distributed()

        meta_filename = self.get_meta_filename(start_new_model, self.train_dir)

        with tf.Graph().as_default() as graph:
            if meta_filename:
                saver = self.recover_model(meta_filename)

            with tf.device(device_fn):
                if not meta_filename:
                    saver = self.build_model(self.model, self.reader)

                global_step = tf.get_collection("global_step")[0]
                loss = tf.get_collection("loss")[0]
                predictions = tf.get_collection("predictions")[0]
                labels = tf.get_collection("labels")[0]
                train_op = tf.get_collection("train_op")[0]
                init_op = tf.global_variables_initializer()

        sv = tf.train.Supervisor(graph,
                                 logdir=self.train_dir,
                                 init_op=init_op,
                                 is_chief=self.is_master,
                                 global_step=global_step,
                                 save_model_secs=15 * 60,
                                 save_summaries_secs=120,
                                 saver=saver)

        logging.info("%s: Starting managed session.",
                     task_as_string(self.task))
        with sv.managed_session(target, config=self.config) as sess:
            try:
                logging.info("%s: Entering training loop.",
                             task_as_string(self.task))
                while (not sv.should_stop()) and (not self.max_steps_reached):
                    batch_start_time = time.time()
                    _, global_step_val, loss_val, predictions_val, labels_val = sess.run(
                        [train_op, global_step, loss, predictions, labels])
                    seconds_per_batch = time.time() - batch_start_time
                    examples_per_second = labels_val.shape[
                        0] / seconds_per_batch

                    if self.max_steps and self.max_steps <= global_step_val:
                        self.max_steps_reached = True

                    if self.is_master and global_step_val % 10 == 0 and self.train_dir:
                        eval_start_time = time.time()
                        hit_at_one = eval_util.calculate_hit_at_one(
                            predictions_val, labels_val)
                        perr = eval_util.calculate_precision_at_equal_recall_rate(
                            predictions_val, labels_val)
                        gap = eval_util.calculate_gap(predictions_val,
                                                      labels_val)
                        eval_end_time = time.time()
                        eval_time = eval_end_time - eval_start_time

                        logging.info("training step " + str(global_step_val) +
                                     " | Loss: " + ("%.2f" % loss_val) +
                                     " Examples/sec: " +
                                     ("%.2f" % examples_per_second) +
                                     " | Hit@1: " + ("%.2f" % hit_at_one) +
                                     " PERR: " + ("%.2f" % perr) + " GAP: " +
                                     ("%.2f" % gap))

                        sv.summary_writer.add_summary(
                            utils.MakeSummary("model/Training_Hit@1",
                                              hit_at_one), global_step_val)
                        sv.summary_writer.add_summary(
                            utils.MakeSummary("model/Training_Perr", perr),
                            global_step_val)
                        sv.summary_writer.add_summary(
                            utils.MakeSummary("model/Training_GAP", gap),
                            global_step_val)
                        sv.summary_writer.add_summary(
                            utils.MakeSummary("global_step/Examples/Second",
                                              examples_per_second),
                            global_step_val)
                        sv.summary_writer.flush()

                        # Exporting the model every x steps
                        time_to_export = (
                            (self.last_model_export_step == 0)
                            or (global_step_val - self.last_model_export_step
                                >= self.export_model_steps))

                        if self.is_master and time_to_export:
                            self.export_model(global_step_val, sv.saver,
                                              sv.save_path, sess)
                            self.last_model_export_step = global_step_val
                    else:
                        logging.info("training step " + str(global_step_val) +
                                     " | Loss: " + ("%.2f" % loss_val) +
                                     " Examples/sec: " +
                                     ("%.2f" % examples_per_second))
            except tf.errors.OutOfRangeError:
                logging.info("%s: Done training -- epoch limit reached.",
                             task_as_string(self.task))

        logging.info("%s: Exited training loop.", task_as_string(self.task))
        sv.Stop()
Ejemplo n.º 20
0
  def run(self, start_new_model=False):
    """Performs training on the currently defined Tensorflow graph.

    Returns:
      A tuple of the training Hit@1 and the training PERR.
    """
    if self.is_master and start_new_model:
      self.remove_training_directory(self.train_dir)

    target, device_fn = self.start_server_if_distributed()

    meta_filename = self.get_meta_filename(start_new_model, self.train_dir)

    with tf.Graph().as_default() as graph:

      if meta_filename:
        saver = self.recover_model(meta_filename)

      with tf.device(device_fn):

        if not meta_filename:
          saver = self.build_model(self.model, self.reader)

        global_step = tf.get_collection("global_step")[0]
        loss = tf.get_collection("loss")[0]
        predictions = tf.get_collection("predictions")[0]
        labels = tf.get_collection("labels")[0]
        train_batch = tf.get_collection("train_batch")[0]
        train_op = tf.get_collection("train_op")[0]
        ilh = [];ilc = []
        reset_state_stackb = {}
        if FLAGS.model != 'MDLSTMCTCModel': 
            for i in range(FLAGS.layers):        
                reset_state_stackb['h{}'.format(i)]=tf.get_collection("reset_state_stackb_{}_h".format(i))[0]
                reset_state_stackb['c{}'.format(i)]=tf.get_collection("reset_state_stackb_{}_c".format(i))[0]
                #ilh.append(reset_state_stackb['h{}'.format(i)])
                #ilc.append(reset_state_stackb['c{}'.format(i)])
            reset_state_stackf = {}
            for i in range(FLAGS.layers):        
                reset_state_stackf['h{}'.format(i)]=tf.get_collection("reset_state_stackf_{}_h".format(i))[0]
                reset_state_stackf['c{}'.format(i)]=tf.get_collection("reset_state_stackf_{}_c".format(i))[0]
            final_state_stackf = {}
            flh = [];flc = []
            for i in range(FLAGS.layers):        
                final_state_stackf['h{}'.format(i)]=tf.get_collection("final_state_stackf_{}_h".format(i))[0]
                final_state_stackf['c{}'.format(i)]=tf.get_collection("final_state_stackf_{}_c".format(i))[0]
            final_state_stackb = {}
            for i in range(FLAGS.layers):        
                final_state_stackb['h{}'.format(i)]=tf.get_collection("final_state_stackb_{}_h".format(i))[0]
                final_state_stackb['c{}'.format(i)]=tf.get_collection("final_state_stackb_{}_c".format(i))[0]

        #reset_state_stackb = tf.get_collection("reset_state_stackb")[0]
        #reset_state_stackf = tf.get_collection("reset_state_stackf")[0]
        #final_state_stackb = tf.get_collection("final_state_stackb")[0]
        #final_state_stackf = tf.get_collection("final_state_stackf")[0]
        decodedPrediction = []
        for i in range(FLAGS.beam_size):
            decodedPrediction.append(tf.get_collection("decodedPrediction{}".format(i))[0])
        ler = tf.get_collection("ler")[0]
        init_op = tf.global_variables_initializer()

    sv = tf.train.Supervisor(
        graph,
        logdir=self.train_dir,
        init_op=init_op,
        is_chief=self.is_master,
        global_step=global_step,
        save_model_secs=15 * 60,
        save_summaries_secs=120,
        saver=saver)
    
    vocabulary = eval_util.read_vocab(FLAGS.vocab_path)
    vocabulary = sorted(vocabulary, key=lambda word: len(word))
    caracters = eval_util.get_characters()
    trie = eval_util.get_trie(vocabulary)
    on, bi, tr = eval_util.get_n_gram(vocabulary,29)
    def tranz(x):
        return eval_util.bi_gram_model(x, tr+0.01, bi+0.01, on)
    
    logging.info("%s: Starting managed session.", task_as_string(self.task))
    with sv.managed_session(target, config=self.config) as sess:

      try:
        
        #state_stackf = sess.run(reset_state_stackf)
        #state_stackb = sess.run(reset_state_stackb)
        logging.info("%s: Entering training loop.", task_as_string(self.task))
        while (not sv.should_stop()) and (not self.max_steps_reached):

          batch_start_time = time.time()
          _ ,global_step_val = sess.run([train_op ,global_step])
          seconds_per_batch = time.time() - batch_start_time
        
          #print(decodedPr,'decoder pr');print(labels_val,'val label')#;print(decV1,'edit dis')
          #todo: add test/evaluation here--add placeholder
        
          feed = {}
          if False:
              for i in range(FLAGS.layers):
                        feed[reset_state_stackb['h{}'.format(i)]] = state_stackb['h{}'.format(i)]
                        feed[reset_state_stackb['h{}'.format(i)]] = state_stackb['h{}'.format(i)]
              for i in range(FLAGS.layers):
                        feed[reset_state_stackf['h{}'.format(i)]] = state_stackf['h{}'.format(i)]
                        feed[reset_state_stackf['h{}'.format(i)]] = state_stackf['h{}'.format(i)]
                    
          if self.max_steps and self.max_steps <= global_step_val:
            self.max_steps_reached = True

          if self.is_master and global_step_val%FLAGS.display_step==0:
            global_step_val, loss_val, predictions_val, labels_val, labelRateError, decodedPr = sess.run(
              [ global_step, loss, predictions, labels, ler, decodedPrediction
               ],feed)
            
            feed[train_batch]=False
            global_step_val_te, loss_val_te, predictions_val_te, labels_val_te, labelRateError_te, decodedPr_te = sess.run(
              [ global_step, loss, predictions, labels, ler, decodedPrediction],feed)
            
            examples_per_second = len(labels_val) / seconds_per_batch
            
            
            if global_step_val % FLAGS.display_step_lme == 0:
                lme = 0
                #lme, newGuess = eval_util.calculate_models_error_withLanguageModel(decodedPr, 
                #                                                                   labels_val,
                #                                                                   vocabulary, 
                #                                                                   FLAGS.beam_size)
                #lme_te, newGuess_te = eval_util.calculate_models_error_withLanguageModel(decodedPr_te,
                #                                                                         labels_val_te,
                #                                                                         vocabulary, 
                #                                                                         FLAGS.beam_size)
                if False:
                    lmd_pred = eval_util.beam_search_dict(predictions_val_te, tranz,bk=30)
                    model_pred, lme = eval_util.dict_model(lmd_pred, lambda x: eval_util.trie_exist(trie,x),
                                                           labels_val_te, vocabulary=None,bk=30)
                    model_pred, err = eval_util.dict_model(eval_util.mkP(decodedPr_te), lambda x: eval_util.trie_exist(trie,x),
                                                           labels_val_te, vocabulary=None,bk=30)
                #print(predictions_val_te.shape, np.sum(predictions_val_te[0][0]))
                #for llk in range(1):
                #    print('custom beam',[eval_util.getIndex(j,caracters) for j in lmd_pred[llk][0][2] if j])
                #    #print('lme',err,[eval_util.getIndex(j,caracters) for j in model_pred[llk] if j])
                eval_util.show_prediction(decodedPr_te, labels_val_te,None,top_k=3)
                #lme_te = err
            else:
                lme,  lme_te = -1., -1.
            if False:
                eval_util.show_prediction(decodedPr, labels_val)

            logging.info(
                "%s: training step " + str(global_step_val) 
               # + "| LME: " +  ("%.2f" % lme) + "| LME-te: " +  ("%.2f" % lme_te) 
                + " ler: " +   ("%.2f" % labelRateError) + " ler-te: " +   ("%.2f" % labelRateError_te) 
                + " Loss: " +  ("%.2f" % loss_val) + " Loss-te: " + str(loss_val_te),
                task_as_string(self.task))

            sv.summary_writer.add_summary(
                utils.MakeSummary("model/labelRateError_train", labelRateError),
                global_step_val)
            sv.summary_writer.add_summary(
                utils.MakeSummary("model/labelRateError_test", labelRateError_te),
                global_step_val)
            sv.summary_writer.add_summary(
                utils.MakeSummary("model/lme_train", lme), global_step_val)
            sv.summary_writer.add_summary(
                utils.MakeSummary("model/lme_test", lme_te), global_step_val)
            sv.summary_writer.add_summary(
                utils.MakeSummary("model/loss_train", loss_val), global_step_val)
            sv.summary_writer.add_summary(
                utils.MakeSummary("model/loss_test", loss_val_te), global_step_val)
            sv.summary_writer.add_summary(
                utils.MakeSummary("global_step/Examples/Second",
                                  examples_per_second), global_step_val)
            sv.summary_writer.flush()

            # Exporting the model every x steps
            time_to_export = ((self.last_model_export_step == 0) or 
                (global_step_val - self.last_model_export_step 
                 >= self.export_model_steps))

            if self.is_master and time_to_export:
              eval_util.show_prediction(decodedPr, labels_val)
              self.export_model(global_step_val, sv.saver, sv.save_path, sess)
              self.last_model_export_step = global_step_val

        # Exporting the final model
        if self.is_master:
          eval_util.show_prediction(decodedPr, labels_val)
          self.export_model(global_step_val, sv.saver, sv.save_path, sess)

      except tf.errors.OutOfRangeError:
        logging.info("%s: Done training -- epoch limit reached.",
                     task_as_string(self.task))

    logging.info("%s: Exited training loop.", task_as_string(self.task))
    sv.Stop()