def get_test_image_preprocessor(batch_size, params): """Returns the preprocessing.TestImagePreprocessor that should be injected. Returns None if no preprocessor should be injected. Args: batch_size: The batch size across all GPUs. params: BenchmarkCNN's parameters. Returns: Returns the preprocessing.TestImagePreprocessor that should be injected. Raises: ValueError: Flag --fake_input is an invalid value. """ if FLAGS.fake_input == 'none': return None elif FLAGS.fake_input == 'zeros_and_ones': half_batch_size = batch_size // 2 images = np.zeros((batch_size, 227, 227, 3), dtype=np.float32) images[half_batch_size:, :, :, :] = 1 labels = np.array([0] * half_batch_size + [1] * half_batch_size, dtype=np.int32) preprocessor = preprocessing.TestImagePreprocessor( 227, 227, batch_size, params.num_gpus, benchmark_cnn.get_data_type(params)) preprocessor.set_fake_data(images, labels) preprocessor.expected_subset = 'validation' if params.eval else 'train' return preprocessor else: raise ValueError('Invalid --fake_input: %s' % FLAGS.fake_input)
def _run_benchmark_cnn_with_fake_images(self, params, images, labels): logs = [] benchmark_cnn.log_fn = _print_and_add_to_list(logs) bench = benchmark_cnn.BenchmarkCNN(params) bench.image_preprocessor = preprocessing.TestImagePreprocessor( 227, 227, params.batch_size * params.num_gpus, params.num_gpus, benchmark_cnn.get_data_type(params)) bench.dataset._queue_runner_required = True bench.image_preprocessor.set_fake_data(images, labels) bench.image_preprocessor.expected_subset = ('validation' if params.eval else 'train') bench.run() return logs