def get_test_image_preprocessor(batch_size, params):
  """Returns the preprocessing.TestImagePreprocessor that should be injected.

  Returns None if no preprocessor should be injected.

  Args:
    batch_size: The batch size across all GPUs.
    params: BenchmarkCNN's parameters.
  Returns:
    Returns the preprocessing.TestImagePreprocessor that should be injected.
  Raises:
    ValueError: Flag --fake_input is an invalid value.
  """
  if FLAGS.fake_input == 'none':
    return None
  elif FLAGS.fake_input == 'zeros_and_ones':
    half_batch_size = batch_size // 2
    images = np.zeros((batch_size, 227, 227, 3), dtype=np.float32)
    images[half_batch_size:, :, :, :] = 1
    labels = np.array([0] * half_batch_size + [1] * half_batch_size,
                      dtype=np.int32)
    preprocessor = preprocessing.TestImagePreprocessor(
        227, 227, batch_size, params.num_gpus,
        benchmark_cnn.get_data_type(params))
    preprocessor.set_fake_data(images, labels)
    preprocessor.expected_subset = 'validation' if params.eval else 'train'
    return preprocessor
  else:
    raise ValueError('Invalid --fake_input: %s' % FLAGS.fake_input)
예제 #2
0
 def _run_benchmark_cnn_with_fake_images(self, params, images, labels):
     logs = []
     benchmark_cnn.log_fn = _print_and_add_to_list(logs)
     bench = benchmark_cnn.BenchmarkCNN(params)
     bench.image_preprocessor = preprocessing.TestImagePreprocessor(
         227, 227, params.batch_size * params.num_gpus, params.num_gpus,
         benchmark_cnn.get_data_type(params))
     bench.dataset._queue_runner_required = True
     bench.image_preprocessor.set_fake_data(images, labels)
     bench.image_preprocessor.expected_subset = ('validation' if params.eval
                                                 else 'train')
     bench.run()
     return logs