def pad_and_batch_with_rng( it, num_devices, padding_and_batch_sizes, base_rng): """Pad and batch according to a collection of sizes. Args: it: Iterable over individual examples. num_devices: Number of devices; determines constant leading batch dimension. padding_and_batch_sizes: List of pairs of padding config and per-device batch size. Padding configs will be tried in order until the example fits in one. base_rng: PRNGKey to use to generate RNG seeds. Yields: Batched tuples of padded examples and RNG keys. Each batch will contain examples of approximately the same shape, and the `static_metadata` field for each will be the padding config used. RNG keys are deterministically based on `base_rng` and the order of examples in `it` (i.e. the nth example from `it` will get a specific RNG value, regardless of padding and batch sizes). """ # Assign each example to a bucket, and pad it appropriately def _find_buckets_and_pad(): for example_number, ex in enumerate(it): padded_example = None for (current_bucket, (padding_config, _)) in enumerate(padding_and_batch_sizes): padded_example = example_definition.pad_example( ex.example, padding_config, allow_failure=True) if padded_example: bucket = current_bucket break if padded_example: example_rng = jax.random.fold_in(base_rng, example_number) yield (bucket, dataclasses.replace(ex, example=(padded_example, example_rng))) else: logging.info('Dropping example %d (exceeded padding config)', ex.example_id) # Batch within each bucket. batched = data_loading.batch_bucketed( _find_buckets_and_pad(), batch_dim_sizes={ i: (num_devices, device_batch_size) for i, (_, device_batch_size) in enumerate(padding_and_batch_sizes) }, remainder_behavior=data_loading.BatchRemainderBehavior.PAD_ZERO) # Move the bucket's padding config into the batch metadata. for bucket, ex in batched: yield dataclasses.replace( ex, static_metadata=padding_and_batch_sizes[bucket][0])
def test_batch_bucketed(self): values = [("a", (1, )), ("b", (2, )), ("c", (3, )), ("a", (4, )), ("b", (5, )), ("c", (6, )), ("a", (7, )), ("b", (8, )), ("c", (9, )), ("a", (10, )), ("b", (11, )), ("c", (12, ))] batched = list( data_loading.batch_bucketed(values, { "a": (2, ), "b": (3, ), "c": (5, ) }, remainder_behavior=data_loading. BatchRemainderBehavior.PAD_ZERO)) expected = [ ("a", (np.array([1, 4]), )), ("b", (np.array([2, 5, 8]), )), ("a", (np.array([7, 10]), )), ("b", (np.array([11, 0, 0]), )), ("c", (np.array([3, 6, 9, 12, 0]), )), ] self.assertEqual(len(batched), len(expected)) for (bk, bv), (ek, ev) in zip(batched, expected): self.assertEqual(bk, ek) jax.test_util.check_eq(bv, ev)