Ejemplo n.º 1
0
def _get_tfds_dataset(
        dataset: str,
        rng: np.ndarray) -> Tuple[tf.data.Dataset, tf.data.Dataset, int]:
    """Loads a TFDS dataset."""

    dataset_builder = tfds.builder(dataset)
    num_classes = 0
    if "label" in dataset_builder.info.features:
        num_classes = dataset_builder.info.features["label"].num_classes

    # Make sure each host uses a different RNG for the training data.
    rng, data_rng = jax.random.split(rng)
    data_rng = jax.random.fold_in(data_rng, jax.host_id())
    data_rng, shuffle_rng = jax.random.split(data_rng)
    train_split = deterministic_data.get_read_instruction_for_host(
        "train", dataset_builder.info.splits["train"].num_examples)
    train_read_config = tfds.ReadConfig(shuffle_seed=shuffle_rng[0])
    train_ds = dataset_builder.as_dataset(split=train_split,
                                          shuffle_files=True,
                                          read_config=train_read_config)

    eval_split_name = {
        "cifar10": "test",
        "imagenet2012": "validation"
    }.get(dataset, "test")

    eval_split_size = dataset_builder.info.splits[eval_split_name].num_examples
    eval_split = deterministic_data.get_read_instruction_for_host(
        eval_split_name, eval_split_size)
    eval_read_config = tfds.ReadConfig(shuffle_seed=shuffle_rng[1])
    eval_ds = dataset_builder.as_dataset(split=eval_split,
                                         shuffle_files=False,
                                         read_config=eval_read_config)
    return train_ds, eval_ds, num_classes
Ejemplo n.º 2
0
def create_train_dataset(
    task,
    batch_size,
    substeps,
    data_rng):
  """Create datasets for training."""
  # Compute batch size per device from global batch size..
  if batch_size % jax.device_count() != 0:
    raise ValueError(f"Batch size ({batch_size}) must be divisible by "
                     f"the number of devices ({jax.device_count()}).")
  per_device_batch_size = batch_size // jax.device_count()

  dataset_builder = tfds.builder(task)

  train_split = deterministic_data.get_read_instruction_for_host(
      "train", dataset_builder.info.splits["train"].num_examples)
  batch_dims = [jax.local_device_count(), substeps, per_device_batch_size]
  train_ds = deterministic_data.create_dataset(
      dataset_builder,
      split=train_split,
      num_epochs=None,
      shuffle=True,
      batch_dims=batch_dims,
      preprocess_fn=_preprocess_cifar10,
      rng=data_rng)

  return dataset_builder.info, train_ds
Ejemplo n.º 3
0
def _create_eval_dataset(config, dataset_builder, split):
    """Create evaluation dataset (validation or test sets)."""
    # This ensures the correct number of elements in the validation sets.
    num_validation_examples = (dataset_builder.info.splits[split].num_examples)
    eval_split = deterministic_data.get_read_instruction_for_host(
        split, dataset_info=dataset_builder.info, drop_remainder=False)

    eval_num_batches = None
    if config.eval_pad_last_batch:
        # This is doing some extra work to get exactly all examples in the
        # validation split. Without this the dataset would first be split between
        # the different hosts and then into batches (both times dropping the
        # remainder). If you don't mind dropping a few extra examples you can omit
        # the `pad_up_to_batches` argument.
        eval_batch_size = jax.local_device_count(
        ) * config.per_device_batch_size
        eval_num_batches = int(
            np.ceil(num_validation_examples / eval_batch_size /
                    jax.host_count()))
    return deterministic_data.create_dataset(
        dataset_builder,
        split=eval_split,
        # Only cache dataset in distributed setup to avoid consuming a lot of
        # memory in Colab and unit tests.
        cache=jax.host_count() > 1,
        batch_dims=[jax.local_device_count(), config.per_device_batch_size],
        num_epochs=1,
        shuffle=False,
        preprocess_fn=_preprocess_spherical_mnist,
        pad_up_to_batches=eval_num_batches,
    )
Ejemplo n.º 4
0
def get_raw_dataset(dataset_builder: tfds.core.DatasetBuilder,
                    split: str,
                    *,
                    reverse_translation: bool = False) -> tf.data.Dataset:
    """Loads a raw WMT dataset and normalizes feature keys.

  Args:
    dataset_builder: TFDS dataset builder that can build `slit`.
    split: Split to use. This must be the full split. We shard the split across
      multiple hosts and currently don't support sharding subsplits.
    reverse_translation: bool: whether to reverse the translation direction.
      e.g. for 'de-en' this translates from english to german.

  Returns:
    Dataset with source and target language features mapped to 'inputs' and
    'targets'.
  """
    num_examples = dataset_builder.info.splits[split].num_examples
    per_host_split = deterministic_data.get_read_instruction_for_host(
        split, num_examples, drop_remainder=False)
    ds = dataset_builder.as_dataset(split=per_host_split, shuffle_files=False)
    ds = ds.map(NormalizeFetaureNamesOp(
        dataset_builder.info, reverse_translation=reverse_translation),
                num_parallel_calls=AUTOTUNE)
    return ds
Ejemplo n.º 5
0
def create_train_dataset(
    config,
    dataset_builder,
    split,
    data_rng,
    preprocess_fn = None,
):
  """Create train dataset."""
  # This ensures determinism in distributed setting.
  train_split = deterministic_data.get_read_instruction_for_host(
      split, dataset_info=dataset_builder.info)
  train_dataset = deterministic_data.create_dataset(
      dataset_builder,
      split=train_split,
      rng=data_rng,
      preprocess_fn=preprocess_fn,
      shuffle_buffer_size=config.shuffle_buffer_size,
      batch_dims=[jax.local_device_count(), config.per_device_batch_size],
      num_epochs=config.num_epochs,
      shuffle=True,
  )
  options = tf.data.Options()
  options.experimental_external_state_policy = (
      tf.data.experimental.ExternalStatePolicy.WARN)
  train_dataset = train_dataset.with_options(options)

  return train_dataset
def get_split(rng,
              builder,
              split,
              batch_size,
              num_epochs=None,
              shuffle_buffer_size=None,
              repeat_after=False,
              cache=False):
    """Loads a audio dataset and shifts audio values to be positive.

  Args:
    rng: JAX PRNGKey random number generator state.
    builder: TFDS dataset builder instance.
    split: TFDS split to load.
    batch_size: Global batch size.
    num_epochs: Number of epochs. None to repeat forever.
    shuffle_buffer_size: Size of the shuffle buffer. If None, data is not
      shuffled.
    repeat_after: If True, the dataset is repeated infinitely *after* CLU.
    cache: If True, the dataset is cached prior to pre-processing.

  Returns:
    Audio datasets with `inputs` and `label` features. The former is shifted to
    be non-negative.
  """
    host_count = jax.process_count()
    if batch_size % host_count != 0:
        raise ValueError(
            f'Batch size ({batch_size}) must be divisible by the host'
            f' count ({host_count}).')
    batch_size = batch_size // host_count
    device_count = jax.local_device_count()
    if batch_size % device_count != 0:
        raise ValueError(
            f'Local batch size ({batch_size}) must be divisible by the'
            f' local device count ({device_count}).')
    batch_dims = [device_count, batch_size // device_count]

    host_split = data.get_read_instruction_for_host(
        split,
        dataset_info=builder.info,
        remainder_options=data.RemainderOptions.BALANCE_ON_PROCESSES)
    ds = data.create_dataset(builder,
                             split=host_split,
                             preprocess_fn=PrepareAudio(),
                             cache=cache,
                             batch_dims=batch_dims,
                             rng=rng,
                             num_epochs=num_epochs,
                             pad_up_to_batches='auto',
                             shuffle=shuffle_buffer_size
                             and (shuffle_buffer_size > 0),
                             shuffle_buffer_size=shuffle_buffer_size or 0)
    if repeat_after:
        ds = ds.repeat()
    return ds
  def test_get_read_instruction_for_host(self, host_id: int, host_count: int,
                                         drop_remainder: bool, spec: str,
                                         expected_spec_for_host: str):

    actual_spec_for_host = deterministic_data.get_read_instruction_for_host(
        spec,
        dataset_info=FakeDatasetInfo(),
        host_id=host_id,
        host_count=host_count,
        drop_remainder=drop_remainder)
    expected_spec_for_host = tfds.core.ReadInstruction.from_spec(
        expected_spec_for_host)
    self.assertEqual(str(actual_spec_for_host), str(expected_spec_for_host))
 def test_get_read_instruction_for_host(self, num_examples: int,
                                        host_id: int, host_count: int,
                                        drop_remainder: bool,
                                        expected_spec: str):
     expected = tfds.core.ReadInstruction.from_spec(expected_spec)
     actual = deterministic_data.get_read_instruction_for_host(
         "test",
         num_examples,
         host_id=host_id,
         host_count=host_count,
         drop_remainder=drop_remainder)
     name2len = {"test": 9}
     self.assertEqual(expected.to_absolute(name2len),
                      actual.to_absolute(name2len))
 def test_same_cardinality_on_all_hosts(self, num_examples: int,
                                        host_count: int):
   builder = MyDatasetBuilder({"train": num_examples})
   cardinalities = []
   for host_id in range(host_count):
     split = deterministic_data.get_read_instruction_for_host(
         split="train",
         num_examples=num_examples,
         host_id=host_id,
         host_count=host_count,
         drop_remainder=True)
     ds = deterministic_data.create_dataset(
         builder, split=split, batch_dims=[2], shuffle=False, num_epochs=1)
     cardinalities.append(ds.cardinality().numpy().item())
   self.assertLen(set(cardinalities), 1)
 def test_get_read_instruction_balance_remainder(self, host_id: int,
                                                 host_count: int,
                                                 balance_remainder: bool,
                                                 spec: str,
                                                 expected_spec_for_host: str):
   actual_spec_for_host = deterministic_data.get_read_instruction_for_host(
       spec,
       dataset_info=FakeDatasetInfo(test_size=10),
       host_id=host_id,
       host_count=host_count,
       remainder_options=deterministic_data.RemainderOptions
       .BALANCE_ON_PROCESSES if balance_remainder else
       deterministic_data.RemainderOptions.ON_FIRST_PROCESS)
   expected_spec_for_host = tfds.core.ReadInstruction.from_spec(
       expected_spec_for_host)
   self.assertEqual(str(actual_spec_for_host), str(expected_spec_for_host))
Ejemplo n.º 11
0
def get_raw_dataset(dataset_builder: tfds.core.DatasetBuilder,
                    split: str) -> tf.data.Dataset:
    """Loads a raw text dataset and normalizes feature keys.

  Args:
    dataset_builder: TFDS dataset builder that can build `split`.
    split: Split to use. This must be the full split. We shard the split across
      multiple hosts and currently don't support sharding subsplits.

  Returns:
    Dataset with source and target language features mapped to 'inputs' and
    'targets'.
  """
    per_host_split = deterministic_data.get_read_instruction_for_host(
        split, dataset_info=dataset_builder.info, drop_remainder=False)
    ds = dataset_builder.as_dataset(split=per_host_split, shuffle_files=False)
    ds = ds.map(NormalizeFeatureNamesOp(dataset_builder.info),
                num_parallel_calls=AUTOTUNE)
    return ds
Ejemplo n.º 12
0
def create_eval_dataset(task, batch_size, subset):
    """Create datasets for evaluation."""
    if batch_size % jax.device_count() != 0:
        raise ValueError(f"Batch size ({batch_size}) must be divisible by "
                         f"the number of devices ({jax.device_count()}).")
    per_device_batch_size = batch_size // jax.device_count()

    dataset_builder = tfds.builder(task)
    dataset_builder.download_and_prepare()

    eval_split = deterministic_data.get_read_instruction_for_host(
        subset, dataset_builder.info.splits[subset].num_examples)
    eval_ds = deterministic_data.create_dataset(
        dataset_builder,
        split=eval_split,
        num_epochs=1,
        shuffle=False,
        batch_dims=[jax.local_device_count(), per_device_batch_size],
        preprocess_fn=_preprocess_cifar10)

    return dataset_builder.info, eval_ds
 def test_get_read_instruction_for_host_deprecated(self, num_examples: int,
                                                   host_id: int,
                                                   host_count: int,
                                                   drop_remainder: bool,
                                                   expected_spec: str):
   expected = tfds.core.ReadInstruction.from_spec(expected_spec)
   actual = deterministic_data.get_read_instruction_for_host(
       "test",
       num_examples,
       host_id=host_id,
       host_count=host_count,
       drop_remainder=drop_remainder)
   if _use_split_info:
     split_infos = {
         "test": tfds.core.SplitInfo(
             name="test",
             shard_lengths=[9],
             num_bytes=0,
         )}
   else:
     split_infos = {"test": 9}
   self.assertEqual(
       expected.to_absolute(split_infos), actual.to_absolute(split_infos))
 def test_same_cardinality_on_all_hosts_with_pad(self, num_examples: int,
                                                 host_count: int):
   builder = MyDatasetBuilder({"train": num_examples})
   # All hosts should have the same number of batches.
   batch_size = 2
   pad_up_to_batches = int(math.ceil(num_examples / (batch_size * host_count)))
   assert pad_up_to_batches * batch_size * host_count >= num_examples
   cardinalities = []
   for host_id in range(host_count):
     split = deterministic_data.get_read_instruction_for_host(
         split="train",
         num_examples=num_examples,
         host_id=host_id,
         host_count=host_count,
         drop_remainder=False)
     ds = deterministic_data.create_dataset(
         builder,
         split=split,
         batch_dims=[batch_size],
         shuffle=False,
         num_epochs=1,
         pad_up_to_batches=pad_up_to_batches)
     cardinalities.append(ds.cardinality().numpy().item())
   self.assertLen(set(cardinalities), 1)
Ejemplo n.º 15
0
def create_datasets(config, data_rng):
    """Create datasets for training and evaluation."""
    # Compute batch size per device from global batch size.
    if config.batch_size % jax.device_count() != 0:
        raise ValueError(
            f'Batch size ({config.batch_size}) must be divisible by '
            f'the number of devices ({jax.device_count()}).')
    per_device_batch_size = config.batch_size // jax.device_count()

    dataset_builder = tfds.builder(config.dataset)

    def cast_int32(batch):
        img = tf.cast(batch['image'], tf.int32)
        out = batch.copy()
        out['image'] = img
        return out

    def drop_info(batch):
        """Removes unwanted keys from batch."""
        if 'id' in batch:
            batch.pop('id')
        if 'rng' in batch:
            batch.pop('rng')
        return batch

    if config.data_augmentation:
        should_augment = True
        should_randflip = True
        should_rotate = True
    else:
        should_augment = False
        should_randflip = False
        should_rotate = False

    def augment(batch):
        img = tf.cast(batch['image'], tf.float32)
        aug = None
        if should_augment:
            if should_randflip:
                img_flipped = tf.image.flip_left_right(img)
                aug = tf.random.uniform(shape=[]) > 0.5
                img = tf.where(aug, img_flipped, img)
            if should_rotate:
                u = tf.random.uniform(shape=[])
                k = tf.cast(tf.floor(4. * u), tf.int32)
                img = tf.image.rot90(img, k=k)
                aug = aug | (k > 0)
        if aug is None:
            aug = tf.convert_to_tensor(False, dtype=tf.bool)

        out = batch.copy()
        out['image'] = img
        return out

    def preprocess_train(batch):
        return cast_int32(augment(drop_info(batch)))

    def preprocess_eval(batch):
        return cast_int32(drop_info(batch))

    # Read instructions to shard the dataset!
    print('train', dataset_builder.info.splits['train'].num_examples)
    # TODO(emielh) use dataset_info instead of num_examples.
    train_split = deterministic_data.get_read_instruction_for_host(
        'train',
        num_examples=dataset_builder.info.splits['train'].num_examples)
    train_ds = deterministic_data.create_dataset(
        dataset_builder,
        split=train_split,
        num_epochs=1,
        shuffle=True,
        batch_dims=[jax.local_device_count(), per_device_batch_size],
        preprocess_fn=preprocess_train,
        rng=data_rng,
        prefetch_size=tf.data.AUTOTUNE,
        drop_remainder=True)

    # TODO(emielh) check if this is necessary?

    # Test batches are _not_ sharded. In the worst case, this simply leads to some
    # duplicated information. In our case, since the elbo is stochastic we get
    # multiple passes over the test data.
    if config.test_batch_size % jax.local_device_count() != 0:
        raise ValueError(
            f'Batch size ({config.batch_size}) must be divisible by '
            f'the number of devices ({jax.local_device_count()}).')
    test_device_batch_size = config.test_batch_size // jax.local_device_count()

    eval_ds = deterministic_data.create_dataset(
        dataset_builder,
        split='test',
        # Repeated epochs for lower variance ELBO estimate.
        num_epochs=config.num_eval_passes,
        shuffle=False,
        batch_dims=[jax.local_device_count(), test_device_batch_size],
        preprocess_fn=preprocess_eval,
        # TODO(emielh) Fix this with batch padding instead of dropping.
        prefetch_size=tf.data.AUTOTUNE,
        drop_remainder=False)

    return dataset_builder.info, train_ds, eval_ds
 def test_get_read_instruction_for_host_fails(self, host_id: int,
                                              host_count: int):
   with self.assertRaises(ValueError):
     deterministic_data.get_read_instruction_for_host(
         "test", 11, host_id=host_id, host_count=host_count)