def _test_run_benchmark(self, params):
   """Tests that run_benchmark() runs successfully with the params."""
   logs = []
   with test_util.monkey_patch(all_reduce_benchmark,
                               log_fn=test_util.print_and_add_to_list(logs)):
     bench_cnn = benchmark_cnn.BenchmarkCNN(params)
     all_reduce_benchmark.run_benchmark(bench_cnn, num_iters=5)
     self.assertRegex(logs[-1], '^Average time per step: [0-9.]+$')
예제 #2
0
 def _run_benchmark_cnn_with_fake_images(self, params, images, labels):
   logs = []
   benchmark_cnn.log_fn = test_util.print_and_add_to_list(logs)
   bench = benchmark_cnn.BenchmarkCNN(params)
   bench.image_preprocessor = preprocessing.TestImagePreprocessor(
       227, 227, params.batch_size * params.num_gpus, params.num_gpus,
       benchmark_cnn.get_data_type(params))
   bench.dataset._queue_runner_required = True
   bench.image_preprocessor.set_fake_data(images, labels)
   bench.image_preprocessor.expected_subset = ('validation'
                                               if params.eval else 'train')
   bench.run()
   return logs
예제 #3
0
  def _get_benchmark_cnn_losses(self, inputs, params):
    """Returns the losses of BenchmarkCNN on the given inputs and params."""
    logs = []
    model = test_util.TestCNNModel()
    with test_util.monkey_patch(benchmark_cnn,
                                log_fn=test_util.print_and_add_to_list(logs),
                                LOSS_AND_ACCURACY_DIGITS_TO_SHOW=15):
      bench = benchmark_cnn.BenchmarkCNN(
          params, dataset=test_util.TestDataSet(), model=model)
      # The test model does not use labels when computing loss, so the label
      # values do not matter as long as it's the right shape.
      labels = np.array([1] * inputs.shape[0])
      bench.image_preprocessor.set_fake_data(inputs, labels)
      bench.run()

    outputs = test_util.get_training_outputs_from_logs(
        logs, params.print_training_accuracy)
    return [x.loss for x in outputs]
예제 #4
0
 def _run_benchmark_cnn(self, params):
   logs = []
   benchmark_cnn.log_fn = test_util.print_and_add_to_list(logs)
   benchmark_cnn.BenchmarkCNN(params).run()
   return logs