Esempio n. 1
0
 def _delete_checkpoint(
         self, checkpoint: checkpoint_pb2.CheckpointMetadata) -> None:
     """Deletes the checkpoint files for a given checkpoint."""
     logging.info('Deleting checkpoint: %s %s', checkpoint.timestamp_sec,
                  self._root_dir)
     if (self._checkpoint_history.checkpoint_type ==
             CheckpointType.CHECKPOINT_FLAX):
         if jax.process_index() != 0:
             return
         self._delete_pattern_if_exists(
             self._root_dir,
             f'{CHECKPOINT_PREFIX}{checkpoint.global_step_id}')
     elif (self._checkpoint_history.checkpoint_type ==
           CheckpointType.CHECKPOINT_MULTI_HOST_FLAX):
         root_dir = os.path.join(self._root_dir,
                                 f'{jax.process_index():03d}')
         self._delete_pattern_if_exists(
             root_dir, f'{CHECKPOINT_PREFIX}{checkpoint.global_step_id}')
     elif self._checkpoint_history.checkpoint_type in {
             CheckpointType.CHECKPOINT_PERSISTENCE,
             CheckpointType.CHECKPOINT_GDA,
     }:
         if jax.process_index() != 0:
             return
         self._delete_pattern_if_exists(
             self._root_dir,
             f'{CHECKPOINT_PREFIX}{checkpoint.global_step_id:08d}')
Esempio n. 2
0
def process_iterator(tag: str,
                     item_ids: Sequence[str],
                     iterator,
                     rng: types.PRNGKey,
                     state: model_utils.TrainState,
                     step: int,
                     render_fn: Any,
                     summary_writer: tensorboard.SummaryWriter,
                     save_dir: Optional[gpath.GPath],
                     datasource: datasets.DataSource):
  """Process a dataset iterator and compute metrics."""
  save_dir = save_dir / f'{step:08d}' / tag if save_dir else None
  meters = collections.defaultdict(utils.ValueMeter)
  for i, (item_id, batch) in enumerate(zip(item_ids, iterator)):
    logging.info('[%s:%d/%d] Processing %s ', tag, i+1, len(item_ids), item_id)
    if tag == 'test':
      test_rng = random.PRNGKey(step)
      shape = batch['origins'][..., :1].shape
      metadata = {}
      if datasource.use_appearance_id:
        appearance_id = random.choice(
            test_rng, jnp.asarray(datasource.appearance_ids))
        logging.info('\tUsing appearance_id = %d', appearance_id)
        metadata['appearance'] = jnp.full(shape, fill_value=appearance_id,
                                          dtype=jnp.uint32)
      if datasource.use_warp_id:
        warp_id = random.choice(test_rng, jnp.asarray(datasource.warp_ids))
        logging.info('\tUsing warp_id = %d', warp_id)
        metadata['warp'] = jnp.full(shape, fill_value=warp_id, dtype=jnp.uint32)
      if datasource.use_camera_id:
        camera_id = random.choice(test_rng, jnp.asarray(datasource.camera_ids))
        logging.info('\tUsing camera_id = %d', camera_id)
        metadata['camera'] = jnp.full(shape, fill_value=camera_id,
                                      dtype=jnp.uint32)
      if datasource.use_time:
        timestamp = random.uniform(test_rng, minval=0.0, maxval=1.0)
        logging.info('\tUsing time = %d', timestamp)
        metadata['time'] = jnp.full(
            shape, fill_value=timestamp, dtype=jnp.uint32)
      batch['metadata'] = metadata

    stats = process_batch(batch=batch,
                          rng=rng,
                          state=state,
                          tag=tag,
                          item_id=item_id,
                          step=step,
                          render_fn=render_fn,
                          summary_writer=summary_writer,
                          save_dir=save_dir,
                          datasource=datasource)
    if jax.process_index() == 0:
      for k, v in stats.items():
        meters[k].update(v)

  if jax.process_index() == 0:
    for meter_name, meter in meters.items():
      summary_writer.scalar(tag=f'metrics-eval/{meter_name}/{tag}',
                            value=meter.reduce('mean'),
                            step=step)
Esempio n. 3
0
def meta_init(loss_fn,
              flax_module,
              params,
              hps,
              input_shape,
              output_shape,
              rng_key,
              metrics_logger=None,
              log_every=10):
  """Implements MetaInit initializer.

  Args:
    loss_fn: Loss function.
    flax_module: Flax nn.Module class.
    params: The dict of model parameters.
    hps: HParam object. Required hparams are meta_learning_rate,
      meta_batch_size, meta_steps, and epsilon.
    input_shape: Must agree with batch[0].shape[1:].
    output_shape: Must agree with batch[1].shape[1:].
    rng_key: jax.PRNGKey, used to seed all randomness.
    metrics_logger: Instance of utils.MetricsLogger
    log_every: Print meta loss every k steps.

  Returns:
    A Flax module with the learned initialization.
  """
  # Pretty print the preinitialized norms with the variable shapes.
  if jax.process_index() == 0:
    logging.info('Preinitialized norms:')
    _log_shape_and_norms(params, metrics_logger, key='init_norms')
    # First grab the norms of all weights and rescale params to have norm 1.
    logging.info('Running meta init')
  norms = jax.tree_map(lambda node: jnp.linalg.norm(node.reshape(-1)),
                       params)

  normalized_params = jax.tree_map(normalize, params)

  learned_norms, _ = meta_optimize_scales(
      loss_fn,
      flax_module.apply,
      normalized_params,
      norms,
      hps,
      input_shape,
      output_shape,
      rng_key,
      metrics_logger=metrics_logger,
      log_every=log_every)
  new_params = scale_params(normalized_params, learned_norms)

  if jax.process_index() == 0:
    # Pretty print the meta init norms with the variable shapes.
    logging.info('Learned norms from meta_init:')
    _log_shape_and_norms(new_params, metrics_logger, key='meta_init_norms')

  return new_params
Esempio n. 4
0
def main(unused_argv):
    # Necessary to use the tfds loader.
    tf.enable_v2_behavior()

    if jax.process_count() > 1:
        # TODO(ankugarg): Add support for multihost inference.
        raise NotImplementedError(
            'BLEU eval does not support multihost inference.')

    rng = jax.random.PRNGKey(FLAGS.seed)

    mt_eval_config = json.loads(FLAGS.mt_eval_config)

    if FLAGS.experiment_config_filename:
        with tf.io.gfile.GFile(FLAGS.experiment_config_filename) as f:
            experiment_config = json.load(f)
        if jax.process_index() == 0:
            logging.info('experiment_config: %r', experiment_config)
        dataset_name = experiment_config['dataset']
        model_name = experiment_config['model']
    else:
        assert FLAGS.dataset and FLAGS.model
        dataset_name = FLAGS.dataset
        model_name = FLAGS.model

    if jax.process_index() == 0:
        logging.info('argv:\n%s', ' '.join(sys.argv))
        logging.info('device_count: %d', jax.device_count())
        logging.info('num_hosts : %d', jax.host_count())
        logging.info('host_id : %d', jax.host_id())

    model_class = models.get_model(model_name)
    dataset_builder = datasets.get_dataset(dataset_name)
    dataset_meta_data = datasets.get_dataset_meta_data(dataset_name)

    hparam_overrides = None
    if FLAGS.hparam_overrides:
        if isinstance(FLAGS.hparam_overrides, str):
            hparam_overrides = json.loads(FLAGS.hparam_overrides)

    merged_hps = hyperparameters.build_hparams(
        model_name=model_name,
        initializer_name=experiment_config['initializer'],
        dataset_name=dataset_name,
        hparam_file=FLAGS.trial_hparams_filename,
        hparam_overrides=hparam_overrides)

    if jax.process_index() == 0:
        logging.info('Merged hps are: %s', json.dumps(merged_hps.to_json()))

    evaluator = bleu_evaluator.BLEUEvaluator(FLAGS.checkpoint_dir, merged_hps,
                                             rng, model_class, dataset_builder,
                                             dataset_meta_data, mt_eval_config)
    evaluator.translate_and_calculate_bleu()
Esempio n. 5
0
    def compute_interpolations(self, params, gdirs, udirs, hvex, cvex, step):
        """Compute the linear interpolation along directions of gdirs or udirs."""
        row = {'step': step}
        if not self.eval_config['compute_interps']:
            return row
        lower = self.eval_config['lower_thresh']
        upper = self.eval_config['upper_thresh']
        num_points = self.eval_config['num_points']
        etas = np.linspace(lower, upper, num=num_points, endpoint=True)
        row = {'step_size': etas}
        for i, u_dir in enumerate(gdirs):
            u_dir = _tree_normalize(u_dir)
            loss_values = np.zeros(shape=(num_points, ))
            for j in range(num_points):
                eta = etas[j]
                loss_values[j] = self._full_batch_eval(params, u_dir, eta)
            row['loss%d' % (i, )] = np.copy(loss_values)
        if jax.process_index() == 0:
            logging.info('Loss interpolation along gradients finished.')

        for i, u_dir in enumerate(udirs):
            u_dir = _tree_normalize(u_dir)
            loss_values = np.zeros(shape=(num_points, ))
            for j in range(num_points):
                eta = etas[j]
                loss_values[j] = self._full_batch_eval(params, u_dir, eta)
            row['loss_u%d' % (i, )] = np.copy(loss_values)
        if jax.process_index() == 0:
            logging.info(
                'Loss interpolation along optimizer directions finished.')

        _, unflatten = ravel_pytree(gdirs[0])
        for i, u_dir in enumerate(hvex):
            loss_values = np.zeros(shape=(num_points, ))
            u_dir = unflatten(u_dir)
            for j in range(num_points):
                eta = etas[j]
                loss_values[j] = self._full_batch_eval(params, u_dir, eta)
            row['loss_hvec%d' % (i, )] = np.copy(loss_values)

        for i, u_dir in enumerate(cvex):
            loss_values = np.zeros(shape=(num_points, ))
            u_dir = unflatten(u_dir)
            for j in range(num_points):
                eta = etas[j]
                loss_values[j] = self._full_batch_eval(params, u_dir, eta)
            row['loss_cvec%d' % (i, )] = np.copy(loss_values)

        if jax.process_index() == 0:
            logging.info('Loss interpolations finished. Statistics captured:')
            logging.info(row.keys())
        return row
Esempio n. 6
0
def main(unused_argv):
  # Necessary to use the tfds imagenet loader.
  tf.enable_v2_behavior()


  rng = jax.random.PRNGKey(FLAGS.seed)

  if FLAGS.hessian_eval_config:
    hessian_eval_config = json.loads(FLAGS.hessian_eval_config)
  else:
    hessian_eval_config = hessian_eval.DEFAULT_EVAL_CONFIG

  if FLAGS.experiment_config_filename:
    with tf.io.gfile.GFile(FLAGS.experiment_config_filename, 'r') as f:
      experiment_config = json.load(f)
    if jax.process_index() == 0:
      logging.info('experiment_config: %r', experiment_config)
    dataset_name = experiment_config['dataset']
    model_name = experiment_config['model']
  else:
    assert FLAGS.dataset and FLAGS.model
    dataset_name = FLAGS.dataset
    model_name = FLAGS.model

  if jax.process_index() == 0:
    logging.info('argv:\n%s', ' '.join(sys.argv))
    logging.info('device_count: %d', jax.device_count())
    logging.info('num_hosts : %d', jax.process_count())
    logging.info('host_id : %d', jax.process_index())

  model = models.get_model(model_name)
  dataset_builder = datasets.get_dataset(dataset_name)
  dataset_meta_data = datasets.get_dataset_meta_data(dataset_name)

  with tf.io.gfile.GFile(FLAGS.trial_hparams_filename, 'r') as f:
    hps = config_dict.ConfigDict(json.load(f))

  if FLAGS.hparam_overrides:
    if isinstance(FLAGS.hparam_overrides, str):
      hparam_overrides = json.loads(FLAGS.hparam_overrides)
    hps.update_from_flattened_dict(hparam_overrides)
  run_lanczos.eval_checkpoints(
      FLAGS.checkpoint_dir,
      hps,
      rng,
      FLAGS.eval_num_batches,
      model,
      dataset_builder,
      dataset_meta_data,
      hessian_eval_config,
      FLAGS.min_global_step,
      FLAGS.max_global_step)
Esempio n. 7
0
def evaluate(
    model_name: str,
    job_log_dir: Optional[str],
    multi_host_checkpointing: Optional[bool],
    maybe_use_persistence_checkpointing: bool,
) -> None:
  """Runs the evaluation loop on the entire eval data set.

  Args:
    model_name: The name of the model from the registry to evaluate.
    job_log_dir: The directory for the job logs.
    multi_host_checkpointing: Whether to use multi-host checkpointing.
    maybe_use_persistence_checkpointing: If set, it will try to use
      persistence-based checkpointing if suitable.
  """
  model_config = model_utils.get_model(model_name)()
  task_p = model_config.task()
  model_p = task_p.model
  eval_input_p = [v for v in model_config.datasets() if not v.is_training]
  for inp in eval_input_p:
    inp.num_infeed_hosts = jax.process_count()
    inp.infeed_host_index = jax.process_index()

  if model_p.device_mesh is not None:
    checkpoint_type = checkpoints.retrieve_checkpoint_type(
        multi_host_checkpointing, maybe_use_persistence_checkpointing, task_p)
    evaluate_spmd_model(task_p, eval_input_p, job_log_dir, checkpoint_type)
  else:
    evaluate_pmap_model(task_p, eval_input_p, job_log_dir)
Esempio n. 8
0
    def _thread_func(self, temp_checkpoint_dir, final_checkpoint_dir):
        try:
            for future in self._commit_futures:
                for f in future:
                    f.result()

            current_process = jax.process_index()
            logging.info(
                'Commit to storage layer has completed by process: %s',
                current_process)

            # All processes will wait at the barrier. When all processes are at the
            # barrier, the barrier will be satisfied. If not, then it will timeout.
            self._client.wait_at_barrier(self._final_ckpt_dir,
                                         self._timeout_in_ms)
            logging.info('Finished waiting at barrier for process %s',
                         current_process)

            if current_process == 0:
                logging.info('Renaming %s to %s', temp_checkpoint_dir,
                             final_checkpoint_dir)
                epath.Path(temp_checkpoint_dir).rename(final_checkpoint_dir)
                logging.info('Finished saving GDA checkpoint to `%s`.',
                             final_checkpoint_dir)
                self._client.key_value_set(_get_key(self._final_ckpt_dir),
                                           _CHECKPOINT_SUCCESS)
        except Exception as e:
            self._exception = e
Esempio n. 9
0
def main(argv):
    del argv

    # Hide any GPUs form TensorFlow. Otherwise, TF might reserve memory and make
    # it unavailable to JAX.
    tf.config.experimental.set_visible_devices([], "GPU")

    logging.info("JAX host: %d / %d", jax.process_index(), jax.process_count())
    logging.info("JAX devices: %r", jax.devices())

    # Add a note so that we can tell which task is which JAX host. (Task 0 is not
    # guaranteed to be host 0)
    platform.work_unit().set_task_status(
        f"host_id: {jax.process_index()}, host_count: {jax.process_count()}")
    platform.work_unit().create_artifact(platform.ArtifactType.DIRECTORY,
                                         _WORKDIR.value, "workdir")

    train_mode = _CONFIG.value.mode
    if train_mode == TrainingMode.PRETRAINING:
        train_lib = run_pretraining
    elif train_mode == TrainingMode.CLASSIFICATION:
        train_lib = run_classifier
    else:
        raise ValueError("Unknown mode: %s" % train_mode)

    train_lib.train_and_evaluate(_CONFIG.value, _WORKDIR.value,
                                 _VOCAB_FILEPATH.value)
Esempio n. 10
0
def main(argv):
    del argv

    # Hide any GPUs form TensorFlow. Otherwise TF might reserve memory and make
    # it unavailable to JAX.
    tf.config.experimental.set_visible_devices([], "GPU")

    if FLAGS.jax_backend_target:
        logging.info("Using JAX backend target %s", FLAGS.jax_backend_target)
        jax_xla_backend = ("None" if FLAGS.jax_xla_backend is None else
                           FLAGS.jax_xla_backend)
        logging.info("Using JAX XLA backend %s", jax_xla_backend)

    logging.info("JAX process: %d / %d", jax.process_index(),
                 jax.process_count())
    logging.info("JAX devices: %r", jax.devices())

    if FLAGS.is_train:
        # Add a note so that we can tell which task is which JAX host.
        # (Depending on platform task 0 is not guaranteed to be host 0)
        platform.work_unit().set_task_status(
            f"process_index: {jax.process_index()}, "
            f"process_count: {jax.process_count()}")
        platform.work_unit().create_artifact(platform.ArtifactType.DIRECTORY,
                                             FLAGS.workdir, "workdir")

        train_lib.train_and_evaluate(FLAGS.ml_config, FLAGS.workdir)
    else:
        eval_lib.evaluate(FLAGS.ml_config, FLAGS.workdir)
Esempio n. 11
0
def restore_checkpoint(work_dir):
    """Given a valid dir, restores a checkpoint and returns ALSState object."""
    print(f"Attempting restore_checkpoint from dir: {work_dir}.")

    # Each host stores state in a seperate subdir.
    host_dir = multihost_utils.get_host_dir(work_dir,
                                            host_id=jax.process_index())

    # First retore to host then device_put sharded array.
    state = checkpoints.restore_checkpoint(host_dir, target=None)

    def device_put_sharded(x):
        if not isinstance(x, (jnp.ndarray, np.ndarray)):
            return x

        # Later, device_put_sharded takes a sequence of tensors, one tensor for
        # every local device. So we split it on the zeroth (device) dimension.
        x = np.reshape(x, [jax.local_device_count(), -1, x.shape[2]])
        x_list = np.split(x, x.shape[0], axis=0)

        # Squeeze out the dummy dimension.
        x_list = jax.tree_map(lambda y: np.squeeze(y, axis=0), x_list)

        # Send the sharded array in devices.
        return jax.device_put_sharded(x_list, jax.local_devices())

    state = jax.tree_map(device_put_sharded, state)
    return state
def train_step(train_state, batch, label_smoothing):
    def loss_fn(params):
        images = rearrange(batch['images'], 'H W C N -> N H W C')
        images = images.astype(jnp.bfloat16)
        logits = train_state.apply_fn(params, images, is_training=True)
        y = one_hot(batch['labels'])
        if 'mix_labels' in batch:
            y1 = one_hot(batch['mix_labels'])
            y = batch['ratio'][:,
                               None] * y + (1. - batch['ratio'][:, None]) * y1
        y = optax.smooth_labels(y, label_smoothing)
        logits = logits.astype(jnp.float32)
        loss = jnp.mean(optax.softmax_cross_entropy(logits, y))
        scaled_loss = loss / jax.device_count()
        return scaled_loss, logits

    grad_fn = jax.value_and_grad(loss_fn, has_aux=True, axis_name='batch')
    aux, grads = grad_fn(train_state.params)
    grads = jax.lax.pmean(grads, axis_name='batch')
    loss, logits = aux
    top_k_acc = utils.topk_correct(logits, batch['labels'], prefix='train_')
    top_k_acc = jax.tree_map(jnp.mean, top_k_acc)
    new_train_state = train_state.apply_gradients(grads=grads)

    if jax.process_index() == 0:
        wandb.log(
            {
                'train/loss': float(loss),
                'train/top-1-acc': top_k_acc['train_top_1_acc']
            }, train_state.step)

    return new_train_state
Esempio n. 13
0
def main(argv):
    if len(argv) > 1:
        raise app.UsageError('Too many command-line arguments.')

    # Hide any GPUs form TensorFlow. Otherwise TF might reserve memory and make
    # it unavailable to JAX.
    tf.config.experimental.set_visible_devices([], 'GPU')

    logging.info('JAX process: %d / %d', jax.process_index(),
                 jax.process_count())
    logging.info('JAX local devices: %r', jax.local_devices())

    # Add a note so that we can tell which task is which JAX host.
    # (Depending on the platform task 0 is not guaranteed to be host 0)
    platform.work_unit().set_task_status(
        f'process_index: {jax.process_index()}, '
        f'process_count: {jax.process_count()}')
    platform.work_unit().create_artifact(platform.ArtifactType.DIRECTORY,
                                         FLAGS.workdir, 'workdir')

    if FLAGS.mode == 'train':
        train.train_and_evaluate(FLAGS.config, FLAGS.workdir)
    else:
        predict.predict_and_evaluate(FLAGS.config, FLAGS.workdir,
                                     FLAGS.ckpt_path)
Esempio n. 14
0
def write_per_example_losses(*, p_eval_step, target, eval_ds, num_eval_steps,
                             loss_filename):
    """Evaluate the target an return a dictionary with the metrics."""
    logging.info('Gathering evaluation metrics.')
    losses = []
    lengths = []
    eval_iter = iter(eval_ds)  # pytype: disable=wrong-arg-types
    for _, eval_batch in zip(range(num_eval_steps), eval_iter):
        eval_batch = jax.tree_map(lambda x: x._numpy(), eval_batch)  # pylint: disable=protected-access
        eval_batch = common_utils.shard(eval_batch)
        loss, length = p_eval_step(target, eval_batch)
        losses.append(common.tohost(loss))
        lengths.append(common.tohost(length))
    # Write losses and lengths
    if jax.process_index() == 0:
        with tf.io.gfile.GFile(loss_filename, 'w') as f:
            writer = csv.writer(f)
            for pos_losses in losses:
                for val in pos_losses:
                    writer.writerow(list(val))
        with tf.io.gfile.GFile(loss_filename.replace('.csv', '_length.csv'),
                               'w') as f:
            writer = csv.writer(f)
            for val in lengths:
                writer.writerow([int(v) for v in list(val)])
    return
Esempio n. 15
0
def save_samples_to_json(features: List[Dict[str, Any]],
                         config: ml_collections.ConfigDict, step: int):
    """Save samples to a json file."""
    save_samples_for_this_step = (
        config.get('save_samples_every_steps')
        and (step % config.get('save_samples_every_steps') == 0))

    process_index = jax.process_index()
    accepted_processes = config.get('save_samples_process_ids', 0)
    if isinstance(accepted_processes, list):
        save_samples_for_this_process = (process_index in accepted_processes)
    elif accepted_processes == -1:
        save_samples_for_this_process = True
    else:
        save_samples_for_this_process = (process_index == accepted_processes)

    if save_samples_for_this_step and save_samples_for_this_process:
        logging.info('Saving samples at step %d, process %d', step,
                     process_index)
        path = os.path.join(config.model_dir, 'samples',
                            'step_%d.process_%d.json' % (step, process_index))
        tf.io.gfile.makedirs(os.path.dirname(path))
        with tf.io.gfile.GFile(path, 'ab') as fp:
            for batch in features:
                json.dump(batch, fp)
                fp.write('\n')
Esempio n. 16
0
def setup_jax(globally_use_hardware_rng: bool, jax_use_gda: bool,
              jax_backend_target: Optional[str],
              jax_xla_backend: Optional[str]) -> None:
    """Setups JAX and logs information about this job."""
    # Hide any GPUs from TensorFlow. Otherwise TF might reserve memory and make
    # it unavailable to JAX.
    tf.config.experimental.set_visible_devices([], 'GPU')
    if globally_use_hardware_rng:
        py_utils.set_globally_use_rbg_prng_key()

    # We use xmap only with SPMD.
    jax.config.update('experimental_xmap_spmd_lowering', True)
    # Use the manual partitioning lowering of xmap to avoid vectorization.
    jax.config.update('experimental_xmap_spmd_lowering_manual', True)

    if jax_use_gda:
        logging.info('Using JAX GSDA for pjit and checkpointing')

    if jax_backend_target:
        logging.info('Using JAX backend target %s', jax_backend_target)
        jax_xla_backend = 'None' if jax_xla_backend is None else jax_xla_backend
        logging.info('Using JAX XLA backend %s', jax_xla_backend)

    logging.info('JAX process: %d / %d', jax.process_index(),
                 jax.process_count())
    logging.info('JAX devices: %r', jax.devices())
    logging.info('jax.device_count(): %d', jax.device_count())
    logging.info('jax.local_device_count(): %d', jax.local_device_count())
    logging.info('jax.process_count(): %d', jax.process_count())
Esempio n. 17
0
def prepare_batches_gen(dataset, eval_config):
    """Returns a data iterator.

  The API for the data iterator will be
    for b in batches_gen():
      pass
  We yield the same "epoch" every time to the data iterator is called.

  Args:
    dataset: An init2winit.dataset_lib.Dataset object. This is ignored if
      eval_config['use_training_gen'] == False.
    eval_config: A dict specifying the parameters for the hessian eval.

  Returns:
    A data generator.
  """
    train_iter = itertools.islice(dataset.train_iterator_fn(), 0,
                                  eval_config['num_batches'])
    batches = list(train_iter)
    init_rng = jax.random.PRNGKey(eval_config['rng_key'])
    init_rng = jax.random.fold_in(init_rng, jax.process_index())

    def training_batches_gen():
        for counter, batch in enumerate(batches):
            batch = data_utils.shard(batch)
            rng = jax.random.fold_in(init_rng, counter)
            rng = jax_utils.replicate(rng)
            yield (batch, rng)

    return training_batches_gen
Esempio n. 18
0
def main(argv):
    if len(argv) > 1:
        raise app.UsageError('Too many command-line arguments.')

    utils.add_gfile_logger(_WORKDIR.value)

    # Hide any GPUs form TensorFlow. Otherwise TF might reserve memory and make
    # it unavailable to JAX.
    tf.config.experimental.set_visible_devices([], 'GPU')

    jax.config.update('jax_log_compiles', True)

    logging.info('JAX process: %d / %d', jax.process_index(),
                 jax.process_count())
    logging.info('JAX local devices: %r', jax.local_devices())
    jax_xla_backend = ('None' if FLAGS.jax_xla_backend is None else
                       FLAGS.jax_xla_backend)
    logging.info('Using JAX XLA backend %s', jax_xla_backend)

    logging.info('Config: %s', FLAGS.config)

    # Add a note so that we can tell which task is which JAX host.
    # (Depending on the platform task 0 is not guaranteed to be host 0)
    platform.work_unit().set_task_status(
        f'process_index: {jax.process_index()}, '
        f'process_count: {jax.process_count()}')
    platform.work_unit().create_artifact(platform.ArtifactType.DIRECTORY,
                                         _WORKDIR.value, 'workdir')

    if FLAGS.config.trainer == 'train':
        train.train_and_evaluate(FLAGS.config, _WORKDIR.value)
    elif FLAGS.config.trainer == 'inference_time':
        inference_time.inference_time(FLAGS.config, _WORKDIR.value)
    else:
        raise app.UsageError(f'Unknown trainer: {FLAGS.config.trainer}')
Esempio n. 19
0
    def _thread_func(self, temp_checkpoint_dir, final_checkpoint_dir):
        try:
            for future in self._commit_futures:
                for f in future:
                    f.result()
            logging.info('Commit to storage layer has completed.')

            current_process = jax.process_index()
            lockfiles_dir = os.path.join(temp_checkpoint_dir, 'lockfiles')
            all_lockfile_paths = [
                os.path.join(lockfiles_dir, f'lockfile_{p}')
                for p in range(jax.process_count())
            ]
            current_process_lockfile = os.path.join(
                lockfiles_dir, f'lockfile_{current_process}')

            with _RetryWithTimeout(self._timeout_secs) as t:
                while not tf.io.gfile.exists(current_process_lockfile):
                    if t.timed_out:
                        raise RuntimeError(
                            'Terminating after waiting for '
                            f'{self._timeout_secs} secs for lockfile to appear'
                        )
                    logging.info(
                        'Waiting for current process %s lockfile to appear.',
                        current_process)
                    time.sleep(60)

            tf.io.gfile.remove(current_process_lockfile)
            logging.info('Lockfile removed for process %s', current_process)

            # This while loop will not trigger until all commits have finished.
            if current_process == 0:
                with _RetryWithTimeout(self._timeout_secs) as t:
                    while True:
                        if t.timed_out:
                            raise RuntimeError(
                                'Terminating after waiting for '
                                f'{self._timeout_secs} secs for '
                                'finishing the serialization.')
                        # Mark as done when no lockfiles exist.
                        if no_lockfiles_exists(all_lockfile_paths):
                            tf.io.gfile.rmtree(lockfiles_dir)
                            logging.info('Lockfiles directory removed.')

                            logging.info('Renaming %s to %s',
                                         temp_checkpoint_dir,
                                         final_checkpoint_dir)
                            tf.io.gfile.rename(temp_checkpoint_dir,
                                               final_checkpoint_dir)
                            logging.info(
                                'Finished saving GDA checkpoint to `%s`.',
                                final_checkpoint_dir)
                            break
                        else:
                            logging.info('Thread sleeping for 60 seconds.')
                            time.sleep(60)

        except Exception as e:
            self._exception = e
Esempio n. 20
0
def main(argv):
    del argv

    # Hide any GPUs form TensorFlow. Otherwise TF might reserve memory and make
    # it unavailable to JAX.
    tf.config.experimental.set_visible_devices([], "GPU")

    logging.info("JAX process: %d / %d", jax.process_index(),
                 jax.process_count())
    logging.info("JAX devices: %r", jax.devices())

    # Add a note so that we can tell which task is which JAX process.
    platform.work_unit().set_task_status(
        f"process_index: {jax.process_index()}, process_count: {jax.process_count()}"
    )
    platform.work_unit().create_artifact(platform.ArtifactType.DIRECTORY,
                                         FLAGS.workdir, "workdir")

    train_mode = FLAGS.config.mode
    if train_mode == TrainingMode.PRETRAINING:
        train_lib = run_pretraining
    elif train_mode == TrainingMode.CLASSIFICATION:
        train_lib = run_classifier
    else:
        raise ValueError("Unknown training mode: %s" % train_mode)

    train_lib.train_and_evaluate(FLAGS.config, FLAGS.workdir,
                                 FLAGS.vocab_filepath)
Esempio n. 21
0
def get_data_iterator(config):
    """Get iterator over the dataset."""
    task = memory_generation_task.MemoryGenerationTask

    # Establish host information
    local_device_count = jax.local_device_count()
    host_count = jax.process_count()
    host_id = jax.process_index()

    # Load datasets
    logging.info('Loading dataset.')
    decode_fn = data_utils.make_decode_fn(
        name_to_features=task.get_name_to_features(config),
        samples_per_example=config.samples_per_example,
    )

    preprocess_fn = task.make_preprocess_fn(config)
    collater_fn = task.make_collater_fn(config)

    data = data_utils.load_dataset(
        patterns=config.data_patterns,
        decode_fn=decode_fn,
        preprocess_fn=preprocess_fn,
        collater_fn=collater_fn,
        is_training=False,
        per_device_batch_size=config.per_device_batch_size,
        local_device_count=local_device_count,
        host_count=host_count,
        host_id=host_id,
        seed=0,
    )
    return iter(data)
Esempio n. 22
0
def _mkdir_path(name: str, tmp_dir: str) -> str:
  # Tensorstore does not want a trailing / in dirname.
  path = os.path.join(tmp_dir, name).rstrip('/')
  # Make the paths only on process 0.
  if jax.process_index() == 0:
    # Avoid recursively create parent dir.
    tf.io.gfile.mkdir(path)
  return path
Esempio n. 23
0
 def _f(path_tuple, v):
     vname = "/".join(path_tuple)
     for pattern, arg in regex_rules:
         if re.match(pattern, vname):
             if jax.process_index() == 0:
                 logging.info("Updating %s with %s due to `%s`", vname, arg,
                              pattern)
             return f(v, arg)
     return v
Esempio n. 24
0
def main(unused_argv):
    if jax.process_index() == 0:
        logging.info('argv:\n%s', ' '.join(sys.argv))
        logging.info('device_count: %d', jax.device_count())
        logging.info('num_hosts : %d', jax.process_count())
        logging.info('host_id : %d', jax.process_index())

        if FLAGS.batch_size is None or FLAGS.batch_size <= 0:
            raise ValueError("""FLAGS.batch_size value is invalid,
          expected a positive non-zero integer.""")

        if FLAGS.dataset is None:
            raise ValueError("""FLAGS.dataset value is invalid,
          expected a non-empty string describing dataset name.""")

        batch_size = FLAGS.batch_size
        num_batches = FLAGS.num_batches
        dataset_name = FLAGS.dataset
        model_name = FLAGS.model
        initializer_name = 'noop'

        hparam_overrides = {
            'batch_size': batch_size,
        }

        hps = hyperparameters.build_hparams(model_name=model_name,
                                            initializer_name=initializer_name,
                                            dataset_name=dataset_name,
                                            hparam_file=None,
                                            hparam_overrides=hparam_overrides)

        rng = jax.random.PRNGKey(0)
        rng, data_rng = jax.random.split(rng)

        dataset = datasets.get_dataset(FLAGS.dataset)(data_rng, batch_size,
                                                      batch_size, hps)
        train_iter = dataset.train_iterator_fn()

        for i in range(num_batches):
            batch = next(train_iter)
            logging.info('train batch_num = %d, batch = %r', i, batch)

        for batch in dataset.valid_epoch(num_batches):
            logging.info('validation batch = %r', batch)
Esempio n. 25
0
def main(argv: Sequence[str]) -> None:
  if len(argv) > 1:
    raise app.UsageError("Too many command-line arguments.")

  # Hide any GPUs form TensorFlow. Otherwise TF might reserve memory and make
  # it unavailable to JAX.
  tf.config.experimental.set_visible_devices([], "GPU")

  logging.info("JAX process: %d / %d", jax.process_index(), jax.process_count())
  logging.info("JAX local devices: %r", jax.local_devices())

  # Add a note so that we can tell which task is which JAX host.
  # (Depending on the platform task 0 is not guaranteed to be host 0)
  platform.work_unit().set_task_status(f"process_index: {jax.process_index()}, "
                                       f"process_count: {jax.process_count()}")
  platform.work_unit().create_artifact(platform.ArtifactType.DIRECTORY,
                                       FLAGS.model_dir, "model_dir")

  tf.io.gfile.makedirs(FLAGS.model_dir)

  # Process config here
  if FLAGS.config_file:
    with tf.io.gfile.GFile(FLAGS.config_file, "r") as reader:
      config = json.load(reader)
  else:
    config = json.loads(FLAGS.config)
    # # Save config to workdir if it's not yet exists
    if jax.process_index() == 0:
      config_file = os.path.join(FLAGS.model_dir, "config.json")
      with tf.io.gfile.GFile(config_file, "w") as writer:
        writer.write(json.dumps(config, indent=4))

  config["model_dir"] = FLAGS.model_dir
  if FLAGS.learning_rate is not None:
    config["learning_rate"] = FLAGS.learning_rate
  if FLAGS.per_device_batch_size is not None:
    config["per_device_batch_size"] = FLAGS.per_device_batch_size
  if FLAGS.num_train_steps is not None:
    config["num_train_steps"] = FLAGS.num_train_steps
  if FLAGS.warmup_steps is not None:
    config["warmup_steps"] = FLAGS.warmup_steps

  train(ml_collections.ConfigDict(config))
Esempio n. 26
0
def get_summary_writer(summary_dir: str) -> SummaryWriter:
    """Context manager around Tensorflow's SummaryWriter."""
    if jax.process_index() == 0:
        logging.info('Opening SummaryWriter `%s`...', summary_dir)
        summary_writer = tf_summary.create_file_writer(summary_dir)
    else:
        # We create a dummy tf.summary.SummaryWriter() on non-zero tasks. This will
        # return a mock object, which acts like a summary writer, but does nothing,
        # such as writing event to disk.
        logging.info('Opening a mock-like SummaryWriter.')
        summary_writer = tf_summary.create_noop_writer()
    try:
        yield summary_writer
    finally:
        summary_writer.close()
        if jax.process_index() == 0:
            logging.info('Closed SummaryWriter `%s`.', summary_dir)
        else:
            logging.info('Closed a mock-like SummaryWriter.')
def _train_sentencepiece(dataset: tf.data.Dataset,
                         *,
                         vocab_size: int,
                         maxchars: int = int(1e7),
                         model_path: str,
                         model_type: str = 'unigram',
                         character_coverage: float = 1.0,
                         data_keys=('inputs', 'targets')):
    """Train SentencePiece tokenizer from subset of tf dataset.

  Args:
    dataset: tf.dataset
    vocab_size: int: size of vocab tokens to train.
    maxchars: int: number of characters to use for sentencepiece training.
    model_path: str: path of model file to save vocab model to.
    model_type: str: type of sentencepiece vocab to train.
    character_coverage: amount of characters covered by the model, good defaults
      are 0.9995 for languages with rich character set like Japanese or Chinese
      and 1.0 for other languages with small character set.
    data_keys: Tuple[str]: keys of dataset to use for training.

  Returns:
    path to the trained sentencepiece vocabulary model.
  """
    if model_path.startswith('gs://'):
        abs_model_path = model_path
    else:
        abs_model_path = os.path.abspath(os.path.expanduser(model_path))
    fname, _ = _dump_chars_to_textfile(dataset,
                                       maxchars=maxchars,
                                       data_keys=data_keys)
    with tempfile.NamedTemporaryFile(delete=False,
                                     prefix='/tmp/sp_tmp') as model_fp:
        pass  # we just want a prefix'd tmp-filename
    argstr = ' '.join([
        f'--input={fname}', f'--vocab_size={vocab_size}',
        f'--character_coverage={character_coverage}',
        f'--model_prefix={model_fp.name}', f'--model_type={model_type}'
    ])
    SentencePieceTrainer.Train(argstr)
    if jax.process_index() == 0:
        # Use an intermediate filename that is renamed to the target name to address
        # create and fill delays.
        copy_rename_path = abs_model_path + '.rntmp'
        tf.io.gfile.copy(model_fp.name + '.model',
                         copy_rename_path,
                         overwrite=True)
        tf.io.gfile.rename(copy_rename_path, abs_model_path, overwrite=True)
        logging.info('copied %s to %s', model_fp.name + '.model',
                     abs_model_path)
    else:
        while not tf.io.gfile.exists(abs_model_path):
            time.sleep(1)
        time.sleep(1)
    return abs_model_path
Esempio n. 28
0
def train_sentencepiece(dataset,
                        vocab_size,
                        maxchars=1e9,
                        character_coverage=1.0,
                        model_path="model",
                        model_type="unigram",
                        data_keys=("inputs", "targets")):
    """Train SentencePiece tokenizer from subset of tf dataset.

  Args:
    dataset: tf.dataset
    vocab_size: int: size of vocab tokens to train.
    maxchars: int: number of characters to use for sentencepiece training.
    character_coverage: amount of characters covered by the model, good defaults
      are 0.9995 for languages with rich character set like Japanese or Chinese
      and 1.0 for other languages with small character set.
    model_path: str: path of model file to save vocab model to.
    model_type: str: type of sentencepiece vocab to train.
    data_keys: Tuple[str]: keys of dataset to use for training.

  Returns:
    path to the trained sentencepiece vocabulary model.
  """
    fname, _ = dump_chars_to_textfile(dataset,
                                      maxchars=maxchars,
                                      data_keys=data_keys)
    with tempfile.NamedTemporaryFile(delete=False,
                                     prefix="/tmp/sp_tmp") as model_fp:
        pass  # we just want a prefix'd tmp-filename
    argstr = " ".join([
        f"--input={fname}", f"--vocab_size={vocab_size}",
        f"--character_coverage={character_coverage}",
        f"--model_prefix={model_fp.name}", f"--model_type={model_type}"
    ])
    SentencePieceTrainer.Train(argstr)
    if jax.process_index() == 0:
        # Use an intermediate filename that is renamed to the target name to address
        # create and fill delays.
        copy_rename_path = model_path + ".rntmp"
        tf.io.gfile.copy(model_fp.name + ".model",
                         copy_rename_path,
                         overwrite=True)
        tf.io.gfile.rename(copy_rename_path, model_path, overwrite=True)
        tf.io.gfile.copy(model_fp.name + ".vocab",
                         copy_rename_path + ".vocab",
                         overwrite=True)
        tf.io.gfile.rename(copy_rename_path + ".vocab",
                           model_path + ".vocab",
                           overwrite=True)
        logging.info("copied %s to %s", model_fp.name + ".model", model_path)
    else:
        while not tf.io.gfile.exists(model_path):
            time.sleep(1)
        time.sleep(1)
    return model_path
Esempio n. 29
0
def main(argv: Sequence[str]) -> None:
    if len(argv) > 1:
        raise app.UsageError('Too many command-line arguments.')

    # Hide any GPUs form TensorFlow. Otherwise TF might reserve memory and make
    # it unavailable to JAX.
    tf.config.experimental.set_visible_devices([], 'GPU')

    logging.info('JAX process: %d / %d', jax.process_index(),
                 jax.process_count())
    logging.info('JAX local devices: %r', jax.local_devices())
    logging.info('JAX total devices: %r', jax.device_count())

    # Add a note so that we can tell which task is which JAX host.
    # (Depending on the platform task 0 is not guaranteed to be host 0)
    platform.work_unit().set_task_status(
        f'process_index: {jax.process_index()}, '
        f'process_count: {jax.process_count()}')
    platform.work_unit().create_artifact(platform.ArtifactType.DIRECTORY,
                                         FLAGS.output_dir, 'output_dir')

    tf.io.gfile.makedirs(FLAGS.output_dir)

    # Process config here
    if FLAGS.config_file:
        with tf.io.gfile.GFile(FLAGS.config_file, 'r') as reader:
            config = json.load(reader)
    else:
        config = json.loads(FLAGS.config)
        # # Save config to workdir if it's not yet exists
        if jax.process_index() == 0:
            config_file = os.path.join(FLAGS.output_dir, 'config.json')
            with tf.io.gfile.GFile(config_file, 'w') as writer:
                writer.write(json.dumps(config, indent=4))

    config['output_dir'] = FLAGS.output_dir

    if 'num_total_memories' not in config:
        config['num_total_memories'] = get_num_total_memories(
            ml_collections.ConfigDict(config))

    generate(ml_collections.ConfigDict(config))
Esempio n. 30
0
def itstime(step,
            every_n_steps,
            total_steps,
            process=None,
            last=True,
            first=True):
    """Determines whether or not it is time to trigger an action."""
    is_process = process is None or jax.process_index() == process
    is_step = every_n_steps and (step % every_n_steps == 0)
    is_last = every_n_steps and step == total_steps
    is_first = every_n_steps and step == 1
    return is_process and (is_step or (last and is_last) or
                           (first and is_first))