Esempio n. 1
0
def build_dataset(config, is_training=False, fake_data=False):
    """Returns a tf.data.Dataset with <image, label> pairs.

  Args:
    config: DataConfig instance.
    is_training: Whether to build a dataset for training
        (with shuffling and image distortions).
    fake_data: If true, use randomly generated data.
  Returns:
    tf.data.Dataset
  """
    if fake_data:
        logging.info('Generating fake data for config: %s', config)
        return image_data_utils.make_fake_data(CIFAR_SHAPE)

    logging.info('Building dataset for config:\n%s', attr.asdict(config))
    # TODO(yovadia): Split off a validation set from the training set.
    if config.corruption_type and config.corruption_static:
        return image_data_utils.make_static_dataset(config,
                                                    _get_static_cifar_c)

    if config.alt_dataset_name:
        all_images, all_labels = _download_alt_dataset(config)
    else:
        all_images, all_labels = _download_cifar_split(config.split,
                                                       is_training)

    if config.corruption_type:
        assert (config.corruption_value
                is not None) != (config.corruption_level > 0)
        # NOTE: dhtd corruptions expect to be applied before float32 conversion.
        apply_corruption = functools.partial(
            robustness_dhtd.corrupt,
            severity=config.corruption_level,
            severity_value=config.corruption_value,
            dim=32,
            corruption_name=config.corruption_type,
            dataset_name='cifar')
        all_images = np.stack([apply_corruption(im) for im in all_images])

    dataset = tf.data.Dataset.from_tensor_slices((all_images, all_labels))

    def prep_fn(image, label):
        """Image preprocessing function."""
        if config.roll_pixels:
            image = tf.roll(image, config.roll_pixels, -2)
        if is_training:
            image = tf.image.random_flip_left_right(image)
            image = tf.pad(image, [[4, 4], [4, 4], [0, 0]])
            image = tf.image.random_crop(image, CIFAR_SHAPE)

        image = tf.image.convert_image_dtype(image, tf.float32)
        return image, label

    return dataset.map(prep_fn)
    def input_fn(self):
        """Input function which provides a single batch for train or eval.

    Returns:
      A `tf.data.Dataset` object.
    """
        if self.fake_data:
            return image_data_utils.make_fake_data(IMAGENET_SHAPE)

        train_path_tmpl = os.path.join(self.data_dir, 'train-{0:05d}*')
        if self.dataset_split == 'train':
            file_pattern = [
                train_path_tmpl.format(i) for i in range(
                    IMAGENET_VALID_SHARDS, IMAGENET_TRAIN_AND_VALID_SHARDS)
            ]
        elif self.dataset_split == 'valid':
            file_pattern = [
                train_path_tmpl.format(i) for i in range(IMAGENET_VALID_SHARDS)
            ]
        elif self.dataset_split == 'test':
            file_pattern = os.path.join(self.data_dir, 'validation-*')
        else:
            raise ValueError(
                "Dataset_split must be 'train', 'valid', or 'test', was %s" %
                self.dataset_split)

        # Shuffle the filenames to ensure better randomization.
        dataset = tf.data.Dataset.list_files(file_pattern,
                                             shuffle=self.is_training)

        if self.is_training:
            dataset = dataset.repeat()

        def fetch_dataset(filename):
            buffer_size = 8 * 1024 * 1024  # 8 MiB per file
            dataset = tf.data.TFRecordDataset(filename,
                                              buffer_size=buffer_size)
            return dataset

        # Read the data from disk in parallel
        dataset = dataset.interleave(fetch_dataset, cycle_length=16)

        if self.is_training:
            dataset = dataset.shuffle(1024)

        # Parse, pre-process, and batch the data in parallel (for speed, it's
        # necessary to apply batching here rather than using dataset.batch later)
        dataset = dataset.apply(
            tf.data.experimental.map_and_batch(self.dataset_parser,
                                               batch_size=self.batch_size,
                                               num_parallel_batches=2,
                                               drop_remainder=True))

        # Prefetch overlaps in-feed with training
        dataset = dataset.prefetch(tf.data.experimental.AUTOTUNE)

        if self.is_training:
            # Use a private thread pool and limit intra-op parallelism. Enable
            # non-determinism only for training.
            options = tf.data.Options()
            options.experimental_threading.max_intra_op_parallelism = 1
            options.experimental_threading.private_threadpool_size = 16
            options.experimental_deterministic = False
            dataset = dataset.with_options(options)

        return dataset
Esempio n. 3
0
def build_dataset(config,
                  batch_size,
                  is_training=False,
                  fake_data=False,
                  use_bfloat16=False):
    """Returns a tf.data.Dataset with <image, label> pairs.

  Args:
    config: DataConfig instance.
    batch_size: Dataset batch size.
    is_training: Whether to build a dataset for training
        (with shuffling and image distortions).
    fake_data: If True, use randomly generated data.
    use_bfloat16: If True, use bfloat16. If False, use float32.
  Returns:
    tf.data.Dataset
  """
    if fake_data:
        logging.info('Generating fake data for config: %s', config)
        return image_data_utils.make_fake_data(IMAGENET_SHAPE).batch(
            batch_size)

    if config.alt_dataset_name:
        dataset = _download_alt_dataset(config, shuffle_files=is_training)

        def prep_fn(image_input):
            image = tf.image.convert_image_dtype(image_input['image'],
                                                 tf.float32)
            image = tf.image.crop_to_bounding_box(image, 20, 0, 178, 178)
            image = tf.image.resize(image, (224, 224))

            # omit CelebA labels
            return image, -1

        return dataset.map(prep_fn).batch(batch_size)

    logging.info('Building dataset for config:\n%s', attr.asdict(config))
    if config.corruption_type and config.corruption_static:
        return image_data_utils.make_static_dataset(
            config, _get_static_imagenet_c).batch(batch_size)

    dataset_builder = imagenet_input.ImageNetInput(is_training=is_training,
                                                   data_dir=FLAGS.imagenet_dir,
                                                   batch_size=batch_size,
                                                   dataset_split=config.split,
                                                   use_bfloat16=use_bfloat16)

    dataset = dataset_builder.input_fn()

    if config.corruption_type:
        assert (config.corruption_value
                is not None) != (config.corruption_level > 0)

        # NOTE: dhtd corruptions expect to be applied before float32 conversion.
        def apply_corruption(image, label):
            """Apply the corruption function to the image."""
            image = tf.image.convert_image_dtype(image, tf.uint8)
            corruption_fn = functools.partial(
                robustness_dhtd.corrupt,
                severity=config.corruption_level,
                severity_value=config.corruption_value,
                dim=224,
                corruption_name=config.corruption_type,
                dataset_name='imagenet')

            def apply_to_batch(ims):
                ims_numpy = ims.numpy()
                for i in range(ims_numpy.shape[0]):
                    ims_numpy[i] = corruption_fn(ims_numpy[i])
                return ims_numpy

            image = tf.py_function(func=apply_to_batch,
                                   inp=[image],
                                   Tout=tf.float32)
            image = tf.clip_by_value(image, 0., 255.) / 255.
            return image, label

        dataset = dataset.map(apply_corruption)

    if config.roll_pixels:

        def roll_fn(image, label):
            """Function to roll pixels."""
            image = tf.roll(image, config.roll_pixels, -2)
            return image, label

        dataset = dataset.map(roll_fn)

    return dataset