def build_dataset(config, is_training=False, fake_data=False):
    """Returns a tf.data.Dataset with <image, label> pairs.

  Args:
    config: DataConfig instance.
    is_training: Whether to build a dataset for training
        (with shuffling and image distortions).
    fake_data: If true, use randomly generated data.
  Returns:
    tf.data.Dataset
  """
    if fake_data:
        logging.info('Generating fake data for config: %s', config)
        return image_data_utils.make_fake_data(CIFAR_SHAPE)

    logging.info('Building dataset for config:\n%s', attr.asdict(config))
    # TODO(yovadia): Split off a validation set from the training set.
    if config.corruption_type and config.corruption_static:
        return image_data_utils.make_static_dataset(config,
                                                    _get_static_cifar_c)

    if config.alt_dataset_name:
        all_images, all_labels = _download_alt_dataset(config)
    else:
        all_images, all_labels = _download_cifar_split(config.split,
                                                       is_training)

    if config.corruption_type:
        assert (config.corruption_value
                is not None) != (config.corruption_level > 0)
        # NOTE: dhtd corruptions expect to be applied before float32 conversion.
        apply_corruption = functools.partial(
            robustness_dhtd.corrupt,
            severity=config.corruption_level,
            severity_value=config.corruption_value,
            dim=32,
            corruption_name=config.corruption_type,
            dataset_name='cifar')
        all_images = np.stack([apply_corruption(im) for im in all_images])

    dataset = tf.data.Dataset.from_tensor_slices((all_images, all_labels))

    def prep_fn(image, label):
        """Image preprocessing function."""
        if config.roll_pixels:
            image = tf.roll(image, config.roll_pixels, -2)
        if is_training:
            image = tf.image.random_flip_left_right(image)
            image = tf.pad(image, [[4, 4], [4, 4], [0, 0]])
            image = tf.image.random_crop(image, CIFAR_SHAPE)

        image = tf.image.convert_image_dtype(image, tf.float32)
        return image, label

    return dataset.map(prep_fn)
Example #2
0
def build_dataset(config,
                  batch_size,
                  is_training=False,
                  fake_data=False,
                  use_bfloat16=False):
    """Returns a tf.data.Dataset with <image, label> pairs.

  Args:
    config: DataConfig instance.
    batch_size: Dataset batch size.
    is_training: Whether to build a dataset for training
        (with shuffling and image distortions).
    fake_data: If True, use randomly generated data.
    use_bfloat16: If True, use bfloat16. If False, use float32.
  Returns:
    tf.data.Dataset
  """
    if fake_data:
        logging.info('Generating fake data for config: %s', config)
        return image_data_utils.make_fake_data(IMAGENET_SHAPE).batch(
            batch_size)

    if config.alt_dataset_name:
        dataset = _download_alt_dataset(config, shuffle_files=is_training)

        def prep_fn(image_input):
            image = tf.image.convert_image_dtype(image_input['image'],
                                                 tf.float32)
            image = tf.image.crop_to_bounding_box(image, 20, 0, 178, 178)
            image = tf.image.resize(image, (224, 224))

            # omit CelebA labels
            return image, -1

        return dataset.map(prep_fn).batch(batch_size)

    logging.info('Building dataset for config:\n%s', attr.asdict(config))
    if config.corruption_type and config.corruption_static:
        return image_data_utils.make_static_dataset(
            config, _get_static_imagenet_c).batch(batch_size)

    dataset_builder = imagenet_input.ImageNetInput(is_training=is_training,
                                                   data_dir=FLAGS.imagenet_dir,
                                                   batch_size=batch_size,
                                                   dataset_split=config.split,
                                                   use_bfloat16=use_bfloat16)

    dataset = dataset_builder.input_fn()

    if config.corruption_type:
        assert (config.corruption_value
                is not None) != (config.corruption_level > 0)

        # NOTE: dhtd corruptions expect to be applied before float32 conversion.
        def apply_corruption(image, label):
            """Apply the corruption function to the image."""
            image = tf.image.convert_image_dtype(image, tf.uint8)
            corruption_fn = functools.partial(
                robustness_dhtd.corrupt,
                severity=config.corruption_level,
                severity_value=config.corruption_value,
                dim=224,
                corruption_name=config.corruption_type,
                dataset_name='imagenet')

            def apply_to_batch(ims):
                ims_numpy = ims.numpy()
                for i in range(ims_numpy.shape[0]):
                    ims_numpy[i] = corruption_fn(ims_numpy[i])
                return ims_numpy

            image = tf.py_function(func=apply_to_batch,
                                   inp=[image],
                                   Tout=tf.float32)
            image = tf.clip_by_value(image, 0., 255.) / 255.
            return image, label

        dataset = dataset.map(apply_corruption)

    if config.roll_pixels:

        def roll_fn(image, label):
            """Function to roll pixels."""
            image = tf.roll(image, config.roll_pixels, -2)
            return image, label

        dataset = dataset.map(roll_fn)

    return dataset