def get_imagenet(shuffle_rng, batch_size, eval_batch_size, hps): """Data generators for imagenet.""" per_host_batch_size = batch_size // jax.process_count() per_host_eval_batch_size = eval_batch_size // jax.process_count() image_size = hps.input_shape[0] # TODO(gilmer) Currently the training data is not determistic. logging.info('Loading train split') train_ds = load_split(per_host_batch_size, 'train', hps=hps, image_size=image_size, shuffle_rng=shuffle_rng) train_ds = tfds.as_numpy(train_ds) logging.info('Loading eval_train split') eval_train_ds = load_split(per_host_eval_batch_size, 'eval_train', hps=hps, image_size=image_size) eval_train_ds = tfds.as_numpy(eval_train_ds) logging.info('Loading eval split') eval_ds = load_split(per_host_eval_batch_size, 'valid', hps=hps, image_size=image_size) eval_ds = tfds.as_numpy(eval_ds) def train_iterator_fn(): return train_ds def eval_train_epoch(num_batches=None): # This uses per_host_batch_size and not per_host_eval_batch_size. for batch in itertools.islice(eval_train_ds, num_batches): yield data_utils.maybe_pad_batch(batch, per_host_eval_batch_size) def valid_epoch(num_batches=None): for batch in itertools.islice(eval_ds, num_batches): yield data_utils.maybe_pad_batch(batch, per_host_eval_batch_size) # pylint: disable=unreachable def test_epoch(*args, **kwargs): del args del kwargs return yield # This yield is needed to make this a valid (null) iterator. # pylint: enable=unreachable return data_utils.Dataset(train_iterator_fn, eval_train_epoch, valid_epoch, test_epoch)
def get_fastmri(shuffle_rng, batch_size, eval_batch_size, hps): """FastMRI dataset. Args: shuffle_rng: rng for shuffling. batch_size: batch size. eval_batch_size: batch size for eval. hps: hyperparameters. Returns: An init2winit Dataset. """ per_host_batch_size = batch_size // jax.process_count() per_host_eval_batch_size = eval_batch_size // jax.process_count() train_ds = load_split(per_host_batch_size, 'train', hps, shuffle_rng) train_ds = tfds.as_numpy(train_ds) # NOTE(dsuo): fastMRI has fixed randomness for eval. eval_train_ds = load_split(per_host_eval_batch_size, 'eval_train', hps, shuffle_rng) eval_train_ds = tfds.as_numpy(eval_train_ds) eval_ds = load_split(per_host_eval_batch_size, 'val', hps, shuffle_rng) eval_ds = tfds.as_numpy(eval_ds) def train_iterator_fn(): return train_ds def eval_train_epoch(num_batches=None): for batch in itertools.islice(eval_train_ds, num_batches): yield data_utils.maybe_pad_batch(batch, per_host_eval_batch_size) def valid_epoch(num_batches=None): for batch in itertools.islice(eval_ds, num_batches): yield data_utils.maybe_pad_batch(batch, per_host_eval_batch_size) # pylint: disable=unreachable def test_epoch(*args, **kwargs): del args del kwargs return yield # This yield is needed to make this a valid (null) iterator. # pylint: enable=unreachable return data_utils.Dataset(train_iterator_fn, eval_train_epoch, valid_epoch, test_epoch)
def get_fake(shuffle_rng, batch_size, eval_batch_size, hps=None): """Data generators for imagenet.""" del shuffle_rng per_host_batch_size = batch_size // jax.process_count() per_host_eval_batch_size = eval_batch_size // jax.process_count() fake_train_batch = get_fake_batch(per_host_batch_size, hps.input_shape, hps.output_shape[0]) fake_test_batch = get_fake_batch(per_host_eval_batch_size, hps.input_shape, hps.output_shape[0]) def train_iterator_fn(): while True: yield fake_train_batch def valid_epoch(epoch, num_batches=None): del num_batches del epoch # Note that we do // beacuse we do not support partial batching for the fake # dataset. for _ in range(hps.valid_size // eval_batch_size): yield fake_test_batch # pylint: disable=unreachable def eval_train_epoch(*args, **kwargs): del args del kwargs return yield # This yield is needed to make this a valid (null) iterator. # pylint: enable=unreachable # pylint: disable=unreachable def test_epoch(*args, **kwargs): del args del kwargs return yield # This yield is needed to make this a valid (null) iterator. # pylint: enable=unreachable return data_utils.Dataset(train_iterator_fn, eval_train_epoch, valid_epoch, test_epoch)
def get_mlperf_imagenet(rng, batch_size, eval_batch_size, hps=None): """Data generators for imagenet. Args: rng: RNG seed that is split into a shuffle seed and a seed that is folded into a per-example seed. batch_size: the *global* batch size used for training. eval_batch_size: the *global* batch size used for evaluation. hps: the hparams for the experiment, only required field is valid_size. Returns: A data_utils.Dataset for the MLPerf version of ImageNet. """ if batch_size % jax.device_count() != 0: raise ValueError( 'Require batch_size % jax.device_count(), received ' 'batch_size={}, device_count={}.'.format( batch_size, jax.device_count())) if eval_batch_size % jax.device_count() != 0: raise ValueError( 'Require eval_batch_size % jax.device_count(), received ' 'eval_batch_size={}, device_count={}.'.format( eval_batch_size, jax.device_count())) host_batch_size = batch_size // jax.host_count() eval_host_batch_size = eval_batch_size // jax.host_count() max_eval_steps = hps.valid_size // eval_batch_size + 1 input_dtype = tf.bfloat16 shuffle_buffer_size = 16384 train_ds = mlperf_input_pipeline.load_split( host_batch_size, dtype=input_dtype, split='train', rng=rng, shuffle_size=shuffle_buffer_size, preprocess_fn=_preprocess_fn) eval_train_ds = mlperf_input_pipeline.load_split( host_batch_size, dtype=input_dtype, split='eval_train', rng=rng, shuffle_size=shuffle_buffer_size, preprocess_fn=_preprocess_fn) eval_ds = mlperf_input_pipeline.load_split( eval_host_batch_size, dtype=input_dtype, split='validation', rng=rng, shuffle_size=shuffle_buffer_size, preprocess_fn=_preprocess_fn) # We cannot use tfds.as_numpy because this calls tensor.numpy() which does an # additional copy of the tensor, instead we call tensor._numpy() below. def train_iterator_fn(): return data_utils.iterator_as_numpy(iter(train_ds)) def eval_train_epoch(num_batches=None): if num_batches is None: num_batches = 0 eval_train_iter = iter(eval_train_ds) np_iter = data_utils.iterator_as_numpy( itertools.islice(eval_train_iter, num_batches)) for batch in np_iter: yield data_utils.maybe_pad_batch(batch, eval_host_batch_size) def valid_epoch(num_batches=None): if num_batches is None: num_batches = max_eval_steps valid_iter = iter(eval_ds) np_iter = data_utils.iterator_as_numpy( itertools.islice(valid_iter, num_batches)) for batch in np_iter: yield data_utils.maybe_pad_batch(batch, eval_host_batch_size) # pylint: disable=unreachable def test_epoch(*args, **kwargs): del args del kwargs return yield # This yield is needed to make this a valid (null) iterator. # pylint: enable=unreachable return data_utils.Dataset( train_iterator_fn, eval_train_epoch, valid_epoch, test_epoch)
def get_imagenet(shuffle_rng, batch_size, eval_batch_size, hps): """Data generators for imagenet.""" per_host_batch_size = batch_size per_host_eval_batch_size = eval_batch_size image_size = hps.input_shape[0] num_classes = 1000 # TODO(gilmer) Currently the training data is not determistic. logging.info('Loading train split') train_ds = load_split( per_host_batch_size, 'train', hps=hps, image_size=image_size, shuffle_rng=shuffle_rng) train_ds = tfds.as_numpy(train_ds) logging.info('Loading eval_train split') eval_train_ds = load_split( per_host_eval_batch_size, 'eval_train', hps=hps, image_size=image_size) eval_train_ds = tfds.as_numpy(eval_train_ds) logging.info('Loading eval split') eval_ds = load_split( per_host_eval_batch_size, 'valid', hps=hps, image_size=image_size) eval_ds = tfds.as_numpy(eval_ds) def train_iterator_fn(): if hps.use_mixup: # NOTE(dsuo): using `fold_in` so as not to disturb shuffle_rng. mixup_rng = jax.random.fold_in(shuffle_rng, jax.process_index()) # NOTE(dsuo): synchronize weights across hosts. More generally, we # should consider passing a global key into `get_dataset`. # NOTE(dsuo): in order to preserve the dirichlet distribution, we set # mixup_rng to be non-zero for only the 0th device on the 0th host and # psum so each host gets the same value. mixup_rng = jnp.tile(mixup_rng, (jax.local_device_count(), 1)) if jax.process_index() == 0: one_hot = jax.nn.one_hot( jnp.array([0]), jax.local_device_count(), dtype=mixup_rng.dtype) mixup_rng *= one_hot.T else: mixup_rng = jnp.zeros_like(mixup_rng) mixup_rng = jax.pmap(lambda x: jax.lax.psum(x, 'devices'), 'devices')( mixup_rng) mixup_rng = mixup_rng.at[0].get() for batch in iter(train_ds): image = batch['image'] targets = np.eye(num_classes)[batch['label']] if hps.use_mixup: mixup_rng = jax.random.fold_in(mixup_rng, 0) (image, targets), _ = image_preprocessing.mixup_general( mixup_rng, image, targets, alpha=hps.mixup.alpha, n=2) yield { 'inputs': image, 'targets': targets, 'weights': np.ones(per_host_batch_size, dtype=image.dtype) } def eval_train_epoch(num_batches=None): # This uses per_host_batch_size and not per_host_eval_batch_size. eval_train_iter = iter(eval_train_ds) for batch in itertools.islice(eval_train_iter, num_batches): batch_dict = { 'inputs': batch['image'], 'targets': np.eye(num_classes)[batch['label']], } if hps.get('include_example_keys'): batch_dict['example_key'] = batch['example_key'] yield data_utils.maybe_pad_batch(batch_dict, per_host_eval_batch_size) def valid_epoch(num_batches=None): valid_iter = iter(eval_ds) for batch in itertools.islice(valid_iter, num_batches): batch_dict = { 'inputs': batch['image'], 'targets': np.eye(num_classes)[batch['label']], } if hps.get('include_example_keys'): batch_dict['example_key'] = batch['example_key'] yield data_utils.maybe_pad_batch(batch_dict, per_host_eval_batch_size) # pylint: disable=unreachable def test_epoch(*args, **kwargs): del args del kwargs return yield # This yield is needed to make this a valid (null) iterator. # pylint: enable=unreachable return data_utils.Dataset( train_iterator_fn, eval_train_epoch, valid_epoch, test_epoch)
def get_imagenet(shuffle_rng, batch_size, eval_batch_size, hps): """Data generators for imagenet.""" per_host_batch_size = batch_size // jax.host_count() per_host_eval_batch_size = eval_batch_size // jax.host_count() image_size = hps.input_shape[0] num_classes = 1000 # TODO(gilmer) Currently the training data is not determistic. logging.info('Loading train split') train_ds = load_split(per_host_batch_size, 'train', hps=hps, image_size=image_size, shuffle_rng=shuffle_rng) train_ds = tfds.as_numpy(train_ds) logging.info('Loading eval_train split') eval_train_ds = load_split(per_host_eval_batch_size, 'eval_train', hps=hps, image_size=image_size) eval_train_ds = tfds.as_numpy(eval_train_ds) logging.info('Loading eval split') eval_ds = load_split(per_host_eval_batch_size, 'valid', hps=hps, image_size=image_size) eval_ds = tfds.as_numpy(eval_ds) def train_iterator_fn(): for batch in iter(train_ds): image = batch['image'] yield { 'inputs': image, 'targets': np.eye(num_classes)[batch['label']], 'weights': np.ones(per_host_batch_size, dtype=image.dtype) } def eval_train_epoch(num_batches=None): # This uses per_host_batch_size and not per_host_eval_batch_size. eval_train_iter = iter(eval_train_ds) for batch in itertools.islice(eval_train_iter, num_batches): batch_dict = { 'inputs': batch['image'], 'targets': np.eye(num_classes)[batch['label']], } if hps.get('include_example_keys'): batch_dict['example_key'] = batch['example_key'] yield data_utils.maybe_pad_batch(batch_dict, per_host_eval_batch_size) def valid_epoch(num_batches=None): valid_iter = iter(eval_ds) for batch in itertools.islice(valid_iter, num_batches): batch_dict = { 'inputs': batch['image'], 'targets': np.eye(num_classes)[batch['label']], } if hps.get('include_example_keys'): batch_dict['example_key'] = batch['example_key'] yield data_utils.maybe_pad_batch(batch_dict, per_host_eval_batch_size) # pylint: disable=unreachable def test_epoch(*args, **kwargs): del args del kwargs return yield # This yield is needed to make this a valid (null) iterator. # pylint: enable=unreachable return data_utils.Dataset(train_iterator_fn, eval_train_epoch, valid_epoch, test_epoch)