def _get_tfds_dataset( dataset: str, rng: np.ndarray) -> Tuple[tf.data.Dataset, tf.data.Dataset, int]: """Loads a TFDS dataset.""" dataset_builder = tfds.builder(dataset) num_classes = 0 if "label" in dataset_builder.info.features: num_classes = dataset_builder.info.features["label"].num_classes # Make sure each host uses a different RNG for the training data. rng, data_rng = jax.random.split(rng) data_rng = jax.random.fold_in(data_rng, jax.host_id()) data_rng, shuffle_rng = jax.random.split(data_rng) train_split = deterministic_data.get_read_instruction_for_host( "train", dataset_builder.info.splits["train"].num_examples) train_read_config = tfds.ReadConfig(shuffle_seed=shuffle_rng[0]) train_ds = dataset_builder.as_dataset(split=train_split, shuffle_files=True, read_config=train_read_config) eval_split_name = { "cifar10": "test", "imagenet2012": "validation" }.get(dataset, "test") eval_split_size = dataset_builder.info.splits[eval_split_name].num_examples eval_split = deterministic_data.get_read_instruction_for_host( eval_split_name, eval_split_size) eval_read_config = tfds.ReadConfig(shuffle_seed=shuffle_rng[1]) eval_ds = dataset_builder.as_dataset(split=eval_split, shuffle_files=False, read_config=eval_read_config) return train_ds, eval_ds, num_classes
def create_train_dataset( task, batch_size, substeps, data_rng): """Create datasets for training.""" # Compute batch size per device from global batch size.. if batch_size % jax.device_count() != 0: raise ValueError(f"Batch size ({batch_size}) must be divisible by " f"the number of devices ({jax.device_count()}).") per_device_batch_size = batch_size // jax.device_count() dataset_builder = tfds.builder(task) train_split = deterministic_data.get_read_instruction_for_host( "train", dataset_builder.info.splits["train"].num_examples) batch_dims = [jax.local_device_count(), substeps, per_device_batch_size] train_ds = deterministic_data.create_dataset( dataset_builder, split=train_split, num_epochs=None, shuffle=True, batch_dims=batch_dims, preprocess_fn=_preprocess_cifar10, rng=data_rng) return dataset_builder.info, train_ds
def _create_eval_dataset(config, dataset_builder, split): """Create evaluation dataset (validation or test sets).""" # This ensures the correct number of elements in the validation sets. num_validation_examples = (dataset_builder.info.splits[split].num_examples) eval_split = deterministic_data.get_read_instruction_for_host( split, dataset_info=dataset_builder.info, drop_remainder=False) eval_num_batches = None if config.eval_pad_last_batch: # This is doing some extra work to get exactly all examples in the # validation split. Without this the dataset would first be split between # the different hosts and then into batches (both times dropping the # remainder). If you don't mind dropping a few extra examples you can omit # the `pad_up_to_batches` argument. eval_batch_size = jax.local_device_count( ) * config.per_device_batch_size eval_num_batches = int( np.ceil(num_validation_examples / eval_batch_size / jax.host_count())) return deterministic_data.create_dataset( dataset_builder, split=eval_split, # Only cache dataset in distributed setup to avoid consuming a lot of # memory in Colab and unit tests. cache=jax.host_count() > 1, batch_dims=[jax.local_device_count(), config.per_device_batch_size], num_epochs=1, shuffle=False, preprocess_fn=_preprocess_spherical_mnist, pad_up_to_batches=eval_num_batches, )
def get_raw_dataset(dataset_builder: tfds.core.DatasetBuilder, split: str, *, reverse_translation: bool = False) -> tf.data.Dataset: """Loads a raw WMT dataset and normalizes feature keys. Args: dataset_builder: TFDS dataset builder that can build `slit`. split: Split to use. This must be the full split. We shard the split across multiple hosts and currently don't support sharding subsplits. reverse_translation: bool: whether to reverse the translation direction. e.g. for 'de-en' this translates from english to german. Returns: Dataset with source and target language features mapped to 'inputs' and 'targets'. """ num_examples = dataset_builder.info.splits[split].num_examples per_host_split = deterministic_data.get_read_instruction_for_host( split, num_examples, drop_remainder=False) ds = dataset_builder.as_dataset(split=per_host_split, shuffle_files=False) ds = ds.map(NormalizeFetaureNamesOp( dataset_builder.info, reverse_translation=reverse_translation), num_parallel_calls=AUTOTUNE) return ds
def create_train_dataset( config, dataset_builder, split, data_rng, preprocess_fn = None, ): """Create train dataset.""" # This ensures determinism in distributed setting. train_split = deterministic_data.get_read_instruction_for_host( split, dataset_info=dataset_builder.info) train_dataset = deterministic_data.create_dataset( dataset_builder, split=train_split, rng=data_rng, preprocess_fn=preprocess_fn, shuffle_buffer_size=config.shuffle_buffer_size, batch_dims=[jax.local_device_count(), config.per_device_batch_size], num_epochs=config.num_epochs, shuffle=True, ) options = tf.data.Options() options.experimental_external_state_policy = ( tf.data.experimental.ExternalStatePolicy.WARN) train_dataset = train_dataset.with_options(options) return train_dataset
def get_split(rng, builder, split, batch_size, num_epochs=None, shuffle_buffer_size=None, repeat_after=False, cache=False): """Loads a audio dataset and shifts audio values to be positive. Args: rng: JAX PRNGKey random number generator state. builder: TFDS dataset builder instance. split: TFDS split to load. batch_size: Global batch size. num_epochs: Number of epochs. None to repeat forever. shuffle_buffer_size: Size of the shuffle buffer. If None, data is not shuffled. repeat_after: If True, the dataset is repeated infinitely *after* CLU. cache: If True, the dataset is cached prior to pre-processing. Returns: Audio datasets with `inputs` and `label` features. The former is shifted to be non-negative. """ host_count = jax.process_count() if batch_size % host_count != 0: raise ValueError( f'Batch size ({batch_size}) must be divisible by the host' f' count ({host_count}).') batch_size = batch_size // host_count device_count = jax.local_device_count() if batch_size % device_count != 0: raise ValueError( f'Local batch size ({batch_size}) must be divisible by the' f' local device count ({device_count}).') batch_dims = [device_count, batch_size // device_count] host_split = data.get_read_instruction_for_host( split, dataset_info=builder.info, remainder_options=data.RemainderOptions.BALANCE_ON_PROCESSES) ds = data.create_dataset(builder, split=host_split, preprocess_fn=PrepareAudio(), cache=cache, batch_dims=batch_dims, rng=rng, num_epochs=num_epochs, pad_up_to_batches='auto', shuffle=shuffle_buffer_size and (shuffle_buffer_size > 0), shuffle_buffer_size=shuffle_buffer_size or 0) if repeat_after: ds = ds.repeat() return ds
def test_get_read_instruction_for_host(self, host_id: int, host_count: int, drop_remainder: bool, spec: str, expected_spec_for_host: str): actual_spec_for_host = deterministic_data.get_read_instruction_for_host( spec, dataset_info=FakeDatasetInfo(), host_id=host_id, host_count=host_count, drop_remainder=drop_remainder) expected_spec_for_host = tfds.core.ReadInstruction.from_spec( expected_spec_for_host) self.assertEqual(str(actual_spec_for_host), str(expected_spec_for_host))
def test_get_read_instruction_for_host(self, num_examples: int, host_id: int, host_count: int, drop_remainder: bool, expected_spec: str): expected = tfds.core.ReadInstruction.from_spec(expected_spec) actual = deterministic_data.get_read_instruction_for_host( "test", num_examples, host_id=host_id, host_count=host_count, drop_remainder=drop_remainder) name2len = {"test": 9} self.assertEqual(expected.to_absolute(name2len), actual.to_absolute(name2len))
def test_same_cardinality_on_all_hosts(self, num_examples: int, host_count: int): builder = MyDatasetBuilder({"train": num_examples}) cardinalities = [] for host_id in range(host_count): split = deterministic_data.get_read_instruction_for_host( split="train", num_examples=num_examples, host_id=host_id, host_count=host_count, drop_remainder=True) ds = deterministic_data.create_dataset( builder, split=split, batch_dims=[2], shuffle=False, num_epochs=1) cardinalities.append(ds.cardinality().numpy().item()) self.assertLen(set(cardinalities), 1)
def test_get_read_instruction_balance_remainder(self, host_id: int, host_count: int, balance_remainder: bool, spec: str, expected_spec_for_host: str): actual_spec_for_host = deterministic_data.get_read_instruction_for_host( spec, dataset_info=FakeDatasetInfo(test_size=10), host_id=host_id, host_count=host_count, remainder_options=deterministic_data.RemainderOptions .BALANCE_ON_PROCESSES if balance_remainder else deterministic_data.RemainderOptions.ON_FIRST_PROCESS) expected_spec_for_host = tfds.core.ReadInstruction.from_spec( expected_spec_for_host) self.assertEqual(str(actual_spec_for_host), str(expected_spec_for_host))
def get_raw_dataset(dataset_builder: tfds.core.DatasetBuilder, split: str) -> tf.data.Dataset: """Loads a raw text dataset and normalizes feature keys. Args: dataset_builder: TFDS dataset builder that can build `split`. split: Split to use. This must be the full split. We shard the split across multiple hosts and currently don't support sharding subsplits. Returns: Dataset with source and target language features mapped to 'inputs' and 'targets'. """ per_host_split = deterministic_data.get_read_instruction_for_host( split, dataset_info=dataset_builder.info, drop_remainder=False) ds = dataset_builder.as_dataset(split=per_host_split, shuffle_files=False) ds = ds.map(NormalizeFeatureNamesOp(dataset_builder.info), num_parallel_calls=AUTOTUNE) return ds
def create_eval_dataset(task, batch_size, subset): """Create datasets for evaluation.""" if batch_size % jax.device_count() != 0: raise ValueError(f"Batch size ({batch_size}) must be divisible by " f"the number of devices ({jax.device_count()}).") per_device_batch_size = batch_size // jax.device_count() dataset_builder = tfds.builder(task) dataset_builder.download_and_prepare() eval_split = deterministic_data.get_read_instruction_for_host( subset, dataset_builder.info.splits[subset].num_examples) eval_ds = deterministic_data.create_dataset( dataset_builder, split=eval_split, num_epochs=1, shuffle=False, batch_dims=[jax.local_device_count(), per_device_batch_size], preprocess_fn=_preprocess_cifar10) return dataset_builder.info, eval_ds
def test_get_read_instruction_for_host_deprecated(self, num_examples: int, host_id: int, host_count: int, drop_remainder: bool, expected_spec: str): expected = tfds.core.ReadInstruction.from_spec(expected_spec) actual = deterministic_data.get_read_instruction_for_host( "test", num_examples, host_id=host_id, host_count=host_count, drop_remainder=drop_remainder) if _use_split_info: split_infos = { "test": tfds.core.SplitInfo( name="test", shard_lengths=[9], num_bytes=0, )} else: split_infos = {"test": 9} self.assertEqual( expected.to_absolute(split_infos), actual.to_absolute(split_infos))
def test_same_cardinality_on_all_hosts_with_pad(self, num_examples: int, host_count: int): builder = MyDatasetBuilder({"train": num_examples}) # All hosts should have the same number of batches. batch_size = 2 pad_up_to_batches = int(math.ceil(num_examples / (batch_size * host_count))) assert pad_up_to_batches * batch_size * host_count >= num_examples cardinalities = [] for host_id in range(host_count): split = deterministic_data.get_read_instruction_for_host( split="train", num_examples=num_examples, host_id=host_id, host_count=host_count, drop_remainder=False) ds = deterministic_data.create_dataset( builder, split=split, batch_dims=[batch_size], shuffle=False, num_epochs=1, pad_up_to_batches=pad_up_to_batches) cardinalities.append(ds.cardinality().numpy().item()) self.assertLen(set(cardinalities), 1)
def create_datasets(config, data_rng): """Create datasets for training and evaluation.""" # Compute batch size per device from global batch size. if config.batch_size % jax.device_count() != 0: raise ValueError( f'Batch size ({config.batch_size}) must be divisible by ' f'the number of devices ({jax.device_count()}).') per_device_batch_size = config.batch_size // jax.device_count() dataset_builder = tfds.builder(config.dataset) def cast_int32(batch): img = tf.cast(batch['image'], tf.int32) out = batch.copy() out['image'] = img return out def drop_info(batch): """Removes unwanted keys from batch.""" if 'id' in batch: batch.pop('id') if 'rng' in batch: batch.pop('rng') return batch if config.data_augmentation: should_augment = True should_randflip = True should_rotate = True else: should_augment = False should_randflip = False should_rotate = False def augment(batch): img = tf.cast(batch['image'], tf.float32) aug = None if should_augment: if should_randflip: img_flipped = tf.image.flip_left_right(img) aug = tf.random.uniform(shape=[]) > 0.5 img = tf.where(aug, img_flipped, img) if should_rotate: u = tf.random.uniform(shape=[]) k = tf.cast(tf.floor(4. * u), tf.int32) img = tf.image.rot90(img, k=k) aug = aug | (k > 0) if aug is None: aug = tf.convert_to_tensor(False, dtype=tf.bool) out = batch.copy() out['image'] = img return out def preprocess_train(batch): return cast_int32(augment(drop_info(batch))) def preprocess_eval(batch): return cast_int32(drop_info(batch)) # Read instructions to shard the dataset! print('train', dataset_builder.info.splits['train'].num_examples) # TODO(emielh) use dataset_info instead of num_examples. train_split = deterministic_data.get_read_instruction_for_host( 'train', num_examples=dataset_builder.info.splits['train'].num_examples) train_ds = deterministic_data.create_dataset( dataset_builder, split=train_split, num_epochs=1, shuffle=True, batch_dims=[jax.local_device_count(), per_device_batch_size], preprocess_fn=preprocess_train, rng=data_rng, prefetch_size=tf.data.AUTOTUNE, drop_remainder=True) # TODO(emielh) check if this is necessary? # Test batches are _not_ sharded. In the worst case, this simply leads to some # duplicated information. In our case, since the elbo is stochastic we get # multiple passes over the test data. if config.test_batch_size % jax.local_device_count() != 0: raise ValueError( f'Batch size ({config.batch_size}) must be divisible by ' f'the number of devices ({jax.local_device_count()}).') test_device_batch_size = config.test_batch_size // jax.local_device_count() eval_ds = deterministic_data.create_dataset( dataset_builder, split='test', # Repeated epochs for lower variance ELBO estimate. num_epochs=config.num_eval_passes, shuffle=False, batch_dims=[jax.local_device_count(), test_device_batch_size], preprocess_fn=preprocess_eval, # TODO(emielh) Fix this with batch padding instead of dropping. prefetch_size=tf.data.AUTOTUNE, drop_remainder=False) return dataset_builder.info, train_ds, eval_ds
def test_get_read_instruction_for_host_fails(self, host_id: int, host_count: int): with self.assertRaises(ValueError): deterministic_data.get_read_instruction_for_host( "test", 11, host_id=host_id, host_count=host_count)