def main(_):
    master = jax.host_id() == 0
    # make sure TF does not allocate gpu memory
    tf.config.experimental.set_visible_devices([], 'GPU')

    # The pool is used to perform misc operations such as logging in async way.
    pool = multiprocessing.pool.ThreadPool()

    # load configs from a config json string
    hparams = FLAGS.config
    logging.info('=========== Hyperparameters ============')
    logging.info(hparams)

    if hparams.get('debug'):
        logging.warning('DEBUG MODE IS ENABLED!')

    # set tensorflow random seed
    tf.random.set_seed(jax.host_id() + hparams.rng_seed)
    experiment_dir = FLAGS.experiment_dir
    logging.info('Experiment directory: %s', experiment_dir)
    summary_writer = None

    if master and hparams.write_summary:
        tensorboard_dir = os.path.join(experiment_dir, 'tb_summaries')
        gfile.makedirs(tensorboard_dir)
        summary_writer = tensorboard.SummaryWriter(tensorboard_dir)

    run(hparams, experiment_dir, summary_writer)

    pool.close()
    pool.join()
Exemple #2
0
def create_split(dataset_builder: tfds.core.DatasetBuilder,
                 batch_size: int,
                 train: bool,
                 dtype: tf.DType = tf.float32,
                 image_size: int = IMAGE_SIZE,
                 cache: bool = False):
    """Creates a split from the ImageNet dataset using TensorFlow Datasets.

  Args:
    dataset_builder: TFDS dataset builder for ImageNet.
    batch_size: the batch size returned by the data pipeline.
    train: Whether to load the train or evaluation split.
    dtype: data type of the image (default: float32).
    image_size: The target size of the images (default: 224).
    cache: Whether to cache the dataset (default: False).
  Returns:
    A `tf.data.Dataset`.
  """
    if train:
        train_size = dataset_builder.info.splits['train'].num_examples
        split_size = train_size // jax.host_count()
        start = jax.host_id() * split_size
        split = 'train[{}:{}]'.format(start, start + split_size)
    else:
        validation_size = dataset_builder.info.splits[
            'validation'].num_examples
        split_size = validation_size // jax.host_count()
        start = jax.host_id() * split_size
        split = 'validation[{}:{}]'.format(start, start + split_size)

    def _decode_example(example):
        if train:
            image = preprocess_for_train(example['image'], dtype, image_size)
        else:
            image = preprocess_for_eval(example['image'], dtype, image_size)
        return {'image': image, 'label': example['label']}

    ds = dataset_builder.as_dataset(
        split=split, decoders={'image': tfds.decode.SkipDecoding()})
    ds.options().experimental_threading.private_threadpool_size = 48
    ds.options().experimental_threading.max_intra_op_parallelism = 1

    if cache:
        ds = ds.cache()

    if train:
        ds = ds.repeat()
        ds = ds.shuffle(16 * batch_size, seed=0)

    ds = ds.map(_decode_example,
                num_parallel_calls=tf.data.experimental.AUTOTUNE)

    ds = ds.batch(batch_size, drop_remainder=True)

    if not train:
        ds = ds.repeat()

    ds = ds.prefetch(10)

    return ds
Exemple #3
0
def process_iterator(tag: str, item_ids: Sequence[str], iterator,
                     rng: types.PRNGKey, state: model_utils.TrainState,
                     step: int, render_fn: Any,
                     summary_writer: tensorboard.SummaryWriter,
                     save_dir: Optional[gpath.GPath],
                     datasource: datasets.DataSource):
    """Process a dataset iterator and compute metrics."""
    save_dir = save_dir / f'{step:08d}' / tag if save_dir else None
    meters = collections.defaultdict(utils.ValueMeter)
    for i, (item_id, batch) in enumerate(zip(item_ids, iterator)):
        logging.info('[%s:%d/%d] Processing %s ', tag, i + 1, len(item_ids),
                     item_id)
        if tag == 'test':
            test_rng = random.PRNGKey(step)
            shape = batch['origins'][..., :1].shape
            metadata = {}
            if datasource.use_appearance_id:
                appearance_id = random.choice(
                    test_rng, jnp.asarray(datasource.appearance_ids))
                logging.info('\tUsing appearance_id = %d', appearance_id)
                metadata['appearance'] = jnp.full(shape,
                                                  fill_value=appearance_id,
                                                  dtype=jnp.uint32)
            if datasource.use_warp_id:
                warp_id = random.choice(test_rng,
                                        jnp.asarray(datasource.warp_ids))
                logging.info('\tUsing warp_id = %d', warp_id)
                metadata['warp'] = jnp.full(shape,
                                            fill_value=warp_id,
                                            dtype=jnp.uint32)
            if datasource.use_camera_id:
                camera_id = random.choice(test_rng,
                                          jnp.asarray(datasource.camera_ids))
                logging.info('\tUsing camera_id = %d', camera_id)
                metadata['camera'] = jnp.full(shape,
                                              fill_value=camera_id,
                                              dtype=jnp.uint32)
            batch['metadata'] = metadata

        stats = process_batch(batch=batch,
                              rng=rng,
                              state=state,
                              tag=tag,
                              item_id=item_id,
                              step=step,
                              render_fn=render_fn,
                              summary_writer=summary_writer,
                              save_dir=save_dir,
                              datasource=datasource)
        if jax.host_id() == 0:
            for k, v in stats.items():
                meters[k].update(v)

    if jax.host_id() == 0:
        for meter_name, meter in meters.items():
            summary_writer.scalar(tag=f'metrics-eval/{meter_name}/{tag}',
                                  value=meter.reduce('mean'),
                                  step=step)
def load_split(batch_size,
               train,
               dtype=tf.float32,
               image_size=IMAGE_SIZE,
               cache=False):
    """Creates a split from the ImageNet dataset using TensorFlow Datasets.

  Args:
    batch_size: the batch size returned by the data pipeline.
    train: Whether to load the train or evaluation split.
    dtype: data type of the image.
    image_size: The target size of the images.
    cache: Whether to cache the dataset.
  Returns:
    A `tf.data.Dataset`.
  """
    if train:
        split_size = TRAIN_IMAGES // jax.host_count()
        start = jax.host_id() * split_size
        split = 'train[{}:{}]'.format(start, start + split_size)
    else:
        split_size = EVAL_IMAGES // jax.host_count()
        start = jax.host_id() * split_size
        split = 'validation[{}:{}]'.format(start, start + split_size)

    def decode_example(example):
        if train:
            image = preprocess_for_train(example['image'], dtype, image_size)
        else:
            image = preprocess_for_eval(example['image'], dtype, image_size)
        return {'image': image, 'label': example['label']}

    ds = tfds.load('imagenet2012:5.*.*',
                   split=split,
                   decoders={
                       'image': tfds.decode.SkipDecoding(),
                   })
    options = tf.data.Options()
    options.experimental_threading.private_threadpool_size = 48
    ds = ds.with_options(options)

    if cache:
        ds = ds.cache()

    if train:
        ds = ds.repeat()
        ds = ds.shuffle(16 * batch_size, seed=0)

    ds = ds.map(decode_example,
                num_parallel_calls=tf.data.experimental.AUTOTUNE)
    ds = ds.batch(batch_size, drop_remainder=True)

    if not train:
        ds = ds.repeat()

    ds = ds.prefetch(10)

    return ds
Exemple #5
0
def eval_once(run_configuration, checkpoint_path, optimizer=None):
    """Evaluates a single checkpoint on a single epoch of data."""
    config = run_configuration.config
    run_dir = run_configuration.run_dir
    adapter = run_configuration.adapter
    optimizer = optimizer or adapter.create_optimizer(run_configuration)
    dataset = run_configuration.dataset_info.dataset
    info = run_configuration.dataset_info.info

    eval_name = config.eval_name or 'eval'
    log_dir = os.path.join(run_dir, eval_name)

    if jax.host_id() == 0:
        summary_writer = tensorboard.SummaryWriter(log_dir)

    # Restore checkpoint
    optimizer = checkpoint_utils.restore_checkpoint(checkpoint_path, optimizer)
    step = int(optimizer.state.step)

    # Replicate optimizer.
    optimizer = flax.jax_utils.replicate(optimizer)
    eval_step = adapter.make_eval_step()
    eval_step_parallel = jax.pmap(eval_step, axis_name='batch')

    # Perform evaluation
    tick = time.time()
    metrics_all = []

    example = None
    dataset_iter_raw = iter(dataset)
    dataset_iter = adapter.preprocess(dataset_iter_raw)
    for unused_eval_step, example in zip(range(config.eval_steps),
                                         dataset_iter):
        train_inputs = adapter.get_train_inputs(example)
        metrics, logits, state = eval_step_parallel(optimizer.target,
                                                    train_inputs)
        metrics_all.append(metrics)

    # Write results.
    metrics_all = common_utils.get_metrics(metrics_all)
    metrics_sums = jax.tree_map(jnp.sum, metrics_all)
    denominator = metrics_sums.pop('denominator')
    summary = jax.tree_map(lambda x: x / denominator, metrics_sums)  # pylint: disable=cell-var-from-loop
    summary['perplexity'] = jnp.clip(jnp.exp(summary['loss']), a_max=1.0e4)
    logging.info('eval @ train step: %d, loss: %.4f', step, summary['loss'])
    if jax.host_id() == 0:
        tock = time.time()
        steps_per_sec = len(metrics_all) / (tock - tick)
        examples_per_sec = denominator / (tock - tick)
        summary_writer.scalar('per-second/steps', steps_per_sec, step)
        summary_writer.scalar('per-second/examples', examples_per_sec, step)
        for key, val in summary.items():
            summary_writer.scalar(key, val, step)

        adapter.write_summaries(example, logits, summary_writer, info, step,
                                state)
        summary_writer.flush()
def main(executable_dict, argv):
    del argv

    work_unit = platform.work_unit()
    tf.enable_v2_behavior()
    # Hide any GPUs form TensorFlow. Otherwise TF might reserve memory and make
    # it unavailable to JAX.
    tf.config.experimental.set_visible_devices([], 'GPU')

    logging.info('JAX host: %d / %d', jax.host_id(), jax.host_count())
    logging.info('JAX devices: %r', jax.devices())

    work_unit.set_task_status(
        f'host_id: {jax.host_id()}, host_count: {jax.host_count()}')

    # Read configuration
    if FLAGS.config_json:
        logging.info('Reading config from JSON: %s', FLAGS.config_json)
        with tf.io.gfile.GFile(FLAGS.config_json, 'r') as f:
            config = ml_collections.ConfigDict(json.loads(f.read()))
    else:
        config = FLAGS.config
    logging.info('config=%s',
                 config.to_json_best_effort(indent=4, sort_keys=True))

    # Make output directories
    if FLAGS.experiment_dir:
        work_unit.create_artifact(platform.ArtifactType.DIRECTORY,
                                  FLAGS.experiment_dir, 'experiment_dir')
    if FLAGS.work_unit_dir:
        work_unit.create_artifact(platform.ArtifactType.DIRECTORY,
                                  FLAGS.work_unit_dir, 'work_unit_dir')
    logging.info('experiment_dir=%s work_unit_dir=%s', FLAGS.experiment_dir,
                 FLAGS.work_unit_dir)

    # Seeding
    random.seed(config.seed * jax.host_count() + jax.host_id())
    onp.random.seed(config.seed * jax.host_count() + jax.host_id())
    rng = utils.RngGen(
        jax.random.fold_in(jax.random.PRNGKey(config.seed), jax.host_id()))

    # Run the main function
    logging.info('Running executable: %s', FLAGS.executable_name)

    extra_args = {}
    if FLAGS.extra_args_json_str:
        extra_args = json.loads(FLAGS.extra_args_json_str)
        logging.info('Extra args passed in: %r', extra_args)

    executable_dict[FLAGS.executable_name](config=config,
                                           experiment_dir=FLAGS.experiment_dir,
                                           work_unit_dir=FLAGS.work_unit_dir,
                                           rng=rng,
                                           **extra_args)

    utils.barrier()
Exemple #7
0
def meta_init(loss_fn,
              model,
              hps,
              input_shape,
              output_shape,
              rng_key,
              metrics_logger=None,
              log_every=10):
    """Implements MetaInit initializer.

  Args:
    loss_fn: Loss function.
    model: Flax Model class.
    hps: HParam object. Required hparams are meta_learning_rate,
      meta_batch_size, meta_steps, and epsilon.
    input_shape: Must agree with batch[0].shape[1:].
    output_shape: Must agree with batch[1].shape[1:].
    rng_key: jax.PRNGKey, used to seed all randomness.
    metrics_logger: Instance of utils.MetricsLogger
    log_every: Print meta loss every k steps.

  Returns:
    A Flax model with the learned initialization.
  """
    # Pretty print the preinitialized norms with the variable shapes.
    if jax.host_id() == 0:
        logging.info('Preinitialized norms:')
        _log_shape_and_norms(model.params, metrics_logger, key='init_norms')
        # First grab the norms of all weights and rescale params to have norm 1.
        logging.info('Running meta init')
    norms = jax.tree_map(lambda node: jnp.linalg.norm(node.reshape(-1)),
                         model.params)

    normalized_params = jax.tree_map(normalize, model.params)

    learned_norms, _ = meta_optimize_scales(loss_fn,
                                            model.module.call,
                                            normalized_params,
                                            norms,
                                            hps,
                                            input_shape,
                                            output_shape,
                                            rng_key,
                                            metrics_logger=metrics_logger,
                                            log_every=log_every)
    new_init = scale_params(normalized_params, learned_norms)

    if jax.host_id() == 0:
        # Pretty print the meta init norms with the variable shapes.
        logging.info('Learned norms from meta_init:')
        _log_shape_and_norms(new_init, metrics_logger, key='meta_init_norms')

    return nn.Model(model.module, new_init)
Exemple #8
0
  def compute_interpolations(self, model, gdirs, udirs, hvex, cvex, step):
    """Compute the linear interpolation along directions of gdirs or udirs."""
    row = {'step': step}
    if not self.eval_config['compute_interps']:
      return row
    lower = self.eval_config['lower_thresh']
    upper = self.eval_config['upper_thresh']
    num_points = self.eval_config['num_points']
    etas = np.linspace(lower, upper, num=num_points, endpoint=True)
    row = {'step_size': etas}
    for i, u_dir in enumerate(gdirs):
      u_dir = _tree_normalize(u_dir)
      loss_values = np.zeros(shape=(num_points,))
      for j in range(num_points):
        eta = etas[j]
        loss_values[j] = self._full_batch_eval(model, u_dir, eta)
      row['loss%d' % (i,)] = np.copy(loss_values)
    if jax.host_id() == 0:
      logging.info('Loss interpolation along gradients finished.')

    for i, u_dir in enumerate(udirs):
      u_dir = _tree_normalize(u_dir)
      loss_values = np.zeros(shape=(num_points,))
      for j in range(num_points):
        eta = etas[j]
        loss_values[j] = self._full_batch_eval(model, u_dir, eta)
      row['loss_u%d' % (i,)] = np.copy(loss_values)
    if jax.host_id() == 0:
      logging.info('Loss interpolation along optimizer directions finished.')

    _, unflatten = ravel_pytree(gdirs[0])
    for i, u_dir in enumerate(hvex):
      loss_values = np.zeros(shape=(num_points,))
      u_dir = unflatten(u_dir)
      for j in range(num_points):
        eta = etas[j]
        loss_values[j] = self._full_batch_eval(model, u_dir, eta)
      row['loss_hvec%d' % (i,)] = np.copy(loss_values)

    for i, u_dir in enumerate(cvex):
      loss_values = np.zeros(shape=(num_points,))
      u_dir = unflatten(u_dir)
      for j in range(num_points):
        eta = etas[j]
        loss_values[j] = self._full_batch_eval(model, u_dir, eta)
      row['loss_cvec%d' % (i,)] = np.copy(loss_values)

    if jax.host_id() == 0:
      logging.info('Loss interpolations finished. Statistics captured:')
      logging.info(row.keys())
    return row
Exemple #9
0
def run_train_single_device(run_configuration):
    """Runs the training workflow without pmap or jit."""
    config = run_configuration.config
    run_dir = run_configuration.run_dir
    adapter = run_configuration.adapter
    checkpoint_path = run_configuration.original_checkpoint_path
    dataset = run_configuration.dataset_info.dataset

    random_seed = 0
    rng = jax.random.PRNGKey(random_seed)
    rng = jax.random.fold_in(rng, jax.host_id())
    dropout_rng, init_rng = jax.random.split(rng)

    # Set up optimizer.
    optimizer = adapter.create_optimizer(run_configuration, rng=init_rng)

    # Set up train step.
    train_step = adapter.make_train_step(single_device=True)

    # Set up checkpointing.
    # TODO(dbieber): Set up phoenix.
    checkpoint_dir = checkpoint_utils.build_checkpoint_dir(run_dir)
    if checkpoint_path is None:
        checkpoint_path = checkpoint_utils.latest_checkpoint(checkpoint_dir)
    optimizer = checkpoint_utils.handle_restart_behavior(
        checkpoint_path, optimizer, config)

    start_step = int(optimizer.state.step)
    num_train_steps = config.train.total_steps

    # Begin training loop.
    dataset_iter_raw = iter(dataset)
    dataset_iter = adapter.preprocess(dataset_iter_raw, single_device=True)

    for step, example in zip(range(start_step, num_train_steps), dataset_iter):
        print(f'Step #{step}')
        train_inputs = adapter.get_train_inputs(example)
        optimizer, metrics, dropout_rng, logits, state = train_step(
            optimizer, train_inputs, dropout_rng)
        del metrics, logits, state  # Unused.

        # Save a Checkpoint.
        if ((step % config.logging.save_freq == 0 and step > 0)
                or step == num_train_steps - 1):
            if jax.host_id() == 0 and config.logging.save_freq:
                # Save unreplicated optimizer + model state.
                checkpoint_utils.save_checkpoint(checkpoint_dir, optimizer,
                                                 step)
Exemple #10
0
    def run_eval(self, flax_module, batch_stats, optimizer_state, global_step):
        """Computes the loss hessian and returns the max eigenvalue.

    Note, the full lanczos tridiagonal matrix is saved via the logger to
    train_dir/checkpoints/config['name'].

    Args:
      flax_module: Replicated flax module.
      batch_stats: Replicated batch_stats from the trainer.
      optimizer_state: Replicated optimizer state from the trainer.
      global_step: Current training step.

    Returns:
      Max eigenvalue of the loss (full tridiag is saved to disk).
    """
        del batch_stats
        if self.callback_config.get('precondition'):
            precondition_config = self.callback_config.get(
                'precondition_config', default=FrozenConfigDict())
            diag_preconditioner = precondition.make_diag_preconditioner(
                self.hps.optimizer, self.hps.opt_hparams,
                jax_utils.unreplicate(optimizer_state), precondition_config)
        else:
            diag_preconditioner = None
        hessian_metrics, _, _ = self.hessian_evaluator.evaluate_spectrum(
            flax_module, global_step, diag_preconditioner=diag_preconditioner)
        if jax.host_id() == 0:
            self.logger.append_pytree(hessian_metrics)

        max_eig_key = self.name + '/max_eig'
        return {max_eig_key: hessian_metrics['max_eig_hess']}
  def save_checkpoint(
      self,
      experiment_state: Mapping[str, jnp.ndarray],
      opt_state: Mapping[str, jnp.ndarray],
      step: int,
      extra_checkpoint_info: Optional[Mapping[str, Any]] = None) -> None:
    """Save checkpoint with experiment state and step information.

    Args:
     experiment_state: Experiment params to be stored.
     opt_state: Optimizer state to be stored.
     step: Training iteration step.
     extra_checkpoint_info: Extra information to be stored.
    """
    if jax.host_id() != 0:
      return

    checkpoint_data = dict(
        experiment_state=jax.tree_map(jax.device_get, experiment_state),
        opt_state=jax.tree_map(jax.device_get, opt_state),
        step=step)
    if extra_checkpoint_info is not None:
      for key in extra_checkpoint_info:
        checkpoint_data[key] = extra_checkpoint_info[key]

    with open(self._checkpoint_path, 'wb') as checkpoint_file:
      dill.dump(checkpoint_data, checkpoint_file, protocol=2)
Exemple #12
0
    def __init__(self,
                 multihost_base_directory: str,
                 tf_state: Optional[Dict[str, Any]] = None,
                 *,
                 host_id: Optional[int] = None,
                 max_to_keep: int = 5,
                 checkpoint_name: str = "ckpt"):
        """Initializes a MultihostCheckpoint with a dict of TensorFlow Trackables.

    Args:
      multihost_base_directory: Directory that will be used to construct a
        host-specific `base_directory` under which the checkpoints will be
        stored.
      tf_state: A dictionary of TensorFlow `Trackable` to be serialized, for
        example a dataset iterator.
      host_id: Host ID used to construct the `base_directory`. Taken from
        `jax.host_id()` if not specified.
      max_to_keep: Number of checkpoints to keep in the directory. If there are
        more checkpoints than specified by this number, then the oldest
        checkpoints are removed.
      checkpoint_name: Prefix of the checkpoint files (before `-{number}`).
    """
        if max_to_keep < 2:
            raise ValueError("Requires multiple checkpoints (max_to_keep>=2).")
        multihost_base_directory = multihost_base_directory.rstrip("/")
        self.multihost_base_directory = multihost_base_directory
        if host_id is None:
            host_id = jax.host_id()
        base_directory = f"{multihost_base_directory}-{host_id}"
        super().__init__(base_directory,
                         tf_state,
                         max_to_keep=max_to_keep,
                         checkpoint_name=checkpoint_name)
Exemple #13
0
def main(argv):
    del argv
    print('JAX host: %d / %d' % (jax.host_id(), jax.host_count()))
    print('JAX devices:\n%s' % '\n'.join(str(d) for d in jax.devices()),
          flush=True)
    experiment = Experiment()
    experiment.train_and_eval()
Exemple #14
0
  def __init__(self, dataset, tokenizer):
    self.tokenizer = tokenizer

    # shard train here already to avoid unnecessary tokenization.
    dataset['train'] = dataset['train'].shard(jax.host_count(), jax.host_id())

    if isinstance(dataset, dict):
      single_split = dataset['train']
    else:
      single_split = dataset

    name_a, *names_other = [
      name for name, feature in single_split.features.items()
      if feature.dtype=='string']
    assert len(names_other) <= 1, (
      'Only single sentences and sentence pairs allowed.')
    if names_other:
      name_b = names_other[0]
      tokenize = lambda example: self.tokenizer(
        example[name_a], example[name_b], truncation=True)
    else:
      tokenize = lambda example: self.tokenizer(
        example[name_a], truncation=True)
    
    mapped_dataset = dataset.map(tokenize, batched=True)
    mapped_dataset.set_format('numpy', columns=[
      'idx', 'input_ids', 'token_type_ids', 'attention_mask', 'label'])
    super().__init__(mapped_dataset)
Exemple #15
0
 def maybe_save_checkpoint(self, experiment_state: Mapping[Text,
                                                           jnp.ndarray],
                           step: int, rng: jnp.ndarray, is_final: bool):
     """Saves a checkpoint if enough time has passed since the previous one."""
     current_time = time.time()
     if (not self._checkpoint_enabled or jax.host_id() != 0
             or  # Only checkpoint the first worker.
         (not is_final and current_time - self._last_checkpoint_time <
          self._checkpoint_every)):
         return
     checkpoint_data = dict(experiment_state=jax.tree_map(
         lambda x: jax.device_get(x[0]), experiment_state),
                            step=step,
                            rng=rng)
     with open(self._checkpoint_path + '_tmp', 'wb') as checkpoint_file:
         dill.dump(checkpoint_data, checkpoint_file, protocol=2)
     try:
         os.rename(self._checkpoint_path, self._checkpoint_path + '_old')
         remove_old = True
     except FileNotFoundError:
         remove_old = False  # No previous checkpoint to remove
     os.rename(self._checkpoint_path + '_tmp', self._checkpoint_path)
     if remove_old:
         os.remove(self._checkpoint_path + '_old')
     self._last_checkpoint_time = current_time
Exemple #16
0
def prepare_batches_gen(dataset, eval_config):
  """Returns a data iterator.

  The API for the data iterator will be
    for b in batches_gen():
      pass
  We yield the same "epoch" every time to the data iterator is called.

  Args:
    dataset: An init2winit.dataset_lib.Dataset object. This is ignored if
      eval_config['use_training_gen'] == False.
    eval_config: A dict specifying the parameters for the hessian eval.

  Returns:
    A data generator.
  """
  train_iter = itertools.islice(dataset.train_iterator_fn(), 0,
                                eval_config['num_batches'])
  batches = list(train_iter)
  init_rng = jax.random.PRNGKey(eval_config['rng_key'])
  init_rng = jax.random.fold_in(init_rng, jax.host_id())
  def training_batches_gen():
    for counter, batch in enumerate(batches):
      batch = data_utils.shard(batch)
      rng = jax.random.fold_in(init_rng, counter)
      rng = jax_utils.replicate(rng)
      yield (batch, rng)
  return training_batches_gen
Exemple #17
0
def specialize_rng_host_device(rng,
                               axis_name,
                               mode="unique_host_unique_device"):
    """Specializes a rng to the host/device we are on.

  Must be called from within a pmapped function.

  Args:
    rng: a jax.random.PRNGKey.
    axis_name: the axis of the devices we are specializing across.
    mode: str mode. Must be one of "unique_host_unique_device",
      "unique_host_same_device", "same_host_unique_device",
      "same_host_same_device".
  Returns:
    jax.random.PRNGKey specialized to host/device.
  """
    # Will throw an error if mode is not a valid enumeration.
    enum_mode = DistributedRNGMode(mode)
    if enum_mode in [
            DistributedRNGMode.UNIQUE_HOST_UNIQUE_DEVICE,
            DistributedRNGMode.UNIQUE_HOST_SAME_DEVICE
    ]:
        rng = jax.random.fold_in(rng, jax.host_id())
    if enum_mode in [
            DistributedRNGMode.UNIQUE_HOST_UNIQUE_DEVICE,
            DistributedRNGMode.SAME_HOST_UNIQUE_DEVICE
    ]:
        rng = jax.random.fold_in(rng, jax.lax.axis_index(axis_name))
    return rng
Exemple #18
0
def _get_tfds_dataset(
        dataset: str,
        rng: np.ndarray) -> Tuple[tf.data.Dataset, tf.data.Dataset, int]:
    """Loads a TFDS dataset."""

    dataset_builder = tfds.builder(dataset)
    num_classes = 0
    if "label" in dataset_builder.info.features:
        num_classes = dataset_builder.info.features["label"].num_classes

    # Make sure each host uses a different RNG for the training data.
    rng, data_rng = jax.random.split(rng)
    data_rng = jax.random.fold_in(data_rng, jax.host_id())
    data_rng, shuffle_rng = jax.random.split(data_rng)
    train_split = deterministic_data.get_read_instruction_for_host(
        "train", dataset_builder.info.splits["train"].num_examples)
    train_read_config = tfds.ReadConfig(shuffle_seed=shuffle_rng[0])
    train_ds = dataset_builder.as_dataset(split=train_split,
                                          shuffle_files=True,
                                          read_config=train_read_config)

    eval_split_name = {
        "cifar10": "test",
        "imagenet2012": "validation"
    }.get(dataset, "test")

    eval_split_size = dataset_builder.info.splits[eval_split_name].num_examples
    eval_split = deterministic_data.get_read_instruction_for_host(
        eval_split_name, eval_split_size)
    eval_read_config = tfds.ReadConfig(shuffle_seed=shuffle_rng[1])
    eval_ds = dataset_builder.as_dataset(split=eval_split,
                                         shuffle_files=False,
                                         read_config=eval_read_config)
    return train_ds, eval_ds, num_classes
Exemple #19
0
def load_extra(batch_sizes: Sequence[int],
               path_npz: str,
               is_training: bool = True,
               drop_remainder: bool = True) -> tf.data.Dataset:
    """Loads extra data from a given path."""
    if not tf.io.gfile.exists(path_npz):
        if path_npz in _ALLOWED_FILES:
            path_npz = tf.keras.utils.get_file(path_npz, _DATA_URL + path_npz)
        else:
            raise ValueError(
                f'Extra data not found ({path_npz}). See {_WEBPAGE} for '
                'more details.')
    with tf.io.gfile.GFile(path_npz, 'rb') as fp:
        npzfile = np.load(fp)
        data = {'image': npzfile['image'], 'label': npzfile['label']}
        with tf.device(
                '/device:cpu:0'):  # Prevent allocation to happen on GPU.
            ds = tf.data.Dataset.from_tensor_slices(data)
    ds = ds.cache()
    if is_training:
        ds = ds.repeat()
        ds = ds.shuffle(buffer_size=50_000, seed=jax.host_id())
    ds = ds.map(cifar10_preprocess('train' if is_training else 'test'),
                num_parallel_calls=tf.data.AUTOTUNE)
    for batch_size in reversed(batch_sizes):
        ds = ds.batch(batch_size, drop_remainder=drop_remainder)
    return ds.prefetch(tf.data.AUTOTUNE)
Exemple #20
0
def parallel_write_images(image_write_fn, img_and_path_list):
    """Parallelizes image writing over JAX hosts and CPU cores.

  Args:
    image_write_fn: A function that takes a tuple as input (path, image) and
      writes the result to disk.
    img_and_path_list: A list of tuples (image, path) containing all the images
      that should be written.
  """
    num_hosts = jax.host_count()
    host_id = jax.host_id()
    num_images = len(img_and_path_list)
    num_images_per_batch = math.ceil(num_images / num_hosts)

    # First shard the images onto each host.
    per_host_images_and_paths = []
    for i in range(num_images_per_batch):
        base_index = i * num_hosts
        global_index = base_index + host_id
        if global_index < num_images:
            per_host_images_and_paths.append(img_and_path_list[global_index])

    # Now within each JAX host, use multi-processing to save the sharded images.
    with multiprocessing.pool.ThreadPool() as pool:
        pool.map(image_write_fn, per_host_images_and_paths)
        pool.close()
        pool.join()
def load_split(train: bool,
               cache: bool) -> tf.data.Dataset:
  """Creates a split from the ImageNet dataset using TensorFlow Datasets.

  Args:
    train: Whether to load the train or evaluation split.
    cache: Whether to cache the dataset.
  Returns:
    A `tf.data.Dataset`.
  """
  if train:
    split_size = TRAIN_IMAGES // jax.host_count()
    start = jax.host_id() * split_size
    split = 'train[{}:{}]'.format(start, start + split_size)
  else:
    # For validation, we load up the dataset on each host. This will have the
    # effect of evaluating on the whole dataset num_host times, but will
    # prevent size issues. This makes the performance slightly worse when
    # evaluating often, but spares us the need to pad the datasets and mask the
    # loss accordingly.
    split = 'validation'

  ds = tfds.load('imagenet2012:5.*.*', split=split, decoders={
      'image': tfds.decode.SkipDecoding(),
  })
  ds.options().experimental_threading.private_threadpool_size = 48
  ds.options().experimental_threading.max_intra_op_parallelism = 1

  if cache:
    ds = ds.cache()

  return ds
def main():
    args = parser.parse_args()
    logging.set_verbosity(logging.ERROR)
    print('JAX host: %d / %d' % (jax.host_id(), jax.host_count()))
    print('JAX devices:\n%s' % '\n'.join(str(d) for d in jax.devices()), flush=True)

    if get_model_cfg(args.model) is not None:
        validate(args)
    else:
        models = list_models(pretrained=True)
        if args.model != 'all':
            models = fnmatch.filter(models, args.model)
        if not models:
            print(f'ERROR: No models found to validate with pattern ({args.model}).')
            exit(1)

        print('Validating:', ', '.join(models))
        results = []
        for m in models:
            args.model = m
            res = validate(args)
            res.update(dict(model=m))
            results.append(res)
        print('Results:')
        for r in results:
            print(f"Model: {r['model']}, Top1: {r['top1']}, Top5: {r['top5']}")
Exemple #23
0
def test(optimizer, state, p_eval_step, step, test_ds, summary_writer):
    """Test the flax module in optimizer on test_ds.

  Args:
    optimizer: flax optimizer (contains flax module).
    state: model state, e.g. batch statistics.
    p_eval_step: fn; Pmapped evaluation step function.
    step: int; Number of training steps passed so far.
    test_ds: tf.dataset; Test dataset.
    summary_writer: tensorflow summary writer.
  """
    # Test Metrics
    test_metrics = []
    test_iter = iter(test_ds)
    for _, test_batch in zip(itertools.repeat(1), test_iter):
        # pylint: disable=protected-access
        test_batch = common_utils.shard(
            jax.tree_map(lambda x: x._numpy(), test_batch))
        # pylint: enable=protected-access
        metrics = p_eval_step(optimizer.target, state, test_batch)
        test_metrics.append(metrics)
    test_metrics = common_utils.get_metrics(test_metrics)
    test_metrics_sums = jax.tree_map(jnp.sum, test_metrics)
    test_denominator = test_metrics_sums.pop('denominator')
    test_summary = jax.tree_map(
        lambda x: x / test_denominator,  # pylint: disable=cell-var-from-loop
        test_metrics_sums)
    logging.info('test in step: %d, loss: %.4f, acc: %.4f', step,
                 test_summary['loss'], test_summary['accuracy'])
    if jax.host_id() == 0:
        for key, val in test_summary.items():
            summary_writer.scalar(f'test_{key}', val, step)
        summary_writer.flush()
    def train(self):
        """Training loop."""

        master = jax.host_id() == 0
        train_metrics = []
        train_summary, eval_summary = None, None

        tick = time.time()

        # Main train loop.
        for step in range(self.start_step + 1, self.total_steps + 1):
            train_batch = self.get_next_batch(
                self.task.dataset.data_iters.train)
            self.train_state, t_metrics = self.pmapped_train_step(
                self.train_state, train_batch)
            train_metrics.append(t_metrics)

            eval_summary, train_metrics, train_summary, tick = self.maybe_eval_and_log(
                eval_summary, master, step, tick, train_metrics, train_summary)

            # sync and save
            self.train_state = self.checkpoint(self.train_state, step)

        # wait until computations are done before exiting (for timing!)
        jax.random.normal(jax.random.PRNGKey(0), ()).block_until_ready()

        # return the train and eval summary after last step for regresesion testing
        return train_summary, eval_summary
Exemple #25
0
    def _init_host_and_devices(self, n_devices=None, random_seed=None):
        """Initializes host and device attributes for this trainer.

    Args:
      n_devices: Number of devices this trainer will use. If `None`, get the
          number from the backend.
      random_seed: Random seed as the starting point for all random numbers used
          by the trainer. If `None`, calculate one from system time and host id.

    Returns:
      is_chief: True if this trainer has special chief responsibilities.
      n_devices: The passed in value of n_devices or a computed default.
      random_seed: The passed in value of random_seed or a computed default.
    """
        if math.backend_name() == 'jax':
            host_id = jax.host_id()
            host_count = jax.host_count()
        else:
            host_id = 0
            host_count = 1
        is_chief = (host_id == 0)

        device_count = math.device_count()
        n_devices = n_devices or device_count
        # TODO(lukaszkaiser): remove this restriction when possible.
        if n_devices != device_count and math.backend_name() == 'jax':
            raise ValueError(
                'JAX cannot work yet with n_devices != all devices: '
                '%d != %d' % (n_devices, device_count))

        if random_seed is None and host_count > 1:
            random_seed = int(1e6 * (host_id + time.time())) % 2**32
        return is_chief, n_devices, init_random_number_generators(random_seed)
    def checkpoint(self, train_state, step):
        """Saves checkpoint.

    Syncs the model state across replicas if needed.

    Args:
      train_state: TrainSate; A flax struct that keeps model state and optimizer
        state.
      step: int; Number of steps passes so far during training.

    Returns:
      train_state
    """
        checkpoint_flag = False
        if self.hparams.get('ckpnt_steps', None) and self.hparams.checkpoint:
            if step in self.hparams.get('ckpnt_steps'):
                checkpoint_flag = True
        elif ((step % self.checkpoint_frequency == 0) or
              (step == self.total_steps)) and self.hparams.checkpoint:
            checkpoint_flag = True

        if checkpoint_flag:
            # Sync model state across replicas.
            train_state = pipeline_utils.sync_model_state_across_replicas(
                train_state)
            if jax.host_id() == 0:
                pipeline_utils.save_checkpoint(self.experiment_dir,
                                               train_state,
                                               keep=self.hparams.keep_ckpts)

        return train_state
Exemple #27
0
def save_checkpoint(optimizer: flax.optim.Optimizer,
                    model_state: Any,
                    directory: str,
                    epoch: int):
  """Saves a model and its state.

  Removes a checkpoint if it already exists for a given epoch. For multi-host
  training, only the first host will save the checkpoint.

  Args:
    optimizer: The optimizer containing the model that we are training.
    model_state: Current state associated with the model.
    directory: Directory where the checkpoints should be saved.
    epoch: Number of epochs the model has been trained for.
  """
  if jax.host_id() != 0:
    return
  # Sync across replicas before saving.
  optimizer = jax.tree_map(lambda x: x[0], optimizer)
  model_state = jax.tree_map(lambda x: jnp.mean(x, axis=0), model_state)
  train_state = dict(optimizer=optimizer,
                     model_state=model_state,
                     epoch=epoch)
  if gfile.exists(os.path.join(directory, 'checkpoint_' + str(epoch))):
    gfile.remove(os.path.join(directory, 'checkpoint_' + str(epoch)))
  checkpoints.save_checkpoint(directory, train_state, epoch, keep=2)
def eval_loop(experiment_class, config):
    """The main evaluation loop.

  This loop periodically loads a checkpoint and evaluates its performance on the
  test set, by calling experiment.evaluate.

  Args:
    experiment_class: the constructor for the experiment (either byol_experiment
    or eval_experiment).
    config: the experiment config.
  """
    experiment = experiment_class(**config)
    last_evaluated_step = -1
    while True:
        checkpoint_data = experiment.load_checkpoint()
        if checkpoint_data is None:
            logging.info('No checkpoint found. Waiting for 10s.')
            time.sleep(10)
            continue
        step, _ = checkpoint_data
        if step <= last_evaluated_step:
            logging.info('Checkpoint at step %d already evaluated, waiting.',
                         step)
            time.sleep(10)
            continue
        host_id = jax.host_id()
        local_device_count = jax.local_device_count()
        step_device = np.broadcast_to(step, [local_device_count])
        scalars = experiment.evaluate(global_step=step_device)
        if host_id == 0:  # Only perform logging in one host.
            logging.info('Evaluation at step %d: %s', step, scalars)
        last_evaluated_step = step
        if last_evaluated_step >= config['max_steps']:
            return
Exemple #29
0
def main(argv):
  if len(argv) > 1:
    raise app.UsageError('Too many command-line arguments.')

  if not gfile.IsDirectory(FLAGS.model_dir):
    gfile.MakeDirs(os.path.dirname(FLAGS.model_dir))

  logging.info('Number of recognized devices: %d', jax.local_device_count())
  logging.info('Import pretrained weights: %s', FLAGS.load_tf_weights)
  jax_squad_model = get_squad_model()

  with tf.io.gfile.GFile(FLAGS.input_meta_data_path, 'rb') as reader:
    input_meta_data = json.loads(reader.read().decode('utf-8'))

  step = 0
  optimizer = create_optimizer(jax_squad_model, FLAGS.learning_rate)
  if FLAGS.load_checkpoint:
    optimizer = checkpoints.restore_checkpoint(FLAGS.model_dir, optimizer)
    step = optimizer.state[0].step[0]

  if FLAGS.mode in ('train', 'train_and_predict'):
    optimizer = train_squad(optimizer, input_meta_data, step)

  if FLAGS.mode in ('predict', 'train_and_predict'):
    if not FLAGS.use_eval_sharding:
      optimizer = optimizer.unreplicate()
    if jax.host_id() == 0:
      predict_squad(optimizer, input_meta_data)

  global_barrier()
    def train(self):
        """Training loop."""

        master = jax.host_id() == 0

        train_metrics = []
        train_summary, eval_summary = None, None
        tick = time.time()
        eval_env_ids = list(
            map(int, self.task.dataset.data_iters.validation.keys()))
        train_env_ids, train_iters = list(
            zip(*dict(self.task.dataset.data_iters['train']).items()))
        train_env_ids = list(map(int, train_env_ids))

        for step in range(self.start_step + 1, self.total_steps + 1):
            train_batches = self.get_next_batch(train_iters)
            self.train_state, t_metrics = self.pmapped_train_step(
                self.train_state, train_batches)

            t_metrics = jax.tree_map(lambda x: x[0], t_metrics)
            train_metrics.append(t_metrics)

            eval_summary, train_metrics, train_summary, tick = self.maybe_eval_and_log(
                eval_env_ids, eval_summary, master, step, tick, train_metrics,
                train_summary)

            # Sync and save
            self.train_state = self.checkpoint(self.train_state, step)

        # wait until computations are done before exiting (for timing!)
        jax.random.normal(jax.random.PRNGKey(0), ()).block_until_ready()

        # return the train and eval summary after last step for regresesion testing
        return train_summary, eval_summary