Ejemplo n.º 1
0
def _read_tfds(tfds_builder: tfds.core.DatasetBuilder,
               tfds_split: Text,
               tfds_skip_decoding_feature: Text,
               tfds_as_supervised: bool,
               input_context: Optional[tf.distribute.InputContext] = None,
               seed: Optional[Union[int, tf.Tensor]] = None,
               is_training: bool = False,
               cache: bool = False,
               cycle_length: Optional[int] = None,
               block_length: Optional[int] = None) -> tf.data.Dataset:
    """Reads a dataset from tfds."""
    # No op if exist.
    tfds_builder.download_and_prepare()

    read_config = tfds.ReadConfig(interleave_cycle_length=cycle_length,
                                  interleave_block_length=block_length,
                                  input_context=input_context,
                                  shuffle_seed=seed)
    decoders = {}
    if tfds_skip_decoding_feature:
        for skip_feature in tfds_skip_decoding_feature.split(','):
            decoders[skip_feature.strip()] = tfds.decode.SkipDecoding()
    dataset = tfds_builder.as_dataset(split=tfds_split,
                                      shuffle_files=is_training,
                                      as_supervised=tfds_as_supervised,
                                      decoders=decoders,
                                      read_config=read_config)

    if is_training and not cache:
        dataset = dataset.repeat()
    return dataset
Ejemplo n.º 2
0
def _read_tfds(tfds_builder: tfds.core.DatasetBuilder,
               tfds_split: Text,
               tfds_skip_decoding_feature: Text,
               tfds_as_supervised: bool,
               input_context: Optional[tf.distribute.InputContext] = None,
               seed: Optional[Union[int, tf.Tensor]] = None,
               is_training: bool = False,
               cache: bool = False,
               cycle_length: Optional[int] = None,
               block_length: Optional[int] = None) -> tf.data.Dataset:
    """Reads a dataset from tfds."""
    # No op if exist.
    tfds_builder.download_and_prepare()
    decoders = {}
    if tfds_skip_decoding_feature:
        for skip_feature in tfds_skip_decoding_feature.split(','):
            decoders[skip_feature.strip()] = tfds.decode.SkipDecoding()
    if tfds_builder.info.splits:
        num_shards = len(
            tfds_builder.info.splits[tfds_split].file_instructions)
    else:
        # The tfds mock path often does not provide splits.
        num_shards = 1
    if input_context and num_shards < input_context.num_input_pipelines:
        # The number of files in the dataset split is smaller than the number of
        # input pipelines. We read the entire dataset first and then shard in the
        # host memory.
        read_config = tfds.ReadConfig(interleave_cycle_length=cycle_length,
                                      interleave_block_length=block_length,
                                      input_context=None,
                                      shuffle_seed=seed)
        dataset = tfds_builder.as_dataset(split=tfds_split,
                                          shuffle_files=is_training,
                                          as_supervised=tfds_as_supervised,
                                          decoders=decoders,
                                          read_config=read_config)
        dataset = dataset.shard(input_context.num_input_pipelines,
                                input_context.input_pipeline_id)
    else:
        read_config = tfds.ReadConfig(interleave_cycle_length=cycle_length,
                                      interleave_block_length=block_length,
                                      input_context=input_context,
                                      shuffle_seed=seed)
        dataset = tfds_builder.as_dataset(split=tfds_split,
                                          shuffle_files=is_training,
                                          as_supervised=tfds_as_supervised,
                                          decoders=decoders,
                                          read_config=read_config)

    if is_training and not cache:
        dataset = dataset.repeat()
    return dataset
Ejemplo n.º 3
0
def create_split(dataset_builder: tfds.core.DatasetBuilder,
                 batch_size: int,
                 train: bool,
                 dtype: tf.DType = tf.float32,
                 image_size: int = IMAGE_SIZE,
                 cache: bool = False):
    """Creates a split from the ImageNet dataset using TensorFlow Datasets.

  Args:
    dataset_builder: TFDS dataset builder for ImageNet.
    batch_size: the batch size returned by the data pipeline.
    train: Whether to load the train or evaluation split.
    dtype: data type of the image (default: float32).
    image_size: The target size of the images (default: 224).
    cache: Whether to cache the dataset (default: False).
  Returns:
    A `tf.data.Dataset`.
  """
    if train:
        train_size = dataset_builder.info.splits['train'].num_examples
        split_size = train_size // jax.host_count()
        start = jax.host_id() * split_size
        split = 'train[{}:{}]'.format(start, start + split_size)
    else:
        validation_size = dataset_builder.info.splits[
            'validation'].num_examples
        split_size = validation_size // jax.host_count()
        start = jax.host_id() * split_size
        split = 'validation[{}:{}]'.format(start, start + split_size)

    def _decode_example(example):
        if train:
            image = preprocess_for_train(example['image'], dtype, image_size)
        else:
            image = preprocess_for_eval(example['image'], dtype, image_size)
        return {'image': image, 'label': example['label']}

    ds = dataset_builder.as_dataset(
        split=split, decoders={'image': tfds.decode.SkipDecoding()})
    ds.options().experimental_threading.private_threadpool_size = 48
    ds.options().experimental_threading.max_intra_op_parallelism = 1

    if cache:
        ds = ds.cache()

    if train:
        ds = ds.repeat()
        ds = ds.shuffle(16 * batch_size, seed=0)

    ds = ds.map(_decode_example,
                num_parallel_calls=tf.data.experimental.AUTOTUNE)

    ds = ds.batch(batch_size, drop_remainder=True)

    if not train:
        ds = ds.repeat()

    ds = ds.prefetch(10)

    return ds
Ejemplo n.º 4
0
def get_raw_dataset(dataset_builder: tfds.core.DatasetBuilder,
                    split: str,
                    *,
                    reverse_translation: bool = False) -> tf.data.Dataset:
    """Loads a raw WMT dataset and normalizes feature keys.

  Args:
    dataset_builder: TFDS dataset builder that can build `slit`.
    split: Split to use. This must be the full split. We shard the split across
      multiple hosts and currently don't support sharding subsplits.
    reverse_translation: bool: whether to reverse the translation direction.
      e.g. for 'de-en' this translates from english to german.

  Returns:
    Dataset with source and target language features mapped to 'inputs' and
    'targets'.
  """
    num_examples = dataset_builder.info.splits[split].num_examples
    per_host_split = deterministic_data.get_read_instruction_for_host(
        split, num_examples, drop_remainder=False)
    ds = dataset_builder.as_dataset(split=per_host_split, shuffle_files=False)
    ds = ds.map(NormalizeFetaureNamesOp(
        dataset_builder.info, reverse_translation=reverse_translation),
                num_parallel_calls=AUTOTUNE)
    return ds
Ejemplo n.º 5
0
    def _build_autocached_info(self, builder: tfds.core.DatasetBuilder):
        """Returns the auto-cache information string."""
        always_cached = {}
        never_cached = {}
        unshuffle_cached = {}
        for split_name in sorted(builder.info.splits.keys()):
            split_name = str(split_name)
            cache_shuffled = builder._should_cache_ds(  # pylint: disable=protected-access
                split_name,
                shuffle_files=True,
                read_config=tfds.ReadConfig())
            cache_unshuffled = builder._should_cache_ds(  # pylint: disable=protected-access
                split_name,
                shuffle_files=False,
                read_config=tfds.ReadConfig())

            if all((cache_shuffled, cache_unshuffled)):
                always_cached[split_name] = None
            elif not any((cache_shuffled, cache_unshuffled)):
                never_cached[split_name] = None
            else:  # Dataset is only cached when shuffled_files is False
                assert not cache_shuffled and cache_unshuffled
                unshuffle_cached[split_name] = None

        if not len(builder.info.splits) or not builder.info.dataset_size:  # pylint: disable=g-explicit-length-test
            autocached_info = 'Unknown'
        elif len(always_cached) == len(builder.info.splits.keys()):
            autocached_info = 'Yes'  # All splits are auto-cached.
        elif len(never_cached) == len(builder.info.splits.keys()):
            autocached_info = 'No'  # Splits never auto-cached.
        else:  # Some splits cached, some not.
            autocached_info_parts = []
            if always_cached:
                split_names_str = ', '.join(always_cached)
                autocached_info_parts.append(
                    'Yes ({})'.format(split_names_str))
            if never_cached:
                split_names_str = ', '.join(never_cached)
                autocached_info_parts.append('No ({})'.format(split_names_str))
            if unshuffle_cached:
                split_names_str = ', '.join(unshuffle_cached)
                autocached_info_parts.append(
                    'Only when `shuffle_files=False` ({})'.format(
                        split_names_str))
            autocached_info = ', '.join(autocached_info_parts)
        return autocached_info
Ejemplo n.º 6
0
  def __init__(self, dataset_builder: tfds.core.DatasetBuilder):
    """A NitroML dataset from a TFDS DatasetBuilder.

    Args:
      dataset_builder: A `tfds.DatasetBuilder` instance which defines the
        TFDS dataset to use. Example: `dataset =
          TFDSTask(tfds.builder('titanic'))`
    """

    # TODO(b/159086401): Download and prepare the dataset in a component
    # instead of at construction time, so that this step happens lazily during
    # pipeline execution.
    logging.info('Preparing dataset...')
    dataset_builder.download_and_prepare()
    logging.info(dataset_builder.info)

    self._dataset_builder = dataset_builder
    self._example_gen = self._make_example_gen()
Ejemplo n.º 7
0
def _download_and_prepare(
    args: argparse.Namespace,
    builder: tfds.core.DatasetBuilder,
) -> None:
  """Generate a single builder."""
  logging.info(f'download_and_prepare for dataset {builder.info.full_name}...')

  dl_config = _make_download_config(args)
  if args.add_name_to_manual_dir:
    dl_config.manual_dir = os.path.join(dl_config.manual_dir, builder.name)

  builder.download_and_prepare(
      download_dir=args.download_dir,
      download_config=dl_config,
  )

  # Dataset generated successfully
  logging.info('Dataset generation complete...')
  termcolor.cprint(str(builder.info.as_proto), attrs=['bold'])
Ejemplo n.º 8
0
def get_raw_dataset(dataset_builder: tfds.core.DatasetBuilder,
                    split: str) -> tf.data.Dataset:
    """Loads a raw text dataset and normalizes feature keys.

  Args:
    dataset_builder: TFDS dataset builder that can build `split`.
    split: Split to use. This must be the full split. We shard the split across
      multiple hosts and currently don't support sharding subsplits.

  Returns:
    Dataset with source and target language features mapped to 'inputs' and
    'targets'.
  """
    per_host_split = deterministic_data.get_read_instruction_for_host(
        split, dataset_info=dataset_builder.info, drop_remainder=False)
    ds = dataset_builder.as_dataset(split=per_host_split, shuffle_files=False)
    ds = ds.map(NormalizeFeatureNamesOp(dataset_builder.info),
                num_parallel_calls=AUTOTUNE)
    return ds
def create_split(
    dataset_builder: tfds.core.DatasetBuilder,
    batch_size: int,
    train: bool = True,
    half_precision: bool = False,
    image_size: int = IMAGE_SIZE,
    mean: Optional[Tuple[float]] = None,
    std: Optional[Tuple[float]] = None,
    interpolation: str = 'bicubic',
    augment_name: Optional[str] = None,
    randaug_num_layers: Optional[int] = None,
    randaug_magnitude: Optional[int] = None,
    cache: bool = False,
    no_repeat: bool = False,
):
    """Creates a split from the ImageNet dataset using TensorFlow Datasets.

    Args:
      dataset_builder: TFDS dataset builder for ImageNet.
      batch_size: the batch size returned by the data pipeline.
      train: Whether to load the train or evaluation split.
      half_precision: convert image datatype to half-precision
      image_size: The target size of the images (default: 224).
      mean: image dataset mean
      std: image dataset std-dev
      interpolation: interpolation method to use for image resize (default: 'bicubic')
      cache: Whether to cache the dataset (default: False).
      no_repeat: disable repeat iter for evaluation
    Returns:
      A `tf.data.Dataset`.
    """
    mean = mean or MEAN_RGB
    std = std or STDDEV_RGB
    interpolation = tf.image.ResizeMethod.BICUBIC if interpolation == 'bicubic' else tf.image.ResizeMethod.BILINEAR
    platform = jax.local_devices()[0].platform
    if half_precision:
        if platform == 'tpu':
            input_dtype = tf.bfloat16
        else:
            input_dtype = tf.float16
    else:
        input_dtype = tf.float32

    if train:
        data_size = dataset_builder.info.splits['train'].num_examples
        split = 'train'
    else:
        data_size = dataset_builder.info.splits['validation'].num_examples
        split = 'validation'
    split_size = data_size // jax.host_count()
    start = jax.host_id() * split_size
    split = split + '[{}:{}]'.format(start, start + split_size)

    def _decode_example(example):
        if train:
            image = preprocess_for_train(example['image'],
                                         input_dtype,
                                         image_size,
                                         mean,
                                         std,
                                         interpolation,
                                         augment_name=augment_name,
                                         randaug_num_layers=randaug_num_layers,
                                         randaug_magnitude=randaug_magnitude)
        else:
            image = preprocess_for_eval(example['image'], input_dtype,
                                        image_size, mean, std, interpolation)
        return {'image': image, 'label': example['label']}

    ds = dataset_builder.as_dataset(
        split=split, decoders={'image': tfds.decode.SkipDecoding()})
    ds.options().experimental_threading.private_threadpool_size = 16
    ds.options().experimental_threading.max_intra_op_parallelism = 1

    if cache:
        ds = ds.cache()

    if train:
        ds = ds.repeat()
        ds = ds.shuffle(16 * batch_size, seed=0)

    ds = ds.map(_decode_example,
                num_parallel_calls=tf.data.experimental.AUTOTUNE)
    ds = ds.batch(batch_size, drop_remainder=True)

    if not train and not no_repeat:
        ds = ds.repeat()

    ds = ds.prefetch(10)

    return ds