class RunAfterCheckpointHook(session_run_hook.SessionRunHook):
    """ Runs a certain callback function right after a checkpoint has been saved. 
      We use this to generate some text at regular intervals during the training to show the progress. 
      Note that it restores the model from a checkpoint, which is why it needs to happen with the same 
      interval as checkpoint saving. """
    def __init__(self, run_config, callback):
        self._timer = SecondOrStepTimer(
            every_secs=run_config.save_checkpoints_secs,
            every_steps=run_config.save_checkpoints_steps)
        self.callback = callback
        self.is_first_run = True

    def begin(self):
        self._global_step_tensor = training_util._get_or_create_global_step_read(
        )  # pylint: disable=protected-access

    def after_run(self, run_context, run_values):
        global_step = run_context.session.run(self._global_step_tensor)

        if self._timer.should_trigger_for_step(global_step):
            self._timer.update_last_triggered_step(global_step)

            # the timer will tell us that it needs to trigger on the very first run, which does not make sense.
            if not self.is_first_run:
                self.callback()
            else:
                self.is_first_run = False
Example #2
0
class BlastHook(session_run_hook.SessionRunHook):
    """Hook that counts steps per second."""
    def __init__(self,
                 summary,
                 config,
                 id_to_enzyme_class,
                 every_n_steps=1200,
                 every_n_secs=None,
                 output_dir=None,
                 summary_writer=None,
                 n_examples=2,
                 running_mode="train"):

        self._timer = SecondOrStepTimer(every_steps=every_n_steps,
                                        every_secs=every_n_secs)
        self.summary = summary
        self.config = config
        self.summary_writer = summary_writer
        self.output_dir = output_dir
        self.last_global_step = None
        self.id_to_enzyme_class = id_to_enzyme_class
        self.global_step_check_count = 0
        self.steps_per_run = 1
        self.n_examples = n_examples,
        self.running_mode = running_mode

    def _set_steps_per_run(self, steps_per_run):
        self.steps_per_run = steps_per_run

    def begin(self):
        if self.summary_writer is None and self.output_dir:
            self.summary_writer = SummaryWriterCache.get(self.output_dir)
        graph = ops.get_default_graph()
        self.fake_seq = graph.get_tensor_by_name("model/" + FAKE_PROTEINS +
                                                 ":0")
        self.labels = graph.get_tensor_by_name("model/" + LABELS + ":0")
        self.d_score = graph.get_tensor_by_name("model/d_score:0")
        self.global_step_tensor = training_util._get_or_create_global_step_read(
        )
        if self.global_step_tensor is None:
            raise RuntimeError("Could not global step tensor")
        if self.fake_seq is None:
            raise RuntimeError("Could not get fake seq tensor")

    def before_run(self, run_context):  # pylint: disable=unused-argument
        return SessionRunArgs([
            self.global_step_tensor, self.fake_seq, self.labels, self.d_score
        ])

    def after_run(self, run_context, run_values):
        global_step, fake_seq, labels, d_score = run_values.results
        if self._timer.should_trigger_for_step(global_step):
            # fake_seq, real_seq, labels = run_context.session.run([self._fake_seq, self._real_seq, self._labels])
            self.summary(self.config, self.summary_writer, global_step,
                         fake_seq, labels, self.id_to_enzyme_class,
                         self.n_examples[0], self.running_mode,
                         d_score).start()
            self._timer.update_last_triggered_step(global_step)
Example #3
0
class IntervalHook(tf.train.SessionRunHook):
    """
    A hook which runs every # of iterations. Useful for subclassing.
    """
    def __init__(self, interval):
        """
        Construct the interval.

        :param interval: The interval.
        """
        self.global_step = None
        self.interval = interval

        if interval is not None:
            self.timer = SecondOrStepTimer(every_steps=interval)
        else:
            self.timer = None

    def begin(self):
        self.global_step = tf.train.get_or_create_global_step()

    def before_run(self, run_context):
        return tf.train.SessionRunArgs(
            [self.global_step, *self.session_run_args(run_context)])

    # noinspection PyMethodMayBeStatic, PyUnusedLocal
    def session_run_args(self, run_context):  # pylint: disable=unused-argument
        """
        Create the session run arguments.

        :param run_context: The run context.
        :return: The list of arguments to run.
        """
        return list()

    def after_run(self, run_context, run_values):
        if self.interval is None:
            return

        global_step = run_values.results[0]
        if self.timer.should_trigger_for_step(global_step):
            self.timer.update_last_triggered_step(global_step)
            self.run_interval_operations(run_context, run_values.results[1:],
                                         global_step)

    @abc.abstract
    def run_interval_operations(self, run_context, results, global_step):
        """
        The method to override.

        :param run_context: The run context.
        :param results: The results of running the given arguments.
        :param global_step: The evaluated global step tensor.
        """
        pass
Example #4
0
class BestSaverHook(tf.train.CheckpointSaverHook):

    def __init__(self, checkpoint_dir, save_secs=None, save_steps=None, saver=None,
                 checkpoint_basename="model.ckpt", scaffold=None, listeners=None):

        self.saver_listener = listeners[0]
        super(BestSaverHook, self).__init__(checkpoint_dir, save_secs, save_steps, saver,
                                            checkpoint_basename, scaffold, listeners)

        logging.info("Create CheckpointSaverHook.")
        if saver is not None and scaffold is not None:
            raise ValueError("You cannot provide both saver and scaffold.")
        self._saver = saver
        self._checkpoint_dir = checkpoint_dir
        self._save_path = os.path.join(checkpoint_dir, checkpoint_basename)
        self._scaffold = scaffold
        self._timer = SecondOrStepTimer(every_secs=save_secs,
                                        every_steps=save_steps)
        self._listeners = listeners or []

        print('__init__ listeners:{}, {}'.format(len(listeners), len(self._listeners)))

    # def after_run(self, run_context, run_values):
    #     print('EarlyStoppingHook:{}'.format(run_values.results))
    #     super(EarlyStoppingHook, self).after_run(run_context, run_values)
    #     if self.saver_listener.should_stop():
    #         run_context.request_stop()

    def after_run(self, run_context, run_values):
        # print('EarlyStoppingHook:{}'.format(run_values.results))
        stale_global_step = run_values.results
        if self._timer.should_trigger_for_step(stale_global_step+1):
            global_step = run_context.session.run(self._global_step_tensor)
            if self._timer.should_trigger_for_step(global_step):
                self._timer.update_last_triggered_step(global_step)
                self._save(run_context.session, global_step)

        if self.saver_listener.should_stop(run_context.session):
            print('early stop')
            run_context.request_stop()

    def _save(self, session, step):
        """Saves the latest checkpoint."""
        self.saver_listener.before_save(session, step)
        self._get_saver().save(session, self._save_path, global_step=step)
        self._summary_writer.add_session_log(SessionLog(status=SessionLog.CHECKPOINT, checkpoint_path=self._save_path), step)
        self.saver_listener.after_save(session, step)
        logging.info("Saving checkpoints for %d into %s.", step, self._save_path)
Example #5
0
class MetadataHook(SessionRunHook):
    def __init__(self, save_steps=None, save_secs=None, output_dir=""):
        self._output_tag = "step-{}"
        self._output_dir = output_dir
        self._save_steps = save_steps
        self._timer = SecondOrStepTimer(every_secs=save_secs,
                                        every_steps=save_steps)

    def begin(self):
        self._next_step = None
        self._global_step_tensor = training_util.get_global_step()
        self._writer = tf.summary.FileWriter(self._output_dir,
                                             tf.get_default_graph())

        if self._global_step_tensor is None:
            raise RuntimeError(
                "Global step should be created to use ProfilerHook.")

    def before_run(self, run_context):
        self._request_summary = (self._next_step is None
                                 or self._timer.should_trigger_for_step(
                                     self._next_step))
        requests = {"global_step": self._global_step_tensor}
        opts = (tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE)
                if self._request_summary else None)
        return SessionRunArgs(requests, options=opts)

    def after_run(self, run_context, run_values):
        stale_global_step = run_values.results['global_step']
        if self._next_step is None:
            self._timer.update_last_triggered_step(stale_global_step)

        global_step = stale_global_step + 1
        if self._request_summary:
            global_step = run_context.session.run(self._global_step_tensor)
            self._timer.update_last_triggered_step(global_step)
            self._writer.add_run_metadata(run_values.run_metadata,
                                          self._output_tag.format(global_step))
            self._writer.flush()
        self._next_step = global_step + 1

    def end(self, session):
        last_step = session.run(self._global_step_tensor)
        if last_step != self._timer.last_triggered_step():
            self._writer.close()
Example #6
0
class TrainSampleHook(TrainingHook):
  """Occasionally samples predictions from the training run and prints them.

  Params:
    every_n_secs: Sample predictions every N seconds.
      If set, `every_n_steps` must be None.
    every_n_steps: Sample predictions every N steps.
      If set, `every_n_secs` must be None.
    sample_dir: Optional, a directory to write samples to.
    delimiter: Join tokens on this delimiter. Defaults to space.
  """

  #pylint: disable=missing-docstring

  def __init__(self, params, model_dir, run_config):
    super(TrainSampleHook, self).__init__(params, model_dir, run_config)
    self._sample_dir = os.path.join(self.model_dir, "samples")
    self._timer = SecondOrStepTimer(
        every_secs=self.params["every_n_secs"],
        every_steps=self.params["every_n_steps"])
    self._pred_dict = {}
    self._should_trigger = False
    self._iter_count = 0
    self._global_step = None
    self._source_delimiter = self.params["source_delimiter"]
    self._target_delimiter = self.params["target_delimiter"]

  @staticmethod
  def default_params():
    return {
        "every_n_secs": None,
        "every_n_steps": 1000,
        "source_delimiter": " ",
        "target_delimiter": " "
    }

  def begin(self):
    self._iter_count = 0
    self._global_step = tf.train.get_global_step()
    self._pred_dict = graph_utils.get_dict_from_collection("predictions")
    # Create the sample directory
    if self._sample_dir is not None:
      gfile.MakeDirs(self._sample_dir)

  def before_run(self, _run_context):
    self._should_trigger = self._timer.should_trigger_for_step(self._iter_count)
    if self._should_trigger:
      fetches = {
          "predicted_tokens": self._pred_dict["predicted_tokens"],
          "target_words": self._pred_dict["labels.target_tokens"],
          "target_len": self._pred_dict["labels.target_len"]
      }
      return tf.train.SessionRunArgs([fetches, self._global_step])
    return tf.train.SessionRunArgs([{}, self._global_step])

  def after_run(self, _run_context, run_values):
    result_dict, step = run_values.results
    self._iter_count = step

    if not self._should_trigger:
      return None

    # Convert dict of lists to list of dicts
    result_dicts = [
        dict(zip(result_dict, t)) for t in zip(*result_dict.values())
    ]

    # Print results
    result_str = ""
    result_str += "Prediction followed by Target @ Step {}\n".format(step)
    result_str += ("=" * 100) + "\n"
    for result in result_dicts:
      target_len = result["target_len"]
      predicted_slice = result["predicted_tokens"][:target_len - 1]
      target_slice = result["target_words"][1:target_len]
      result_str += self._target_delimiter.encode("utf-8").join(
          predicted_slice).decode("utf-8") + "\n"
      result_str += self._target_delimiter.encode("utf-8").join(
          target_slice).decode("utf-8") + "\n\n"
    result_str += ("=" * 100) + "\n\n"
    tf.logging.info(result_str)
    if self._sample_dir:
      filepath = os.path.join(self._sample_dir,
                              "samples_{:06d}.txt".format(step))
      with gfile.GFile(filepath, "w") as file:
        file.write(result_str)
    self._timer.update_last_triggered_step(self._iter_count - 1)
Example #7
0
class TrainSampleHook(TrainingHook):
  """Occasionally samples predictions from the training run and prints them.

  Params:
    every_n_secs: Sample predictions every N seconds.
      If set, `every_n_steps` must be None.
    every_n_steps: Sample predictions every N steps.
      If set, `every_n_secs` must be None.
    sample_dir: Optional, a directory to write samples to.
    delimiter: Join tokens on this delimiter. Defaults to space.
  """

  #pylint: disable=missing-docstring

  def __init__(self, params, model_dir, run_config):
    super(TrainSampleHook, self).__init__(params, model_dir, run_config)
    self._sample_dir = os.path.join(self.model_dir, "samples")
    self._timer = SecondOrStepTimer(
        every_secs=self.params["every_n_secs"],
        every_steps=self.params["every_n_steps"])
    self._pred_dict = {}
    self._should_trigger = False
    self._iter_count = 0
    self._global_step = None
    self._source_delimiter = self.params["source_delimiter"]
    self._target_delimiter = self.params["target_delimiter"]

  @staticmethod
  def default_params():
    return {
        "every_n_secs": None,
        "every_n_steps": 1000,
        "source_delimiter": " ",
        "target_delimiter": " "
    }

  def begin(self):
    self._iter_count = 0
    self._global_step = tf.train.get_global_step()
    self._pred_dict = graph_utils.get_dict_from_collection("predictions")
    # Create the sample directory
    if self._sample_dir is not None:
      gfile.MakeDirs(self._sample_dir)

  def before_run(self, _run_context):
    self._should_trigger = self._timer.should_trigger_for_step(self._iter_count)
    if self._should_trigger:
      fetches = {
          "predicted_tokens": self._pred_dict["predicted_tokens"],
          "target_words": self._pred_dict["labels.target_tokens"],
          "target_len": self._pred_dict["labels.target_len"]
      }
      return tf.train.SessionRunArgs([fetches, self._global_step])
    return tf.train.SessionRunArgs([{}, self._global_step])

  def after_create_session(self, session, coord):
    print("Session created. Finalizing graph.")
    session.graph.finalize()

  def after_run(self, _run_context, run_values):
    result_dict, step = run_values.results
    self._iter_count = step

    if not self._should_trigger:
      return None

    # Convert dict of lists to list of dicts
    result_dicts = [
        dict(zip(result_dict, t)) for t in zip(*result_dict.values())
    ]

    # Print results
    result_str = ""
    result_str += "Prediction followed by Target @ Step {}\n".format(step)
    result_str += ("=" * 100) + "\n"
    for result in result_dicts:
      target_len = result["target_len"]
      predicted_slice = result["predicted_tokens"][:target_len - 1]
      target_slice = result["target_words"][1:target_len]
      result_str += self._target_delimiter.encode("utf-8").join(
          predicted_slice).decode("utf-8") + "\n"
      result_str += self._target_delimiter.encode("utf-8").join(
          target_slice).decode("utf-8") + "\n\n"
    result_str += ("=" * 100) + "\n\n"
    tf.logging.info(result_str)
    if self._sample_dir:
      filepath = os.path.join(self._sample_dir,
                              "samples_{:06d}.txt".format(step))
      with gfile.GFile(filepath, "w") as file:
        file.write(result_str)
    self._timer.update_last_triggered_step(self._iter_count - 1)
Example #8
0
class ProfilerHook(session_run_hook.SessionRunHook):
  """Captures CPU/GPU profiling information every N steps or seconds.

  This produces files called "timeline-<step>.json", which are in Chrome
  Trace format.

  For more information see:
  https://github.com/catapult-project/catapult/blob/master/tracing/README.md"""

  def __init__(self,
               save_steps=None,
               save_secs=None,
               output_dir="",
               show_dataflow=True,
               show_memory=False):
    """Initializes a hook that takes periodic profiling snapshots.

    Args:
      save_steps: `int`, save profile traces every N steps. Exactly one of
          `save_secs` and `save_steps` should be set.
      save_secs: `int`, save profile traces every N seconds.
      output_dir: `string`, the directory to save the profile traces to.
          Defaults to the current directory.
      show_dataflow: `bool`, if True, add flow events to the trace connecting
          producers and consumers of tensors.
      show_memory: `bool`, if True, add object snapshot events to the trace
          showing the sizes and lifetimes of tensors.
    """
    self._output_file = os.path.join(output_dir, "timeline-{}.json")
    self._show_dataflow = show_dataflow
    self._show_memory = show_memory
    self._timer = SecondOrStepTimer(every_secs=save_secs,
                                    every_steps=save_steps)

  def begin(self):
    self._next_step = None
    self._global_step_tensor = training_util.get_global_step()
    if self._global_step_tensor is None:
      raise RuntimeError(
          "Global step should be created to use ProfilerHook.")

  def before_run(self, run_context):
    self._request_summary = (
        self._next_step is None or
        self._timer.should_trigger_for_step(self._next_step))
    requests = {"global_step": self._global_step_tensor}
    opts = (config_pb2.RunOptions(trace_level=config_pb2.RunOptions.FULL_TRACE)
            if self._request_summary else None)

    return SessionRunArgs(requests, options=opts)

  def after_run(self, run_context, run_values):
    global_step = run_values.results["global_step"]

    if self._request_summary:
      self._timer.update_last_triggered_step(global_step)
      self._save(global_step,
                 self._output_file.format(global_step),
                 run_values.run_metadata.step_stats)

    self._next_step = global_step + 1

  def _save(self, step, save_path, step_stats):
    logging.info("Saving timeline for %d into '%s'.", step, save_path)
    with gfile.Open(save_path, "w") as f:
      trace = timeline.Timeline(step_stats)
      f.write(trace.generate_chrome_trace_format(
          show_dataflow=self._show_dataflow,
          show_memory=self._show_memory))
Example #9
0
class EvalHook(SessionRunHook):
    def __init__(self, estimator, eval_features, max_seq_length, eval_steps,
                 save_model_dir, th, output_dir):
        self.estimator = estimator
        self.eval_features = eval_features
        self.max_seq_length = max_seq_length
        self.eval_steps = eval_steps
        self.save_model_dir = save_model_dir
        self.th = th
        self.output_dir = output_dir

        if os.path.exists(self.save_model_dir) is False:
            os.mkdir(self.save_model_dir)
        self._timer = SecondOrStepTimer(every_steps=eval_steps)
        self._steps_per_run = 1
        self._global_step_tensor = None

    def _set_steps_per_run(self, steps_per_run):
        self._steps_per_run = steps_per_run

    def begin(self):
        # self._summary_writer = SummaryWriterCache.get(self._checkpoint_dir)
        self._global_step_tensor = get_global_step()  # pylint: disable=protected-access
        if self._global_step_tensor is None:
            raise RuntimeError(
                "Global step should be created to use CheckpointSaverHook.")

    def before_run(self, run_context):  # pylint: disable=unused-argument
        return SessionRunArgs(self._global_step_tensor)

    def after_run(self, run_context, run_values):
        # print(run_values.results)
        stale_global_step = run_values.results
        if self._timer.should_trigger_for_step(stale_global_step +
                                               self._steps_per_run):
            # get the real value after train op.
            global_step = run_context.session.run(self._global_step_tensor)
            if self._timer.should_trigger_for_step(global_step):
                self._timer.update_last_triggered_step(global_step)
                metrics = self.evaluation_v2(global_step)
                # print("================", MAP, MRR, self.th, type(MAP), type(MRR), type(self.th))
                if metrics["sum_acc"] * 100 > self.th:
                    # print("================", MAP, MRR)
                    self._save(run_context.session, global_step, metrics)

    def end(self, session):
        last_step = session.run(self._global_step_tensor)
        if last_step != self._timer.last_triggered_step():
            metrics = self.evaluation_v2(last_step)
            if metrics["sum_acc"] * 100 > self.th:
                self._save(session, last_step, metrics)

    def evaluation(self, global_step):
        dev_input_fn = input_fn_builder(input_features=self.eval_features,
                                        seq_length=self.max_seq_length,
                                        is_training=False,
                                        drop_remainder=False)
        predictions = self.estimator.predict(dev_input_fn,
                                             yield_single_examples=True)

        results = []
        logits = []
        for item in tqdm(predictions):
            pos_logit = item["pos_logit"]
            neg_logit = item["neg_logit"]
            logits.append((pos_logit, neg_logit))
            results.append(1 if pos_logit > neg_logit else 0)

        acc = sum(results) / len(results)
        print(f"global_step: {global_step}, acc: {acc}")
        return {"acc": acc}

    def evaluation_v2(self, global_step):
        dev_input_fn = input_fn_builder(input_features=self.eval_features,
                                        seq_length=self.max_seq_length,
                                        is_training=False,
                                        drop_remainder=False)
        predictions = self.estimator.predict(dev_input_fn,
                                             yield_single_examples=True)

        examples_pred = {}
        for item in tqdm(predictions):
            pos_logit = item["pos_logit"]
            neg_logit = item["neg_logit"]
            example_id = item["example_id"]
            if str(example_id) in examples_pred:
                examples_pred[str(example_id)].append([pos_logit, neg_logit])
            else:
                examples_pred[str(example_id)] = [[pos_logit, neg_logit]]

        results_sum = []
        results_more = []
        for example_id, logits in examples_pred.items():
            # method 1: sum
            pos_logits = [a[0] for a in logits]
            neg_logits = [a[1] for a in logits]
            pos_logit = sum(pos_logits)
            neg_logit = sum(neg_logits)

            results_sum.append(1 if pos_logit > neg_logit else 0)

            # method 2: more
            num_more_than = sum([a[0] > a[1]
                                 for a in logits])  # the number of pos > neg
            results_more.append(1 if num_more_than >= len(logits) / 2 else 0)

        sum_acc = sum(results_sum) / len(results_sum)
        more_acc = sum(results_more) / len(results_more)
        print(
            f"global_step: {global_step}, sum_acc: {sum_acc}, more_acc: {more_acc}"
        )
        return {"sum_acc": sum_acc, "more_acc": more_acc}

    def _save(self, session, step, metrics):
        save_file = os.path.join(
            self.save_model_dir, "step{}_sumacc{:5.4f}_moreacc{:5.4f}".format(
                step, metrics["sum_acc"], metrics["more_acc"]))
        list_name = os.listdir(self.output_dir)
        for name in list_name:
            if "model.ckpt-{}".format(step - 1) in name:
                org_name = os.path.join(self.output_dir, name)
                tag_name = save_file + "." + name.split(".")[-1]
                print("save {} to {}".format(org_name, tag_name))
                with open(org_name, "rb") as fr, open(tag_name, 'wb') as fw:
                    fw.write(fr.read())
Example #10
0
class SampleImageHook(tf.train.SessionRunHook):
    def __init__(self,
                 model,
                 sample_img,
                 img_path,
                 every_n_iter=None,
                 every_n_secs=None):
        '''
        Args:
          model : In order to retrieve `model.epoch_id` and `model.batch_id` for naming.
          sample_img : `Tensor`, sample images to save.
          img_path: 'String', path containing the directory and filename prefix
          every_n_iter: `int`, save the sample images every N local steps.
          every_n_secs: `int` or `float`, save sample images every N seconds. 
                Exactly one of `every_n_iter` and `every_n_secs` should be provided.
        '''
        self.model = model
        self.sample_img = sample_img
        self.img_path = img_path
        # Calculate appropriate grid size automatically
        h = math.sqrt(sample_img.get_shape().as_list()[0])
        w = math.ceil(h)
        self.grid_size = (int(h), w)
        self._timer = SecondOrStepTimer(every_secs=every_n_secs,
                                        every_steps=every_n_iter)

    def begin(self):
        # Make the dir if not exist
        img_dir = os.path.dirname(self.img_path)
        if not os.path.exists(img_dir):
            os.makedirs(img_dir)
        # Counter for run iterations
        self._iter_count = 0

    def before_run(self, run_context):
        self._should_trigger = self._timer.should_trigger_for_step(
            self._iter_count)
        if self._should_trigger:
            requests = {
                'sample_img': self.sample_img,
                'g_out': self.model.output,
            }
            # 'gt_img': self.model.input['im_gt']}
            return tf.train.SessionRunArgs(requests)
        else:
            return None

    def after_run(self, run_context, run_values):
        _ = run_context
        if self._should_trigger:
            self._timer.update_last_triggered_step(self._iter_count)
            # Save sample images, visualizing the current training results
            save_images(
                self.img_path + '_%02d_%04d.jpg' %
                (self.model.epoch_id, self.model.batch_id),
                run_values.results['sample_img'], self.grid_size)
            # save_images(self.img_path+'_%02d_%04d_out.jpg' % (self.model.epoch_id, self.model.batch_id),
            #             run_values.results['g_out'],
            #             self.grid_size)
            ## For checking. Save groundtruth (natuarl) training images.
            # save_images(self.img_path+'_%02d_%04d_gt.jpg' % (self.model.epoch_id, self.model.batch_id),
            #             run_values.results['gt_img'][:64],
            #             self.grid_size)
        self._iter_count += 1
Example #11
0
class TrainSampleHook(TrainingHook):
    """Occasionally samples predictions from the training run and prints them.

  Params:
    every_n_secs: Sample predictions every N seconds.
      If set, `every_n_steps` must be None.
    every_n_steps: Sample predictions every N steps.
      If set, `every_n_secs` must be None.
    sample_dir: Optional, a directory to write samples to.
    delimiter: Join tokens on this delimiter. Defaults to space.
  """

    #pylint: disable=missing-docstring

    def __init__(self, params, model_dir, run_config):
        super(TrainSampleHook, self).__init__(params, model_dir, run_config)
        self._sample_dir = os.path.join(self.model_dir, "samples")
        self._timer = SecondOrStepTimer(
            every_secs=self.params["every_n_secs"],
            every_steps=self.params["every_n_steps"])
        self._pred_dict = {}
        self._should_trigger = False
        self._iter_count = 0
        self._global_step = None
        self._source_delimiter = self.params["source_delimiter"]
        self._target_delimiter = self.params["target_delimiter"]

    @staticmethod
    def default_params():
        return {
            "every_n_secs": None,
            "every_n_steps": 1000,
            "source_delimiter": " ",
            "target_delimiter": " "
        }

    def begin(self):
        self._iter_count = 0
        self._global_step = tf.train.get_global_step()
        self._pred_dict = graph_utils.get_dict_from_collection("predictions")
        ##self._logits_infer = graph_utils.get_dict_from_collection("logits_infer")
        self._logits_softmax = graph_utils.get_dict_from_collection(
            "logits_softmax")
        # Create the sample directory
        if self._sample_dir is not None:
            gfile.MakeDirs(self._sample_dir)

    def before_run(self, _run_context):
        self._should_trigger = self._timer.should_trigger_for_step(
            self._iter_count)
        if self._should_trigger:
            fetches = {
                "predicted_tokens": self._pred_dict["predicted_tokens"],
                "target_words": self._pred_dict["labels.target_tokens"],
                "target_len": self._pred_dict["labels.target_len"]
            }
            return tf.train.SessionRunArgs([fetches, self._global_step])
        return tf.train.SessionRunArgs([{}, self._global_step])

    def after_run(self, _run_context, run_values):
        result_dict, step = run_values.results
        self._iter_count = step

        source_emb_logits_fetches = [
            self._logits_softmax["logits_softmax_output"],
            self._logits_softmax["logits_exp_sum"],
            self._logits_softmax["logits_message_exp_nan"],
            self._logits_softmax["logits_topic_exp_nan"],
            self._logits_softmax["logits_message_exp"],
            self._logits_softmax["logits_topic_exp"]
        ]

        logits_softmax_output, logits_exp_sum, logits_message_exp_nan, logits_topic_exp_nan, logits_message_exp, logits_topic_exp = self._session.run(
            source_emb_logits_fetches)
        ###source_message_emb, source_topic_emb, logits_message, logits_topic, logits_output, logits_message_nan,logits_topic_nan,topic_words_id_tensor, topic_word_location,losses, loss = self._session.run(source_emb_logits_fetches)
        ###tf.logging.info("source_message_emb:{}".format(source_message_emb))
        ###tf.logging.info("source_topic_emb:{}".format(source_topic_emb))  ###ok

        with open("log", "a") as f:
            f.write("step:{}".format(step))

            f.write("logits_exp_sum:{}".format(logits_exp_sum))
            f.write("logits_exp_sum max:{}".format(np.amax(logits_exp_sum)))
            f.write("logits_exp_sum min:{}".format(np.amin(logits_exp_sum)))
            f.write("logits_message_exp:{}".format(logits_message_exp))
            f.write("logits_topic_exp:{}".format(logits_topic_exp))
            f.write("logits_message_exp max:{}".format(
                np.amax(logits_message_exp)))
            f.write("logits_topic_exp max:{}".format(
                np.amax(logits_topic_exp)))
            ###f.write("topic_words_mask:{}".format(topic_words_mask))
            f.write("logits_message_exp_nan:{}".format(logits_message_exp_nan))
            f.write("logits_topic_exp_nan:{}".format(logits_topic_exp_nan))

        if step % 100 == 0:
            tf.logging.info("step:{}".format(step))

            tf.logging.info("logits_exp_sum:{}".format(logits_exp_sum))
            ###tf.logging.info("topic_words_mask:{}".format(topic_words_mask))
            tf.logging.info(
                "logits_message_exp_nan:{}".format(logits_message_exp_nan))
            tf.logging.info(
                "logits_topic_exp_nan:{}".format(logits_topic_exp_nan))

        if not self._should_trigger:
            return None

        # Convert dict of lists to list of dicts
        result_dicts = [
            dict(zip(result_dict, t)) for t in zip(*result_dict.values())
        ]

        # Print results
        result_str = ""
        result_str += "Prediction followed by Target @ Step {}\n".format(step)
        result_str += ("=" * 100) + "\n"
        for result in result_dicts:
            target_len = result["target_len"]
            predicted_slice = result["predicted_tokens"][:target_len - 1]
            target_slice = result["target_words"][1:target_len]
            result_str += self._target_delimiter.encode("utf-8").join(
                predicted_slice).decode("utf-8") + "\n"
            result_str += self._target_delimiter.encode("utf-8").join(
                target_slice).decode("utf-8") + "\n\n"
        result_str += ("=" * 100) + "\n\n"
        tf.logging.info(result_str)
        if self._sample_dir:
            filepath = os.path.join(self._sample_dir,
                                    "samples_{:06d}.txt".format(step))
            with gfile.GFile(filepath, "w") as file:
                file.write(result_str)
        self._timer.update_last_triggered_step(self._iter_count - 1)
Example #12
0
class EvalHook(SessionRunHook):
    def __init__(self,
                 estimator,
                 dev_features,
                 dev_label,
                 dev_cid,
                 max_seq_length,
                 th=82.0,
                 eval_steps=None,
                 checkpoint_dir=None,
                 model_name=None,
                 _input_fn_builder=None,
                 tail_num=0,
                 type_word=''):
        logging.info("Create EvalHook.")
        self.estimator = estimator
        self.dev_features = dev_features
        self.dev_label = dev_label
        self.dev_cid = dev_cid
        self.max_seq_length = max_seq_length
        self.th = th
        self._checkpoint_dir = checkpoint_dir
        if os.path.exists('./EVAL_LOG') is False:
            os.mkdir('./EVAL_LOG')
        self.model_name = model_name
        self.tail_num = tail_num
        self.org_dir = "CQA_" + type_word + self.model_name + "_{}".format(
            self.tail_num)

        self._log_save_path = os.path.join(
            './EVAL_LOG', model_name + '_' + type_word + '_log')
        self._save_path = checkpoint_dir
        if os.path.exists(self._save_path) is False:
            os.mkdir(self._save_path)
        self._timer = SecondOrStepTimer(every_steps=eval_steps)
        self._steps_per_run = 1
        self._global_step_tensor = None
        self._saver = None

        if _input_fn_builder is not None:
            self.input_fn_builder = _input_fn_builder
        else:
            self.input_fn_builder = input_fn_builder

    def _set_steps_per_run(self, steps_per_run):
        self._steps_per_run = steps_per_run

    def begin(self):
        # self._summary_writer = SummaryWriterCache.get(self._checkpoint_dir)
        self._global_step_tensor = get_global_step()  # pylint: disable=protected-access
        if self._global_step_tensor is None:
            raise RuntimeError(
                "Global step should be created to use CheckpointSaverHook.")

    def before_run(self, run_context):  # pylint: disable=unused-argument
        return SessionRunArgs(self._global_step_tensor)

    def after_run(self, run_context, run_values):
        stale_global_step = run_values.results
        if self._timer.should_trigger_for_step(stale_global_step +
                                               self._steps_per_run):
            # get the real value after train op.
            global_step = run_context.session.run(self._global_step_tensor)
            if self._timer.should_trigger_for_step(global_step):
                self._timer.update_last_triggered_step(global_step)
                MAP, MRR = self.evaluation(global_step)
                # print("================", MAP, MRR, self.th, type(MAP), type(MRR), type(self.th))
                if MAP > self.th:
                    # print("================", MAP, MRR)
                    self._save(run_context.session, global_step, MAP, MRR)

    def end(self, session):
        last_step = session.run(self._global_step_tensor)
        if last_step != self._timer.last_triggered_step():
            MAP, MRR = self.evaluation(last_step)
            if MAP > self.th:
                self._save(session, last_step, MAP, MRR)

    def evaluation(self, global_step):
        eval_input_fn = self.input_fn_builder(features=self.dev_features,
                                              seq_length=self.max_seq_length,
                                              is_training=False,
                                              drop_remainder=False)

        predictions = self.estimator.predict(eval_input_fn,
                                             yield_single_examples=False)
        res = np.concatenate([a["prob"] for a in predictions], axis=0)

        metrics = PRF(np.array(self.dev_label), res.argmax(axis=-1))

        print('\n Global step is : ', global_step)
        MAP, AvgRec, MRR = eval_reranker(self.dev_cid, self.dev_label, res[:,
                                                                           0])

        metrics['MAP'] = MAP
        metrics['AvgRec'] = AvgRec
        metrics['MRR'] = MRR

        metrics['global_step'] = global_step

        print_metrics(metrics, 'dev', save_dir=self._log_save_path)

        return MAP * 100, MRR

    def _save(self, session, step, map=None, mrr=None):
        """Saves the latest checkpoint, returns should_stop."""
        save_path = os.path.join(
            self._save_path,
            "step{}_MAP{:5.4f}_MRR{:5.4f}".format(step, map, mrr))

        list_name = os.listdir(self.org_dir)
        for name in list_name:
            if "model.ckpt-{}".format(step - 1) in name:
                org_name = os.path.join(self.org_dir, name)
                tag_name = save_path + "." + name.split(".")[-1]
                print("save {} to {}".format(org_name, tag_name))
                with open(org_name, "rb") as fr, open(tag_name, 'wb') as fw:
                    fw.write(fr.read())
Example #13
0
class StepCounterHook(session_run_hook.SessionRunHook):
    """Hook that counts steps per second."""
    def __init__(self,
                 scale=1,
                 every_n_steps=100,
                 every_n_secs=None,
                 output_dir=None,
                 summary_writer=None,
                 summary_train_op=None,
                 summary_test_op=None,
                 summary_evaluator=None,
                 test_every_n_steps=None,
                 local_step_tensor=None):

        if (every_n_steps is None) == (every_n_secs is None):
            raise ValueError(
                "exactly one of every_n_steps and every_n_secs should be provided."
            )
        self._timer = SecondOrStepTimer(every_steps=every_n_steps,
                                        every_secs=every_n_secs)

        self._summary_writer = summary_writer
        self._output_dir = output_dir
        self._last_global_step = 0
        self._last_local_step = None
        self._scale = scale
        self._summary_train_op = summary_train_op
        self._summary_test_op = summary_test_op
        self._summary_evaluator = summary_evaluator
        self._test_every_n_steps = test_every_n_steps
        self._local_step_tensor = local_step_tensor
        self._exec_count = 0

    def begin(self):
        if self._summary_writer is None and self._output_dir:
            self._summary_writer = SummaryWriterCache.get(self._output_dir)
        self._global_step_tensor = training_util._get_or_create_global_step_read(
        )  # pylint: disable=protected-access
        if self._global_step_tensor is None:
            raise RuntimeError(
                "Global step should be created to use StepCounterHook.")
        self._summary_tag = "absolute_" + training_util.get_global_step(
        ).op.name + "/sec"

    def before_run(self, run_context):  # pylint: disable=unused-argument
        return SessionRunArgs(self._local_step_tensor)

    def after_run(self, run_context, run_values):
        _ = run_context

        stale_local_step = run_values.results
        if stale_local_step > 0:
            if self._timer.should_trigger_for_step(stale_local_step + 1):
                # get the real value after train op.
                global_step, local_step = run_context.session.run(
                    [self._global_step_tensor, self._local_step_tensor])
                if self._timer.should_trigger_for_step(local_step):
                    elapsed_time, _ = self._timer.update_last_triggered_step(
                        local_step)
                    if elapsed_time is not None:
                        steps_per_sec = (global_step - self._last_global_step
                                         ) * self._scale / elapsed_time
                        logging.info("Speech %s: %g", self._summary_tag,
                                     steps_per_sec)
                        if self._summary_writer is not None:
                            aggregated_summary = run_context.session.run(
                                self._summary_train_op)
                            self._summary_writer.add_summary(
                                aggregated_summary, global_step)
                            summary = Summary(value=[
                                Summary.Value(tag=self._summary_tag,
                                              simple_value=steps_per_sec)
                            ])
                            self._summary_writer.add_summary(
                                summary, global_step)
                            self._exec_count += 1
                            if (self._test_every_n_steps is not None) and (
                                    self._exec_count %
                                    self._test_every_n_steps) == 0:
                                logging.info("Evaluate model start")
                                self._summary_evaluator(run_context.session)
                                aggregated_summary = run_context.session.run(
                                    self._summary_test_op)
                                self._summary_writer.add_summary(
                                    aggregated_summary, global_step)
                                logging.info("Evaluate model end")
                    self._timer.update_last_triggered_step(local_step)
                    self._last_global_step = global_step

            self._last_local_step = stale_local_step
Example #14
0
class LoggingTensorHook(session_run_hook.SessionRunHook):
  """Prints the given tensors once every N local steps or once every N seconds.
  The tensors will be printed to the log, with `INFO` severity.
  """

  def __init__(self, tensors, every_n_iter=None, every_n_secs=None,
               formatter=None):
    """Initializes a LoggingHook monitor.
    Args:
      tensors: `dict` that maps string-valued tags to tensors/tensor names,
          or `iterable` of tensors/tensor names.
      every_n_iter: `int`, print the values of `tensors` once every N local
          steps taken on the current worker.
      every_n_secs: `int` or `float`, print the values of `tensors` once every N
          seconds. Exactly one of `every_n_iter` and `every_n_secs` should be
          provided.
      formatter: function, takes dict of `tag`->`Tensor` and returns a string.
          If `None` uses default printing all tensors.
    Raises:
      ValueError: if `every_n_iter` is non-positive.
    """
    if (every_n_iter is None) == (every_n_secs is None):
      raise ValueError(
          "exactly one of every_n_iter and every_n_secs must be provided.")
    if every_n_iter is not None and every_n_iter <= 0:
      raise ValueError("invalid every_n_iter=%s." % every_n_iter)
    if not isinstance(tensors, dict):
      self._tag_order = tensors
      tensors = {item: item for item in tensors}
    else:
      self._tag_order = tensors.keys()
    self._tensors = tensors
    self._formatter = formatter
    self._timer = SecondOrStepTimer(every_secs=every_n_secs,
                                    every_steps=every_n_iter)

  def begin(self):
    self._iter_count = 0
    # Convert names to tensors if given
    self._current_tensors = {tag: _as_graph_element(tensor)
                             for (tag, tensor) in self._tensors.items()}

  def before_run(self, run_context):  # pylint: disable=unused-argument
    self._should_trigger = self._timer.should_trigger_for_step(self._iter_count)
    if self._should_trigger:
      return SessionRunArgs(self._current_tensors)
    else:
      return None

  def after_run(self, run_context, run_values):
    _ = run_context
    if self._should_trigger:
      original = np.get_printoptions()
      np.set_printoptions(suppress=True)
      elapsed_secs, _ = self._timer.update_last_triggered_step(self._iter_count)
      if self._formatter:
        logging.info(self._formatter(run_values.results))
      else:
        stats = []
        for tag in self._tag_order:
          stats.append("%s = %s" % (tag, run_values.results[tag]))
        if elapsed_secs is not None:
          logging.info("%s (%.3f sec)", ", ".join(stats), elapsed_secs)
        else:
          logging.info("%s", ", ".join(stats))
      np.set_printoptions(**original)
    self._iter_count += 1
    
class EvalHook(SessionRunHook):
    def __init__(self,
                 estimator,
                 dev_file,
                 org_dev_file,
                 eval_features,
                 eval_steps=100,
                 max_seq_length=300,
                 max_answer_length=15,
                 checkpoint_dir=None,
                 input_fn_builder=None,
                 th=86,
                 model_name=None):
        self.estimator = estimator
        self.max_seq_length = max_seq_length
        self.max_answer_length = max_answer_length
        self.dev_file = dev_file
        self.org_dev_file = org_dev_file
        self.eval_features = eval_features
        self.th = th
        self.checkpoint_dir = checkpoint_dir
        self.org_dir = model_name
        if os.path.exists("./EVAL_LOG") is False:
            os.mkdir("./EVAL_LOG")

        if os.path.exists(self.checkpoint_dir) is False:
            os.mkdir(self.checkpoint_dir)
        self._log_save_path = os.path.join("./EVAL_LOG", model_name)
        self.save_path = os.path.join(self.checkpoint_dir, model_name)
        if os.path.exists(self.save_path) is False:
            os.mkdir(self.save_path)

        self._timer = SecondOrStepTimer(every_steps=eval_steps)
        self._steps_per_run = 1
        self._global_step_tensor = None

        self.input_fn_builder = input_fn_builder

    def _set_steps_per_run(self, steps_per_run):
        self._steps_per_run = steps_per_run

    def begin(self):
        # self._summary_writer = SummaryWriterCache.get(self._checkpoint_dir)
        self._global_step_tensor = get_global_step()  # pylint: disable=protected-access
        if self._global_step_tensor is None:
            raise RuntimeError(
                "Global step should be created to use CheckpointSaverHook.")

    def before_run(self, run_context):  # pylint: disable=unused-argument
        return SessionRunArgs(self._global_step_tensor)

    def after_run(self, run_context, run_values):
        # print(run_values.results)
        stale_global_step = run_values.results
        if self._timer.should_trigger_for_step(stale_global_step +
                                               self._steps_per_run):
            # get the real value after train op.
            global_step = run_context.session.run(self._global_step_tensor)
            if self._timer.should_trigger_for_step(global_step):
                self._timer.update_last_triggered_step(global_step)
                metrics = self.evaluation(global_step)
                # print("================", MAP, MRR, self.th, type(MAP), type(MRR), type(self.th))
                if metrics["acc"] * 100 > self.th:
                    # print("================", MAP, MRR)
                    self._save(run_context.session, global_step, metrics)

    def end(self, session):
        last_step = session.run(self._global_step_tensor)
        if last_step != self._timer.last_triggered_step():
            metrics = self.evaluation(last_step)
            if metrics["acc"] * 100 > self.th:
                self._save(session, last_step, metrics)

    def evaluation(self, global_step):
        print("======================================================")
        print("EVAL STARTING !!!!\n")

        dev_input_fn = self.input_fn_builder(input_file=self.dev_file,
                                             seq_length=self.max_seq_length,
                                             is_training=False,
                                             drop_remainder=False)

        predictions = self.estimator.predict(dev_input_fn,
                                             yield_single_examples=True)

        #             predictions = {
        #                 "unique_ids": unique_ids,
        #                 "start_logits": start_logits,
        #                 "end_logits": end_logits,
        #             }

        with open("./SAVE_MODEL/temp_results.csv", "w",
                  encoding="utf-8") as fw:
            for i, item in enumerate(predictions):
                unique_ids = item["unique_ids"]
                qa_id = self.eval_features[i].unique_id
                # print(unique_ids, type(unique_ids))
                # print(qa_id, type(qa_id))
                assert qa_id == unique_ids

                start_logits = item["start_logits"]
                end_logits = item["end_logits"]
                # yp1 = item["yp1"]
                # yp2 = item["yp2"]
                #
                # y1 = self.eval_features[i].start_position
                # y2 = self.eval_features[i].end_position

                n_best_items = write_prediction(
                    self.eval_features[i],
                    start_logits,
                    end_logits,
                    n_best_size=20,
                    max_answer_length=self.max_answer_length)
                best_list = [a["text"] for a in n_best_items[:3]]

                while len(best_list) < 3:
                    best_list.append("empty")

                fw.write("\"{}\",\"{}\",\"{}\",\"{}\"\n".format(
                    qa_id, *best_list))
                # instances.append((qa_id, yp1, yp2, y1, y2))

        dev_data = pd.read_csv(self.org_dev_file,
                               header=None,
                               names=["id", "sent", "entity", "label"])
        results_data = pd.read_csv("./SAVE_MODEL/temp_results.csv",
                                   header=None,
                                   names=["id", "s1", "s2", "s3"])
        results_data["sent"] = dev_data["sent"]
        results_data["entity"] = dev_data["entity"]
        results_data["label"] = dev_data["label"]
        results_data["final"] = results_data.apply(process, axis=1)
        final_results = results_data[["id", "final", "label"]]
        final_results["EM"] = final_results.apply(is_equal, axis=1)

        EM = final_results["EM"].to_numpy(dtype=np.int)
        acc = np.sum(EM) / EM.shape[0]
        metrics = {'global_step': global_step, "acc": acc}
        print(f"golbal_step: {global_step}, acc: {acc}")
        return metrics

    def _save(self, session, step, metrics=None):
        """Saves the latest checkpoint, returns should_stop."""
        save_path = os.path.join(
            self.save_path, "step{}_acc{:5.4f}".format(step, metrics["acc"]))

        list_name = os.listdir(self.org_dir)
        for name in list_name:
            if "model.ckpt-{}".format(step - 1) in name:
                org_name = os.path.join(self.org_dir, name)
                tag_name = save_path + "." + name.split(".")[-1]
                print("save {} to {}".format(org_name, tag_name))
                with open(org_name, "rb") as fr, open(tag_name, 'wb') as fw:
                    fw.write(fr.read())
Example #16
0
class EvalHook(SessionRunHook):
    def __init__(self,
                 estimator,
                 dev_features,
                 dev_label,
                 dev_cid,
                 max_seq_length,
                 eval_steps=None,
                 checkpoint_dir=None,
                 model_name=None,
                 _input_fn_builder=None,
                 checkpoint_basename="eval.log"):

        logging.info("Create EvalHook.")
        self.estimator = estimator
        self.dev_features = dev_features
        self.dev_label = dev_label
        self.dev_cid = dev_cid
        self.max_seq_length = max_seq_length
        self._checkpoint_dir = checkpoint_dir
        if os.path.exists('./EVAL_LOG') is False:
            os.mkdir('./EVAL_LOG')
        self._save_path = os.path.join('./EVAL_LOG', model_name+'_log')
        self._timer = SecondOrStepTimer(every_steps=eval_steps)
        self._steps_per_run = 1
        self._global_step_tensor = None

        if _input_fn_builder is not None:
            self.input_fn_builder = _input_fn_builder
        else:
            self.input_fn_builder = input_fn_builder

    def _set_steps_per_run(self, steps_per_run):
        self._steps_per_run = steps_per_run

    def begin(self):
        self._global_step_tensor = get_global_step()  # pylint: disable=protected-access
        if self._global_step_tensor is None:
            raise RuntimeError(
                "Global step should be created to use EvalHook.")

    def before_run(self, run_context):  # pylint: disable=unused-argument
        return SessionRunArgs(self._global_step_tensor)

    def after_run(self, run_context, run_values):
        stale_global_step = run_values.results
        if self._timer.should_trigger_for_step(
                stale_global_step + self._steps_per_run):
            # get the real value after train op.
            global_step = run_context.session.run(self._global_step_tensor)
            if self._timer.should_trigger_for_step(global_step):
                self._timer.update_last_triggered_step(global_step)
                self.evaluation(global_step)

    def end(self, session):
        last_step = session.run(self._global_step_tensor)
        if last_step != self._timer.last_triggered_step():
            self.evaluation(last_step)

    def evaluation(self, global_step):
        eval_input_fn = self.input_fn_builder(
            features=self.dev_features,
            seq_length=self.max_seq_length,
            is_training=False,
            drop_remainder=False)

        predictions = self.estimator.predict(eval_input_fn, yield_single_examples=False)
        res = np.concatenate([a for a in predictions], axis=0)

        metrics = PRF(np.array(self.dev_label), res.argmax(axis=-1))

        print('\n Global step is : ', global_step)
        MAP, AvgRec, MRR = eval_reranker(self.dev_cid, self.dev_label, res[:, 0])

        metrics['MAP'] = MAP
        metrics['AvgRec'] = AvgRec
        metrics['MRR'] = MRR

        metrics['global_step'] = global_step

        print_metrics(metrics, 'dev', save_dir=self._save_path)
Example #17
0
class TokensPerSecondCounter(TrainingHook):
    """A hooks that counts tokens/sec, where the number of tokens is
    defines as `len(source) + len(target)`.
  """
    def __init__(self, params, model_dir, summary_writer=None):
        super(TokensPerSecondCounter, self).__init__(params, model_dir)

        self._summary_tag = "tokens/sec"
        self._timer = SecondOrStepTimer(
            every_steps=self.params["every_n_steps"],
            every_secs=self.params["every_n_secs"])

        self._summary_writer = summary_writer
        if summary_writer is None and self.model_dir:
            self._summary_writer = SummaryWriterCache.get(self.model_dir)

        self._tokens_last_step = 0

    @staticmethod
    def default_params():
        return {"every_n_steps": 100, "every_n_secs": None}

    def begin(self):
        #pylint: disable=W0201
        features = graph_utils.get_dict_from_collection("features")
        labels = graph_utils.get_dict_from_collection("labels")

        self._num_tokens_tensor = tf.constant(0)
        if "source_len" in features:
            self._num_tokens_tensor += tf.reduce_sum(features["source_len"])
        if "target_len" in labels:
            self._num_tokens_tensor += tf.reduce_sum(labels["target_len"])

        self._tokens_last_step = 0
        self._global_step_tensor = tf.train.get_global_step()

        # Create a variable that stores how many tokens have been processed
        # Should be global for distributed training
        with tf.variable_scope("tokens_counter"):
            self._tokens_processed_var = tf.get_variable(
                name="count",
                shape=[],
                dtype=tf.int32,
                initializer=tf.constant_initializer(0, dtype=tf.int32))
            self._tokens_processed_add = tf.assign_add(
                self._tokens_processed_var, self._num_tokens_tensor)

    def before_run(self, run_context):
        return tf.train.SessionRunArgs(
            [self._global_step_tensor, self._tokens_processed_add])

    def after_run(self, _run_context, run_values):
        global_step, num_tokens = run_values.results
        tokens_processed = num_tokens - self._tokens_last_step

        if self._timer.should_trigger_for_step(global_step):
            elapsed_time, _ = self._timer.update_last_triggered_step(
                global_step)
            if elapsed_time is not None:
                tokens_per_sec = tokens_processed / elapsed_time
                if self._summary_writer is not None:
                    summary = tf.Summary(value=[
                        tf.Summary.Value(tag=self._summary_tag,
                                         simple_value=tokens_per_sec)
                    ])
                    self._summary_writer.add_summary(summary, global_step)
                tf.logging.info("%s: %g", self._summary_tag, tokens_per_sec)
            self._tokens_last_step = num_tokens