コード例 #1
0
ファイル: train.py プロジェクト: lvaleriu/Youtube-8M-WILLOW
  def run(self):
    """Starts the parameter server."""

    logging.info("%s: Starting parameter server within cluster %s.",
                 task_as_string(self.task), self.cluster.as_dict())
    server = start_server(self.cluster, self.task)
    server.join()
コード例 #2
0
ファイル: inference.py プロジェクト: vijayky88/youtube-8m
def get_input_data_tensors(reader, data_pattern, batch_size, num_readers=1):
  """Creates the section of the graph which reads the input data.

  Args:
    reader: A class which parses the input data.
    data_pattern: A 'glob' style path to the data files.
    batch_size: How many examples to process at a time.
    num_readers: How many I/O threads to use.

  Returns:
    A tuple containing the features tensor, labels tensor, and optionally a
    tensor containing the number of frames per video. The exact dimensions
    depend on the reader being used.

  Raises:
    IOError: If no files matching the given pattern were found.
  """
  with tf.name_scope("input"):
    files = gfile.Glob(data_pattern)
    if not files:
      raise IOError("Unable to find input files. data_pattern='" +
                    data_pattern + "'")
    logging.info("number of input files: " + str(len(files)))
    filename_queue = tf.train.string_input_producer(
        files, num_epochs=1, shuffle=False)
    examples_and_labels = [reader.prepare_reader(filename_queue)
                           for _ in range(num_readers)]

    video_id_batch, video_batch, unused_labels, num_frames_batch = (
        tf.train.batch_join(examples_and_labels,
                            batch_size=batch_size,
                            allow_smaller_final_batch=True,
                            enqueue_many=True))
    return video_id_batch, video_batch, num_frames_batch
def main(unused_argv):
  logging.set_verbosity(tf.logging.INFO)

  if not FLAGS.json_prediction_files_pattern:
    raise ValueError(
        "The flag --json_prediction_files_pattern must be specified.")

  if not FLAGS.csv_output_file:
    raise ValueError("The flag --csv_output_file must be specified.")

  logging.info("Looking for prediction files with pattern: %s", 
               FLAGS.json_prediction_files_pattern)

  file_paths = gfile.Glob(FLAGS.json_prediction_files_pattern)  
  logging.info("Found files: %s", file_paths)

  logging.info("Writing submission file to: %s", FLAGS.csv_output_file)
  with gfile.Open(FLAGS.csv_output_file, "w+") as output_file:
    output_file.write(get_csv_header())

    for file_path in file_paths:
      logging.info("processing file: %s", file_path)

      with gfile.Open(file_path) as input_file:

        for line in input_file: 
          json_data = json.loads(line)
          output_file.write(to_csv_row(json_data))

    output_file.flush()
  logging.info("done")
コード例 #4
0
ファイル: eval.py プロジェクト: vijayky88/youtube-8m
def evaluate():
  tf.set_random_seed(0)  # for reproducibility

  # Write json of flags
  model_flags_path = os.path.join(FLAGS.train_dir, "model_flags.json")
  if not file_io.file_exists(model_flags_path):
    raise IOError(("Cannot find file %s. Did you run train.py on the same "
                   "--train_dir?") % model_flags_path)
  flags_dict = json.loads(file_io.FileIO(model_flags_path, mode="r").read())

  with tf.Graph().as_default():
    # convert feature_names and feature_sizes to lists of values
    feature_names, feature_sizes = utils.GetListOfFeatureNamesAndSizes(
        flags_dict["feature_names"], flags_dict["feature_sizes"])

    if flags_dict["frame_features"]:
      reader = readers.YT8MFrameFeatureReader(feature_names=feature_names,
                                              feature_sizes=feature_sizes)
    else:
      reader = readers.YT8MAggregatedFeatureReader(feature_names=feature_names,
                                                   feature_sizes=feature_sizes)

    model = find_class_by_name(flags_dict["model"],
        [frame_level_models, video_level_models])()
    label_loss_fn = find_class_by_name(flags_dict["label_loss"], [losses])()

    if FLAGS.eval_data_pattern is "":
      raise IOError("'eval_data_pattern' was not specified. " +
                     "Nothing to evaluate.")

    build_graph(
        reader=reader,
        model=model,
        eval_data_pattern=FLAGS.eval_data_pattern,
        label_loss_fn=label_loss_fn,
        num_readers=FLAGS.num_readers,
        batch_size=FLAGS.batch_size)
    logging.info("built evaluation graph")
    video_id_batch = tf.get_collection("video_id_batch")[0]
    prediction_batch = tf.get_collection("predictions")[0]
    label_batch = tf.get_collection("labels")[0]
    loss = tf.get_collection("loss")[0]
    summary_op = tf.get_collection("summary_op")[0]

    saver = tf.train.Saver(tf.global_variables())
    summary_writer = tf.summary.FileWriter(
        FLAGS.train_dir, graph=tf.get_default_graph())

    evl_metrics = eval_util.EvaluationMetrics(reader.num_classes, FLAGS.top_k)

    last_global_step_val = -1
    while True:
      last_global_step_val = evaluation_loop(video_id_batch, prediction_batch,
                                             label_batch, loss, summary_op,
                                             saver, summary_writer, evl_metrics,
                                             last_global_step_val)
      if FLAGS.run_once:
        break
コード例 #5
0
ファイル: train.py プロジェクト: lvaleriu/Youtube-8M-WILLOW
 def remove_training_directory(self, train_dir):
   """Removes the training directory."""
   try:
     logging.info(
         "%s: Removing existing train directory.",
         task_as_string(self.task))
     gfile.DeleteRecursively(train_dir)
   except:
     logging.error(
         "%s: Failed to delete directory " + train_dir +
         " when starting a new model. Please delete it manually and" +
         " try again.", task_as_string(self.task))
コード例 #6
0
def inference(reader, train_dir, data_pattern, out_file_location, batch_size, top_k):
  with tf.Session() as sess, gfile.Open(out_file_location, "w+") as out_file:
    video_id_batch, video_batch, num_frames_batch = get_input_data_tensors(reader, data_pattern, batch_size)
    latest_checkpoint = tf.train.latest_checkpoint(train_dir)
    if latest_checkpoint is None:
      raise Exception("unable to find a checkpoint at location: %s" % train_dir)
    else:
      if FLAGS.check_point < 0:
        meta_graph_location = latest_checkpoint + ".meta"
      else:
        meta_graph_location = FLAGS.train_dir + "/model.ckpt-" + str(FLAGS.check_point) + ".meta"
        latest_checkpoint = FLAGS.train_dir + "/model.ckpt-" + str(FLAGS.check_point)
      logging.info("loading meta-graph: " + meta_graph_location)
    saver = tf.train.import_meta_graph(meta_graph_location, clear_devices=True)
    logging.info("restoring variables from " + latest_checkpoint)
    saver.restore(sess, latest_checkpoint)
    input_tensor = tf.get_collection("input_batch_raw")[0]
    num_frames_tensor = tf.get_collection("num_frames")[0]
    predictions_tensor = tf.get_collection("predictions")[0]

    # Workaround for num_epochs issue.
    def set_up_init_ops(variables):
      init_op_list = []
      for variable in list(variables):
        if "train_input" in variable.name:
          init_op_list.append(tf.assign(variable, 1))
          variables.remove(variable)
      init_op_list.append(tf.variables_initializer(variables))
      return init_op_list

    sess.run(set_up_init_ops(tf.get_collection_ref(
        tf.GraphKeys.LOCAL_VARIABLES)))

    coord = tf.train.Coordinator()
    threads = tf.train.start_queue_runners(sess=sess, coord=coord)
    num_examples_processed = 0
    start_time = time.time()
    out_file.write("VideoId,LabelConfidencePairs\n")

    try:
      while not coord.should_stop():
          video_id_batch_val, video_batch_val,num_frames_batch_val = sess.run([video_id_batch, video_batch, num_frames_batch])
          predictions_val, = sess.run([predictions_tensor], feed_dict={input_tensor: video_batch_val, num_frames_tensor: num_frames_batch_val})
          now = time.time()
          num_examples_processed += len(video_batch_val)
          num_classes = predictions_val.shape[1]
          logging.info("num examples processed: " + str(num_examples_processed) + " elapsed seconds: " + "{0:.2f}".format(now-start_time))
          for line in format_lines(video_id_batch_val, predictions_val, top_k):
            out_file.write(line)
          out_file.flush()


    except tf.errors.OutOfRangeError:
        logging.info('Done with inference. The output file was written to ' + out_file_location)
    finally:
        coord.request_stop()

    coord.join(threads)
    sess.close()
コード例 #7
0
ファイル: train.py プロジェクト: lvaleriu/Youtube-8M-WILLOW
  def export_model(self, global_step_val, saver, save_path, session):

    # If the model has already been exported at this step, return.
    if global_step_val == self.last_model_export_step:
      return

    last_checkpoint = saver.save(session, save_path, global_step_val)

    model_dir = "{0}/export/step_{1}".format(self.train_dir, global_step_val)
    logging.info("%s: Exporting the model at step %s to %s.",
                 task_as_string(self.task), global_step_val, model_dir)

    self.model_exporter.export_model(
        model_dir=model_dir, 
        global_step_val=global_step_val,
        last_checkpoint=last_checkpoint)
コード例 #8
0
ファイル: train.py プロジェクト: lvaleriu/Youtube-8M-WILLOW
  def start_server_if_distributed(self):
    """Starts a server if the execution is distributed."""

    if self.cluster:
      logging.info("%s: Starting trainer within cluster %s.",
                   task_as_string(self.task), self.cluster.as_dict())
      server = start_server(self.cluster, self.task)
      target = server.target
      device_fn = tf.train.replica_device_setter(
          ps_device="/job:ps",
          worker_device="/job:%s/task:%d" % (self.task.type, self.task.index),
          cluster=self.cluster)
    else:
      target = ""
      device_fn = ""
    return (target, device_fn)
コード例 #9
0
ファイル: train.py プロジェクト: lvaleriu/Youtube-8M-WILLOW
def get_input_data_tensors(reader,
                           data_pattern,
                           batch_size=1000,
                           num_epochs=None,
                           num_readers=1):
  """Creates the section of the graph which reads the training data.

  Args:
    reader: A class which parses the training data.
    data_pattern: A 'glob' style path to the data files.
    batch_size: How many examples to process at a time.
    num_epochs: How many passes to make over the training data. Set to 'None'
                to run indefinitely.
    num_readers: How many I/O threads to use.

  Returns:
    A tuple containing the features tensor, labels tensor, and optionally a
    tensor containing the number of frames per video. The exact dimensions
    depend on the reader being used.

  Raises:
    IOError: If no files matching the given pattern were found.
  """
  logging.info("Using batch size of " + str(batch_size) + " for training.")
  with tf.name_scope("train_input"):
    files = gfile.Glob(data_pattern)
    if not files:
      raise IOError("Unable to find training files. data_pattern='" +
                    data_pattern + "'.")
    logging.info("Number of training files: %s.", str(len(files)))
    filename_queue = tf.train.string_input_producer(
        files, num_epochs=num_epochs, shuffle=True)
    training_data = [
        reader.prepare_reader(filename_queue) for _ in range(num_readers)
    ]

    return tf.train.shuffle_batch_join(
        training_data,
        batch_size=batch_size,
        capacity=FLAGS.batch_size * 5,
        min_after_dequeue=FLAGS.batch_size,
        allow_smaller_final_batch=True,
        enqueue_many=True)
コード例 #10
0
ファイル: train.py プロジェクト: lvaleriu/Youtube-8M-WILLOW
def main(unused_argv):
  # Load the environment.
  env = json.loads(os.environ.get("TF_CONFIG", "{}"))

  # Load the cluster data from the environment.
  cluster_data = env.get("cluster", None)
  cluster = tf.train.ClusterSpec(cluster_data) if cluster_data else None

  # Load the task data from the environment.
  task_data = env.get("task", None) or {"type": "master", "index": 0}
  task = type("TaskSpec", (object,), task_data)

  # Logging the version.
  logging.set_verbosity(tf.logging.INFO)
  logging.info("%s: Tensorflow version: %s.",
               task_as_string(task), tf.__version__)

  # Dispatch to a master, a worker, or a parameter server.
  if not cluster or task.type == "master" or task.type == "worker":
    
    model = find_class_by_name(FLAGS.model, 
        [frame_level_models, video_level_models])()
    
    reader = get_reader()
    
    model_exporter = export_model.ModelExporter(
        frame_features=FLAGS.frame_features,
        model=model,
        reader=reader)

    Trainer(cluster, task, FLAGS.train_dir, model, reader, model_exporter, 
            FLAGS.log_device_placement, FLAGS.max_steps, 
            FLAGS.export_model_steps).run(start_new_model=FLAGS.start_new_model)

  elif task.type == "ps":

    ParameterServer(cluster, task).run()

  else:

    raise ValueError("%s: Invalid task_type: %s." %
                     (task_as_string(task), task.type))
コード例 #11
0
ファイル: eval.py プロジェクト: vijayky88/youtube-8m
def get_input_evaluation_tensors(reader,
                                 data_pattern,
                                 batch_size=1024,
                                 num_readers=1):
  """Creates the section of the graph which reads the evaluation data.

  Args:
    reader: A class which parses the training data.
    data_pattern: A 'glob' style path to the data files.
    batch_size: How many examples to process at a time.
    num_readers: How many I/O threads to use.

  Returns:
    A tuple containing the features tensor, labels tensor, and optionally a
    tensor containing the number of frames per video. The exact dimensions
    depend on the reader being used.

  Raises:
    IOError: If no files matching the given pattern were found.
  """
  logging.info("Using batch size of " + str(batch_size) + " for evaluation.")
  with tf.name_scope("eval_input"):
    files = gfile.Glob(data_pattern)
    if not files:
      raise IOError("Unable to find the evaluation files.")
    logging.info("number of evaluation files: " + str(len(files)))
    filename_queue = tf.train.string_input_producer(
        files, shuffle=False, num_epochs=1)
    eval_data = [
        reader.prepare_reader(filename_queue) for _ in range(num_readers)
    ]
    return tf.train.batch_join(
        eval_data,
        batch_size=batch_size,
        capacity=3 * batch_size,
        allow_smaller_final_batch=True,
        enqueue_many=True)
コード例 #12
0
ファイル: train.py プロジェクト: lvaleriu/Youtube-8M-WILLOW
 def get_meta_filename(self, start_new_model, train_dir):
   if start_new_model:
     logging.info("%s: Flag 'start_new_model' is set. Building a new model.",
                  task_as_string(self.task))
     return None
   
   latest_checkpoint = tf.train.latest_checkpoint(train_dir)
   if not latest_checkpoint: 
     logging.info("%s: No checkpoint file found. Building a new model.",
                  task_as_string(self.task))
     return None
   
   meta_filename = latest_checkpoint + ".meta"
   if not gfile.Exists(meta_filename):
     logging.info("%s: No meta graph file found. Building a new model.",
                    task_as_string(self.task))
     return None
   else:
     return meta_filename
コード例 #13
0
ファイル: train.py プロジェクト: lvaleriu/Youtube-8M-WILLOW
  def run(self, start_new_model=False):
    """Performs training on the currently defined Tensorflow graph.

    Returns:
      A tuple of the training Hit@1 and the training PERR.
    """
    if self.is_master and start_new_model:
      self.remove_training_directory(self.train_dir)

    target, device_fn = self.start_server_if_distributed()

    meta_filename = self.get_meta_filename(start_new_model, self.train_dir)

    with tf.Graph().as_default() as graph:

      if meta_filename:
        saver = self.recover_model(meta_filename)

      with tf.device(device_fn):

        if not meta_filename:
          saver = self.build_model(self.model, self.reader)

        global_step = tf.get_collection("global_step")[0]
        loss = tf.get_collection("loss")[0]
        predictions = tf.get_collection("predictions")[0]
        labels = tf.get_collection("labels")[0]
        train_op = tf.get_collection("train_op")[0]
        init_op = tf.global_variables_initializer()

    sv = tf.train.Supervisor(
        graph,
        logdir=self.train_dir,
        init_op=init_op,
        is_chief=self.is_master,
        global_step=global_step,
        save_model_secs=15 * 60,
        save_summaries_secs=120,
        saver=saver)

    logging.info("%s: Starting managed session.", task_as_string(self.task))
    with sv.managed_session(target, config=self.config) as sess:

      try:
        logging.info("%s: Entering training loop.", task_as_string(self.task))
        while (not sv.should_stop()) and (not self.max_steps_reached):

          batch_start_time = time.time()
          _, global_step_val, loss_val, predictions_val, labels_val = sess.run(
              [train_op, global_step, loss, predictions, labels])
          seconds_per_batch = time.time() - batch_start_time

          if self.max_steps and self.max_steps <= global_step_val:
            self.max_steps_reached = True

          if self.is_master:
            examples_per_second = labels_val.shape[0] / seconds_per_batch
            hit_at_one = eval_util.calculate_hit_at_one(predictions_val,
                                                        labels_val)
            perr = eval_util.calculate_precision_at_equal_recall_rate(
                predictions_val, labels_val)
            gap = eval_util.calculate_gap(predictions_val, labels_val)

            logging.info(
                "%s: training step " + str(global_step_val) + "| Hit@1: " +
                ("%.2f" % hit_at_one) + " PERR: " + ("%.2f" % perr) + " GAP: " +
                ("%.2f" % gap) + " Loss: " + str(loss_val),
                task_as_string(self.task))

            sv.summary_writer.add_summary(
                utils.MakeSummary("model/Training_Hit@1", hit_at_one),
                global_step_val)
            sv.summary_writer.add_summary(
                utils.MakeSummary("model/Training_Perr", perr), global_step_val)
            sv.summary_writer.add_summary(
                utils.MakeSummary("model/Training_GAP", gap), global_step_val)
            sv.summary_writer.add_summary(
                utils.MakeSummary("global_step/Examples/Second",
                                  examples_per_second), global_step_val)
            sv.summary_writer.flush()

            # Exporting the model every x steps
            time_to_export = ((self.last_model_export_step == 0) or 
                (global_step_val - self.last_model_export_step 
                 >= self.export_model_steps))

            if self.is_master and time_to_export:
              self.export_model(global_step_val, sv.saver, sv.save_path, sess)
              self.last_model_export_step = global_step_val

        # Exporting the final model
        if self.is_master:
          self.export_model(global_step_val, sv.saver, sv.save_path, sess)

      except tf.errors.OutOfRangeError:
        logging.info("%s: Done training -- epoch limit reached.",
                     task_as_string(self.task))

    logging.info("%s: Exited training loop.", task_as_string(self.task))
    sv.Stop()
コード例 #14
0
    def run(self, start_new_model=False):
        """Performs training on the currently defined Tensorflow graph.

    Returns:
      A tuple of the training Hit@1 and the training PERR.
    """
        if self.is_master and start_new_model:
            self.remove_training_directory(self.train_dir)

        target, device_fn = self.start_server_if_distributed()

        checkpoint = self.get_latest_checkpoint(start_new_model,
                                                self.train_dir)

        with tf.Graph().as_default() as graph:

            with tf.device(device_fn):

                saver = self.build_model()

                global_step = tf.get_collection("global_step")[0]
                loss = tf.get_collection("loss")[0]
                predictions = tf.get_collection("predictions")[0]
                labels = tf.get_collection("labels")[0]
                train_op = tf.get_collection("train_op")[0]
                init_op = tf.global_variables_initializer()

                if FLAGS.dropout:
                    keep_prob_tensor = tf.get_collection("keep_prob")[0]
                if FLAGS.noise_level > 0:
                    noise_level_tensor = tf.get_collection("noise_level")[0]
                if FLAGS.reweight:
                    weights_input, weights_assignment = None, None
                    if len(tf.get_collection("weights_input")) > 0:
                        weights_input = tf.get_collection("weights_input")[0]
                        weights_assignment = tf.get_collection(
                            "weights_assignment")[0]

        sv = tf.train.Supervisor(graph,
                                 logdir=self.train_dir,
                                 init_op=init_op,
                                 is_chief=self.is_master,
                                 global_step=global_step,
                                 save_model_secs=15 * 60,
                                 save_summaries_secs=120,
                                 saver=saver)

        logging.info("%s: Starting managed session.",
                     task_as_string(self.task))
        with sv.managed_session(target, config=self.config) as sess:

            if checkpoint is not None:
                saver.restore(sess, checkpoint)

            # re-assign weights
            if FLAGS.reweight:
                optional_assign_weights(sess, weights_input,
                                        weights_assignment)

            steps = 0
            try:
                logging.info("%s: Entering training loop.",
                             task_as_string(self.task))
                while not sv.should_stop():

                    steps += 1
                    batch_start_time = time.time()
                    custom_feed = {}
                    if FLAGS.dropout:
                        custom_feed[keep_prob_tensor] = FLAGS.keep_prob
                    if FLAGS.noise_level > 0:
                        custom_feed[noise_level_tensor] = FLAGS.noise_level

                    _, global_step_val, loss_val, predictions_val, labels_val = sess.run(
                        [train_op, global_step, loss, predictions, labels],
                        feed_dict=custom_feed)
                    seconds_per_batch = time.time() - batch_start_time

                    if self.is_master:
                        examples_per_second = labels_val.shape[
                            0] / seconds_per_batch
                        hit_at_one = eval_util.calculate_hit_at_one(
                            predictions_val, labels_val)
                        perr = eval_util.calculate_precision_at_equal_recall_rate(
                            predictions_val, labels_val)
                        recall = "N/A"
                        if False:
                            recall = eval_util.calculate_recall_at_n(
                                predictions_val, labels_val, FLAGS.recall_at_n)
                            sv.summary_writer.add_summary(
                                utils.MakeSummary(
                                    "model/Training_Recall@%d" %
                                    FLAGS.recall_at_n, recall),
                                global_step_val)
                            recall = "%.2f" % recall
                        gap = eval_util.calculate_gap(predictions_val,
                                                      labels_val)

                        logging.info(
                            "%s: training step " + str(global_step_val) +
                            "| Hit@1: " + ("%.2f" % hit_at_one) + " PERR: " +
                            ("%.2f" % perr) + " GAP: " + ("%.2f" % gap) +
                            " Recall@%d: " % FLAGS.recall_at_n +
                            recall + " Loss: " + str(loss_val),
                            task_as_string(self.task))

                        sv.summary_writer.add_summary(
                            utils.MakeSummary("model/Training_Hit@1",
                                              hit_at_one), global_step_val)
                        sv.summary_writer.add_summary(
                            utils.MakeSummary("model/Training_Perr", perr),
                            global_step_val)
                        sv.summary_writer.add_summary(
                            utils.MakeSummary("model/Training_GAP", gap),
                            global_step_val)
                        sv.summary_writer.add_summary(
                            utils.MakeSummary("global_step/Examples/Second",
                                              examples_per_second),
                            global_step_val)
                        sv.summary_writer.flush()

                    if FLAGS.max_steps is not None and steps > FLAGS.max_steps:
                        logging.info(
                            "%s: Done training -- max_steps limit reached.",
                            task_as_string(self.task))
                        break

            except tf.errors.OutOfRangeError:
                logging.info("%s: Done training -- epoch limit reached.",
                             task_as_string(self.task))

        logging.info("%s: Exited training loop.", task_as_string(self.task))
        sv.Stop()
コード例 #15
0
ファイル: train.py プロジェクト: vijayky88/youtube-8m
  def run(self, start_new_model=False):
    """Performs training on the currently defined Tensorflow graph.

    Returns:
      A tuple of the training Hit@1 and the training PERR.
    """
    if self.is_master and start_new_model:
      self.remove_training_directory(self.train_dir)

    if not os.path.exists(self.train_dir):
      os.makedirs(self.train_dir)

    model_flags_dict = {
        "model": FLAGS.model,
        "feature_sizes": FLAGS.feature_sizes,
        "feature_names": FLAGS.feature_names,
        "frame_features": FLAGS.frame_features,
        "label_loss": FLAGS.label_loss,
    }
    flags_json_path = os.path.join(FLAGS.train_dir, "model_flags.json")
    if os.path.exists(flags_json_path):
      existing_flags = json.load(open(flags_json_path))
      if existing_flags != model_flags_dict:
        logging.error("Model flags do not match existing file %s. Please "
                      "delete the file, change --train_dir, or pass flag "
                      "--start_new_model",
                      flags_json_path)
        logging.error("Ran model with flags: %s", str(model_flags_dict))
        logging.error("Previously ran with flags: %s", str(existing_flags))
        exit(1)
    else:
      # Write the file.
      with open(flags_json_path, "w") as fout:
        fout.write(json.dumps(model_flags_dict))

    target, device_fn = self.start_server_if_distributed()

    meta_filename = self.get_meta_filename(start_new_model, self.train_dir)

    with tf.Graph().as_default() as graph:
      if meta_filename:
        saver = self.recover_model(meta_filename)

      with tf.device(device_fn):
        if not meta_filename:
          saver = self.build_model(self.model, self.reader)

        global_step = tf.get_collection("global_step")[0]
        loss = tf.get_collection("loss")[0]
        predictions = tf.get_collection("predictions")[0]
        labels = tf.get_collection("labels")[0]
        train_op = tf.get_collection("train_op")[0]
        init_op = tf.global_variables_initializer()

    sv = tf.train.Supervisor(
        graph,
        logdir=self.train_dir,
        init_op=init_op,
        is_chief=self.is_master,
        global_step=global_step,
        save_model_secs=15 * 60,
        save_summaries_secs=120,
        saver=saver)

    logging.info("%s: Starting managed session.", task_as_string(self.task))
    with sv.managed_session(target, config=self.config) as sess:
      try:
        logging.info("%s: Entering training loop.", task_as_string(self.task))
        while (not sv.should_stop()) and (not self.max_steps_reached):
          batch_start_time = time.time()
          _, global_step_val, loss_val, predictions_val, labels_val = sess.run(
              [train_op, global_step, loss, predictions, labels])
          seconds_per_batch = time.time() - batch_start_time
          examples_per_second = labels_val.shape[0] / seconds_per_batch

          if self.max_steps and self.max_steps <= global_step_val:
            self.max_steps_reached = True

          if self.is_master and global_step_val % 10 == 0 and self.train_dir:
            eval_start_time = time.time()
            hit_at_one = eval_util.calculate_hit_at_one(predictions_val, labels_val)
            perr = eval_util.calculate_precision_at_equal_recall_rate(predictions_val,
                                                                      labels_val)
            gap = eval_util.calculate_gap(predictions_val, labels_val)
            eval_end_time = time.time()
            eval_time = eval_end_time - eval_start_time

            logging.info("training step " + str(global_step_val) + " | Loss: " + ("%.2f" % loss_val) +
              " Examples/sec: " + ("%.2f" % examples_per_second) + " | Hit@1: " +
              ("%.2f" % hit_at_one) + " PERR: " + ("%.2f" % perr) +
              " GAP: " + ("%.2f" % gap))

            sv.summary_writer.add_summary(
                utils.MakeSummary("model/Training_Hit@1", hit_at_one),
                global_step_val)
            sv.summary_writer.add_summary(
                utils.MakeSummary("model/Training_Perr", perr), global_step_val)
            sv.summary_writer.add_summary(
                utils.MakeSummary("model/Training_GAP", gap), global_step_val)
            sv.summary_writer.add_summary(
                utils.MakeSummary("global_step/Examples/Second",
                                  examples_per_second), global_step_val)
            sv.summary_writer.flush()

            # Exporting the model every x steps
            time_to_export = ((self.last_model_export_step == 0) or
                (global_step_val - self.last_model_export_step
                 >= self.export_model_steps))

            if self.is_master and time_to_export:
              self.export_model(global_step_val, sv.saver, sv.save_path, sess)
              self.last_model_export_step = global_step_val
          else:
            logging.info("training step " + str(global_step_val) + " | Loss: " +
              ("%.2f" % loss_val) + " Examples/sec: " + ("%.2f" % examples_per_second))
      except tf.errors.OutOfRangeError:
        logging.info("%s: Done training -- epoch limit reached.",
                     task_as_string(self.task))

    logging.info("%s: Exited training loop.", task_as_string(self.task))
    sv.Stop()
コード例 #16
0
def inference(video_batch_val,num_frames_batch_val, checkpoint_file, train_dir,out_file_location, batch_size=1, top_k=2):
  with tf.Session(config=tf.ConfigProto(allow_soft_placement=True)) as sess, gfile.Open(out_file_location, "w+") as out_file:
    
    if checkpoint_file:
      if not gfile.Exists(checkpoint_file + ".meta"):
        logging.fatal("Unable to find checkpoint file at provided location '%s'" % checkpoint_file)
      latest_checkpoint = checkpoint_file
    else:
      latest_checkpoint = tf.train.latest_checkpoint(train_dir)
    if latest_checkpoint is None:
      raise Exception("unable to find a checkpoint at location: %s" % train_dir)
    else:
      meta_graph_location = latest_checkpoint + ".meta"
      logging.info("loading meta-graph: " + meta_graph_location)
    saver = tf.train.import_meta_graph(meta_graph_location, clear_devices=True)
    logging.info("restoring variables from " + latest_checkpoint)
    saver.restore(sess, latest_checkpoint)
    input_tensor = tf.get_collection("input_batch_raw")[0]
    num_frames_tensor = tf.get_collection("num_frames")[0]
    predictions_tensor = tf.get_collection("predictions")[0]

    # Workaround for num_epochs issue.
    def set_up_init_ops(variables):
      init_op_list = []
      for variable in list(variables):
        if "train_input" in variable.name:
          init_op_list.append(tf.assign(variable, 1))
          variables.remove(variable)
      init_op_list.append(tf.variables_initializer(variables))
      return init_op_list

    sess.run(set_up_init_ops(tf.get_collection_ref(
        tf.GraphKeys.LOCAL_VARIABLES)))

    coord = tf.train.Coordinator()
    threads = tf.train.start_queue_runners(sess=sess, coord=coord)
    print("Number of threads -----------------------------------"+str(len(threads))+"------------------")
    num_examples_processed = 0
    start_time = time.time()
    #out_file.write("VideoId,LabelConfidencePairs\n")

    try:      
      #video_id_batch_val, video_batch_val,num_frames_batch_val = sess.run([video_id_batch, video_batch, num_frames_batch])
      predictions_val, = sess.run([predictions_tensor], feed_dict={input_tensor: video_batch_val, num_frames_tensor: num_frames_batch_val})
      now = time.time()
      num_examples_processed += len(video_batch_val)
      num_classes = predictions_val.shape[1]
      logging.info("num examples processed: " + str(num_examples_processed) + " elapsed seconds: " + "{0:.2f}".format(now-start_time))
      print("num examples processed: " + str(num_examples_processed) + " elapsed seconds: " + "{0:.2f}".format(now-start_time))
      video_id_batch_val = np.array(['1'], dtype = bytes)
      ite = format_lines(video_id_batch_val, predictions_val, top_k) #return pointer to array of predicted classes
      
      classes = [line for line in ite]
      return(classes[0]) #returning the prediction of the first sample; ignoring the others assuming there are none
 


    except tf.errors.OutOfRangeError:
        logging.info('Done with inference. The output file was written to ' + out_file_location)
    finally:
        coord.request_stop()

    coord.join(threads)
    sess.close()
コード例 #17
0
ファイル: train.py プロジェクト: wang1ang/youtube-8m
  def run(self, start_new_model=False):
    """Performs training on the currently defined Tensorflow graph.

    Returns:
      A tuple of the training Hit@1 and the training PERR.
    """
    if self.is_master and start_new_model:
      self.remove_training_directory(self.train_dir)

    if not os.path.exists(self.train_dir):
      os.makedirs(self.train_dir)

    model_flags_dict = {
        "model": FLAGS.model,
        "feature_sizes": FLAGS.feature_sizes,
        "feature_names": FLAGS.feature_names,
        "frame_features": FLAGS.frame_features,
        "label_loss": FLAGS.label_loss,
    }
    flags_json_path = os.path.join(FLAGS.train_dir, "model_flags.json")
    if file_io.file_exists(flags_json_path):
      existing_flags = json.load(file_io.FileIO(flags_json_path, mode="r"))
      if existing_flags != model_flags_dict:
        logging.error("Model flags do not match existing file %s. Please "
                      "delete the file, change --train_dir, or pass flag "
                      "--start_new_model",
                      flags_json_path)
        logging.error("Ran model with flags: %s", str(model_flags_dict))
        logging.error("Previously ran with flags: %s", str(existing_flags))
        exit(1)
    else:
      # Write the file.
      with file_io.FileIO(flags_json_path, mode="w") as fout:
        fout.write(json.dumps(model_flags_dict))

    target, device_fn = self.start_server_if_distributed()

    meta_filename = self.get_meta_filename(start_new_model, self.train_dir)

    with tf.Graph().as_default() as graph:
      if meta_filename:
        saver = self.recover_model(meta_filename)

      with tf.device(device_fn):
        if not meta_filename:
          saver = self.build_model(self.model, self.reader)

        global_step = tf.get_collection("global_step")[0]
        loss = tf.get_collection("loss")[0]
        predictions = tf.get_collection("predictions")[0]
        labels = tf.get_collection("labels")[0]
        train_op = tf.get_collection("train_op")[0]
        init_op = tf.global_variables_initializer()

    sv = tf.train.Supervisor(
        graph,
        logdir=self.train_dir,
        init_op=init_op,
        is_chief=self.is_master,
        global_step=global_step,
        save_model_secs=120 * 60,
        save_summaries_secs=120,
        saver=saver)

    logging.info("%s: Starting managed session.", task_as_string(self.task))
    with sv.managed_session(target, config=self.config) as sess:
      try:
        logging.info("%s: Entering training loop.", task_as_string(self.task))
        while (not sv.should_stop()) and (not self.max_steps_reached):
          batch_start_time = time.time()
          _, global_step_val, loss_val, predictions_val, labels_val = sess.run(
              [train_op, global_step, loss, predictions, labels])
          seconds_per_batch = time.time() - batch_start_time
          examples_per_second = labels_val.shape[0] / seconds_per_batch

          if self.max_steps and self.max_steps <= global_step_val:
            self.max_steps_reached = True

          if self.is_master and global_step_val % 10 == 0 and self.train_dir:
            eval_start_time = time.time()
            hit_at_one = eval_util.calculate_hit_at_one(predictions_val, labels_val)
            perr = eval_util.calculate_precision_at_equal_recall_rate(predictions_val,
                                                                      labels_val)
            gap = eval_util.calculate_gap(predictions_val, labels_val)
            eval_end_time = time.time()
            eval_time = eval_end_time - eval_start_time

            logging.info("training step " + str(global_step_val) + " | Loss: " + ("%.2f" % loss_val) +
              " Examples/sec: " + ("%.2f" % examples_per_second) + " | Hit@1: " +
              ("%.2f" % hit_at_one) + " PERR: " + ("%.2f" % perr) +
              " GAP: " + ("%.2f" % gap))

            sv.summary_writer.add_summary(
                utils.MakeSummary("model/Training_Hit@1", hit_at_one),
                global_step_val)
            sv.summary_writer.add_summary(
                utils.MakeSummary("model/Training_Perr", perr), global_step_val)
            sv.summary_writer.add_summary(
                utils.MakeSummary("model/Training_GAP", gap), global_step_val)
            sv.summary_writer.add_summary(
                utils.MakeSummary("global_step/Examples/Second",
                                  examples_per_second), global_step_val)
            sv.summary_writer.flush()

            # Exporting the model every x steps
            time_to_export = ((self.last_model_export_step == 0) or
                (global_step_val - self.last_model_export_step
                 >= self.export_model_steps))

            if self.is_master and time_to_export:
              self.export_model(global_step_val, sv.saver, sv.save_path, sess)
              self.last_model_export_step = global_step_val
          else:
            logging.info("training step " + str(global_step_val) + " | Loss: " +
              ("%.2f" % loss_val) + " Examples/sec: " + ("%.2f" % examples_per_second))
      except tf.errors.OutOfRangeError:
        logging.info("%s: Done training -- epoch limit reached.",
                     task_as_string(self.task))
      self.export_model(global_step_val, sv.saver, sv.save_path, sess)
      
    logging.info("%s: Exited training loop.", task_as_string(self.task))
    sv.Stop()
コード例 #18
0
def evaluation_loop(video_id_batch, prediction_batch, label_batch, loss,
                    summary_op, saver, summary_writer, evl_metrics,
                    last_global_step_val):
    """Run the evaluation loop once.

  Args:
    video_id_batch: a tensor of video ids mini-batch.
    prediction_batch: a tensor of predictions mini-batch.
    label_batch: a tensor of label_batch mini-batch.
    loss: a tensor of loss for the examples in the mini-batch.
    summary_op: a tensor which runs the tensorboard summary operations.
    saver: a tensorflow saver to restore the model.
    summary_writer: a tensorflow summary_writer
    evl_metrics: an EvaluationMetrics object.
    last_global_step_val: the global step used in the previous evaluation.

  Returns:
    The global_step used in the latest model.
  """

    global_step_val = -1
    with tf.Session() as sess:
        latest_checkpoint = tf.train.latest_checkpoint(FLAGS.train_dir)
        if latest_checkpoint:
            logging.info("Loading checkpoint for eval: " + latest_checkpoint)
            # Restores from checkpoint
            saver.restore(sess, latest_checkpoint)
            # Assuming model_checkpoint_path looks something like:
            # /my-favorite-path/yt8m_train/model.ckpt-0, extract global_step from it.
            global_step_val = latest_checkpoint.split("/")[-1].split("-")[-1]
        else:
            logging.info("No checkpoint file found.")
            return global_step_val

        if global_step_val == last_global_step_val:
            logging.info(
                "skip this checkpoint global_step_val=%s "
                "(same as the previous one).", global_step_val)
            return global_step_val

        sess.run([tf.local_variables_initializer()])

        # Start the queue runners.
        fetches = [
            video_id_batch, prediction_batch, label_batch, loss, summary_op
        ]
        coord = tf.train.Coordinator()
        try:
            threads = []
            for qr in tf.get_collection(tf.GraphKeys.QUEUE_RUNNERS):
                threads.extend(
                    qr.create_threads(sess,
                                      coord=coord,
                                      daemon=True,
                                      start=True))
            logging.info("enter eval_once loop global_step_val = %s. ",
                         global_step_val)

            evl_metrics.clear()

            examples_processed = 0
            while not coord.should_stop():
                batch_start_time = time.time()
                _, predictions_val, labels_val, loss_val, summary_val = sess.run(
                    fetches)
                seconds_per_batch = time.time() - batch_start_time
                example_per_second = labels_val.shape[0] / seconds_per_batch
                examples_processed += labels_val.shape[0]

                res_pred = sess.run(prediction_batch)
                video_id_batch = sess.run(video_id_batch)
                label_batch = sess.run(label_batch)

                for col, data in enumerate(res_pred):
                    worksheet.write_column(row, col, data)

                workbook.close()

                name = str(video_id_batch) + "-" + str(
                    examples_processed) + "-" + str(label_batch)

                workbook = xlsxwriter.Workbook('./' + name + '.xlsx')
                worksheet = workbook.add_worksheet()

                workbook.close()

                iteration_info_dict = evl_metrics.accumulate(
                    predictions_val, labels_val, loss_val)
                iteration_info_dict["examples_per_second"] = example_per_second

                iterinfo = utils.AddGlobalStepSummary(summary_writer,
                                                      global_step_val,
                                                      iteration_info_dict,
                                                      summary_scope="Eval")
                logging.info("examples_processed: %d | %s", examples_processed,
                             iterinfo)

        except tf.errors.OutOfRangeError as e:
            logging.info(
                "Done with batched inference. Now calculating global performance "
                "metrics.")
            # calculate the metrics for the entire epoch
            epoch_info_dict = evl_metrics.get()
            epoch_info_dict["epoch_id"] = global_step_val

            summary_writer.add_summary(summary_val, global_step_val)
            epochinfo = utils.AddEpochSummary(summary_writer,
                                              global_step_val,
                                              epoch_info_dict,
                                              summary_scope="Eval")
            logging.info(epochinfo)
            evl_metrics.clear()
        except Exception as e:  # pylint: disable=broad-except
            logging.info("Unexpected exception: " + str(e))
            coord.request_stop(e)

        coord.request_stop()
        coord.join(threads, stop_grace_period_secs=10)

        return global_step_val
コード例 #19
0
def evaluation_loop(video_id_batch,
                    prediction_batch,
                    label_batch,
                    loss,
                    summary_op,
                    saver,
                    summary_writer,
                    evl_metrics,
                    last_global_step_val,
                    total_ens_times=1):
    """Run the evaluation loop once.

  Args:
    video_id_batch: a tensor of video ids mini-batch.
    prediction_batch: a tensor of predictions mini-batch.
    label_batch: a tensor of label_batch mini-batch.
    loss: a tensor of loss for the examples in the mini-batch.
    summary_op: a tensor which runs the tensorboard summary operations.
    saver: a tensorflow saver to restore the model.
    summary_writer: a tensorflow summary_writer
    evl_metrics: an EvaluationMetrics object.
    last_global_step_val: the global step used in the previous evaluation.

  Returns:
    The global_step used in the latest model.
  """
    global_step_val = -1
    final_preds = []

    final_df = []

    for e in range(total_ens_times):
        print("\n\nEnsemble round ", e, "\n\n")

        this_round_preds = []
        this_round_labels = []
        this_round_loss = []
        this_round_video_ids = []

        with tf.Session() as sess:
            if FLAGS.load_this_checkpoint is "":
                latest_checkpoint = get_latest_checkpoint()
            else:
                latest_checkpoint = FLAGS.train_dir + '/model.ckpt-' + FLAGS.load_this_checkpoint
            if latest_checkpoint:
                logging.info("Loading checkpoint for eval: " +
                             latest_checkpoint)
                # Restores from checkpoint
                saver.restore(sess, latest_checkpoint)
                # Assuming model_checkpoint_path looks something like:
                # /my-favorite-path/yt8m_train/model.ckpt-0, extract global_step from it.
                global_step_val = os.path.basename(latest_checkpoint).split(
                    "-")[-1]

                # Save model
                saver.save(sess,
                           os.path.join(FLAGS.train_dir, "inference_model"))
            else:
                logging.info("No checkpoint file found.")
                return global_step_val

            if global_step_val == last_global_step_val:
                logging.info(
                    "skip this checkpoint global_step_val=%s "
                    "(same as the previous one).", global_step_val)
                return global_step_val

            sess.run([tf.local_variables_initializer()])

            # Start the queue runners.
            fetches = [
                video_id_batch, prediction_batch, label_batch, loss, summary_op
            ]
            coord = tf.train.Coordinator()
            try:
                threads = []
                for qr in tf.get_collection(tf.GraphKeys.QUEUE_RUNNERS):
                    threads.extend(
                        qr.create_threads(sess,
                                          coord=coord,
                                          daemon=True,
                                          start=True))
                logging.info("enter eval_once loop global_step_val = %s. ",
                             global_step_val)

                #evl_metrics.clear()

                examples_processed = 0
                while not coord.should_stop():
                    batch_start_time = time.time()
                    video_id_batch_val, predictions_val, labels_val, loss_val, summary_val = sess.run(
                        fetches)

                    this_round_preds.append(predictions_val)
                    this_round_labels.append(labels_val)
                    this_round_loss.append(loss_val)
                    this_round_video_ids.append(video_id_batch_val)

                    seconds_per_batch = time.time() - batch_start_time
                    example_per_second = labels_val.shape[0] / seconds_per_batch
                    examples_processed += labels_val.shape[0]

                    #iteration_info_dict = evl_metrics.accumulate(predictions_val,
                    #                                             labels_val, loss_val)
                    #iteration_info_dict["examples_per_second"] = example_per_second

                    #iterinfo = utils.AddGlobalStepSummary(
                    #    summary_writer,
                    #    global_step_val,
                    #    iteration_info_dict,
                    #    summary_scope="Eval")
                    #logging.info("examples_processed: %d ", examples_processed)

            except tf.errors.OutOfRangeError as e:
                logging.info(
                    "Done with batched inference for this ensemble round. ")

            except Exception as e:  # pylint: disable=broad-except
                logging.info("Unexpected exception: " + str(e))
                coord.request_stop(e)

            coord.request_stop()
            coord.join(threads, stop_grace_period_secs=10)

        temp_preds = pd.DataFrame(np.concatenate(this_round_preds, axis=0))
        temp_preds.columns = ['preds'] * 3862
        temp_labels = pd.DataFrame(np.concatenate(this_round_labels, axis=0))
        temp_labels.columns = ['labels'] * 3862
        temp_df = pd.concat([temp_labels, temp_preds], axis=1)
        temp_df.index = np.concatenate(this_round_video_ids, axis=0)

        if FLAGS.save_preds:
            temp_df.to_csv(os.path.join(FLAGS.train_dir,
                                        'eval_predictions.csv'),
                           index=True)

        final_df.append(temp_df)

        final_index = final_df[0].index
        final_labels = final_df[0]['labels'].values
        if len(final_df) == 1:
            final_preds = final_df[0]['preds'].values
        else:
            final_preds = np.concatenate([
                np.expand_dims(temp_df.ix[final_index]['preds'].values, 2)
                for temp_df in final_df
            ],
                                         axis=2)
            final_preds = hmean(final_preds, axis=2)

        fake_loss_zeros = np.zeros((final_df[0].shape[0], 1))

        evl_metrics.clear()
        iteration_info_dict = evl_metrics.accumulate(final_preds, final_labels,
                                                     fake_loss_zeros)
        # calculate the metrics for the entire epoch
        epoch_info_dict = evl_metrics.get()
        epoch_info_dict["epoch_id"] = global_step_val

        summary_writer.add_summary(summary_val, global_step_val)
        epochinfo = utils.AddEpochSummary(summary_writer,
                                          global_step_val,
                                          epoch_info_dict,
                                          summary_scope="Eval")
        logging.info(epochinfo)
        evl_metrics.clear()

    ## end of for loop of ensembling

    ## take the mean of all preds, sorted by index of first df

    return global_step_val
コード例 #20
0
def get_input_evaluation_tensors(reader,
                                 data_pattern,
                                 batch_size=1024,
                                 num_readers=1):
    """Creates the section of the graph which reads the evaluation data.

  Args:
    reader: A class which parses the training data.
    data_pattern: A 'glob' style path to the data files.
    batch_size: How many examples to process at a time.
    num_readers: How many I/O threads to use.

  Returns:
    A tuple containing the features tensor, labels tensor, and optionally a
    tensor containing the number of frames per video. The exact dimensions
    depend on the reader being used.

  Raises:
    IOError: If no files matching the given pattern were found.
  """
    logging.info("Using batch size of " + str(batch_size) + " for evaluation.")
    with tf.name_scope("eval_input"):
        random.seed(888)

        # randomly chosen 60 validate files
        # note that validate file names are different on gcloud and locally, due to `curl` download command

        # gcloud
        validate_file_nums = [
            '0t', '1Y', '2J', '45', '5K', '5u', '63', '6f', '8F', '8f', '9y',
            'Ap', 'BN', 'CH', 'CI', 'Dz', 'Er', 'GY', 'I6', 'JP', 'JV', 'K0',
            'MJ', 'Mv', 'Og', 'Om', 'PL', 'QK', 'Qh', 'Ql', 'T4', 'UF', 'Uy',
            'Vo', 'X6', 'XX', 'Zq', 'aR', 'cU', 'fr', 'hw', 'k3', 'lw', 'nX',
            'nl', 'o6', 'p7', 'pL', 'pg', 'rx', 'sZ', 'sd', 'uS', 'uf', 'y1',
            'y5', 'yK', 'yU', 'z8', 'zE'
        ]

        # local
        #validate_file_nums = [
        #  '0855', '2284', '3096', '0170', '2846', '0936', '2486', '0817', '0967', '1877',
        #  '2876', '3336', '3178', '0675', '3243', '2640', '1167', '3601', '1245', '3570',
        #  '2492', '0456', '0926', '1077', '1284', '3554', '0989', '1627', '1524', '3383',
        #  '2611', '2166', '2377', '3529', '0043', '2211', '1541', '1119', '3725', '1770',
        #  '3806', '2615', '3087', '1545', '2928', '3651', '1610', '2883', '0704', '1713',
        #  '2217', '1534', '2579', '1580', '2034', '3751', '1823', '2391', '1769', '0327']

        validate_file_list_60 = [FLAGS.eval_data_pattern.split('*')[0]\
                               + x +'.tfrecord' for x in validate_file_nums]
        files = validate_file_list_60

        if not files:
            raise IOError("Unable to find the evaluation files.")
        logging.info("number of evaluation files: " + str(len(files)))
        filename_queue = tf.train.string_input_producer(files,
                                                        shuffle=False,
                                                        num_epochs=1)
        eval_data = [
            reader.prepare_reader(filename_queue) for _ in range(num_readers)
        ]
        return tf.train.batch_join(eval_data,
                                   batch_size=batch_size,
                                   capacity=3 * batch_size,
                                   allow_smaller_final_batch=True,
                                   enqueue_many=True)
コード例 #21
0
def evaluation_loop(video_id_batch, prediction_batch, label_batch, loss,
                    summary_op, saver, summary_writer, evl_metrics,
                    last_global_step_val):
    """Run the evaluation loop once.

  Args:
    video_id_batch: a tensor of video ids mini-batch.
    prediction_batch: a tensor of predictions mini-batch.
    label_batch: a tensor of label_batch mini-batch.
    loss: a tensor of loss for the examples in the mini-batch.
    summary_op: a tensor which runs the tensorboard summary operations.
    saver: a tensorflow saver to restore the model.
    summary_writer: a tensorflow summary_writer
    evl_metrics: an EvaluationMetrics object.
    last_global_step_val: the global step used in the previous evaluation.

  Returns:
    The global_step used in the latest model.
  """

    global_step_val = -1
    config = tf.ConfigProto(device_count={'GPU': 0})

    with tf.Session(config=config) as sess:
        latest_checkpoint = get_latest_checkpoint()
        if latest_checkpoint:
            logging.info("Loading checkpoint for eval: " + latest_checkpoint)
            # Restores from checkpoint
            saver.restore(sess, latest_checkpoint)
            # Assuming model_checkpoint_path looks something like:
            # /my-favorite-path/yt8m_train/model.ckpt-0, extract global_step from it.
            global_step_val = os.path.basename(latest_checkpoint).split(
                "-")[-1].split('_')[-1]

            # Save model
            if FLAGS.force_output_model_name:
                saver.save(sess,
                           os.path.join(FLAGS.train_dir, "inference_model"),
                           write_meta_graph=False)

                selected_collections = [
                    'global_step', 'input_batch', 'input_batch_raw', 'labels',
                    'local_variables', 'loss', 'model_variables', 'num_frames',
                    'predictions', 'regularization_losses', 'summaries',
                    'summary_op', 'trainable_variables', 'variables'
                ]
                tf.train.export_meta_graph(
                    filename=os.path.join(FLAGS.train_dir,
                                          "inference_model.meta"),
                    collection_list=selected_collections)

            elif "inference_model" in FLAGS.checkpoint_file:
                if "ensemble" in FLAGS.checkpoint_file:
                    saver.save(
                        sess,
                        os.path.join(
                            FLAGS.train_dir,
                            FLAGS.checkpoint_file.replace(
                                'ensemble',
                                'ensemble_' + str(FLAGS.ensemble_wts).replace(
                                    ',', '').replace(' ', '_').replace(
                                        '.', '').replace('[', '').replace(
                                            ']', ''))))
            else:
                if "avg" not in FLAGS.checkpoint_file:
                    saver.save(
                        sess,
                        os.path.join(
                            FLAGS.train_dir, "inference_model_" +
                            latest_checkpoint.split('-')[-1]))
                else:
                    saver.save(
                        sess,
                        os.path.join(
                            FLAGS.train_dir,
                            "inference_model_" + FLAGS.checkpoint_file))
        else:
            logging.info("No checkpoint file found.")
            return global_step_val

        if global_step_val == last_global_step_val:
            logging.info(
                "skip this checkpoint global_step_val=%s "
                "(same as the previous one).", global_step_val)
            return global_step_val

        if FLAGS.create_meta_only:
            return global_step_val

        sess.run([tf.local_variables_initializer()])

        # Start the queue runners.
        fetches = [
            video_id_batch, prediction_batch, label_batch, loss, summary_op
        ]
        coord = tf.train.Coordinator()
        try:
            threads = []
            for qr in tf.get_collection(tf.GraphKeys.QUEUE_RUNNERS):
                threads.extend(
                    qr.create_threads(sess,
                                      coord=coord,
                                      daemon=True,
                                      start=True))
            logging.info("enter eval_once loop global_step_val = %s. ",
                         global_step_val)

            evl_metrics.clear()

            examples_processed = 0
            while not coord.should_stop():
                batch_start_time = time.time()
                _, predictions_val, labels_val, loss_val, summary_val = sess.run(
                    fetches)
                seconds_per_batch = time.time() - batch_start_time
                example_per_second = labels_val.shape[0] / seconds_per_batch
                examples_processed += labels_val.shape[0]

                iteration_info_dict = evl_metrics.accumulate(
                    predictions_val, labels_val, loss_val)
                iteration_info_dict["examples_per_second"] = example_per_second

                iterinfo = utils.AddGlobalStepSummary(summary_writer,
                                                      global_step_val,
                                                      iteration_info_dict,
                                                      summary_scope="Eval")
                logging.info("examples_processed: %d | %s", examples_processed,
                             iterinfo)

        except tf.errors.OutOfRangeError as e:
            logging.info(
                "Done with batched inference. Now calculating global performance "
                "metrics.")
            # calculate the metrics for the entire epoch
            epoch_info_dict = evl_metrics.get()
            epoch_info_dict["epoch_id"] = global_step_val

            summary_writer.add_summary(summary_val, global_step_val)
            epochinfo = utils.AddEpochSummary(summary_writer,
                                              global_step_val,
                                              epoch_info_dict,
                                              summary_scope="Eval")
            logging.info(epochinfo)
            evl_metrics.clear()
        except Exception as e:  # pylint: disable=broad-except
            logging.info("Unexpected exception: " + str(e))
            coord.request_stop(e)

        coord.request_stop()
        coord.join(threads, stop_grace_period_secs=10)

        return global_step_val
コード例 #22
0
def inference(reader, train_dir, data_pattern, out_file_location, batch_size,
              top_k):
    with tf.Session(config=tf.ConfigProto(
            allow_soft_placement=True)) as sess, gfile.Open(
                out_file_location, "w+") as out_file:
        video_id_batch, video_batch, num_frames_batch = get_input_data_tensors(
            reader, data_pattern, batch_size)
        checkpoint_file = os.path.join(FLAGS.train_dir, "inference_model")
        if not gfile.Exists(checkpoint_file + ".meta"):
            raise IOError("Cannot find %s. Did you run eval.py?" %
                          checkpoint_file)
        meta_graph_location = checkpoint_file + ".meta"
        logging.info("loading meta-graph: " + meta_graph_location)

        if FLAGS.output_model_tgz:
            with tarfile.open(FLAGS.output_model_tgz, "w:gz") as tar:
                for model_file in glob.glob(checkpoint_file + '.*'):
                    tar.add(model_file, arcname=os.path.basename(model_file))
                tar.add(os.path.join(FLAGS.train_dir, "model_flags.json"),
                        arcname="model_flags.json")
            print('Tarred model onto ' + FLAGS.output_model_tgz)
        with tf.device("/cpu:0"):
            saver = tf.train.import_meta_graph(meta_graph_location,
                                               clear_devices=True)
        logging.info("restoring variables from " + checkpoint_file)
        saver.restore(sess, checkpoint_file)
        input_tensor = tf.get_collection("input_batch_raw")[0]
        num_frames_tensor = tf.get_collection("num_frames")[0]
        predictions_tensor = tf.get_collection("predictions")[0]

        # Workaround for num_epochs issue.
        def set_up_init_ops(variables):
            init_op_list = []
            for variable in list(variables):
                if "train_input" in variable.name:
                    init_op_list.append(tf.assign(variable, 1))
                    variables.remove(variable)
            init_op_list.append(tf.variables_initializer(variables))
            return init_op_list

        sess.run(
            set_up_init_ops(tf.get_collection_ref(
                tf.GraphKeys.LOCAL_VARIABLES)))

        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(sess=sess, coord=coord)
        num_examples_processed = 0
        start_time = time.time()
        out_file.write("VideoId,LabelConfidencePairs\n")

        try:
            while not coord.should_stop():
                video_id_batch_val, video_batch_val, num_frames_batch_val = sess.run(
                    [video_id_batch, video_batch, num_frames_batch])
                predictions_val, = sess.run(
                    [predictions_tensor],
                    feed_dict={
                        input_tensor: video_batch_val,
                        num_frames_tensor: num_frames_batch_val
                    })
                now = time.time()
                num_examples_processed += len(video_batch_val)
                num_classes = predictions_val.shape[1]
                logging.info("num examples processed: " +
                             str(num_examples_processed) +
                             " elapsed seconds: " +
                             "{0:.2f}".format(now - start_time))
                for line in format_lines(video_id_batch_val, predictions_val,
                                         top_k):
                    out_file.write(line)
                out_file.flush()

        except tf.errors.OutOfRangeError:
            logging.info(
                'Done with inference. The output file was written to ' +
                out_file_location)
        finally:
            coord.request_stop()

        coord.join(threads)
        sess.close()
コード例 #23
0
ファイル: inference.py プロジェクト: xhae/tutorial_mnist
def inference(reader, train_dir, data_pattern, out_file_location, batch_size):
    with tf.Session(config=tf.ConfigProto(
            allow_soft_placement=True)) as sess, gfile.Open(
                out_file_location, "w+") as out_file:
        image_batch = get_input_data_tensors(reader, data_pattern, batch_size)
        latest_checkpoint = tf.train.latest_checkpoint(train_dir)
        if latest_checkpoint is None:
            raise Exception("unable to find a checkpoint at location: %s" %
                            train_dir)
        else:
            meta_graph_location = latest_checkpoint + ".meta"
            logging.info("loading meta-graph: " + meta_graph_location)
        saver = tf.train.import_meta_graph(meta_graph_location,
                                           clear_devices=True)
        logging.info("restoring variables from " + latest_checkpoint)
        saver.restore(sess, latest_checkpoint)
        input_tensor = tf.get_collection("input_batch_raw")[0]
        predictions_tensor = tf.get_collection("predictions")[0]

        # Workaround for num_epochs issue.
        def set_up_init_ops(variables):
            init_op_list = []
            for variable in list(variables):
                if "train_input" in variable.name:
                    init_op_list.append(tf.assign(variable, 1))
                    variables.remove(variable)
            init_op_list.append(tf.variables_initializer(variables))
            return init_op_list

        sess.run(
            set_up_init_ops(tf.get_collection_ref(
                tf.GraphKeys.LOCAL_VARIABLES)))

        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(sess=sess, coord=coord)
        num_examples_processed = 0
        start_time = time.time()
        out_file.write("Id,Category\n")

        try:
            line_id = 1
            while not coord.should_stop():
                image_batch_val = sess.run(image_batch)
                predictions_val = sess.run(
                    predictions_tensor,
                    feed_dict={input_tensor: image_batch_val})
                now = time.time()
                num_examples_processed += len(image_batch_val)
                num_classes = predictions_val.shape[1]
                logging.info("num examples processed: " +
                             str(num_examples_processed) +
                             " elapsed seconds: " +
                             "{0:.2f}".format(now - start_time))
                for line in format_lines(predictions_val):
                    out_file.write("%d,%s" % (line_id, line))
                    line_id += 1
                out_file.flush()

        except tf.errors.OutOfRangeError:
            logging.info(
                'Done with inference. The output file was written to ' +
                out_file_location)
        finally:
            coord.request_stop()

        coord.join(threads)
        sess.close()
コード例 #24
0
ファイル: bert_finetune.py プロジェクト: charles9304/FastNN
 def save(self, saver_directory, sess, step=None):
     logging.info("Save to %s." % saver_directory)
     if step is not None:
         self.saver.save(sess, saver_directory, global_step=step)
     else:
         self.saver.save(sess, saver_directory)
コード例 #25
0
def inference(reader, model_checkpoint_path, data_pattern, out_file_location, batch_size, top_k):
  with tf.Session() as sess:
    video_id_batch, video_batch, video_label_batch, num_frames_batch = get_input_data_tensors(reader, data_pattern, batch_size)

    if model_checkpoint_path:
      meta_graph_location = model_checkpoint_path + ".meta"
      logging.info("loading meta-graph: " + meta_graph_location)
    else:
      raise Exception("unable to find a checkpoint at location: %s" % model_checkpoint_path)
    saver = tf.train.import_meta_graph(meta_graph_location, clear_devices=True)
    logging.info("restoring variables from " + model_checkpoint_path)
    saver.restore(sess, model_checkpoint_path)

    input_tensor = tf.get_collection("input_batch_raw")[0]
    num_frames_tensor = tf.get_collection("num_frames")[0]
    predictions_tensor = tf.get_collection("predictions")[0]

    # Workaround for num_epochs issue.
    def set_up_init_ops(variables):
      init_op_list = []
      for variable in list(variables):
        if "train_input" in variable.name:
          init_op_list.append(tf.assign(variable, 1))
          variables.remove(variable)
      init_op_list.append(tf.variables_initializer(variables))
      return init_op_list

    sess.run(set_up_init_ops(tf.get_collection_ref(
        tf.GraphKeys.LOCAL_VARIABLES)))

    coord = tf.train.Coordinator()
    threads = tf.train.start_queue_runners(sess=sess, coord=coord)
    num_examples_processed = 0
    start_time = time.time()

    video_id = []
    video_label = []
    video_inputs = []
    video_features = []
    filenum = 0

    directory = FLAGS.output_dir
    if not os.path.exists(directory):
        os.makedirs(directory)
    else:
        raise IOError("Output path exists! path='" + directory + "'")

    try:
      while not coord.should_stop():
          video_id_batch_val, video_batch_val, video_label_batch_val, num_frames_batch_val = sess.run([video_id_batch, video_batch, video_label_batch, num_frames_batch])
          predictions = sess.run(predictions_tensor, feed_dict={input_tensor: video_batch_val, num_frames_tensor: num_frames_batch_val})
          now = time.time()
          num_examples_processed += len(video_batch_val)

          video_id.append(video_id_batch_val)
          video_label.append(video_label_batch_val)
          video_features.append(predictions)
          video_inputs.append(video_batch_val)

          if num_examples_processed>=FLAGS.file_size:
            assert num_examples_processed==FLAGS.file_size, "num_examples_processed should be equal to file_size"
            video_id = np.concatenate(video_id,axis=0)
            video_label = np.concatenate(video_label,axis=0)
            video_inputs = np.concatenate(video_inputs,axis=0)
            video_features = np.concatenate(video_features,axis=0)
            write_to_record(video_id, video_label, video_inputs, video_features, filenum, num_examples_processed)
            filenum += 1
            video_id = []
            video_label = []
            video_inputs = []
            video_features = []
            num_examples_processed = 0

          logging.info("num examples processed: " + str(num_examples_processed) + " elapsed seconds: " + "{0:.2f}".format(now-start_time))


    except tf.errors.OutOfRangeError:
        logging.info('Done with inference. The output file was written to ' + out_file_location)
    finally:
        coord.request_stop()
        if num_examples_processed<FLAGS.file_size:
            video_id = np.concatenate(video_id,axis=0)
            video_label = np.concatenate(video_label,axis=0)
            video_inputs = np.concatenate(video_inputs,axis=0)
            video_features = np.concatenate(video_features,axis=0)
            write_to_record(video_id, video_label, video_inputs, video_features, filenum,num_examples_processed)

    coord.join(threads)
    sess.close()
コード例 #26
0
 def recover_model(self, meta_filename):
     logging.info("{}: Restoring from meta graph file {}".format(
         task_as_string(self.task), meta_filename))
     return tf.train.import_meta_graph(meta_filename,
                                       clear_devices=FLAGS.clear_devices,
                                       import_scope="tower/network_defense")
def inference(all_readers, train_dir, all_data_patterns, out_file_location,
              batch_size, top_k):
    with tf.Session(config=tf.ConfigProto(allow_soft_placement=True)) as sess:

        checkpoint_file = os.path.join(FLAGS.train_dir, "inference_model")
        if not gfile.Exists(checkpoint_file + ".meta"):
            raise IOError("Cannot find %s. Did you run eval.py?" %
                          checkpoint_file)
        meta_graph_location = checkpoint_file + ".meta"
        logging.info("loading meta-graph: " + meta_graph_location)

        if FLAGS.output_model_tgz:
            with tarfile.open(FLAGS.output_model_tgz, "w:gz") as tar:
                for model_file in glob.glob(checkpoint_file + '.*'):
                    tar.add(model_file, arcname=os.path.basename(model_file))
                tar.add(os.path.join(FLAGS.train_dir, "model_flags.json"),
                        arcname="model_flags.json")
            print('Tarred model onto ' + FLAGS.output_model_tgz)
        # with tf.device("/cpu:0"):
        saver = tf.train.import_meta_graph(meta_graph_location,
                                           clear_devices=True)
        logging.info("restoring variables from " + checkpoint_file)
        saver.restore(sess, checkpoint_file)

        # print('loading tfrecords')
        # video_id_batch, video_batch, labels_batch, num_frames_batch = get_input_data_tensors(reader, data_pattern, batch_size)
        # print('loaded tfrecords')

        model_input_raw_tensors = []
        labels_batch_tensor = None
        num_frames_batch_tensor = None
        video_id_batch = None

        for reader, data_pattern in zip(all_readers, all_data_patterns):
            video_id_batch, model_input_raw, labels_batch, num_frames_batch = (
                get_input_data_tensors(reader,
                                       data_pattern,
                                       batch_size=batch_size))
            if labels_batch_tensor is None:
                labels_batch_tensor = labels_batch
            if num_frames_batch_tensor is None:
                num_frames_batch_tensor = num_frames_batch
            if video_id_batch is None:
                video_id_batch = unused_video_id
            model_input_raw_tensors.append(
                tf.expand_dims(model_input_raw, axis=2))

        video_batch = tf.concat(model_input_raw_tensors, axis=2)
        labels_batch = labels_batch_tensor
        num_frames_batch = num_frames_batch_tensor

        input_tensor = tf.get_collection("input_batch_raw")[0]
        num_frames_tensor = tf.get_collection("num_frames")[0]
        predictions_tensor = tf.get_collection("predictions")[0]

        # Workaround for num_epochs issue.
        def set_up_init_ops(variables):
            init_op_list = []
            for variable in list(variables):
                if "train_input" in variable.name:
                    init_op_list.append(tf.assign(variable, 1))
                    variables.remove(variable)
            init_op_list.append(tf.variables_initializer(variables))
            return init_op_list

        sess.run(
            set_up_init_ops(tf.get_collection_ref(
                tf.GraphKeys.LOCAL_VARIABLES)))

        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(sess=sess, coord=coord)

        start_time = time.time()
        video_ids = []
        video_labels = []
        video_features = []
        filenum = 0
        num_examples_processed = 0
        total_num_examples_processed = 0

        directory = FLAGS.output_dir
        if not os.path.exists(directory):
            os.makedirs(directory)
        else:
            pass
            # raise IOError("Output path exists! path='" + directory + "'")
        start_time = time.time()
        # out_file.write("VideoId,LabelConfidencePairs\n")

        try:
            while not coord.should_stop():
                video_id_batch_val, video_batch_val, labels_val, num_frames_batch_val = sess.run(
                    [
                        video_id_batch, video_batch, labels_batch,
                        num_frames_batch
                    ])
                predictions_val, = sess.run(
                    [predictions_tensor],
                    feed_dict={
                        input_tensor: video_batch_val,
                        num_frames_tensor: num_frames_batch_val
                    })
                now = time.time()
                num_classes = predictions_val.shape[1]
                logging.info("num examples processed: " +
                             str(num_examples_processed) +
                             " elapsed seconds: " +
                             "{0:.2f}".format(now - start_time))
                # for line in format_lines(video_id_batch_val, predictions_val, top_k):
                #   out_file.write(line)
                # out_file.flush()

                video_ids.append(video_id_batch_val)
                video_labels.append(labels_val)
                video_features.append(predictions_val)
                num_examples_processed += len(video_id_batch_val)

                if num_examples_processed >= FLAGS.file_size:
                    assert num_examples_processed == FLAGS.file_size, "num_examples_processed should be equal to %d" % FLAGS.file_size
                    video_ids = np.concatenate(video_ids, axis=0)
                    video_labels = np.concatenate(video_labels, axis=0)
                    video_features = np.concatenate(video_features, axis=0)
                    write_to_record(video_ids, video_labels, video_features,
                                    filenum, num_examples_processed)

                    video_ids = []
                    video_labels = []
                    video_features = []
                    filenum += 1
                    total_num_examples_processed += num_examples_processed

                    now = time.time()
                    logging.info("num examples processed: " +
                                 str(num_examples_processed) +
                                 " elapsed seconds: " +
                                 "{0:.2f}".format(now - start_time))
                    num_examples_processed = 0

        except tf.errors.OutOfRangeError:
            # logging.info('Done with inference. The output file was written to ' + out_file_location)
            if video_id_batch_val is not None:
                pass
                # print(len(video_id_batch_val))
                # video_ids.append(video_id_batch_val)
                # video_labels.append(labels_val)
                # video_features.append(predictions_val)
                # num_examples_processed += len(video_id_batch_val)

            if 0 < num_examples_processed <= FLAGS.file_size:
                video_ids = np.concatenate(video_ids, axis=0)
                video_labels = np.concatenate(video_labels, axis=0)
                video_features = np.concatenate(video_features, axis=0)
                write_to_record(video_ids, video_labels, video_features,
                                filenum, num_examples_processed)
                total_num_examples_processed += num_examples_processed

                now = time.time()
                logging.info("num examples processed: " +
                             str(num_examples_processed) +
                             " elapsed seconds: " +
                             "{0:.2f}".format(now - start_time))
                num_examples_processed = 0

            logging.info("Done with inference. %d samples was written to %s" %
                         (total_num_examples_processed, FLAGS.output_dir))
        finally:
            coord.request_stop()

        coord.join(threads)
        sess.close()
コード例 #28
0
def log_batch(articles, abstracts):
    for i in range(len(articles)):
        article = articles[i]
        abstract = abstracts[i]
        log.info('i={}\n\narticle={}\n\nabstract={}'.format(i, repr(article), repr(abstract)))
コード例 #29
0
def inference(reader, train_dir, data_pattern, out_file_location, batch_size,
              top_k):
    model_names = FLAGS.model
    if len(model_names.split(',')) > 1:
        model = []
        for name in model_names.split(','):
            modules = find_class_by_name(
                name, [embedding_models, video_level_models])()
            model.append(modules)
    with tf.Session(config=tf.ConfigProto(
            allow_soft_placement=True)) as sess, gfile.Open(
                out_file_location, "w+") as out_file:
        video_id_batch, model_input_raw, num_frames = get_input_data_tensors(
            reader, data_pattern, batch_size)

        feature_dim = len(model_input_raw.get_shape()) - 1

        # Normalize input features.
        model_input = tf.nn.l2_normalize(model_input_raw, feature_dim)
        # with tf.variable_scope("net1"):
        with tf.variable_scope("tower"):

            result1 = model[0].create_model(model_input,
                                            num_frames=num_frames,
                                            vocab_size=4716,
                                            is_training=False)
            #####

            result1 = tf.stop_gradient(result1)
            result2 = model[1].create_model(model_input,
                                            num_frames=num_frames,
                                            vocab_size=4716,
                                            is_training=False)
            result2 = tf.stop_gradient(result2)
            all_vars = tf.global_variables()
            # for v in all_vars:
            #   print v.name
            # for i in v_vars:
            #   logging.info(str(i))
            for i, v in enumerate(all_vars):
                logging.info(str(v.name))
                if 'rnn' in v.name:
                    vars1 = all_vars[:i]
                    vars2 = all_vars[i:]
                    break

            result1 = tf.nn.l2_normalize(result1, dim=1)
            result2 = tf.nn.l2_normalize(result2, dim=1)
            embeddings = tf.concat([result1, result2], axis=1)
            model_concat = find_class_by_name('MoeModel',
                                              [video_level_models])()
            result = model_concat.create_model(embeddings,
                                               vocab_size=4716,
                                               num_mixtures=4)
            predictions = result["predictions"]
            # predictions=(result1["predictions"]+result2["predictions"])/2
            tf.summary.histogram("model_activations", predictions)
            # if "loss" in result.keys():
            #   label_loss = result["loss"]
            # else:
            # label_loss = losses.CrossEntropyLoss().calculate_loss(predictions, labels_batch)
            # tf.summary.scalar("label_loss", label_loss)
            # if "regularization_loss" in result.keys():
            #   reg_loss = result["regularization_loss"]
            # reg_losses = tf.losses.get_regularization_losses()
            # if "regularization_loss" in result.keys():
            #   reg_loss = result["regularization_loss"]
            # else:
            #   reg_loss = tf.constant(0.0)
            # final_loss = FLAGS.regularization_penalty * reg_loss + label_loss
            #
            # optimizer = optimizer_class(learning_rate)
            # gradients = optimizer.compute_gradients(final_loss,
            #                                         colocate_gradients_with_ops=False)
            #
            # with tf.name_scope('clip_grads'):
            #   merged_gradients = utils.clip_gradient_norms(gradients, 1.0)
            # train_op = optimizer.apply_gradients(merged_gradients, global_step=global_step)
            #
            # tf.add_to_collection("global_step", global_step)
            # tf.add_to_collection("loss", label_loss)
            tf.add_to_collection("input_batch_raw", model_input_raw)
            tf.add_to_collection("predictions", predictions)
            tf.add_to_collection("video_id_batch", video_id_batch)
            tf.add_to_collection("num_frames", num_frames)
            # tf.add_to_collection("labels", tf.cast(labels_batch, tf.float32))
            tf.add_to_collection("summary_op", tf.summary.merge_all())
            # tf.add_to_collection("train_op", train_op)

            video_id_batch = tf.get_collection("video_id_batch")[0]
            # prediction_batch = tf.get_collection("predictions")[0]
            # label_batch = tf.get_collection("labels")[0]
            # loss = tf.get_collection("loss")[0]

        saver = tf.train.Saver()
        latest_checkpoint = tf.train.latest_checkpoint(train_dir)
        if latest_checkpoint is None:
            raise Exception("unable to find a checkpoint at location: %s" %
                            train_dir)
        else:
            meta_graph_location = latest_checkpoint + ".meta"
            logging.info("loading meta-graph: " + meta_graph_location)
        # saver = tf.train.import_meta_graph(meta_graph_location, clear_devices=True)
        logging.info("restoring variables from " + latest_checkpoint)
        saver.restore(sess, latest_checkpoint)
        input_tensor = tf.get_collection("input_batch_raw")[0]
        num_frames_tensor = tf.get_collection("num_frames")[0]
        predictions_tensor = tf.get_collection("predictions")[0]

        # Workaround for num_epochs issue.
        def set_up_init_ops(variables):
            init_op_list = []
            for variable in list(variables):
                logging.info("train_input:")
                logging.info(str(variable.name))
                if "train_input" in variable.name:

                    init_op_list.append(tf.assign(variable, 1))
                    variables.remove(variable)
            init_op_list.append(tf.variables_initializer(variables))
            return init_op_list

        sess.run(
            set_up_init_ops(tf.get_collection_ref(
                tf.GraphKeys.LOCAL_VARIABLES)))

        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(sess=sess, coord=coord)
        num_examples_processed = 0
        start_time = time.time()
        out_file.write("VideoId,LabelConfidencePairs\n")

        try:
            while not coord.should_stop():
                video_id_batch_val, video_batch_val, num_frames_batch_val, predictions_val, = sess.run(
                    [
                        video_id_batch, model_input_raw, num_frames,
                        predictions_tensor
                    ])
                now = time.time()
                num_examples_processed += len(video_batch_val)
                num_classes = predictions_val.shape[1]
                logging.info("num examples processed: " +
                             str(num_examples_processed) +
                             " elapsed seconds: " +
                             "{0:.2f}".format(now - start_time))
                for line in format_lines(video_id_batch_val, predictions_val,
                                         top_k):
                    out_file.write(line)
                out_file.flush()

        except tf.errors.OutOfRangeError:
            logging.info(
                'Done with inference. The output file was written to ' +
                out_file_location)
        finally:
            coord.request_stop()

        coord.join(threads)
        sess.close()
コード例 #30
0
ファイル: eval.py プロジェクト: alexsyrom/youtube-8m
def evaluate():
    tf.set_random_seed(0)  # for reproducibility

    # Write json of flags
    model_flags_path = os.path.join(FLAGS.train_dir, "model_flags.json")
    if not file_io.file_exists(model_flags_path):
        raise IOError(("Cannot find file %s. Did you run train.py on the same "
                       "--train_dir?") % model_flags_path)
    flags_dict = json.loads(file_io.FileIO(model_flags_path, mode="r").read())

    with tf.Graph().as_default():
        # convert feature_names and feature_sizes to lists of values
        feature_names, feature_sizes = utils.GetListOfFeatureNamesAndSizes(
            flags_dict["feature_names"], flags_dict["feature_sizes"])

        if flags_dict["frame_features"]:
            reader = readers.YT8MFrameFeatureReader(
                feature_names=feature_names, feature_sizes=feature_sizes)
        else:
            reader = readers.YT8MAggregatedFeatureReader(
                feature_names=feature_names, feature_sizes=feature_sizes)

        model = find_class_by_name(flags_dict["model"],
                                   [frame_level_models, video_level_models])()
        label_loss_fn = find_class_by_name(flags_dict["label_loss"],
                                           [losses])()

        if FLAGS.eval_data_pattern is "":
            raise IOError("'eval_data_pattern' was not specified. " +
                          "Nothing to evaluate.")

        build_graph(reader=reader,
                    model=model,
                    eval_data_pattern=FLAGS.eval_data_pattern,
                    label_loss_fn=label_loss_fn,
                    num_readers=FLAGS.num_readers,
                    batch_size=FLAGS.batch_size)
        logging.info("built evaluation graph")
        video_id_batch = tf.get_collection("video_id_batch")[0]
        prediction_batch = tf.get_collection("predictions")[0]
        label_batch = tf.get_collection("labels")[0]
        loss = tf.get_collection("loss")[0]
        summary_op = tf.get_collection("summary_op")[0]

        saver = tf.train.Saver(tf.global_variables())
        summary_writer = tf.summary.FileWriter(FLAGS.train_dir,
                                               graph=tf.get_default_graph())

        evl_metrics = eval_util.EvaluationMetrics(reader.num_classes,
                                                  FLAGS.top_k)

        last_global_step_val = -1
        while True:
            last_global_step_val = evaluation_loop(video_id_batch,
                                                   prediction_batch,
                                                   label_batch, loss,
                                                   summary_op, saver,
                                                   summary_writer, evl_metrics,
                                                   last_global_step_val)
            if FLAGS.run_once:
                break
コード例 #31
0
ファイル: train.py プロジェクト: wang1ang/youtube-8m
def build_graph(reader,
                model,
                train_data_pattern,
                label_loss_fn=losses.CrossEntropyLoss(),
                batch_size=1000,
                base_learning_rate=0.01,
                learning_rate_decay_examples=1000000,
                learning_rate_decay=0.95,
                optimizer_class=tf.train.AdamOptimizer,
                clip_gradient_norm=1.0,
                regularization_penalty=1,
                num_readers=1,
                num_epochs=None):
  """Creates the Tensorflow graph.

  This will only be called once in the life of
  a training model, because after the graph is created the model will be
  restored from a meta graph file rather than being recreated.

  Args:
    reader: The data file reader. It should inherit from BaseReader.
    model: The core model (e.g. logistic or neural net). It should inherit
           from BaseModel.
    train_data_pattern: glob path to the training data files.
    label_loss_fn: What kind of loss to apply to the model. It should inherit
                from BaseLoss.
    batch_size: How many examples to process at a time.
    base_learning_rate: What learning rate to initialize the optimizer with.
    optimizer_class: Which optimization algorithm to use.
    clip_gradient_norm: Magnitude of the gradient to clip to.
    regularization_penalty: How much weight to give the regularization loss
                            compared to the label loss.
    num_readers: How many threads to use for I/O operations.
    num_epochs: How many passes to make over the data. 'None' means an
                unlimited number of passes.
  """

  global_step = tf.Variable(0, trainable=False, name="global_step")

  local_device_protos = device_lib.list_local_devices()
  gpus = [x.name for x in local_device_protos if x.device_type == 'GPU']
  gpus = gpus[:FLAGS.num_gpu]
  num_gpus = len(gpus)

  if num_gpus > 0:
    logging.info("Using the following GPUs to train: " + str(gpus))
    num_towers = num_gpus
    device_string = '/gpu:%d'
  else:
    logging.info("No GPUs found. Training on CPU.")
    num_towers = 1
    device_string = '/cpu:%d'
  '''
  learning_rate = tf.train.exponential_decay(
      base_learning_rate,
      global_step * batch_size * num_towers,
      learning_rate_decay_examples,
      learning_rate_decay,
      staircase=True)
  tf.summary.scalar('learning_rate', learning_rate)
  '''
  learning_rate = tf.train.cosine_decay_restarts(
    base_learning_rate,
    global_step * batch_size * num_towers, # first decay step
    first_decay_steps=FLAGS.begin_rate * FLAGS.max_steps * batch_size * num_towers,
    t_mul=FLAGS.t_mul, # 2.0
    m_mul=FLAGS.m_mul, # 1.0,
    alpha=0.00001,
    name=None
  )
  optimizer = optimizer_class(learning_rate)
  unused_video_id, model_input_raw, labels_batch, num_frames = (
      get_input_data_tensors(
          reader,
          train_data_pattern,
          batch_size=batch_size * num_towers,
          num_readers=num_readers,
          num_epochs=num_epochs))
  tf.summary.histogram("model/input_raw", model_input_raw)
  feature_dim = len(model_input_raw.get_shape()) - 1

  offset = np.array([4./512] * 1024) # + [0] * feature_dim-1024)
  offset = tf.constant(offset, dtype=tf.float32)

  eigen_val = tf.constant(np.sqrt(np.load("yt8m_pca/eigenvals.npy")[:1024, 0]), dtype=tf.float32)

  model_input = tf.multiply(model_input_raw - offset,  eigen_val + 1e-4) #tf.pad(eigen_val + 1e-4, [[0, 128]], constant_values=1.))
  #model_input = tf.nn.l2_normalize(model_input_raw, feature_dim)


  tower_inputs = tf.split(model_input, num_towers)
  tower_labels = tf.split(labels_batch, num_towers)
  tower_num_frames = tf.split(num_frames, num_towers)
  tower_gradients = []
  tower_predictions = []
  tower_label_losses = []
  tower_reg_losses = []
  for i in range(num_towers):
    # For some reason these 'with' statements can't be combined onto the same
    # line. They have to be nested.
    with tf.device(device_string % i):
      with (tf.variable_scope(("tower"), reuse=True if i > 0 else None)):
        with (slim.arg_scope([slim.model_variable, slim.variable], device="/cpu:0" if num_gpus!=1 else "/gpu:0")):
          result = model.create_model(
            tower_inputs[i],
            num_frames=tower_num_frames[i],
            vocab_size=reader.num_classes,
            labels=tower_labels[i])
          for variable in slim.get_model_variables():
            tf.summary.histogram(variable.op.name, variable)

          predictions = result["predictions"]
          tower_predictions.append(predictions)

          if "loss" in result.keys():
            label_loss = result["loss"]
          else:
            label_loss = label_loss_fn.calculate_loss(predictions, tower_labels[i])
            if "aux_predictions" in result.keys():
              for pred in result["aux_predictions"]:
                label_loss += label_loss_fn.calculate_loss(pred, tower_labels[i])

          if "regularization_loss" in result.keys():
            reg_loss = result["regularization_loss"]
          else:
            reg_loss = tf.constant(0.0)

          reg_losses = tf.losses.get_regularization_losses()
          if reg_losses:
            reg_loss += tf.add_n(reg_losses)

          tower_reg_losses.append(reg_loss)

          # Adds update_ops (e.g., moving average updates in batch normalization) as
          # a dependency to the train_op.
          update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
          if "update_ops" in result.keys():
            update_ops += result["update_ops"]
          if update_ops:
            with tf.control_dependencies(update_ops):
              barrier = tf.no_op(name="gradient_barrier")
              with tf.control_dependencies([barrier]):
                label_loss = tf.identity(label_loss)

          tower_label_losses.append(label_loss)

          # Incorporate the L2 weight penalties etc.
          final_loss = regularization_penalty * reg_loss + label_loss
          gradients = optimizer.compute_gradients(final_loss,
              colocate_gradients_with_ops=False)
          tower_gradients.append(gradients)
  label_loss = tf.reduce_mean(tf.stack(tower_label_losses))
  tf.summary.scalar("label_loss", label_loss)
  if regularization_penalty != 0:
    reg_loss = tf.reduce_mean(tf.stack(tower_reg_losses))
    tf.summary.scalar("reg_loss", reg_loss)
  merged_gradients = utils.combine_gradients(tower_gradients)

  if clip_gradient_norm > 0:
    with tf.name_scope('clip_grads'):
      merged_gradients = utils.clip_gradient_norms(merged_gradients, clip_gradient_norm)

  train_op = optimizer.apply_gradients(merged_gradients, global_step=global_step)

  tf.add_to_collection("global_step", global_step)
  tf.add_to_collection("loss", label_loss)
  tf.add_to_collection("predictions", tf.concat(tower_predictions, 0))
  tf.add_to_collection("input_batch_raw", model_input_raw)
  tf.add_to_collection("input_batch", model_input)
  tf.add_to_collection("num_frames", num_frames)
  tf.add_to_collection("labels", tf.cast(labels_batch, tf.float32))
  tf.add_to_collection("train_op", train_op)
コード例 #32
0
    def run(self, start_new_model=False):
        """Performs training on the currently defined Tensorflow graph.

    Returns:
      A tuple of the training Hit@1 and the training PERR.
    """
        if self.is_master and start_new_model:
            self.remove_training_directory(self.train_dir)

        target, device_fn = self.start_server_if_distributed()

        meta_filename = self.get_meta_filename(start_new_model, self.train_dir)

        with tf.Graph().as_default() as graph:

            if meta_filename:
                saver = self.recover_model(meta_filename)

            with tf.device(device_fn):
                if not meta_filename:
                    saver = self.build_model(self.model, self.reader)

                global_step = tf.get_collection("global_step")[0]
                loss = tf.get_collection("loss")[0]
                input_batch_raw = tf.get_collection("input_batch_raw")[0]
                input_batch = tf.get_collection("input_batch")[0]
                model_input_raw_ph = tf.get_collection("model_input_raw_ph")[0]
                predictions = tf.get_collection("predictions")[0]
                labels = tf.get_collection("labels")[0]
                train_op = tf.get_collection("train_op")[0]
                num_frames = tf.get_collection("num_frames")[0]
                num_frames_ph = tf.get_collection("num_frames_ph")[0]
                learning_rate = tf.get_collection("learning_rate")[0]

                init_op = tf.global_variables_initializer()

        sv = tf.train.Supervisor(graph,
                                 logdir=self.train_dir,
                                 init_op=init_op,
                                 is_chief=self.is_master,
                                 global_step=global_step,
                                 save_model_secs=15 * 60,
                                 save_summaries_secs=120,
                                 saver=saver)

        logging.info("%s: Starting managed session.",
                     task_as_string(self.task))
        with sv.managed_session(target, config=self.config) as sess:
            try:
                logging.info("%s: Entering training loop.",
                             task_as_string(self.task))
                while (not sv.should_stop()) and (not self.max_steps_reached):
                    batch_start_time = time.time()

                    model_input_raw_val, num_frames_val, learning_rate_val = sess.run(
                        [input_batch_raw, num_frames, learning_rate])

                    pr_feature = []
                    pr_num = []

                    for i in range(model_input_raw_val.shape[0]):
                        if num_frames_val[i] / FLAGS.max_scene <= 2:
                            num_tmp = num_frames_val[i] * np.ceil(
                                1 + FLAGS.max_scene / num_frames_val[i])
                            input_tmp = model_input_raw_val[
                                i][:num_frames_val[i]]
                            input_tmp = np.repeat(
                                input_tmp,
                                np.ceil(1 +
                                        FLAGS.max_scene / num_frames_val[i]),
                                0)
                        else:
                            num_tmp = num_frames_val[i]
                            input_tmp = model_input_raw_val[i][:num_tmp]

                        numvec = (input_tmp[1:] *
                                  input_tmp[:-1]).sum(axis=1) / (np.sqrt(
                                      (input_tmp[1:]**2).sum(1)) * (np.sqrt(
                                          (input_tmp[:-1]**2).sum(1))))
                        idx = np.sort(
                            numvec.argpartition(FLAGS.max_scene -
                                                1)[:FLAGS.max_scene - 1] + 1)

                        example_splits = np.split(input_tmp, idx, 0)

                        example_splits_mean = [
                            np.mean(example_split, 0)
                            for example_split in example_splits
                        ]
                        example_splits_mean = np.stack(example_splits_mean, 0)
                        pr_num.append(FLAGS.max_scene)
                        pr_feature.append(example_splits_mean)
                    pr_feature = np.stack(pr_feature, 0)
                    pr_num = np.stack(pr_num, 0)

                    _, global_step_val, loss_val, predictions_val, labels_val = sess.run(
                        [train_op, global_step, loss, predictions, labels],
                        feed_dict={
                            model_input_raw_ph: pr_feature,
                            num_frames_ph: pr_num
                        })

                    seconds_per_batch = time.time() - batch_start_time
                    examples_per_second = labels_val.shape[
                        0] / seconds_per_batch

                    if self.max_steps and self.max_steps <= global_step_val:
                        self.max_steps_reached = True

                    if self.is_master and global_step_val % 10 == 0 and self.train_dir:
                        eval_start_time = time.time()
                        hit_at_one = eval_util.calculate_hit_at_one(
                            predictions_val, labels_val)
                        perr = eval_util.calculate_precision_at_equal_recall_rate(
                            predictions_val, labels_val)
                        gap = eval_util.calculate_gap(predictions_val,
                                                      labels_val)
                        eval_end_time = time.time()
                        eval_time = eval_end_time - eval_start_time

                        logging.info("training step " + str(global_step_val) +
                                     " | Loss: " + ("%.2f" % loss_val) +
                                     " Examples/sec: " +
                                     ("%.2f" % examples_per_second) +
                                     " | Hit@1: " + ("%.2f" % hit_at_one) +
                                     " PERR: " + ("%.2f" % perr) + " GAP: " +
                                     ("%.2f" % gap))

                        sv.summary_writer.add_summary(
                            utils.MakeSummary("model/Training_Hit@1",
                                              hit_at_one), global_step_val)
                        sv.summary_writer.add_summary(
                            utils.MakeSummary("model/Training_Perr", perr),
                            global_step_val)
                        sv.summary_writer.add_summary(
                            utils.MakeSummary("model/Training_GAP", gap),
                            global_step_val)
                        sv.summary_writer.add_summary(
                            utils.MakeSummary("global_step/Examples/Second",
                                              examples_per_second),
                            global_step_val)
                        sv.summary_writer.flush()

                        # Exporting the model every x steps
                        time_to_export = (
                            (self.last_model_export_step == 0)
                            or (global_step_val - self.last_model_export_step
                                >= self.export_model_steps))

                        if self.is_master and time_to_export:
                            self.export_model(global_step_val, sv.saver,
                                              sv.save_path, sess)
                            self.last_model_export_step = global_step_val
                    else:
                        logging.info("training step " + str(global_step_val) +
                                     " | Loss: " + ("%.2f" % loss_val) +
                                     " Examples/sec: " +
                                     ("%.2f" % examples_per_second))
            except tf.errors.OutOfRangeError:
                logging.info("%s: Done training -- epoch limit reached.",
                             task_as_string(self.task))

        logging.info("%s: Exited training loop.", task_as_string(self.task))
        sv.Stop()
コード例 #33
0
ファイル: train.py プロジェクト: wang1ang/youtube-8m
 def recover_model(self, meta_filename):
   logging.info("%s: Restoring from meta graph file %s",
                task_as_string(self.task), meta_filename)
   return tf.train.import_meta_graph(meta_filename)
def evaluation_loop(video_id_batch, prediction_batch, label_batch, loss,
                    summary_op, saver, summary_writer, evl_metrics,
                    last_global_step_val):
  """Run the evaluation loop once.

  Args:
    video_id_batch: a tensor of video ids mini-batch.
    prediction_batch: a tensor of predictions mini-batch.
    label_batch: a tensor of label_batch mini-batch.
    loss: a tensor of loss for the examples in the mini-batch.
    summary_op: a tensor which runs the tensorboard summary operations.
    saver: a tensorflow saver to restore the model.
    summary_writer: a tensorflow summary_writer
    evl_metrics: an EvaluationMetrics object.
    last_global_step_val: the global step used in the previous evaluation.

  Returns:
    The global_step used in the latest model.
  """

  global_step_val = -1
  with tf.Session(config=tf.ConfigProto(log_device_placement=False)) as sess:
#     latest_checkpoint = get_latest_checkpoint()
#     if latest_checkpoint:
#       logging.info("Loading checkpoint for eval: " + latest_checkpoint)
#       # Restores from checkpoint
#       saver.restore(sess, latest_checkpoint)
#       # Assuming model_checkpoint_path looks something like:
#       # /my-favorite-path/yt8m_train/model.ckpt-0, extract global_step from it.
#       global_step_val = os.path.basename(latest_checkpoint).split("-")[-1]

#       # Save model
#       saver.save(sess, os.path.join(FLAGS.train_dir, "inference_model"))
#     else:
#       logging.info("No checkpoint file found.")
#       return global_step_val

#     if global_step_val == last_global_step_val:
#       logging.info("skip this checkpoint global_step_val=%s "
#                    "(same as the previous one).", global_step_val)
#       return global_step_val

    sess.run([tf.local_variables_initializer()])

    # Start the queue runners.
    fetches = [video_id_batch, prediction_batch, label_batch, loss, summary_op]
    coord = tf.train.Coordinator()
    
    
    # output results
    start_time = time.time()
    video_ids = []
    video_labels = []
    video_features = []
    filenum = 0
    num_examples_processed = 0
    total_num_examples_processed = 0
    
    # output prediction dir
    directory = FLAGS.output_dir
    if directory != '':
      if not os.path.exists(directory):
          os.makedirs(directory)
      else:
          raise IOError("Output path exists! path='" + directory + "'")
    
    try:
      threads = []
      for qr in tf.get_collection(tf.GraphKeys.QUEUE_RUNNERS):
        threads.extend(qr.create_threads(
            sess, coord=coord, daemon=True,
            start=True))
      logging.info("enter eval_once loop global_step_val = %s. ",
                   global_step_val)

      evl_metrics.clear()

      examples_processed = 0
      while not coord.should_stop():
        batch_start_time = time.time()
        ids_val, predictions_val, labels_val, loss_val, summary_val = sess.run(
            fetches)
        seconds_per_batch = time.time() - batch_start_time
        example_per_second = labels_val.shape[0] / seconds_per_batch
        examples_processed += labels_val.shape[0]

        iteration_info_dict = evl_metrics.accumulate(predictions_val,
                                                     labels_val, loss_val)
        iteration_info_dict["examples_per_second"] = example_per_second

        iterinfo = utils.AddGlobalStepSummary(
            summary_writer,
            global_step_val,
            iteration_info_dict,
            summary_scope="Eval")
        logging.info("examples_processed: %d | %s", examples_processed,
                     iterinfo)
        
        # save predictions
        if directory != '':
          video_ids.append(ids_val)
          video_labels.append(labels_val)
          video_features.append(predictions_val)
          num_examples_processed += len(ids_val)

          if num_examples_processed >= FLAGS.file_size:
            assert num_examples_processed==FLAGS.file_size, "num_examples_processed should be equal to %d"%FLAGS.file_size
            video_ids = np.concatenate(video_ids, axis=0)
            video_labels = np.concatenate(video_labels, axis=0)
            video_features = np.concatenate(video_features, axis=0)
            write_to_record(video_ids, video_labels, video_features, filenum, num_examples_processed)

            video_ids = []
            video_labels = []
            video_features = []
            filenum += 1
            total_num_examples_processed += num_examples_processed

            now = time.time()
            logging.info("num examples processed: " + str(num_examples_processed) + " elapsed seconds: " + "{0:.2f}".format(now-start_time))
            num_examples_processed = 0

    except tf.errors.OutOfRangeError as e:
      logging.info(
          "Done with batched inference. Now calculating global performance "
          "metrics.")
      # calculate the metrics for the entire epoch
      epoch_info_dict = evl_metrics.get()
      epoch_info_dict["epoch_id"] = global_step_val

      summary_writer.add_summary(summary_val, global_step_val)
      epochinfo = utils.AddEpochSummary(
          summary_writer,
          global_step_val,
          epoch_info_dict,
          summary_scope="Eval")
      logging.info(epochinfo)
      evl_metrics.clear()
      
      # save prediction
      if directory != '':
        # if ids_val is not None:
        #   video_ids.append(ids_val)
        #   video_labels.append(labels_val)
        #   video_features.append(predictions_val)
        #   num_examples_processed += len(ids_val)

        if 0 < num_examples_processed <= FLAGS.file_size:
          video_ids = np.concatenate(video_ids, axis=0)
          video_labels = np.concatenate(video_labels, axis=0)
          video_features = np.concatenate(video_features, axis=0)
          write_to_record(video_ids, video_labels, video_features, filenum, num_examples_processed)
          total_num_examples_processed += num_examples_processed

          now = time.time()
          logging.info("num examples processed: " + str(num_examples_processed) + " elapsed seconds: " + "{0:.2f}".format(now-start_time))
          num_examples_processed = 0

        logging.info("Done with inference. %d samples was written to %s" % (total_num_examples_processed, FLAGS.output_dir))
    except Exception as e:  # pylint: disable=broad-except
      logging.info("Unexpected exception: " + str(e))
      coord.request_stop(e)

    coord.request_stop()
    coord.join(threads, stop_grace_period_secs=10)

    return global_step_val
コード例 #35
0
ファイル: train.py プロジェクト: vijayky88/youtube-8m
def build_graph(reader,
                model,
                train_data_pattern,
                label_loss_fn=losses.CrossEntropyLoss(),
                batch_size=1000,
                base_learning_rate=0.01,
                learning_rate_decay_examples=1000000,
                learning_rate_decay=0.95,
                optimizer_class=tf.train.AdamOptimizer,
                clip_gradient_norm=1.0,
                regularization_penalty=1,
                num_readers=1,
                num_epochs=None):
  """Creates the Tensorflow graph.

  This will only be called once in the life of
  a training model, because after the graph is created the model will be
  restored from a meta graph file rather than being recreated.

  Args:
    reader: The data file reader. It should inherit from BaseReader.
    model: The core model (e.g. logistic or neural net). It should inherit
           from BaseModel.
    train_data_pattern: glob path to the training data files.
    label_loss_fn: What kind of loss to apply to the model. It should inherit
                from BaseLoss.
    batch_size: How many examples to process at a time.
    base_learning_rate: What learning rate to initialize the optimizer with.
    optimizer_class: Which optimization algorithm to use.
    clip_gradient_norm: Magnitude of the gradient to clip to.
    regularization_penalty: How much weight to give the regularization loss
                            compared to the label loss.
    num_readers: How many threads to use for I/O operations.
    num_epochs: How many passes to make over the data. 'None' means an
                unlimited number of passes.
  """

  global_step = tf.Variable(0, trainable=False, name="global_step")

  local_device_protos = device_lib.list_local_devices()
  gpus = [x.name for x in local_device_protos if x.device_type == 'GPU']
  gpus = gpus[:FLAGS.num_gpu]
  num_gpus = len(gpus)

  if num_gpus > 0:
    logging.info("Using the following GPUs to train: " + str(gpus))
    num_towers = num_gpus
    device_string = '/gpu:%d'
  else:
    logging.info("No GPUs found. Training on CPU.")
    num_towers = 1
    device_string = '/cpu:%d'

  learning_rate = tf.train.exponential_decay(
      base_learning_rate,
      global_step * batch_size * num_towers,
      learning_rate_decay_examples,
      learning_rate_decay,
      staircase=True)
  tf.summary.scalar('learning_rate', learning_rate)

  optimizer = optimizer_class(learning_rate)
  unused_video_id, model_input_raw, labels_batch, num_frames = (
      get_input_data_tensors(
          reader,
          train_data_pattern,
          batch_size=batch_size * num_towers,
          num_readers=num_readers,
          num_epochs=num_epochs))
  tf.summary.histogram("model/input_raw", model_input_raw)

  feature_dim = len(model_input_raw.get_shape()) - 1

  model_input = tf.nn.l2_normalize(model_input_raw, feature_dim)

  tower_inputs = tf.split(model_input, num_towers)
  tower_labels = tf.split(labels_batch, num_towers)
  tower_num_frames = tf.split(num_frames, num_towers)
  tower_gradients = []
  tower_predictions = []
  tower_label_losses = []
  tower_reg_losses = []
  for i in range(num_towers):
    # For some reason these 'with' statements can't be combined onto the same
    # line. They have to be nested.
    with tf.device(device_string % i):
      with (tf.variable_scope(("tower"), reuse=True if i > 0 else None)):
        with (slim.arg_scope([slim.model_variable, slim.variable], device="/cpu:0" if num_gpus!=1 else "/gpu:0")):
          result = model.create_model(
            tower_inputs[i],
            num_frames=tower_num_frames[i],
            vocab_size=reader.num_classes,
            labels=tower_labels[i])
          for variable in slim.get_model_variables():
            tf.summary.histogram(variable.op.name, variable)

          predictions = result["predictions"]
          tower_predictions.append(predictions)

          if "loss" in result.keys():
            label_loss = result["loss"]
          else:
            label_loss = label_loss_fn.calculate_loss(predictions, tower_labels[i])

          if "regularization_loss" in result.keys():
            reg_loss = result["regularization_loss"]
          else:
            reg_loss = tf.constant(0.0)

          reg_losses = tf.losses.get_regularization_losses()
          if reg_losses:
            reg_loss += tf.add_n(reg_losses)

          tower_reg_losses.append(reg_loss)

          # Adds update_ops (e.g., moving average updates in batch normalization) as
          # a dependency to the train_op.
          update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
          if "update_ops" in result.keys():
            update_ops += result["update_ops"]
          if update_ops:
            with tf.control_dependencies(update_ops):
              barrier = tf.no_op(name="gradient_barrier")
              with tf.control_dependencies([barrier]):
                label_loss = tf.identity(label_loss)

          tower_label_losses.append(label_loss)

          # Incorporate the L2 weight penalties etc.
          final_loss = regularization_penalty * reg_loss + label_loss
          gradients = optimizer.compute_gradients(final_loss,
              colocate_gradients_with_ops=False)
          tower_gradients.append(gradients)
  label_loss = tf.reduce_mean(tf.stack(tower_label_losses))
  tf.summary.scalar("label_loss", label_loss)
  if regularization_penalty != 0:
    reg_loss = tf.reduce_mean(tf.stack(tower_reg_losses))
    tf.summary.scalar("reg_loss", reg_loss)
  merged_gradients = utils.combine_gradients(tower_gradients)

  if clip_gradient_norm > 0:
    with tf.name_scope('clip_grads'):
      merged_gradients = utils.clip_gradient_norms(merged_gradients, clip_gradient_norm)

  train_op = optimizer.apply_gradients(merged_gradients, global_step=global_step)

  tf.add_to_collection("global_step", global_step)
  tf.add_to_collection("loss", label_loss)
  tf.add_to_collection("predictions", tf.concat(tower_predictions, 0))
  tf.add_to_collection("input_batch_raw", model_input_raw)
  tf.add_to_collection("input_batch", model_input)
  tf.add_to_collection("num_frames", num_frames)
  tf.add_to_collection("labels", tf.cast(labels_batch, tf.float32))
  tf.add_to_collection("train_op", train_op)
コード例 #36
0
ファイル: eval.py プロジェクト: vijayky88/youtube-8m
def evaluation_loop(video_id_batch, prediction_batch, label_batch, loss,
                    summary_op, saver, summary_writer, evl_metrics,
                    last_global_step_val):
  """Run the evaluation loop once.

  Args:
    video_id_batch: a tensor of video ids mini-batch.
    prediction_batch: a tensor of predictions mini-batch.
    label_batch: a tensor of label_batch mini-batch.
    loss: a tensor of loss for the examples in the mini-batch.
    summary_op: a tensor which runs the tensorboard summary operations.
    saver: a tensorflow saver to restore the model.
    summary_writer: a tensorflow summary_writer
    evl_metrics: an EvaluationMetrics object.
    last_global_step_val: the global step used in the previous evaluation.

  Returns:
    The global_step used in the latest model.
  """

  global_step_val = -1
  with tf.Session() as sess:
    latest_checkpoint = get_latest_checkpoint()
    if latest_checkpoint:
      logging.info("Loading checkpoint for eval: " + latest_checkpoint)
      # Restores from checkpoint
      saver.restore(sess, latest_checkpoint)
      # Assuming model_checkpoint_path looks something like:
      # /my-favorite-path/yt8m_train/model.ckpt-0, extract global_step from it.
      global_step_val = os.path.basename(latest_checkpoint).split("-")[-1]

      # Save model
      saver.save(sess, os.path.join(FLAGS.train_dir, "inference_model"))
    else:
      logging.info("No checkpoint file found.")
      return global_step_val

    if global_step_val == last_global_step_val:
      logging.info("skip this checkpoint global_step_val=%s "
                   "(same as the previous one).", global_step_val)
      return global_step_val

    sess.run([tf.local_variables_initializer()])

    # Start the queue runners.
    fetches = [video_id_batch, prediction_batch, label_batch, loss, summary_op]
    coord = tf.train.Coordinator()
    try:
      threads = []
      for qr in tf.get_collection(tf.GraphKeys.QUEUE_RUNNERS):
        threads.extend(qr.create_threads(
            sess, coord=coord, daemon=True,
            start=True))
      logging.info("enter eval_once loop global_step_val = %s. ",
                   global_step_val)

      evl_metrics.clear()

      examples_processed = 0
      while not coord.should_stop():
        batch_start_time = time.time()
        _, predictions_val, labels_val, loss_val, summary_val = sess.run(
            fetches)
        seconds_per_batch = time.time() - batch_start_time
        example_per_second = labels_val.shape[0] / seconds_per_batch
        examples_processed += labels_val.shape[0]

        iteration_info_dict = evl_metrics.accumulate(predictions_val,
                                                     labels_val, loss_val)
        iteration_info_dict["examples_per_second"] = example_per_second

        iterinfo = utils.AddGlobalStepSummary(
            summary_writer,
            global_step_val,
            iteration_info_dict,
            summary_scope="Eval")
        logging.info("examples_processed: %d | %s", examples_processed,
                     iterinfo)

    except tf.errors.OutOfRangeError as e:
      logging.info(
          "Done with batched inference. Now calculating global performance "
          "metrics.")
      # calculate the metrics for the entire epoch
      epoch_info_dict = evl_metrics.get()
      epoch_info_dict["epoch_id"] = global_step_val

      summary_writer.add_summary(summary_val, global_step_val)
      epochinfo = utils.AddEpochSummary(
          summary_writer,
          global_step_val,
          epoch_info_dict,
          summary_scope="Eval")
      logging.info(epochinfo)
      evl_metrics.clear()
    except Exception as e:  # pylint: disable=broad-except
      logging.info("Unexpected exception: " + str(e))
      coord.request_stop(e)

    coord.request_stop()
    coord.join(threads, stop_grace_period_secs=10)

    return global_step_val
コード例 #37
0
ファイル: inference.py プロジェクト: vijayky88/youtube-8m
def inference(reader, train_dir, data_pattern, out_file_location, batch_size, top_k):
  with tf.Session(config=tf.ConfigProto(allow_soft_placement=True)) as sess, gfile.Open(out_file_location, "w+") as out_file:
    video_id_batch, video_batch, num_frames_batch = get_input_data_tensors(reader, data_pattern, batch_size)
    checkpoint_file = os.path.join(FLAGS.train_dir, "inference_model")
    if not gfile.Exists(checkpoint_file + ".meta"):
      raise IOError("Cannot find %s. Did you run eval.py?" % checkpoint_file)
    meta_graph_location = checkpoint_file + ".meta"
    logging.info("loading meta-graph: " + meta_graph_location)

    if FLAGS.output_model_tgz:
      with tarfile.open(FLAGS.output_model_tgz, "w:gz") as tar:
        for model_file in glob.glob(checkpoint_file + '.*'):
          tar.add(model_file, arcname=os.path.basename(model_file))
        tar.add(os.path.join(FLAGS.train_dir, "model_flags.json"),
                arcname="model_flags.json")
      print('Tarred model onto ' + FLAGS.output_model_tgz)
    with tf.device("/cpu:0"):
      saver = tf.train.import_meta_graph(meta_graph_location, clear_devices=True)
    logging.info("restoring variables from " + checkpoint_file)
    saver.restore(sess, checkpoint_file)
    input_tensor = tf.get_collection("input_batch_raw")[0]
    num_frames_tensor = tf.get_collection("num_frames")[0]
    predictions_tensor = tf.get_collection("predictions")[0]

    # Workaround for num_epochs issue.
    def set_up_init_ops(variables):
      init_op_list = []
      for variable in list(variables):
        if "train_input" in variable.name:
          init_op_list.append(tf.assign(variable, 1))
          variables.remove(variable)
      init_op_list.append(tf.variables_initializer(variables))
      return init_op_list

    sess.run(set_up_init_ops(tf.get_collection_ref(
        tf.GraphKeys.LOCAL_VARIABLES)))

    coord = tf.train.Coordinator()
    threads = tf.train.start_queue_runners(sess=sess, coord=coord)
    num_examples_processed = 0
    start_time = time.time()
    out_file.write("VideoId,LabelConfidencePairs\n")

    try:
      while not coord.should_stop():
          video_id_batch_val, video_batch_val,num_frames_batch_val = sess.run([video_id_batch, video_batch, num_frames_batch])
          predictions_val, = sess.run([predictions_tensor], feed_dict={input_tensor: video_batch_val, num_frames_tensor: num_frames_batch_val})
          now = time.time()
          num_examples_processed += len(video_batch_val)
          num_classes = predictions_val.shape[1]
          logging.info("num examples processed: " + str(num_examples_processed) + " elapsed seconds: " + "{0:.2f}".format(now-start_time))
          for line in format_lines(video_id_batch_val, predictions_val, top_k):
            out_file.write(line)
          out_file.flush()


    except tf.errors.OutOfRangeError:
        logging.info('Done with inference. The output file was written to ' + out_file_location)
    finally:
        coord.request_stop()

    coord.join(threads)
    sess.close()
コード例 #38
0
ファイル: testYT8mRR.py プロジェクト: bullud/testTFRecord
def main():
    env = json.loads(os.environ.get("TF_CONFIG", "{}"))

    task_data = env.get("task", None) or {"type": "master", "index": 0}
    task = type("TaskSpec", (object, ), task_data)

    logging.set_verbosity(tf.logging.INFO)
    logging.info("%s: Tensorflow version: %s.", task_as_string(task),
                 tf.__version__)

    video_ids, video_features, video_labels, video_frames = gen_input(
        data_pattern,
        reader_batch_size=reader_batch_size,
        num_classes=num_classes,
        num_readers=num_readers,
        mini_batch_size=mini_batch_size)

    result = gen_model(model_input=video_features,
                       vocab_size=num_classes,
                       labels=video_labels,
                       num_frames=video_frames)

    predictions = result["predictions"]

    global_step = tf.Variable(0, trainable=False, name="global_step")

    label_loss = label_loss_fn.calculate_loss(predictions, video_labels)

    if "regularization_loss" in result.keys():
        reg_loss = result["regularization_loss"]
    else:
        reg_loss = tf.constant(0.0)

    reg_losses = tf.losses.get_regularization_losses()
    if reg_losses:
        reg_loss += tf.add_n(reg_losses)

    update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
    if "update_ops" in result.keys():
        update_ops += result["update_ops"]

    if update_ops:
        with tf.control_dependencies(update_ops):
            barrier = tf.no_op(name="gradient_barrier")
            with tf.control_dependencies([barrier]):
                label_loss = tf.identity(label_loss)

    final_loss = regularization_penalty * reg_loss + label_loss

    learning_rate = tf.train.exponential_decay(base_learning_rate,
                                               global_step * mini_batch_size *
                                               num_towers,
                                               learning_rate_decay_examples,
                                               learning_rate_decay,
                                               staircase=True)

    tf.summary.scalar('learning_rate', learning_rate)

    optimizer = optimizer_class(learning_rate)

    gradients = optimizer.compute_gradients(final_loss,
                                            colocate_gradients_with_ops=False)

    tf.summary.scalar("label_loss", label_loss)

    tf.summary.scalar("reg_loss", reg_loss)

    if clip_gradient_norm > 0:
        with tf.name_scope('clip_grads'):
            gradients = utils.clip_gradient_norms(gradients,
                                                  clip_gradient_norm)

    train_op = optimizer.apply_gradients(gradients, global_step=global_step)

    with tf.Session() as sess:
        tf.global_variables_initializer().run()
        tf.local_variables_initializer().run()

        #init_local_op = tf.local_variables_initializer()
        #sess.run(init_local_op)

        coord = tf.train.Coordinator()

        threads = tf.train.start_queue_runners(coord=coord)

        total_step = 0

        try:
            while total_step < 100000:
                batch_start_time = time.time()

                # v_ids, v_features, v_labels, v_frames = sess.run([video_ids, video_features, video_labels, video_frames])

                _, global_step_val, loss_val, predictions_val, labels_val = sess.run(
                    [
                        train_op, global_step, label_loss, predictions,
                        tf.cast(video_labels, tf.float32)
                    ])

                seconds_per_batch = time.time() - batch_start_time
                examples_per_second = labels_val.shape[0] / seconds_per_batch

                # if max_steps <= global_step_val:
                #    max_steps_reached = True
                # print(v_features.shape)
                # print(v_ids)

                if total_step % 10 == 0:
                    eval_start_time = time.time()
                    hit_at_one = eval_util.calculate_hit_at_one(
                        predictions_val, labels_val)
                    perr = eval_util.calculate_precision_at_equal_recall_rate(
                        predictions_val, labels_val)
                    gap = eval_util.calculate_gap(predictions_val, labels_val)
                    eval_end_time = time.time()
                    eval_time = eval_end_time - eval_start_time

                    logging.info("training step " + str(global_step_val) +
                                 " | Loss: " + ("%.2f" % loss_val) +
                                 " Examples/sec: " +
                                 ("%.2f" % examples_per_second) +
                                 " | Hit@1: " + ("%.2f" % hit_at_one) +
                                 " PERR: " + ("%.2f" % perr) + " GAP: " +
                                 ("%.2f" % gap))

                else:
                    logging.info("training step " + str(global_step_val) +
                                 " | Loss: " + ("%.2f" % loss_val) +
                                 " Examples/sec: " +
                                 ("%.2f" % examples_per_second))

                total_step = total_step + 1

        except tf.errors.OutOfRangeError:
            logging.info("%s: Done training -- epoch limit reached.",
                         task_as_string(task))

        coord.request_stop()

        coord.join(threads)
コード例 #39
0
    def run(self, start_new_model=False):
        """Performs training on the currently defined Tensorflow graph.

        Returns:
          A tuple of the training Hit@1 and the training PERR.
        """

        meta_filename = self.get_meta_filename(start_new_model, self.train_dir)

        with tf.Graph().as_default() as graph:

            if meta_filename:
                saver = self.recover_model(meta_filename)
            else:
                saver = self.build_model(self.model, self.reader)

            global_step = tf.get_collection("global_step")[0]
            train_top_loss = tf.get_collection("train_loss")[0]
            test_top_loss = tf.get_collection("test_top_loss")[0]
            train_predictions = tf.get_collection("train_predictions")[0]
            test_predictions = tf.get_collection("test_predictions")[0]
            train_top_labels = tf.get_collection("train_top_labels")[0]
            test_top_labels = tf.get_collection("test_top_labels")[0]
            train_op = tf.get_collection("train_op")[0]
            init_op = tf.global_variables_initializer()

        logging.info("%s: Starting session.", self.task)
        with tf.Session(config=self.config) as sess:

            # try:
            if 1:
                logging.info("Entering training loop.")
                while not self.max_steps_reached:

                    batch_start_time = time.time()
                    _, global_step_val, train_top_loss_val, train_predictions_val, train_top_labels_val = sess.run(
                        [train_op, global_step, train_top_loss, train_predictions, train_top_labels])
                    seconds_per_batch = time.time() - batch_start_time
                    if self.max_steps and self.max_steps <= global_step_val:
                        self.max_steps_reached = True

                    if global_step_val % 10 == 0:
                        train_metric = utils.CalMetric(train_predictions_val, train_top_labels_val)
                        logging.info("%s: training step " + str(global_step_val) + " top_precision: " + (
                                    "%.4f" % train_metric['top_precision']) + " top_recall: " + (
                                                 "%.4f" % train_metric['top_recall']) + " top_f1_score: " + (
                                                 "%.4f" % train_metric['top_f1_score']) + " top_loss: " + str(
                            train_top_loss_val) + " cost: " + str(
                            seconds_per_batch), self.task)

                        if global_step_val % FLAGS.test_interval == 0:
                            batch_start_time = time.time()
                            global_step_val, test_top_loss_val, test_predictions_val, test_top_labels_val = sess.run(
                                [global_step, test_top_loss, test_predictions, test_top_labels])
                            seconds_per_batch = time.time() - batch_start_time
                            test_metric = utils.CalMetric(test_predictions_val, test_top_labels_val)

                            logging.info("%s: training step " + str(global_step_val) + \
                                         " test_top_precision: " + ("%.4f" % test_metric['top_precision']) + \
                                         " test_top_recall: " + ("%.4f" % test_metric['top_recall']) + \
                                         "test_top_f1_score: " + ("%.4f" % test_metric['top_f1_score']) + \
                                         " top_loss: " + str(train_top_loss_val) + " cost: " + str(seconds_per_batch),
                                         self.task)

                            time_to_export = (test_metric['top_f1_score'] - self.last_model_eval_precision) / (
                                    self.last_model_eval_precision + 0.0000001) > 0.001 \
                                             and np.abs(
                                test_metric['top_f1_score'] - self.last_model_eval_precision) > 0.1 \
                                             and np.abs(test_metric['top_f1_score']) > 0.65 \

                            if time_to_export:
                                saver.save(sess, self.train_dir, global_step=global_step_val)
                                self.last_model_export_step = global_step_val

            # except tf.errors.OutOfRangeError:
            #     logging.info("%s: Done training -- epoch limit reached.",
            #                  self.task)

        logging.info("%s: Exited training loop.", self.task)
コード例 #40
0
ファイル: train.py プロジェクト: slai11/youtube-8m
    def run(self, start_new_model=False):
        """Performs training on the currently defined Tensorflow graph.

    Returns:
      A tuple of the training Hit@1 and the training PERR.
    """
        if self.is_master and start_new_model:
            self.remove_training_directory(self.train_dir)

        target, device_fn = self.start_server_if_distributed()

        meta_filename = self.get_meta_filename(start_new_model, self.train_dir)

        with tf.Graph().as_default() as graph:

            if meta_filename:
                saver = self.recover_model(meta_filename)

            with tf.device(device_fn):
                if not meta_filename:
                    saver = self.build_model(self.model, self.reader)

                global_step = tf.get_collection("global_step")[0]
                loss = tf.get_collection("loss")[0]
                predictions = tf.get_collection("predictions")[0]
                labels = tf.get_collection("labels")[0]
                train_op = tf.get_collection("train_op")[0]
                init_op = tf.global_variables_initializer()

        sv = tf.train.Supervisor(graph,
                                 logdir=self.train_dir,
                                 init_op=init_op,
                                 is_chief=self.is_master,
                                 global_step=global_step,
                                 save_model_secs=15 * 60,
                                 save_summaries_secs=120,
                                 saver=saver)

        logging.info("%s: Starting managed session.",
                     task_as_string(self.task))
        with sv.managed_session(target, config=self.config) as sess:
            #writer = tf.summary.FileWriter('model/', sess.graph)
            try:
                logging.info("%s: Entering training loop.",
                             task_as_string(self.task))
                while (not sv.should_stop()) and (not self.max_steps_reached):
                    batch_start_time = time.time()
                    _, global_step_val, loss_val, predictions_val, labels_val = sess.run(
                        [train_op, global_step, loss, predictions, labels])

                    seconds_per_batch = time.time() - batch_start_time
                    examples_per_second = labels_val.shape[
                        0] / seconds_per_batch

                    if self.max_steps and self.max_steps <= global_step_val:
                        logging("MAX STEP REACHED")
                        self.max_steps_reached = True

                    if self.is_master and global_step_val % 5 == 0 and self.train_dir:
                        eval_start_time = time.time()
                        hit_at_one = eval_util.calculate_hit_at_one(
                            predictions_val, labels_val)
                        perr = eval_util.calculate_precision_at_equal_recall_rate(
                            predictions_val, labels_val)
                        gap = eval_util.calculate_gap(predictions_val,
                                                      labels_val)
                        eval_end_time = time.time()
                        eval_time = eval_end_time - eval_start_time

                        logging.info("training step " + str(global_step_val) +
                                     " | Loss: " + ("%.2f" % loss_val) +
                                     " Examples/sec: " +
                                     ("%.2f" % examples_per_second) +
                                     " | Hit@1: " + ("%.2f" % hit_at_one) +
                                     " PERR: " + ("%.2f" % perr) + " GAP: " +
                                     ("%.2f" % gap))

                        sv.summary_writer.add_summary(
                            utils.MakeSummary("model/Training_Hit@1",
                                              hit_at_one), global_step_val)
                        sv.summary_writer.add_summary(
                            utils.MakeSummary("model/Training_Perr", perr),
                            global_step_val)
                        sv.summary_writer.add_summary(
                            utils.MakeSummary("model/Training_GAP", gap),
                            global_step_val)
                        sv.summary_writer.add_summary(
                            utils.MakeSummary("global_step/Examples/Second",
                                              examples_per_second),
                            global_step_val)
                        sv.summary_writer.flush()

                        # Exporting the model every x steps
                        time_to_export = (
                            (self.last_model_export_step == 0)
                            or (global_step_val - self.last_model_export_step
                                >= self.export_model_steps))

                        if self.is_master and time_to_export:
                            logging.info("Saving a file now at {}".format(
                                global_step_val))
                            self.export_model(global_step_val, sv.saver,
                                              sv.save_path, sess)
                            self.last_model_export_step = global_step_val

                            #TODO, validate
                            #evaluate()
                    else:
                        logging.info("training step " + str(global_step_val) +
                                     " | Loss: " + ("%.2f" % loss_val) +
                                     " Examples/sec: " +
                                     ("%.2f" % examples_per_second))
            except tf.errors.OutOfRangeError as e:
                #print(e)
                #pdb.set_trace()
                logging.info("%s: Done training -- epoch limit reached.",
                             task_as_string(self.task))

        logging.info("%s: Exited training loop.", task_as_string(self.task))
        sv.Stop()
コード例 #41
0
def build_graph(reader, model, loss_fn, batch_size, regularization_penalty):
    """Creates the Tensorflow graph.

  This will only be called once in the life of
  a training model, because after the graph is created the model will be
  restored from a meta graph file rather than being recreated.

  Args:
    reader: the input class.
    model: The core model.
    loss_fn: What kind of loss to apply to the model. It should inherit
                from BaseLoss.
    batch_size: How many examples to process at a time.
    regularization_penalty: How much weight to give the regularization loss
                            compared to the label loss.
  """
    global_step = tf.train.get_or_create_global_step()

    local_device_protos = device_lib.list_local_devices()
    gpus = [x.name for x in local_device_protos if x.device_type == 'GPU']
    gpus = gpus[:FLAGS.train_num_gpu]
    num_gpus = len(gpus)

    if num_gpus > 0:
        logging.info("Using the following GPUs to train: " + str(gpus))
        num_towers = num_gpus
        device_string = '/gpu:{}'
        logging.info("Using total batch size of {} for training "
                     "over {} GPUs: batch size of {} per GPUs.".format(
                         batch_size, num_towers, batch_size // num_towers))
    else:
        logging.info("No GPUs found. Training on CPU.")
        num_towers = 1
        device_string = '/cpu:{}'
        logging.info(
            "Using total batch size of {} for training. ".format(batch_size))

    learning_rate = LearningRate(global_step, batch_size).get_learning_rate()
    opt_img = Optimizer(learning_rate).get_optimizer()
    opt_adv = Optimizer(learning_rate).get_optimizer()

    with tf.name_scope("input"):
        images_batch, labels_batch = reader.input_fn()
    tf.summary.histogram("model/input_raw", images_batch)

    gradients_cls = ComputeAndProcessGradients()

    tower_inputs = tf.split(images_batch, num_towers)
    tower_labels = tf.split(labels_batch, num_towers)
    tower_gradients_img = []
    tower_gradients_adv = []
    tower_logits_img, tower_losses_img = [], []
    tower_losses_adv = []
    for i in range(num_towers):
        reuse = False if i == 0 else True
        with tf.device(device_string.format(i)):
            with tf.variable_scope("tower", reuse=reuse):

                logits_img, logits_adv = model.create_model(
                    tower_inputs[i],
                    n_classes=reader.n_classes,
                    is_training=True,
                    labels=tower_labels[i],
                    loss_fn=loss_fn.calculate_loss)
                tower_logits_img.append(logits_img)

                loss_img = loss_fn.calculate_loss(logits=logits_adv,
                                                  labels=tower_labels[i])
                loss_img = add_reg_to_loss(loss_img,
                                           regularization_penalty,
                                           scope="tower/network_defense")

                if FLAGS.train_attack['learn_noise_attack']:
                    loss_adv = -add_reg_to_loss(loss_img,
                                                regularization_penalty,
                                                scope="tower/network_attack")

                # Adds update_ops (e.g., moving average updates in batch norm) as
                # a dependency to the train_op.
                update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
                if update_ops:
                    with tf.control_dependencies(update_ops):
                        barrier = tf.no_op(name="gradient_barrier")
                        with tf.control_dependencies([barrier]):
                            loss_img = tf.identity(loss_img)

                var_list_img = tf.global_variables(
                    scope='tower/network_defense')
                gradients_img = gradients_cls.get_gradients(
                    opt_img, loss_img, var_list=var_list_img)
                tower_gradients_img.append(gradients_img)
                tower_losses_img.append(loss_img)

                if FLAGS.train_attack['learn_noise_attack']:
                    var_list_adv = tf.global_variables(
                        scope='tower/network_attack')
                    gradients_adv = gradients_cls.get_gradients(
                        opt_adv, loss_adv, var_list=var_list_adv)
                    tower_gradients_adv.append(gradients_adv)
                    tower_losses_adv.append(loss_adv)

    for variable in tf.trainable_variables():
        tf.summary.histogram(variable.op.name, variable)

    total_loss_img = tf.reduce_mean(tower_losses_img)
    full_gradients_img = combine_gradients(tower_gradients_img)
    tf.summary.scalar("img_loss", total_loss_img)
    train_op_cls_img = UpdateOps(opt_img)

    if FLAGS.train_attack['learn_noise_attack']:
        total_loss_adv = tf.reduce_mean(tower_losses_adv)
        full_gradients_adv = combine_gradients(tower_gradients_adv)
        tf.summary.scalar("adv_loss", total_loss_adv)
        train_op_cls_adv = UpdateOps(opt_adv, with_update=False)

    summary_op = tf.summary.merge_all()

    train_ops = []
    if FLAGS.train_attack['learn_noise_attack']:
        train_op_adv = train_op_cls_adv.make_update(full_gradients_adv)
        train_ops.append(train_op_adv)
    train_op_img = train_op_cls_img.make_update(full_gradients_img,
                                                global_step)
    train_ops.append(train_op_img)

    train_ops = tf.group(*train_ops)
    with tf.control_dependencies([train_ops]):
        train_op = tf.no_op(name='train_op')

    tf.add_to_collection("loss_img", total_loss_img)
    if FLAGS.train_attack['learn_noise_attack']:
        tf.add_to_collection("loss_adv", total_loss_adv)
    else:
        tf.add_to_collection("loss_adv", tf.constant(0.))

    tf.add_to_collection("logits", tf.concat(tower_logits_img, 0))
    tf.add_to_collection("labels", labels_batch)
    tf.add_to_collection("learning_rate", learning_rate)
    tf.add_to_collection("summary_op", summary_op)
    tf.add_to_collection("training_model", train_op)
コード例 #42
0
def inference_loop(video_ids_batch, labels_batch, inputs_batch,
                   predictions_batch, video_ids_equal, labels_equal,
                   output_dir, batch_size):

    with tf.Session() as sess:

        sess.run([tf.local_variables_initializer()])

        # Start the queue runners.
        fetches = [
            video_ids_batch, labels_batch, inputs_batch, predictions_batch,
            video_ids_equal, labels_equal
        ]
        coord = tf.train.Coordinator()
        start_time = time.time()

        video_ids = []
        video_labels = []
        video_inputs = []
        video_predictions = []
        filenum = 0
        num_examples_processed = 0
        total_num_examples_processed = 0

        directory = FLAGS.output_dir
        if not os.path.exists(directory):
            os.makedirs(directory)
        else:
            raise IOError("Output path exists! path='" + directory + "'")

        try:
            threads = []
            for qr in tf.get_collection(tf.GraphKeys.QUEUE_RUNNERS):
                threads.extend(
                    qr.create_threads(sess,
                                      coord=coord,
                                      daemon=True,
                                      start=True))

            while not coord.should_stop():
                ids_val = None
                ids_val, labels_val, inputs_val, predictions_val, ids_equal_val, labels_equal_val = sess.run(
                    fetches)

                print "ids equal = %f" % (ids_equal_val)
                print "labels equal = %f" % (labels_equal_val)

                video_ids.append(ids_val)
                video_labels.append(labels_val)
                video_inputs.append(inputs_val)
                video_predictions.append(predictions_val)
                num_examples_processed += len(ids_val)

                ids_shape = ids_val.shape[0]
                inputs_shape = inputs_val.shape[0]
                predictions_shape = predictions_val.shape[0]
                assert ids_shape == inputs_shape == predictions_shape, "tensor ids(%d), inputs(%d) and predictions(%d) should have equal rows" % (
                    ids_shape, inputs_shape, predictions_shape)

                ids_val = None

                if num_examples_processed >= FLAGS.file_size:
                    assert num_examples_processed == FLAGS.file_size, "num_examples_processed should be equal to %d" % FLAGS.file_size
                    video_ids = np.concatenate(video_ids, axis=0)
                    video_labels = np.concatenate(video_labels, axis=0)
                    video_inputs = np.concatenate(video_inputs, axis=0)
                    video_predictions = np.concatenate(video_predictions,
                                                       axis=0)
                    write_to_record(video_ids, video_labels, video_inputs,
                                    video_predictions, filenum,
                                    num_examples_processed)

                    video_ids = []
                    video_labels = []
                    video_inputs = []
                    video_predictions = []
                    filenum += 1
                    total_num_examples_processed += num_examples_processed

                    now = time.time()
                    logging.info("num examples processed: " +
                                 str(num_examples_processed) +
                                 " elapsed seconds: " +
                                 "{0:.2f}".format(now - start_time))
                    num_examples_processed = 0

        except tf.errors.OutOfRangeError as e:
            if ids_val is not None:
                video_ids.append(ids_val)
                video_labels.append(labels_val)
                video_inputs.append(inputs_val)
                video_predictions.append(predictions_val)
                num_examples_processed += len(ids_val)

            if 0 < num_examples_processed <= FLAGS.file_size:
                video_ids = np.concatenate(video_ids, axis=0)
                video_labels = np.concatenate(video_labels, axis=0)
                video_inputs = np.concatenate(video_inputs, axis=0)
                video_predictions = np.concatenate(video_predictions, axis=0)
                write_to_record(video_ids, video_labels, video_inputs,
                                video_predictions, filenum,
                                num_examples_processed)
                total_num_examples_processed += num_examples_processed

                now = time.time()
                logging.info("num examples processed: " +
                             str(num_examples_processed) +
                             " elapsed seconds: " +
                             "{0:.2f}".format(now - start_time))
                num_examples_processed = 0

            logging.info("Done with inference. %d samples was written to %s" %
                         (total_num_examples_processed, FLAGS.output_dir))
        except Exception as e:  # pylint: disable=broad-except
            logging.info("Unexpected exception: " + str(e))
        finally:
            coord.request_stop()

        coord.join(threads, stop_grace_period_secs=10)
コード例 #43
0
    def run(self, start_new_model=False):
        """Performs training on the currently defined Tensorflow graph.

    Returns:
      A tuple of the training Hit@1 and the training PERR.
    """
        if self.is_master and start_new_model and exists(self.train_dir):
            self.remove_training_directory(self.train_dir)

        pp = pprint.PrettyPrinter(indent=2, compact=True)
        logging.info(pp.pformat(FLAGS.values()))

        model_flags_dict = FLAGS.to_json()
        log_folder = '{}_logs'.format(self.train_dir)
        flags_json_path = join(log_folder, "model_flags.json")
        if not exists(flags_json_path):
            # Write the file.
            with open(flags_json_path, "w") as fout:
                fout.write(model_flags_dict)

        target, device_fn = self.start_server_if_distributed()
        meta_filename = self.get_meta_filename(start_new_model, self.train_dir)

        with tf.Graph().as_default() as graph:
            if meta_filename:
                saver = self.recover_model(meta_filename)

            with tf.device(device_fn):
                if not meta_filename:
                    saver = self.build_model(self.model, self.reader)

                loss_img = tf.get_collection("loss_img")[0]
                loss_adv = tf.get_collection("loss_adv")[0]

                global_step = tf.train.get_global_step()
                logits = tf.get_collection("logits")[0]
                labels = tf.get_collection("labels")[0]
                learning_rate = tf.get_collection("learning_rate")[0]
                train_op = tf.get_collection("training_model")[0]
                summary_op = tf.get_collection("summary_op")[0]
                init_op = tf.global_variables_initializer()

                gradients_norm = tf.get_collection("gradients_norm")[0]

            scaffold = tf.train.Scaffold(
                saver=saver,
                init_op=init_op,
                summary_op=summary_op,
            )

            hooks = [
                tf.train.NanTensorHook(loss_img + loss_adv),
                tf.train.StopAtStepHook(num_steps=self.max_steps),
            ]

            session_args = dict(
                is_chief=self.is_master,
                scaffold=scaffold,
                checkpoint_dir=FLAGS.train_dir,
                hooks=hooks,
                save_checkpoint_steps=FLAGS.save_checkpoint_steps,
                save_summaries_steps=FLAGS.save_summaries_steps,
                save_summaries_secs=None,
                log_step_count_steps=0,
                config=self.config,
            )

            logging.info("Start training")
            with tf.train.MonitoredTrainingSession(**session_args) as sess:

                summary_writer = tf.summary.FileWriterCache.get(
                    FLAGS.train_dir)

                if FLAGS.profiler:
                    profiler = tf.profiler.Profiler(sess.graph)

                global_step_val = 0
                while not sess.should_stop():

                    make_profile = False
                    profile_args = {}

                    if global_step_val % 1000 == 0 and FLAGS.profiler:
                        make_profile = True
                        run_meta = tf.RunMetadata()
                        profile_args = {
                            'options':
                            tf.RunOptions(
                                trace_level=tf.RunOptions.FULL_TRACE),
                            'run_metadata':
                            run_meta
                        }

                    fetches = OrderedDict(train_op=train_op,
                                          global_step=global_step,
                                          loss_img=loss_img,
                                          loss_adv=loss_adv,
                                          learning_rate=learning_rate,
                                          logits=logits,
                                          labels=labels)

                    if gradients_norm != 0:
                        fetches += [gradients_norm]
                    else:
                        grad_norm_val = 0

                    batch_start_time = time.time()
                    values = sess.run(list(fetches.values()), **profile_args)
                    fetches_values = OrderedDict(zip(fetches.keys(), values))
                    seconds_per_batch = time.time() - batch_start_time
                    examples_per_second = self.batch_size / seconds_per_batch

                    global_step_val = fetches_values['global_step']
                    loss_img_val = fetches_values['loss_img']
                    loss_adv_val = fetches_values['loss_adv']
                    learning_rate_val = fetches_values['learning_rate']
                    predictions_val = fetches_values['logits']
                    labels_val = fetches_values['labels']

                    if gradients_norm != 0:
                        grad_norm_val = fetches_values['gradients_norm']

                    if FLAGS.gradients['compute_hessian'] and global_step_val != 0 and \
                       global_step_val % FLAGS.gradients['hessian_every_n_step'] == 0:
                        compute_hessian_and_summary(sess, summary_writer,
                                                    global_step_val)

                    if make_profile and FLAGS.profiler:
                        logging.info('dump make profile')
                        profiler.add_step(global_step_val, run_meta)

                        # Profile the parameters of your model.
                        profiler.profile_name_scope(
                            options=(tf.profiler.ProfileOptionBuilder.
                                     trainable_variables_parameter()))

                        # Or profile the timing of your model operations.
                        opts = tf.profiler.ProfileOptionBuilder.time_and_memory(
                        )
                        profiler.profile_operations(options=opts)

                        # Or you can generate a timeline:
                        opts = (tf.profiler.ProfileOptionBuilder(
                            tf.profiler.ProfileOptionBuilder.time_and_memory(
                            )).with_step(global_step_val).with_timeline_output(
                                '~/profile.logs').build())
                        profiler.profile_graph(options=opts)

                    to_print = global_step_val % FLAGS.frequency_log_steps == 0
                    if (self.is_master and to_print) or global_step_val == 1:
                        epoch = ((global_step_val * self.batch_size) /
                                 self.reader.n_train_files)

                        message = MessageBuilder()
                        message.add("epoch", epoch, format="4.2f")
                        message.add("step",
                                    global_step_val,
                                    width=5,
                                    format=".0f")
                        message.add("lr", learning_rate_val, format=".6f")
                        message.add("img loss", loss_img_val, format=".4f")
                        message.add("adv loss", -loss_adv_val, format=".4f")
                        if "YT8M" in self.reader.__class__.__name__:
                            gap = eval_util.calculate_gap(
                                predictions_val, labels_val)
                            message.add("gap", gap, format=".3f")
                        message.add("imgs/sec",
                                    examples_per_second,
                                    width=5,
                                    format=".0f")
                        if FLAGS.gradients['perturbed_gradients']:
                            message.add("grad norm",
                                        grad_norm_val,
                                        format=".4f")
                        logging.info(message.get_message())

                # End training
                logging.info(
                    "{}: Done training -- epoch limit reached.".format(
                        task_as_string(self.task)))
                if FLAGS.profiler:
                    profiler.advise()
        logging.info("{}: Exited training loop.".format(
            task_as_string(self.task)))
コード例 #44
0
def train_loop(train_dir=None,
               saver=None,
               is_chief=True,
               master="",
               start_supervisor_services=True):
    """Performs training on the currently defined tensorflow graph.

  Args:
    train_dir: Where to save the model checkpoints.
    saver: The class to use for serializing the graph variables.
    is_chief: Whether this worker is the primary worker (which is responsible
    for writing checkpoints and summaries), or an anonymous member of the flock.
    master: Which Tensorflow master to listen to.
    start_supervisor_services: Whether to start threads for writing summaries
      and checkpoints.

  Returns:
  A tuple of the training Hit@1 and the training PERR.
  """
    global_step = tf.get_collection("global_step")[0]
    loss = tf.get_collection("loss")[0]
    predictions = tf.get_collection("predictions")[0]
    labels = tf.get_collection("labels")[0]
    train_op = tf.get_collection("train_op")[0]

    sv = tf.train.Supervisor(logdir=train_dir,
                             is_chief=is_chief,
                             global_step=global_step,
                             save_model_secs=60,
                             save_summaries_secs=60,
                             saver=saver)
    sess = sv.prepare_or_wait_for_session(
        master,
        start_standard_services=start_supervisor_services,
        config=tf.ConfigProto(log_device_placement=False))

    logging.info("prepared session")
    sv.start_queue_runners(sess)
    logging.info("started queue runners")

    try:
        logging.info("entering training loop")
        while not sv.should_stop():
            batch_start_time = time.time()
            _, global_step_val, loss_val, predictions_val, labels_val = sess.run(
                [train_op, global_step, loss, predictions, labels])
            seconds_per_batch = time.time() - batch_start_time
            examples_per_second = labels_val.shape[0] / seconds_per_batch

            hit_at_one = eval_util.calculate_hit_at_one(
                predictions_val, labels_val)
            perr = eval_util.calculate_precision_at_equal_recall_rate(
                predictions_val, labels_val)
            gap = eval_util.calculate_gap(predictions_val, labels_val)

            logging.info("training step " + str(global_step_val) +
                         "| Hit@1: " + ("%.2f" % hit_at_one) + " PERR: " +
                         ("%.2f" % perr) + " GAP: " + ("%.2f" % gap) +
                         " Loss: " + str(loss_val))
            if is_chief and global_step_val % 10 == 0 and train_dir:
                sv.summary_writer.add_summary(
                    utils.MakeSummary("model/Training_Hit@1", hit_at_one),
                    global_step_val)
                sv.summary_writer.add_summary(
                    utils.MakeSummary("model/Training_Perr", perr),
                    global_step_val)
                sv.summary_writer.add_summary(
                    utils.MakeSummary("model/Training_GAP", gap),
                    global_step_val)
                sv.summary_writer.add_summary(
                    utils.MakeSummary("global_step/Examples/Second",
                                      examples_per_second), global_step_val)
                sv.summary_writer.flush()
    except tf.errors.OutOfRangeError:
        logging.info("Done training -- epoch limit reached")
    logging.info("exited training loop")
    sv.Stop()
    return hit_at_one, perr
コード例 #45
0
ファイル: train.py プロジェクト: lvaleriu/Youtube-8M-WILLOW
 def recover_model(self, meta_filename):
   logging.info("%s: Restoring from meta graph file %s",
                task_as_string(self.task), meta_filename)
   return tf.train.import_meta_graph(meta_filename)
コード例 #46
0
def main(unused_argv):
    logging.set_verbosity(tf.logging.INFO)
    print("tensorflow version: %s" % tf.__version__)
    is_chief = (FLAGS.task == 0)

    # Recover session
    saver = None
    latest_checkpoint = tf.train.latest_checkpoint(FLAGS.train_dir)
    if FLAGS.start_new_model:
        logging.info(
            "'start_new_model' flag is set. Removing existing train dir.")
        try:
            gfile.DeleteRecursively(FLAGS.train_dir)
        except:
            logging.error(
                "Failed to delete directory " + FLAGS.train_dir +
                " when starting a new model. Please delete it manually and" +
                " try again.")
    elif not latest_checkpoint:
        logging.info("No checkpoint file found. Building a new model.")
    else:
        meta_filename = latest_checkpoint + ".meta"
        if not gfile.Exists(meta_filename):
            logging.info("No meta graph file found. Building a new model.")
        else:
            logging.info("Restoring from meta graph file %s", meta_filename)
            saver = tf.train.import_meta_graph(meta_filename)

    if not saver:
        # convert feature_names and feature_sizes to lists of values
        feature_names, feature_sizes = utils.GetListOfFeatureNamesAndSizes(
            FLAGS.feature_names, FLAGS.feature_sizes)

        if FLAGS.frame_features:
            reader = readers.YT8MFrameFeatureReader(
                feature_names=feature_names, feature_sizes=feature_sizes)
        else:
            reader = readers.YT8MAggregatedFeatureReader(
                feature_names=feature_names, feature_sizes=feature_sizes)

        model = find_class_by_name(FLAGS.model,
                                   [frame_level_models, video_level_models])()
        label_loss_fn = find_class_by_name(FLAGS.label_loss, [losses])()
        optimizer_class = find_class_by_name(FLAGS.optimizer, [tf.train])
        build_graph(reader=reader,
                    model=model,
                    optimizer_class=optimizer_class,
                    train_data_pattern=FLAGS.train_data_pattern,
                    label_loss_fn=label_loss_fn,
                    base_learning_rate=FLAGS.base_learning_rate,
                    regularization_penalty=FLAGS.regularization_penalty,
                    num_readers=FLAGS.num_readers,
                    batch_size=FLAGS.batch_size,
                    num_epochs=FLAGS.epochs)
        logging.info("built graph")
        saver = tf.train.Saver()

    train_loop(is_chief=is_chief,
               train_dir=FLAGS.train_dir,
               saver=saver,
               master=FLAGS.master)
def evaluate():
  tf.set_random_seed(0)  # for reproducibility

  # Write json of flags
  # model_flags_path = os.path.join(FLAGS.train_dir, "model_flags.json")
  # if not file_io.file_exists(model_flags_path):
  #   raise IOError(("Cannot find file %s. Did you run train.py on the same "
  #                  "--train_dir?") % model_flags_path)
  # flags_dict = json.loads(file_io.FileIO(model_flags_path, mode="r").read())
  all_eval_data_patterns = []
  with open(FLAGS.eval_data_config) as f:
    all_eval_data_patterns = f.read().splitlines()

  with tf.Graph().as_default():
    # convert feature_names and feature_sizes to lists of values
    # feature_names, feature_sizes = utils.GetListOfFeatureNamesAndSizes(
    #     flags_dict["feature_names"], flags_dict["feature_sizes"])

    # prepare a reader for each single model prediction result
    all_readers = []

    for i in xrange(len(all_eval_data_patterns)):
      reader = readers.EnsembleReader(
          feature_names=[FLAGS.feature_names], feature_sizes=[FLAGS.feature_sizes])
      all_readers.append(reader)

    input_reader = None
    input_data_pattern = None
    
    # model = find_class_by_name(flags_dict["model"],
    #     [frame_level_models, video_level_models])()
    model = ensemble_model.MeanModel()
    label_loss_fn = find_class_by_name("CrossEntropyLoss", [losses])()

    build_graph(
        all_readers=all_readers,
        all_eval_data_patterns = all_eval_data_patterns, 
        model=model,
        label_loss_fn=label_loss_fn,
        num_readers=FLAGS.num_readers,
        batch_size=FLAGS.batch_size)
    
    logging.info("built evaluation graph")
    video_id_batch = tf.get_collection("video_id_batch")[0]
    prediction_batch = tf.get_collection("predictions")[0]
    label_batch = tf.get_collection("labels")[0]
    loss = tf.get_collection("loss")[0]
    summary_op = tf.get_collection("summary_op")[0]

    saver = tf.train.Saver(tf.global_variables())
    summary_writer = tf.summary.FileWriter(
        FLAGS.train_dir, graph=tf.get_default_graph())

    evl_metrics = eval_util.EvaluationMetrics(reader.num_classes, FLAGS.top_k)

    last_global_step_val = -1
    while True:
      last_global_step_val = evaluation_loop(video_id_batch, prediction_batch,
                                             label_batch, loss, summary_op,
                                             saver, summary_writer, evl_metrics,
                                             last_global_step_val)
      if FLAGS.run_once:
        break
コード例 #48
0
ファイル: train_hdfs.py プロジェクト: KangHsi/youtube-8m
    def run(self, start_new_model=False):
        """Performs training on the currently defined Tensorflow graph.

    Returns:
      A tuple of the training Hit@1 and the training PERR.
    """
        if self.is_master and start_new_model:
            self.remove_training_directory(self.train_dir)

        target, device_fn = self.start_server_if_distributed()

        meta_filename = self.get_meta_filename(start_new_model, self.train_dir)

        with tf.Graph().as_default() as graph:

            if meta_filename:
                saver = self.recover_model(meta_filename)

            self.downloader = create_hdfs_downloader(
                FLAGS.train_data_pattern,
                FLAGS.data_tmp_dir,
                num_downloader=FLAGS.num_downloaders,
                enqueue_size=FLAGS.enqueue_size)

            with tf.device(device_fn):
                if not meta_filename:
                    saver = self.build_model(self.model, self.reader,
                                             self.downloader)

                global_step = tf.get_collection("global_step")[0]
                loss = tf.get_collection("loss")[0]
                predictions = tf.get_collection("predictions")[0]
                labels = tf.get_collection("labels")[0]
                train_op = tf.get_collection("train_op")[0]
                init_op = tf.global_variables_initializer()

        sv = tf.train.Supervisor(graph,
                                 logdir=self.train_dir,
                                 init_op=init_op,
                                 is_chief=self.is_master,
                                 global_step=global_step,
                                 save_model_secs=0,
                                 save_summaries_secs=360,
                                 saver=saver)

        logging.info("%s: Starting managed session.",
                     task_as_string(self.task))
        with sv.managed_session(target, config=self.config) as sess:
            try:
                # create a thread for enqueuing
                enqueue_thread = threading.Thread(
                    target=self.downloader.enqueuing,
                    args=(sess, ),
                    kwargs={'num_epochs': FLAGS.num_epochs})
                enqueue_thread.daemon = True
                enqueue_thread.start()

                logging.info("%s: Entering training loop.",
                             task_as_string(self.task))
                while (not sv.should_stop()) and (not self.max_steps_reached):
                    batch_start_time = time.time()
                    _, global_step_val, loss_val, predictions_val, labels_val = sess.run(
                        [train_op, global_step, loss, predictions, labels])
                    seconds_per_batch = time.time() - batch_start_time
                    examples_per_second = labels_val.shape[
                        0] / seconds_per_batch

                    if self.max_steps and self.max_steps <= global_step_val:
                        self.max_steps_reached = True

                    if self.is_master and global_step_val % self.disp_batches == 0 and self.train_dir:
                        eval_start_time = time.time()
                        hit_at_one = eval_util.calculate_hit_at_one(
                            predictions_val, labels_val)
                        perr = eval_util.calculate_precision_at_equal_recall_rate(
                            predictions_val, labels_val)
                        gap = eval_util.calculate_gap(predictions_val,
                                                      labels_val)
                        eval_end_time = time.time()
                        eval_time = eval_end_time - eval_start_time

                        logging.info("training step " + str(global_step_val) +
                                     " | Loss: " + ("%.2f" % loss_val) +
                                     " Examples/sec: " +
                                     ("%.2f" % examples_per_second) +
                                     " | Hit@1: " + ("%.4f" % hit_at_one) +
                                     " PERR: " + ("%.4f" % perr) + " GAP: " +
                                     ("%.4f" % gap))

                        sv.summary_writer.add_summary(
                            utils.MakeSummary("model/Training_Hit@1",
                                              hit_at_one), global_step_val)
                        sv.summary_writer.add_summary(
                            utils.MakeSummary("model/Training_Perr", perr),
                            global_step_val)
                        sv.summary_writer.add_summary(
                            utils.MakeSummary("model/Training_GAP", gap),
                            global_step_val)
                        sv.summary_writer.add_summary(
                            utils.MakeSummary("global_step/Examples/Second",
                                              examples_per_second),
                            global_step_val)
                        sv.summary_writer.flush()

                        # Exporting the model every x steps
                        time_to_export = (
                            (self.last_model_export_step == 0)
                            or (global_step_val - self.last_model_export_step
                                >= self.export_model_steps))

                        if self.is_master and time_to_export:
                            self.export_model(global_step_val, sv.saver,
                                              sv.save_path, sess)
                            self.last_model_export_step = global_step_val
                    #else:
                    logging.info("training step " + str(global_step_val) +
                                 " | Loss: " + ("%.4f" % loss_val) +
                                 " Examples/sec: " +
                                 ("%.4f" % examples_per_second))
            except tf.errors.OutOfRangeError:
                logging.info("%s: Done training -- epoch limit reached.",
                             task_as_string(self.task))
            except KeyboardInterrupt:
                pass
            finally:
                self.downloader.stop()
                logging.info('Donwloader stopped')

        logging.info("%s: Exited training loop.", task_as_string(self.task))
        sv.Stop()