def run_heartbeat_loop(): count = 0 # Set a timeout of 30 seconds for each heartbeat. if self.session_timeout_in_ms: timeout_in_ms = self.session_timeout_in_ms else: timeout_in_ms = 30 * 1000 run_options = tf.RunOptions( timeout_in_ms=self.session_timeout_in_ms) while True: try: if count % 100 == 0: tf.logging.info('heartbeat: request_%d ...', count) t_begin = time.time() sess = self.get_session() ret = sess.run(self.heartbeat, options=run_options) if self.streamz_heartbeat_latency is not None: self.streamz_heartbeat_latency.Record( (time.time() - t_begin) * 1e3) if count % 100 == 0: tf.logging.info('heartbeat: done request_%d ... %s', count, ret) except Exception as e: tf.logging.fatal( 'Exception in heartbeat loop thread: %r %s', e, e) raise count += 1 # Once every 10 seconds. time.sleep(10)
def Run(self, fetch_keys, validate_fetches=True, session_run_options=None, run_metadata=None, **kwargs): """Runs predictor. Args: fetch_keys: a list of keys in the fetch dictionary to fetch. validate_fetches: if True, raises a KeyError if a specified fetch is invalid. If False, returns None for invalid fetches instead. session_run_options: Optional tf.RunOptions() to use in the session. run_metadata: Optional tf.RunMetadata() to use in the session. **kwargs: a dict of inputs to feed. Returns: A list of predictions corresponding to the order of fetch_keys. Raises: InvalidArgumentError: the number of inputs does not meet requirements. KeyError: a feed specified in kwargs is invalid, or a fetch in fetch_keys is invalid and validate_fetches is True. """ if validate_fetches: for x in fetch_keys: if x not in self._fetches: raise KeyError( "%s is not in the list of available fetches. Available keys: %s" % (x, list(self._fetches.keys()))) valid_fetch_idxs, valid_fetches = zip(*[(i, self._fetches[k]) for i, k in enumerate(fetch_keys) if k in self._fetches.keys()]) for k in kwargs: if k not in self._feeds: raise KeyError( "%s is not in the list of available feeds. Available keys: %s" % (k, list(self._feeds.keys()))) feeds = {self._feeds[k]: v for k, v in six.iteritems(kwargs)} run_options = tf.RunOptions(report_tensor_allocations_upon_oom=False) if session_run_options: run_options = session_run_options fetched_results = self._RunWithValidSession( tf.Session.run, valid_fetches, feed_dict=feeds, options=run_options, run_metadata=run_metadata) results = [None] * len(fetch_keys) for i, fetch in zip(valid_fetch_idxs, fetched_results): results[i] = fetch return results
def run_init_sequence(self): """Runs init sequences before decoding.""" assert self.init_vars_op is not None assert self.compile_op is not None sess = self.reset_session(self._tpu) if self._heartbeat: self._start_heartbeat() if self.ckpt: def run_restore(): tf.logging.info('Restoring vars from ckpt: start') try: self.saver.restore(sess, self.ckpt) except Exception as e: tf.logging.fatal('Restoring vars exception: %r %s', e, e) raise tf.logging.info('Restoring vars from ckpt: done') init_thread = daemon(run_restore) else: def run_init(): tf.logging.info('Init vars randomly: start') try: sess.run(self.init_vars_op) except Exception as e: tf.logging.fatal('Init vars exception: %r %s', e, e) raise tf.logging.info('Init vars randomly: done') init_thread = daemon(run_init) if hasattr(self.task, 'input'): tf.logging.info('Init data') self.task.input.Initialize(sess) tf.logging.info('Init data done') tf.logging.info('Compile: start') run_options = tf.RunOptions(timeout_in_ms=86400 * 1000) sess.run(self.compile_op, options=run_options) tf.logging.info('Compile: done') init_thread.join()
def DecodeCheckpoint(self, sess, checkpoint_path): """Decodes `samples_per_summary` examples using `checkpoint_path`.""" p = self._task.params ckpt_id_from_file = self.GetCkptIdFromFile(checkpoint_path) if ckpt_id_from_file < p.eval.start_decoder_after: return False samples_per_summary = p.eval.decoder_samples_per_summary if samples_per_summary is None: samples_per_summary = p.eval.samples_per_summary if samples_per_summary == 0: assert self._task.params.input.resettable self.checkpointer.RestoreFromPath(sess, checkpoint_path) global_step = sess.run(py_utils.GetGlobalStep()) if self._task.params.input.resettable: tf.logging.info('Resetting input_generator.') self._task.input.Reset(sess) dec_metrics = self._task.CreateDecoderMetrics() if not dec_metrics: tf.logging.info('Empty decoder metrics') return buffered_decode_out = [] num_examples_metric = dec_metrics['num_samples_in_batch'] start_time = time.time() while samples_per_summary == 0 or (num_examples_metric.total_value < samples_per_summary): try: tf.logging.info('Fetching dec_output.') fetch_start = time.time() run_options = tf.RunOptions(report_tensor_allocations_upon_oom=False) if self._summary_op is None: # No summaries were collected. dec_out = sess.run(self._dec_output, options=run_options) else: dec_out, summary = sess.run([self._dec_output, self._summary_op], options=run_options) self._summary_writer.add_summary(summary, global_step) post_process_start = time.time() tf.logging.info('Done fetching (%f seconds)' % (post_process_start - fetch_start)) decode_out = self._task.PostProcessDecodeOut(dec_out, dec_metrics) if decode_out: buffered_decode_out.extend(decode_out) tf.logging.info( 'Total examples done: %d/%d ' '(%f seconds decode postprocess)', num_examples_metric.total_value, samples_per_summary, time.time() - post_process_start) except tf.errors.OutOfRangeError: if not self._task.params.input.resettable: raise break tf.logging.info('Done decoding ckpt: %s', checkpoint_path) summaries = {k: v.Summary(k) for k, v in dec_metrics.items()} elapsed_secs = time.time() - start_time example_rate = num_examples_metric.total_value / elapsed_secs summaries['examples/sec'] = metrics.CreateScalarSummary( 'examples/sec', example_rate) summaries['total_samples'] = metrics.CreateScalarSummary( 'total_samples', num_examples_metric.total_value) self._WriteSummaries( self._summary_writer, os.path.basename(self._decoder_dir), global_step, summaries, text_filename=os.path.join(self._decoder_dir, 'score-{:08d}.txt'.format(global_step))) self._ExportMetrics( # Metrics expects python int, but global_step is numpy.int64. decode_checkpoint=int(global_step), dec_metrics=dec_metrics, example_rate=example_rate) # global_step and the checkpoint id from the checkpoint file might be # different. For consistency of checkpoint filename and decoder_out # file, use the checkpoint id as derived from the checkpoint filename. checkpoint_id = _GetCheckpointIdForDecodeOut(ckpt_id_from_file, global_step) decode_out_path = self.GetDecodeOutPath(self._decoder_dir, checkpoint_id) decode_finalize_args = base_model.DecodeFinalizeArgs( decode_out_path=decode_out_path, decode_out=buffered_decode_out) self._task.DecodeFinalize(decode_finalize_args) should_stop = global_step >= self.params.train.max_steps if self._should_report_metrics: tf.logging.info('Reporting eval measure for step %d.' % global_step) trial_should_stop = self._trial.ReportEvalMeasure(global_step, dec_metrics, checkpoint_path) should_stop = should_stop or trial_should_stop return should_stop
def Run(self, fetch_keys, validate_fetches=True, session_run_options=None, run_metadata=None, time_session_run=False, subgraph_name=None, **kwargs): # pylint: disable=invalid-name """Runs predictor. Args: fetch_keys: dict_keys object or a list of keys in the fetch dictionary to fetch. validate_fetches: if True, raises a KeyError if a specified fetch is invalid. If False, returns None for invalid fetches instead. session_run_options: Optional tf.RunOptions() to use in the session. run_metadata: Optional tf.RunMetadata() to use in the session. time_session_run: Optional bool, if True, additionally return the execution time of session.run. Defaults to False. subgraph_name: Optional string of the subgraph to use. **kwargs: a dict of inputs to feed. Returns: A list of predictions corresponding to the order of fetch_keys and, if time_session_run is True, the run time in seconds. Raises: InvalidArgumentError: the number of inputs does not meet requirements. KeyError: a feed specified in kwargs is invalid, or a fetch in fetch_keys is invalid and validate_fetches is True. """ subgraph_name = subgraph_name or self._default_subgraph_name single_fetch = False if not isinstance(fetch_keys, (list, type(dict().keys()))): single_fetch = True fetch_keys = [fetch_keys] valid_fetch_keys = self.subgraph_fetch_keys(subgraph_name) if validate_fetches: for k in fetch_keys: if k not in valid_fetch_keys: raise KeyError( f"{k} is not in the list of available fetches. Available keys: " f"{valid_fetch_keys}.") subgraph_fetches = self._get_subgraph_fetches(subgraph_name) valid_fetch_idxs, valid_fetches = zip( *[(i, subgraph_fetches[k]) for i, k in enumerate(fetch_keys) if k in valid_fetch_keys]) valid_feed_keys = self.subgraph_feed_keys(subgraph_name) for k in kwargs: if k not in valid_feed_keys: raise KeyError( f"{k} is not in the list of available feeds. Available keys: " f"{valid_feed_keys}.") subgraph_feeds = self._get_subgraph_feeds(subgraph_name) feeds = {subgraph_feeds[k]: v for k, v in kwargs.items()} run_options = tf.RunOptions(report_tensor_allocations_upon_oom=False) if session_run_options: run_options = session_run_options start = time.time() fetched_results = self._run_with_valid_session( tf.Session.run, valid_fetches, feed_dict=feeds, options=run_options, run_metadata=run_metadata) duration = time.time() - start results = [None] * len(fetch_keys) for i, fetch in zip(valid_fetch_idxs, fetched_results): results[i] = fetch if single_fetch: results = results[0] return (results, duration) if time_session_run else results
def DecodeCheckpoint(self, sess, checkpoint_path): """Decodes `samples_per_summary` examples using `checkpoint_path`.""" p = self._task.params ckpt_id_from_file = self.GetCkptIdFromFile(checkpoint_path) if ckpt_id_from_file < p.eval.start_decoder_after: return samples_per_summary = p.eval.decoder_samples_per_summary if samples_per_summary is None: samples_per_summary = p.eval.samples_per_summary if samples_per_summary == 0: assert self._task.input.params.resettable self.checkpointer.RestoreFromPath(sess, checkpoint_path) global_step = sess.run(py_utils.GetGlobalStep()) if self._task.input.params.resettable: tf.logging.info('Resetting input_generator.') self._task.input.Reset(sess) dec_metrics = self._task.CreateDecoderMetrics() if not dec_metrics: tf.logging.info('Empty decoder metrics') return buffered_decode_out = [] num_examples_metric = dec_metrics['num_samples_in_batch'] start_time = time.time() while samples_per_summary == 0 or (num_examples_metric.total_value < samples_per_summary): try: is_first_loop = num_examples_metric.total_value == 0 tf.logging.info('Fetching dec_output.') fetch_start = time.time() run_options = tf.RunOptions( report_tensor_allocations_upon_oom=False) # NOTE: We intentionally do not generate scalar summaries by # default, because decoder is run multiple times for each # checkpoint. Multiple summaries at the same step is often confusing. # Instead, models should generate aggregate summaries using # PostProcessDecodeOut. Other types of summaries (images, audio etc.) # will be generated for the first eval batch. if self._summary_op is not None and is_first_loop: dec_out, summaries = sess.run( [self._dec_output, self._summary_op], options=run_options) summaries = self._RemoveScalarSummaries(summaries) # Add non-scalar summaries only for the first batch of data. self._summary_writer.add_summary(summaries, global_step) self._summary_writer.flush() else: dec_out = sess.run(self._dec_output, options=run_options) self._RunTF2SummaryOps(sess) post_process_start = time.time() tf.logging.info('Done fetching (%f seconds)' % (post_process_start - fetch_start)) decode_out = self._task.PostProcessDecodeOut( dec_out, dec_metrics) if decode_out: if isinstance(decode_out, dict): decode_out = decode_out.items() if is_first_loop: # Add summaries only for the first batch of data. for key, value in decode_out: if isinstance(value, tf.Summary): tf.logging.info( f'Adding summary {key} with tags ' f'{[x.tag for x in value.value]}.') self._summary_writer.add_summary( value, global_step) self._summary_writer.flush() buffered_decode_out.extend( kv for kv in decode_out if not isinstance(kv[1], tf.Summary)) tf.logging.info( 'Total examples done: %d/%d ' '(%f seconds decode postprocess)', num_examples_metric.total_value, samples_per_summary, time.time() - post_process_start) except tf.errors.OutOfRangeError: if not self._task.input.params.resettable: raise break tf.logging.info('Done decoding ckpt: %s', checkpoint_path) summaries = {k: v.Summary(k) for k, v in dec_metrics.items()} elapsed_secs = time.time() - start_time example_rate = num_examples_metric.total_value / elapsed_secs summaries['examples/sec'] = metrics.CreateScalarSummary( 'examples/sec', example_rate) summaries['total_samples'] = metrics.CreateScalarSummary( 'total_samples', num_examples_metric.total_value) self._WriteSummaries(self._summary_writer, os.path.basename(self._decoder_dir), global_step, summaries, text_filename=os.path.join( self._decoder_dir, 'score-{:08d}.txt'.format(global_step))) self._ExportMetrics( # Metrics expects python int, but global_step is numpy.int64. decode_checkpoint=int(global_step), dec_metrics=dec_metrics, example_rate=example_rate) # global_step and the checkpoint id from the checkpoint file might be # different. For consistency of checkpoint filename and decoder_out # file, use the checkpoint id as derived from the checkpoint filename. checkpoint_id = _GetCheckpointIdForDecodeOut(ckpt_id_from_file, global_step) decode_out_path = self.GetDecodeOutPath(self._decoder_dir, checkpoint_id) decode_finalize_args = base_model.DecodeFinalizeArgs( decode_out_path=decode_out_path, decode_out=buffered_decode_out) self._task.DecodeFinalize(decode_finalize_args) if self._should_report_metrics: tf.logging.info('Reporting eval measure for step %d.' % global_step) self._trial.ReportEvalMeasure(global_step, dec_metrics, checkpoint_path)