def get_imagenet(shuffle_rng, batch_size, eval_batch_size, hps):
    """Data generators for imagenet."""
    per_host_batch_size = batch_size // jax.process_count()
    per_host_eval_batch_size = eval_batch_size // jax.process_count()

    image_size = hps.input_shape[0]

    # TODO(gilmer) Currently the training data is not determistic.
    logging.info('Loading train split')
    train_ds = load_split(per_host_batch_size,
                          'train',
                          hps=hps,
                          image_size=image_size,
                          shuffle_rng=shuffle_rng)
    train_ds = tfds.as_numpy(train_ds)
    logging.info('Loading eval_train split')
    eval_train_ds = load_split(per_host_eval_batch_size,
                               'eval_train',
                               hps=hps,
                               image_size=image_size)
    eval_train_ds = tfds.as_numpy(eval_train_ds)
    logging.info('Loading eval split')
    eval_ds = load_split(per_host_eval_batch_size,
                         'valid',
                         hps=hps,
                         image_size=image_size)
    eval_ds = tfds.as_numpy(eval_ds)

    def train_iterator_fn():
        return train_ds

    def eval_train_epoch(num_batches=None):
        # This uses per_host_batch_size and not per_host_eval_batch_size.
        for batch in itertools.islice(eval_train_ds, num_batches):
            yield data_utils.maybe_pad_batch(batch, per_host_eval_batch_size)

    def valid_epoch(num_batches=None):
        for batch in itertools.islice(eval_ds, num_batches):
            yield data_utils.maybe_pad_batch(batch, per_host_eval_batch_size)

    # pylint: disable=unreachable
    def test_epoch(*args, **kwargs):
        del args
        del kwargs
        return
        yield  # This yield is needed to make this a valid (null) iterator.

    # pylint: enable=unreachable

    return data_utils.Dataset(train_iterator_fn, eval_train_epoch, valid_epoch,
                              test_epoch)
Exemple #2
0
def get_fastmri(shuffle_rng, batch_size, eval_batch_size, hps):
    """FastMRI dataset.

  Args:
    shuffle_rng: rng for shuffling.
    batch_size: batch size.
    eval_batch_size: batch size for eval.
    hps: hyperparameters.

  Returns:
    An init2winit Dataset.
  """
    per_host_batch_size = batch_size // jax.process_count()
    per_host_eval_batch_size = eval_batch_size // jax.process_count()

    train_ds = load_split(per_host_batch_size, 'train', hps, shuffle_rng)
    train_ds = tfds.as_numpy(train_ds)

    # NOTE(dsuo): fastMRI has fixed randomness for eval.
    eval_train_ds = load_split(per_host_eval_batch_size, 'eval_train', hps,
                               shuffle_rng)
    eval_train_ds = tfds.as_numpy(eval_train_ds)
    eval_ds = load_split(per_host_eval_batch_size, 'val', hps, shuffle_rng)
    eval_ds = tfds.as_numpy(eval_ds)

    def train_iterator_fn():
        return train_ds

    def eval_train_epoch(num_batches=None):
        for batch in itertools.islice(eval_train_ds, num_batches):
            yield data_utils.maybe_pad_batch(batch, per_host_eval_batch_size)

    def valid_epoch(num_batches=None):
        for batch in itertools.islice(eval_ds, num_batches):
            yield data_utils.maybe_pad_batch(batch, per_host_eval_batch_size)

    # pylint: disable=unreachable
    def test_epoch(*args, **kwargs):
        del args
        del kwargs
        return
        yield  # This yield is needed to make this a valid (null) iterator.

    # pylint: enable=unreachable

    return data_utils.Dataset(train_iterator_fn, eval_train_epoch, valid_epoch,
                              test_epoch)
Exemple #3
0
def get_fake(shuffle_rng, batch_size, eval_batch_size, hps=None):
    """Data generators for imagenet."""
    del shuffle_rng
    per_host_batch_size = batch_size // jax.process_count()
    per_host_eval_batch_size = eval_batch_size // jax.process_count()

    fake_train_batch = get_fake_batch(per_host_batch_size, hps.input_shape,
                                      hps.output_shape[0])
    fake_test_batch = get_fake_batch(per_host_eval_batch_size, hps.input_shape,
                                     hps.output_shape[0])

    def train_iterator_fn():
        while True:
            yield fake_train_batch

    def valid_epoch(epoch, num_batches=None):
        del num_batches
        del epoch
        # Note that we do // beacuse we do not support partial batching for the fake
        # dataset.
        for _ in range(hps.valid_size // eval_batch_size):
            yield fake_test_batch

    # pylint: disable=unreachable
    def eval_train_epoch(*args, **kwargs):
        del args
        del kwargs
        return
        yield  # This yield is needed to make this a valid (null) iterator.

    # pylint: enable=unreachable
    # pylint: disable=unreachable

    def test_epoch(*args, **kwargs):
        del args
        del kwargs
        return
        yield  # This yield is needed to make this a valid (null) iterator.

    # pylint: enable=unreachable

    return data_utils.Dataset(train_iterator_fn, eval_train_epoch, valid_epoch,
                              test_epoch)
Exemple #4
0
def get_mlperf_imagenet(rng,
                        batch_size,
                        eval_batch_size,
                        hps=None):
  """Data generators for imagenet.

  Args:
    rng: RNG seed that is split into a shuffle seed and a seed that is folded
      into a per-example seed.
    batch_size: the *global* batch size used for training.
    eval_batch_size: the *global* batch size used for evaluation.
    hps: the hparams for the experiment, only required field is valid_size.

  Returns:
    A data_utils.Dataset for the MLPerf version of ImageNet.
  """
  if batch_size % jax.device_count() != 0:
    raise ValueError(
        'Require batch_size % jax.device_count(), received '
        'batch_size={}, device_count={}.'.format(
            batch_size, jax.device_count()))
  if eval_batch_size % jax.device_count() != 0:
    raise ValueError(
        'Require eval_batch_size % jax.device_count(), received '
        'eval_batch_size={}, device_count={}.'.format(
            eval_batch_size, jax.device_count()))
  host_batch_size = batch_size // jax.host_count()
  eval_host_batch_size = eval_batch_size // jax.host_count()

  max_eval_steps = hps.valid_size // eval_batch_size + 1

  input_dtype = tf.bfloat16
  shuffle_buffer_size = 16384

  train_ds = mlperf_input_pipeline.load_split(
      host_batch_size,
      dtype=input_dtype,
      split='train',
      rng=rng,
      shuffle_size=shuffle_buffer_size,
      preprocess_fn=_preprocess_fn)

  eval_train_ds = mlperf_input_pipeline.load_split(
      host_batch_size,
      dtype=input_dtype,
      split='eval_train',
      rng=rng,
      shuffle_size=shuffle_buffer_size,
      preprocess_fn=_preprocess_fn)

  eval_ds = mlperf_input_pipeline.load_split(
      eval_host_batch_size,
      dtype=input_dtype,
      split='validation',
      rng=rng,
      shuffle_size=shuffle_buffer_size,
      preprocess_fn=_preprocess_fn)

  # We cannot use tfds.as_numpy because this calls tensor.numpy() which does an
  # additional copy of the tensor, instead we call tensor._numpy() below.
  def train_iterator_fn():
    return data_utils.iterator_as_numpy(iter(train_ds))

  def eval_train_epoch(num_batches=None):
    if num_batches is None:
      num_batches = 0
    eval_train_iter = iter(eval_train_ds)
    np_iter = data_utils.iterator_as_numpy(
        itertools.islice(eval_train_iter, num_batches))
    for batch in np_iter:
      yield data_utils.maybe_pad_batch(batch, eval_host_batch_size)

  def valid_epoch(num_batches=None):
    if num_batches is None:
      num_batches = max_eval_steps
    valid_iter = iter(eval_ds)
    np_iter = data_utils.iterator_as_numpy(
        itertools.islice(valid_iter, num_batches))
    for batch in np_iter:
      yield data_utils.maybe_pad_batch(batch, eval_host_batch_size)

  # pylint: disable=unreachable
  def test_epoch(*args, **kwargs):
    del args
    del kwargs
    return
    yield  # This yield is needed to make this a valid (null) iterator.
  # pylint: enable=unreachable

  return data_utils.Dataset(
      train_iterator_fn,
      eval_train_epoch,
      valid_epoch,
      test_epoch)
Exemple #5
0
def get_imagenet(shuffle_rng,
                 batch_size,
                 eval_batch_size,
                 hps):
  """Data generators for imagenet."""
  per_host_batch_size = batch_size
  per_host_eval_batch_size = eval_batch_size

  image_size = hps.input_shape[0]
  num_classes = 1000

  # TODO(gilmer) Currently the training data is not determistic.
  logging.info('Loading train split')
  train_ds = load_split(
      per_host_batch_size,
      'train',
      hps=hps,
      image_size=image_size,
      shuffle_rng=shuffle_rng)
  train_ds = tfds.as_numpy(train_ds)
  logging.info('Loading eval_train split')
  eval_train_ds = load_split(
      per_host_eval_batch_size, 'eval_train', hps=hps, image_size=image_size)
  eval_train_ds = tfds.as_numpy(eval_train_ds)
  logging.info('Loading eval split')
  eval_ds = load_split(
      per_host_eval_batch_size, 'valid', hps=hps, image_size=image_size)
  eval_ds = tfds.as_numpy(eval_ds)

  def train_iterator_fn():
    if hps.use_mixup:
      # NOTE(dsuo): using `fold_in` so as not to disturb shuffle_rng.
      mixup_rng = jax.random.fold_in(shuffle_rng, jax.process_index())

      # NOTE(dsuo): synchronize weights across hosts. More generally, we
      # should consider passing a global key into `get_dataset`.
      # NOTE(dsuo): in order to preserve the dirichlet distribution, we set
      # mixup_rng to be non-zero for only the 0th device on the 0th host and
      # psum so each host gets the same value.
      mixup_rng = jnp.tile(mixup_rng, (jax.local_device_count(), 1))
      if jax.process_index() == 0:
        one_hot = jax.nn.one_hot(
            jnp.array([0]), jax.local_device_count(), dtype=mixup_rng.dtype)
        mixup_rng *= one_hot.T
      else:
        mixup_rng = jnp.zeros_like(mixup_rng)
      mixup_rng = jax.pmap(lambda x: jax.lax.psum(x, 'devices'), 'devices')(
          mixup_rng)
      mixup_rng = mixup_rng.at[0].get()

    for batch in iter(train_ds):
      image = batch['image']
      targets = np.eye(num_classes)[batch['label']]
      if hps.use_mixup:
        mixup_rng = jax.random.fold_in(mixup_rng, 0)
        (image, targets), _ = image_preprocessing.mixup_general(
            mixup_rng,
            image,
            targets,
            alpha=hps.mixup.alpha,
            n=2)
      yield {
          'inputs': image,
          'targets': targets,
          'weights': np.ones(per_host_batch_size, dtype=image.dtype)
      }

  def eval_train_epoch(num_batches=None):
    # This uses per_host_batch_size and not per_host_eval_batch_size.
    eval_train_iter = iter(eval_train_ds)
    for batch in itertools.islice(eval_train_iter, num_batches):
      batch_dict = {
          'inputs': batch['image'],
          'targets': np.eye(num_classes)[batch['label']],
      }
      if hps.get('include_example_keys'):
        batch_dict['example_key'] = batch['example_key']
      yield data_utils.maybe_pad_batch(batch_dict, per_host_eval_batch_size)

  def valid_epoch(num_batches=None):
    valid_iter = iter(eval_ds)
    for batch in itertools.islice(valid_iter, num_batches):
      batch_dict = {
          'inputs': batch['image'],
          'targets': np.eye(num_classes)[batch['label']],
      }
      if hps.get('include_example_keys'):
        batch_dict['example_key'] = batch['example_key']
      yield data_utils.maybe_pad_batch(batch_dict, per_host_eval_batch_size)

  # pylint: disable=unreachable
  def test_epoch(*args, **kwargs):
    del args
    del kwargs
    return
    yield  # This yield is needed to make this a valid (null) iterator.
  # pylint: enable=unreachable

  return data_utils.Dataset(
      train_iterator_fn, eval_train_epoch, valid_epoch, test_epoch)
def get_imagenet(shuffle_rng, batch_size, eval_batch_size, hps):
    """Data generators for imagenet."""
    per_host_batch_size = batch_size // jax.host_count()
    per_host_eval_batch_size = eval_batch_size // jax.host_count()

    image_size = hps.input_shape[0]
    num_classes = 1000

    # TODO(gilmer) Currently the training data is not determistic.
    logging.info('Loading train split')
    train_ds = load_split(per_host_batch_size,
                          'train',
                          hps=hps,
                          image_size=image_size,
                          shuffle_rng=shuffle_rng)
    train_ds = tfds.as_numpy(train_ds)
    logging.info('Loading eval_train split')
    eval_train_ds = load_split(per_host_eval_batch_size,
                               'eval_train',
                               hps=hps,
                               image_size=image_size)
    eval_train_ds = tfds.as_numpy(eval_train_ds)
    logging.info('Loading eval split')
    eval_ds = load_split(per_host_eval_batch_size,
                         'valid',
                         hps=hps,
                         image_size=image_size)
    eval_ds = tfds.as_numpy(eval_ds)

    def train_iterator_fn():
        for batch in iter(train_ds):
            image = batch['image']
            yield {
                'inputs': image,
                'targets': np.eye(num_classes)[batch['label']],
                'weights': np.ones(per_host_batch_size, dtype=image.dtype)
            }

    def eval_train_epoch(num_batches=None):
        # This uses per_host_batch_size and not per_host_eval_batch_size.
        eval_train_iter = iter(eval_train_ds)
        for batch in itertools.islice(eval_train_iter, num_batches):
            batch_dict = {
                'inputs': batch['image'],
                'targets': np.eye(num_classes)[batch['label']],
            }
            if hps.get('include_example_keys'):
                batch_dict['example_key'] = batch['example_key']
            yield data_utils.maybe_pad_batch(batch_dict,
                                             per_host_eval_batch_size)

    def valid_epoch(num_batches=None):
        valid_iter = iter(eval_ds)
        for batch in itertools.islice(valid_iter, num_batches):
            batch_dict = {
                'inputs': batch['image'],
                'targets': np.eye(num_classes)[batch['label']],
            }
            if hps.get('include_example_keys'):
                batch_dict['example_key'] = batch['example_key']
            yield data_utils.maybe_pad_batch(batch_dict,
                                             per_host_eval_batch_size)

    # pylint: disable=unreachable
    def test_epoch(*args, **kwargs):
        del args
        del kwargs
        return
        yield  # This yield is needed to make this a valid (null) iterator.

    # pylint: enable=unreachable

    return data_utils.Dataset(train_iterator_fn, eval_train_epoch, valid_epoch,
                              test_epoch)