def pad_and_batch_with_rng(
    it, num_devices,
    padding_and_batch_sizes,
    base_rng):
  """Pad and batch according to a collection of sizes.

  Args:
    it: Iterable over individual examples.
    num_devices: Number of devices; determines constant leading batch dimension.
    padding_and_batch_sizes: List of pairs of padding config and per-device
      batch size. Padding configs will be tried in order until the example fits
      in one.
    base_rng: PRNGKey to use to generate RNG seeds.

  Yields:
    Batched tuples of padded examples and RNG keys. Each batch will contain
    examples of approximately the same shape, and the `static_metadata` field
    for each will be the padding config used. RNG keys are deterministically
    based on `base_rng` and the order of examples in `it` (i.e. the nth
    example from `it` will get a specific RNG value, regardless of padding and
    batch sizes).
  """

  # Assign each example to a bucket, and pad it appropriately
  def _find_buckets_and_pad():
    for example_number, ex in enumerate(it):
      padded_example = None
      for (current_bucket, (padding_config,
                            _)) in enumerate(padding_and_batch_sizes):
        padded_example = example_definition.pad_example(
            ex.example, padding_config, allow_failure=True)
        if padded_example:
          bucket = current_bucket
          break

      if padded_example:
        example_rng = jax.random.fold_in(base_rng, example_number)
        yield (bucket,
               dataclasses.replace(ex, example=(padded_example, example_rng)))
      else:
        logging.info('Dropping example %d (exceeded padding config)',
                     ex.example_id)

  # Batch within each bucket.
  batched = data_loading.batch_bucketed(
      _find_buckets_and_pad(),
      batch_dim_sizes={
          i: (num_devices, device_batch_size)
          for i, (_, device_batch_size) in enumerate(padding_and_batch_sizes)
      },
      remainder_behavior=data_loading.BatchRemainderBehavior.PAD_ZERO)

  # Move the bucket's padding config into the batch metadata.
  for bucket, ex in batched:
    yield dataclasses.replace(
        ex, static_metadata=padding_and_batch_sizes[bucket][0])
예제 #2
0
 def test_batch_bucketed(self):
     values = [("a", (1, )), ("b", (2, )), ("c", (3, )), ("a", (4, )),
               ("b", (5, )), ("c", (6, )), ("a", (7, )), ("b", (8, )),
               ("c", (9, )), ("a", (10, )), ("b", (11, )), ("c", (12, ))]
     batched = list(
         data_loading.batch_bucketed(values, {
             "a": (2, ),
             "b": (3, ),
             "c": (5, )
         },
                                     remainder_behavior=data_loading.
                                     BatchRemainderBehavior.PAD_ZERO))
     expected = [
         ("a", (np.array([1, 4]), )),
         ("b", (np.array([2, 5, 8]), )),
         ("a", (np.array([7, 10]), )),
         ("b", (np.array([11, 0, 0]), )),
         ("c", (np.array([3, 6, 9, 12, 0]), )),
     ]
     self.assertEqual(len(batched), len(expected))
     for (bk, bv), (ek, ev) in zip(batched, expected):
         self.assertEqual(bk, ek)
         jax.test_util.check_eq(bv, ev)