def cifar10(mode, params): """CIFAR10 dataset creator.""" # Images are naturally 32x32. model_mode = utils.estimator_mode_to_model_mode(mode) hparams = params['hparams'] is_training = model_mode == enums.ModelMode.TRAIN preprocessor = preprocessing.ImageToMultiViewedImagePreprocessor( is_training=is_training, preprocessing_options=hparams.input_data.preprocessing, dataset_options=preprocessing.DatasetOptions( decode_input=False, image_mean_std=(np.array([[[-0.0172, -0.0356, -0.107]]]), np.array([[[0.4046, 0.3988, 0.402]]]))), bfloat16_supported=params['use_tpu']) cifar_input = TfdsInput( dataset_name='cifar10:3.*.*', split='train' if is_training else 'test', mode=model_mode, preprocessor=preprocessor, shard_per_host=hparams.input_data.shard_per_host, cache=is_training, shuffle_buffer=50000, num_parallel_calls=64, max_samples=hparams.input_data.max_samples, label_noise_prob=hparams.input_data.label_noise_prob, num_classes=get_num_classes(hparams), data_dir=params['data_dir'], ) return cifar_input.input_fn(params)
def imagenet(mode, params): """An input_fn for ImageNet (ILSVRC 2012) data.""" model_mode = utils.estimator_mode_to_model_mode(mode) hparams = params['hparams'] is_training = model_mode == enums.ModelMode.TRAIN preprocessor = preprocessing.ImageToMultiViewedImagePreprocessor( is_training=is_training, preprocessing_options=hparams.input_data.preprocessing, dataset_options=preprocessing.DatasetOptions(decode_input=False), bfloat16_supported=params['use_tpu']) imagenet_input = TfdsInput( dataset_name='imagenet2012:5.*.*', split='train' if is_training else 'validation', mode=model_mode, preprocessor=preprocessor, shuffle_buffer=1024, shard_per_host=hparams.input_data.shard_per_host, cache=is_training, num_parallel_calls=64, max_samples=hparams.input_data.max_samples, label_noise_prob=hparams.input_data.label_noise_prob, num_classes=get_num_classes(hparams), data_dir=params['data_dir'], ) return imagenet_input.input_fn(params)
def test_preprocess_image(self, decode_image, image_dtype, image_size, is_training, augmentation_type, warp_prob, augmentation_magnitude, eval_crop_method, image_mean_std): # tf.random.uniform() doesn't allow generating random values that are uint8. image = tf.cast( tf.random.uniform( shape=(300, 400, 3), minval=0, maxval=255, dtype=tf.float32), dtype=tf.uint8) if decode_image: image = tf.image.encode_jpeg(image) else: image = tf.cast(image, image_dtype) expect_error = (not decode_image and image_dtype != tf.uint8) if expect_error: context_manager = self.assertRaises(AssertionError) else: context_manager = nullcontext() with context_manager: output = preprocessing.preprocess_image( image, is_training=is_training, bfloat16_supported=False, preprocessing_options=hparams.ImagePreprocessing( image_size=image_size, augmentation_type=augmentation_type, warp_probability=warp_prob, augmentation_magnitude=augmentation_magnitude, eval_crop_method=eval_crop_method, ), dataset_options=preprocessing.DatasetOptions( image_mean_std=image_mean_std, decode_input=decode_image)) self.assertEqual(output.dtype, tf.float32) self.assertEqual([image_size, image_size, 3], output.shape.as_list())
def test_input_class(self, input_class, model_mode, image_size, max_samples): split = 'train' if model_mode == enums.ModelMode.TRAIN else 'test' batch_size = 2 dataset_size = 10 expected_num_batches = dataset_size // batch_size if max_samples is not None and model_mode == enums.ModelMode.TRAIN: expected_num_batches = max_samples // batch_size params = {'batch_size': batch_size} if input_class == 'TfdsInput': with tfds.testing.mock_data(num_examples=dataset_size): data = inputs.TfdsInput( 'cifar10', split, mode=model_mode, preprocessor=preprocessing. ImageToMultiViewedImagePreprocessor( is_training=model_mode == enums.ModelMode.TRAIN, preprocessing_options=hparams.ImagePreprocessing( image_size=image_size, num_views=2), dataset_options=preprocessing.DatasetOptions( decode_input=False)), max_samples=max_samples, num_classes=10).input_fn(params) else: raise ValueError(f'Unknown input class {input_class}') expected_num_channels = 3 if model_mode == enums.ModelMode.INFERENCE else 6 expected_batch_size = (None if model_mode == enums.ModelMode.INFERENCE else batch_size) if model_mode == enums.ModelMode.INFERENCE: self.assertIsInstance( data, tf.estimator.export.TensorServingInputReceiver) image_shape = data.features.shape.as_list() else: self.assertIsInstance(data, tf.data.Dataset) shapes = tf.data.get_output_shapes(data) image_shape = shapes[0].as_list() label_shape = shapes[1].as_list() self.assertEqual([expected_batch_size], label_shape) self.assertEqual([ expected_batch_size, image_size, image_size, expected_num_channels ], image_shape) if model_mode == enums.ModelMode.INFERENCE: return # Now extract the Tensors data = tf.data.make_one_shot_iterator(data).get_next()[0] with self.cached_session() as sess: for i in range(expected_num_batches + 1): if i == expected_num_batches and model_mode == enums.ModelMode.EVAL: with self.assertRaises(tf.errors.OutOfRangeError): sess.run(data) break else: sess.run(data)
def test_image_to_multi_viewed_image_preprocessor(self, is_training, decode_image, num_views): # tf.random.uniform() doesn't allow generating random values that are uint8. image = tf.cast( tf.random.uniform( shape=(300, 400, 3), minval=0, maxval=255, dtype=tf.float32), dtype=tf.uint8) if decode_image: image = tf.image.encode_jpeg(image) image_size = 32 preprocessor = preprocessing.ImageToMultiViewedImagePreprocessor( is_training=is_training, preprocessing_options=hparams.ImagePreprocessing( image_size=image_size, num_views=num_views), dataset_options=preprocessing.DatasetOptions(decode_input=decode_image)) output = preprocessor.preprocess(image) self.assertEqual(output.dtype, tf.float32) self.assertEqual([image_size, image_size, 3 * num_views], output.shape.as_list())