Beispiel #1
0
 def test_config_benchmark_file_logger(self):
   # Set the benchmark_log_dir first since the benchmark_logger_type will need
   # the value to be set when it does the validation.
   with flagsaver.flagsaver(benchmark_log_dir='/tmp'):
     with flagsaver.flagsaver(benchmark_logger_type='BenchmarkFileLogger'):
       logger.config_benchmark_logger()
       self.assertIsInstance(logger.get_benchmark_logger(),
                             logger.BenchmarkFileLogger)
Beispiel #2
0
def get_logging_metric_hook(tensors_to_log=None, every_n_secs=600, **kwargs):  # pylint: disable=unused-argument
    """Function to get LoggingMetricHook.

  Args:
    tensors_to_log: List of tensor names or dictionary mapping labels to tensor
      names. If not set, log _TENSORS_TO_LOG by default.
    every_n_secs: `int`, the frequency for logging the metric. Default to every
      10 mins.

  Returns:
    Returns a LoggingMetricHook that saves tensor values in a JSON format.
  """
    if tensors_to_log is None:
        tensors_to_log = _TENSORS_TO_LOG
    return metric_hook.LoggingMetricHook(
        tensors=tensors_to_log,
        metric_logger=logger.get_benchmark_logger(),
        every_n_secs=every_n_secs)
Beispiel #3
0
def get_examples_per_second_hook(every_n_steps=100,
                                 batch_size=128,
                                 warm_steps=5,
                                 **kwargs):  # pylint: disable=unused-argument
    """Function to get ExamplesPerSecondHook.

  Args:
    every_n_steps: `int`, print current and average examples per second every
      N steps.
    batch_size: `int`, total batch size used to calculate examples/second from
      global time.
    warm_steps: skip this number of steps before logging and running average.
    **kwargs: a dictionary of arguments to ExamplesPerSecondHook.

  Returns:
    Returns a ProfilerHook that writes out timelines that can be loaded into
    profiling tools like chrome://tracing.
  """
    return hooks.ExamplesPerSecondHook(
        batch_size=batch_size,
        every_n_steps=every_n_steps,
        warm_steps=warm_steps,
        metric_logger=logger.get_benchmark_logger())
Beispiel #4
0
def run_loop(name,
             train_input_fn,
             eval_input_fn,
             model_column_fn,
             build_estimator_fn,
             flags_obj,
             tensors_to_log,
             early_stop=False):
    """Define training loop."""
    model_helpers.apply_clean(flags.FLAGS)
    with tf.contrib.tfprof.ProfileContext(
            '/tmp/census_training_profile') as pctx:
        model = build_estimator_fn(
            model_dir=flags_obj.model_dir,
            model_type=flags_obj.model_type,
            model_column_fn=model_column_fn,
            inter_op=flags_obj.inter_op_parallelism_threads,
            intra_op=flags_obj.intra_op_parallelism_threads)

        run_params = {
            'batch_size': flags_obj.batch_size,
            'train_epochs': flags_obj.train_epochs,
            'model_type': flags_obj.model_type,
        }

        benchmark_logger = logger.get_benchmark_logger()
        benchmark_logger.log_run_info('wide_deep',
                                      name,
                                      run_params,
                                      test_id=flags_obj.benchmark_test_id)

        loss_prefix = LOSS_PREFIX.get(flags_obj.model_type, '')
        tensors_to_log = {
            k: v.format(loss_prefix=loss_prefix)
            for k, v in tensors_to_log.items()
        }
        train_hooks = hooks_helper.get_train_hooks(
            flags_obj.hooks,
            model_dir=flags_obj.model_dir,
            batch_size=flags_obj.batch_size,
            tensors_to_log=tensors_to_log)

        # Train and evaluate the model every `flags.epochs_between_evals` epochs.
        for n in range(flags_obj.train_epochs //
                       flags_obj.epochs_between_evals):
            model.train(input_fn=train_input_fn, hooks=train_hooks)

            results = model.evaluate(input_fn=eval_input_fn)

            # Display evaluation metrics
            tf.logging.info('Results at epoch %d / %d',
                            (n + 1) * flags_obj.epochs_between_evals,
                            flags_obj.train_epochs)
            tf.logging.info('-' * 60)

            for key in sorted(results):
                tf.logging.info('%s: %s' % (key, results[key]))

            benchmark_logger.log_evaluation_result(results)

            if early_stop and model_helpers.past_stop_threshold(
                    flags_obj.stop_threshold, results['accuracy']):
                break

        # Export the model
        if flags_obj.export_dir is not None:
            export_model(model, flags_obj.model_type, flags_obj.export_dir,
                         model_column_fn)
Beispiel #5
0
 def test_config_benchmark_bigquery_logger(self, mock_bigquery_client):
   with flagsaver.flagsaver(benchmark_logger_type='BenchmarkBigQueryLogger'):
     logger.config_benchmark_logger()
     self.assertIsInstance(logger.get_benchmark_logger(),
                           logger.BenchmarkBigQueryLogger)
Beispiel #6
0
 def test_config_base_benchmark_logger(self):
   with flagsaver.flagsaver(benchmark_logger_type='BaseBenchmarkLogger'):
     logger.config_benchmark_logger()
     self.assertIsInstance(logger.get_benchmark_logger(),
                           logger.BaseBenchmarkLogger)
Beispiel #7
0
 def test_get_default_benchmark_logger(self):
   with flagsaver.flagsaver(benchmark_logger_type='foo'):
     self.assertIsInstance(logger.get_benchmark_logger(),
                           logger.BaseBenchmarkLogger)