def build_dataset(config, is_training=False, fake_data=False): """Returns a tf.data.Dataset with <image, label> pairs. Args: config: DataConfig instance. is_training: Whether to build a dataset for training (with shuffling and image distortions). fake_data: If true, use randomly generated data. Returns: tf.data.Dataset """ if fake_data: logging.info('Generating fake data for config: %s', config) return image_data_utils.make_fake_data(CIFAR_SHAPE) logging.info('Building dataset for config:\n%s', attr.asdict(config)) # TODO(yovadia): Split off a validation set from the training set. if config.corruption_type and config.corruption_static: return image_data_utils.make_static_dataset(config, _get_static_cifar_c) if config.alt_dataset_name: all_images, all_labels = _download_alt_dataset(config) else: all_images, all_labels = _download_cifar_split(config.split, is_training) if config.corruption_type: assert (config.corruption_value is not None) != (config.corruption_level > 0) # NOTE: dhtd corruptions expect to be applied before float32 conversion. apply_corruption = functools.partial( robustness_dhtd.corrupt, severity=config.corruption_level, severity_value=config.corruption_value, dim=32, corruption_name=config.corruption_type, dataset_name='cifar') all_images = np.stack([apply_corruption(im) for im in all_images]) dataset = tf.data.Dataset.from_tensor_slices((all_images, all_labels)) def prep_fn(image, label): """Image preprocessing function.""" if config.roll_pixels: image = tf.roll(image, config.roll_pixels, -2) if is_training: image = tf.image.random_flip_left_right(image) image = tf.pad(image, [[4, 4], [4, 4], [0, 0]]) image = tf.image.random_crop(image, CIFAR_SHAPE) image = tf.image.convert_image_dtype(image, tf.float32) return image, label return dataset.map(prep_fn)
def input_fn(self): """Input function which provides a single batch for train or eval. Returns: A `tf.data.Dataset` object. """ if self.fake_data: return image_data_utils.make_fake_data(IMAGENET_SHAPE) train_path_tmpl = os.path.join(self.data_dir, 'train-{0:05d}*') if self.dataset_split == 'train': file_pattern = [ train_path_tmpl.format(i) for i in range( IMAGENET_VALID_SHARDS, IMAGENET_TRAIN_AND_VALID_SHARDS) ] elif self.dataset_split == 'valid': file_pattern = [ train_path_tmpl.format(i) for i in range(IMAGENET_VALID_SHARDS) ] elif self.dataset_split == 'test': file_pattern = os.path.join(self.data_dir, 'validation-*') else: raise ValueError( "Dataset_split must be 'train', 'valid', or 'test', was %s" % self.dataset_split) # Shuffle the filenames to ensure better randomization. dataset = tf.data.Dataset.list_files(file_pattern, shuffle=self.is_training) if self.is_training: dataset = dataset.repeat() def fetch_dataset(filename): buffer_size = 8 * 1024 * 1024 # 8 MiB per file dataset = tf.data.TFRecordDataset(filename, buffer_size=buffer_size) return dataset # Read the data from disk in parallel dataset = dataset.interleave(fetch_dataset, cycle_length=16) if self.is_training: dataset = dataset.shuffle(1024) # Parse, pre-process, and batch the data in parallel (for speed, it's # necessary to apply batching here rather than using dataset.batch later) dataset = dataset.apply( tf.data.experimental.map_and_batch(self.dataset_parser, batch_size=self.batch_size, num_parallel_batches=2, drop_remainder=True)) # Prefetch overlaps in-feed with training dataset = dataset.prefetch(tf.data.experimental.AUTOTUNE) if self.is_training: # Use a private thread pool and limit intra-op parallelism. Enable # non-determinism only for training. options = tf.data.Options() options.experimental_threading.max_intra_op_parallelism = 1 options.experimental_threading.private_threadpool_size = 16 options.experimental_deterministic = False dataset = dataset.with_options(options) return dataset
def build_dataset(config, batch_size, is_training=False, fake_data=False, use_bfloat16=False): """Returns a tf.data.Dataset with <image, label> pairs. Args: config: DataConfig instance. batch_size: Dataset batch size. is_training: Whether to build a dataset for training (with shuffling and image distortions). fake_data: If True, use randomly generated data. use_bfloat16: If True, use bfloat16. If False, use float32. Returns: tf.data.Dataset """ if fake_data: logging.info('Generating fake data for config: %s', config) return image_data_utils.make_fake_data(IMAGENET_SHAPE).batch( batch_size) if config.alt_dataset_name: dataset = _download_alt_dataset(config, shuffle_files=is_training) def prep_fn(image_input): image = tf.image.convert_image_dtype(image_input['image'], tf.float32) image = tf.image.crop_to_bounding_box(image, 20, 0, 178, 178) image = tf.image.resize(image, (224, 224)) # omit CelebA labels return image, -1 return dataset.map(prep_fn).batch(batch_size) logging.info('Building dataset for config:\n%s', attr.asdict(config)) if config.corruption_type and config.corruption_static: return image_data_utils.make_static_dataset( config, _get_static_imagenet_c).batch(batch_size) dataset_builder = imagenet_input.ImageNetInput(is_training=is_training, data_dir=FLAGS.imagenet_dir, batch_size=batch_size, dataset_split=config.split, use_bfloat16=use_bfloat16) dataset = dataset_builder.input_fn() if config.corruption_type: assert (config.corruption_value is not None) != (config.corruption_level > 0) # NOTE: dhtd corruptions expect to be applied before float32 conversion. def apply_corruption(image, label): """Apply the corruption function to the image.""" image = tf.image.convert_image_dtype(image, tf.uint8) corruption_fn = functools.partial( robustness_dhtd.corrupt, severity=config.corruption_level, severity_value=config.corruption_value, dim=224, corruption_name=config.corruption_type, dataset_name='imagenet') def apply_to_batch(ims): ims_numpy = ims.numpy() for i in range(ims_numpy.shape[0]): ims_numpy[i] = corruption_fn(ims_numpy[i]) return ims_numpy image = tf.py_function(func=apply_to_batch, inp=[image], Tout=tf.float32) image = tf.clip_by_value(image, 0., 255.) / 255. return image, label dataset = dataset.map(apply_corruption) if config.roll_pixels: def roll_fn(image, label): """Function to roll pixels.""" image = tf.roll(image, config.roll_pixels, -2) return image, label dataset = dataset.map(roll_fn) return dataset