示例#1
0
def _get_uniref(per_host_batch_size, per_host_eval_batch_size, hps, data_rng):
    """Data generators for Uniref50 clustered protein dataset."""
    # TODO(gilmer) Currently uniref drops the last partial batch on eval.
    logging.warning(
        'Currently the Protein dataset drops the last partial batch on eval')
    if jax.process_count() > 1:
        raise NotImplementedError(
            'Proteins does not support multihost training')

    n_devices = jax.local_device_count()
    if per_host_batch_size % n_devices != 0:
        raise ValueError(
            'n_devices={} must divide per_host_batch_size={}.'.format(
                n_devices, per_host_batch_size))
    if per_host_eval_batch_size % n_devices != 0:
        raise ValueError(
            'n_devices={} must divide per_host_eval_batch_size={}.'.format(
                n_devices, per_host_eval_batch_size))

    train_ds, eval_ds, vocab = load_dataset(
        hps.data_name,
        batch_size=per_host_batch_size,
        eval_batch_size=per_host_eval_batch_size,
        length=hps.max_target_length)

    masker = BertMasker(vocab=vocab)

    def train_iterator_fn():
        for batch_index, batch in enumerate(iter(train_ds)):
            batch_rng = jax.random.fold_in(data_rng, batch_index)
            yield _batch_to_dict(batch, masker, 'train', batch_rng)

    def eval_train_epoch(num_batches=None):
        eval_train_iter = iter(train_ds)
        for batch_index, batch in enumerate(
                itertools.islice(eval_train_iter, num_batches)):
            batch_rng = jax.random.fold_in(data_rng, batch_index)
            yield _batch_to_dict(batch, masker, 'eval', batch_rng)

    def valid_epoch(num_batches=None):
        valid_iter = iter(eval_ds)
        for batch_index, batch in enumerate(
                itertools.islice(valid_iter, num_batches)):
            batch_rng = jax.random.fold_in(data_rng, batch_index)
            yield _batch_to_dict(batch, masker, 'eval', batch_rng)

    # pylint: disable=unreachable
    def test_epoch(*args, **kwargs):
        del args
        del kwargs
        return
        yield  # This yield is needed to make this a valid (null) iterator.

    # pylint: enable=unreachable

    return Dataset(train_iterator_fn, eval_train_epoch, valid_epoch,
                   test_epoch)
示例#2
0
def _get_translate_wmt(per_host_batch_size,
                       per_host_eval_batch_size,
                       hps,
                       shuffle_rng):
  """Data generators for wmt translate task."""
  n_devices = jax.local_device_count()
  if per_host_batch_size % n_devices != 0:
    raise ValueError('n_devices={} must divide per_host_batch_size={}.'.format(
        n_devices, per_host_batch_size))
  if per_host_eval_batch_size % n_devices != 0:
    raise ValueError(
        'n_devices={} must divide per_host_eval_batch_size={}.'.format(
            n_devices, per_host_eval_batch_size))

  vocab_path = hps.vocab_path
  train_ds, eval_ds, predict_ds = mt_pipeline.get_wmt_datasets(
      hps,
      shuffle_seed=shuffle_rng[0],
      sample_seed=shuffle_rng[1],
      n_devices=jax.local_device_count(),
      per_host_batch_size=per_host_batch_size,
      per_host_eval_batch_size=per_host_eval_batch_size,
      vocab_path=vocab_path)

  def train_iterator_fn():
    for batch in iter(train_ds):
      yield mt_pipeline.maybe_pad_batch(
          data_utils.tf_to_numpy(batch),
          per_host_batch_size,
          mask_key='targets')

  def eval_train_epoch(num_batches=None):
    eval_train_iter = iter(train_ds)
    for batch in itertools.islice(eval_train_iter, num_batches):
      yield mt_pipeline.maybe_pad_batch(
          data_utils.tf_to_numpy(batch),
          per_host_batch_size,
          mask_key='targets')

  def valid_epoch(num_batches=None):
    valid_iter = iter(eval_ds)
    for batch in itertools.islice(valid_iter, num_batches):
      yield mt_pipeline.maybe_pad_batch(
          data_utils.tf_to_numpy(batch),
          per_host_eval_batch_size,
          mask_key='targets')

  def test_epoch(num_batches=None):
    predict_iter = iter(predict_ds)
    for batch in itertools.islice(predict_iter, num_batches):
      yield mt_pipeline.maybe_pad_batch(
          data_utils.tf_to_numpy(batch),
          per_host_eval_batch_size,
          mask_key='targets')

  return Dataset(train_iterator_fn, eval_train_epoch, valid_epoch, test_epoch)
示例#3
0
def _get_translate_wmt(per_host_batch_size, per_host_eval_batch_size, hps,
                       shuffle_rng):
    """Data generators for wmt translate task."""
    n_devices = jax.local_device_count()
    if per_host_batch_size % n_devices != 0:
        raise ValueError(
            'n_devices={} must divide per_host_batch_size={}.'.format(
                n_devices, per_host_batch_size))
    if per_host_eval_batch_size % n_devices != 0:
        raise ValueError(
            'n_devices={} must divide per_host_eval_batch_size={}.'.format(
                n_devices, per_host_eval_batch_size))

    vocab_path = hps.vocab_path
    train_ds, eval_ds, _ = mt_pipeline.get_wmt_datasets(
        hps,
        shuffle_seed=shuffle_rng[0],
        n_devices=jax.local_device_count(),
        per_host_batch_size=per_host_batch_size,
        per_host_eval_batch_size=per_host_eval_batch_size,
        vocab_path=vocab_path)

    def train_iterator_fn():
        for batch in iter(train_ds):
            yield data_utils.maybe_pad_batch(data_utils.tf_to_numpy(batch),
                                             per_host_batch_size,
                                             data_format=None,
                                             mask_key='targets')

    def eval_train_epoch(num_batches=None):
        eval_train_iter = iter(train_ds)
        for batch in itertools.islice(eval_train_iter, num_batches):
            yield data_utils.maybe_pad_batch(data_utils.tf_to_numpy(batch),
                                             per_host_batch_size,
                                             data_format=None,
                                             mask_key='targets')

    def valid_epoch(num_batches=None):
        valid_iter = iter(eval_ds)
        for batch in itertools.islice(valid_iter, num_batches):
            yield data_utils.maybe_pad_batch(data_utils.tf_to_numpy(batch),
                                             per_host_eval_batch_size,
                                             data_format=None,
                                             mask_key='targets')

    # pylint: disable=unreachable
    def test_epoch(*args, **kwargs):
        del args
        del kwargs
        return
        yield  # This yield is needed to make this a valid (null) iterator.

    # pylint: enable=unreachable

    return Dataset(train_iterator_fn, eval_train_epoch, valid_epoch,
                   test_epoch)
def get_criteo1tb(unused_shuffle_rng, batch_size, eval_batch_size, hps):
    """Get the Criteo 1TB train and eval iterators."""
    process_count = jax.process_count()
    if batch_size % process_count != 0:
        raise ValueError('process_count={} must divide batch_size={}.'.format(
            process_count, batch_size))
    if eval_batch_size is None:
        eval_batch_size = batch_size
    if eval_batch_size % process_count != 0:
        raise ValueError(
            'process_count={} must divide eval_batch_size={}.'.format(
                process_count, eval_batch_size))
    per_host_eval_batch_size = eval_batch_size // process_count
    per_host_batch_size = batch_size // process_count
    train_dataset = _criteo_tsv_reader(
        file_path=hps.train_file_path,
        num_dense_features=hps.num_dense_features,
        vocab_sizes=hps.vocab_sizes,
        batch_size=per_host_batch_size,
        is_training=True)
    train_iterator_fn = lambda: tfds.as_numpy(train_dataset)
    eval_train_dataset = _criteo_tsv_reader(
        file_path=hps.train_file_path,
        num_dense_features=hps.num_dense_features,
        vocab_sizes=hps.vocab_sizes,
        batch_size=per_host_eval_batch_size,
        is_training=False)
    eval_train_epoch = functools.partial(convert_to_numpy_iterator_fn,
                                         tf_dataset=eval_train_dataset)
    eval_dataset = _criteo_tsv_reader(
        file_path=hps.eval_file_path,
        num_dense_features=hps.num_dense_features,
        vocab_sizes=hps.vocab_sizes,
        batch_size=per_host_eval_batch_size,
        is_training=False)
    eval_iterator_fn = functools.partial(convert_to_numpy_iterator_fn,
                                         tf_dataset=eval_dataset)

    # pylint: disable=unreachable
    def test_epoch(*args, **kwargs):
        del args
        del kwargs
        return
        yield  # This yield is needed to make this a valid (null) iterator.

    # pylint: enable=unreachable
    return Dataset(train_iterator_fn, eval_train_epoch, eval_iterator_fn,
                   test_epoch)
示例#5
0
def get_ogbg_molpcba(shuffle_rng, batch_size, eval_batch_size, hps=None):
    """Data generators for ogbg-molpcba."""
    shuffle_buffer_size = 2**15
    shuffle_rng_train, shuffle_rng_eval_train = jax.random.split(shuffle_rng)
    train_ds = _load_dataset('train',
                             should_shuffle=True,
                             shuffle_seed=shuffle_rng_train,
                             shuffle_buffer_size=shuffle_buffer_size)
    eval_train_ds = _load_dataset('train',
                                  should_shuffle=True,
                                  shuffle_seed=shuffle_rng_eval_train,
                                  shuffle_buffer_size=shuffle_buffer_size)
    valid_ds = _load_dataset('validation')
    test_ds = _load_dataset('test')

    iterator_from_ds = functools.partial(
        _get_batch_iterator,
        nodes_per_graph=hps.max_nodes_multiplier,
        edges_per_graph=hps.max_edges_multiplier,
        add_bidirectional_edges=hps.add_bidirectional_edges,
        add_virtual_node=hps.add_virtual_node,
        add_self_loops=hps.add_self_loops)

    def train_iterator_fn():
        return iterator_from_ds(dataset_iter=iter(train_ds),
                                batch_size=batch_size)

    def eval_train_epoch(num_batches=None):
        return itertools.islice(
            iterator_from_ds(dataset_iter=iter(eval_train_ds),
                             batch_size=eval_batch_size), num_batches)

    def valid_epoch(num_batches=None):
        return itertools.islice(
            iterator_from_ds(dataset_iter=iter(valid_ds),
                             batch_size=eval_batch_size), num_batches)

    def test_epoch(num_batches=None):
        return itertools.islice(
            iterator_from_ds(dataset_iter=iter(test_ds),
                             batch_size=eval_batch_size), num_batches)

    return Dataset(train_iterator_fn, eval_train_epoch, valid_epoch,
                   test_epoch)
示例#6
0
def _get_lm1b(hps, per_host_batch_size, per_host_eval_batch_size, shuffle_rng):
  """Data generators for lm1b."""
  n_devices = jax.local_device_count()
  if per_host_batch_size % n_devices != 0:
    raise ValueError('n_devices={} must divide per_host_batch_size={}.'.format(
        n_devices, per_host_batch_size))

  if per_host_eval_batch_size % n_devices != 0:
    raise ValueError(
        'n_devices={} must divide per_host_eval_batch_size={}.'.format(
            n_devices, per_host_eval_batch_size))

  train_ds, eval_ds = lm1b_input_pipeline_v2.get_lm1b_datasets(
      hps, per_host_batch_size, per_host_eval_batch_size, shuffle_rng)

  def train_iterator_fn():
    for batch in iter(train_ds):
      yield _batch_to_dict(batch)

  def eval_train_epoch(num_batches=None):
    eval_train_iter = iter(train_ds)
    for batch in itertools.islice(eval_train_iter, num_batches):
      yield _batch_to_dict(batch)

  def valid_epoch(num_batches=None):
    valid_iter = iter(eval_ds)
    for batch in itertools.islice(valid_iter, num_batches):
      yield _batch_to_dict(batch)

  # pylint: disable=unreachable
  def test_epoch(*args, **kwargs):
    del args
    del kwargs
    return
    yield  # This yield is needed to make this a valid (null) iterator.

  # pylint: enable=unreachable
  return Dataset(train_iterator_fn, eval_train_epoch, valid_epoch, test_epoch)
def _prepare_small_image_datasets(data_train,
                                  data_valid,
                                  data_test,
                                  per_host_batch_size,
                                  per_host_eval_batch_size,
                                  train_size,
                                  rescale,
                                  input_shape,
                                  output_shape,
                                  shuffle_rng,
                                  augment_fn,
                                  is_one_hot=True,
                                  autoencoder=False,
                                  include_example_keys=False):
    """Prepare Dataset using tf.data.Datasets of the different splits."""
    if autoencoder and is_one_hot:
        raise ValueError(
            'One hot encoding cannot be applied to autoencoder datasets.')

    eval_image_iterator = functools.partial(data_utils.image_iterator,
                                            rescale=rescale,
                                            output_shape=output_shape,
                                            is_one_hot=is_one_hot,
                                            autoencoder=autoencoder)

    # Setup the eval_train split as a copy of the training data, in the form of
    # the first `num_train_batches` batches of the data as an np.array.
    eval_train_iterator = eval_image_iterator(data_train)

    num_train_batches = train_size // per_host_batch_size
    eval_train_data = list(
        itertools.islice(eval_train_iterator, 0, num_train_batches))

    eval_train_inputs = jnp.array(
        [batch['inputs'] for batch in eval_train_data])
    eval_train_inputs_shape = (num_train_batches * per_host_batch_size,
                               *input_shape)
    eval_train_inputs = np.reshape(eval_train_inputs, eval_train_inputs_shape)
    eval_train_targets = jnp.array(
        [batch['targets'] for batch in eval_train_data])
    eval_train_output_shape = (num_train_batches * per_host_batch_size,
                               *output_shape)
    eval_train_targets = np.reshape(eval_train_targets,
                                    eval_train_output_shape)

    valid_inputs = jnp.array([])
    valid_targets = jnp.array([])
    valid_example_keys = jnp.array([])
    if data_valid:
        valid_data = next(
            eval_image_iterator(data_valid,
                                include_example_keys=include_example_keys))
        valid_inputs = valid_data['inputs']
        valid_targets = valid_data['targets']
        if include_example_keys:
            valid_example_keys = valid_data['tfds_id']

    test_data = next(eval_image_iterator(data_test))
    test_inputs = jnp.array(test_data['inputs'].astype(np.float32))
    test_targets = test_data['targets']

    # NOTE(dsuo): each host should see the entire dataset, so we remove sharding
    # by host.

    # TODO(gilmer): The simplest way to do this would be to directly yield from
    # tfds.as_numpy(). However we currently do not know how to properly handle
    # restarts with tfds. We'd like the train shuffle to depend on the epoch
    # number, so everytime we generate epoch 10 it yields the same pseudorandom
    # order.

    train_iterator_fn = functools.partial(data_utils.image_iterator,
                                          data_train,
                                          rescale=rescale,
                                          output_shape=output_shape,
                                          is_one_hot=is_one_hot,
                                          autoencoder=autoencoder,
                                          shuffle_rng=shuffle_rng,
                                          augment_fn=augment_fn)

    eval_train_epoch = functools.partial(_eval_batches, eval_train_inputs,
                                         eval_train_targets,
                                         per_host_eval_batch_size)
    valid_epoch = functools.partial(_eval_batches,
                                    valid_inputs,
                                    valid_targets,
                                    per_host_eval_batch_size,
                                    valid_example_keys=valid_example_keys)
    test_epoch = functools.partial(_eval_batches, test_inputs, test_targets,
                                   per_host_eval_batch_size)

    return Dataset(train_iterator_fn, eval_train_epoch, valid_epoch,
                   test_epoch)
示例#8
0
def get_nqm_noise(shuffle_rng, batch_size, eval_batch_size, hps=None):
    """Returns the noise seed for the nqm model.

  NOTE: This dataset is only meant to be used with the nqm model.
  This just generates isotropic Gaussian noise of the desired dimension.
  The nqm model will then multiple this noise by a matrix D, with the properly
  that D^T D = C. This yields noise with gradient covariance C.

  Args:
    shuffle_rng: Not used.
    batch_size: The global train batch size, used to determine the batch size
      yielded from train_epoch().
    eval_batch_size: Not used.
    hps: Hparams object. We only refer to hps.input_shape to determine the
      dimension of the noise.
  Returns:
    train_epoch, eval_train_epoch, valid_epoch, test_epoch: three generators.
      Only train_epoch is used.
  """
    del eval_batch_size

    per_host_batch_size = batch_size // jax.process_count()
    # We only use the first part of the seed, which may result in slightly more
    # rng collisions than normal.
    train_rng = np.random.RandomState(seed=shuffle_rng[0])
    eval_rng = np.random.RandomState(seed=shuffle_rng[0])

    def train_iterator_fn():
        while True:
            yield {
                'inputs':
                train_rng.normal(size=(per_host_batch_size, *hps.input_shape))
            }

    def eval_train_epoch(num_batches):
        for _ in range(num_batches):
            yield {
                'inputs':
                eval_rng.normal(size=(per_host_batch_size, *hps.input_shape))
            }

    # pylint: disable=unreachable
    def valid_epoch(*args, **kwargs):
        del args
        del kwargs
        return
        yield  # This yield is needed to make this a valid (null) iterator.

    # pylint: enable=unreachable

    # pylint: disable=unreachable
    def test_epoch(*args, **kwargs):
        del args
        del kwargs
        return
        yield  # This yield is needed to make this a valid (null) iterator.

    # pylint: enable=unreachable

    return Dataset(train_iterator_fn, eval_train_epoch, valid_epoch,
                   test_epoch)