Example #1
0
    def test_every_steps(self, mock_time, mock_profiler):
        start_steps = []
        stop_steps = []
        step = 0

        def add_start_step(logdir):
            del logdir  # unused
            start_steps.append(step)

        def add_stop_step():
            stop_steps.append(step)

        mock_profiler.start.side_effect = add_start_step
        mock_profiler.stop.side_effect = add_stop_step
        hook = periodic_actions.Profile(logdir=tempfile.mkdtemp(),
                                        num_profile_steps=2,
                                        profile_duration_ms=2_000,
                                        first_profile=3,
                                        every_steps=7)
        for step in range(1, 18):
            mock_time.return_value = step - 0.5 if step == 9 else step
            hook(step)
        self.assertAllEqual([3, 7, 14], start_steps)
        # Note: profiling 7..10 instead of 7..9 because 7..9 took only 1.5 seconds.
        self.assertAllEqual([5, 10, 16], stop_steps)
Example #2
0
    def test_every_steps(self, mock_profiler):
        start_steps = []
        stop_steps = []
        step = 0

        def add_start_step():
            start_steps.append(step)

        def add_stop_step():
            stop_steps.append(step)

        mock_profiler.start.side_effect = add_start_step
        mock_profiler.stop.side_effect = add_stop_step
        hook = periodic_actions.Profile(num_profile_steps=2,
                                        first_profile=3,
                                        every_steps=7)
        for step in range(1, 18):
            hook(step)
        self.assertAllEqual([3, 7, 14], start_steps)
        self.assertAllEqual([5, 9, 16], stop_steps)
Example #3
0
def train_and_evaluate(config: ml_collections.ConfigDict, workdir: str):
  """Runs training interleaved with evaluation."""

  # Setup input pipeline
  dataset_info = input_pipeline.get_dataset_info(config.dataset, 'train')

  ds_train, ds_test = input_pipeline.get_datasets(config)
  batch = next(iter(ds_train))
  logging.info(ds_train)
  logging.info(ds_test)

  # Build VisionTransformer architecture
  model_cls = {'ViT': models.VisionTransformer,
               'Mixer': models.MlpMixer}[config.get('model_type', 'ViT')]
  model = model_cls(num_classes=dataset_info['num_classes'], **config.model)

  def init_model():
    return model.init(
        jax.random.PRNGKey(0),
        # Discard the "num_local_devices" dimension for initialization.
        jnp.ones(batch['image'].shape[1:], batch['image'].dtype.name),
        train=False)

  # Use JIT to make sure params reside in CPU memory.
  variables = jax.jit(init_model, backend='cpu')()

  model_or_filename = config.get('model_or_filename')
  if model_or_filename:
    # Loading model from repo published with  "How to train your ViT? Data,
    # Augmentation, and Regularization in Vision Transformers" paper.
    # https://arxiv.org/abs/2106.10270
    if '-' in model_or_filename:
      filename = model_or_filename
    else:
      # Select best checkpoint from i21k pretraining by final upstream
      # validation accuracy.
      df = checkpoint.get_augreg_df(directory=config.pretrained_dir)
      sel = df.filename.apply(
          lambda filename: filename.split('-')[0] == model_or_filename)
      best = df.loc[sel].query('ds=="i21k"').sort_values('final_val').iloc[-1]
      filename = best.filename
      logging.info('Selected fillename="%s" for "%s" with final_val=%.3f',
                   filename, model_or_filename, best.final_val)
    pretrained_path = os.path.join(config.pretrained_dir,
                                   f'{config.model.name}.npz')
  else:
    # ViT / Mixer papers
    filename = config.model.name

  pretrained_path = os.path.join(config.pretrained_dir, f'{filename}.npz')
  if not tf.io.gfile.exists(pretrained_path):
    raise ValueError(
        f'Could not find "{pretrained_path}" - you can download models from '
        '"gs://vit_models/imagenet21k" or directly set '
        '--config.pretrained_dir="gs://vit_models/imagenet21k".')
  params = checkpoint.load_pretrained(
      pretrained_path=pretrained_path,
      init_params=variables['params'],
      model_config=config.model)

  total_steps = config.total_steps
  lr_fn = utils.create_learning_rate_schedule(total_steps, config.base_lr,
                                              config.decay_type,
                                              config.warmup_steps)

  update_fn_repl = make_update_fn(
      apply_fn=model.apply, accum_steps=config.accum_steps, lr_fn=lr_fn)
  infer_fn_repl = jax.pmap(functools.partial(model.apply, train=False))

  # Create optimizer and replicate it over all TPUs/GPUs
  opt = momentum_clip.Optimizer(
      dtype=config.optim_dtype,
      grad_norm_clip=config.grad_norm_clip).create(params)

  initial_step = 1
  opt, initial_step = flax_checkpoints.restore_checkpoint(
      workdir, (opt, initial_step))
  logging.info('Will start/continue training at initial_step=%d', initial_step)

  opt_repl = flax.jax_utils.replicate(opt)

  # Delete references to the objects that are not needed anymore
  del opt
  del params

  # Prepare the learning-rate and pre-fetch it to device to avoid delays.
  update_rng_repl = flax.jax_utils.replicate(jax.random.PRNGKey(0))

  # Setup metric writer & hooks.
  writer = metric_writers.create_default_writer(workdir, asynchronous=False)
  writer.write_hparams(config.to_dict())
  hooks = [
      periodic_actions.Profile(logdir=workdir),
      periodic_actions.ReportProgress(
          num_train_steps=total_steps, writer=writer),
  ]

  # Run training loop
  logging.info('Starting training loop; initial compile can take a while...')
  t0 = lt0 = time.time()
  lstep = initial_step
  for step, batch in zip(
      range(initial_step, total_steps + 1),
      input_pipeline.prefetch(ds_train, config.prefetch)):

    with jax.profiler.StepTraceContext('train', step_num=step):
      opt_repl, loss_repl, update_rng_repl = update_fn_repl(
          opt_repl, flax.jax_utils.replicate(step), batch, update_rng_repl)

    for hook in hooks:
      hook(step)

    if step == initial_step:
      logging.info('First step took %.1f seconds.', time.time() - t0)
      t0 = time.time()
      lt0, lstep = time.time(), step

    # Report training metrics
    if config.progress_every and step % config.progress_every == 0:
      img_sec_core_train = (config.batch * (step - lstep) /
                            (time.time() - lt0)) / jax.device_count()
      lt0, lstep = time.time(), step
      writer.write_scalars(
          step,
          dict(
              train_loss=float(flax.jax_utils.unreplicate(loss_repl)),
              img_sec_core_train=img_sec_core_train))
      done = step / total_steps
      logging.info(f'Step: {step}/{total_steps} {100*done:.1f}%, '  # pylint: disable=logging-format-interpolation
                   f'img/sec/core: {img_sec_core_train:.1f}, '
                   f'ETA: {(time.time()-t0)/done*(1-done)/3600:.2f}h')

    # Run evaluation
    if ((config.eval_every and step % config.eval_every == 0) or
        (step == total_steps)):

      accuracies = []
      lt0 = time.time()
      for test_batch in input_pipeline.prefetch(ds_test, config.prefetch):
        logits = infer_fn_repl(
            dict(params=opt_repl.target), test_batch['image'])
        accuracies.append(
            (np.argmax(logits,
                       axis=-1) == np.argmax(test_batch['label'],
                                             axis=-1)).mean())
      accuracy_test = np.mean(accuracies)
      img_sec_core_test = (
          config.batch_eval * ds_test.cardinality().numpy() /
          (time.time() - lt0) / jax.device_count())
      lt0 = time.time()

      lr = float(lr_fn(step))
      logging.info(f'Step: {step} '  # pylint: disable=logging-format-interpolation
                   f'Learning rate: {lr:.7f}, '
                   f'Test accuracy: {accuracy_test:0.5f}, '
                   f'img/sec/core: {img_sec_core_test:.1f}')
      writer.write_scalars(
          step,
          dict(
              accuracy_test=accuracy_test,
              lr=lr,
              img_sec_core_test=img_sec_core_test))

    # Store checkpoint.
    if ((config.checkpoint_every and step % config.eval_every == 0) or
        step == total_steps):
      checkpoint_path = flax_checkpoints.save_checkpoint(
          workdir, (flax.jax_utils.unreplicate(opt_repl), step), step)
      logging.info('Stored checkpoint at step %d to "%s"', step,
                   checkpoint_path)

  return flax.jax_utils.unreplicate(opt_repl)
Example #4
0
def train_and_evaluate(config, workdir):
    """Runs a training and evaluation loop.

  Args:
    config: Configuration to use.
    workdir: Working directory for checkpoints and TF summaries. If this
      contains checkpoint training will be resumed from the latest checkpoint.
  """
    logging.info('Starting training at %s', workdir)
    tf.io.gfile.makedirs(workdir)
    if jax.process_index() == 0:
        with tf.io.gfile.GFile(os.path.join(workdir, 'config.json'), 'w') as f:
            json.dump(config.to_dict(), f, indent=2)
    rng = jax.random.PRNGKey(config.seed)

    # Build input pipeline.
    rng, data_rng = jax.random.split(rng)
    data_rng = jax.random.fold_in(data_rng, jax.process_index())
    train_ds, eval_ds = input_pipeline.create_datasets(config.dataset,
                                                       data_rng)
    train_iter = iter(train_ds)

    test_ds = []
    for split in config.dataset.test_splits:
        ds = input_pipeline.create_val_dataset(
            config.dataset, split, config.dataset.test_per_device_batch_size,
            config.dataset.test_pad_last_batch)
        test_ds.append(ds)

    # Learning rate schedule.
    num_train_steps = config.num_train_steps
    if num_train_steps == -1:
        num_train_steps = train_ds.cardinality().numpy()
    steps_per_epoch = num_train_steps // config.dataset.num_epochs
    logging.info('num_train_steps=%d, steps_per_epoch=%d', num_train_steps,
                 steps_per_epoch)
    learning_rate_fn = functools.partial(
        train_utils.get_learning_rate,
        base_learning_rate=config.learning_rate,
        num_train_steps=num_train_steps,
        schedule_type=config.learning_rate_schedule,
        warmup_proportion=config.warmup_proportion,
        step_boundaries=config.learning_rate_step_boundaries)

    # Initialize model.
    inputs = train_utils.get_init_inputs(train_ds)
    rng, model_rng = jax.random.split(rng)
    eval_config = models.TransformerConfig(**config.model.to_dict())
    train_config = eval_config.replace(deterministic=False)
    model = models.Model(eval_config)
    state = train_utils.create_train_state(model,
                                           config,
                                           model_rng,
                                           inputs=inputs)

    # Set up checkpointing of the model and the input pipeline.
    checkpoint_dir = os.path.join(workdir, 'checkpoints')
    ckpt = checkpoint.MultihostCheckpoint(checkpoint_dir, max_to_keep=3)
    state = ckpt.restore_or_initialize(state)
    initial_step = int(state.step) + 1

    # Distribute training.
    state = flax_utils.replicate(state)
    p_train_step = jax.pmap(functools.partial(
        train_step,
        config=train_config,
        learning_rate_fn=learning_rate_fn,
        grad_clip=config.grad_clip),
                            axis_name='batch',
                            donate_argnums=(0, ))
    p_eval_step = jax.pmap(functools.partial(eval_step, config=eval_config),
                           axis_name='batch')

    writer = metric_writers.create_default_writer(
        workdir, just_logging=jax.process_index() > 0)
    if initial_step == 1:
        writer.write_hparams(train_utils.flatten_config(config))

    logging.info('Starting training loop at step %d.', initial_step)
    hooks = []
    report_progress = periodic_actions.ReportProgress(
        num_train_steps=num_train_steps, writer=writer)
    if jax.process_index() == 0:
        hooks += [
            report_progress,
            periodic_actions.Profile(
                num_profile_steps=config.num_profile_steps, logdir=workdir)
        ]

    rng, train_rngs = jax.random.split(rng)
    train_rngs = jax.random.fold_in(train_rngs, jax.process_index())
    train_rngs = jax.random.split(train_rngs, jax.local_device_count())

    train_metrics = []
    with metric_writers.ensure_flushes(writer):
        for step in range(initial_step, num_train_steps + 1):
            is_last_step = step == num_train_steps
            with jax.profiler.StepTraceContext('train', step_num=step):
                batch = jax.tree_map(np.asarray, next(train_iter))
                state, metrics = p_train_step(batch=batch,
                                              rng=train_rngs,
                                              state=state)
                train_metrics.append(metrics)

            # Quick indication that training is happening.
            logging.log_first_n(logging.INFO, 'Finished training step %d.', 5,
                                step)
            for h in hooks:
                h(step)

            if config.log_loss_every_steps > 0 and (
                    step % config.log_loss_every_steps == 0 or is_last_step):
                train_metrics = common_utils.get_metrics(train_metrics)
                lr = train_metrics.pop('learning_rate').mean()
                train_summary = train_utils.metrics_summary(
                    train_metrics, 'train')
                train_summary['learning_rate'] = lr
                writer.write_scalars(step, train_summary)
                train_metrics = []

            if config.eval_every_steps > 0 and (step % config.eval_every_steps
                                                == 0 or is_last_step):
                with report_progress.timed('eval'):
                    eval_summary = evaluate(p_eval_step, state, eval_ds,
                                            config.num_eval_steps)
                writer.write_scalars(step, eval_summary)

            if config.checkpoint_every_steps > 0 and (
                    step % config.checkpoint_every_steps == 0 or is_last_step):
                with report_progress.timed('checkpoint'):
                    ckpt.save(flax_utils.unreplicate(state))
                logging.info('Checkpoint saved to %s', checkpoint_dir)

    logging.info('Finishing training at step %d', num_train_steps)
Example #5
0
def main(_):
    if FLAGS.config.precrop_iters > 0 and FLAGS.config.batching:
        raise ValueError(
            "'precrop_iters has no effect when 'batching' the dataset")
    assert FLAGS.config.down_factor > 0 and FLAGS.config.render_factor > 0

    logging.info("JAX host: %d / %d", jax.process_index(), jax.host_count())
    logging.info("JAX local devices: %r", jax.local_devices())

    platform.work_unit().set_task_status(
        f"host_id: {jax.process_index()}, host_count: {jax.host_count()}")
    platform.work_unit().create_artifact(platform.ArtifactType.DIRECTORY,
                                         FLAGS.model_dir, "model_dir")

    os.makedirs(FLAGS.model_dir, exist_ok=True)
    rng = jax.random.PRNGKey(FLAGS.seed)
    rng, rng_coarse, rng_fine, data_rng, step_rng = jax.random.split(rng, 5)
    rngs = common_utils.shard_prng_key(step_rng)

    ### Load dataset and data values
    datasets, counts, optics, render_datasets = get_dataset(
        FLAGS.data_dir,
        FLAGS.config,
        rng=data_rng,
        num_poses=FLAGS.config.num_poses)
    train_ds, val_ds, test_ds = datasets
    *_, test_items = counts
    hwf, r_hwf, near, far = optics
    render_ds, render_vdirs_ds, num_poses = render_datasets
    iter_render_ds = zip(range(num_poses), render_ds)
    iter_vdirs_ds = zip(range(num_poses), render_vdirs_ds)
    iter_test_ds = zip(range(test_items), test_ds)
    img_h, img_w, _ = hwf

    logging.info("Num poses: %d", num_poses)
    logging.info("Splits: train - %d, val - %d, test - %d", *counts)
    logging.info("Images: height %d, width %d, focal %.5f", *hwf)
    logging.info("Render: height %d, width %d, focal %.5f", *r_hwf)

    ### Init model parameters and optimizer
    initialized_ = functools.partial(initialized,
                                     model_config=FLAGS.config.model)
    pts_shape = (FLAGS.config.num_rand, FLAGS.config.num_samples, 3)
    views_shape = (FLAGS.config.num_rand, 3)
    model_coarse, params_coarse = initialized_(rng_coarse, pts_shape,
                                               views_shape)

    schedule_fn = optax.exponential_decay(
        init_value=FLAGS.config.learning_rate,
        transition_steps=FLAGS.config.lr_decay * 1000,
        decay_rate=FLAGS.config.decay_factor,
    )
    tx = optax.adam(learning_rate=schedule_fn)
    state = train_state.TrainState.create(apply_fn=(model_coarse.apply, None),
                                          params={"coarse": params_coarse},
                                          tx=tx)

    if FLAGS.config.num_importance > 0:
        pts_shape = (
            FLAGS.config.num_rand,
            FLAGS.config.num_importance + FLAGS.config.num_samples,
            3,
        )
        model_fine, params_fine = initialized_(rng_fine, pts_shape,
                                               views_shape)
        state = train_state.TrainState.create(
            apply_fn=(model_coarse.apply, model_fine.apply),
            params={
                "coarse": params_coarse,
                "fine": params_fine
            },
            tx=tx,
        )

    state = checkpoints.restore_checkpoint(FLAGS.model_dir, state)
    start_step = int(state.step)

    # cycle already seen examples if resuming from checkpoint
    # (only useful for ensuring deterministic dataset, slow for large start_step)
    # if start_step != 0:
    #     for _ in range(start_step):
    #         _ = next(train_ds)

    # parameter_overview.log_parameter_overview(state.optimizer_coarse.target)
    # if FLAGS.config.num_importance > 0:
    #     parameter_overview.log_parameter_overview(state.optimizer_fine.target)

    state = jax.device_put_replicated(state, jax.local_devices())

    ### Build "pmapped" functions for distributed training
    train_fn = functools.partial(train_step, near, far, FLAGS.config,
                                 schedule_fn)
    p_train_step = jax.pmap(
        train_fn,
        axis_name="batch",
        in_axes=(0, 0, None, 0),
        # donate_argnums=(0, 1, 2),
    )

    def render_fn(state, rays):
        step_fn = functools.partial(eval_step, FLAGS.config, near, far, state)
        return lax.map(step_fn, rays)

    p_eval_step = jax.pmap(
        render_fn,
        axis_name="batch",
        # in_axes=(0, 0, None),
        # donate_argnums=(0, 1))
    )

    # TODO: add hparams
    writer = metric_writers.create_default_writer(
        FLAGS.model_dir, just_logging=jax.process_index() > 0)
    logging.info("Starting training loop.")

    hooks = []
    profiler = periodic_actions.Profile(num_profile_steps=5,
                                        logdir=FLAGS.model_dir)
    report_progress = periodic_actions.ReportProgress(
        num_train_steps=FLAGS.config.num_steps, writer=writer)
    if jax.process_index() == 0:
        hooks += [profiler, report_progress]
    train_metrics = []
    gen_video_ = functools.partial(gen_video, FLAGS.model_dir)

    for step in range(start_step, FLAGS.config.num_steps + 1):
        is_last_step = step == FLAGS.config.num_steps

        batch = next(train_ds)
        coords = None
        if not FLAGS.config.batching:
            coords = jnp.meshgrid(jnp.arange(img_h),
                                  jnp.arange(img_w),
                                  indexing="ij")
            if step < FLAGS.config.precrop_iters:
                dH = int(img_h // 2 * FLAGS.config.precrop_frac)
                dW = int(img_w // 2 * FLAGS.config.precrop_frac)
                coords = jnp.meshgrid(
                    jnp.arange(img_h // 2 - dH, img_h // 2 + dH),
                    jnp.arange(img_w // 2 - dW, img_w // 2 + dW),
                    indexing="ij",
                )
            coords = jnp.stack(coords, axis=-1).reshape([-1, 2])

        with jax.profiler.StepTraceAnnotation("train", step_num=step):
            state, metrics = p_train_step(batch, state, coords, rngs)
        train_metrics.append(metrics)

        logging.log_first_n(logging.INFO, "Finished training step %d.", 5,
                            step)
        _ = [h(step) for h in hooks]

        ### Write train summaries to TB
        if step % FLAGS.config.i_print == 0 or is_last_step:
            with report_progress.timed("training_metrics"):
                train_metrics = common_utils.get_metrics(train_metrics)
                train_summary = jax.tree_map(lambda x: x.mean(), train_metrics)
                summary = {f"train/{k}": v for k, v in train_summary.items()}
                writer.write_scalars(step, summary)
            train_metrics = []

        ### Eval a random validation image and plot it to TB
        if step % FLAGS.config.i_img == 0 and step > 0 or is_last_step:
            with report_progress.timed("validation"):
                inputs = next(val_ds)
                rays, padding = prepare_render_data(inputs["rays"]._numpy())
                outputs = p_eval_step(state, rays)
                preds, preds_c, z_std = jax.tree_map(
                    lambda x: to_np(x, hwf, padding), outputs)
                loss = np.mean((preds["rgb"] - inputs["image"])**2)
                summary = {"val/loss": loss, "val/psnr": psnr_fn(loss)}
                writer.write_scalars(step, summary)

                summary = {
                    "val/rgb": to_rgb(preds["rgb"]),
                    "val/target": to_np(inputs["image"], hwf, padding),
                    "val/disp": disp_post(preds["disp"], FLAGS.config),
                    "val/acc": preds["acc"],
                }
                if FLAGS.config.num_importance > 0:
                    summary["val/rgb_c"] = to_rgb(preds_c["rgb"])
                    summary["val/disp_c"] = disp_post(preds_c["disp"],
                                                      FLAGS.config)
                    summary["val/z_std"] = z_std
                writer.write_images(step, summary)

        ### Render a video with test poses
        if step % FLAGS.config.i_video == 0 and step > 0:
            with report_progress.timed("video_render"):
                logging.info("Rendering video at step %d", step)
                rgb_list = []
                disp_list = []
                for idx, inputs in tqdm(iter_render_ds, desc="Rays render"):
                    rays, padding = prepare_render_data(inputs["rays"].numpy())
                    preds, *_ = p_eval_step(state, rays)
                    preds = jax.tree_map(lambda x: to_np(x, r_hwf, padding),
                                         preds)
                    rgb_list.append(preds["rgb"])
                    disp_list.append(preds["disp"])

                gen_video_(np.stack(rgb_list), "rgb", r_hwf, step)
                disp = np.stack(disp_list)
                gen_video_(disp_post(disp, FLAGS.config),
                           "disp",
                           r_hwf,
                           step,
                           ch=1)

                if FLAGS.config.use_viewdirs:
                    rgb_list = []
                    for idx, inputs in tqdm(iter_vdirs_ds,
                                            desc="Viewdirs render"):
                        rays, padding = prepare_render_data(
                            inputs["rays"].numpy())
                        preds, *_ = p_eval_step(state, rays)
                        rgb_list.append(to_np(preds["rgb"], r_hwf, padding))
                    gen_video_(np.stack(rgb_list), "rgb_still", r_hwf, step)

        ### Save images in the test set
        if step % FLAGS.config.i_testset == 0 and step > 0:
            with report_progress.timed("test_render"):
                logging.info("Rendering test set at step %d", step)
                test_losses = []
                for idx, inputs in tqdm(iter_test_ds, desc="Test render"):
                    rays, padding = prepare_render_data(inputs["rays"].numpy())
                    preds, *_ = p_eval_step(state, rays)
                    save_test_imgs(FLAGS.model_dir, preds["rgb"], r_hwf, step,
                                   idx)

                    if FLAGS.config.render_factor == 0:
                        loss = np.mean((preds["rgb"] - inputs["image"])**2.0)
                        test_losses.append(loss)
                if FLAGS.config.render_factor == 0:
                    loss = np.mean(test_losses)
                    summary = {"test/loss": loss, "test/psnr": psnr_fn(loss)}
                    writer.write_scalars(step, summary)
        writer.flush()

        ### Save ckpt
        if step % FLAGS.config.i_weights == 0 or is_last_step:
            with report_progress.timed("checkpoint"):
                save_checkpoint(state, FLAGS.model_dir)
Example #6
0
def main(argv):
    if len(argv) > 1:
        raise app.UsageError('Too many command-line arguments.')

    # Make sure tf does not allocate gpu memory.
    tf.config.experimental.set_visible_devices([], 'GPU')

    if FLAGS.jax_backend_target:
        jax.config.FLAGS.jax_xla_backend = 'tpu_driver'
        jax.config.FLAGS.jax_backend_target = FLAGS.jax_backend_target

    # Number of local devices for this host.
    n_devices = jax.local_device_count()

    if jax.process_index() == 0:
        tf.io.gfile.makedirs(FLAGS.model_dir)

    if FLAGS.batch_size % n_devices:
        raise ValueError(
            'Batch size must be divisible by the number of devices')

    vocab_path = FLAGS.vocab_path
    if vocab_path is None:
        vocab_path = os.path.join(FLAGS.model_dir, 'sentencepiece_model')
    tf.io.gfile.makedirs(os.path.split(vocab_path)[0])

    # Load Dataset
    # ---------------------------------------------------------------------------
    logging.info('Initializing dataset.')
    train_ds, eval_ds, predict_ds, encoder = input_pipeline.get_wmt_datasets(
        dataset_name=FLAGS.dataset_name,
        eval_dataset_name=FLAGS.eval_dataset_name,
        shard_idx=jax.process_index(),
        shard_count=jax.process_count(),
        data_dir=FLAGS.data_dir,
        vocab_path=vocab_path,
        target_vocab_size=FLAGS.vocab_size,
        batch_size=FLAGS.batch_size,
        max_length=FLAGS.max_target_length,
        max_eval_length=FLAGS.max_eval_target_length,
        paracrawl_size=FLAGS.paracrawl_size,
        is_scores_path=FLAGS.is_scores_path,
        num_to_keep=FLAGS.data_selection_size,
        pseudo_path=FLAGS.pseudo_path,
        repeat_count=FLAGS.repeat_count,
        newscommentary_size=FLAGS.newscommentary_size,
        split_tokenizer=FLAGS.split_tokenizer)

    if FLAGS.aux_eval_dataset:
        aux_datasets = []
        aux_names = FLAGS.aux_eval_dataset.split(',')
        for name in aux_names:
            _, aux_eval_ds, _, _ = input_pipeline.get_wmt_datasets(
                dataset_name=name,
                eval_dataset_name=None,
                shard_idx=jax.process_index(),
                shard_count=jax.process_count(),
                data_dir=FLAGS.data_dir,
                vocab_path=vocab_path,
                target_vocab_size=FLAGS.vocab_size,
                batch_size=FLAGS.batch_size,
                max_length=FLAGS.max_target_length,
                max_eval_length=FLAGS.max_eval_target_length,
                paracrawl_size=FLAGS.paracrawl_size,
                is_scores_path=FLAGS.is_scores_path,
                num_to_keep=FLAGS.data_selection_size,
                pseudo_path=FLAGS.pseudo_path,
                repeat_count=FLAGS.repeat_count,
                newscommentary_size=FLAGS.newscommentary_size)
            aux_datasets.append(aux_eval_ds)

    train_iter = iter(train_ds)
    vocab_size = int(encoder.vocab_size())
    eos_id = decode.EOS_ID  # Default Sentencepiece EOS token.

    def decode_tokens(toks):
        valid_toks = toks[:np.argmax(toks == eos_id) + 1].astype(np.int32)
        return encoder.detokenize(valid_toks).numpy().decode('utf-8')

    logging.info('Initializing model, optimizer, and step functions.')

    # Build Model and Optimizer
    # ---------------------------------------------------------------------------
    train_config = models.TransformerConfig(
        vocab_size=vocab_size,
        output_vocab_size=vocab_size,
        share_embeddings=FLAGS.share_embeddings,
        logits_via_embedding=FLAGS.logits_via_embedding,
        dtype=jnp.bfloat16 if FLAGS.use_bfloat16 else jnp.float32,
        emb_dim=FLAGS.emb_dim,
        num_heads=FLAGS.num_heads,
        num_layers=FLAGS.num_layers,
        qkv_dim=FLAGS.qkv_dim,
        mlp_dim=FLAGS.mlp_dim,
        max_len=max(FLAGS.max_target_length, FLAGS.max_eval_target_length),
        dropout_rate=FLAGS.dropout_rate,
        attention_dropout_rate=FLAGS.attention_dropout_rate,
        deterministic=False,
        decode=False,
        kernel_init=nn.initializers.xavier_uniform(),
        bias_init=nn.initializers.normal(stddev=1e-6))
    eval_config = train_config.replace(deterministic=True)
    predict_config = train_config.replace(deterministic=True, decode=True)

    start_step = 0
    rng = jax.random.PRNGKey(FLAGS.random_seed)
    rng, init_rng = jax.random.split(rng)
    # It's possible that is supposed to be per device batch size
    input_shape = (FLAGS.batch_size, FLAGS.max_target_length)
    target_shape = (FLAGS.batch_size, FLAGS.max_target_length)

    m = models.Transformer(eval_config)
    initial_variables = jax.jit(m.init)(init_rng,
                                        jnp.ones(input_shape, jnp.float32),
                                        jnp.ones(target_shape, jnp.float32))

    # apply an optimizer to this tree
    optimizer_def = optim.Adam(FLAGS.learning_rate,
                               beta1=0.9,
                               beta2=0.98,
                               eps=1e-9,
                               weight_decay=FLAGS.weight_decay)
    optimizer = optimizer_def.create(initial_variables['params'])

    # We access model params only from optimizer below via optimizer.target.
    del initial_variables

    if FLAGS.restore_checkpoints:
        logging.info('Restoring checkpoint.')
        # If we have a pretrained model, use that. Else, just continue where leftoff
        model_path = FLAGS.pretrained_model_dir if FLAGS.pretrained_model_dir else FLAGS.model_dir
        optimizer = checkpoints.restore_checkpoint(model_path, optimizer)
        # Grab last step.
        start_step = int(optimizer.state.step)

    writer = metric_writers.create_default_writer(
        FLAGS.model_dir, just_logging=jax.process_index() > 0)

    flag_key = [
        k for k in FLAGS.flags_by_module_dict().keys() if 'wmt.par' in k
    ]
    if flag_key:
        flag_key = flag_key[0]
        local_flags = {
            f.name: f.value
            for f in FLAGS.flags_by_module_dict()[flag_key]
        }
        writer.write_hparams(local_flags)

    # Replicate optimizer.
    optimizer = jax_utils.replicate(optimizer)

    learning_rate_fn = common.create_learning_rate_scheduler(
        base_learning_rate=FLAGS.learning_rate,
        warmup_steps=FLAGS.warmup_steps,
        steps_per_cycle=FLAGS.steps_per_cycle,
        init_step=start_step,
        finetune_lr=FLAGS.finetune_lr)

    # compile multidevice versions of train/eval/predict step and cache init fn.
    p_train_step = jax.pmap(functools.partial(
        train_util.train_step,
        config=train_config,
        learning_rate_fn=learning_rate_fn,
        label_smoothing=FLAGS.label_smoothing),
                            axis_name='batch',
                            donate_argnums=(0, ))  # pytype: disable=wrong-arg-types
    p_eval_step = jax.pmap(functools.partial(train_util.eval_step,
                                             config=eval_config),
                           axis_name='batch')
    p_init_cache = jax.pmap(functools.partial(
        train_util.initialize_cache,
        max_decode_len=FLAGS.max_predict_length,
        config=predict_config),
                            axis_name='batch')
    p_pred_step = jax.pmap(
        functools.partial(train_util.predict_step,
                          config=predict_config,
                          beam_size=FLAGS.beam_size),
        axis_name='batch',
        static_broadcasted_argnums=(3,
                                    4))  # eos token, max_length are constant

    # Main Train Loop
    # ---------------------------------------------------------------------------

    # We init the first set of dropout PRNG keys, but update it afterwards inside
    # the main pmap"d training update for performance.
    dropout_rngs = jax.random.split(rng, jax.local_device_count())
    del rng

    logging.info('Starting training loop.')
    hooks = []
    report_progress = periodic_actions.ReportProgress(
        num_train_steps=FLAGS.num_train_steps, writer=writer)
    if jax.process_index() == 0:
        hooks += [
            report_progress,
            periodic_actions.Profile(logdir=FLAGS.model_dir,
                                     num_profile_steps=5)
        ]
    train_metrics = []
    total_steps = start_step + FLAGS.num_train_steps
    if FLAGS.eval_only:
        total_steps = start_step + 1
    best_eval_loss = 1000
    curr_eval_loss = 1000
    eval_loss_history = []
    last_eval_step = 0
    do_resample_data = False
    gradual_selection_size = FLAGS.data_selection_size
    dynamic_eval_freq = FLAGS.eval_frequency
    with metric_writers.ensure_flushes(writer):
        for step in range(start_step, total_steps):
            is_last_step = step == total_steps - 1

            # Resample training data for gradual FT
            if do_resample_data:
                # resample data
                do_resample_data = False
                gradual_selection_size *= .7
                dynamic_eval_freq = int(gradual_selection_size / 1000 / 4)

                train_ds, eval_ds, predict_ds, encoder = input_pipeline.get_wmt_datasets(
                    dataset_name=FLAGS.dataset_name,
                    eval_dataset_name=FLAGS.eval_dataset_name,
                    shard_idx=jax.process_index(),
                    shard_count=jax.process_count(),
                    data_dir=FLAGS.data_dir,
                    vocab_path=vocab_path,
                    target_vocab_size=FLAGS.vocab_size,
                    batch_size=FLAGS.batch_size,
                    max_length=FLAGS.max_target_length,
                    max_eval_length=FLAGS.max_eval_target_length,
                    paracrawl_size=FLAGS.paracrawl_size,
                    is_scores_path=FLAGS.is_scores_path,
                    num_to_keep=int(gradual_selection_size),
                    pseudo_path=FLAGS.pseudo_path,
                    repeat_count=FLAGS.repeat_count,
                    newscommentary_size=FLAGS.newscommentary_size,
                    split_tokenizer=FLAGS.split_tokenizer)
                train_iter = iter(train_ds)

            # Shard data to devices and do a training step.
            if not FLAGS.eval_only:
                logging.info('Doing Training.')
                with jax.profiler.StepTraceAnnotation('train', step_num=step):
                    try:
                        batch = common_utils.shard(
                            jax.tree_map(np.asarray, next(train_iter)))
                        optimizer, metrics = p_train_step(
                            optimizer, batch, dropout_rng=dropout_rngs)
                        train_metrics.append(metrics)
                    except StopIteration:
                        is_last_step = True

            # Quick indication that training is happening.
            logging.log_first_n(logging.INFO, 'Finished training step %d.', 5,
                                step)
            for h in hooks:
                h(step)

            # Periodic metric handling.
            if (step - start_step) % dynamic_eval_freq == 0 or is_last_step:
                if not FLAGS.eval_only:
                    with report_progress.timed('training_metrics'):
                        logging.info('Gathering training metrics.')
                        train_metrics = common_utils.get_metrics(train_metrics)
                        lr = train_metrics.pop('learning_rate').mean()
                        metrics_sums = jax.tree_map(jnp.sum, train_metrics)
                        denominator = metrics_sums.pop('denominator')
                        summary = jax.tree_map(lambda x: x / denominator,
                                               metrics_sums)  # pylint: disable=cell-var-from-loop
                        summary['learning_rate'] = lr
                        summary = {'train_' + k: v for k, v in summary.items()}
                        writer.write_scalars(step, summary)
                        train_metrics = []

                if FLAGS.eval_only:
                    p_eval_per_pos_step = jax.pmap(functools.partial(
                        train_util.eval_per_pos_step, config=eval_config),
                                                   axis_name='batch')
                    # Get per example loss
                    loss_filename = FLAGS.model_dir + '/test_losses.csv'
                    train_util.write_per_example_losses(
                        p_eval_step=p_eval_per_pos_step,
                        target=optimizer.target,
                        eval_ds=eval_ds,
                        num_eval_steps=FLAGS.num_eval_steps,
                        loss_filename=loss_filename)
                else:
                    with report_progress.timed('eval'):
                        eval_results = train_util.evaluate(
                            p_eval_step=p_eval_step,
                            target=optimizer.target,
                            eval_ds=eval_ds,
                            num_eval_steps=FLAGS.num_eval_steps)
                        curr_eval_loss = eval_results['loss']
                        eval_loss_history.append(curr_eval_loss)
                        if len(eval_loss_history) > 1:
                            improvement_rate = 0.000004
                            orig_loss = eval_loss_history[-2]
                            true_improvement = orig_loss - curr_eval_loss
                            expected_improvement = (
                                step - last_eval_step) * improvement_rate
                            # percent_change = (orig_loss - curr_eval_loss) / orig_loss
                            # percent_change *= 100
                            if true_improvement < expected_improvement:  # percent_change<.1:
                                do_resample_data = True
                        last_eval_step = step
                        writer.write_scalars(
                            step,
                            {'eval_' + k: v
                             for k, v in eval_results.items()})

                if FLAGS.aux_eval_dataset:
                    for aux_i, aux_eval_ds in enumerate(aux_datasets):
                        with report_progress.timed('aux_eval'):
                            eval_results = train_util.evaluate(
                                p_eval_step=p_eval_step,
                                target=optimizer.target,
                                eval_ds=aux_eval_ds,
                                num_eval_steps=FLAGS.num_eval_steps)
                            writer.write_scalars(
                                step, {
                                    'aux' + str(aux_i) + '_eval_' + k: v
                                    for k, v in eval_results.items()
                                })

                if FLAGS.compute_bleu:
                    with report_progress.timed('translate_and_bleu'):
                        decode_file = FLAGS.model_dir + '/decodes.csv'
                        exemplars, bleu_score = train_util.translate_and_calculate_bleu(
                            p_pred_step=p_pred_step,
                            p_init_cache=p_init_cache,
                            target=optimizer.target,
                            predict_ds=predict_ds,
                            decode_tokens=decode_tokens,
                            max_predict_length=FLAGS.max_predict_length,
                            num_eval_steps=FLAGS.num_eval_steps,
                            decode_file=decode_file if FLAGS.eval_only else '')
                        writer.write_scalars(step, {'bleu': bleu_score})
                        writer.write_texts(step, {'samples': exemplars})

            # Save a checkpoint on one host after every checkpoint_freq steps.
            save_checkpoint = ((step - start_step) % FLAGS.checkpoint_freq == 0
                               or is_last_step)
            if FLAGS.save_checkpoints and save_checkpoint and jax.process_index(
            ) == 0:
                if curr_eval_loss < best_eval_loss:  # only save better checkpoints
                    best_eval_loss = curr_eval_loss
                    with report_progress.timed('checkpoint'):
                        checkpoints.save_checkpoint(
                            FLAGS.model_dir,
                            jax_utils.unreplicate(optimizer),
                            step,
                            keep=FLAGS.chkpts_to_keep,
                            overwrite=True)

            if is_last_step:
                break
Example #7
0
def train_and_evaluate(config, workdir):
    """Execute model training and evaluation loop.

  Args:
    config: Hyperparameter configuration for training and evaluation.
    workdir: Directory where the tensorboard summaries are written to.

  Returns:
    The train state (which includes the `.params`).
  """
    # Seed for reproducibility.
    rng = jax.random.PRNGKey(config.rng_seed)

    # Set up logging.
    summary_writer = metric_writers.create_default_writer(workdir)
    summary_writer.write_hparams(dict(config))

    # Get datasets.
    rng, dataset_rng = jax.random.split(rng)
    dataset = input_pipeline.get_dataset(config, dataset_rng)
    graph, labels, masks = jax.tree_map(jnp.asarray, dataset)
    labels = jax.nn.one_hot(labels, config.num_classes)
    train_mask = masks['train']
    train_indices = jnp.where(train_mask)[0]
    train_labels = labels[train_indices]
    num_training_nodes = len(train_indices)

    # Get subgraphs.
    if config.differentially_private_training:
        graph = jax.tree_map(np.asarray, graph)
        subgraphs = get_subgraphs(graph, pad_to=config.pad_subgraphs_to)
        graph = jax.tree_map(jnp.asarray, graph)

        # We only need the subgraphs for training nodes.
        train_subgraphs = subgraphs[train_indices]
        del subgraphs
    else:
        train_subgraphs = None

    # Initialize privacy accountant.
    training_privacy_accountant = privacy_accountants.get_training_privacy_accountant(
        config, num_training_nodes, compute_max_terms_per_node(config))

    # Construct and initialize model.
    rng, init_rng = jax.random.split(rng)
    estimation_indices = get_estimation_indices(train_indices, config)
    state = create_train_state(init_rng, config, graph, train_labels,
                               train_subgraphs, estimation_indices)

    # Set up checkpointing of the model.
    checkpoint_dir = os.path.join(workdir, 'checkpoints')
    ckpt = checkpoint.Checkpoint(checkpoint_dir, max_to_keep=2)
    state = ckpt.restore_or_initialize(state)
    initial_step = int(state.step) + 1

    # Log overview of parameters.
    parameter_overview.log_parameter_overview(state.params)

    # Log metrics after initialization.
    logits = compute_logits(state, graph)
    metrics_after_init = compute_metrics(logits, labels, masks)
    metrics_after_init['epsilon'] = 0
    log_metrics(0, metrics_after_init, summary_writer, postfix='_after_init')

    # Train model.
    rng, train_rng = jax.random.split(rng)
    max_training_epsilon = get_max_training_epsilon(config)

    # Hooks called periodically during training.
    report_progress = periodic_actions.ReportProgress(
        num_train_steps=config.num_training_steps, writer=summary_writer)
    profiler = periodic_actions.Profile(num_profile_steps=5, logdir=workdir)
    hooks = [report_progress, profiler]

    for step in range(initial_step, config.num_training_steps):

        # Perform one step of training.
        with jax.profiler.StepTraceAnnotation('train', step_num=step):
            # Sample batch.
            step_rng = jax.random.fold_in(train_rng, step)
            indices = jax.random.choice(step_rng, num_training_nodes,
                                        (config.batch_size, ))

            # Compute gradients.
            if config.differentially_private_training:
                grads = compute_updates_for_dp(state, graph, train_labels,
                                               train_subgraphs, indices,
                                               config.adjacency_normalization)
            else:
                grads = compute_updates(state, graph, train_labels, indices)

            # Update parameters.
            state = update_model(state, grads)

        # Quick indication that training is happening.
        logging.log_first_n(logging.INFO, 'Finished training step %d.', 10,
                            step)
        for hook in hooks:
            hook(step)

        # Evaluate, if required.
        is_last_step = (step == config.num_training_steps - 1)
        if step % config.evaluate_every_steps == 0 or is_last_step:
            with report_progress.timed('eval'):
                # Check if privacy budget exhausted.
                training_epsilon = training_privacy_accountant(step + 1)
                if max_training_epsilon is not None and training_epsilon >= max_training_epsilon:
                    break

                # Compute metrics.
                logits = compute_logits(state, graph)
                metrics_during_training = compute_metrics(
                    logits, labels, masks)
                metrics_during_training['epsilon'] = training_epsilon
                log_metrics(step, metrics_during_training, summary_writer)

        # Checkpoint, if required.
        if step % config.checkpoint_every_steps == 0 or is_last_step:
            with report_progress.timed('checkpoint'):
                ckpt.save(state)

    return state
Example #8
0
def train_and_evaluate(config: ml_collections.ConfigDict, workdir: str):
  """Runs a training and evaluation loop.

  Args:
    config: Configuration to use.
    workdir: Working directory for checkpoints and TF summaries. If this
      contains checkpoint training will be resumed from the latest checkpoint.
  """
  tf.io.gfile.makedirs(workdir)

  vocab_path = config.vocab_path
  if vocab_path is None:
    vocab_path = os.path.join(workdir, "sentencepiece_model")
    config.vocab_path = vocab_path
  tf.io.gfile.makedirs(os.path.split(vocab_path)[0])

  # Load Dataset
  # ---------------------------------------------------------------------------
  logging.info("Initializing dataset.")
  train_ds, eval_ds, predict_ds, encoder = input_pipeline.get_wmt_datasets(
      n_devices=jax.local_device_count(),
      config=config,
      reverse_translation=config.reverse_translation,
      vocab_path=vocab_path)

  train_iter = iter(train_ds)
  vocab_size = int(encoder.vocab_size())
  eos_id = decode.EOS_ID  # Default Sentencepiece EOS token.

  def decode_tokens(toks):
    valid_toks = toks[:np.argmax(toks == eos_id) + 1].astype(np.int32)
    return encoder.detokenize(valid_toks).numpy().decode("utf-8")

  if config.num_predict_steps > 0:
    predict_ds = predict_ds.take(config.num_predict_steps)

  logging.info("Initializing model, optimizer, and step functions.")

  # Build Model and Optimizer
  # ---------------------------------------------------------------------------
  train_config = models.TransformerConfig(
      vocab_size=vocab_size,
      output_vocab_size=vocab_size,
      share_embeddings=config.share_embeddings,
      logits_via_embedding=config.logits_via_embedding,
      dtype=jnp.bfloat16 if config.use_bfloat16 else jnp.float32,
      emb_dim=config.emb_dim,
      num_heads=config.num_heads,
      num_layers=config.num_layers,
      qkv_dim=config.qkv_dim,
      mlp_dim=config.mlp_dim,
      max_len=max(config.max_target_length, config.max_eval_target_length),
      dropout_rate=config.dropout_rate,
      attention_dropout_rate=config.attention_dropout_rate,
      deterministic=False,
      decode=False,
      kernel_init=nn.initializers.xavier_uniform(),
      bias_init=nn.initializers.normal(stddev=1e-6))
  eval_config = train_config.replace(deterministic=True)
  predict_config = train_config.replace(deterministic=True, decode=True)

  start_step = 0
  rng = jax.random.PRNGKey(config.seed)
  rng, init_rng = jax.random.split(rng)
  input_shape = (config.per_device_batch_size, config.max_target_length)
  target_shape = (config.per_device_batch_size, config.max_target_length)

  m = models.Transformer(eval_config)
  initial_variables = jax.jit(m.init)(init_rng,
                                      jnp.ones(input_shape, jnp.float32),
                                      jnp.ones(target_shape, jnp.float32))

  # apply an optimizer to this tree
  optimizer_def = optim.Adam(
      config.learning_rate,
      beta1=0.9,
      beta2=0.98,
      eps=1e-9,
      weight_decay=config.weight_decay)
  optimizer = optimizer_def.create(initial_variables["params"])

  # We access model params only from optimizer below via optimizer.target.
  del initial_variables

  if config.restore_checkpoints:
    # Restore unreplicated optimizer + model state from last checkpoint.
    optimizer = checkpoints.restore_checkpoint(workdir, optimizer)
    # Grab last step.
    start_step = int(optimizer.state.step)

  writer = metric_writers.create_default_writer(
      workdir, just_logging=jax.host_id() > 0)
  if start_step == 0:
    writer.write_hparams(dict(config))

  # Replicate optimizer.
  optimizer = jax_utils.replicate(optimizer)

  learning_rate_fn = create_learning_rate_scheduler(
      base_learning_rate=config.learning_rate, warmup_steps=config.warmup_steps)

  # compile multidevice versions of train/eval/predict step and cache init fn.
  p_train_step = jax.pmap(
      functools.partial(
          train_step,
          config=train_config,
          learning_rate_fn=learning_rate_fn,
          label_smoothing=config.label_smoothing),
      axis_name="batch",
      donate_argnums=(0,))  # pytype: disable=wrong-arg-types
  p_eval_step = jax.pmap(
      functools.partial(
          eval_step, config=eval_config),
      axis_name="batch")
  p_init_cache = jax.pmap(
      functools.partial(
          initialize_cache,
          max_decode_len=config.max_predict_length,
          config=predict_config),
      axis_name="batch")
  p_pred_step = jax.pmap(
      functools.partial(
          predict_step, config=predict_config, beam_size=config.beam_size),
      axis_name="batch",
      static_broadcasted_argnums=(3, 4))  # eos token, max_length are constant

  # Main Train Loop
  # ---------------------------------------------------------------------------

  # We init the first set of dropout PRNG keys, but update it afterwards inside
  # the main pmap"d training update for performance.
  dropout_rngs = jax.random.split(rng, jax.local_device_count())
  del rng

  logging.info("Starting training loop.")
  hooks = []
  report_progress = periodic_actions.ReportProgress(
      num_train_steps=config.num_train_steps, writer=writer)
  if jax.host_id() == 0:
    hooks += [
        report_progress,
        periodic_actions.Profile(logdir=workdir, num_profile_steps=5)
    ]
  train_metrics = []
  with metric_writers.ensure_flushes(writer):
    for step in range(start_step, config.num_train_steps):
      is_last_step = step == config.num_train_steps - 1

      # Shard data to devices and do a training step.
      with jax.profiler.StepTraceAnnotation("train", step_num=step):
        batch = common_utils.shard(jax.tree_map(np.asarray, next(train_iter)))
        optimizer, metrics = p_train_step(
            optimizer, batch, dropout_rng=dropout_rngs)
        train_metrics.append(metrics)

      # Quick indication that training is happening.
      logging.log_first_n(logging.INFO, "Finished training step %d.", 5, step)
      for h in hooks:
        h(step)

      # Periodic metric handling.
      if step % config.eval_every_steps == 0 or is_last_step:
        with report_progress.timed("training_metrics"):
          logging.info("Gathering training metrics.")
          train_metrics = common_utils.get_metrics(train_metrics)
          lr = train_metrics.pop("learning_rate").mean()
          metrics_sums = jax.tree_map(jnp.sum, train_metrics)
          denominator = metrics_sums.pop("denominator")
          summary = jax.tree_map(lambda x: x / denominator, metrics_sums)  # pylint: disable=cell-var-from-loop
          summary["learning_rate"] = lr
          summary = {"train_" + k: v for k, v in summary.items()}
          writer.write_scalars(step, summary)
          train_metrics = []

        with report_progress.timed("eval"):
          eval_results = evaluate(
              p_eval_step=p_eval_step,
              target=optimizer.target,
              eval_ds=eval_ds,
              num_eval_steps=config.num_eval_steps)
          writer.write_scalars(
              step, {"eval_" + k: v for k, v in eval_results.items()})

        with report_progress.timed("translate_and_bleu"):
          exemplars, bleu_score = translate_and_calculate_bleu(
              p_pred_step=p_pred_step,
              p_init_cache=p_init_cache,
              target=optimizer.target,
              predict_ds=predict_ds,
              decode_tokens=decode_tokens,
              max_predict_length=config.max_predict_length)
          writer.write_scalars(step, {"bleu": bleu_score})
          writer.write_texts(step, {"samples": exemplars})

      # Save a checkpoint on one host after every checkpoint_freq steps.
      save_checkpoint = (step % config.checkpoint_every_steps == 0 or
                         is_last_step)
      if config.save_checkpoints and save_checkpoint and jax.host_id() == 0:
        with report_progress.timed("checkpoint"):
          checkpoints.save_checkpoint(workdir, jax_utils.unreplicate(optimizer),
                                      step)
Example #9
0
def train(base_dir, config):
    """Train function."""
    print(config)
    chkpt_manager = checkpoint.Checkpoint(str(base_dir / 'train'))

    writer = create_default_writer()

    # Initialize dataset
    key = jax.random.PRNGKey(config.seed)
    key, subkey = jax.random.split(key)
    ds = dataset.get_dataset(config, subkey, num_tasks=config.num_tasks)
    ds_iter = iter(ds)

    key, subkey = jax.random.split(key)
    encoder = MLPEncoder(**config.encoder)

    train_config = config.train.to_dict()
    train_method = train_config.pop('method')

    module_config = train_config.pop('module')
    module_class = module_config.pop('name')

    module = globals().get(module_class)(encoder, **module_config)
    train_step = globals().get(f'train_step_{train_method}')
    train_step = functools.partial(train_step, **train_config)

    params = module.init(subkey, next(ds_iter)[0])
    lr = optax.cosine_decay_schedule(config.learning_rate,
                                     config.num_train_steps)
    optim = optax.chain(optax.adam(lr),
                        # optax.adaptive_grad_clip(0.15)
                        )

    state = TrainState.create(apply_fn=module.apply, params=params, tx=optim)
    state = chkpt_manager.restore_or_initialize(state)

    # Hooks
    report_progress = periodic_actions.ReportProgress(
        num_train_steps=config.num_train_steps, writer=writer)
    hooks = [
        report_progress,
        periodic_actions.Profile(num_profile_steps=5, logdir=str(base_dir))
    ]

    def handle_preemption(signal_number, _):
        logging.info('Received signal %d, saving checkpoint.', signal_number)
        with report_progress.timed('checkpointing'):
            chkpt_manager.save(state)
        logging.info('Finished saving checkpoint.')

    signal.signal(signal.SIGTERM, handle_preemption)

    metrics = TrainMetrics.empty()
    with metric_writers.ensure_flushes(writer):
        for step in tqdm.tqdm(range(state.step, config.num_train_steps)):
            with jax.profiler.StepTraceAnnotation('train', step_num=step):
                states, targets = next(ds_iter)
                state, metrics = train_step(state, metrics, states, targets)

            logging.log_first_n(logging.INFO, 'Finished training step %d', 5,
                                step)

            if step % config.log_metrics_every == 0:
                writer.write_scalars(step, metrics.compute())
                metrics = TrainMetrics.empty()

            # if step % config.log_eval_metrics_every == 0 and isinstance(
            #     ds, dataset.MDPDataset):
            #   eval_metrics = evaluate_mdp(state, ds.aux_task_matrix, config)
            #   writer.write_scalars(step, eval_metrics.compute())

            for hook in hooks:
                hook(step)

    chkpt_manager.save(state)
    return state
Example #10
0
def main(config, output_dir):
    # Note: switch to ProfileAllHosts() if you need to profile all hosts.
    # (Xprof data become much larger and take longer to load for analysis)
    profiler = periodic_actions.Profile(
        # Create profile after every restart to analyze pre-emption related
        # problems and assure we get similar performance in every run.
        logdir=output_dir,
        first_profile=10)

    logging.info(config)

    acquisition_method = config.get('acquisition_method')

    # Create an asynchronous multi-metric writer.
    writer = metric_writers.create_default_writer(
        output_dir, just_logging=jax.process_index() > 0)
    writer.write_hparams(dict(config))

    # The pool is used to perform misc operations such as logging in async way.
    pool = multiprocessing.pool.ThreadPool()

    def write_note(note):
        if jax.process_index() == 0:
            logging.info('NOTE: %s', note)

    write_note(f'Initializing for {acquisition_method}')

    # Download dataset
    data_builder = tfds.builder(config.dataset)
    data_builder.download_and_prepare()

    seed = config.get('seed', 0)
    rng = jax.random.PRNGKey(seed)

    batch_size = config.batch_size
    batch_size_eval = config.get('batch_size_eval', batch_size)

    local_batch_size = batch_size // jax.process_count()
    local_batch_size_eval = batch_size_eval // jax.process_count()

    val_ds = input_utils.get_data(
        dataset=config.dataset,
        split=config.val_split,
        rng=None,
        process_batch_size=local_batch_size_eval,
        preprocess_fn=preprocess_spec.parse(
            spec=config.pp_eval, available_ops=preprocess_utils.all_ops()),
        shuffle=False,
        prefetch_size=config.get('prefetch_to_host', 2),
        num_epochs=1,  # Only repeat once.
    )

    test_ds = input_utils.get_data(
        dataset=config.dataset,
        split=config.test_split,
        rng=None,
        process_batch_size=local_batch_size_eval,
        preprocess_fn=preprocess_spec.parse(
            spec=config.pp_eval, available_ops=preprocess_utils.all_ops()),
        shuffle=False,
        prefetch_size=config.get('prefetch_to_host', 2),
        num_epochs=1,  # Only repeat once.
    )

    # Init model
    if config.model_type == 'deterministic':
        model_utils = deterministic_utils
        reinit_params = config.get('model_reinit_params',
                                   ('head/kernel', 'head/bias'))
        model = ub.models.vision_transformer(num_classes=config.num_classes,
                                             **config.get('model', {}))
    elif config.model_type == 'batchensemble':
        model_utils = batchensemble_utils
        reinit_params = ('batchensemble_head/bias',
                         'batchensemble_head/kernel',
                         'batchensemble_head/fast_weight_alpha',
                         'batchensemble_head/fast_weight_gamma')
        model = ub.models.PatchTransformerBE(num_classes=config.num_classes,
                                             **config.model)
    else:
        raise ValueError('Expect config.model_type to be "deterministic" or'
                         f'"batchensemble", but received {config.model_type}.')

    init = model_utils.create_init(model, config, test_ds)

    rng, rng_init = jax.random.split(rng)
    params_cpu = init(rng_init)

    if jax.process_index() == 0:
        num_params = sum(p.size for p in jax.tree_flatten(params_cpu)[0])
        parameter_overview.log_parameter_overview(params_cpu)
        writer.write_scalars(step=0, scalars={'num_params': num_params})

    # Load the optimizer from flax.
    opt_name = config.get('optim_name')
    opt_def = getattr(flax.optim, opt_name)(**config.get('optim', {}))

    # We jit this, such that the arrays that are created on the same
    # device as the input is, in this case the CPU. Else they'd be on device[0].
    opt_cpu = jax.jit(opt_def.create)(params_cpu)

    loaded_params = checkpoint_utils.load_checkpoint(tree=None,
                                                     path=config.model_init)
    loaded = checkpoint_utils.restore_from_pretrained_params(
        params_cpu,
        loaded_params,
        config.model.representation_size,
        config.model.classifier,
        reinit_params,
    )

    opt_cpu = opt_cpu.replace(target=loaded)

    # TODO(joost,andreas): This shouldn't be needed but opt_cpu is being
    # donated otherwise. Ensure opt_cpu is really on the cpu this way.
    opt_cpu = jax.device_get(opt_cpu)

    update_fn = model_utils.create_update_fn(model, config)
    evaluation_fn = model_utils.create_evaluation_fn(model, config)

    # NOTE: We need this because we need an Id field of type int.
    # TODO(andreas): Rename to IdSubsetDatasetBuilder?
    pool_subset_data_builder = al_utils.SubsetDatasetBuilder(data_builder,
                                                             subset_ids=None)

    rng, pool_ds_rng = jax.random.split(rng)

    # NOTE: below line is necessary on multi host setup
    # pool_ds_rng = jax.random.fold_in(pool_ds_rng, jax.process_index())

    pool_train_ds = input_utils.get_data(
        dataset=pool_subset_data_builder,
        split=config.train_split,
        rng=pool_ds_rng,
        process_batch_size=local_batch_size,
        preprocess_fn=preprocess_spec.parse(
            spec=config.pp_eval, available_ops=preprocess_utils.all_ops()),
        shuffle=False,
        drop_remainder=False,
        prefetch_size=config.get('prefetch_to_host', 2),
        num_epochs=1,  # Don't repeat
    )

    # Potentially acquire an initial training set.
    initial_training_set_size = config.get('initial_training_set_size', 10)

    if initial_training_set_size > 0:
        current_opt_repl = flax_utils.replicate(opt_cpu)
        pool_ids, _, _, pool_masks = get_ids_logits_masks(
            model=model,
            opt_repl=current_opt_repl,
            ds=pool_train_ds,
            config=config)

        rng, initial_uniform_rng = jax.random.split(rng)
        pool_scores = get_uniform_scores(pool_masks, initial_uniform_rng)

        initial_training_set_batch_ids, _ = select_acquisition_batch_indices(
            acquisition_batch_size=initial_training_set_size,
            scores=pool_scores,
            ids=pool_ids,
            ignored_ids=set(),
        )
    else:
        initial_training_set_batch_ids = []

    # NOTE: if we could `enumerate` before `filter` in `create_dataset` of CLU
    # then this dataset creation could be simplified.
    # https://github.com/google/CommonLoopUtils/blob/main/clu/deterministic_data.py#L340
    # CLU is explicitly not accepting outside contributions at the moment.
    train_subset_data_builder = al_utils.SubsetDatasetBuilder(
        data_builder, subset_ids=set(initial_training_set_batch_ids))

    test_accuracies = []
    training_sizes = []

    rng, rng_loop = jax.random.split(rng)
    rngs_loop = flax_utils.replicate(rng_loop)
    if config.model_type == 'batchensemble':
        rngs_loop = {'dropout': rngs_loop}

    # TODO(joost,andreas): double check if below is still necessary
    # (train_split is independent of this)
    # NOTE: train_ds_rng is re-used for all train_ds creations
    rng, train_ds_rng = jax.random.split(rng)

    measurements = {}
    accumulated_steps = 0
    while True:
        current_train_ds_length = len(train_subset_data_builder.subset_ids)
        if current_train_ds_length >= config.get('max_training_set_size', 150):
            break
        write_note(f'Training set size: {current_train_ds_length}')

        current_opt_repl = flax_utils.replicate(opt_cpu)

        # Only fine-tune if there is anything to fine-tune with.
        if current_train_ds_length > 0:
            # Repeat dataset to have oversampled epochs and bootstrap more batches
            number_of_batches = current_train_ds_length / config.batch_size
            num_repeats = math.ceil(config.total_steps / number_of_batches)
            write_note(f'Repeating dataset {num_repeats} times')

            # We repeat the dataset several times, such that we can obtain batches
            # of size batch_size, even at start of training. These batches will be
            # effectively 'bootstrap' sampled, meaning they are sampled with
            # replacement from the original training set.
            repeated_train_ds = input_utils.get_data(
                dataset=train_subset_data_builder,
                split=config.train_split,
                rng=train_ds_rng,
                process_batch_size=local_batch_size,
                preprocess_fn=preprocess_spec.parse(
                    spec=config.pp_train,
                    available_ops=preprocess_utils.all_ops()),
                shuffle_buffer_size=config.shuffle_buffer_size,
                prefetch_size=config.get('prefetch_to_host', 2),
                # TODO(joost,andreas): double check if below leads to bootstrap
                # sampling.
                num_epochs=num_repeats,
            )

            # We use this dataset to evaluate how well we perform on the training set.
            # We need this to evaluate if we fit well within max_steps budget.
            train_eval_ds = input_utils.get_data(
                dataset=train_subset_data_builder,
                split=config.train_split,
                rng=train_ds_rng,
                process_batch_size=local_batch_size,
                preprocess_fn=preprocess_spec.parse(
                    spec=config.pp_eval,
                    available_ops=preprocess_utils.all_ops()),
                shuffle=False,
                drop_remainder=False,
                prefetch_size=config.get('prefetch_to_host', 2),
                num_epochs=1,
            )

            # NOTE: warmup and decay are not a good fit for the small training set
            # lr_fn = train_utils.create_learning_rate_schedule(config.total_steps,
            #                                                   **config.get('lr', {})
            #                                                   )
            lr_fn = lambda x: config.lr.base

            early_stopping_patience = config.get('early_stopping_patience', 15)
            current_opt_repl, rngs_loop, measurements = finetune(
                update_fn=update_fn,
                opt_repl=current_opt_repl,
                lr_fn=lr_fn,
                ds=repeated_train_ds,
                rngs_loop=rngs_loop,
                total_steps=config.total_steps,
                train_eval_ds=train_eval_ds,
                val_ds=val_ds,
                evaluation_fn=evaluation_fn,
                early_stopping_patience=early_stopping_patience,
                profiler=profiler)
            train_val_accuracies = measurements.pop('train_val_accuracies')
            current_steps = 0
            for step, train_acc, val_acc in train_val_accuracies:
                writer.write_scalars(accumulated_steps + step, {
                    'train_accuracy': train_acc,
                    'val_accuracy': val_acc
                })
                current_steps = step
            accumulated_steps += current_steps + 10

        test_accuracy = get_accuracy(evaluation_fn=evaluation_fn,
                                     opt_repl=current_opt_repl,
                                     ds=test_ds)

        write_note(f'Accuracy at {current_train_ds_length}: {test_accuracy}')

        test_accuracies.append(test_accuracy)
        training_sizes.append(current_train_ds_length)

        pool_ids, pool_outputs, _, pool_masks = get_ids_logits_masks(
            model=model,
            opt_repl=current_opt_repl,
            ds=pool_train_ds,
            use_pre_logits=acquisition_method == 'density',
            config=config)

        if acquisition_method == 'uniform':
            rng_loop, rng_acq = jax.random.split(rng_loop, 2)
            pool_scores = get_uniform_scores(pool_masks, rng_acq)
        elif acquisition_method == 'entropy':
            pool_scores = get_entropy_scores(pool_outputs, pool_masks)
        elif acquisition_method == 'margin':
            pool_scores = get_margin_scores(pool_outputs, pool_masks)
        elif acquisition_method == 'density':
            if current_train_ds_length > 0:
                pool_scores = get_density_scores(model=model,
                                                 opt_repl=current_opt_repl,
                                                 train_ds=train_eval_ds,
                                                 pool_pre_logits=pool_outputs,
                                                 pool_masks=pool_masks,
                                                 config=config)
            else:
                rng_loop, rng_acq = jax.random.split(rng_loop, 2)
                pool_scores = get_uniform_scores(pool_masks, rng_acq)
        else:
            raise ValueError('Acquisition method not found.')

        acquisition_batch_ids, _ = select_acquisition_batch_indices(
            acquisition_batch_size=config.get('acquisition_batch_size', 10),
            scores=pool_scores,
            ids=pool_ids,
            ignored_ids=train_subset_data_builder.subset_ids)

        train_subset_data_builder.subset_ids.update(acquisition_batch_ids)

        measurements.update({'test_accuracy': test_accuracy})
        writer.write_scalars(current_train_ds_length, measurements)

    write_note(f'Final acquired training ids: '
               f'{train_subset_data_builder.subset_ids}'
               f'Accuracies: {test_accuracies}')

    pool.close()
    pool.join()
    writer.close()
    # TODO(joost,andreas): save the final checkpoint
    return (train_subset_data_builder.subset_ids, test_accuracies)
Example #11
0
def evaluate(base_dir, config, *, train_state):
    """Eval function."""
    chkpt_manager = checkpoint.Checkpoint(str(base_dir / 'eval'))

    writer = create_default_writer()

    key = jax.random.PRNGKey(config.eval.seed)
    model_init_key, ds_key = jax.random.split(key)

    linear_module = LinearModule(config.eval.num_tasks)
    params = linear_module.init(model_init_key,
                                jnp.zeros((config.encoder.embedding_dim, )))
    lr = optax.cosine_decay_schedule(config.eval.learning_rate,
                                     config.num_eval_steps)
    optim = optax.adam(lr)

    ds = dataset.get_dataset(config, ds_key, num_tasks=config.eval.num_tasks)
    ds_iter = iter(ds)

    state = TrainState.create(apply_fn=linear_module.apply,
                              params=params,
                              tx=optim)
    state = chkpt_manager.restore_or_initialize(state)

    report_progress = periodic_actions.ReportProgress(
        num_train_steps=config.num_eval_steps, writer=writer)
    hooks = [
        report_progress,
        periodic_actions.Profile(num_profile_steps=5, logdir=str(base_dir))
    ]

    def handle_preemption(signal_number, _):
        logging.info('Received signal %d, saving checkpoint.', signal_number)
        with report_progress.timed('checkpointing'):
            chkpt_manager.save(state)
        logging.info('Finished saving checkpoint.')

    signal.signal(signal.SIGTERM, handle_preemption)

    metrics = EvalMetrics.empty()
    with metric_writers.ensure_flushes(writer):
        for step in tqdm.tqdm(range(state.step, config.num_eval_steps)):
            with jax.profiler.StepTraceAnnotation('eval', step_num=step):
                states, targets = next(ds_iter)
                state, metrics = evaluate_step(train_state, state, metrics,
                                               states, targets)

            if step % config.log_metrics_every == 0:
                writer.write_scalars(step, metrics.compute())
                metrics = EvalMetrics.empty()

            for hook in hooks:
                hook(step)

        # Finally, evaluate on the true(ish) test aux task matrix.
        states, targets = dataset.EvalDataset(config, ds_key).get_batch()

        @jax.jit
        def loss_fn():
            outputs = train_state.apply_fn(train_state.params, states)
            phis = outputs.phi
            predictions = jax.vmap(state.apply_fn,
                                   in_axes=(None, 0))(state.params, phis)
            return jnp.mean(optax.l2_loss(predictions, targets))

        test_loss = loss_fn()
        writer.write_scalars(config.num_eval_steps + 1,
                             {'test_loss': test_loss})
Example #12
0
def train(config: ml_collections.ConfigDict):
  """Run training."""

  # Establish host information
  local_device_count = jax.local_device_count()
  host_count = jax.process_count()
  host_id = jax.process_index()

  task = task_registry.get_registered_task(config.task_name)

  start_step = 0
  rng = jax.random.PRNGKey(config.seed)

  model_config = ml_collections.FrozenConfigDict(config.model_config)
  model = task.build_model(model_config)

  # Initialization needs to be pmapped because models use collective ops.
  # Create dummy input
  dummy_input = {
      key: jnp.tile(value, (local_device_count,) + (1,) * value.ndim)
      for key, value in task.dummy_input(config).items()
  }

  rng, init_rng = jax.random.split(rng)
  init_rng = jax.random.split(init_rng, local_device_count)

  logging.info('Initializing model.')
  initial_variables = jax.pmap(
      model.init, 'batch', static_broadcasted_argnums=2)(init_rng, dummy_input,
                                                         True)
  logging.info('Finished initializing model.')
  initial_variables = initial_variables.unfreeze()

  if config.load_weights is not None:
    logging.info('Loading model weights from file')
    loaded_variables = task.load_weights(config)
    unexpected, missing = checkpoint_utils.merge_nested_dicts(
        initial_variables, loaded_variables)
    logging.info('*** Unexpected features: ***')
    for feature_name in unexpected:
      logging.info('\t%s', feature_name)
    logging.info('*** Missing features: ***')
    for feature_name in missing:
      logging.info('\t%s', feature_name)

  model_vars = {
      key: value for key, value in initial_variables.items() if key != 'params'
  }

  learning_rate_fn = optim_utils.create_learning_rate_scheduler(
      learning_rate=config.learning_rate,
      warmup=config.warmup,
      warmup_steps=config.get('warmup_steps', None),
      linear_decay=config.linear_decay,
      max_steps=config.num_train_steps,
      decay_minimum_factor=config.get('decay_minimum_factor', None),
  )

  if config.weight_decay_exclude is not None:
    decay_mask = optim_utils.create_dict_mask(initial_variables['params'],
                                              config.weight_decay_exclude)
  else:
    decay_mask = None
  tx = optax.adamw(
      learning_rate=learning_rate_fn,
      weight_decay=config.weight_decay,
      b1=0.9,
      b2=0.999,
      eps=1e-6,
      mask=decay_mask)
  if config.grad_clip is not None:
    tx = optax.chain(tx, optax.clip_by_global_norm(config.grad_clip))

  ignore_k_nans = config.get('ignore_k_nans')
  if ignore_k_nans is not None:
    tx = optax.apply_if_finite(tx, max_consecutive_errors=ignore_k_nans)

  loss_fn = task.make_loss_fn(config)
  train_state = ts.TrainState.create(
      apply_fn=loss_fn,
      params=jax_utils.unreplicate(initial_variables['params']),
      tx=tx,
  )

  # We access model params only from train state.
  del initial_variables

  # Restore unreplicated train state from last checkpoint
  train_state = checkpoints.restore_checkpoint(config.model_dir, train_state)
  # Grab last step.
  start_step = int(train_state.step)

  writer = metric_writers.create_default_writer(
      config.model_dir, just_logging=jax.process_index() > 0)
  if start_step == 0:
    writer.write_hparams(config.to_dict())

  dropout_rngs = jax.random.split(rng, local_device_count)

  del rng

  # Load datasets
  logging.info('Loading dataset.')

  # Make sure we don't re-use same data if we load weights or checkpoint
  seed = config.seed + start_step
  if config.load_weights:
    seed = seed + hash(config.load_weights)

  name_to_features = task.get_name_to_features(config)
  preprocess_fn = task.make_preprocess_fn(config)
  collater_fn = task.make_collater_fn(config)

  train_data = data_utils.load_multi_dataset(
      datasets_config=config.train_data,
      name_to_features=name_to_features,
      preprocess_fn=preprocess_fn,
      collater_fn=collater_fn,
      is_training=True,
      per_device_batch_size=config.per_device_batch_size,
      local_device_count=local_device_count,
      host_count=host_count,
      host_id=host_id,
      seed=config.seed,
  )
  train_iter = iter(train_data)

  pad_eval = config.get('pad_eval', False)
  if pad_eval:
    logging.info('Eval data is padded such that none of samples are dropped.')
  else:
    logging.warn('Eval data is NOT padded -- some samples might be dropped.')

  eval_data = data_utils.load_multi_dataset(
      datasets_config=config.eval_data,
      name_to_features=name_to_features,
      preprocess_fn=preprocess_fn,
      collater_fn=collater_fn,
      is_training=False,
      per_device_batch_size=config.per_device_batch_size,
      local_device_count=local_device_count,
      host_count=host_count,
      host_id=host_id,
      seed=config.seed,
      pad_eval=pad_eval,
  )
  eval_data = list(eval_data)
  logging.info('Loaded %d samples for evaluation.', len(eval_data))

  # Setup postprocessing_fn for saving samples occasionally.
  if config.get('save_samples_every_steps') is not None:
    if config.get('save_samples_every_steps') % config.eval_every_steps != 0:
      raise ValueError(
          '`eval_every_steps` must divide `save_samples_every_steps`.')
    postprocessing_fn = task.make_output_postprocess_fn(config)

  # Training loop
  logging.info('Starting training.')

  # Replicate train state.
  train_state = jax_utils.replicate(train_state)

  # compile multidevice versions of train/eval/predict step
  p_train_step = jax.pmap(
      functools.partial(
          train_step,
          model_config=model_config,
      ),
      axis_name='batch',
      donate_argnums=(0,),
  )  # pytype: disable=wrong-arg-types
  p_eval_step = jax.pmap(
      functools.partial(
          eval_step,
          model_config=model_config,
      ),
      axis_name='batch')

  hooks = []
  report_progress = periodic_actions.ReportProgress(
      num_train_steps=config.num_train_steps, writer=writer)

  if jax.process_index() == 0:
    hooks += [
        report_progress,
        periodic_actions.Profile(logdir=config.model_dir, num_profile_steps=5)
    ]
  train_metrics = []
  with metric_writers.ensure_flushes(writer):
    for step in range(start_step, config.num_train_steps):
      is_last_step = step == config.num_train_steps - 1

      # Shard data to devices and perform a training step.
      with jax.profiler.StepTraceAnnotation('train', step_num=step):
        batch = jax.tree_map(jnp.asarray, train_iter.get_next())
        train_state, metrics = p_train_step(
            train_state,
            model_vars,
            batch,
            dropout_rngs,
        )
        train_metrics.append(metrics)

      # Quick indication that training is happening.
      logging.log_first_n(logging.INFO, 'Finished training step %d', 5, step)
      for h in hooks:
        h(step)

        # Periodic metric handling.
      if step % config.eval_every_steps == 0 or is_last_step:
        with report_progress.timed('training_metrics'):
          logging.info('Gathering training metrics.')
          train_metrics = common_utils.get_metrics(train_metrics)
          metrics_sums = jax.tree_map(jnp.sum, train_metrics)
          summary = metric_utils.process_metrics(metrics_sums, prefix='train')
          summary['learning_rate'] = learning_rate_fn(step)

          writer.write_scalars(step, summary)
          train_metrics = []

          with report_progress.timed('eval'):
            eval_results, eval_auxiliary = evaluate(
                eval_step_fn=p_eval_step,
                train_state=train_state,
                model_vars=model_vars,
                eval_data=eval_data,
            )
            writer.write_scalars(step, eval_results)

            if config.get('save_samples_every_steps') is not None:
              with report_progress.timed('save_samples'):
                if config.get('save_first_batch_only', 'True'):
                  postprocessing_input = [eval_auxiliary[0]]
                eval_processed = [
                    postprocessing_fn(batch, auxiliary_output)
                    for batch, auxiliary_output in eval_auxiliary
                ]
                data_utils.save_samples_to_json(eval_processed, config, step)

      # Save a checkpoint on one host after every checkpoint_freq steps.
      save_checkpoint = (
          step % config.checkpoint_every_steps == 0 or is_last_step)
      if (config.save_checkpoints and save_checkpoint and
          jax.process_index() == 0):
        with report_progress.timed('checkpoint'):
          logging.info('Saving checkpoint at step %s', step)
          checkpoints.save_checkpoint(
              config.model_dir,
              jax_utils.unreplicate(train_state),
              step,
              keep=config.get('keep_checkpoints', 1),
              keep_every_n_steps=config.get('keep_checkpoint_every_steps'),
          )

      save_model = (
          config.save_every_steps and
          (step % config.save_every_steps == 0 or is_last_step) and step != 0)
      if (save_model and jax.process_index() == 0):
        with report_progress.timed('checkpoint'):
          logging.info('Saving weights at step %s', step)
          save_path = os.path.join(config.model_dir, 'weights',
                                   'step' + str(step))
          # By default, save only encoder weights
          weights = jax_utils.unreplicate(train_state).params['encoder']
          checkpoint_utils.save_weights(save_path, weights)
def main(_):
    config = FLAGS.config

    # Unpack total and warmup steps
    total_steps = config.total_and_warmup_steps[0]
    warmup_steps = config.total_and_warmup_steps[1]
    del config.total_and_warmup_steps
    config.total_steps = total_steps
    config.lr.warmup_steps = warmup_steps

    # Wandb and Checkpointing Setup
    output_dir = FLAGS.output_dir
    wandb_run, output_dir = vit_utils.maybe_setup_wandb(config)
    tf.io.gfile.makedirs(output_dir)
    logging.info('Saving checkpoints at %s', output_dir)

    # Dataset Split Flags
    dist_shift = config.distribution_shift
    print(f'Distribution Shift: {dist_shift}.')
    dataset_names, split_names = vit_utils.get_dataset_and_split_names(
        dist_shift)

    # LR / Optimization Flags
    print('wandb hyperparameters:')
    print({
        'batch_size': config.batch_size,
        'grad_clip_norm': config.grad_clip_norm,
        'weight_decay': config.weight_decay,
        'total_steps': config.total_steps,
        'lr': config.lr,
        'fast_weight_lr_multiplier': config.fast_weight_lr_multiplier
    })

    # Reweighting loss for class imbalance
    # class_reweight_mode = config.class_reweight_mode
    # if class_reweight_mode == 'constant':
    #   class_weights = utils.get_diabetic_retinopathy_class_balance_weights()
    # else:
    #   class_weights = None

    # Shows the number of available devices.
    # In a CPU/GPU runtime this will be a single device.
    # In a TPU runtime this will be 8 cores.
    print('Number of Jax local devices:', jax.local_devices())

    # TODO(nband): fix sigmoid loss issues.
    assert config.get('loss', None) == 'softmax_xent'

    seed = config.get('seed', 0)
    rng = jax.random.PRNGKey(seed)
    tf.random.set_seed(seed)

    if config.get('data_dir'):
        logging.info('data_dir=%s', config.data_dir)
    logging.info('Output dir: %s', output_dir)
    tf.io.gfile.makedirs(output_dir)

    save_checkpoint_path = None
    if config.get('checkpoint_steps'):
        save_checkpoint_path = os.path.join(output_dir, 'checkpoint.npz')

    # Create an asynchronous multi-metric writer.
    writer = metric_writers.create_default_writer(
        output_dir, just_logging=jax.process_index() > 0)

    # The pool is used to perform misc operations such as logging in async way.
    pool = multiprocessing.pool.ThreadPool()

    def write_note(note):
        if jax.process_index() == 0:
            logging.info('NOTE: %s', note)

    write_note('Initializing...')

    # Verify settings to make sure no checkpoints are accidentally missed.
    if config.get('keep_checkpoint_steps'):
        assert config.get('checkpoint_steps'), 'Specify `checkpoint_steps`.'
        assert config.keep_checkpoint_steps % config.checkpoint_steps == 0, (
            f'`keep_checkpoint_steps` ({config.checkpoint_steps}) should be'
            f'divisible by `checkpoint_steps ({config.checkpoint_steps}).`')

    batch_size = config.batch_size
    batch_size_eval = config.get('batch_size_eval', batch_size)
    if (batch_size % jax.device_count() != 0
            or batch_size_eval % jax.device_count() != 0):
        raise ValueError(
            f'Batch sizes ({batch_size} and {batch_size_eval}) must '
            f'be divisible by device number ({jax.device_count()})')

    local_batch_size = batch_size // jax.process_count()
    local_batch_size_eval = batch_size_eval // jax.process_count()
    logging.info(
        'Global batch size %d on %d hosts results in %d local batch size. '
        'With %d devices per host (%d devices total), that\'s a %d per-device '
        'batch size.', batch_size, jax.process_count(), local_batch_size,
        jax.local_device_count(), jax.device_count(),
        local_batch_size // jax.local_device_count())

    write_note('Initializing preprocessing function...')
    # Same preprocessing function for training and evaluation
    preproc_fn = preprocess_spec.parse(
        spec=config.pp_train, available_ops=preprocess_utils.all_ops())

    write_note('Initializing train dataset...')
    rng, train_ds_rng = jax.random.split(rng)
    train_ds_rng = jax.random.fold_in(train_ds_rng, jax.process_index())
    train_base_dataset = ub.datasets.get(dataset_names['in_domain_dataset'],
                                         split=split_names['train_split'],
                                         data_dir=config.get('data_dir'))
    train_dataset_builder = train_base_dataset._dataset_builder  # pylint:disable=protected-access
    train_ds = input_utils.get_data(
        dataset=train_dataset_builder,
        split=split_names['train_split'],
        rng=train_ds_rng,
        process_batch_size=local_batch_size,
        preprocess_fn=preproc_fn,
        shuffle_buffer_size=config.shuffle_buffer_size,
        prefetch_size=config.get('prefetch_to_host', 2),
        data_dir=config.get('data_dir'))

    # Start prefetching already.
    train_iter = input_utils.start_input_pipeline(
        train_ds, config.get('prefetch_to_device', 1))

    write_note('Initializing val dataset(s)...')

    # Load in-domain and OOD validation and/or test datasets.
    # Please specify the desired shift (Country Shift or Severity Shift)
    # in the config.
    eval_iter_splits = vit_utils.init_evaluation_datasets(
        use_validation=config.use_validation,
        use_test=config.use_test,
        dataset_names=dataset_names,
        split_names=split_names,
        config=config,
        preproc_fn=preproc_fn,
        batch_size_eval=batch_size_eval,
        local_batch_size_eval=local_batch_size_eval)

    ntrain_img = input_utils.get_num_examples(
        train_dataset_builder,
        split=split_names['train_split'],
        process_batch_size=local_batch_size,
        data_dir=config.get('data_dir'))
    steps_per_epoch = ntrain_img // batch_size

    if config.get('num_epochs'):
        total_steps = int(config.num_epochs * steps_per_epoch)
        assert not config.get(
            'total_steps'), 'Set either num_epochs or total_steps'
    else:
        total_steps = config.total_steps

    logging.info('Total train data points: %d', ntrain_img)
    logging.info(
        'Running for %d steps, that means %f epochs and %d steps per epoch',
        total_steps, total_steps * batch_size / ntrain_img, steps_per_epoch)

    write_note('Initializing model...')
    model_dict = vit_utils.initialize_model('batchensemble', config)
    model = model_dict['model']
    ens_size = model_dict['ens_size']

    # We want all parameters to be created in host RAM, not on any device, they'll
    # be sent there later as needed, otherwise we already encountered two
    # situations where we allocate them twice.
    @functools.partial(jax.jit, backend='cpu')
    def init(rng):
        image_size = tuple(train_ds.element_spec['image'].shape[2:])
        logging.info('image_size = %s', image_size)
        dummy_input = jnp.zeros((local_batch_size, ) + image_size, jnp.float32)
        params = flax.core.unfreeze(model.init(rng, dummy_input,
                                               train=False))['params']

        # Set bias in the head to a low value, such that loss is small initially.
        params['batchensemble_head']['bias'] = jnp.full_like(
            params['batchensemble_head']['bias'],
            config.get('init_head_bias', 0))

        # init head kernel to all zeros for fine-tuning
        if config.get('model_init'):
            params['batchensemble_head']['kernel'] = jnp.full_like(
                params['batchensemble_head']['kernel'], 0)

        return params

    rng, rng_init = jax.random.split(rng)
    params_cpu = init(rng_init)

    if jax.process_index() == 0:
        num_params = sum(p.size for p in jax.tree_flatten(params_cpu)[0])
        parameter_overview.log_parameter_overview(params_cpu)
        writer.write_scalars(step=0, scalars={'num_params': num_params})

    @functools.partial(jax.pmap, axis_name='batch')
    def evaluation_fn(params, images, labels):
        tiled_logits, out = model.apply({'params': flax.core.freeze(params)},
                                        images,
                                        train=False)

        loss_name = config.get('loss', 'sigmoid_xent')
        # TODO(dusenberrymw,zmariet): Clean up and generalize this.
        if loss_name == 'sigmoid_xent':
            ens_logits = batchensemble_utils.log_average_sigmoid_probs(
                jnp.asarray(jnp.split(tiled_logits, ens_size)))
            pre_logits = batchensemble_utils.log_average_sigmoid_probs(
                jnp.asarray(jnp.split(out['pre_logits'], ens_size)))
        else:  # softmax
            ens_logits = batchensemble_utils.log_average_softmax_probs(
                jnp.asarray(jnp.split(tiled_logits, ens_size)))
            pre_logits = batchensemble_utils.log_average_softmax_probs(
                jnp.asarray(jnp.split(out['pre_logits'], ens_size)))

        losses = getattr(train_utils,
                         loss_name)(logits=ens_logits,
                                    labels=labels[:, :config.num_classes],
                                    reduction=False)
        loss = jax.lax.psum(losses, axis_name='batch')

        top1_idx = jnp.argmax(ens_logits, axis=1)
        top1_correct = jnp.take_along_axis(labels, top1_idx[:, None],
                                           axis=1)[:, 0]
        ncorrect = jax.lax.psum(top1_correct, axis_name='batch')
        n = batch_size_eval

        metric_args = jax.lax.all_gather([ens_logits, labels, pre_logits],
                                         axis_name='batch')
        return ncorrect, loss, n, metric_args

    # Load the optimizer from flax.
    opt_name = config.get('optim_name')
    write_note(f'Initializing {opt_name} optimizer...')
    opt_def = getattr(flax.optim, opt_name)(**config.get('optim', {}))

    # We jit this, such that the arrays that are created are created on the same
    # device as the input is, in this case the CPU. Else they'd be on device[0].
    opt_cpu = jax.jit(opt_def.create)(params_cpu)

    weight_decay_rules = config.get('weight_decay', []) or []
    rescale_value = config.lr.base if config.get(
        'weight_decay_decouple') else 1.
    weight_decay_fn = train_utils.get_weight_decay_fn(
        weight_decay_rules=weight_decay_rules, rescale_value=rescale_value)

    def batch_loss_fn(params, images, labels, rngs):
        logits, _ = model.apply({'params': flax.core.freeze(params)},
                                images,
                                train=True,
                                rngs=rngs)
        labels = jnp.tile(labels, (ens_size, 1))
        loss_fn = getattr(train_utils, config.get('loss', 'sigmoid_xent'))
        loss = jnp.mean(loss_fn(logits=logits, labels=labels))
        return loss, dict()

    @functools.partial(jax.pmap, axis_name='batch', donate_argnums=(0, 1))
    def update_fn(opt, rngs, lr, images, labels):
        return batchensemble_utils.update_fn_be(
            opt=opt,
            rngs=rngs,
            lr=lr,
            images=images,
            labels=labels,
            batch_loss_fn=batch_loss_fn,
            weight_decay_fn=weight_decay_fn,
            max_grad_norm_global=config.get('grad_clip_norm', None),
            fast_weight_lr_multiplier=config.get('fast_weight_lr_multiplier',
                                                 None))

    # Set config checkpoint resume path, if provided in args.
    if config.resume_checkpoint_path is not None:
        config.resume = config.resume_checkpoint_path

    reint_params = ('batchensemble_head/bias', 'batchensemble_head/kernel',
                    'batchensemble_head/fast_weight_alpha',
                    'batchensemble_head/fast_weight_gamma')
    if config.get('only_eval', False) or not config.get('reint_head', True):
        reint_params = []
    checkpoint_data = checkpoint_utils.maybe_load_checkpoint(
        train_loop_rngs=rng,
        save_checkpoint_path=save_checkpoint_path,
        init_optimizer=opt_cpu,
        init_params=params_cpu,
        init_fixed_model_states=None,
        default_reinit_params=reint_params,
        config=config)
    train_loop_rngs = {'dropout': checkpoint_data.train_loop_rngs}
    opt_cpu = checkpoint_data.optimizer
    accumulated_train_time = checkpoint_data.accumulated_train_time

    write_note('Adapting the checkpoint model...')
    adapted_params = checkpoint_utils.adapt_upstream_architecture(
        init_params=params_cpu, loaded_params=opt_cpu.target)
    opt_cpu = opt_cpu.replace(target=adapted_params)

    write_note('Kicking off misc stuff...')
    first_step = int(opt_cpu.state.step)  # Might be a DeviceArray type.
    if first_step == 0 and jax.process_index() == 0:
        writer.write_hparams(dict(config))
    chrono = train_utils.Chrono(first_step, total_steps, batch_size,
                                accumulated_train_time)
    # Note: switch to ProfileAllHosts() if you need to profile all hosts.
    # (Xprof data become much larger and take longer to load for analysis)
    profiler = periodic_actions.Profile(
        # Create profile after every restart to analyze pre-emption related
        # problems and assure we get similar performance in every run.
        logdir=output_dir,
        first_profile=first_step + 10)

    # Prepare the learning-rate and pre-fetch it to device to avoid delays.
    lr_fn = train_utils.create_learning_rate_schedule(total_steps,
                                                      **config.get('lr', {}))

    # TODO(dusenberrymw): According to flax docs, prefetching shouldn't be
    # necessary for TPUs.
    lr_iter = train_utils.prefetch_scalar(map(lr_fn, range(total_steps)),
                                          config.get('prefetch_to_device', 1))

    write_note(f'Replicating...\n{chrono.note}')
    opt_repl = flax.jax_utils.replicate(opt_cpu)

    checkpoint_writer = None

    # Note: we return the train loss, val loss, and fewshot best l2s for use in
    # reproducibility unit tests.
    # train_loss = -jnp.inf
    # eval_loss = {
    #     eval_name: -jnp.inf for eval_name, _ in eval_iter_splits.items()}
    # fewshot_results = {'dummy': {(0, 1): -jnp.inf}}

    write_note(f'First step compilations...\n{chrono.note}')
    logging.info('first_step = %s', first_step)
    # Advance the iterators if we are restarting from an earlier checkpoint.
    # TODO(dusenberrymw): Look into checkpointing dataset state instead.
    if first_step > 0:
        write_note('Advancing iterators after resuming from a checkpoint...')
        lr_iter = itertools.islice(lr_iter, first_step, None)
        # TODO(zmariet): Find better way to cut down iteration advancement cost.
        if not config.get('disable_preemption_reproducibility', False):
            train_iter = itertools.islice(train_iter, first_step, None)

    # Using a python integer for step here, because opt.state.step is allocated
    # on TPU during replication.
    for step, train_batch, lr_repl in zip(
            range(first_step + 1, total_steps + 1), train_iter, lr_iter):

        with jax.profiler.TraceAnnotation('train_step', step_num=step, _r=1):
            if not config.get('only_eval', False):
                opt_repl, train_loop_rngs, extra_measurements = update_fn(
                    opt_repl, train_loop_rngs, lr_repl, train_batch['image'],
                    train_batch['labels'])

        if jax.process_index() == 0:
            profiler(step)

        # Checkpoint saving
        if not config.get('only_eval', False) and train_utils.itstime(
                step, config.get('checkpoint_steps'), total_steps, process=0):
            write_note('Checkpointing...')
            chrono.pause()
            train_utils.checkpointing_timeout(
                checkpoint_writer, config.get('checkpoint_timeout', 1))
            accumulated_train_time = chrono.accum_train_time
            # We need to transfer the weights over now or else we risk keeping them
            # alive while they'll be updated in a future step, creating hard to debug
            # memory errors (see b/160593526). Also, takes device 0's params only.
            opt_cpu = jax.tree_util.tree_map(lambda x: np.array(x[0]),
                                             opt_repl)

            # Check whether we want to keep a copy of the current checkpoint.
            copy_step = None
            if train_utils.itstime(step, config.get('keep_checkpoint_steps'),
                                   total_steps):
                write_note('Keeping a checkpoint copy...')
                copy_step = step

            # Checkpoint should be a nested dictionary or FLAX datataclasses from
            # `flax.struct`. Both can be present in a checkpoint.
            checkpoint_data = checkpoint_utils.CheckpointData(
                train_loop_rngs=train_loop_rngs,
                optimizer=opt_cpu,
                accumulated_train_time=accumulated_train_time)

            checkpoint_writer = pool.apply_async(
                checkpoint_utils.checkpoint_trained_model,
                (checkpoint_data, save_checkpoint_path, copy_step))
            chrono.resume()

        # Report training progress
        if not config.get('only_eval', False) and train_utils.itstime(
                step, config.log_training_steps, total_steps, process=0):
            write_note('Reporting training progress...')
            timing_measurements, note = chrono.tick(step)
            write_note(note)
            train_measurements = {}
            train_measurements.update(
                flax.jax_utils.unreplicate(extra_measurements))
            train_measurements.update(timing_measurements)
            writer.write_scalars(step, train_measurements)
            # Keep to return for reproducibility tests.
            # train_loss = train_measurements['training_loss']

        # Report validation performance
        if config.get('only_eval', False) or train_utils.itstime(
                step, config.log_eval_steps, total_steps):
            write_note('Evaluating on the validation sets...')
            chrono.pause()

            all_eval_results = {}

            for eval_name, (eval_iter, eval_steps) in eval_iter_splits.items():
                start_time = time.time()

                # Runs evaluation loop.
                results_arrs = {
                    'y_true': [],
                    'y_pred': [],
                    'y_pred_entropy': []
                }

                for _, batch in zip(range(eval_steps), eval_iter):
                    batch_ncorrect, batch_losses, batch_n, batch_metric_args = (  # pylint: disable=unused-variable
                        evaluation_fn(opt_repl.target, batch['image'],
                                      batch['labels']))

                    # All results are a replicated array shaped as follows:
                    # (local_devices, per_device_batch_size, elem_shape...)
                    # with each local device's entry being identical as they got psum'd.
                    # So let's just take the first one to the host as numpy.

                    # Here we parse batch_metric_args to compute uncertainty metrics.
                    logits, labels, _ = batch_metric_args
                    logits = np.array(logits[0])
                    probs = jax.nn.softmax(logits)

                    # From one-hot to integer labels.
                    int_labels = np.argmax(np.array(labels[0]), axis=-1)

                    probs = np.reshape(probs,
                                       (probs.shape[0] * probs.shape[1], -1))
                    int_labels = int_labels.flatten()
                    y_pred = probs[:, 1]
                    results_arrs['y_true'].append(int_labels)
                    results_arrs['y_pred'].append(y_pred)

                    # Entropy is computed at the per-epoch level (see below).
                    results_arrs['y_pred_entropy'].append(probs)

                results_arrs['y_true'] = np.concatenate(results_arrs['y_true'],
                                                        axis=0)
                results_arrs['y_pred'] = np.concatenate(
                    results_arrs['y_pred'], axis=0).astype('float64')
                results_arrs['y_pred_entropy'] = vit_utils.entropy(
                    np.concatenate(results_arrs['y_pred_entropy'], axis=0),
                    axis=-1)

                time_elapsed = time.time() - start_time
                results_arrs['total_ms_elapsed'] = time_elapsed * 1e3
                results_arrs['dataset_size'] = eval_steps * batch_size_eval

                all_eval_results[eval_name] = results_arrs

            per_pred_results, metrics_results = vit_utils.evaluate_vit_predictions(  # pylint: disable=unused-variable
                dataset_split_to_containers=all_eval_results,
                is_deterministic=True,
                num_bins=15,
                return_per_pred_results=True)

            # `metrics_results` is a dict of {str: jnp.ndarray} dicts, one for each
            # dataset. Flatten this dict so we can pass to the writer and remove empty
            # entries.
            flattened_metric_results = {}
            for dic in metrics_results.values():
                for key, value in dic.items():
                    if value is not None:
                        flattened_metric_results[key] = value
            writer.write_scalars(step, flattened_metric_results)

            # Optionally log to wandb
            if config.use_wandb:
                wandb.log(metrics_results, step=step)

            # Save per-prediction metrics
            results_storage_utils.save_per_prediction_results(output_dir,
                                                              step,
                                                              per_pred_results,
                                                              verbose=False)
            chrono.resume()

        # End of step.
        if config.get('testing_failure_step'):
            # Break early to simulate infra failures in test cases.
            if config.testing_failure_step == step:
                break

        if config.get('only_eval', False):
            break

    write_note(f'Done!\n{chrono.note}')
    pool.close()
    pool.join()
    writer.close()
Example #14
0
    def train_and_evaluate(self, workdir):
        """Runs a training and evaluation loop.

    Args:
      workdir: Working directory for checkpoints and TF summaries. If this
        contains checkpoint training will be resumed from the latest checkpoint.
    """

        tf.io.gfile.makedirs(workdir)
        config = self.config
        substeps = config.training.substeps

        # Learning rate schedule.
        num_train_steps = config.training.num_train_steps
        logging.info('num_train_steps=%d', num_train_steps)

        # Get train state
        state = self._train_state

        # Set up checkpointing of the model and the input pipeline.
        checkpoint_dir = os.path.join(workdir, 'checkpoints')
        ckpt = checkpoint.MultihostCheckpoint(checkpoint_dir, max_to_keep=5)
        state = ckpt.restore_or_initialize(state)
        initial_step = int(state.step)

        # Distribute training.
        state = flax_utils.replicate(state)

        writer = metric_writers.create_default_writer(
            workdir, just_logging=jax.process_index() > 0)
        if initial_step == 0:
            writer.write_hparams(dict(config))

        logging.info('Starting training loop at step %d.', initial_step)
        hooks = []
        report_progress = periodic_actions.ReportProgress(
            num_train_steps=num_train_steps, writer=writer)
        if jax.process_index() == 0:
            hooks += [
                report_progress,
                periodic_actions.Profile(num_profile_steps=5, logdir=workdir)
            ]
        step = initial_step
        with metric_writers.ensure_flushes(writer):
            while step < num_train_steps:
                # `step` is a Python integer. `state.step` is JAX integer on the GPU/TPU
                # devices.
                is_last_step = step + substeps >= num_train_steps

                with jax.profiler.StepTraceAnnotation('train', step_num=step):
                    inputs = jax.tree_map(np.asarray, next(self._train_iter))
                    state, outputs = self._update_func(state, inputs)

                # Quick indication that training is happening.
                logging.log_first_n(logging.INFO, 'Finished training step %d.',
                                    5, step)
                for h in hooks:
                    h(step)

                new_step = int(state.step[0])
                assert new_step == step + substeps
                step = new_step

                is_eval = step % config.logs.eval_full_every_steps == 0 or is_last_step
                if step % config.logs.log_loss_every_steps == 0 and not is_eval:

                    def avg_over_substeps(x):
                        assert x.shape[0] == substeps
                        return float(x.mean(axis=0))

                    # Extract scalars and images.
                    outputs = flax_utils.unreplicate(outputs)
                    outputs = jax.tree_map(avg_over_substeps, outputs)
                    scalars = outputs['scalars']
                    writer.write_scalars(step, scalars)

                if is_eval:
                    with report_progress.timed('eval_full'):
                        outputs = self._eval_epoch(params=state.ema_params)
                        outputs = flax_utils.unreplicate(outputs)
                        scalars = outputs['scalars']
                        writer.write_scalars(step, scalars)

                if step % config.logs.checkpoint_every_steps == 0 or is_last_step:
                    with report_progress.timed('checkpoint'):
                        ckpt.save(flax_utils.unreplicate(state))

        logging.info('Finishing training at step %d', num_train_steps)
Example #15
0
def train_and_evaluate(config: ml_collections.ConfigDict, workdir: str):
    """Runs a training and evaluation loop.

  Args:
    config: Configuration to use.
    workdir: Working directory for checkpoints and TF summaries. If this
      contains checkpoint training will be resumed from the latest checkpoint.
  """
    tf.io.gfile.makedirs(workdir)

    # Number of local devices for this host.
    n_devices = jax.local_device_count()

    if config.batch_size % n_devices:
        raise ValueError(
            "Batch size must be divisible by the number of devices")

    vocab_path = config.vocab_path
    if vocab_path is None:
        vocab_path = os.path.join(workdir, "sentencepiece_model")
    tf.io.gfile.makedirs(os.path.split(vocab_path)[0])

    # Load Dataset
    # ---------------------------------------------------------------------------
    logging.info("Initializing dataset.")
    train_ds, eval_ds, predict_ds, encoder = input_pipeline.get_wmt_datasets(
        n_devices=n_devices,
        dataset_name=config.dataset_name,
        eval_dataset_name=config.eval_dataset_name,
        shard_idx=jax.host_id(),
        shard_count=jax.host_count(),
        vocab_path=vocab_path,
        target_vocab_size=config.vocab_size,
        batch_size=config.batch_size,
        max_corpus_chars=config.max_corpus_chars,
        max_length=config.max_target_length,
        max_eval_length=config.max_eval_target_length)
    train_iter = iter(train_ds)
    vocab_size = int(encoder.vocab_size())
    eos_id = decode.EOS_ID  # Default Sentencepiece EOS token.

    def decode_tokens(toks):
        valid_toks = toks[:np.argmax(toks == eos_id) + 1].astype(np.int32)
        return encoder.detokenize(valid_toks).numpy().decode("utf-8")

    if config.num_predict_steps > 0:
        predict_ds = predict_ds.take(config.num_predict_steps)

    logging.info("Initializing model, optimizer, and step functions.")

    # Build Model and Optimizer
    # ---------------------------------------------------------------------------
    train_config = models.TransformerConfig(
        vocab_size=vocab_size,
        output_vocab_size=vocab_size,
        share_embeddings=config.share_embeddings,
        logits_via_embedding=config.logits_via_embedding,
        dtype=jnp.bfloat16 if config.use_bfloat16 else jnp.float32,
        emb_dim=config.emb_dim,
        num_heads=config.num_heads,
        num_layers=config.num_layers,
        qkv_dim=config.qkv_dim,
        mlp_dim=config.mlp_dim,
        max_len=max(config.max_target_length, config.max_eval_target_length),
        dropout_rate=config.dropout_rate,
        attention_dropout_rate=config.attention_dropout_rate,
        deterministic=False,
        decode=False,
        kernel_init=nn.initializers.xavier_uniform(),
        bias_init=nn.initializers.normal(stddev=1e-6))
    eval_config = train_config.replace(deterministic=True)
    predict_config = train_config.replace(deterministic=True, decode=True)

    start_step = 0
    rng = random.PRNGKey(config.seed)
    rng, init_rng = random.split(rng)
    input_shape = (config.batch_size, config.max_target_length)
    target_shape = (config.batch_size, config.max_target_length)

    m = models.Transformer(eval_config)
    initial_variables = jax.jit(m.init)(init_rng,
                                        jnp.ones(input_shape, jnp.float32),
                                        jnp.ones(target_shape, jnp.float32))

    # apply an optimizer to this tree
    optimizer_def = optim.Adam(config.learning_rate,
                               beta1=0.9,
                               beta2=0.98,
                               eps=1e-9,
                               weight_decay=config.weight_decay)
    optimizer = optimizer_def.create(initial_variables["params"])

    # We access model params only from optimizer below via optimizer.target.
    del initial_variables

    if config.restore_checkpoints:
        # Restore unreplicated optimizer + model state from last checkpoint.
        optimizer = checkpoints.restore_checkpoint(workdir, optimizer)
        # Grab last step.
        start_step = int(optimizer.state.step)

    writer = metric_writers.create_default_writer(
        workdir, just_logging=jax.host_id() > 0)
    if start_step == 1:
        writer.write_hparams(dict(config))

    # Replicate optimizer.
    optimizer = jax_utils.replicate(optimizer)

    learning_rate_fn = create_learning_rate_scheduler(
        base_learning_rate=config.learning_rate,
        warmup_steps=config.warmup_steps)

    # compile multidevice versions of train/eval/predict step and cache init fn.
    p_train_step = jax.pmap(functools.partial(
        train_step,
        config=train_config,
        learning_rate_fn=learning_rate_fn,
        label_smoothing=config.label_smoothing),
                            axis_name="batch",
                            donate_argnums=(0, ))  # pytype: disable=wrong-arg-types
    p_eval_step = jax.pmap(functools.partial(
        eval_step, config=eval_config, label_smoothing=config.label_smoothing),
                           axis_name="batch")
    p_init_cache = jax.pmap(functools.partial(
        initialize_cache,
        max_decode_len=config.max_predict_length,
        config=predict_config),
                            axis_name="batch")
    p_pred_step = jax.pmap(
        functools.partial(predict_step,
                          config=predict_config,
                          beam_size=config.beam_size),
        axis_name="batch",
        static_broadcasted_argnums=(3,
                                    4))  # eos token, max_length are constant

    # Main Train Loop
    # ---------------------------------------------------------------------------

    # We init the first set of dropout PRNG keys, but update it afterwards inside
    # the main pmap"d training update for performance.
    dropout_rngs = random.split(rng, n_devices)

    logging.info("Starting training loop.")
    hooks = []
    report_progress = periodic_actions.ReportProgress(
        num_train_steps=config.num_train_steps, writer=writer)
    if jax.host_id() == 0:
        hooks += [
            report_progress,
            periodic_actions.Profile(num_profile_steps=5)
        ]
    metrics_all = []
    with metric_writers.ensure_flushes(writer):
        for step, batch in zip(range(start_step, config.num_train_steps),
                               train_iter):
            # Shard data to devices and do a training step.
            batch = common_utils.shard(
                jax.tree_map(lambda x: x._numpy(), batch))  # pylint: disable=protected-access
            optimizer, metrics, dropout_rngs = p_train_step(
                optimizer, batch, dropout_rng=dropout_rngs)
            metrics_all.append(metrics)

            # Quick indication that training is happening.
            logging.log_first_n(logging.INFO, "Finished training step %d.", 5,
                                step)
            for h in hooks:
                h(step)

            # Save a checkpoint on one host after every checkpoint_freq steps.
            if (config.save_checkpoints and step % config.checkpoint_freq == 0
                    and step > 0 and jax.host_id() == 0):
                checkpoints.save_checkpoint(workdir,
                                            jax_utils.unreplicate(optimizer),
                                            step)

            # Periodic metric handling.
            if step % config.eval_frequency != 0 and step > 0:
                continue

            # Training Metrics
            logging.info("Gathering training metrics.")
            metrics_all = common_utils.get_metrics(metrics_all)
            lr = metrics_all.pop("learning_rate").mean()
            metrics_sums = jax.tree_map(jnp.sum, metrics_all)
            denominator = metrics_sums.pop("denominator")
            summary = jax.tree_map(lambda x: x / denominator, metrics_sums)  # pylint: disable=cell-var-from-loop
            summary["learning_rate"] = lr
            summary = {"train_" + k: v for k, v in summary.items()}
            writer.write_scalars(step, summary)
            metrics_all = []

            # Eval Metrics
            logging.info("Gathering evaluation metrics.")
            eval_metrics = []
            eval_iter = iter(eval_ds)
            for _, eval_batch in zip(range(config.num_eval_steps), eval_iter):
                eval_batch = jax.tree_map(lambda x: x._numpy(), eval_batch)  # pylint: disable=protected-access
                eval_batch = common_utils.shard(eval_batch)
                metrics = p_eval_step(optimizer.target, eval_batch)
                eval_metrics.append(metrics)
            eval_metrics = common_utils.get_metrics(eval_metrics)
            eval_metrics_sums = jax.tree_map(jnp.sum, eval_metrics)
            eval_denominator = eval_metrics_sums.pop("denominator")
            eval_summary = jax.tree_map(
                lambda x: x / eval_denominator,  # pylint: disable=cell-var-from-loop
                eval_metrics_sums)
            eval_summary = {"eval_" + k: v for k, v in eval_summary.items()}
            writer.write_scalars(step, eval_summary)

            # Translation and BLEU Score.
            logging.info("Translating evaluation dataset.")
            t_inference_start = time.time()
            sources, references, predictions = [], [], []
            for pred_batch in predict_ds:
                pred_batch = jax.tree_map(lambda x: x._numpy(), pred_batch)  # pylint: disable=protected-access
                # Handle final odd-sized batch by padding instead of dropping it.
                cur_pred_batch_size = pred_batch["inputs"].shape[0]
                if cur_pred_batch_size % n_devices:
                    padded_size = int(
                        np.ceil(cur_pred_batch_size / n_devices) * n_devices)
                    pred_batch = jax.tree_map(
                        lambda x: pad_examples(x, padded_size),  # pylint: disable=cell-var-from-loop
                        pred_batch)
                pred_batch = common_utils.shard(pred_batch)
                cache = p_init_cache(pred_batch["inputs"])
                predicted = p_pred_step(pred_batch["inputs"], optimizer.target,
                                        cache, eos_id,
                                        config.max_predict_length)
                predicted = tohost(predicted)
                inputs = tohost(pred_batch["inputs"])
                targets = tohost(pred_batch["targets"])
                # Iterate through non-padding examples of batch.
                for i, s in enumerate(predicted[:cur_pred_batch_size]):
                    sources.append(decode_tokens(inputs[i]))
                    references.append(decode_tokens(targets[i]))
                    predictions.append(decode_tokens(s))
            logging.info(
                "Translation: %d predictions %d references %d sources.",
                len(predictions), len(references), len(sources))
            logging.info("Translation time: %.4f s step %d.",
                         time.time() - t_inference_start, step)

            # Calculate BLEU score for translated eval corpus against reference.
            bleu_matches = bleu.bleu_partial(references, predictions)
            all_bleu_matches = per_host_sum_pmap(bleu_matches)
            bleu_score = bleu.complete_bleu(*all_bleu_matches)
            # Save translation samples for tensorboard.
            exemplars = ""
            for n in np.random.choice(np.arange(len(predictions)), 8):
                exemplars += f"{sources[n]}\n\n{references[n]}\n\n{predictions[n]}\n\n"
            writer.write_scalars(step, {"bleu": bleu_score})
            writer.write_texts(step, {"samples": exemplars})
Example #16
0
def train_and_evaluate(config: ml_collections.ConfigDict, workdir: str):
    """Execute model training and evaluation loop.

  Args:
    config: Hyperparameter configuration for training and evaluation.
    workdir: Directory where the tensorboard summaries are written to.
  """

    if jax.host_id() == 0:
        summary_writer = tensorboard.SummaryWriter(workdir)
        summary_writer.hparams(dict(config))

    rng = random.PRNGKey(0)

    image_size = 224

    if config.batch_size % jax.device_count() > 0:
        raise ValueError(
            'Batch size must be divisible by the number of devices')
    local_batch_size = config.batch_size // jax.host_count()

    platform = jax.local_devices()[0].platform

    if config.half_precision:
        if platform == 'tpu':
            input_dtype = tf.bfloat16
        else:
            input_dtype = tf.float16
    else:
        input_dtype = tf.float32

    dataset_builder = tfds.builder('imagenet2012:5.*.*')
    train_iter = create_input_iter(dataset_builder,
                                   local_batch_size,
                                   image_size,
                                   input_dtype,
                                   train=True,
                                   cache=config.cache)
    eval_iter = create_input_iter(dataset_builder,
                                  local_batch_size,
                                  image_size,
                                  input_dtype,
                                  train=False,
                                  cache=config.cache)

    steps_per_epoch = (dataset_builder.info.splits['train'].num_examples //
                       config.batch_size)

    if config.num_train_steps == -1:
        num_steps = steps_per_epoch * config.num_epochs
    else:
        num_steps = config.num_train_steps

    if config.steps_per_eval == -1:
        num_validation_examples = dataset_builder.info.splits[
            'validation'].num_examples
        steps_per_eval = num_validation_examples // config.batch_size
    else:
        steps_per_eval = config.steps_per_eval

    steps_per_checkpoint = steps_per_epoch * 10

    base_learning_rate = config.learning_rate * config.batch_size / 256.

    model_cls = getattr(models, config.model)
    model = create_model(model_cls=model_cls,
                         half_precision=config.half_precision)

    state = create_train_state(rng, config, model, image_size)
    state = restore_checkpoint(state, workdir)
    # step_offset > 0 if restarting from checkpoint
    step_offset = int(state.step)
    state = jax_utils.replicate(state)

    learning_rate_fn = create_learning_rate_fn(base_learning_rate,
                                               steps_per_epoch,
                                               config.num_epochs)

    p_train_step = jax.pmap(functools.partial(
        train_step, model.apply, learning_rate_fn=learning_rate_fn),
                            axis_name='batch')
    p_eval_step = jax.pmap(functools.partial(eval_step, model.apply),
                           axis_name='batch')

    epoch_metrics = []
    hooks = []
    if jax.host_id() == 0:
        hooks += [
            periodic_actions.Profile(logdir=workdir, num_profile_steps=5)
        ]
    t_loop_start = time.time()
    logging.info('Initial compilation, this might take some minutes...')
    for step, batch in zip(range(step_offset, num_steps), train_iter):
        state, metrics = p_train_step(state, batch)
        for h in hooks:
            h(step)
        epoch_metrics.append(metrics)
        if (step + 1) % steps_per_epoch == 0:
            epoch = step // steps_per_epoch
            epoch_metrics = common_utils.get_metrics(epoch_metrics)
            summary = jax.tree_map(lambda x: x.mean(), epoch_metrics)
            logging.info('train epoch: %d, loss: %.4f, accuracy: %.2f', epoch,
                         summary['loss'], summary['accuracy'] * 100)
            steps_per_sec = steps_per_epoch / (time.time() - t_loop_start)
            t_loop_start = time.time()
            if jax.host_id() == 0:
                for key, vals in epoch_metrics.items():
                    tag = 'train_%s' % key
                    for i, val in enumerate(vals):
                        summary_writer.scalar(tag, val,
                                              step - len(vals) + i + 1)
                summary_writer.scalar('steps per second', steps_per_sec, step)

            epoch_metrics = []
            eval_metrics = []

            # sync batch statistics across replicas
            state = sync_batch_stats(state)
            for _ in range(steps_per_eval):
                eval_batch = next(eval_iter)
                metrics = p_eval_step(state, eval_batch)
                eval_metrics.append(metrics)
            eval_metrics = common_utils.get_metrics(eval_metrics)
            summary = jax.tree_map(lambda x: x.mean(), eval_metrics)
            logging.info('eval epoch: %d, loss: %.4f, accuracy: %.2f', epoch,
                         summary['loss'], summary['accuracy'] * 100)
            if jax.host_id() == 0:
                for key, val in eval_metrics.items():
                    tag = 'eval_%s' % key
                    summary_writer.scalar(tag, val.mean(), step)
                summary_writer.flush()
        if (step + 1) % steps_per_checkpoint == 0 or step + 1 == num_steps:
            state = sync_batch_stats(state)
            save_checkpoint(state, workdir)

    # Wait until computations are done before exiting
    jax.random.normal(jax.random.PRNGKey(0), ()).block_until_ready()
Example #17
0
def main(argv):
  del argv  # unused arg

  config = FLAGS.config

  # Unpack total and warmup steps
  # TODO(nband): revert this to separate arguments.
  total_steps = config.total_and_warmup_steps[0]
  warmup_steps = config.total_and_warmup_steps[1]
  del config.total_and_warmup_steps
  config.total_steps = total_steps
  config.lr.warmup_steps = warmup_steps

  # Wandb and Checkpointing Setup
  output_dir = FLAGS.output_dir
  wandb_run, output_dir = vit_utils.maybe_setup_wandb(config)
  tf.io.gfile.makedirs(output_dir)
  logging.info('Saving checkpoints at %s', output_dir)

  # Dataset Split Flags
  dist_shift = config.distribution_shift
  print(f'Distribution Shift: {dist_shift}.')
  dataset_names, split_names = vit_utils.get_dataset_and_split_names(dist_shift)

  # LR / Optimization Flags
  batch_size = config.batch_size
  grad_clip_norm = config.grad_clip_norm
  weight_decay = config.weight_decay
  print('Standard wandb hyperparameters:')
  print({
      'batch_size': batch_size,
      'grad_clip_norm': grad_clip_norm,
      'weight_decay': weight_decay,
      'total_steps': config.total_steps,
      'lr': config.lr
  })
  print('SNGP Params:', config.gp_layer)

  # Reweighting loss for class imbalance
  # class_reweight_mode = config.class_reweight_mode
  # if class_reweight_mode == 'constant':
  #   class_weights = utils.get_diabetic_retinopathy_class_balance_weights()
  # else:
  #   class_weights = None

  # Shows the number of available devices.
  # In a CPU/GPU runtime this will be a single device.
  # In a TPU runtime this will be 8 cores.
  print('Number of Jax local devices:', jax.local_devices())

  # TODO(nband): fix sigmoid loss issues.
  assert config.get('loss', None) == 'softmax_xent'

  seed = config.seed
  rng = jax.random.PRNGKey(seed)
  tf.random.set_seed(seed)

  if config.get('data_dir'):
    logging.info('data_dir=%s', config.data_dir)
  logging.info('Output dir: %s', output_dir)

  save_checkpoint_path = None
  if config.get('checkpoint_steps'):
    tf.io.gfile.makedirs(output_dir)
    save_checkpoint_path = os.path.join(output_dir, 'checkpoint.npz')

  # Create an asynchronous multi-metric writer.
  writer = metric_writers.create_default_writer(
      output_dir, just_logging=jax.process_index() > 0)

  # The pool is used to perform misc operations such as logging in async way.
  pool = multiprocessing.pool.ThreadPool()

  def write_note(note):
    if jax.process_index() == 0:
      logging.info('NOTE: %s', note)

  write_note('Initializing...')

  # Verify settings to make sure no checkpoints are accidentally missed.
  if config.get('keep_checkpoint_steps'):
    assert config.get('checkpoint_steps'), 'Specify `checkpoint_steps`.'
    assert config.keep_checkpoint_steps % config.checkpoint_steps == 0, (
        f'`keep_checkpoint_steps` ({config.checkpoint_steps}) should be'
        f'divisible by `checkpoint_steps ({config.checkpoint_steps}).`')

  batch_size_eval = config.get('batch_size_eval', batch_size)
  if (batch_size % jax.device_count() != 0 or
      batch_size_eval % jax.device_count() != 0):
    raise ValueError(f'Batch sizes ({batch_size} and {batch_size_eval}) must '
                     f'be divisible by device number ({jax.device_count()})')

  local_batch_size = batch_size // jax.process_count()
  local_batch_size_eval = batch_size_eval // jax.process_count()
  logging.info(
      'Global batch size %d on %d hosts results in %d local batch size. '
      'With %d dev per host (%d dev total), that is a %d per-device batch size.',
      batch_size,
      jax.process_count(), local_batch_size, jax.local_device_count(),
      jax.device_count(), local_batch_size // jax.local_device_count())

  write_note('Initializing preprocessing function...')
  # Same preprocessing function for training and evaluation
  preproc_fn = preprocess_spec.parse(
      spec=config.pp_train, available_ops=preprocess_utils.all_ops())

  write_note('Initializing train dataset...')
  rng, train_ds_rng = jax.random.split(rng)
  train_ds_rng = jax.random.fold_in(train_ds_rng, jax.process_index())
  train_base_dataset = ub.datasets.get(
      dataset_names['in_domain_dataset'],
      split=split_names['train_split'],
      data_dir=config.get('data_dir'))
  train_dataset_builder = train_base_dataset._dataset_builder  # pylint: disable=protected-access
  train_ds = input_utils.get_data(
      dataset=train_dataset_builder,
      split=split_names['train_split'],
      rng=train_ds_rng,
      process_batch_size=local_batch_size,
      preprocess_fn=preproc_fn,
      shuffle_buffer_size=config.shuffle_buffer_size,
      prefetch_size=config.get('prefetch_to_host', 2),
      data_dir=config.get('data_dir'))
  logging.info('image_size = %s', train_ds.element_spec['image'].shape[1:])

  # Start prefetching already.
  train_iter = input_utils.start_input_pipeline(
      train_ds, config.get('prefetch_to_device', 1))

  write_note('Initializing val dataset(s)...')

  # Load in-domain and OOD validation and/or test datasets.
  # Please specify the desired shift (Country Shift or Severity Shift)
  # in the config.
  eval_iter_splits = vit_utils.init_evaluation_datasets(
      use_validation=config.use_validation,
      use_test=config.use_test,
      dataset_names=dataset_names,
      split_names=split_names,
      config=config,
      preproc_fn=preproc_fn,
      batch_size_eval=batch_size_eval,
      local_batch_size_eval=local_batch_size_eval)

  ntrain_img = input_utils.get_num_examples(
      train_dataset_builder,
      split=split_names['train_split'],
      process_batch_size=local_batch_size,
      data_dir=config.get('data_dir'))
  steps_per_epoch = ntrain_img / batch_size

  if config.get('num_epochs'):
    total_steps = int(config.num_epochs * steps_per_epoch)
    assert not config.get('total_steps'), 'Set either num_epochs or total_steps'
  else:
    total_steps = config.total_steps

  logging.info('Total train data points: %d', ntrain_img)
  logging.info(
      'Running for %d steps, that means %f epochs and %d steps per epoch',
      total_steps, total_steps * batch_size / ntrain_img, steps_per_epoch)

  write_note('Initializing model...')
  logging.info('config.model = %s', config.get('model'))

  # Specify Gaussian process layer configs.
  gp_config = config.get('gp_layer', {})
  model_dict = vit_utils.initialize_model('sngp', config)
  model, use_gp_layer = model_dict['model'], model_dict['use_gp_layer']

  # We want all parameters to be created in host RAM, not on any device, they'll
  # be sent there later as needed, otherwise we already encountered two
  # situations where we allocate them twice.
  @functools.partial(jax.jit, backend='cpu')
  def init(rng):
    image_size = tuple(train_ds.element_spec['image'].shape[2:])
    logging.info('image_size = %s', image_size)
    dummy_input = jnp.zeros((local_batch_size,) + image_size, jnp.float32)
    variables = model.init(rng, dummy_input, train=False)
    # Split model parameters into trainable and untrainable collections.
    states, params = variables.pop('params')
    del variables

    # Set bias in the head to a low value, such that loss is small initially.
    params = flax.core.unfreeze(params)
    if use_gp_layer:
      # Modify the head parameter in the GP head.
      params['head']['output_layer']['bias'] = jnp.full_like(
          params['head']['output_layer']['bias'],
          config.get('init_head_bias', 0))
    else:
      params['head']['bias'] = jnp.full_like(
          params['head']['bias'], config.get('init_head_bias', 0))

    return params, states

  rng, rng_init = jax.random.split(rng)
  params_cpu, states_cpu = init(rng_init)

  if jax.process_index() == 0:
    num_params = sum(p.size for p in jax.tree_flatten(params_cpu)[0])
    parameter_overview.log_parameter_overview(params_cpu)
    writer.write_scalars(step=0, scalars={'num_params': num_params})

  @functools.partial(jax.pmap, axis_name='batch')
  def evaluation_fn(params, states, images, labels):
    variable_dict = {'params': flax.core.freeze(params), **states}
    logits, out = model.apply(
        variable_dict,
        images,
        train=False,
        mean_field_factor=gp_config.get('mean_field_factor', -1.))
    losses = getattr(train_utils, config.get('loss', 'softmax_xent'))(
        logits=logits, labels=labels, reduction=False)
    loss = jax.lax.psum(losses, axis_name='batch')
    top1_idx = jnp.argmax(logits, axis=1)

    # Extracts the label at the highest logit index for each image.
    top1_correct = jnp.take_along_axis(labels, top1_idx[:, None], axis=1)[:, 0]

    ncorrect = jax.lax.psum(top1_correct, axis_name='batch')
    n = batch_size_eval
    metric_args = jax.lax.all_gather([
        logits, labels, out['pre_logits']], axis_name='batch')
    return ncorrect, loss, n, metric_args

  # Load the optimizer from flax.
  opt_name = config.get('optim_name')
  write_note(f'Initializing {opt_name} optimizer...')
  opt_def = getattr(flax.optim, opt_name)(**config.get('optim', {}))

  # We jit this, such that the arrays that are created are created on the same
  # device as the input is, in this case the CPU. Else they'd be on device[0].
  opt_cpu = jax.jit(opt_def.create)(params_cpu)

  weight_decay_rules = config.get('weight_decay', []) or []
  rescale_value = config.lr.base if config.get('weight_decay_decouple') else 1.
  weight_decay_fn = train_utils.get_weight_decay_fn(
      weight_decay_rules=weight_decay_rules, rescale_value=rescale_value)

  @functools.partial(jax.pmap, axis_name='batch', donate_argnums=(0,))
  def update_fn(opt, states, lr, reset_covmat, images, labels, rng):
    """Update step."""
    measurements = {}

    # Get device-specific loss rng.
    rng, rng_model = jax.random.split(rng, 2)
    rng_model_local = jax.random.fold_in(rng_model, jax.lax.axis_index('batch'))

    def loss_fn(params, states, images, labels):
      # Specify mutable collection to update untrainable GP parameters.
      variable_dict = {'params': flax.core.freeze(params), **states}
      model_results, updated_states = model.apply(
          variable_dict,
          images,
          train=True,
          rngs={'dropout': rng_model_local},
          mutable=list(states.keys()),
          mean_field_factor=gp_config.get('mean_field_factor', -1.))

      logits, _ = model_results
      loss = getattr(train_utils, config.get('loss', 'sigmoid_xent'))(
          logits=logits, labels=labels)
      return loss, updated_states

    # Performs exact covariance update (i.e., reset precision matrix resetting
    # at begining of new epoch) if covmat_momentum is a null value.
    if use_gp_layer and gp_config.get('covmat_momentum', -1.) < 0:
      # Resets precision matrix to Identity * ridge_penalty if at the begining
      # of a new epoch. This should be done before accumulate gradient.
      ridge_penalty = gp_config.get('ridge_penalty', 1.)
      prec_mat_old = states['laplace_covariance']['head']['covmat_layer'][
          'precision_matrix']
      prec_mat_new = (
          (1. - reset_covmat) * prec_mat_old +
          reset_covmat * jnp.eye(prec_mat_old.shape[0]) * ridge_penalty)

      states = flax.core.unfreeze(states)
      states['laplace_covariance']['head']['covmat_layer'][
          'precision_matrix'] = prec_mat_new
      states = flax.core.freeze(states)

    # Implementation considerations compared and summarized at
    # https://docs.google.com/document/d/1g3kMEvqu1DOawaflKNyUsIoQ4yIVEoyE5ZlIPkIl4Lc/edit?hl=en#
    (l, s), g = vit_utils.accumulate_gradient_with_states(
        jax.value_and_grad(loss_fn, has_aux=True), opt.target, states, images,
        labels, config.get('grad_accum_steps'))
    l, g = jax.lax.pmean((l, g), axis_name='batch')

    # Log the gradient norm only if we need to compute it anyways (clipping)
    # or if we don't use grad_accum_steps, as they interact badly.
    if config.get('grad_accum_steps', 1) == 1 or grad_clip_norm is not None:
      grads, _ = jax.tree_flatten(g)
      l2_g = jnp.sqrt(sum([jnp.vdot(p, p) for p in grads]))
      measurements['l2_grads'] = l2_g

    # Optionally resize the global gradient to a maximum norm. We found this
    # useful in some cases across optimizers, hence it's in the main loop.
    if grad_clip_norm is not None:
      g_factor = jnp.minimum(1.0, grad_clip_norm / l2_g)
      g = jax.tree_map(lambda p: g_factor * p, g)
    opt = opt.apply_gradient(g, learning_rate=lr)
    opt = opt.replace(target=weight_decay_fn(opt.target, lr))

    params, _ = jax.tree_flatten(opt.target)
    measurements['l2_params'] = jnp.sqrt(sum([jnp.vdot(p, p) for p in params]))
    measurements['reset_covmat'] = reset_covmat

    return opt, s, l, rng, measurements

  # Set config checkpoint resume path, if provided in args.
  if config.resume_checkpoint_path is not None:
    config.resume = config.resume_checkpoint_path

  default_reinit_params = ('head/output_layer/kernel', 'head/output_layer/bias',
                           'head/kernel', 'head/bias')
  rng, train_loop_rngs = jax.random.split(rng)
  checkpoint_data = checkpoint_utils.maybe_load_checkpoint(
      train_loop_rngs=train_loop_rngs,
      save_checkpoint_path=save_checkpoint_path,
      init_optimizer=opt_cpu,
      init_params=params_cpu,
      init_fixed_model_states=states_cpu,
      default_reinit_params=default_reinit_params,
      config=config)
  train_loop_rngs = checkpoint_data.train_loop_rngs
  opt_cpu = checkpoint_data.optimizer
  states_cpu = checkpoint_data.fixed_model_states
  accumulated_train_time = checkpoint_data.accumulated_train_time

  write_note('Adapting the checkpoint model...')
  adapted_params = checkpoint_utils.adapt_upstream_architecture(
      init_params=params_cpu,
      loaded_params=opt_cpu.target)
  opt_cpu = opt_cpu.replace(target=adapted_params)

  write_note('Kicking off misc stuff...')
  first_step = int(opt_cpu.state.step)  # Might be a DeviceArray type.
  if first_step == 0 and jax.process_index() == 0:
    writer.write_hparams(dict(config))
  chrono = train_utils.Chrono(first_step, total_steps, batch_size,
                              accumulated_train_time)
  # Note: switch to ProfileAllHosts() if you need to profile all hosts.
  # (Xprof data become much larger and take longer to load for analysis)
  profiler = periodic_actions.Profile(
      # Create profile after every restart to analyze pre-emption related
      # problems and assure we get similar performance in every run.
      logdir=output_dir, first_profile=first_step + 10)

  # Prepare the learning-rate and pre-fetch it to device to avoid delays.
  lr_fn = train_utils.create_learning_rate_schedule(total_steps,
                                                    **config.get('lr', {}))

  # TODO(dusenberrymw): According to flax docs, prefetching shouldn't be
  # necessary for TPUs.
  lr_iter = train_utils.prefetch_scalar(
      map(lr_fn, range(total_steps)), config.get('prefetch_to_device', 1))

  # Prepare the precision matrix resetting schedule, and pre-fetch it to device.
  reset_covmat_fn = lambda step: float(step % steps_per_epoch == 0)
  reset_covmat_iter = train_utils.prefetch_scalar(
      map(reset_covmat_fn, range(first_step, total_steps)),
      nprefetch=config.get('prefetch_to_device', 1))

  write_note(f'Replicating...\n{chrono.note}')
  opt_repl = flax.jax_utils.replicate(opt_cpu)
  states_repl = flax.jax_utils.replicate(states_cpu)

  checkpoint_writer = None

  # Note: we return the train loss, val loss, and fewshot best l2s for use in
  # reproducibility unit tests.
  # train_loss = -jnp.inf
  # val_loss = -jnp.inf
  # results = {'dummy': {(0, 1): -jnp.inf}}

  write_note(f'First step compilations...\n{chrono.note}')
  logging.info('first_step = %s', first_step)
  # Advance the iterators if we are restarting from an earlier checkpoint.
  # TODO(dusenberrymw): Look into checkpointing dataset state instead.

  # Makes sure log_eval_steps is same as steps_per_epoch. This is because
  # the precision matrix needs to be updated fully (at the end of each epoch)
  # when eval takes place.
  log_eval_steps = steps_per_epoch
  if first_step > 0:
    write_note('Advancing iterators after resuming from a checkpoint...')
    lr_iter = itertools.islice(lr_iter, first_step, None)
    train_iter = itertools.islice(train_iter, first_step, None)

  # Using a python integer for step here, because opt.state.step is allocated
  # on TPU during replication.
  for step, train_batch, lr_repl, reset_covmat_repl in zip(
      range(first_step + 1, total_steps + 1), train_iter, lr_iter,
      reset_covmat_iter):

    with jax.profiler.TraceAnnotation('train_step', step_num=step, _r=1):
      # TODO(jereliu): Expand to allow precision matrix resetting.
      (opt_repl, states_repl, loss_value, train_loop_rngs,
       extra_measurements) = update_fn(
           opt_repl,
           states_repl,
           lr_repl,
           reset_covmat_repl,
           train_batch['image'],
           train_batch['labels'],
           rng=train_loop_rngs)

    if jax.process_index() == 0:
      profiler(step)

    # Checkpoint saving
    if train_utils.itstime(
        step, config.get('checkpoint_steps'), total_steps, process=0):
      write_note('Checkpointing...')
      chrono.pause()
      train_utils.checkpointing_timeout(checkpoint_writer,
                                        config.get('checkpoint_timeout', 1))
      accumulated_train_time = chrono.accum_train_time
      # We need to transfer the weights over now or else we risk keeping them
      # alive while they'll be updated in a future step, creating hard to debug
      # memory errors (see b/160593526). Also, takes device 0's params only.
      # For GP layer, we will also do the same for untrainable parameters
      # (`states`). This is ok since `random features` are frozen throughout
      # pre-training, and `precision matrix` is a finetuning-specific parameters
      # that will be re-learned in the finetuning task.
      opt_cpu = jax.tree_map(lambda x: np.array(x[0]), opt_repl)
      states_cpu = jax.tree_map(lambda x: np.array(x[0]), states_repl)

      # Check whether we want to keep a copy of the current checkpoint.
      copy_step = None
      if train_utils.itstime(step, config.get('keep_checkpoint_steps'),
                             total_steps):
        write_note('Keeping a checkpoint copy...')
        copy_step = step

      # Checkpoint should be a nested dictionary or FLAX datataclasses from
      # `flax.struct`. Both can be present in a checkpoint.
      checkpoint_data = checkpoint_utils.CheckpointData(
          optimizer=opt_cpu,
          fixed_model_states=states_cpu,
          train_loop_rngs=train_loop_rngs,
          accumulated_train_time=accumulated_train_time)
      checkpoint_writer = pool.apply_async(
          checkpoint_utils.checkpoint_trained_model,
          (checkpoint_data, save_checkpoint_path, copy_step))
      chrono.resume()

    # Report training progress
    if train_utils.itstime(
        step, config.log_training_steps, total_steps, process=0):
      write_note('Reporting training progress...')
      train_loss = loss_value[0]  # Keep to return for reproducibility tests.
      timing_measurements, note = chrono.tick(step)
      write_note(note)
      train_measurements = {}
      train_measurements.update({
          'learning_rate': lr_repl[0],
          'training_loss': train_loss,
      })
      train_measurements.update(flax.jax_utils.unreplicate(extra_measurements))
      train_measurements.update(timing_measurements)
      writer.write_scalars(step, train_measurements)

    # Report validation performance
    if train_utils.itstime(step, log_eval_steps, total_steps):
      write_note('Evaluating on the validation set...')
      chrono.pause()

      all_eval_results = {}

      for eval_name, (eval_iter, eval_steps) in eval_iter_splits.items():
        start_time = time.time()

        # Runs evaluation loop.
        results_arrs = {
            'y_true': [],
            'y_pred': [],
            'y_pred_entropy': []
        }

        for _, batch in zip(range(eval_steps), eval_iter):
          batch_ncorrect, batch_losses, batch_n, batch_metric_args = (  # pylint: disable=unused-variable
              evaluation_fn(
                  opt_repl.target, states_repl, batch['image'],
                  batch['labels']))

          # All results are a replicated array shaped as follows:
          # (local_devices, per_device_batch_size, elem_shape...)
          # with each local device's entry being identical as they got psum'd.
          # So let's just take the first one to the host as numpy.

          # Here we parse batch_metric_args to compute uncertainty metrics.
          logits, labels, _ = batch_metric_args
          logits = np.array(logits[0])
          probs = jax.nn.softmax(logits)

          # From one-hot to integer labels.
          int_labels = np.argmax(np.array(labels[0]), axis=-1)

          probs = np.reshape(probs, (probs.shape[0] * probs.shape[1], -1))
          int_labels = int_labels.flatten()
          y_pred = probs[:, 1]
          results_arrs['y_true'].append(int_labels)
          results_arrs['y_pred'].append(y_pred)

          # Entropy is computed at the per-epoch level (see below).
          results_arrs['y_pred_entropy'].append(probs)

        results_arrs['y_true'] = np.concatenate(results_arrs['y_true'],
                                                axis=0)
        results_arrs['y_pred'] = np.concatenate(
            results_arrs['y_pred'], axis=0).astype('float64')
        results_arrs['y_pred_entropy'] = vit_utils.entropy(
            np.concatenate(results_arrs['y_pred_entropy'], axis=0), axis=-1)

        time_elapsed = time.time() - start_time
        results_arrs['total_ms_elapsed'] = time_elapsed * 1e3
        results_arrs['dataset_size'] = eval_steps * batch_size_eval

        all_eval_results[eval_name] = results_arrs

      per_pred_results, metrics_results = vit_utils.evaluate_vit_predictions(  # pylint: disable=unused-variable
          dataset_split_to_containers=all_eval_results,
          is_deterministic=True,
          num_bins=15,
          return_per_pred_results=True
      )

      # `metrics_results` is a dict of {str: jnp.ndarray} dicts, one for each
      # dataset. Flatten this dict so we can pass to the writer and remove empty
      # entries.
      flattened_metric_results = {}
      for dic in metrics_results.values():
        for key, value in dic.items():
          if value is not None:
            flattened_metric_results[key] = value
      writer.write_scalars(step, flattened_metric_results)

      # Optionally log to wandb
      if config.use_wandb:
        wandb.log(metrics_results, step=step)

      # Save per-prediction metrics
      results_storage_utils.save_per_prediction_results(
          output_dir, step, per_pred_results, verbose=False)

      chrono.resume()

      # End of step.
    if config.get('testing_failure_step'):
      # Break early to simulate infra failures in test cases.
      if config.testing_failure_step == step:
        break

  write_note(f'Done!\n{chrono.note}')
  pool.close()
  pool.join()
  writer.close()

  if wandb_run is not None:
    wandb_run.finish()
def main(config, output_dir):

    seed = config.get('seed', 0)
    rng = jax.random.PRNGKey(seed)
    tf.random.set_seed(seed)

    if config.get('data_dir'):
        logging.info('data_dir=%s', config.data_dir)
    logging.info('Output dir: %s', output_dir)

    save_checkpoint_path = None
    if config.get('checkpoint_steps'):
        gfile.makedirs(output_dir)
        save_checkpoint_path = os.path.join(output_dir, 'checkpoint.npz')

    # Create an asynchronous multi-metric writer.
    writer = metric_writers.create_default_writer(
        output_dir, just_logging=jax.process_index() > 0)

    # The pool is used to perform misc operations such as logging in async way.
    pool = multiprocessing.pool.ThreadPool()

    def write_note(note):
        if jax.process_index() == 0:
            logging.info('NOTE: %s', note)

    write_note('Initializing...')

    # Verify settings to make sure no checkpoints are accidentally missed.
    if config.get('keep_checkpoint_steps'):
        assert config.get('checkpoint_steps'), 'Specify `checkpoint_steps`.'
        assert config.keep_checkpoint_steps % config.checkpoint_steps == 0, (
            f'`keep_checkpoint_steps` ({config.checkpoint_steps}) should be'
            f'divisible by `checkpoint_steps ({config.checkpoint_steps}).`')

    batch_size = config.batch_size
    batch_size_eval = config.get('batch_size_eval', batch_size)
    if (batch_size % jax.device_count() != 0
            or batch_size_eval % jax.device_count() != 0):
        raise ValueError(
            f'Batch sizes ({batch_size} and {batch_size_eval}) must '
            f'be divisible by device number ({jax.device_count()})')

    local_batch_size = batch_size // jax.process_count()
    local_batch_size_eval = batch_size_eval // jax.process_count()
    logging.info(
        'Global batch size %d on %d hosts results in %d local batch size. '
        'With %d devices per host (%d devices total), that\'s a %d per-device '
        'batch size.', batch_size, jax.process_count(), local_batch_size,
        jax.local_device_count(), jax.device_count(),
        local_batch_size // jax.local_device_count())

    write_note('Initializing train dataset...')
    rng, train_ds_rng = jax.random.split(rng)
    train_ds_rng = jax.random.fold_in(train_ds_rng, jax.process_index())
    train_ds = input_utils.get_data(
        dataset=config.dataset,
        split=config.train_split,
        rng=train_ds_rng,
        process_batch_size=local_batch_size,
        preprocess_fn=preprocess_spec.parse(
            spec=config.pp_train, available_ops=preprocess_utils.all_ops()),
        shuffle_buffer_size=config.shuffle_buffer_size,
        prefetch_size=config.get('prefetch_to_host', 2),
        data_dir=config.get('data_dir'))

    # Start prefetching already.
    train_iter = input_utils.start_input_pipeline(
        train_ds, config.get('prefetch_to_device', 1))

    write_note('Initializing val dataset(s)...')

    def _get_val_split(dataset, split, pp_eval, data_dir=None):
        # We do ceil rounding such that we include the last incomplete batch.
        nval_img = input_utils.get_num_examples(
            dataset,
            split=split,
            process_batch_size=local_batch_size_eval,
            drop_remainder=False,
            data_dir=data_dir)
        val_steps = int(np.ceil(nval_img / batch_size_eval))
        logging.info('Running validation for %d steps for %s, %s', val_steps,
                     dataset, split)

        if isinstance(pp_eval, str):
            pp_eval = preprocess_spec.parse(
                spec=pp_eval, available_ops=preprocess_utils.all_ops())

        val_ds = input_utils.get_data(dataset=dataset,
                                      split=split,
                                      rng=None,
                                      process_batch_size=local_batch_size_eval,
                                      preprocess_fn=pp_eval,
                                      cache=config.get('val_cache', 'batched'),
                                      repeat_after_batching=True,
                                      shuffle=False,
                                      prefetch_size=config.get(
                                          'prefetch_to_host', 2),
                                      drop_remainder=False,
                                      data_dir=data_dir)
        val_iter = input_utils.start_input_pipeline(
            val_ds, config.get('prefetch_to_device', 1))

        return (val_iter, val_steps)

    val_iter_splits = {
        'val':
        _get_val_split(config.dataset,
                       split=config.val_split,
                       pp_eval=config.pp_eval,
                       data_dir=config.get('data_dir'))
    }

    if config.get('eval_on_cifar_10h'):
        cifar10_to_cifar10h_fn = data_uncertainty_utils.create_cifar10_to_cifar10h_fn(
            config.get('data_dir', None))
        preprocess_fn = preprocess_spec.parse(
            spec=config.pp_eval_cifar_10h,
            available_ops=preprocess_utils.all_ops())
        pp_eval = lambda ex: preprocess_fn(cifar10_to_cifar10h_fn(ex))
        val_iter_splits['cifar_10h'] = _get_val_split(
            'cifar10',
            split=config.get('cifar_10h_split') or 'test',
            pp_eval=pp_eval,
            data_dir=config.get('data_dir'))
    elif config.get('eval_on_imagenet_real'):
        imagenet_to_real_fn = data_uncertainty_utils.create_imagenet_to_real_fn(
        )
        preprocess_fn = preprocess_spec.parse(
            spec=config.pp_eval_imagenet_real,
            available_ops=preprocess_utils.all_ops())
        pp_eval = lambda ex: preprocess_fn(imagenet_to_real_fn(ex))
        val_iter_imagenet_real, val_steps = _get_val_split(
            'imagenet2012_real',
            split=config.get('imagenet_real_split') or 'validation',
            pp_eval=pp_eval,
            data_dir=config.get('data_dir'))
        val_iter_splits['imagenet_real'] = (val_iter_imagenet_real, val_steps)

    ood_ds = {}
    if config.get('ood_datasets') and config.get('ood_methods'):
        if config.get(
                'ood_methods'):  #  config.ood_methods is not a empty list
            logging.info('loading OOD dataset = %s', config.get('ood_dataset'))
            ood_ds, ood_ds_names = ood_utils.load_ood_datasets(
                config.dataset,
                config.ood_datasets,
                config.ood_split,
                config.pp_eval,
                config.pp_eval_ood,
                config.ood_methods,
                config.train_split,
                config.get('data_dir'),
                _get_val_split,
            )

    ntrain_img = input_utils.get_num_examples(
        config.dataset,
        split=config.train_split,
        process_batch_size=local_batch_size,
        data_dir=config.get('data_dir'))
    steps_per_epoch = int(ntrain_img / batch_size)

    if config.get('num_epochs'):
        total_steps = int(config.num_epochs * steps_per_epoch)
        assert not config.get(
            'total_steps'), 'Set either num_epochs or total_steps'
    else:
        total_steps = config.total_steps

    logging.info('Total train data points: %d', ntrain_img)
    logging.info(
        'Running for %d steps, that means %f epochs and %d steps per epoch',
        total_steps, total_steps * batch_size / ntrain_img, steps_per_epoch)

    write_note('Initializing model...')
    logging.info('config.model = %s', config.get('model'))
    model = ub.models.vision_transformer(num_classes=config.num_classes,
                                         **config.get('model', {}))

    # We want all parameters to be created in host RAM, not on any device, they'll
    # be sent there later as needed, otherwise we already encountered two
    # situations where we allocate them twice.
    @partial(jax.jit, backend='cpu')
    def init(rng):
        image_size = tuple(train_ds.element_spec['image'].shape[2:])
        logging.info('image_size = %s', image_size)
        dummy_input = jnp.zeros((local_batch_size, ) + image_size, jnp.float32)
        params = flax.core.unfreeze(model.init(rng, dummy_input,
                                               train=False))['params']

        # Set bias in the head to a low value, such that loss is small initially.
        params['head']['bias'] = jnp.full_like(params['head']['bias'],
                                               config.get('init_head_bias', 0))

        # init head kernel to all zeros for fine-tuning
        if config.get('model_init'):
            params['head']['kernel'] = jnp.full_like(params['head']['kernel'],
                                                     0)

        return params

    rng, rng_init = jax.random.split(rng)
    params_cpu = init(rng_init)

    if jax.process_index() == 0:
        num_params = sum(p.size for p in jax.tree_flatten(params_cpu)[0])
        parameter_overview.log_parameter_overview(params_cpu)
        writer.write_scalars(step=0, scalars={'num_params': num_params})

    @partial(jax.pmap, axis_name='batch')
    def evaluation_fn(params, images, labels, mask):
        # Ignore the entries with all zero labels for evaluation.
        mask *= labels.max(axis=1)
        logits, out = model.apply({'params': flax.core.freeze(params)},
                                  images,
                                  train=False)
        label_indices = config.get('label_indices')
        logging.info('!!! mask %s, label_indices %s', mask, label_indices)
        if label_indices:
            logits = logits[:, label_indices]

        # Note that logits and labels are usually of the shape [batch,num_classes].
        # But for OOD data, when num_classes_ood > num_classes_ind, we need to
        # adjust labels to labels[:, :config.num_classes] to match the shape of
        # logits. That is just to avoid shape mismatch. The output losses does not
        # have any meaning for OOD data, because OOD not belong to any IND class.
        losses = getattr(train_utils, config.get('loss', 'sigmoid_xent'))(
            logits=logits,
            labels=labels[:, :(
                len(label_indices) if label_indices else config.num_classes)],
            reduction=False)
        loss = jax.lax.psum(losses * mask, axis_name='batch')

        top1_idx = jnp.argmax(logits, axis=1)
        # Extracts the label at the highest logit index for each image.
        top1_correct = jnp.take_along_axis(labels, top1_idx[:, None],
                                           axis=1)[:, 0]
        ncorrect = jax.lax.psum(top1_correct * mask, axis_name='batch')
        n = jax.lax.psum(mask, axis_name='batch')

        metric_args = jax.lax.all_gather(
            [logits, labels, out['pre_logits'], mask], axis_name='batch')
        return ncorrect, loss, n, metric_args

    @partial(jax.pmap, axis_name='batch')
    def cifar_10h_evaluation_fn(params, images, labels, mask):
        logits, out = model.apply({'params': flax.core.freeze(params)},
                                  images,
                                  train=False)
        label_indices = config.get('label_indices')
        if label_indices:
            logits = logits[:, label_indices]

        losses = getattr(train_utils,
                         config.get('loss', 'softmax_xent'))(logits=logits,
                                                             labels=labels,
                                                             reduction=False)
        loss = jax.lax.psum(losses, axis_name='batch')

        top1_idx = jnp.argmax(logits, axis=1)
        # Extracts the label at the highest logit index for each image.
        one_hot_labels = jnp.eye(10)[jnp.argmax(labels, axis=1)]

        top1_correct = jnp.take_along_axis(one_hot_labels,
                                           top1_idx[:, None],
                                           axis=1)[:, 0]
        ncorrect = jax.lax.psum(top1_correct, axis_name='batch')
        n = jax.lax.psum(one_hot_labels, axis_name='batch')

        metric_args = jax.lax.all_gather(
            [logits, labels, out['pre_logits'], mask], axis_name='batch')
        return ncorrect, loss, n, metric_args

    # Setup function for computing representation.
    @partial(jax.pmap, axis_name='batch')
    def representation_fn(params, images, labels, mask):
        _, outputs = model.apply({'params': flax.core.freeze(params)},
                                 images,
                                 train=False)
        representation = outputs[config.fewshot.representation_layer]
        representation = jax.lax.all_gather(representation, 'batch')
        labels = jax.lax.all_gather(labels, 'batch')
        mask = jax.lax.all_gather(mask, 'batch')
        return representation, labels, mask

    # Load the optimizer from flax.
    opt_name = config.get('optim_name')
    write_note(f'Initializing {opt_name} optimizer...')
    opt_def = getattr(flax.optim, opt_name)(**config.get('optim', {}))

    # We jit this, such that the arrays that are created are created on the same
    # device as the input is, in this case the CPU. Else they'd be on device[0].
    opt_cpu = jax.jit(opt_def.create)(params_cpu)

    weight_decay_rules = config.get('weight_decay', []) or []
    rescale_value = config.lr.base if config.get(
        'weight_decay_decouple') else 1.
    weight_decay_fn = train_utils.get_weight_decay_fn(
        weight_decay_rules=weight_decay_rules, rescale_value=rescale_value)

    @partial(jax.pmap, axis_name='batch', donate_argnums=(0, ))
    def update_fn(opt, lr, images, labels, rng):
        """Update step."""

        measurements = {}

        # Get device-specific loss rng.
        rng, rng_model = jax.random.split(rng, 2)
        rng_model_local = jax.random.fold_in(rng_model,
                                             jax.lax.axis_index('batch'))

        def loss_fn(params, images, labels):
            logits, _ = model.apply({'params': flax.core.freeze(params)},
                                    images,
                                    train=True,
                                    rngs={'dropout': rng_model_local})
            label_indices = config.get('label_indices')
            if label_indices:
                logits = logits[:, label_indices]
            return getattr(train_utils,
                           config.get('loss', 'sigmoid_xent'))(logits=logits,
                                                               labels=labels)

        # Implementation considerations compared and summarized at
        # https://docs.google.com/document/d/1g3kMEvqu1DOawaflKNyUsIoQ4yIVEoyE5ZlIPkIl4Lc/edit?hl=en#
        l, g = train_utils.accumulate_gradient(jax.value_and_grad(loss_fn),
                                               opt.target, images, labels,
                                               config.get('grad_accum_steps'))
        l, g = jax.lax.pmean((l, g), axis_name='batch')

        # Log the gradient norm only if we need to compute it anyways (clipping)
        # or if we don't use grad_accum_steps, as they interact badly.
        if config.get('grad_accum_steps',
                      1) == 1 or config.get('grad_clip_norm'):
            grads, _ = jax.tree_flatten(g)
            l2_g = jnp.sqrt(sum([jnp.vdot(p, p) for p in grads]))
            measurements['l2_grads'] = l2_g

        # Optionally resize the global gradient to a maximum norm. We found this
        # useful in some cases across optimizers, hence it's in the main loop.
        if config.get('grad_clip_norm'):
            g_factor = jnp.minimum(1.0, config.grad_clip_norm / l2_g)
            g = jax.tree_util.tree_map(lambda p: g_factor * p, g)
        opt = opt.apply_gradient(g, learning_rate=lr)

        opt = opt.replace(target=weight_decay_fn(opt.target, lr))

        params, _ = jax.tree_flatten(opt.target)
        measurements['l2_params'] = jnp.sqrt(
            sum([jnp.vdot(p, p) for p in params]))

        return opt, l, rng, measurements

    rng, train_loop_rngs = jax.random.split(rng)
    reint_params = ('head/kernel', 'head/bias')
    if config.get('only_eval', False) or not config.get('reint_head', True):
        reint_params = []
    checkpoint_data = checkpoint_utils.maybe_load_checkpoint(
        train_loop_rngs=train_loop_rngs,
        save_checkpoint_path=save_checkpoint_path,
        init_optimizer=opt_cpu,
        init_params=params_cpu,
        init_fixed_model_states=None,
        default_reinit_params=reint_params,
        config=config,
    )
    train_loop_rngs = checkpoint_data.train_loop_rngs
    opt_cpu = checkpoint_data.optimizer
    accumulated_train_time = checkpoint_data.accumulated_train_time

    write_note('Adapting the checkpoint model...')
    adapted_params = checkpoint_utils.adapt_upstream_architecture(
        init_params=params_cpu, loaded_params=opt_cpu.target)
    opt_cpu = opt_cpu.replace(target=adapted_params)

    write_note('Kicking off misc stuff...')
    first_step = int(opt_cpu.state.step)  # Might be a DeviceArray type.
    if first_step == 0 and jax.process_index() == 0:
        writer.write_hparams(dict(config))
    chrono = train_utils.Chrono(first_step, total_steps, batch_size,
                                accumulated_train_time)
    # Note: switch to ProfileAllHosts() if you need to profile all hosts.
    # (Xprof data become much larger and take longer to load for analysis)
    profiler = periodic_actions.Profile(
        # Create profile after every restart to analyze pre-emption related
        # problems and assure we get similar performance in every run.
        logdir=output_dir,
        first_profile=first_step + 10)

    # Prepare the learning-rate and pre-fetch it to device to avoid delays.
    lr_fn = train_utils.create_learning_rate_schedule(total_steps,
                                                      **config.get('lr', {}))
    # TODO(dusenberrymw): According to flax docs, prefetching shouldn't be
    # necessary for TPUs.
    lr_iter = train_utils.prefetch_scalar(map(lr_fn, range(total_steps)),
                                          config.get('prefetch_to_device', 1))

    write_note(f'Replicating...\n{chrono.note}')
    opt_repl = flax_utils.replicate(opt_cpu)

    write_note(f'Initializing few-shotters...\n{chrono.note}')
    fewshotter = None
    if 'fewshot' in config and fewshot is not None:
        fewshotter = fewshot.FewShotEvaluator(
            representation_fn, config.fewshot,
            config.fewshot.get('batch_size') or batch_size_eval)

    checkpoint_writer = None

    # Note: we return the train loss, val loss, and fewshot best l2s for use in
    # reproducibility unit tests.
    train_loss = -jnp.inf
    val_loss = {val_name: -jnp.inf for val_name, _ in val_iter_splits.items()}
    fewshot_results = {'dummy': {(0, 1): -jnp.inf}}

    write_note(f'First step compilations...\n{chrono.note}')
    logging.info('first_step = %s', first_step)
    # Advance the iterators if we are restarting from an earlier checkpoint.
    # TODO(dusenberrymw): Look into checkpointing dataset state instead.
    if first_step > 0:
        write_note('Advancing iterators after resuming from a checkpoint...')
        lr_iter = itertools.islice(lr_iter, first_step, None)
        train_iter = itertools.islice(train_iter, first_step, None)
        # NOTE: Validation eval is only run on certain steps, so determine how many
        # times it was run previously.
        num_val_runs = sum(
            map(
                lambda i: train_utils.itstime(i, config.log_eval_steps,
                                              total_steps),
                range(1, first_step + 1)))
        for val_name, (val_iter, val_steps) in val_iter_splits.items():
            val_iter = itertools.islice(val_iter, num_val_runs * val_steps,
                                        None)
            val_iter_splits[val_name] = (val_iter, val_steps)

    # Using a python integer for step here, because opt.state.step is allocated
    # on TPU during replication.
    for step, train_batch, lr_repl in zip(
            range(first_step + 1, total_steps + 1), train_iter, lr_iter):

        with jax.profiler.TraceAnnotation('train_step', step_num=step, _r=1):
            if not config.get('only_eval', False):
                opt_repl, loss_value, train_loop_rngs, extra_measurements = update_fn(
                    opt_repl,
                    lr_repl,
                    train_batch['image'],
                    train_batch['labels'],
                    rng=train_loop_rngs)

        if jax.process_index() == 0:
            profiler(step)

        # Checkpoint saving
        if not config.get('only_eval', False) and train_utils.itstime(
                step, config.get('checkpoint_steps'), total_steps, process=0):
            write_note('Checkpointing...')
            chrono.pause()
            train_utils.checkpointing_timeout(
                checkpoint_writer, config.get('checkpoint_timeout', 1))
            accumulated_train_time = chrono.accum_train_time
            # We need to transfer the weights over now or else we risk keeping them
            # alive while they'll be updated in a future step, creating hard to debug
            # memory errors (see b/160593526). Also, takes device 0's params only.
            opt_cpu = jax.tree_util.tree_map(lambda x: np.array(x[0]),
                                             opt_repl)

            # Check whether we want to keep a copy of the current checkpoint.
            copy_step = None
            if train_utils.itstime(step, config.get('keep_checkpoint_steps'),
                                   total_steps):
                write_note('Keeping a checkpoint copy...')
                copy_step = step

            # Checkpoint should be a nested dictionary or FLAX datataclasses from
            # `flax.struct`. Both can be present in a checkpoint.
            checkpoint_data = checkpoint_utils.CheckpointData(
                train_loop_rngs=train_loop_rngs,
                optimizer=opt_cpu,
                accumulated_train_time=accumulated_train_time)

            checkpoint_writer = pool.apply_async(
                checkpoint_utils.checkpoint_trained_model,
                (checkpoint_data, save_checkpoint_path, copy_step))
            chrono.resume()

        # Report training progress
        if not config.get('only_eval', False) and train_utils.itstime(
                step, config.log_training_steps, total_steps, process=0):
            write_note('Reporting training progress...')
            train_loss = loss_value[
                0]  # Keep to return for reproducibility tests.
            timing_measurements, note = chrono.tick(step)
            write_note(note)
            train_measurements = {}
            train_measurements.update({
                'learning_rate': lr_repl[0],
                'training_loss': train_loss,
            })
            train_measurements.update(
                flax.jax_utils.unreplicate(extra_measurements))
            train_measurements.update(timing_measurements)
            writer.write_scalars(step, train_measurements)

        # Report validation performance
        if train_utils.itstime(step, config.log_eval_steps, total_steps):
            write_note('Evaluating on the validation set...')
            chrono.pause()
            for val_name, (val_iter, val_steps) in val_iter_splits.items():
                # Sets up evaluation metrics.
                ece_num_bins = config.get('ece_num_bins', 15)
                auc_num_bins = config.get('auc_num_bins', 1000)
                ece = rm.metrics.ExpectedCalibrationError(
                    num_bins=ece_num_bins)
                calib_auc = rm.metrics.CalibrationAUC(
                    correct_pred_as_pos_label=False)
                oc_auc_0_5 = rm.metrics.OracleCollaborativeAUC(
                    oracle_fraction=0.005, num_bins=auc_num_bins)
                oc_auc_1 = rm.metrics.OracleCollaborativeAUC(
                    oracle_fraction=0.01, num_bins=auc_num_bins)
                oc_auc_2 = rm.metrics.OracleCollaborativeAUC(
                    oracle_fraction=0.02, num_bins=auc_num_bins)
                oc_auc_5 = rm.metrics.OracleCollaborativeAUC(
                    oracle_fraction=0.05, num_bins=auc_num_bins)
                label_diversity = tf.keras.metrics.Mean()
                sample_diversity = tf.keras.metrics.Mean()
                ged = tf.keras.metrics.Mean()

                # Runs evaluation loop.
                ncorrect, loss, nseen = 0, 0, 0
                for _, batch in zip(range(val_steps), val_iter):
                    if val_name == 'cifar_10h':
                        batch_ncorrect, batch_losses, batch_n, batch_metric_args = (
                            cifar_10h_evaluation_fn(opt_repl.target,
                                                    batch['image'],
                                                    batch['labels'],
                                                    batch['mask']))
                    else:
                        batch_ncorrect, batch_losses, batch_n, batch_metric_args = (
                            evaluation_fn(opt_repl.target, batch['image'],
                                          batch['labels'], batch['mask']))
                    # All results are a replicated array shaped as follows:
                    # (local_devices, per_device_batch_size, elem_shape...)
                    # with each local device's entry being identical as they got psum'd.
                    # So let's just take the first one to the host as numpy.
                    ncorrect += np.sum(np.array(batch_ncorrect[0]))
                    loss += np.sum(np.array(batch_losses[0]))
                    nseen += np.sum(np.array(batch_n[0]))
                    if config.get('loss', 'sigmoid_xent') != 'sigmoid_xent':
                        # Here we parse batch_metric_args to compute uncertainty metrics.
                        # (e.g., ECE or Calibration AUC).
                        logits, labels, _, masks = batch_metric_args
                        masks = np.array(masks[0], dtype=np.bool)
                        logits = np.array(logits[0])
                        probs = jax.nn.softmax(logits)
                        # From one-hot to integer labels, as required by ECE.
                        int_labels = np.argmax(np.array(labels[0]), axis=-1)
                        int_preds = np.argmax(logits, axis=-1)
                        confidence = np.max(probs, axis=-1)
                        for p, c, l, d, m, label in zip(
                                probs, confidence, int_labels, int_preds,
                                masks, labels[0]):
                            ece.add_batch(p[m, :], label=l[m])
                            calib_auc.add_batch(d[m],
                                                label=l[m],
                                                confidence=c[m])
                            # TODO(jereliu): Extend to support soft multi-class probabilities.
                            oc_auc_0_5.add_batch(d[m],
                                                 label=l[m],
                                                 custom_binning_score=c[m])
                            oc_auc_1.add_batch(d[m],
                                               label=l[m],
                                               custom_binning_score=c[m])
                            oc_auc_2.add_batch(d[m],
                                               label=l[m],
                                               custom_binning_score=c[m])
                            oc_auc_5.add_batch(d[m],
                                               label=l[m],
                                               custom_binning_score=c[m])

                            if val_name == 'cifar_10h' or val_name == 'imagenet_real':
                                batch_label_diversity, batch_sample_diversity, batch_ged = data_uncertainty_utils.generalized_energy_distance(
                                    label[m], p[m, :], config.num_classes)
                                label_diversity.update_state(
                                    batch_label_diversity)
                                sample_diversity.update_state(
                                    batch_sample_diversity)
                                ged.update_state(batch_ged)

                val_loss[
                    val_name] = loss / nseen  # Keep for reproducibility tests.
                val_measurements = {
                    f'{val_name}_prec@1': ncorrect / nseen,
                    f'{val_name}_loss': val_loss[val_name],
                }
                if config.get('loss', 'sigmoid_xent') != 'sigmoid_xent':
                    val_measurements[f'{val_name}_ece'] = ece.result()['ece']
                    val_measurements[
                        f'{val_name}_calib_auc'] = calib_auc.result(
                        )['calibration_auc']
                    val_measurements[
                        f'{val_name}_oc_auc_0.5%'] = oc_auc_0_5.result(
                        )['collaborative_auc']
                    val_measurements[
                        f'{val_name}_oc_auc_1%'] = oc_auc_1.result(
                        )['collaborative_auc']
                    val_measurements[
                        f'{val_name}_oc_auc_2%'] = oc_auc_2.result(
                        )['collaborative_auc']
                    val_measurements[
                        f'{val_name}_oc_auc_5%'] = oc_auc_5.result(
                        )['collaborative_auc']
                writer.write_scalars(step, val_measurements)

                if val_name == 'cifar_10h' or val_name == 'imagenet_real':
                    cifar_10h_measurements = {
                        f'{val_name}_label_diversity':
                        label_diversity.result(),
                        f'{val_name}_sample_diversity':
                        sample_diversity.result(),
                        f'{val_name}_ged': ged.result(),
                    }
                    writer.write_scalars(step, cifar_10h_measurements)

            # OOD eval
            # Entries in the ood_ds dict include:
            # (ind_dataset, ood_dataset1, ood_dataset2, ...).
            # OOD metrics are computed using ind_dataset paired with each of the
            # ood_dataset. When Mahalanobis distance method is applied, train_ind_ds
            # is also included in the ood_ds.
            if ood_ds and config.ood_methods:
                ood_measurements = ood_utils.eval_ood_metrics(
                    ood_ds, ood_ds_names, config.ood_methods, evaluation_fn,
                    opt_repl)
                writer.write_scalars(step, ood_measurements)
            chrono.resume()

        if 'fewshot' in config and fewshotter is not None:
            # Compute few-shot on-the-fly evaluation.
            if train_utils.itstime(step, config.fewshot.log_steps,
                                   total_steps):
                chrono.pause()
                write_note(f'Few-shot evaluation...\n{chrono.note}')
                # Keep `results` to return for reproducibility tests.
                fewshot_results, best_l2 = fewshotter.run_all(
                    opt_repl.target, config.fewshot.datasets)

                # TODO(dusenberrymw): Remove this once fewshot.py is updated.
                def make_writer_measure_fn(step):
                    def writer_measure(name, value):
                        writer.write_scalars(step, {name: value})

                    return writer_measure

                fewshotter.walk_results(make_writer_measure_fn(step),
                                        fewshot_results, best_l2)
                chrono.resume()

        # End of step.
        if config.get('testing_failure_step'):
            # Break early to simulate infra failures in test cases.
            if config.testing_failure_step == step:
                break

    write_note(f'Done!\n{chrono.note}')
    pool.close()
    pool.join()
    writer.close()

    # Return final training loss, validation loss, and fewshot results for
    # reproducibility test cases.
    return train_loss, val_loss, fewshot_results
Example #19
0
def train_and_evaluate(config, workdir):
    """Runs a training and evaluation loop.

  Args:
    config: Configuration to use.
    workdir: Working directory for checkpoints and TF summaries. If this
      contains checkpoint training will be resumed from the latest checkpoint.
  """
    tf.io.gfile.makedirs(workdir)

    rng = jax.random.PRNGKey(config.seed)

    # Build input pipeline.
    rng, data_rng = jax.random.split(rng)
    data_rng = jax.random.fold_in(data_rng, jax.host_id())
    splits = input_pipeline.create_datasets(config, data_rng)
    num_classes = splits.info.features["label"].num_classes
    train_iter = iter(splits.train)  # pytype: disable=wrong-arg-types

    # Learning rate schedule.
    num_train_steps = config.num_train_steps
    if num_train_steps == -1:
        num_train_steps = splits.train.cardinality().numpy()
    steps_per_epoch = num_train_steps // config.num_epochs
    logging.info("num_train_steps=%d, steps_per_epoch=%d", num_train_steps,
                 steps_per_epoch)
    # We treat the learning rate in the config as the learning rate for batch size
    # 32 but scale it according to our batch size.
    global_batch_size = config.per_device_batch_size * jax.device_count()
    base_learning_rate = config.learning_rate * global_batch_size / 32.0
    learning_rate_fn = functools.partial(get_learning_rate,
                                         base_learning_rate=base_learning_rate,
                                         steps_per_epoch=steps_per_epoch,
                                         num_epochs=config.num_epochs,
                                         warmup_epochs=config.warmup_epochs)

    # Initialize model.
    rng, model_rng = jax.random.split(rng)
    model, state = create_train_state(
        config,
        model_rng,
        input_shape=splits.train.element_spec["input"].shape[1:],
        num_classes=num_classes)

    # Set up checkpointing of the model and the input pipeline.
    checkpoint_dir = os.path.join(workdir, "checkpoints")
    ckpt = checkpoint.MultihostCheckpoint(checkpoint_dir,
                                          {"train_iter": train_iter},
                                          max_to_keep=2)
    state = ckpt.restore_or_initialize(state)
    initial_step = int(state.step) + 1

    # Count number of trainable parameters. This must be done before replicating
    # the state to avoid double-counting replicated parameters.
    param_count = sum(p.size for p in jax.tree_leaves(state.optimizer.target))

    # Distribute training over local devices.
    state = flax_utils.replicate(state)

    p_train_step = jax.pmap(functools.partial(
        train_step,
        model=model,
        learning_rate_fn=learning_rate_fn,
        weight_decay=config.weight_decay),
                            axis_name=_PMAP_AXIS_NAME)

    writer = metric_writers.create_default_writer(
        workdir, just_logging=jax.host_id() > 0)
    if initial_step == 1:
        writer.write_hparams(dict(config))
        # Log the number of trainable params.
        writer.write_scalars(initial_step, {"param_count": param_count})

    logging.info("Starting training loop at step %d.", initial_step)
    hooks = []
    report_progress = periodic_actions.ReportProgress(
        num_train_steps=num_train_steps, writer=writer)
    if jax.host_id() == 0:
        hooks += [
            report_progress,
            periodic_actions.Profile(num_profile_steps=5, logdir=workdir)
        ]
    train_metrics = None
    with metric_writers.ensure_flushes(writer):
        for step in range(initial_step, num_train_steps + 1):
            # `step` is a Python integer. `state.step` is JAX integer on the GPU/TPU
            # devices.
            is_last_step = step == num_train_steps

            with jax.profiler.StepTraceContext("train", step_num=step):
                batch = jax.tree_map(np.asarray, next(train_iter))
                state, metrics_update = p_train_step(state=state, batch=batch)
                metric_update = flax_utils.unreplicate(metrics_update)
                train_metrics = (metric_update if train_metrics is None else
                                 train_metrics.merge(metric_update))

            # Quick indication that training is happening.
            logging.log_first_n(logging.INFO, "Finished training step %d.", 5,
                                step)
            for h in hooks:
                h(step)

            if step % config.log_loss_every_steps == 0 or is_last_step:
                writer.write_scalars(step, train_metrics.compute())
                train_metrics = None

            # When combining train and eval, we do not evaluate while training.
            if ((step % config.eval_every_steps == 0 or is_last_step)
                    and not config.combine_train_val_and_eval_on_test):
                with report_progress.timed("eval"):
                    eval_metrics = evaluate(model, state, splits.validation,
                                            config.num_eval_steps)
                writer.write_scalars(step, eval_metrics.compute())

            if step % config.checkpoint_every_steps == 0 or is_last_step:
                with report_progress.timed("checkpoint"):
                    ckpt.save(flax_utils.unreplicate(state))

            if is_last_step and config.combine_train_val_and_eval_on_test:
                # Evaluate a single time on the test set when requested.
                with report_progress.timed("test"):
                    test_metrics = evaluate(model, state, splits.test,
                                            config.num_eval_steps)
                writer.write_scalars(step, test_metrics.compute())

    logging.info("Finishing training at step %d", num_train_steps)
Example #20
0
def train_and_evaluate(config: ml_collections.ConfigDict,
                       workdir: str) -> TrainState:
    """Execute model training and evaluation loop.

  Args:
    config: Hyperparameter configuration for training and evaluation.
    workdir: Directory where the tensorboard summaries are written to.

  Returns:
    Final TrainState.
  """

    writer = metric_writers.create_default_writer(
        logdir=workdir, just_logging=jax.host_id() != 0)

    rng = random.PRNGKey(0)

    image_size = 224

    if config.batch_size % jax.device_count() > 0:
        raise ValueError(
            'Batch size must be divisible by the number of devices')
    local_batch_size = config.batch_size // jax.process_count()

    platform = jax.local_devices()[0].platform

    if config.half_precision:
        if platform == 'tpu':
            input_dtype = tf.bfloat16
        else:
            input_dtype = tf.float16
    else:
        input_dtype = tf.float32

    dataset_builder = tfds.builder(config.dataset)
    train_iter = create_input_iter(dataset_builder,
                                   local_batch_size,
                                   image_size,
                                   input_dtype,
                                   train=True,
                                   cache=config.cache)
    eval_iter = create_input_iter(dataset_builder,
                                  local_batch_size,
                                  image_size,
                                  input_dtype,
                                  train=False,
                                  cache=config.cache)

    steps_per_epoch = (dataset_builder.info.splits['train'].num_examples //
                       config.batch_size)

    if config.num_train_steps == -1:
        num_steps = int(steps_per_epoch * config.num_epochs)
    else:
        num_steps = config.num_train_steps

    if config.steps_per_eval == -1:
        num_validation_examples = dataset_builder.info.splits[
            'validation'].num_examples
        steps_per_eval = num_validation_examples // config.batch_size
    else:
        steps_per_eval = config.steps_per_eval

    steps_per_checkpoint = steps_per_epoch * 10

    base_learning_rate = config.learning_rate * config.batch_size / 256.

    model_cls = getattr(models, config.model)
    model = create_model(model_cls=model_cls,
                         half_precision=config.half_precision)

    learning_rate_fn = create_learning_rate_fn(config, base_learning_rate,
                                               steps_per_epoch)

    state = create_train_state(rng, config, model, image_size,
                               learning_rate_fn)
    state = restore_checkpoint(state, workdir)
    # step_offset > 0 if restarting from checkpoint
    step_offset = int(state.step)
    state = jax_utils.replicate(state)

    p_train_step = jax.pmap(functools.partial(
        train_step, learning_rate_fn=learning_rate_fn),
                            axis_name='batch')
    p_eval_step = jax.pmap(eval_step, axis_name='batch')

    train_metrics = []
    hooks = []
    if jax.process_index() == 0:
        hooks += [
            periodic_actions.Profile(num_profile_steps=5, logdir=workdir)
        ]
    train_metrics_last_t = time.time()
    logging.info('Initial compilation, this might take some minutes...')
    for step, batch in zip(range(step_offset, num_steps), train_iter):
        state, metrics = p_train_step(state, batch)
        for h in hooks:
            h(step)
        if step == step_offset:
            logging.info('Initial compilation completed.')

        if config.get('log_every_steps'):
            train_metrics.append(metrics)
            if (step + 1) % config.log_every_steps == 0:
                train_metrics = common_utils.get_metrics(train_metrics)
                summary = {
                    f'train_{k}': v
                    for k, v in jax.tree_map(lambda x: x.mean(),
                                             train_metrics).items()
                }
                summary['steps_per_second'] = config.log_every_steps / (
                    time.time() - train_metrics_last_t)
                writer.write_scalars(step + 1, summary)
                train_metrics = []
                train_metrics_last_t = time.time()

        if (step + 1) % steps_per_epoch == 0:
            epoch = step // steps_per_epoch
            eval_metrics = []

            # sync batch statistics across replicas
            state = sync_batch_stats(state)
            for _ in range(steps_per_eval):
                eval_batch = next(eval_iter)
                metrics = p_eval_step(state, eval_batch)
                eval_metrics.append(metrics)
            eval_metrics = common_utils.get_metrics(eval_metrics)
            summary = jax.tree_map(lambda x: x.mean(), eval_metrics)
            logging.info('eval epoch: %d, loss: %.4f, accuracy: %.2f', epoch,
                         summary['loss'], summary['accuracy'] * 100)
            writer.write_scalars(
                step + 1, {f'eval_{key}': val
                           for key, val in summary.items()})
            writer.flush()
        if (step + 1) % steps_per_checkpoint == 0 or step + 1 == num_steps:
            state = sync_batch_stats(state)
            save_checkpoint(state, workdir)

    # Wait until computations are done before exiting
    jax.random.normal(jax.random.PRNGKey(0), ()).block_until_ready()

    return state
Example #21
0
def train_and_evaluate(config, workdir):
  """Runs a training and evaluation loop.

  Args:
    config: Configuration to use.
    workdir: Working directory for checkpoints and TF summaries. If this
      contains checkpoint training will be resumed from the latest checkpoint.
  """
  is_first_process = jax.process_index() == 0
  tf.io.gfile.makedirs(workdir)

  # Load Dataset
  # ---------------------------------------------------------------------------
  logging.info('Initializing dataset.')
  train_ds, eval_ds, test_ds, encoder = input_pipeline.get_datasets(
      config)
  config.seq_length = 250
  vocab_size = int(encoder.vocab_size())
  config.num_classes = vocab_size
  config.data_shape = (config.seq_length, 1)

  logging.info('Training with vocab size %d', vocab_size)

  def decode_tokens(toks):
    return encoder.detokenize(toks)

  start_step = 0
  rng = jax.random.PRNGKey(config.seed)
  rng, init_rng = jax.random.split(rng)
  config.per_device_batch_size = config.batch_size // jax.process_count()

  logging.info('Initializing model, optimizer, and step functions.')
  # Build Model and Optimizer
  # ---------------------------------------------------------------------------
  model, initial_variables = model_setup(init_rng, config)

  # Instead of passing the optimizer fns directly, we use a fn that returns
  # the optimizer given a learning rate.
  def tx_fn(lr):
    return optax.adamw(
        lr, b1=0.9, b2=0.99, eps=1e-08, eps_root=0.0,
        weight_decay=config.weight_decay)

  state = language_train_state.TrainState.create(
      params=initial_variables['params'], tx_fn=tx_fn)

  # We access model params only from state below via state.params.
  del initial_variables

  if config.restore_checkpoints:
    # Restore unreplicated model state from last checkpoint.
    state = checkpoints.restore_checkpoint(workdir, state)
    # Grab last step.
    start_step = int(state.step)

  writer = metric_writers.create_default_writer(
      workdir, just_logging=not is_first_process)
  if start_step == 0:
    config_dict = dict(config)
    writer.write_hparams(config_dict)

  if is_first_process and start_step == 0:
    # Dump config file to work dir for easy model loading.
    config_path = os.path.join(workdir, 'config')
    with tf.io.gfile.GFile(config_path, 'wb') as fp:
      pickle.dump(config, fp)

  print('Using state', type(state))
  # Replicate state.
  state = jax_utils.replicate(state)

  learning_rate_fn = create_learning_rate_scheduler(
      factors=config.lr_factors,
      base_learning_rate=config.learning_rate, warmup_steps=config.warmup_steps)

  # Compile multidevice versions of train/eval/predict step fn.
  p_train_step = jax.pmap(
      functools.partial(
          train_step,
          model=model,
          learning_rate_fn=learning_rate_fn,
          clip_grad=config.clip_grad,
          ema_momentum=config.get('ema_momentum', 0.999)),
      axis_name='batch',
      in_axes=(0, 0),
      donate_argnums=(0,))
  p_eval_step = jax.pmap(
      functools.partial(
          eval_step, model=model),
      axis_name='batch')

  # Main Train Loop
  # ---------------------------------------------------------------------------

  # We init the first set of train PRNG keys, but update it afterwards inside
  # the main pmap'd training update for performance.
  rng = jax.random.fold_in(rng, jax.process_index())
  rng1, rng2, rng3, extensive_eval_rngs, sample_rng = jax.random.split(rng, 5)
  train_rngs = jax.random.split(rng1, jax.local_device_count())
  eval_rngs = jax.random.split(rng2, jax.local_device_count())
  test_rngs = jax.random.split(rng3, jax.local_device_count())
  del rng, rng1, rng2, rng3

  logging.info('Starting training loop.')
  hooks = []
  report_progress = periodic_actions.ReportProgress(
      num_train_steps=config.num_train_steps, writer=writer)
  if is_first_process:
    hooks += [
        report_progress,
        periodic_actions.Profile(logdir=workdir, num_profile_steps=5)
    ]
  train_metrics = []

  # Iterator that does epoch-wise indefinite iteration.
  def iterate_train(train_ds):
    epoch = 1
    while True:
      msg = f'Starting epoch {epoch}'
      logging.info(msg)
      for batch in train_ds:
        yield batch
      epoch += 1

  train_iter = iterate_train(train_ds)

  kl_tracker_train = util_fns.KLTracker(num_steps=model.num_steps)
  kl_history = []

  with metric_writers.ensure_flushes(writer):
    step = start_step
    for step in range(start_step, config.num_train_steps):
      is_last_step = step == config.num_train_steps - 1

      # Shard data to devices and do a training step.
      with jax.profiler.StepTraceAnnotation('train', step_num=step):
        batch = common_utils.shard(jax.tree_map(np.asarray, next(train_iter)))
        state, metrics = p_train_step(
            state, batch, rng=train_rngs)
        train_metrics.append(metrics)

      # Quick indication that training is happening.
      logging.log_first_n(logging.INFO, 'Finished training step %d.', 5, step)
      for h in hooks:
        h(step)

      # Periodic metric handling.
      if step > 0 and (step % config.eval_every_steps == 0 or is_last_step):
        with report_progress.timed('training_metrics'):
          logging.info('Gathering training metrics.')
          train_metrics = common_utils.get_metrics(train_metrics)

          # First handle loss terms per step.
          t_batch = train_metrics.pop('t_batch')
          nelbo_per_t_batch = train_metrics.pop('nelbo_per_t_batch')
          kl_tracker_train.update(
              t_batch.reshape(-1), nelbo_per_t_batch.reshape(-1))
          kl_values = kl_tracker_train.get_kl_per_t()
          kl_history.append(np.array(kl_values))
          kl_history = kl_history[-100:]  # Keep last 100 items only.

          # Handle remaining `standard` metrics
          summary = jax.tree_map(jnp.mean, train_metrics)
          summary = {'train_' + k: v for k, v in summary.items()}
          writer.write_scalars(step, summary)
          train_metrics = []

        with report_progress.timed('eval'):
          eval_results, eval_rngs = evaluate(
              p_eval_step=p_eval_step,
              params=state.ema_params,
              eval_ds=eval_ds,
              rng=eval_rngs)
          writer.write_scalars(
              step, {'eval_' + k: v for k, v in eval_results.items()})

          test_results, test_rngs = evaluate(
              p_eval_step=p_eval_step,
              params=state.ema_params,
              eval_ds=test_ds,
              rng=test_rngs)
          writer.write_scalars(
              step, {'test_' + k: v for k, v in test_results.items()})

        if step == 1000 or (step > 0 and
                            step % config.detailed_eval_every_steps == 0):
          if is_first_process:
            loss_components_path = os.path.join(workdir, 'loss_components')
            with tf.io.gfile.GFile(loss_components_path, 'wb') as fp:
              pickle.dump(kl_history[-1], fp)

          extensive_eval_rngs = extensive_eval(
              config, extensive_eval_rngs, writer, workdir,
              model, state, kl_history, test_ds, step,
              decode_tokens)

        with report_progress.timed('generate_text'):
          generate_prediction(sample_rng, config, model, state, writer,
                              decode_tokens, step)

      # Save a checkpoint on one host after every checkpoint_freq steps.
      save_checkpoint = (
          step > 0 and
          (step % config.checkpoint_every_steps == 0 or is_last_step))
      if config.save_checkpoints and save_checkpoint and is_first_process:
        with report_progress.timed('checkpoint'):
          checkpoints.save_checkpoint(workdir, jax_utils.unreplicate(state),
                                      step, overwrite=True)
Example #22
0
def train_and_evaluate(config, work_dir, try_checkpoint=True):
    """Execute model training and evaluation loop.

  Args:
    config: Hyperparameter configuration for training and evaluation.
    work_dir: Directory where the tensorboard summaries are written to.
    try_checkpoint: Should try to load checkpoint (usually enabled, practical
        for debugging purposes to disable).

  Returns:
    The train state (which includes the `.params`).
  """
    # Init rng key.
    rng = jax.random.PRNGKey(config.seed)
    data_rng, rng = jax.random.split(rng)
    is_first_host = jax.process_index() == 0

    if config.dataset.name.endswith('speech_commands09'):
        ds, ds_metadata = input_pipeline_sc09.get_dataset(data_rng, config)
    else:
        raise ValueError(f'Unknown dataset {config.dataset.name}.')

    # Immediately create infinite iterators.
    it = jax.tree_map(util_fns.get_iterator, ds)

    # TODO(agritsenko): Can we fix the ugly nested dicts?
    config.data_shape = ds_metadata['train']['shape']['inputs'][2:]
    config.num_classes = ds_metadata['train']['num_classes']
    config.sample_rate = ds_metadata['train']['sample_rate']

    writer = metric_writers.create_default_writer(
        work_dir, just_logging=jax.process_index() > 0)
    rng, init_rng = jax.random.split(rng)

    model, variables = model_setup(init_rng, config)

    # From now on we want different rng across hosts:
    rng = jax.random.fold_in(rng, jax.process_index())

    def tx_fn(lr):
        return optax.adamw(lr,
                           b1=0.9,
                           b2=config.beta2,
                           eps=1e-08,
                           eps_root=0.0,
                           weight_decay=config.weight_decay)

    state = language_train_state.TrainState.create(params=variables['params'],
                                                   tx_fn=tx_fn)

    start_step = None
    if try_checkpoint:
        state, start_step = checkpoint.restore_from_path(work_dir, state)
    start_step = start_step or 0

    # Use different rngs for train & eval.
    rng_train, rng_eval, rng_sample = jax.random.split(rng, 3)

    kl_tracker = util_fns.KLTracker(num_steps=model.num_steps)
    kl_history = []

    learning_rate_fn = train_utils.create_learning_rate_scheduler(
        **config.learning_rate)
    p_train_step = jax.pmap(functools.partial(
        train_step,
        config=config,
        learning_rate_fn=learning_rate_fn,
        model=model),
                            axis_name='batch',
                            in_axes=(None, 0, 0),
                            out_axes=(0, 0, None),
                            donate_argnums=(2, ))

    # The only axes that are broadcasted are the in- and output rng key ones. The
    # rng is the first arg, and the last return value.
    p_eval_step = jax.pmap(functools.partial(eval_step, model=model),
                           axis_name='batch',
                           in_axes=(None, 0, 0),
                           out_axes=(0, 0, None))

    # Training length.
    logging.info('Training will start from step %d', start_step)

    # Replicate state.
    state = flax.jax_utils.replicate(state)

    # Setup hooks.
    hooks = []
    report_progress = periodic_actions.ReportProgress(
        num_train_steps=config.num_train_steps, writer=writer)
    if is_first_host:
        hooks += [
            report_progress,
            periodic_actions.Profile(logdir=work_dir, num_profile_steps=5)
        ]

    with metric_writers.ensure_flushes(writer):
        batch_metrics = []
        for step in range(start_step, config.num_train_steps):
            logging.log_first_n(logging.INFO, f'Train step: {step}', 5)
            with jax.profiler.StepTraceAnnotation('train', step_num=step):
                state, metrics, rng_train = p_train_step(
                    rng_train, next(it['train']), state)
            batch_metrics.append(metrics)

            # Cycle though hooks.
            for h in hooks:
                h(step)

            is_last_step = step == config.num_train_steps - 1

            if (step % config.log_every_steps == 0) or is_last_step:
                with report_progress.timed('training_metrics'):
                    ################### Process batch metrics ############################
                    batch_metrics = jax.device_get(
                        flax.jax_utils.unreplicate(batch_metrics))

                    if 't_batch' in metrics:
                        # TODO(agritsenko): Factor out into a separate function.
                        # This processes the loss per t, although two nested for-loops
                        # (counting the one inside kl_tracker), it actually does not hurt
                        # timing performance meaningfully.
                        batch_t = [
                            metrics['t_batch'].reshape(-1)
                            for metrics in batch_metrics
                        ]
                        batch_nelbo_per_t = [
                            metrics['nelbo_per_t_batch'].reshape(-1)
                            for metrics in batch_metrics
                        ]
                        for t, nelbo_per_t in zip(batch_t, batch_nelbo_per_t):
                            kl_tracker.update(t, nelbo_per_t)

                    ################### Process batch metrics ############################
                    metrics = {
                        key:
                        np.mean([metrics[key] for metrics in batch_metrics])
                        for key in batch_metrics[0] if 'batch' not in key
                    }

                    # Metric logging.
                    if is_first_host:
                        log_standard_metrics(writer,
                                             step,
                                             train_metrics=metrics)
                    batch_metrics = []

            if config.eval_every_steps and (
                (step % config.eval_every_steps == 0) or is_last_step):
                with report_progress.timed('eval'):
                    ####################### Run evaluation ###############################
                    metrics, rng_eval = eval_model(
                        p_eval_step, rng_eval, state, it['eval'],
                        (ds_metadata['eval']['num_batches'] *
                         config.get('num_eval_passes', 1)))

                    # Metric logging.
                    if is_first_host:
                        log_standard_metrics(writer,
                                             step,
                                             eval_metrics=metrics)

                # Track KL (unrelated to the eval, but nice to not do every step).
                kl_values = kl_tracker.get_kl_per_t()
                kl_history.append(np.array(kl_values))
                kl_history = kl_history[-50:]

            if config.sample_every_steps and (
                (step % config.sample_every_steps == 0) or is_last_step):
                with report_progress.timed('sample'):
                    ######################### Run sampling ###############################
                    chain = model.sample(jax.random.fold_in(rng_sample, step),
                                         state.ema_params,
                                         config.sample_batch_size,
                                         chain_out_size=config.get(
                                             'chain_out_size',
                                             model.num_stages))

                    if is_first_host:
                        chain = jax.device_get(chain)
                        long_sample = np.reshape(chain[-1],
                                                 (1, -1, 1)).astype(np.float32)
                        long_sample = (2. *
                                       long_sample) / config.num_classes - 1.
                        writer.write_audios(step, {'samples': long_sample},
                                            sample_rate=config.sample_rate)

            ######################### Checkpointing #################################
            if is_first_host and config.checkpoint_every_steps and (
                (step % config.checkpoint_every_steps == 0) or is_last_step):
                logging.info('Saving checkpoint: step %d', step)
                with report_progress.timed('checkpoint'):
                    checkpoint.save_checkpoint(
                        work_dir,
                        state=flax.jax_utils.unreplicate(state),
                        step=step)
                logging.info('Finished saving checkpoint: step %d', step)

        return state
def main(argv):
  del argv

  config = FLAGS.config
  workdir = FLAGS.workdir

  logging.info("Workdir: %s", workdir)

  save_checkpoint_path = None
  if config.get("checkpoint_steps"):
    tf.io.gfile.makedirs(workdir)
    save_checkpoint_path = os.path.join(workdir, "checkpoint.npz")

  # The pool is used to perform misc operations such as logging in async way.
  pool = multiprocessing.pool.ThreadPool()

  # This seed makes the Jax part of things (like model init) deterministic.
  # However, full training still won't be deterministic, for example due to the
  # tf.data pipeline not being deterministic even if we would set TF seed.
  rng = jax.random.PRNGKey(config.get("seed", 0))

  def write_note(note):
    if jax.host_id() == 0:
      logging.info("NOTE: %s", note)
  write_note("Initializing...")

  # Verify settings to make sure no checkpoints are accidentally missed.
  if config.get("keep_checkpoint_steps"):
    assert config.get("checkpoint_steps"), "Specify `checkpoint_steps`."
    assert config.keep_checkpoint_steps % config.checkpoint_steps == 0, (
        f"`keep_checkpoint_steps` ({config.checkpoint_steps}) should be"
        f"divisible by `checkpoint_steps ({config.checkpoint_steps}).`")

  batch_size = config.batch_size
  batch_size_eval = config.get("batch_size_eval", batch_size)
  if (batch_size % jax.device_count() != 0 or
      batch_size_eval % jax.device_count() != 0):
    raise ValueError(f"Batch sizes ({batch_size} and {batch_size_eval}) must "
                     f"be divisible by device number ({jax.device_count()})")

  local_batch_size = batch_size // jax.host_count()
  local_batch_size_eval = batch_size_eval // jax.host_count()
  logging.info(
      "Global batch size %d on %d hosts results in %d local batch size. "
      "With %d dev per host (%d dev total), that's a %d per-device batch size.",
      batch_size, jax.host_count(), local_batch_size,
      jax.local_device_count(), jax.device_count(),
      local_batch_size // jax.local_device_count())

  write_note("Initializing train dataset...")
  train_ds = input_pipeline.get_data(
      dataset=config.dataset,
      split=config.train_split,
      data_dir=fillin(config.get("dataset_dir")),
      batch_size=local_batch_size,
      preprocess_fn=pp_builder.get_preprocess_fn(config.pp_train),
      shuffle_buffer_size=config.shuffle_buffer_size,
      prefetch=config.get("prefetch_to_host", 2),
      cache=False)

  # Start prefetching already.
  train_iter = u.start_input_pipeline(
      train_ds, config.get("prefetch_to_device", 1), pad=local_batch_size)
  # We always pad to local_batch_size_eval even when less would be enough in
  # order to minimize memory fragmentation.

  write_note("Initializing val dataset(s)...")
  def _get_val_split(dataset, split, pp_eval, data_dir=None):
    # We do ceil rounding such that we include the last incomplete batch.
    nval_img = input_pipeline.get_num_examples(
        dataset, split, data_dir=fillin(data_dir))
    val_steps = int(np.ceil(nval_img / batch_size_eval))
    logging.info("Running validation for %d steps for %s, %s", val_steps,
                 dataset, split)

    val_it = input_pipeline.get_data(
        dataset=dataset,
        split=split,
        data_dir=fillin(data_dir),
        batch_size=local_batch_size_eval,
        preprocess_fn=pp_builder.get_preprocess_fn(pp_eval),
        cache=config.get("val_cache", "batched"),
        repeat_after_batching=True,
        prefetch=0,  # Save memory since we cache.
        drop_remainder=False,
        shuffle_files=False)
    val_it = u.start_input_pipeline(
        val_it, config.get("prefetch_to_device", 1), pad=local_batch_size_eval)

    return (val_it, val_steps)

  if isinstance(config.val_split, str):
    val_ds = {"val": _get_val_split(config.dataset, config.val_split,
                                    config.pp_eval, config.get("dataset_dir"))}
  else:
    val_ds = {t[0]: _get_val_split(*t[1:]) for t in config.val_split}

  ntrain_img = input_pipeline.get_num_examples(
      config.dataset, config.train_split,
      data_dir=fillin(config.get("dataset_dir")))
  steps_per_epoch = ntrain_img / batch_size

  if config.get("num_epochs"):
    total_steps = int(config.num_epochs * steps_per_epoch)
    assert not config.get("total_steps"), "Set either num_epochs or total_steps"
  else:
    total_steps = config.total_steps

  logging.info(
      "Running for %d steps, that means %f epochs and %f steps per epoch",
      total_steps, total_steps * batch_size / ntrain_img, steps_per_epoch)
  mw = u.BigVisionMetricWriter(xm_xp.id, xm_wu.id, steps_per_epoch)

  write_note(f"Initializing {config.model_name} model...")
  model_mod = importlib.import_module(f"{BASEDIR}.models.{config.model_name}")
  model = model_mod.Model(
      num_classes=config.num_classes, **config.get("model", {}))

  # We want all parameters to be created in host RAM, not on any device, they'll
  # be sent there later as needed, otherwise we already encountered two
  # situations where we allocate them twice.
  @partial(jax.jit, backend="cpu")
  def init(rng):
    image_size = tuple(train_ds.element_spec["image"].shape[1:])
    dummy_input = jnp.zeros((local_batch_size,) + image_size, jnp.float32)
    params = flax.core.unfreeze(model.init(rng, dummy_input))["params"]

    # Set bias in the head to a low value, such that loss is small initially.
    params["head"]["bias"] = jnp.full_like(
        params["head"]["bias"], config.get("init_head_bias", 0))

    return params

  rng, rng_init = jax.random.split(rng)
  params_cpu = init(rng_init)

  if jax.host_id() == 0:
    num_params = sum(p.size for p in jax.tree_flatten(params_cpu)[0])
    parameter_overview.log_parameter_overview(params_cpu)
    mw.measure("num_params", num_params)

  @partial(jax.pmap, axis_name="batch")
  def evaluation_fn(params, images, labels, mask):
    # Ignore the entries with all zero labels for evaluation.
    mask *= labels.max(axis=1)
    logits, _ = model.apply({"params": flax.core.freeze(params)}, images)

    losses = getattr(u, config.get("loss", "sigmoid_xent"))(
        logits=logits, labels=labels, reduction=False)
    loss = jax.lax.psum(losses * mask, axis_name="batch")

    top1_idx = jnp.argmax(logits, axis=1)
    # Extracts the label at the highest logit index for each image.
    top1_correct = jnp.take_along_axis(labels, top1_idx[:, None], axis=1)[:, 0]
    ncorrect = jax.lax.psum(top1_correct * mask, axis_name="batch")
    n = jax.lax.psum(mask, axis_name="batch")
    return ncorrect, loss, n

  # Setup function for computing representation.
  @partial(jax.pmap, axis_name="batch")
  def representation_fn(params, images, labels, mask):
    _, outputs = model.apply({"params": flax.core.freeze(params)}, images)
    representation = outputs[config.fewshot.representation_layer]
    representation = jax.lax.all_gather(representation, "batch")
    labels = jax.lax.all_gather(labels, "batch")
    mask = jax.lax.all_gather(mask, "batch")
    return representation, labels, mask

  # Load the optimizer either from our folder or from flax.
  opt_name = config.get("optim_name", "momentum_hp")
  write_note(f"Initializing {opt_name} optimizer...")
  try:
    opt_mod = importlib.import_module(f"{BASEDIR}.optims.{opt_name}")
    opt_def = opt_mod.Optimizer(**config.get("optim", {}))
  except ModuleNotFoundError:
    opt_def = getattr(flax.optim, opt_name)(**config.get("optim", {}))

  # We jit this, such that the arrays that are created are created on the same
  # device as the input is, in this case the CPU. Else they'd be on device[0].
  opt_cpu = jax.jit(opt_def.create)(params_cpu)

  @partial(jax.pmap, axis_name="batch", donate_argnums=(0,))
  def update_fn(opt, lr, images, labels, rng):
    """Update step."""

    measurements = {}

    if config.get("mixup") and config.mixup.p:
      rng, (images, labels), _ = u.mixup(rng, images, labels, **config.mixup)

    # Get device-specific loss rng.
    rng, rng_model = jax.random.split(rng, 2)
    rng_model_local = jax.random.fold_in(rng_model, jax.lax.axis_index("batch"))

    def loss_fn(params, images, labels):
      logits, _ = model.apply(
          {"params": flax.core.freeze(params)}, images,
          train=True, rngs={"dropout": rng_model_local})
      return getattr(u, config.get("loss", "sigmoid_xent"))(
          logits=logits, labels=labels)

    # Implementation considerations compared and summarized at
    # https://docs.google.com/document/d/1g3kMEvqu1DOawaflKNyUsIoQ4yIVEoyE5ZlIPkIl4Lc/edit?hl=en#
    l, g = u.accumulate_gradient(jax.value_and_grad(loss_fn), opt.target,
                                 images, labels,
                                 config.get("grad_accum_steps"))
    l, g = jax.lax.pmean((l, g), axis_name="batch")

    # Log the gradient norm only if we need to compute it anyways (clipping)
    # or if we don't use grad_accum_steps, as they interact badly.
    if config.get("grad_accum_steps", 1) == 1 or config.get("grad_clip_norm"):
      grads, _ = jax.tree_flatten(g)
      l2_g = jnp.sqrt(sum([jnp.vdot(p, p) for p in grads]))
      measurements["l2_grads"] = l2_g

    # Optionally resize the global gradient to a maximum norm. We found this
    # useful in some cases across optimizers, hence it's in the main loop.
    if config.get("grad_clip_norm"):
      g_factor = jnp.minimum(1.0, config.grad_clip_norm / l2_g)
      g = jax.tree_map(lambda p: g_factor * p, g)
    opt = opt.apply_gradient(g, learning_rate=lr)

    decay_rules = config.get("weight_decay", []) or []
    if isinstance(decay_rules, numbers.Number):
      decay_rules = [(".*kernel.*", decay_rules)]
    sched_m = lr/config.lr.base if config.get("weight_decay_decouple") else lr
    def decay_fn(v, wd):
      return (1.0 - sched_m * wd) * v
    opt = opt.replace(target=u.tree_map_with_regex(
        decay_fn, opt.target, decay_rules, name="weight decay"))

    params, _ = jax.tree_flatten(opt.target)
    measurements["l2_params"] = jnp.sqrt(sum([jnp.vdot(p, p) for p in params]))

    return opt, l, rng, measurements

  # Other things besides optimizer state to be stored.
  checkpoint_extra = dict(accum_train_time=0.0)

  # Decide how to initialize training. The order is important.
  # 1. Always resumes from the existing checkpoint, e.g. resumes a finetune job.
  # 2. Resume from a previous checkpoint, e.g. start a cooldown training job.
  # 3. Initialize model from something, e,g, start a fine-tuning job.
  # 4. Train from scratch.
  resume_checkpoint_path = None
  if save_checkpoint_path and tf.io.gfile.exists(save_checkpoint_path):
    resume_checkpoint_path = save_checkpoint_path
  elif config.get("resume"):
    resume_checkpoint_path = fillin(config.resume)
  if resume_checkpoint_path:
    write_note("Resume training from checkpoint...")
    checkpoint = {"opt": opt_cpu, "extra": checkpoint_extra}
    _, checkpoint_tree = jax.tree_flatten(checkpoint)
    loaded = u.load_checkpoint(checkpoint_tree, resume_checkpoint_path)
    # bfloat16 type gets lost when data is saved to disk, so we recover it.
    checkpoint = jax.tree_map(u.recover_dtype, loaded)
    opt_cpu, checkpoint_extra = checkpoint["opt"], checkpoint["extra"]
  elif config.get("model_init"):
    write_note(f"Initialize model from {config.model_init}...")
    loaded = model_mod.load(params_cpu, config.model_init, config.get("model"))
    opt_cpu = opt_cpu.replace(target=loaded)
    if jax.host_id() == 0:
      logging.info("Restored parameter overview:")
      parameter_overview.log_parameter_overview(loaded)

  write_note("Kicking off misc stuff...")
  first_step = int(opt_cpu.state.step)  # Might be a DeviceArray type.
  chrono = u.Chrono(first_step, total_steps, batch_size,
                    checkpoint_extra["accum_train_time"])
  # Note: switch to ProfileAllHosts() if you need to profile all hosts.
  # (Xprof data become much larger and take longer to load for analysis)
  profiler = periodic_actions.Profile(
      # Create profile after every restart to analyze pre-emption related
      # problems and assure we get similar performance in every run.
      logdir=workdir, first_profile=first_step + 10)

  # Prepare the learning-rate and pre-fetch it to device to avoid delays.
  lr_fn = u.create_learning_rate_schedule(
      batch_size, total_steps, steps_per_epoch, **config.get("lr", {}))
  lr_iter = u.prefetch_scalar(map(lr_fn, range(first_step, total_steps)),
                              config.get("prefetch_to_device", 1))

  write_note(f"Replicating...\n{chrono.note}")
  opt_repl = flax_utils.replicate(opt_cpu)

  write_note(f"Initializing few-shotters...\n{chrono.note}")
  if "fewshot" in config:
    fewshotter = fewshot.FewShotEvaluator(
        representation_fn, config.fewshot,
        config.fewshot.get("batch_size") or batch_size_eval)

  rng, rng_loop = jax.random.split(rng, 2)
  rngs_loop = flax_utils.replicate(rng_loop)
  checkpoint_writer = None

  write_note(f"First step compilations...\n{chrono.note}")
  # Using a python integer for step here, because opt.state.step is allocated
  # on TPU during replication.
  for step, train_batch, lr_repl in zip(
      range(first_step + 1, total_steps + 1), train_iter, lr_iter):
    mw.step_start(step)

    with jax.profiler.TraceContext("train_step", step_num=step, _r=1):
      opt_repl, loss_value, rngs_loop, extra_measurements = update_fn(
          opt_repl,
          lr_repl,
          train_batch["image"],
          train_batch["labels"],
          rng=rngs_loop)

    if jax.host_id() == 0:
      profiler(step)

    # Checkpoint saving
    if u.itstime(step, config.get("checkpoint_steps"), total_steps, host=0):
      chrono.pause()
      u.checkpointing_timeout(checkpoint_writer,
                              config.get("checkpoint_timeout", 1))
      checkpoint_extra["accum_train_time"] = chrono.accum_train_time
      # We need to transfer the weights over now or else we risk keeping them
      # alive while they'll be updated in a future step, creating hard to debug
      # memory errors (see b/160593526). Also, takes device 0's params only.
      opt_cpu = jax.tree_map(lambda x: np.array(x[0]), opt_repl)

      # Check whether we want to keep a copy of the current checkpoint.
      copy_step = None
      if u.itstime(step, config.get("keep_checkpoint_steps"), total_steps):
        copy_step = step

      # Checkpoint should be a nested dictionary or FLAX datataclasses from
      # `flax.struct`. Both can be present in a checkpoint.
      checkpoint = {"opt": opt_cpu, "extra": checkpoint_extra}
      checkpoint_writer = pool.apply_async(
          u.save_checkpoint, (checkpoint, save_checkpoint_path, copy_step))
      chrono.resume()

    # Report training progress
    if u.itstime(step, config.log_training_steps, total_steps, host=0):
      mw.measure("learning_rate", lr_repl[0])
      mw.measure("training_loss", loss_value[0])
      for name, value in extra_measurements.items():
        mw.measure(name, value[0])
      chrono.tick(step, mw.measure, write_note)

    # Report validation performance
    if u.itstime(step, config.log_eval_steps, total_steps):
      chrono.pause()
      for val_name, (val_iter, val_steps) in val_ds.items():
        ncorrect, loss, nseen = 0, 0, 0
        for _, batch in zip(range(val_steps), val_iter):
          batch_ncorrect, batch_losses, batch_n = evaluation_fn(
              opt_repl.target, batch["image"], batch["labels"], batch["mask"])
          # All results are a replicated array shaped as follows:
          # (local_devices, per_device_batch_size, elem_shape...)
          # with each local device's entry being identical as they got psum'd.
          # So let's just take the first one to the host as numpy.
          ncorrect += np.sum(np.array(batch_ncorrect[0]))
          loss += np.sum(np.array(batch_losses[0]))
          nseen += np.sum(np.array(batch_n[0]))
        mw.measure(f"{val_name}_prec@1", ncorrect / nseen)
        mw.measure(f"{val_name}_loss", loss / nseen)
      chrono.resume()

    if "fewshot" in config:
      # Compute few-shot on-the-fly evaluation.
      if u.itstime(step, config.fewshot.log_steps, total_steps):
        chrono.pause()
        write_note(f"Few-shot evaluation...\n{chrono.note}")
        r = fewshotter.run_all(opt_repl.target, config.fewshot.datasets)
        fewshotter.walk_results(mw.measure, *r)
        chrono.resume()
    mw.step_end()

  write_note(f"Done!\n{chrono.note}")
  pool.close()
  pool.join()
  mw.close()