Exemple #1
0
 def run_heartbeat_loop():
     count = 0
     # Set a timeout of 30 seconds for each heartbeat.
     if self.session_timeout_in_ms:
         timeout_in_ms = self.session_timeout_in_ms
     else:
         timeout_in_ms = 30 * 1000
     run_options = tf.RunOptions(
         timeout_in_ms=self.session_timeout_in_ms)
     while True:
         try:
             if count % 100 == 0:
                 tf.logging.info('heartbeat: request_%d ...', count)
             t_begin = time.time()
             sess = self.get_session()
             ret = sess.run(self.heartbeat, options=run_options)
             if self.streamz_heartbeat_latency is not None:
                 self.streamz_heartbeat_latency.Record(
                     (time.time() - t_begin) * 1e3)
             if count % 100 == 0:
                 tf.logging.info('heartbeat: done request_%d ... %s',
                                 count, ret)
         except Exception as e:
             tf.logging.fatal(
                 'Exception in heartbeat loop thread: %r %s', e, e)
             raise
         count += 1
         # Once every 10 seconds.
         time.sleep(10)
Exemple #2
0
  def Run(self,
          fetch_keys,
          validate_fetches=True,
          session_run_options=None,
          run_metadata=None,
          **kwargs):
    """Runs predictor.

    Args:
      fetch_keys: a list of keys in the fetch dictionary to fetch.
      validate_fetches: if True, raises a KeyError if a specified fetch is
        invalid. If False, returns None for invalid fetches instead.
      session_run_options: Optional tf.RunOptions() to use in the session.
      run_metadata: Optional tf.RunMetadata() to use in the session.
      **kwargs: a dict of inputs to feed.

    Returns:
      A list of predictions corresponding to the order of fetch_keys.

    Raises:
      InvalidArgumentError: the number of inputs does not meet requirements.
      KeyError: a feed specified in kwargs is invalid, or a fetch in fetch_keys
        is invalid and validate_fetches is True.
    """
    if validate_fetches:
      for x in fetch_keys:
        if x not in self._fetches:
          raise KeyError(
              "%s is not in the list of available fetches. Available keys: %s" %
              (x, list(self._fetches.keys())))
    valid_fetch_idxs, valid_fetches = zip(*[(i, self._fetches[k])
                                            for i, k in enumerate(fetch_keys)
                                            if k in self._fetches.keys()])

    for k in kwargs:
      if k not in self._feeds:
        raise KeyError(
            "%s is not in the list of available feeds. Available keys: %s" %
            (k, list(self._feeds.keys())))
    feeds = {self._feeds[k]: v for k, v in six.iteritems(kwargs)}

    run_options = tf.RunOptions(report_tensor_allocations_upon_oom=False)
    if session_run_options:
      run_options = session_run_options

    fetched_results = self._RunWithValidSession(
        tf.Session.run,
        valid_fetches,
        feed_dict=feeds,
        options=run_options,
        run_metadata=run_metadata)
    results = [None] * len(fetch_keys)
    for i, fetch in zip(valid_fetch_idxs, fetched_results):
      results[i] = fetch
    return results
Exemple #3
0
    def run_init_sequence(self):
        """Runs init sequences before decoding."""

        assert self.init_vars_op is not None
        assert self.compile_op is not None

        sess = self.reset_session(self._tpu)
        if self._heartbeat:
            self._start_heartbeat()

        if self.ckpt:

            def run_restore():
                tf.logging.info('Restoring vars from ckpt: start')
                try:
                    self.saver.restore(sess, self.ckpt)
                except Exception as e:
                    tf.logging.fatal('Restoring vars exception: %r %s', e, e)
                    raise
                tf.logging.info('Restoring vars from ckpt: done')

            init_thread = daemon(run_restore)
        else:

            def run_init():
                tf.logging.info('Init vars randomly: start')
                try:
                    sess.run(self.init_vars_op)
                except Exception as e:
                    tf.logging.fatal('Init vars exception: %r %s', e, e)
                    raise
                tf.logging.info('Init vars randomly: done')

            init_thread = daemon(run_init)

        if hasattr(self.task, 'input'):
            tf.logging.info('Init data')
            self.task.input.Initialize(sess)
            tf.logging.info('Init data done')

        tf.logging.info('Compile: start')
        run_options = tf.RunOptions(timeout_in_ms=86400 * 1000)
        sess.run(self.compile_op, options=run_options)
        tf.logging.info('Compile: done')

        init_thread.join()
Exemple #4
0
  def DecodeCheckpoint(self, sess, checkpoint_path):
    """Decodes `samples_per_summary` examples using `checkpoint_path`."""
    p = self._task.params
    ckpt_id_from_file = self.GetCkptIdFromFile(checkpoint_path)
    if ckpt_id_from_file < p.eval.start_decoder_after:
      return False
    samples_per_summary = p.eval.decoder_samples_per_summary
    if samples_per_summary is None:
      samples_per_summary = p.eval.samples_per_summary
    if samples_per_summary == 0:
      assert self._task.params.input.resettable
    self.checkpointer.RestoreFromPath(sess, checkpoint_path)

    global_step = sess.run(py_utils.GetGlobalStep())

    if self._task.params.input.resettable:
      tf.logging.info('Resetting input_generator.')
      self._task.input.Reset(sess)

    dec_metrics = self._task.CreateDecoderMetrics()
    if not dec_metrics:
      tf.logging.info('Empty decoder metrics')
      return
    buffered_decode_out = []
    num_examples_metric = dec_metrics['num_samples_in_batch']
    start_time = time.time()
    while samples_per_summary == 0 or (num_examples_metric.total_value <
                                       samples_per_summary):
      try:
        tf.logging.info('Fetching dec_output.')
        fetch_start = time.time()
        run_options = tf.RunOptions(report_tensor_allocations_upon_oom=False)
        if self._summary_op is None:
          # No summaries were collected.
          dec_out = sess.run(self._dec_output, options=run_options)
        else:
          dec_out, summary = sess.run([self._dec_output, self._summary_op],
                                      options=run_options)
          self._summary_writer.add_summary(summary, global_step)
        post_process_start = time.time()
        tf.logging.info('Done fetching (%f seconds)' %
                        (post_process_start - fetch_start))
        decode_out = self._task.PostProcessDecodeOut(dec_out, dec_metrics)
        if decode_out:
          buffered_decode_out.extend(decode_out)
        tf.logging.info(
            'Total examples done: %d/%d '
            '(%f seconds decode postprocess)', num_examples_metric.total_value,
            samples_per_summary,
            time.time() - post_process_start)
      except tf.errors.OutOfRangeError:
        if not self._task.params.input.resettable:
          raise
        break
    tf.logging.info('Done decoding ckpt: %s', checkpoint_path)

    summaries = {k: v.Summary(k) for k, v in dec_metrics.items()}
    elapsed_secs = time.time() - start_time
    example_rate = num_examples_metric.total_value / elapsed_secs
    summaries['examples/sec'] = metrics.CreateScalarSummary(
        'examples/sec', example_rate)
    summaries['total_samples'] = metrics.CreateScalarSummary(
        'total_samples', num_examples_metric.total_value)
    self._WriteSummaries(
        self._summary_writer,
        os.path.basename(self._decoder_dir),
        global_step,
        summaries,
        text_filename=os.path.join(self._decoder_dir,
                                   'score-{:08d}.txt'.format(global_step)))
    self._ExportMetrics(
        # Metrics expects python int, but global_step is numpy.int64.
        decode_checkpoint=int(global_step),
        dec_metrics=dec_metrics,
        example_rate=example_rate)
    # global_step and the checkpoint id from the checkpoint file might be
    # different. For consistency of checkpoint filename and decoder_out
    # file, use the checkpoint id as derived from the checkpoint filename.
    checkpoint_id = _GetCheckpointIdForDecodeOut(ckpt_id_from_file, global_step)
    decode_out_path = self.GetDecodeOutPath(self._decoder_dir, checkpoint_id)

    decode_finalize_args = base_model.DecodeFinalizeArgs(
        decode_out_path=decode_out_path, decode_out=buffered_decode_out)
    self._task.DecodeFinalize(decode_finalize_args)

    should_stop = global_step >= self.params.train.max_steps
    if self._should_report_metrics:
      tf.logging.info('Reporting eval measure for step %d.' % global_step)
      trial_should_stop = self._trial.ReportEvalMeasure(global_step,
                                                        dec_metrics,
                                                        checkpoint_path)
      should_stop = should_stop or trial_should_stop
    return should_stop
Exemple #5
0
    def Run(self,
            fetch_keys,
            validate_fetches=True,
            session_run_options=None,
            run_metadata=None,
            time_session_run=False,
            subgraph_name=None,
            **kwargs):  # pylint: disable=invalid-name
        """Runs predictor.

    Args:
      fetch_keys: dict_keys object or a list of keys in the fetch dictionary to
        fetch.
      validate_fetches: if True, raises a KeyError if a specified fetch is
        invalid. If False, returns None for invalid fetches instead.
      session_run_options: Optional tf.RunOptions() to use in the session.
      run_metadata: Optional tf.RunMetadata() to use in the session.
      time_session_run: Optional bool, if True, additionally return the
        execution time of session.run. Defaults to False.
      subgraph_name: Optional string of the subgraph to use.
      **kwargs: a dict of inputs to feed.

    Returns:
      A list of predictions corresponding to the order of fetch_keys and, if
      time_session_run is True, the run time in seconds.

    Raises:
      InvalidArgumentError: the number of inputs does not meet requirements.
      KeyError: a feed specified in kwargs is invalid, or a fetch in fetch_keys
        is invalid and validate_fetches is True.
    """
        subgraph_name = subgraph_name or self._default_subgraph_name

        single_fetch = False
        if not isinstance(fetch_keys, (list, type(dict().keys()))):
            single_fetch = True
            fetch_keys = [fetch_keys]

        valid_fetch_keys = self.subgraph_fetch_keys(subgraph_name)
        if validate_fetches:
            for k in fetch_keys:
                if k not in valid_fetch_keys:
                    raise KeyError(
                        f"{k} is not in the list of available fetches. Available keys: "
                        f"{valid_fetch_keys}.")
        subgraph_fetches = self._get_subgraph_fetches(subgraph_name)
        valid_fetch_idxs, valid_fetches = zip(
            *[(i, subgraph_fetches[k]) for i, k in enumerate(fetch_keys)
              if k in valid_fetch_keys])

        valid_feed_keys = self.subgraph_feed_keys(subgraph_name)
        for k in kwargs:
            if k not in valid_feed_keys:
                raise KeyError(
                    f"{k} is not in the list of available feeds. Available keys: "
                    f"{valid_feed_keys}.")
        subgraph_feeds = self._get_subgraph_feeds(subgraph_name)
        feeds = {subgraph_feeds[k]: v for k, v in kwargs.items()}

        run_options = tf.RunOptions(report_tensor_allocations_upon_oom=False)
        if session_run_options:
            run_options = session_run_options

        start = time.time()
        fetched_results = self._run_with_valid_session(
            tf.Session.run,
            valid_fetches,
            feed_dict=feeds,
            options=run_options,
            run_metadata=run_metadata)
        duration = time.time() - start
        results = [None] * len(fetch_keys)
        for i, fetch in zip(valid_fetch_idxs, fetched_results):
            results[i] = fetch
        if single_fetch:
            results = results[0]
        return (results, duration) if time_session_run else results
Exemple #6
0
    def DecodeCheckpoint(self, sess, checkpoint_path):
        """Decodes `samples_per_summary` examples using `checkpoint_path`."""
        p = self._task.params
        ckpt_id_from_file = self.GetCkptIdFromFile(checkpoint_path)
        if ckpt_id_from_file < p.eval.start_decoder_after:
            return

        samples_per_summary = p.eval.decoder_samples_per_summary
        if samples_per_summary is None:
            samples_per_summary = p.eval.samples_per_summary
        if samples_per_summary == 0:
            assert self._task.input.params.resettable
        self.checkpointer.RestoreFromPath(sess, checkpoint_path)

        global_step = sess.run(py_utils.GetGlobalStep())

        if self._task.input.params.resettable:
            tf.logging.info('Resetting input_generator.')
            self._task.input.Reset(sess)

        dec_metrics = self._task.CreateDecoderMetrics()
        if not dec_metrics:
            tf.logging.info('Empty decoder metrics')
            return
        buffered_decode_out = []
        num_examples_metric = dec_metrics['num_samples_in_batch']
        start_time = time.time()
        while samples_per_summary == 0 or (num_examples_metric.total_value <
                                           samples_per_summary):
            try:
                is_first_loop = num_examples_metric.total_value == 0
                tf.logging.info('Fetching dec_output.')
                fetch_start = time.time()
                run_options = tf.RunOptions(
                    report_tensor_allocations_upon_oom=False)

                # NOTE: We intentionally do not generate scalar summaries by
                # default, because decoder is run  multiple times for each
                # checkpoint. Multiple summaries at the same step is often confusing.
                # Instead, models should generate aggregate summaries using
                # PostProcessDecodeOut. Other types of summaries (images, audio etc.)
                # will be generated for the first eval batch.
                if self._summary_op is not None and is_first_loop:
                    dec_out, summaries = sess.run(
                        [self._dec_output, self._summary_op],
                        options=run_options)
                    summaries = self._RemoveScalarSummaries(summaries)

                    # Add non-scalar summaries only for the first batch of data.
                    self._summary_writer.add_summary(summaries, global_step)
                    self._summary_writer.flush()
                else:
                    dec_out = sess.run(self._dec_output, options=run_options)

                self._RunTF2SummaryOps(sess)
                post_process_start = time.time()
                tf.logging.info('Done fetching (%f seconds)' %
                                (post_process_start - fetch_start))
                decode_out = self._task.PostProcessDecodeOut(
                    dec_out, dec_metrics)
                if decode_out:
                    if isinstance(decode_out, dict):
                        decode_out = decode_out.items()

                    if is_first_loop:
                        # Add summaries only for the first batch of data.
                        for key, value in decode_out:
                            if isinstance(value, tf.Summary):
                                tf.logging.info(
                                    f'Adding summary {key} with tags '
                                    f'{[x.tag for x in value.value]}.')
                                self._summary_writer.add_summary(
                                    value, global_step)
                        self._summary_writer.flush()

                    buffered_decode_out.extend(
                        kv for kv in decode_out
                        if not isinstance(kv[1], tf.Summary))
                tf.logging.info(
                    'Total examples done: %d/%d '
                    '(%f seconds decode postprocess)',
                    num_examples_metric.total_value, samples_per_summary,
                    time.time() - post_process_start)
            except tf.errors.OutOfRangeError:
                if not self._task.input.params.resettable:
                    raise
                break
        tf.logging.info('Done decoding ckpt: %s', checkpoint_path)

        summaries = {k: v.Summary(k) for k, v in dec_metrics.items()}
        elapsed_secs = time.time() - start_time
        example_rate = num_examples_metric.total_value / elapsed_secs
        summaries['examples/sec'] = metrics.CreateScalarSummary(
            'examples/sec', example_rate)
        summaries['total_samples'] = metrics.CreateScalarSummary(
            'total_samples', num_examples_metric.total_value)
        self._WriteSummaries(self._summary_writer,
                             os.path.basename(self._decoder_dir),
                             global_step,
                             summaries,
                             text_filename=os.path.join(
                                 self._decoder_dir,
                                 'score-{:08d}.txt'.format(global_step)))
        self._ExportMetrics(
            # Metrics expects python int, but global_step is numpy.int64.
            decode_checkpoint=int(global_step),
            dec_metrics=dec_metrics,
            example_rate=example_rate)
        # global_step and the checkpoint id from the checkpoint file might be
        # different. For consistency of checkpoint filename and decoder_out
        # file, use the checkpoint id as derived from the checkpoint filename.
        checkpoint_id = _GetCheckpointIdForDecodeOut(ckpt_id_from_file,
                                                     global_step)
        decode_out_path = self.GetDecodeOutPath(self._decoder_dir,
                                                checkpoint_id)

        decode_finalize_args = base_model.DecodeFinalizeArgs(
            decode_out_path=decode_out_path, decode_out=buffered_decode_out)
        self._task.DecodeFinalize(decode_finalize_args)

        if self._should_report_metrics:
            tf.logging.info('Reporting eval measure for step %d.' %
                            global_step)
            self._trial.ReportEvalMeasure(global_step, dec_metrics,
                                          checkpoint_path)