示例#1
0
  def DecodeCheckpoint(self, sess, checkpoint_path):
    """Decodes `samples_per_summary` examples using `checkpoint_path`."""
    p = self._task.params
    ckpt_id_from_file = self.GetCkptIdFromFile(checkpoint_path)
    if ckpt_id_from_file < p.eval.start_decoder_after:
      return False
    samples_per_summary = p.eval.decoder_samples_per_summary
    if samples_per_summary is None:
      samples_per_summary = p.eval.samples_per_summary
    if samples_per_summary == 0:
      assert self._task.params.input.resettable
    self.checkpointer.RestoreFromPath(sess, checkpoint_path)

    global_step = sess.run(py_utils.GetGlobalStep())

    if self._task.params.input.resettable:
      tf.logging.info('Resetting input_generator.')
      self._task.input.Reset(sess)

    dec_metrics = self._task.CreateDecoderMetrics()
    if not dec_metrics:
      tf.logging.info('Empty decoder metrics')
      return
    buffered_decode_out = []
    num_examples_metric = dec_metrics['num_samples_in_batch']
    start_time = time.time()
    while samples_per_summary == 0 or (num_examples_metric.total_value <
                                       samples_per_summary):
      try:
        tf.logging.info('Fetching dec_output.')
        fetch_start = time.time()
        run_options = tf.RunOptions(report_tensor_allocations_upon_oom=False)
        if self._summary_op is None:
          # No summaries were collected.
          dec_out = sess.run(self._dec_output, options=run_options)
        else:
          dec_out, summary = sess.run([self._dec_output, self._summary_op],
                                      options=run_options)
          self._summary_writer.add_summary(summary, global_step)
        post_process_start = time.time()
        tf.logging.info('Done fetching (%f seconds)' %
                        (post_process_start - fetch_start))
        decode_out = self._task.PostProcessDecodeOut(dec_out, dec_metrics)
        if decode_out:
          buffered_decode_out.extend(decode_out)
        tf.logging.info(
            'Total examples done: %d/%d '
            '(%f seconds decode postprocess)', num_examples_metric.total_value,
            samples_per_summary,
            time.time() - post_process_start)
      except tf.errors.OutOfRangeError:
        if not self._task.params.input.resettable:
          raise
        break
    tf.logging.info('Done decoding ckpt: %s', checkpoint_path)

    summaries = {k: v.Summary(k) for k, v in dec_metrics.items()}
    elapsed_secs = time.time() - start_time
    example_rate = num_examples_metric.total_value / elapsed_secs
    summaries['examples/sec'] = metrics.CreateScalarSummary(
        'examples/sec', example_rate)
    summaries['total_samples'] = metrics.CreateScalarSummary(
        'total_samples', num_examples_metric.total_value)
    self._WriteSummaries(
        self._summary_writer,
        os.path.basename(self._decoder_dir),
        global_step,
        summaries,
        text_filename=os.path.join(self._decoder_dir,
                                   'score-{:08d}.txt'.format(global_step)))
    self._ExportMetrics(
        # Metrics expects python int, but global_step is numpy.int64.
        decode_checkpoint=int(global_step),
        dec_metrics=dec_metrics,
        example_rate=example_rate)
    # global_step and the checkpoint id from the checkpoint file might be
    # different. For consistency of checkpoint filename and decoder_out
    # file, use the checkpoint id as derived from the checkpoint filename.
    checkpoint_id = _GetCheckpointIdForDecodeOut(ckpt_id_from_file, global_step)
    decode_out_path = self.GetDecodeOutPath(self._decoder_dir, checkpoint_id)

    decode_finalize_args = base_model.DecodeFinalizeArgs(
        decode_out_path=decode_out_path, decode_out=buffered_decode_out)
    self._task.DecodeFinalize(decode_finalize_args)

    should_stop = global_step >= self.params.train.max_steps
    if self._should_report_metrics:
      tf.logging.info('Reporting eval measure for step %d.' % global_step)
      trial_should_stop = self._trial.ReportEvalMeasure(global_step,
                                                        dec_metrics,
                                                        checkpoint_path)
      should_stop = should_stop or trial_should_stop
    return should_stop
示例#2
0
 def _SummarizeValue(self, steps, tag, value):
     self._summary_writer.add_summary(
         metrics.CreateScalarSummary(tag, value), steps)
示例#3
0
文件: trainer.py 项目: j-luo93/lingvo
 def _SummarizeValue(self, steps, tag, value, writer):
   if writer:
     writer.add_summary(metrics.CreateScalarSummary(tag, value), steps)
示例#4
0
 def _SummarizeValue(self, steps, tag, value, writer=None):
     if writer:
         writer.add_summary(metrics.CreateScalarSummary(tag, value), steps)
     elif self._summary_writer:
         self._summary_writer.add_summary(
             metrics.CreateScalarSummary(tag, value), steps)
示例#5
0
    def DecodeCheckpoint(self, sess, checkpoint_path):
        """Decodes `samples_per_summary` examples using `checkpoint_path`."""
        p = self._task.params
        ckpt_id_from_file = self.GetCkptIdFromFile(checkpoint_path)
        if ckpt_id_from_file < p.eval.start_decoder_after:
            return

        samples_per_summary = p.eval.decoder_samples_per_summary
        if samples_per_summary is None:
            samples_per_summary = p.eval.samples_per_summary
        if samples_per_summary == 0:
            assert self._task.input.params.resettable
        self.checkpointer.RestoreFromPath(sess, checkpoint_path)

        global_step = sess.run(py_utils.GetGlobalStep())

        if self._task.input.params.resettable:
            tf.logging.info('Resetting input_generator.')
            self._task.input.Reset(sess)

        dec_metrics = self._task.CreateDecoderMetrics()
        if not dec_metrics:
            tf.logging.info('Empty decoder metrics')
            return
        buffered_decode_out = []
        num_examples_metric = dec_metrics['num_samples_in_batch']
        start_time = time.time()
        while samples_per_summary == 0 or (num_examples_metric.total_value <
                                           samples_per_summary):
            try:
                is_first_loop = num_examples_metric.total_value == 0
                tf.logging.info('Fetching dec_output.')
                fetch_start = time.time()
                run_options = tf.RunOptions(
                    report_tensor_allocations_upon_oom=False)

                # NOTE: We intentionally do not generate scalar summaries by
                # default, because decoder is run  multiple times for each
                # checkpoint. Multiple summaries at the same step is often confusing.
                # Instead, models should generate aggregate summaries using
                # PostProcessDecodeOut. Other types of summaries (images, audio etc.)
                # will be generated for the first eval batch.
                if self._summary_op is not None and is_first_loop:
                    dec_out, summaries = sess.run(
                        [self._dec_output, self._summary_op],
                        options=run_options)
                    summaries = self._RemoveScalarSummaries(summaries)

                    # Add non-scalar summaries only for the first batch of data.
                    self._summary_writer.add_summary(summaries, global_step)
                    self._summary_writer.flush()
                else:
                    dec_out = sess.run(self._dec_output, options=run_options)

                self._RunTF2SummaryOps(sess)
                post_process_start = time.time()
                tf.logging.info('Done fetching (%f seconds)' %
                                (post_process_start - fetch_start))
                decode_out = self._task.PostProcessDecodeOut(
                    dec_out, dec_metrics)
                if decode_out:
                    if isinstance(decode_out, dict):
                        decode_out = decode_out.items()

                    if is_first_loop:
                        # Add summaries only for the first batch of data.
                        for key, value in decode_out:
                            if isinstance(value, tf.Summary):
                                tf.logging.info(
                                    f'Adding summary {key} with tags '
                                    f'{[x.tag for x in value.value]}.')
                                self._summary_writer.add_summary(
                                    value, global_step)
                        self._summary_writer.flush()

                    buffered_decode_out.extend(
                        kv for kv in decode_out
                        if not isinstance(kv[1], tf.Summary))
                tf.logging.info(
                    'Total examples done: %d/%d '
                    '(%f seconds decode postprocess)',
                    num_examples_metric.total_value, samples_per_summary,
                    time.time() - post_process_start)
            except tf.errors.OutOfRangeError:
                if not self._task.input.params.resettable:
                    raise
                break
        tf.logging.info('Done decoding ckpt: %s', checkpoint_path)

        summaries = {k: v.Summary(k) for k, v in dec_metrics.items()}
        elapsed_secs = time.time() - start_time
        example_rate = num_examples_metric.total_value / elapsed_secs
        summaries['examples/sec'] = metrics.CreateScalarSummary(
            'examples/sec', example_rate)
        summaries['total_samples'] = metrics.CreateScalarSummary(
            'total_samples', num_examples_metric.total_value)
        self._WriteSummaries(self._summary_writer,
                             os.path.basename(self._decoder_dir),
                             global_step,
                             summaries,
                             text_filename=os.path.join(
                                 self._decoder_dir,
                                 'score-{:08d}.txt'.format(global_step)))
        self._ExportMetrics(
            # Metrics expects python int, but global_step is numpy.int64.
            decode_checkpoint=int(global_step),
            dec_metrics=dec_metrics,
            example_rate=example_rate)
        # global_step and the checkpoint id from the checkpoint file might be
        # different. For consistency of checkpoint filename and decoder_out
        # file, use the checkpoint id as derived from the checkpoint filename.
        checkpoint_id = _GetCheckpointIdForDecodeOut(ckpt_id_from_file,
                                                     global_step)
        decode_out_path = self.GetDecodeOutPath(self._decoder_dir,
                                                checkpoint_id)

        decode_finalize_args = base_model.DecodeFinalizeArgs(
            decode_out_path=decode_out_path, decode_out=buffered_decode_out)
        self._task.DecodeFinalize(decode_finalize_args)

        if self._should_report_metrics:
            tf.logging.info('Reporting eval measure for step %d.' %
                            global_step)
            self._trial.ReportEvalMeasure(global_step, dec_metrics,
                                          checkpoint_path)