def _get_uniref(per_host_batch_size, per_host_eval_batch_size, hps, data_rng): """Data generators for Uniref50 clustered protein dataset.""" # TODO(gilmer) Currently uniref drops the last partial batch on eval. logging.warning( 'Currently the Protein dataset drops the last partial batch on eval') if jax.process_count() > 1: raise NotImplementedError( 'Proteins does not support multihost training') n_devices = jax.local_device_count() if per_host_batch_size % n_devices != 0: raise ValueError( 'n_devices={} must divide per_host_batch_size={}.'.format( n_devices, per_host_batch_size)) if per_host_eval_batch_size % n_devices != 0: raise ValueError( 'n_devices={} must divide per_host_eval_batch_size={}.'.format( n_devices, per_host_eval_batch_size)) train_ds, eval_ds, vocab = load_dataset( hps.data_name, batch_size=per_host_batch_size, eval_batch_size=per_host_eval_batch_size, length=hps.max_target_length) masker = BertMasker(vocab=vocab) def train_iterator_fn(): for batch_index, batch in enumerate(iter(train_ds)): batch_rng = jax.random.fold_in(data_rng, batch_index) yield _batch_to_dict(batch, masker, 'train', batch_rng) def eval_train_epoch(num_batches=None): eval_train_iter = iter(train_ds) for batch_index, batch in enumerate( itertools.islice(eval_train_iter, num_batches)): batch_rng = jax.random.fold_in(data_rng, batch_index) yield _batch_to_dict(batch, masker, 'eval', batch_rng) def valid_epoch(num_batches=None): valid_iter = iter(eval_ds) for batch_index, batch in enumerate( itertools.islice(valid_iter, num_batches)): batch_rng = jax.random.fold_in(data_rng, batch_index) yield _batch_to_dict(batch, masker, 'eval', batch_rng) # pylint: disable=unreachable def test_epoch(*args, **kwargs): del args del kwargs return yield # This yield is needed to make this a valid (null) iterator. # pylint: enable=unreachable return Dataset(train_iterator_fn, eval_train_epoch, valid_epoch, test_epoch)
def _get_translate_wmt(per_host_batch_size, per_host_eval_batch_size, hps, shuffle_rng): """Data generators for wmt translate task.""" n_devices = jax.local_device_count() if per_host_batch_size % n_devices != 0: raise ValueError('n_devices={} must divide per_host_batch_size={}.'.format( n_devices, per_host_batch_size)) if per_host_eval_batch_size % n_devices != 0: raise ValueError( 'n_devices={} must divide per_host_eval_batch_size={}.'.format( n_devices, per_host_eval_batch_size)) vocab_path = hps.vocab_path train_ds, eval_ds, predict_ds = mt_pipeline.get_wmt_datasets( hps, shuffle_seed=shuffle_rng[0], sample_seed=shuffle_rng[1], n_devices=jax.local_device_count(), per_host_batch_size=per_host_batch_size, per_host_eval_batch_size=per_host_eval_batch_size, vocab_path=vocab_path) def train_iterator_fn(): for batch in iter(train_ds): yield mt_pipeline.maybe_pad_batch( data_utils.tf_to_numpy(batch), per_host_batch_size, mask_key='targets') def eval_train_epoch(num_batches=None): eval_train_iter = iter(train_ds) for batch in itertools.islice(eval_train_iter, num_batches): yield mt_pipeline.maybe_pad_batch( data_utils.tf_to_numpy(batch), per_host_batch_size, mask_key='targets') def valid_epoch(num_batches=None): valid_iter = iter(eval_ds) for batch in itertools.islice(valid_iter, num_batches): yield mt_pipeline.maybe_pad_batch( data_utils.tf_to_numpy(batch), per_host_eval_batch_size, mask_key='targets') def test_epoch(num_batches=None): predict_iter = iter(predict_ds) for batch in itertools.islice(predict_iter, num_batches): yield mt_pipeline.maybe_pad_batch( data_utils.tf_to_numpy(batch), per_host_eval_batch_size, mask_key='targets') return Dataset(train_iterator_fn, eval_train_epoch, valid_epoch, test_epoch)
def _get_translate_wmt(per_host_batch_size, per_host_eval_batch_size, hps, shuffle_rng): """Data generators for wmt translate task.""" n_devices = jax.local_device_count() if per_host_batch_size % n_devices != 0: raise ValueError( 'n_devices={} must divide per_host_batch_size={}.'.format( n_devices, per_host_batch_size)) if per_host_eval_batch_size % n_devices != 0: raise ValueError( 'n_devices={} must divide per_host_eval_batch_size={}.'.format( n_devices, per_host_eval_batch_size)) vocab_path = hps.vocab_path train_ds, eval_ds, _ = mt_pipeline.get_wmt_datasets( hps, shuffle_seed=shuffle_rng[0], n_devices=jax.local_device_count(), per_host_batch_size=per_host_batch_size, per_host_eval_batch_size=per_host_eval_batch_size, vocab_path=vocab_path) def train_iterator_fn(): for batch in iter(train_ds): yield data_utils.maybe_pad_batch(data_utils.tf_to_numpy(batch), per_host_batch_size, data_format=None, mask_key='targets') def eval_train_epoch(num_batches=None): eval_train_iter = iter(train_ds) for batch in itertools.islice(eval_train_iter, num_batches): yield data_utils.maybe_pad_batch(data_utils.tf_to_numpy(batch), per_host_batch_size, data_format=None, mask_key='targets') def valid_epoch(num_batches=None): valid_iter = iter(eval_ds) for batch in itertools.islice(valid_iter, num_batches): yield data_utils.maybe_pad_batch(data_utils.tf_to_numpy(batch), per_host_eval_batch_size, data_format=None, mask_key='targets') # pylint: disable=unreachable def test_epoch(*args, **kwargs): del args del kwargs return yield # This yield is needed to make this a valid (null) iterator. # pylint: enable=unreachable return Dataset(train_iterator_fn, eval_train_epoch, valid_epoch, test_epoch)
def get_criteo1tb(unused_shuffle_rng, batch_size, eval_batch_size, hps): """Get the Criteo 1TB train and eval iterators.""" process_count = jax.process_count() if batch_size % process_count != 0: raise ValueError('process_count={} must divide batch_size={}.'.format( process_count, batch_size)) if eval_batch_size is None: eval_batch_size = batch_size if eval_batch_size % process_count != 0: raise ValueError( 'process_count={} must divide eval_batch_size={}.'.format( process_count, eval_batch_size)) per_host_eval_batch_size = eval_batch_size // process_count per_host_batch_size = batch_size // process_count train_dataset = _criteo_tsv_reader( file_path=hps.train_file_path, num_dense_features=hps.num_dense_features, vocab_sizes=hps.vocab_sizes, batch_size=per_host_batch_size, is_training=True) train_iterator_fn = lambda: tfds.as_numpy(train_dataset) eval_train_dataset = _criteo_tsv_reader( file_path=hps.train_file_path, num_dense_features=hps.num_dense_features, vocab_sizes=hps.vocab_sizes, batch_size=per_host_eval_batch_size, is_training=False) eval_train_epoch = functools.partial(convert_to_numpy_iterator_fn, tf_dataset=eval_train_dataset) eval_dataset = _criteo_tsv_reader( file_path=hps.eval_file_path, num_dense_features=hps.num_dense_features, vocab_sizes=hps.vocab_sizes, batch_size=per_host_eval_batch_size, is_training=False) eval_iterator_fn = functools.partial(convert_to_numpy_iterator_fn, tf_dataset=eval_dataset) # pylint: disable=unreachable def test_epoch(*args, **kwargs): del args del kwargs return yield # This yield is needed to make this a valid (null) iterator. # pylint: enable=unreachable return Dataset(train_iterator_fn, eval_train_epoch, eval_iterator_fn, test_epoch)
def get_ogbg_molpcba(shuffle_rng, batch_size, eval_batch_size, hps=None): """Data generators for ogbg-molpcba.""" shuffle_buffer_size = 2**15 shuffle_rng_train, shuffle_rng_eval_train = jax.random.split(shuffle_rng) train_ds = _load_dataset('train', should_shuffle=True, shuffle_seed=shuffle_rng_train, shuffle_buffer_size=shuffle_buffer_size) eval_train_ds = _load_dataset('train', should_shuffle=True, shuffle_seed=shuffle_rng_eval_train, shuffle_buffer_size=shuffle_buffer_size) valid_ds = _load_dataset('validation') test_ds = _load_dataset('test') iterator_from_ds = functools.partial( _get_batch_iterator, nodes_per_graph=hps.max_nodes_multiplier, edges_per_graph=hps.max_edges_multiplier, add_bidirectional_edges=hps.add_bidirectional_edges, add_virtual_node=hps.add_virtual_node, add_self_loops=hps.add_self_loops) def train_iterator_fn(): return iterator_from_ds(dataset_iter=iter(train_ds), batch_size=batch_size) def eval_train_epoch(num_batches=None): return itertools.islice( iterator_from_ds(dataset_iter=iter(eval_train_ds), batch_size=eval_batch_size), num_batches) def valid_epoch(num_batches=None): return itertools.islice( iterator_from_ds(dataset_iter=iter(valid_ds), batch_size=eval_batch_size), num_batches) def test_epoch(num_batches=None): return itertools.islice( iterator_from_ds(dataset_iter=iter(test_ds), batch_size=eval_batch_size), num_batches) return Dataset(train_iterator_fn, eval_train_epoch, valid_epoch, test_epoch)
def _get_lm1b(hps, per_host_batch_size, per_host_eval_batch_size, shuffle_rng): """Data generators for lm1b.""" n_devices = jax.local_device_count() if per_host_batch_size % n_devices != 0: raise ValueError('n_devices={} must divide per_host_batch_size={}.'.format( n_devices, per_host_batch_size)) if per_host_eval_batch_size % n_devices != 0: raise ValueError( 'n_devices={} must divide per_host_eval_batch_size={}.'.format( n_devices, per_host_eval_batch_size)) train_ds, eval_ds = lm1b_input_pipeline_v2.get_lm1b_datasets( hps, per_host_batch_size, per_host_eval_batch_size, shuffle_rng) def train_iterator_fn(): for batch in iter(train_ds): yield _batch_to_dict(batch) def eval_train_epoch(num_batches=None): eval_train_iter = iter(train_ds) for batch in itertools.islice(eval_train_iter, num_batches): yield _batch_to_dict(batch) def valid_epoch(num_batches=None): valid_iter = iter(eval_ds) for batch in itertools.islice(valid_iter, num_batches): yield _batch_to_dict(batch) # pylint: disable=unreachable def test_epoch(*args, **kwargs): del args del kwargs return yield # This yield is needed to make this a valid (null) iterator. # pylint: enable=unreachable return Dataset(train_iterator_fn, eval_train_epoch, valid_epoch, test_epoch)
def _prepare_small_image_datasets(data_train, data_valid, data_test, per_host_batch_size, per_host_eval_batch_size, train_size, rescale, input_shape, output_shape, shuffle_rng, augment_fn, is_one_hot=True, autoencoder=False, include_example_keys=False): """Prepare Dataset using tf.data.Datasets of the different splits.""" if autoencoder and is_one_hot: raise ValueError( 'One hot encoding cannot be applied to autoencoder datasets.') eval_image_iterator = functools.partial(data_utils.image_iterator, rescale=rescale, output_shape=output_shape, is_one_hot=is_one_hot, autoencoder=autoencoder) # Setup the eval_train split as a copy of the training data, in the form of # the first `num_train_batches` batches of the data as an np.array. eval_train_iterator = eval_image_iterator(data_train) num_train_batches = train_size // per_host_batch_size eval_train_data = list( itertools.islice(eval_train_iterator, 0, num_train_batches)) eval_train_inputs = jnp.array( [batch['inputs'] for batch in eval_train_data]) eval_train_inputs_shape = (num_train_batches * per_host_batch_size, *input_shape) eval_train_inputs = np.reshape(eval_train_inputs, eval_train_inputs_shape) eval_train_targets = jnp.array( [batch['targets'] for batch in eval_train_data]) eval_train_output_shape = (num_train_batches * per_host_batch_size, *output_shape) eval_train_targets = np.reshape(eval_train_targets, eval_train_output_shape) valid_inputs = jnp.array([]) valid_targets = jnp.array([]) valid_example_keys = jnp.array([]) if data_valid: valid_data = next( eval_image_iterator(data_valid, include_example_keys=include_example_keys)) valid_inputs = valid_data['inputs'] valid_targets = valid_data['targets'] if include_example_keys: valid_example_keys = valid_data['tfds_id'] test_data = next(eval_image_iterator(data_test)) test_inputs = jnp.array(test_data['inputs'].astype(np.float32)) test_targets = test_data['targets'] # NOTE(dsuo): each host should see the entire dataset, so we remove sharding # by host. # TODO(gilmer): The simplest way to do this would be to directly yield from # tfds.as_numpy(). However we currently do not know how to properly handle # restarts with tfds. We'd like the train shuffle to depend on the epoch # number, so everytime we generate epoch 10 it yields the same pseudorandom # order. train_iterator_fn = functools.partial(data_utils.image_iterator, data_train, rescale=rescale, output_shape=output_shape, is_one_hot=is_one_hot, autoencoder=autoencoder, shuffle_rng=shuffle_rng, augment_fn=augment_fn) eval_train_epoch = functools.partial(_eval_batches, eval_train_inputs, eval_train_targets, per_host_eval_batch_size) valid_epoch = functools.partial(_eval_batches, valid_inputs, valid_targets, per_host_eval_batch_size, valid_example_keys=valid_example_keys) test_epoch = functools.partial(_eval_batches, test_inputs, test_targets, per_host_eval_batch_size) return Dataset(train_iterator_fn, eval_train_epoch, valid_epoch, test_epoch)
def get_nqm_noise(shuffle_rng, batch_size, eval_batch_size, hps=None): """Returns the noise seed for the nqm model. NOTE: This dataset is only meant to be used with the nqm model. This just generates isotropic Gaussian noise of the desired dimension. The nqm model will then multiple this noise by a matrix D, with the properly that D^T D = C. This yields noise with gradient covariance C. Args: shuffle_rng: Not used. batch_size: The global train batch size, used to determine the batch size yielded from train_epoch(). eval_batch_size: Not used. hps: Hparams object. We only refer to hps.input_shape to determine the dimension of the noise. Returns: train_epoch, eval_train_epoch, valid_epoch, test_epoch: three generators. Only train_epoch is used. """ del eval_batch_size per_host_batch_size = batch_size // jax.process_count() # We only use the first part of the seed, which may result in slightly more # rng collisions than normal. train_rng = np.random.RandomState(seed=shuffle_rng[0]) eval_rng = np.random.RandomState(seed=shuffle_rng[0]) def train_iterator_fn(): while True: yield { 'inputs': train_rng.normal(size=(per_host_batch_size, *hps.input_shape)) } def eval_train_epoch(num_batches): for _ in range(num_batches): yield { 'inputs': eval_rng.normal(size=(per_host_batch_size, *hps.input_shape)) } # pylint: disable=unreachable def valid_epoch(*args, **kwargs): del args del kwargs return yield # This yield is needed to make this a valid (null) iterator. # pylint: enable=unreachable # pylint: disable=unreachable def test_epoch(*args, **kwargs): del args del kwargs return yield # This yield is needed to make this a valid (null) iterator. # pylint: enable=unreachable return Dataset(train_iterator_fn, eval_train_epoch, valid_epoch, test_epoch)