def create_val_dataset(config, split, batch_size, pad_last_batch):
    """Create validataion dataset.

  Args:
    config: ml_collections.ConfigDict to use.
    split: The validation split.
    batch_size: The batch size.
    pad_last_batch: Bool to indicate whether to pad last patch or not.

  Returns:
    The validation dataset.
  """
    dataset_builder = gscan_dataset.GSCANDataset(**config)
    num_batches = None
    cardinality = None
    if pad_last_batch:
        num_examples = dataset_builder.num_examples[split]
        val_batch_size = jax.local_device_count() * batch_size
        num_batches = int(
            np.ceil(num_examples / val_batch_size / jax.process_count()))
        cardinality = int(np.ceil(num_examples / jax.process_count()))
    ds = deterministic_data.create_dataset(
        dataset_builder,
        split=split,
        preprocess_fn=dataset_builder.preprocess,
        cache=jax.process_count() > 1,
        batch_dims=[jax.local_device_count(), batch_size],
        num_epochs=1,
        shuffle=False,
        pad_up_to_batches=num_batches,
        cardinality=cardinality)
    return ds
Example #2
0
def setup_jax(globally_use_hardware_rng: bool, jax_use_gda: bool,
              jax_backend_target: Optional[str],
              jax_xla_backend: Optional[str]) -> None:
    """Setups JAX and logs information about this job."""
    # Hide any GPUs from TensorFlow. Otherwise TF might reserve memory and make
    # it unavailable to JAX.
    tf.config.experimental.set_visible_devices([], 'GPU')
    if globally_use_hardware_rng:
        py_utils.set_globally_use_rbg_prng_key()

    # We use xmap only with SPMD.
    jax.config.update('experimental_xmap_spmd_lowering', True)
    # Use the manual partitioning lowering of xmap to avoid vectorization.
    jax.config.update('experimental_xmap_spmd_lowering_manual', True)

    if jax_use_gda:
        logging.info('Using JAX GSDA for pjit and checkpointing')

    if jax_backend_target:
        logging.info('Using JAX backend target %s', jax_backend_target)
        jax_xla_backend = 'None' if jax_xla_backend is None else jax_xla_backend
        logging.info('Using JAX XLA backend %s', jax_xla_backend)

    logging.info('JAX process: %d / %d', jax.process_index(),
                 jax.process_count())
    logging.info('JAX devices: %r', jax.devices())
    logging.info('jax.device_count(): %d', jax.device_count())
    logging.info('jax.local_device_count(): %d', jax.local_device_count())
    logging.info('jax.process_count(): %d', jax.process_count())
Example #3
0
def get_translate_wmt(shuffle_rng, batch_size, eval_batch_size=None, hps=None):
  """Wrapper to conform to the general dataset API."""

  per_host_batch_size = batch_size // jax.process_count()
  per_host_eval_batch_size = eval_batch_size // jax.process_count()
  return _get_translate_wmt(per_host_batch_size,
                            per_host_eval_batch_size,
                            hps,
                            shuffle_rng)
Example #4
0
 def _write_lock_values(self):
   write_count = 0
   for p in range(jax.process_count()):
     # TODO(yashkatariya): Make the key value store writes safe if checkpoint
     # managers are created concurrently.
     self._client.key_value_set(_get_key(str(p)), f'Lock value for process {str(p)}')
     write_count += 1
   if write_count != jax.process_count():
     raise ValueError("Process 0 couldn't write all the lock values.")
   logging.info('Lock values for all processes have been written by process 0.')
Example #5
0
def get_imagenet(shuffle_rng, batch_size, eval_batch_size, hps):
    """Data generators for imagenet."""
    per_host_batch_size = batch_size // jax.process_count()
    per_host_eval_batch_size = eval_batch_size // jax.process_count()

    image_size = hps.input_shape[0]

    # TODO(gilmer) Currently the training data is not determistic.
    logging.info('Loading train split')
    train_ds = load_split(per_host_batch_size,
                          'train',
                          hps=hps,
                          image_size=image_size,
                          shuffle_rng=shuffle_rng)
    train_ds = tfds.as_numpy(train_ds)
    logging.info('Loading eval_train split')
    eval_train_ds = load_split(per_host_eval_batch_size,
                               'eval_train',
                               hps=hps,
                               image_size=image_size)
    eval_train_ds = tfds.as_numpy(eval_train_ds)
    logging.info('Loading eval split')
    eval_ds = load_split(per_host_eval_batch_size,
                         'valid',
                         hps=hps,
                         image_size=image_size)
    eval_ds = tfds.as_numpy(eval_ds)

    def train_iterator_fn():
        return train_ds

    def eval_train_epoch(num_batches=None):
        # This uses per_host_batch_size and not per_host_eval_batch_size.
        for batch in itertools.islice(eval_train_ds, num_batches):
            yield data_utils.maybe_pad_batch(batch, per_host_eval_batch_size)

    def valid_epoch(num_batches=None):
        for batch in itertools.islice(eval_ds, num_batches):
            yield data_utils.maybe_pad_batch(batch, per_host_eval_batch_size)

    # pylint: disable=unreachable
    def test_epoch(*args, **kwargs):
        del args
        del kwargs
        return
        yield  # This yield is needed to make this a valid (null) iterator.

    # pylint: enable=unreachable

    return data_utils.Dataset(train_iterator_fn, eval_train_epoch, valid_epoch,
                              test_epoch)
Example #6
0
def get_total_steps(config):
  """Get total_steps of training.

  Args:
    config: The config of the experiment.

  Returns:
    Total_steps of training.
  """
  local_batch_size = config.batch_size // jax.process_count()
  ntrain_img = input_utils.get_num_examples(
      config.dataset,
      split=config.train_split,
      process_batch_size=local_batch_size,
      data_dir=config.get('data_dir'))
  steps_per_epoch = ntrain_img // config.batch_size

  if config.get('num_epochs'):
    total_steps = int(config.num_epochs * steps_per_epoch)
    assert not config.get('total_steps'), 'Set either num_epochs or total_steps'
  else:
    total_steps = config.total_steps

  logging.info('Total train data points: %d', ntrain_img)
  logging.info(
      'Running for %d steps, that means %f epochs and %d steps per epoch',
      total_steps, total_steps * config.batch_size / ntrain_img,
      steps_per_epoch)
  return total_steps
Example #7
0
def main(argv):
    if len(argv) > 1:
        raise app.UsageError('Too many command-line arguments.')

    # Hide any GPUs form TensorFlow. Otherwise TF might reserve memory and make
    # it unavailable to JAX.
    tf.config.experimental.set_visible_devices([], 'GPU')

    logging.info('JAX process: %d / %d', jax.process_index(),
                 jax.process_count())
    logging.info('JAX local devices: %r', jax.local_devices())

    # Add a note so that we can tell which task is which JAX host.
    # (Depending on the platform task 0 is not guaranteed to be host 0)
    platform.work_unit().set_task_status(
        f'process_index: {jax.process_index()}, '
        f'process_count: {jax.process_count()}')
    platform.work_unit().create_artifact(platform.ArtifactType.DIRECTORY,
                                         FLAGS.workdir, 'workdir')

    if FLAGS.mode == 'train':
        train.train_and_evaluate(FLAGS.config, FLAGS.workdir)
    else:
        predict.predict_and_evaluate(FLAGS.config, FLAGS.workdir,
                                     FLAGS.ckpt_path)
Example #8
0
def main(argv):
    del argv

    # Hide any GPUs form TensorFlow. Otherwise TF might reserve memory and make
    # it unavailable to JAX.
    tf.config.experimental.set_visible_devices([], "GPU")

    logging.info("JAX process: %d / %d", jax.process_index(),
                 jax.process_count())
    logging.info("JAX devices: %r", jax.devices())

    # Add a note so that we can tell which task is which JAX process.
    platform.work_unit().set_task_status(
        f"process_index: {jax.process_index()}, process_count: {jax.process_count()}"
    )
    platform.work_unit().create_artifact(platform.ArtifactType.DIRECTORY,
                                         FLAGS.workdir, "workdir")

    train_mode = FLAGS.config.mode
    if train_mode == TrainingMode.PRETRAINING:
        train_lib = run_pretraining
    elif train_mode == TrainingMode.CLASSIFICATION:
        train_lib = run_classifier
    else:
        raise ValueError("Unknown training mode: %s" % train_mode)

    train_lib.train_and_evaluate(FLAGS.config, FLAGS.workdir,
                                 FLAGS.vocab_filepath)
Example #9
0
def main(argv):
    if len(argv) > 1:
        raise app.UsageError('Too many command-line arguments.')

    utils.add_gfile_logger(_WORKDIR.value)

    # Hide any GPUs form TensorFlow. Otherwise TF might reserve memory and make
    # it unavailable to JAX.
    tf.config.experimental.set_visible_devices([], 'GPU')

    jax.config.update('jax_log_compiles', True)

    logging.info('JAX process: %d / %d', jax.process_index(),
                 jax.process_count())
    logging.info('JAX local devices: %r', jax.local_devices())
    jax_xla_backend = ('None' if FLAGS.jax_xla_backend is None else
                       FLAGS.jax_xla_backend)
    logging.info('Using JAX XLA backend %s', jax_xla_backend)

    logging.info('Config: %s', FLAGS.config)

    # Add a note so that we can tell which task is which JAX host.
    # (Depending on the platform task 0 is not guaranteed to be host 0)
    platform.work_unit().set_task_status(
        f'process_index: {jax.process_index()}, '
        f'process_count: {jax.process_count()}')
    platform.work_unit().create_artifact(platform.ArtifactType.DIRECTORY,
                                         _WORKDIR.value, 'workdir')

    if FLAGS.config.trainer == 'train':
        train.train_and_evaluate(FLAGS.config, _WORKDIR.value)
    elif FLAGS.config.trainer == 'inference_time':
        inference_time.inference_time(FLAGS.config, _WORKDIR.value)
    else:
        raise app.UsageError(f'Unknown trainer: {FLAGS.config.trainer}')
Example #10
0
def evaluate(
    model_name: str,
    job_log_dir: Optional[str],
    multi_host_checkpointing: Optional[bool],
    maybe_use_persistence_checkpointing: bool,
) -> None:
  """Runs the evaluation loop on the entire eval data set.

  Args:
    model_name: The name of the model from the registry to evaluate.
    job_log_dir: The directory for the job logs.
    multi_host_checkpointing: Whether to use multi-host checkpointing.
    maybe_use_persistence_checkpointing: If set, it will try to use
      persistence-based checkpointing if suitable.
  """
  model_config = model_utils.get_model(model_name)()
  task_p = model_config.task()
  model_p = task_p.model
  eval_input_p = [v for v in model_config.datasets() if not v.is_training]
  for inp in eval_input_p:
    inp.num_infeed_hosts = jax.process_count()
    inp.infeed_host_index = jax.process_index()

  if model_p.device_mesh is not None:
    checkpoint_type = checkpoints.retrieve_checkpoint_type(
        multi_host_checkpointing, maybe_use_persistence_checkpointing, task_p)
    evaluate_spmd_model(task_p, eval_input_p, job_log_dir, checkpoint_type)
  else:
    evaluate_pmap_model(task_p, eval_input_p, job_log_dir)
Example #11
0
def create_init(model, config, train_ds):
  """Create the initialization function for model parameters.

  Args:
    model: The model to be used in updates.
    config: The config of the experiment.
    train_ds: tf.data.Dataset.

  Returns:
    Function that returns initialized model parameters.
  """
  local_batch_size = config.batch_size // jax.process_count()
  # We want all parameters to be created in host RAM, not on any device, they'll
  # be sent there later as needed, otherwise we already encountered two
  # situations where we allocate them twice.
  @functools.partial(jax.jit, backend='cpu')
  def init(rng):
    image_size = tuple(train_ds.element_spec['image'].shape[2:])
    logging.info('image_size = %s', image_size)
    dummy_input = jnp.zeros((local_batch_size,) + image_size, jnp.float32)
    params = flax.core.unfreeze(model.init(rng, dummy_input,
                                           train=False))['params']

    # Set bias in the head to a low value, such that loss is small initially.
    params['batchensemble_head']['bias'] = jnp.full_like(
        params['batchensemble_head']['bias'], config.get('init_head_bias', 0))

    # init head kernel to all zeros for fine-tuning
    if config.get('model_init'):
      params['batchensemble_head']['kernel'] = jnp.full_like(
          params['batchensemble_head']['kernel'], 0)

    return params

  return init
Example #12
0
def main(argv):
    del argv

    # Hide any GPUs form TensorFlow. Otherwise, TF might reserve memory and make
    # it unavailable to JAX.
    tf.config.experimental.set_visible_devices([], "GPU")

    logging.info("JAX host: %d / %d", jax.process_index(), jax.process_count())
    logging.info("JAX devices: %r", jax.devices())

    # Add a note so that we can tell which task is which JAX host. (Task 0 is not
    # guaranteed to be host 0)
    platform.work_unit().set_task_status(
        f"host_id: {jax.process_index()}, host_count: {jax.process_count()}")
    platform.work_unit().create_artifact(platform.ArtifactType.DIRECTORY,
                                         _WORKDIR.value, "workdir")

    train_mode = _CONFIG.value.mode
    if train_mode == TrainingMode.PRETRAINING:
        train_lib = run_pretraining
    elif train_mode == TrainingMode.CLASSIFICATION:
        train_lib = run_classifier
    else:
        raise ValueError("Unknown mode: %s" % train_mode)

    train_lib.train_and_evaluate(_CONFIG.value, _WORKDIR.value,
                                 _VOCAB_FILEPATH.value)
def get_data_iterator(config):
    """Get iterator over the dataset."""
    task = memory_generation_task.MemoryGenerationTask

    # Establish host information
    local_device_count = jax.local_device_count()
    host_count = jax.process_count()
    host_id = jax.process_index()

    # Load datasets
    logging.info('Loading dataset.')
    decode_fn = data_utils.make_decode_fn(
        name_to_features=task.get_name_to_features(config),
        samples_per_example=config.samples_per_example,
    )

    preprocess_fn = task.make_preprocess_fn(config)
    collater_fn = task.make_collater_fn(config)

    data = data_utils.load_dataset(
        patterns=config.data_patterns,
        decode_fn=decode_fn,
        preprocess_fn=preprocess_fn,
        collater_fn=collater_fn,
        is_training=False,
        per_device_batch_size=config.per_device_batch_size,
        local_device_count=local_device_count,
        host_count=host_count,
        host_id=host_id,
        seed=0,
    )
    return iter(data)
Example #14
0
def main(argv):
    del argv

    # Hide any GPUs form TensorFlow. Otherwise TF might reserve memory and make
    # it unavailable to JAX.
    tf.config.experimental.set_visible_devices([], "GPU")

    if FLAGS.jax_backend_target:
        logging.info("Using JAX backend target %s", FLAGS.jax_backend_target)
        jax_xla_backend = ("None" if FLAGS.jax_xla_backend is None else
                           FLAGS.jax_xla_backend)
        logging.info("Using JAX XLA backend %s", jax_xla_backend)

    logging.info("JAX process: %d / %d", jax.process_index(),
                 jax.process_count())
    logging.info("JAX devices: %r", jax.devices())

    if FLAGS.is_train:
        # Add a note so that we can tell which task is which JAX host.
        # (Depending on platform task 0 is not guaranteed to be host 0)
        platform.work_unit().set_task_status(
            f"process_index: {jax.process_index()}, "
            f"process_count: {jax.process_count()}")
        platform.work_unit().create_artifact(platform.ArtifactType.DIRECTORY,
                                             FLAGS.workdir, "workdir")

        train_lib.train_and_evaluate(FLAGS.ml_config, FLAGS.workdir)
    else:
        eval_lib.evaluate(FLAGS.ml_config, FLAGS.workdir)
Example #15
0
  def update_fn(updates, state, params):
    del params

    key, subkey = jax.random.split(state.key)

    # Compute updates based on inner optimizer
    updates, inner_state = optimizer.update(updates, state.inner_state)

    prob = state.expert_weights.sum(axis=0) / state.expert_weights.sum()

    # NOTE(dsuo): we rely on jax determinism for each host to behave the same.
    current_expert = jax.random.choice(subkey, jnp.arange(prob.size), p=prob)

    # Synchronize train_losses across hosts.
    # NOTE(dsuo): since we are already insider a pmap, we can't use
    # jax.experimental.multihost_utils.
    # NOTE(dsuo): train_losses is of shape (jax.process_count(),).
    train_losses = jax.lax.all_gather(train_loss, 'batch').reshape(
        jax.process_count(), jax.local_device_count())[:, 0]

    # Compute loss regret and update expert weights.
    loss_regret = train_losses.at[current_expert].get() - train_losses
    expert_weights = state.expert_weights * jnp.exp(mw_etas * loss_regret)

    state = SamuelState(
        inner_state=inner_state,
        expert_weights=expert_weights,
        key=key,
        current_expert=current_expert,
        step=state.step + 1,
    )
    return updates, state
Example #16
0
    def _thread_func(self, temp_checkpoint_dir, final_checkpoint_dir):
        try:
            for future in self._commit_futures:
                for f in future:
                    f.result()
            logging.info('Commit to storage layer has completed.')

            current_process = jax.process_index()
            lockfiles_dir = os.path.join(temp_checkpoint_dir, 'lockfiles')
            all_lockfile_paths = [
                os.path.join(lockfiles_dir, f'lockfile_{p}')
                for p in range(jax.process_count())
            ]
            current_process_lockfile = os.path.join(
                lockfiles_dir, f'lockfile_{current_process}')

            with _RetryWithTimeout(self._timeout_secs) as t:
                while not tf.io.gfile.exists(current_process_lockfile):
                    if t.timed_out:
                        raise RuntimeError(
                            'Terminating after waiting for '
                            f'{self._timeout_secs} secs for lockfile to appear'
                        )
                    logging.info(
                        'Waiting for current process %s lockfile to appear.',
                        current_process)
                    time.sleep(60)

            tf.io.gfile.remove(current_process_lockfile)
            logging.info('Lockfile removed for process %s', current_process)

            # This while loop will not trigger until all commits have finished.
            if current_process == 0:
                with _RetryWithTimeout(self._timeout_secs) as t:
                    while True:
                        if t.timed_out:
                            raise RuntimeError(
                                'Terminating after waiting for '
                                f'{self._timeout_secs} secs for '
                                'finishing the serialization.')
                        # Mark as done when no lockfiles exist.
                        if no_lockfiles_exists(all_lockfile_paths):
                            tf.io.gfile.rmtree(lockfiles_dir)
                            logging.info('Lockfiles directory removed.')

                            logging.info('Renaming %s to %s',
                                         temp_checkpoint_dir,
                                         final_checkpoint_dir)
                            tf.io.gfile.rename(temp_checkpoint_dir,
                                               final_checkpoint_dir)
                            logging.info(
                                'Finished saving GDA checkpoint to `%s`.',
                                final_checkpoint_dir)
                            break
                        else:
                            logging.info('Thread sleeping for 60 seconds.')
                            time.sleep(60)

        except Exception as e:
            self._exception = e
Example #17
0
def get_fastmri(shuffle_rng, batch_size, eval_batch_size, hps):
    """FastMRI dataset.

  Args:
    shuffle_rng: rng for shuffling.
    batch_size: batch size.
    eval_batch_size: batch size for eval.
    hps: hyperparameters.

  Returns:
    An init2winit Dataset.
  """
    per_host_batch_size = batch_size // jax.process_count()
    per_host_eval_batch_size = eval_batch_size // jax.process_count()

    train_ds = load_split(per_host_batch_size, 'train', hps, shuffle_rng)
    train_ds = tfds.as_numpy(train_ds)

    # NOTE(dsuo): fastMRI has fixed randomness for eval.
    eval_train_ds = load_split(per_host_eval_batch_size, 'eval_train', hps,
                               shuffle_rng)
    eval_train_ds = tfds.as_numpy(eval_train_ds)
    eval_ds = load_split(per_host_eval_batch_size, 'val', hps, shuffle_rng)
    eval_ds = tfds.as_numpy(eval_ds)

    def train_iterator_fn():
        return train_ds

    def eval_train_epoch(num_batches=None):
        for batch in itertools.islice(eval_train_ds, num_batches):
            yield data_utils.maybe_pad_batch(batch, per_host_eval_batch_size)

    def valid_epoch(num_batches=None):
        for batch in itertools.islice(eval_ds, num_batches):
            yield data_utils.maybe_pad_batch(batch, per_host_eval_batch_size)

    # pylint: disable=unreachable
    def test_epoch(*args, **kwargs):
        del args
        del kwargs
        return
        yield  # This yield is needed to make this a valid (null) iterator.

    # pylint: enable=unreachable

    return data_utils.Dataset(train_iterator_fn, eval_train_epoch, valid_epoch,
                              test_epoch)
Example #18
0
    def _write_lockfiles(self, temp_checkpoint_dir):
        lockfiles_dir = os.path.join(temp_checkpoint_dir, 'lockfiles')
        tf.io.gfile.mkdir(lockfiles_dir)

        write_count = 0
        for p in range(jax.process_count()):
            with tf.io.gfile.GFile(os.path.join(lockfiles_dir,
                                                f'lockfile_{p}'),
                                   mode='w') as f:
                f.write('File to track if all chunks have been written.')
            write_count += 1

        if write_count != jax.process_count():
            raise ValueError("Process 0 couldn't write all the lockfiles.")

        logging.info(
            'Lock files for all processes have been written by process 0.')
Example #19
0
def _get_uniref(per_host_batch_size, per_host_eval_batch_size, hps, data_rng):
    """Data generators for Uniref50 clustered protein dataset."""
    # TODO(gilmer) Currently uniref drops the last partial batch on eval.
    logging.warning(
        'Currently the Protein dataset drops the last partial batch on eval')
    if jax.process_count() > 1:
        raise NotImplementedError(
            'Proteins does not support multihost training')

    n_devices = jax.local_device_count()
    if per_host_batch_size % n_devices != 0:
        raise ValueError(
            'n_devices={} must divide per_host_batch_size={}.'.format(
                n_devices, per_host_batch_size))
    if per_host_eval_batch_size % n_devices != 0:
        raise ValueError(
            'n_devices={} must divide per_host_eval_batch_size={}.'.format(
                n_devices, per_host_eval_batch_size))

    train_ds, eval_ds, vocab = load_dataset(
        hps.data_name,
        batch_size=per_host_batch_size,
        eval_batch_size=per_host_eval_batch_size,
        length=hps.max_target_length)

    masker = BertMasker(vocab=vocab)

    def train_iterator_fn():
        for batch_index, batch in enumerate(iter(train_ds)):
            batch_rng = jax.random.fold_in(data_rng, batch_index)
            yield _batch_to_dict(batch, masker, 'train', batch_rng)

    def eval_train_epoch(num_batches=None):
        eval_train_iter = iter(train_ds)
        for batch_index, batch in enumerate(
                itertools.islice(eval_train_iter, num_batches)):
            batch_rng = jax.random.fold_in(data_rng, batch_index)
            yield _batch_to_dict(batch, masker, 'eval', batch_rng)

    def valid_epoch(num_batches=None):
        valid_iter = iter(eval_ds)
        for batch_index, batch in enumerate(
                itertools.islice(valid_iter, num_batches)):
            batch_rng = jax.random.fold_in(data_rng, batch_index)
            yield _batch_to_dict(batch, masker, 'eval', batch_rng)

    # pylint: disable=unreachable
    def test_epoch(*args, **kwargs):
        del args
        del kwargs
        return
        yield  # This yield is needed to make this a valid (null) iterator.

    # pylint: enable=unreachable

    return Dataset(train_iterator_fn, eval_train_epoch, valid_epoch,
                   test_epoch)
def get_split(rng,
              builder,
              split,
              batch_size,
              num_epochs=None,
              shuffle_buffer_size=None,
              repeat_after=False,
              cache=False):
    """Loads a audio dataset and shifts audio values to be positive.

  Args:
    rng: JAX PRNGKey random number generator state.
    builder: TFDS dataset builder instance.
    split: TFDS split to load.
    batch_size: Global batch size.
    num_epochs: Number of epochs. None to repeat forever.
    shuffle_buffer_size: Size of the shuffle buffer. If None, data is not
      shuffled.
    repeat_after: If True, the dataset is repeated infinitely *after* CLU.
    cache: If True, the dataset is cached prior to pre-processing.

  Returns:
    Audio datasets with `inputs` and `label` features. The former is shifted to
    be non-negative.
  """
    host_count = jax.process_count()
    if batch_size % host_count != 0:
        raise ValueError(
            f'Batch size ({batch_size}) must be divisible by the host'
            f' count ({host_count}).')
    batch_size = batch_size // host_count
    device_count = jax.local_device_count()
    if batch_size % device_count != 0:
        raise ValueError(
            f'Local batch size ({batch_size}) must be divisible by the'
            f' local device count ({device_count}).')
    batch_dims = [device_count, batch_size // device_count]

    host_split = data.get_read_instruction_for_host(
        split,
        dataset_info=builder.info,
        remainder_options=data.RemainderOptions.BALANCE_ON_PROCESSES)
    ds = data.create_dataset(builder,
                             split=host_split,
                             preprocess_fn=PrepareAudio(),
                             cache=cache,
                             batch_dims=batch_dims,
                             rng=rng,
                             num_epochs=num_epochs,
                             pad_up_to_batches='auto',
                             shuffle=shuffle_buffer_size
                             and (shuffle_buffer_size > 0),
                             shuffle_buffer_size=shuffle_buffer_size or 0)
    if repeat_after:
        ds = ds.repeat()
    return ds
Example #21
0
def get_fake(shuffle_rng, batch_size, eval_batch_size, hps=None):
    """Data generators for imagenet."""
    del shuffle_rng
    per_host_batch_size = batch_size // jax.process_count()
    per_host_eval_batch_size = eval_batch_size // jax.process_count()

    fake_train_batch = get_fake_batch(per_host_batch_size, hps.input_shape,
                                      hps.output_shape[0])
    fake_test_batch = get_fake_batch(per_host_eval_batch_size, hps.input_shape,
                                     hps.output_shape[0])

    def train_iterator_fn():
        while True:
            yield fake_train_batch

    def valid_epoch(epoch, num_batches=None):
        del num_batches
        del epoch
        # Note that we do // beacuse we do not support partial batching for the fake
        # dataset.
        for _ in range(hps.valid_size // eval_batch_size):
            yield fake_test_batch

    # pylint: disable=unreachable
    def eval_train_epoch(*args, **kwargs):
        del args
        del kwargs
        return
        yield  # This yield is needed to make this a valid (null) iterator.

    # pylint: enable=unreachable
    # pylint: disable=unreachable

    def test_epoch(*args, **kwargs):
        del args
        del kwargs
        return
        yield  # This yield is needed to make this a valid (null) iterator.

    # pylint: enable=unreachable

    return data_utils.Dataset(train_iterator_fn, eval_train_epoch, valid_epoch,
                              test_epoch)
Example #22
0
def main(unused_argv):
    # Necessary to use the tfds loader.
    tf.enable_v2_behavior()

    if jax.process_count() > 1:
        # TODO(ankugarg): Add support for multihost inference.
        raise NotImplementedError(
            'BLEU eval does not support multihost inference.')

    rng = jax.random.PRNGKey(FLAGS.seed)

    mt_eval_config = json.loads(FLAGS.mt_eval_config)

    if FLAGS.experiment_config_filename:
        with tf.io.gfile.GFile(FLAGS.experiment_config_filename) as f:
            experiment_config = json.load(f)
        if jax.process_index() == 0:
            logging.info('experiment_config: %r', experiment_config)
        dataset_name = experiment_config['dataset']
        model_name = experiment_config['model']
    else:
        assert FLAGS.dataset and FLAGS.model
        dataset_name = FLAGS.dataset
        model_name = FLAGS.model

    if jax.process_index() == 0:
        logging.info('argv:\n%s', ' '.join(sys.argv))
        logging.info('device_count: %d', jax.device_count())
        logging.info('num_hosts : %d', jax.host_count())
        logging.info('host_id : %d', jax.host_id())

    model_class = models.get_model(model_name)
    dataset_builder = datasets.get_dataset(dataset_name)
    dataset_meta_data = datasets.get_dataset_meta_data(dataset_name)

    hparam_overrides = None
    if FLAGS.hparam_overrides:
        if isinstance(FLAGS.hparam_overrides, str):
            hparam_overrides = json.loads(FLAGS.hparam_overrides)

    merged_hps = hyperparameters.build_hparams(
        model_name=model_name,
        initializer_name=experiment_config['initializer'],
        dataset_name=dataset_name,
        hparam_file=FLAGS.trial_hparams_filename,
        hparam_overrides=hparam_overrides)

    if jax.process_index() == 0:
        logging.info('Merged hps are: %s', json.dumps(merged_hps.to_json()))

    evaluator = bleu_evaluator.BLEUEvaluator(FLAGS.checkpoint_dir, merged_hps,
                                             rng, model_class, dataset_builder,
                                             dataset_meta_data, mt_eval_config)
    evaluator.translate_and_calculate_bleu()
Example #23
0
def main(unused_argv):
  # Necessary to use the tfds imagenet loader.
  tf.enable_v2_behavior()


  rng = jax.random.PRNGKey(FLAGS.seed)

  if FLAGS.hessian_eval_config:
    hessian_eval_config = json.loads(FLAGS.hessian_eval_config)
  else:
    hessian_eval_config = hessian_eval.DEFAULT_EVAL_CONFIG

  if FLAGS.experiment_config_filename:
    with tf.io.gfile.GFile(FLAGS.experiment_config_filename, 'r') as f:
      experiment_config = json.load(f)
    if jax.process_index() == 0:
      logging.info('experiment_config: %r', experiment_config)
    dataset_name = experiment_config['dataset']
    model_name = experiment_config['model']
  else:
    assert FLAGS.dataset and FLAGS.model
    dataset_name = FLAGS.dataset
    model_name = FLAGS.model

  if jax.process_index() == 0:
    logging.info('argv:\n%s', ' '.join(sys.argv))
    logging.info('device_count: %d', jax.device_count())
    logging.info('num_hosts : %d', jax.process_count())
    logging.info('host_id : %d', jax.process_index())

  model = models.get_model(model_name)
  dataset_builder = datasets.get_dataset(dataset_name)
  dataset_meta_data = datasets.get_dataset_meta_data(dataset_name)

  with tf.io.gfile.GFile(FLAGS.trial_hparams_filename, 'r') as f:
    hps = config_dict.ConfigDict(json.load(f))

  if FLAGS.hparam_overrides:
    if isinstance(FLAGS.hparam_overrides, str):
      hparam_overrides = json.loads(FLAGS.hparam_overrides)
    hps.update_from_flattened_dict(hparam_overrides)
  run_lanczos.eval_checkpoints(
      FLAGS.checkpoint_dir,
      hps,
      rng,
      FLAGS.eval_num_batches,
      model,
      dataset_builder,
      dataset_meta_data,
      hessian_eval_config,
      FLAGS.min_global_step,
      FLAGS.max_global_step)
 def load_array(self, pattern: str):
     """Load sharded array as if it was loaded from multiple processes."""
     process_count = jax.process_count()
     arrays = []
     for process_index in range(process_count):
         arrays.append(
             data_utils.load_sharded_array(
                 pattern, process_count * self.memory_reduction,
                 process_index))
     array = np.stack(arrays, axis=0)
     shape = (-1, ) + arrays[0].shape[1:]
     array = array.reshape(shape)
     return array
Example #25
0
def main(argv):
    if len(argv) > 1:
        raise app.UsageError("Too many command-line arguments.")

    # Hide any GPUs from TensorFlow. Otherwise TF might reserve memory and make
    # it unavailable to JAX.
    tf.config.experimental.set_visible_devices([], "GPU")

    # Log available devices. This code only supports a single host currently.
    logging.info("JAX host: %d / %d", jax.process_index(), jax.process_count())
    logging.info("JAX local devices: %r", jax.local_devices())

    render_and_save()
Example #26
0
  def _thread_func(self, temp_checkpoint_dir, final_checkpoint_dir):
    try:
      for future in self._commit_futures:
        for f in future:
          f.result()
      logging.info('Commit to storage layer has completed.')

      current_process = jax.process_index()

      # TODO(yashkatariya): Add a method to the distributed system to wait for
      # the key's value to change.
      # Value already exists -- wait until value is NOT _REMOVED_VALUE.
      with _RetryWithTimeout(self._timeout_secs) as t:
        while self._client.blocking_key_value_get(
            _get_key(str(current_process)), self._timeout_in_ms) == _REMOVED_VALUE:
          if t.timed_out:
            raise TimeoutError('Terminating after waiting for '
                               f'{self._timeout_secs} secs for lock value to appear.')
          logging.info('Waiting for current process %s lock value to appear.',
                       current_process)
          time.sleep(60)

      self._client.key_value_set(_get_key(str(current_process)), _REMOVED_VALUE)
      logging.info('Lock value removed for process %s', current_process)

      # This while loop will not trigger until all commits have finished.
      if current_process == 0:
        with _RetryWithTimeout(self._timeout_secs) as t:
          while True:
            if t.timed_out:
              raise TimeoutError('Terminating after waiting for '
                                 f'{self._timeout_secs} secs for '
                                 'finishing the serialization.')
            # Mark as done when no lock values exist.
            if all(
                self._client.blocking_key_value_get(
                    _get_key(str(p)), self._timeout_in_ms) == _REMOVED_VALUE
                for p in range(jax.process_count())):
              logging.info('Renaming %s to %s', temp_checkpoint_dir, final_checkpoint_dir)
              tf.io.gfile.rename(temp_checkpoint_dir, final_checkpoint_dir)
              logging.info('Finished saving GDA checkpoint to `%s`.', final_checkpoint_dir)
              self._client.key_value_set(_get_key(self._final_ckpt_dir), _CHECKPOINT_SUCCESS)
              break
            else:
              logging.info('Thread sleeping for 60 seconds.')
              time.sleep(60)

    except Exception as e:
      self._exception = e
Example #27
0
def create_eval_dataset(
    config,
    dataset_builder,
    split,
    preprocess_fn = None,
):
  """Create evaluation dataset (validation or test sets)."""
  # This ensures the correct number of elements in the validation sets.
  num_validation_examples = (
      dataset_builder.info.splits[split].num_examples)
  eval_split = deterministic_data.get_read_instruction_for_host(
      split, dataset_info=dataset_builder.info, drop_remainder=False)

  eval_num_batches = None
  if config.eval_pad_last_batch:
    # This is doing some extra work to get exactly all examples in the
    # validation split. Without this the dataset would first be split between
    # the different hosts and then into batches (both times dropping the
    # remainder). If you don't mind dropping a few extra examples you can omit
    # the `pad_up_to_batches` argument.
    eval_batch_size = jax.local_device_count() * config.per_device_batch_size
    eval_num_batches = int(np.ceil(num_validation_examples /
                                   eval_batch_size /
                                   jax.process_count()))
  return deterministic_data.create_dataset(
      dataset_builder,
      split=eval_split,
      # Only cache dataset in distributed setup to avoid consuming a lot of
      # memory in Colab and unit tests.
      cache=jax.process_count() > 1,
      batch_dims=[jax.local_device_count(), config.per_device_batch_size],
      num_epochs=1,
      shuffle=False,
      preprocess_fn=preprocess_fn,
      pad_up_to_batches=eval_num_batches,
  )
def get_criteo1tb(unused_shuffle_rng, batch_size, eval_batch_size, hps):
    """Get the Criteo 1TB train and eval iterators."""
    process_count = jax.process_count()
    if batch_size % process_count != 0:
        raise ValueError('process_count={} must divide batch_size={}.'.format(
            process_count, batch_size))
    if eval_batch_size is None:
        eval_batch_size = batch_size
    if eval_batch_size % process_count != 0:
        raise ValueError(
            'process_count={} must divide eval_batch_size={}.'.format(
                process_count, eval_batch_size))
    per_host_eval_batch_size = eval_batch_size // process_count
    per_host_batch_size = batch_size // process_count
    train_dataset = _criteo_tsv_reader(
        file_path=hps.train_file_path,
        num_dense_features=hps.num_dense_features,
        vocab_sizes=hps.vocab_sizes,
        batch_size=per_host_batch_size,
        is_training=True)
    train_iterator_fn = lambda: tfds.as_numpy(train_dataset)
    eval_train_dataset = _criteo_tsv_reader(
        file_path=hps.train_file_path,
        num_dense_features=hps.num_dense_features,
        vocab_sizes=hps.vocab_sizes,
        batch_size=per_host_eval_batch_size,
        is_training=False)
    eval_train_epoch = functools.partial(convert_to_numpy_iterator_fn,
                                         tf_dataset=eval_train_dataset)
    eval_dataset = _criteo_tsv_reader(
        file_path=hps.eval_file_path,
        num_dense_features=hps.num_dense_features,
        vocab_sizes=hps.vocab_sizes,
        batch_size=per_host_eval_batch_size,
        is_training=False)
    eval_iterator_fn = functools.partial(convert_to_numpy_iterator_fn,
                                         tf_dataset=eval_dataset)

    # pylint: disable=unreachable
    def test_epoch(*args, **kwargs):
        del args
        del kwargs
        return
        yield  # This yield is needed to make this a valid (null) iterator.

    # pylint: enable=unreachable
    return Dataset(train_iterator_fn, eval_train_epoch, eval_iterator_fn,
                   test_epoch)
Example #29
0
def decode(
    model_name: str,
    job_log_dir: Optional[str],
    multi_host_checkpointing: Optional[bool],
    maybe_use_persistence_checkpointing: bool,
    restore_checkpoint_dir: Optional[str],
    restore_checkpoint_step: Optional[int],
    continuous_decode: bool,
) -> None:
    """Runs decoding once on the decoder datasets.

  Args:
    model_name: The name of the model from the registry to evaluate.
    job_log_dir: The directory for the job logs.
    multi_host_checkpointing: Whether to use multi-host checkpointing.
    maybe_use_persistence_checkpointing: If set, it will try to use
      persistence-based checkpointing if suitable.
    restore_checkpoint_dir: The directory from which to restore checkpoint.
    restore_checkpoint_step: If set, the checkpoint step to restore. If unset,
      try to restore from the latest checkpoint if any.
    continuous_decode: whether to continuously decode on the latest ckpt.
  """
    logging.info('running decode_once on model %s restored from %s',
                 model_name, restore_checkpoint_dir)
    model_config = model_utils.get_model(model_name)()
    task_p = model_config.task()
    model_p = task_p.model
    decoder_inputs = model_config.decoder_datasets()
    if not decoder_inputs:
        return
    for inp in decoder_inputs:
        inp.num_infeed_hosts = jax.process_count()
        inp.infeed_host_index = jax.process_index()

    if model_p.device_mesh is not None:
        if continuous_decode:
            raise NotImplementedError('http://b/214589358: not supported')
        checkpoint_type = checkpoints.retrieve_checkpoint_type(
            multi_host_checkpointing, maybe_use_persistence_checkpointing,
            task_p)
        decode_once_spmd_model(task_p, decoder_inputs, job_log_dir,
                               checkpoint_type, restore_checkpoint_dir,
                               restore_checkpoint_step)
    else:
        decode_pmap_model(task_p, decoder_inputs, job_log_dir,
                          restore_checkpoint_dir, restore_checkpoint_step,
                          continuous_decode)
Example #30
0
    def __init__(self, loss_fn, optimizer, devices=None, has_graph=False):
        self._net_init_fn, self._apply_fn = hk.transform_with_state(
            functools.partial(loss_fn, is_training=True))
        _, self._eval_apply_fn = hk.transform_with_state(
            functools.partial(loss_fn, is_training=False))

        if optimizer is None:
            optimizer = optax.identity()
        self._optimizer = optimizer

        self._num_devices = jax.local_device_count()
        if devices is None:
            devices = []
            for host_id in range(jax.process_count()):
                for device_id in jax.local_devices(host_id):
                    devices.append(device_id)
        else:
            self._num_devices = min(self._num_devices, len(devices))

        def _pmap(f, static_broadcasted_argnums=()):
            return jax.pmap(
                f,
                axis_name='i',
                devices=devices,
                static_broadcasted_argnums=static_broadcasted_argnums)

        def handle_graph_size(fn):
            def _fn(*args):
                batch = args[-1].copy()
                max_graph_size = batch['max_graph_size']
                del batch['max_graph_size']
                args = args[:-1] + (batch, max_graph_size)
                return fn(*args)

            return _fn

        # Try to jit.
        if has_graph:
            # If the model contains full graphs, we need to set the max_graph_size
            # as a statically broadcasted argument.
            self._init_fn = handle_graph_size(_pmap(self._init, 4))
            self._update_fn = handle_graph_size(_pmap(self._update, 2))
            self._eval_fn = handle_graph_size(_pmap(self._eval, 2))
        else:
            self._init_fn = _pmap(self._init)
            self._update_fn = _pmap(self._update)
            self._eval_fn = _pmap(self._eval)