Beispiel #1
0
def restore_model_and_put_to_devices(
    config: ml_collections.ConfigDict,
    workdir: str,
    partition_specs: Sequence[PartitionSpec],
    model: flax.nn.Module,
    optimizer: flax.optim.Optimizer,
    train_iter: Iterable[Any],
    rngs: Mapping[str, jnp.ndarray],
    thread_pool: multiprocessing.pool.ThreadPool,
) -> Tuple[flax.optim.Optimizer, Iterable[Any], jnp.ndarray, Mapping[str, Any]]:
  """Restores from latest available checkpoint and puts model to devices."""
  (optimizer, train_iter, rng_state_tf, rngs,
   global_state) = train.restore_checkpoints(
       workdir=workdir,
       step=None,
       partition_specs=partition_specs,
       optimizer=optimizer,
       train_iter=train_iter,
       rng_state_tf=tf.random.get_global_generator().state.numpy(),
       rng_state_jax=rngs,
       global_state={},
       thread_pool=thread_pool)
  if global_state:
    # 1. If a checkpoint is present in the current work dir, continue training.
    logging.info("Continuing training from step %d", global_state["step"])
    # Shard parameters and optim state and put to the corresponding device.
    optimizer = core.tree_shard(optimizer)
  elif config.get("model_init_prefix"):
    # 2. Alternatively, initialize from the given model_init_prefix checkpoint.
    logging.info("Fine-tuning model from %r...", config.model_init_prefix)
    if not hasattr(model, "load"):
      # Note: Likely due to use of .partial, model may end up being e.g.
      # a flax.nn.Base.PatchTransformer instead of experts_nn.PatchTransformer
      # This causes explicit checks for class equivalence to fail, and also
      # causes static type checking to fail. Checking for .load attribute
      # circumvents both these issues.
      raise ValueError((f"Loaded model {model} has no load method. Are you sure"
                        " it is one of  'PatchTransformer' and 'Resformer'?"))
    restored_params = model.load(
        prefix=config.model_init_prefix,
        init_params=optimizer.target,
        model_params=config.model,
        keep_head=config.get("keep_head", False),
        partition_specs=partition_specs)
    # Shard restored parameters and replicate original optimizer state.
    optimizer = optimizer.replace(
        target=core.tree_shard(restored_params),
        state=flax.jax_utils.replicate(optimizer.state))
    global_state = {"step": 0, "accum_train_time": 0.0}
  else:
    # 3. Use model initialized from scratch.
    logging.info("Initializing training from scratch...")
    optimizer = flax.jax_utils.replicate(optimizer)
    global_state = {"step": 0, "accum_train_time": 0.0}
  # Set TF's global RNG generator and JAX's per-device RNG keys.
  train.rng_tf_set_global_generator(rng_state_tf)
  rngs_per_device = jax.tree_map(train.rng_jax_fold_host_if_needed_and_shard,
                                 rngs)
  return optimizer, train_iter, rngs_per_device, global_state
  def init(rng):
    image_size = tuple(train_ds.element_spec["image"].shape[1:])
    dummy_input = jnp.zeros((local_batch_size,) + image_size, jnp.float32)
    params = flax.core.unfreeze(model.init(rng, dummy_input))["params"]

    # Set bias in the head to a low value, such that loss is small initially.
    params["head"]["bias"] = jnp.full_like(
        params["head"]["bias"], config.get("init_head_bias", 0))

    return params
  def update_fn(opt, lr, images, labels, rng):
    """Update step."""

    measurements = {}

    if config.get("mixup") and config.mixup.p:
      rng, (images, labels), _ = u.mixup(rng, images, labels, **config.mixup)

    # Get device-specific loss rng.
    rng, rng_model = jax.random.split(rng, 2)
    rng_model_local = jax.random.fold_in(rng_model, jax.lax.axis_index("batch"))

    def loss_fn(params, images, labels):
      logits, _ = model.apply(
          {"params": flax.core.freeze(params)}, images,
          train=True, rngs={"dropout": rng_model_local})
      return getattr(u, config.get("loss", "sigmoid_xent"))(
          logits=logits, labels=labels)

    # Implementation considerations compared and summarized at
    # https://docs.google.com/document/d/1g3kMEvqu1DOawaflKNyUsIoQ4yIVEoyE5ZlIPkIl4Lc/edit?hl=en#
    l, g = u.accumulate_gradient(jax.value_and_grad(loss_fn), opt.target,
                                 images, labels,
                                 config.get("grad_accum_steps"))
    l, g = jax.lax.pmean((l, g), axis_name="batch")

    # Log the gradient norm only if we need to compute it anyways (clipping)
    # or if we don't use grad_accum_steps, as they interact badly.
    if config.get("grad_accum_steps", 1) == 1 or config.get("grad_clip_norm"):
      grads, _ = jax.tree_flatten(g)
      l2_g = jnp.sqrt(sum([jnp.vdot(p, p) for p in grads]))
      measurements["l2_grads"] = l2_g

    # Optionally resize the global gradient to a maximum norm. We found this
    # useful in some cases across optimizers, hence it's in the main loop.
    if config.get("grad_clip_norm"):
      g_factor = jnp.minimum(1.0, config.grad_clip_norm / l2_g)
      g = jax.tree_map(lambda p: g_factor * p, g)
    opt = opt.apply_gradient(g, learning_rate=lr)

    decay_rules = config.get("weight_decay", []) or []
    if isinstance(decay_rules, numbers.Number):
      decay_rules = [(".*kernel.*", decay_rules)]
    sched_m = lr/config.lr.base if config.get("weight_decay_decouple") else lr
    def decay_fn(v, wd):
      return (1.0 - sched_m * wd) * v
    opt = opt.replace(target=u.tree_map_with_regex(
        decay_fn, opt.target, decay_rules, name="weight decay"))

    params, _ = jax.tree_flatten(opt.target)
    measurements["l2_params"] = jnp.sqrt(sum([jnp.vdot(p, p) for p in params]))

    return opt, l, rng, measurements
  def _get_val_split(dataset, split, pp_eval, data_dir=None):
    # We do ceil rounding such that we include the last incomplete batch.
    nval_img = input_pipeline.get_num_examples(
        dataset, split, data_dir=fillin(data_dir))
    val_steps = int(np.ceil(nval_img / batch_size_eval))
    logging.info("Running validation for %d steps for %s, %s", val_steps,
                 dataset, split)

    val_it = input_pipeline.get_data(
        dataset=dataset,
        split=split,
        data_dir=fillin(data_dir),
        batch_size=local_batch_size_eval,
        preprocess_fn=pp_builder.get_preprocess_fn(pp_eval),
        cache=config.get("val_cache", "batched"),
        repeat_after_batching=True,
        prefetch=0,  # Save memory since we cache.
        drop_remainder=False,
        shuffle_files=False)
    val_it = u.start_input_pipeline(
        val_it, config.get("prefetch_to_device", 1), pad=local_batch_size_eval)

    return (val_it, val_steps)
  def evaluation_fn(params, images, labels, mask):
    # Ignore the entries with all zero labels for evaluation.
    mask *= labels.max(axis=1)
    logits, _ = model.apply({"params": flax.core.freeze(params)}, images)

    losses = getattr(u, config.get("loss", "sigmoid_xent"))(
        logits=logits, labels=labels, reduction=False)
    loss = jax.lax.psum(losses * mask, axis_name="batch")

    top1_idx = jnp.argmax(logits, axis=1)
    # Extracts the label at the highest logit index for each image.
    top1_correct = jnp.take_along_axis(labels, top1_idx[:, None], axis=1)[:, 0]
    ncorrect = jax.lax.psum(top1_correct * mask, axis_name="batch")
    n = jax.lax.psum(mask, axis_name="batch")
    return ncorrect, loss, n
Beispiel #6
0
def main(_):
  config = flags.FLAGS.config
  workdir = flags.FLAGS.workdir
  tf.io.gfile.makedirs(workdir)

  partition_specs = []

  # Loss to apply.
  loss_to_apply = getattr(core, config.get("loss_to_apply", "softmax_xent"))
  compute_ece = config.get("compute_ece", False)
  is_sigmoid = config.get("loss_to_apply", "softmax_xent") == "sigmoid_xent"
  if compute_ece and is_sigmoid:
    error_msg = "Inconsistent config: ECE can only be used with 'softmax_xent'."
    raise ValueError(error_msg)

  ens_size = config.get("model.transformer.ens_size", 1)
  # TODO(ghassen): enable sigmoid for ensemble.
  if ens_size > 1 and is_sigmoid:
    error_msg = "Inconsistent config: Ensemble only works with 'softmax_xent'."
    raise ValueError(error_msg)

  # The pool is used to perform misc operations such as logging in async way.
  pool = multiprocessing.pool.ThreadPool()

  # Ideally, this should make code deterministic, but for many reasons we are
  # not there yet. For instance, tf.data.map is not determisntic.
  rng_generator = tf.random.Generator.from_seed(config.get("seed", 0))
  tf.random.set_global_generator(
      rng_generator.split(jax.host_count())[jax.host_id()])

  logging.info("Number of devices: %s  (host_id: %s)", jax.device_count(),
               jax.host_id())
  logging.info("Config:\n%s", str(config))

  if (config.batch_size % jax.device_count() != 0 or
      config.batch_size_eval % jax.device_count() != 0):
    raise ValueError(f"Batch sizes ({config.batch_size} and "
                     f"{config.batch_size_eval}) must be divisible by "
                     f"the number of devices ({jax.device_count()})")

  batch_size_per_host = config.batch_size // jax.host_count()
  batch_size_per_core = config.batch_size // jax.device_count()
  batch_size_per_host_eval = config.batch_size_eval // jax.host_count()
  # TODO(basilm): Remove when JFT2.6B is properly submitted.
  if config.dataset in jft_latest_pipeline.DATA_INFO:
    input_pipeline = jft_latest_pipeline
    cache = "loaded"
  else:
    input_pipeline = default_input_pipeline
    cache = "batched"

  train_ds = input_pipeline.get_data(
      dataset=config.dataset,
      data_dir=config.get("dataset_dir"),
      split=config.train_split,
      batch_size=batch_size_per_host,
      preprocess_fn=pp_builder.get_preprocess_fn(config.pp_train),
      shuffle_buffer_size=config.shuffle_buffer_size,
      cache=False)
  steps_per_epoch = input_pipeline.get_num_examples(
      config.dataset, config.train_split,
      data_dir=config.get("dataset_dir")) / config.batch_size
  total_steps = train.get_total_steps_from_config(config, steps_per_epoch)
  logging.info("Running for %d steps per epoch (%d steps total)",
               steps_per_epoch, total_steps)

  opt_def = train.get_optimizer_from_config(config, f"{BIG_VISION_DIR}.optims")
  eval_config = copy.deepcopy(config)
  if config.get("eval_overrides"):
    with eval_config.unlocked():
      eval_config.update(config.eval_overrides)
  model = getattr(ub.models, config.model_name)
  model_train = model(
      num_classes=config.num_classes, train=True, **config.model)
  model_eval = model(
      num_classes=config.num_classes, train=False, **eval_config.model)

  image_size = tuple(train_ds.element_spec["image"].shape[1:])
  logging.info("Model initialization: Starting.")
  opt, rngs = train.model_and_optim_init(
      model_train.init, opt_def, (batch_size_per_core * ens_size,) + image_size,
      config.get("init_head_bias"), config.get("seed", 0),
      config.get("extra_rngs", ["dropout", "gating"]))
  logging.info("Model initialization: Done.")
  # TODO(jpuigcerver): Support logging parameter count with new sharding.

  if config.get("plot_grad_norm_patterns"):
    plot_grad_norm_name_fn = experts_utils.make_match_fn_from_prefixes(
        config.plot_grad_norm_patterns)
  else:
    plot_grad_norm_name_fn = None

  weight_decay_fn = train.get_weight_decay_function_from_config(config)
  batch_loss_fn = train.wrap_module_with_auxiliary_loss_fn(
      module=model_train,
      loss_fn=loss_to_apply,
      auxiliary_loss_weight=config.get("auxiliary_loss_weight", 0.0),
      ens_size=ens_size)
  if ens_size == 1:
    evaluation_fn = functools.partial(
        train.evaluation_fn,
        apply_fn=model_eval.apply,
        loss_fn=loss_to_apply,
        correct_fn=train.correct_multilabel,
        return_metric_args=compute_ece)
  else:
    evaluation_fn = functools.partial(
        ensemble.evaluation_fn,
        apply_fn=model_eval.apply,
        return_metric_args=compute_ece,
        ens_size=ens_size)
  pmap_evaluation_fn = core.pmap_sorted(evaluation_fn, axis_name="batch")

  update_fn = functools.partial(
      batchensemble_utils.update_fn_be,
      weight_decay_fn=weight_decay_fn,
      plot_grad_norm_name_fn=plot_grad_norm_name_fn,
      plot_grads_nan_inf=config.get("plot_grads_nan_inf", True),
      max_grad_norm_global=config.get("clip_grad_norm", None),
      frozen_vars_patterns=config.get("frozen_var_patterns", None),
      fast_weight_lr_multiplier=config.get("fast_weight_lr_multiplier", None))
  pmap_update_fn = core.pmap_sorted(
      update_fn, axis_name="batch", donate_argnums=(0, 1),
      static_broadcasted_argnums=(5,))

  # Restore parameters from checkpoints (if possible) and put to TPU devices.
  opt, train_iter, rngs_per_device, global_state = restore_model_and_put_to_devices(
      config, workdir, partition_specs, model, opt, iter(train_ds), rngs, pool)
  del rngs
  first_step = global_state["step"]
  accum_train_time = global_state["accum_train_time"]
  start_time = time.time()
  logging.info("Initial step for training = %d.", first_step)

  local_devices = sorted(jax.local_devices(), key=lambda device: device.id)
  if config.get("ema", {}):
    ema_updater = ema.ExponentialMovingAverage(
        target=partitioning.tree_unreplicate_using_partition_specs(
            jax.tree_map(np.zeros_like, opt.target),
            partition_specs=partition_specs,
            local_devices=local_devices),
        num_updates=0,
        **config.ema)
  else:
    ema_updater = None
  if first_step != 0 and ema_updater is not None:
    ema_updater = train.restore_ema_checkpoints(
        workdir,
        first_step,
        partition_specs,
        ema_updater,
        local_devices=local_devices,
        thread_pool=pool)

  train_iter = u.start_input_pipeline(train_iter, config.prefetch_to_device)
  eval_iters = train.get_dataset_eval_iters_from_config(
      config, batch_size_per_host_eval, cache, input_pipeline)
  lr_fn = u.create_learning_rate_schedule(
      config.batch_size, total_steps, steps_per_epoch, **config.lr)
  lr_iter = u.prefetch_scalar(map(lr_fn, range(first_step, total_steps)),
                              config.get("prefetch_to_device", 1))

  writer = metric_writers.create_default_writer(
      workdir, just_logging=jax.host_id() > 0, asynchronous_workers=128,
      summary_writer=config.get("write_tf_summaries", False))

  checkpoint_async_results = []
  log_training_first_n_steps = config.get("log_training_first_n_steps", -1)
  xm_work_unit.set_notes("First step compilations...")
  with metric_writers.ensure_flushes(writer):
    if jax.host_id() == 0:
      callback_fn = xprof.XmUrlCallbackFn(
          description=f"Xprof [{first_step + 171}...{first_step + 190}]",
          work_unit=xm_work_unit)
    else:
      callback_fn = lambda x: x  # Do nothing.
    xprof_session = xprof.MultiStepXprofSession(
        profile_steps=20,    # For how many steps to profile after warmup.
        warmup_steps=170,    # For how many steps to wait before profiling.
        stop_callback_fn=callback_fn)
    for step, lr_repl in zip(range(first_step + 1, total_steps + 1), lr_iter):
      train_batch = next(train_iter)
      with xprof_session:
        with jax.profiler.StepTraceAnnotation(name="train", step_num=step):
          opt, rngs_per_device, loss_value, aux_info = pmap_update_fn(
              opt,
              rngs_per_device,
              lr_repl,
              train_batch["image"],
              train_batch["labels"],
              batch_loss_fn)

      if (ema_updater is not None and
          step % config.get("ema", {}).get("period", 10) == 0):
        ema_updater = ema_updater.update(
            partitioning.tree_unreplicate_using_partition_specs(
                tree=opt.target,
                partition_specs=partition_specs,
                local_devices=local_devices))

        # Checkpoint saving.
      backup_checkpoints_every_n_steps = config.get("backup_checkpoint_steps")
      if (step % config.write_checkpoint_every_n_steps == 0 or
          (backup_checkpoints_every_n_steps is not None and
           step % backup_checkpoints_every_n_steps == 0) or
          step == total_steps):
        # Before writing new checkpoints, make sure that all the previous
        # checkpoint shards have been completely written (hosts are synced).
        train.wait_async_results(
            checkpoint_async_results,
            timeout_secs=config.checkpoint_write_timeout_secs)
        train.sync_all_hosts()
        # Now host 0 can remove all the checkpoints older than the previous
        # checkpointed step. The pool is used to remove files in parallel.
        if jax.host_id() == 0:
          train.remove_old_checkpoints(
              workdir,
              keep_steps_from=step - config.write_checkpoint_every_n_steps,
              keep_steps_multiple_of=backup_checkpoints_every_n_steps,
              thread_pool=pool)
        # Save checkpoint for the current step, asynchronously.
        # Note: Parameters on TPU are sliced and copied to CPU before scheduling
        # the asynchronous copy, to prevent any extra TPU memory usage.
        time_since_last_start = float(time.time() - start_time)
        checkpoint_async_results = train.save_checkpoints(
            workdir=workdir,
            step=step,
            partition_specs=partition_specs,
            optimizer=opt,
            # TODO(jpuigcerver): start_input_pipeline() does not return a
            # serializable iterator. Also, serialization of a "memory heavy"
            # tf.data.Dataset iterator may cause OOM (e.g. big shuffle buffer).
            train_iter=None,
            rng_state_tf=tf.random.get_global_generator().state.numpy(),
            rng_state_jax=rngs_per_device,
            global_state={
                # Note: "step" is automatically added to this dictionary.
                "accum_train_time": accum_train_time + time_since_last_start,
            },
            thread_pool=pool)
        if ema_updater is not None:
          checkpoint_async_results.append(train.save_ema_checkpoints(
              workdir=workdir,
              step=step,
              partition_specs=partition_specs,
              ema_updater=ema_updater,
              local_devices=local_devices,
              thread_pool=pool))

      # Report training progress
      if (jax.host_id() == 0 and config.log_training_every_n_steps > 0 and
          (step % config.log_training_every_n_steps == 0 or
           step == total_steps or step < log_training_first_n_steps)):
        time_elapsed = time.time() - start_time + accum_train_time
        img_sec_core = (
            config.batch_size * step / time_elapsed / jax.device_count())
        writer.write_scalars(step, {"learning_rate": lr_repl[0],
                                    "training_loss": np.mean(loss_value),
                                    "img/sec/core": img_sec_core,
                                    "epoch": step / steps_per_epoch})
        if aux_info:
          # Per-block info has to be dealt especially.
          if "per_block_info" in aux_info:
            scalar_metrics_to_aggregate = config.get(
                "scalar_metrics_to_aggregate", ())
            metrics.write_info_to_metric_writer(
                metric_writer=writer,
                step=step,
                gating_info_dict=jax.tree_map(lambda x: np.mean(x, axis=0),
                                              aux_info["per_block_info"]),
                scalar_metrics_to_aggregate=scalar_metrics_to_aggregate,
                write_matrices=True)
            del aux_info["per_block_info"]
          # Plot rest of metrics as scalars.
          writer.write_scalars(
              step, {key: np.mean(value) for key, value in aux_info.items()})

        def progress(start_time, step, num_steps, batch_size):
          """Generates progress note."""
          time_elapsed = time.time() - start_time + accum_train_time
          steps_per_sec = step / time_elapsed
          eta_seconds = (num_steps - step) / (steps_per_sec + 1e-8)
          note = ("Steps:{:d}/{:d}  [{:.1f}%]\n"
                  "Images per second:{:.1f},\n"
                  "ETA:{}, Total time:{}".format(
                      step, num_steps, 100 * step / num_steps,
                      steps_per_sec * batch_size,
                      experts_utils.htime(eta_seconds),
                      experts_utils.htime((eta_seconds + time_elapsed))))
          writer.write_scalars(
              step, {"images_per_second": steps_per_sec * batch_size})
          return note

        progress_note = progress(
            start_time,
            step,
            num_steps=total_steps,
            batch_size=config.batch_size)
        pool.apply_async(
            lambda note=progress_note: xm_work_unit.set_notes(note))

      # Run checks to detect if the model partitioning is unhealthy.
      # Global health metrics will be written to XM, and in case of problems a
      # WARNING or ERROR message will be logged.
      train.monitor_partitioning_health(
          optimizer=opt,
          partition_specs=partition_specs,
          metric_writer=writer,
          step=step,
          first_step=first_step + 1,
          every_n_steps=config.get("check_partitioning_health_every_n_steps",
                                   total_steps // 20))
      # Evaluate model on validation, test, ...
      rngs_per_device = train.run_evaluation_on_multiple_splits(
          pmap_evaluation_fn, opt.target, eval_iters, rngs_per_device,
          step / steps_per_epoch, step, total_steps,
          config.run_evaluation_every_n_steps, writer, compute_ece,
          config.get("ece_num_bins", 15), suffix="")
      if ema_updater and config.run_evaluation_every_n_steps > 0 and (
          step == first_step + 1 or
          step % config.run_evaluation_every_n_steps == 0 or
          step == total_steps):
        logging.info("Evaluation with EMA weights at step %d: started.", step)
        # Copy current parameters to CPU. Only one replica of each local
        # partition is copied to prevent redundant data transfers (e.g.
        # non-expert parameters).
        curr_params = partitioning.tree_unreplicate_using_partition_specs(
            tree=opt.target,
            partition_specs=partition_specs,
            local_devices=local_devices)
        # Block curr_params until TPU->CPU copy has finished to prevent multiple
        # copies of the TPU parameters.
        curr_params = core.tree_block_until_ready(curr_params)
        # Allow TPU parameters to be freed.
        opt = opt.replace(target=None)
        # Copy EMA parameters to TPU and run evaluation.
        rngs_per_device = train.run_evaluation_on_multiple_splits(
            pmap_evaluation_fn,
            partitioning.tree_replicate_from_partitioned_tree(
                ema_updater.get(),
                partition_specs=partition_specs,
                local_devices=local_devices),
            eval_iters, rngs_per_device, step / steps_per_epoch, step,
            total_steps, config.run_evaluation_every_n_steps, writer,
            compute_ece, config.get("ece_num_bins", 15), suffix="_ema")
        rngs_per_device = core.tree_block_until_ready(rngs_per_device)
        # Copy current parameters back to the TPU.
        opt = opt.replace(
            target=partitioning.tree_replicate_from_partitioned_tree(
                curr_params,
                partition_specs=partition_specs,
                local_devices=local_devices))
        logging.info("Evaluation with EMA weights at step %d: finished.", step)
        del curr_params

  pool.close()
  pool.join()
def main(argv):
  del argv

  config = FLAGS.config
  workdir = FLAGS.workdir

  logging.info("Workdir: %s", workdir)

  save_checkpoint_path = None
  if config.get("checkpoint_steps"):
    tf.io.gfile.makedirs(workdir)
    save_checkpoint_path = os.path.join(workdir, "checkpoint.npz")

  # The pool is used to perform misc operations such as logging in async way.
  pool = multiprocessing.pool.ThreadPool()

  # This seed makes the Jax part of things (like model init) deterministic.
  # However, full training still won't be deterministic, for example due to the
  # tf.data pipeline not being deterministic even if we would set TF seed.
  rng = jax.random.PRNGKey(config.get("seed", 0))

  def write_note(note):
    if jax.host_id() == 0:
      logging.info("NOTE: %s", note)
  write_note("Initializing...")

  # Verify settings to make sure no checkpoints are accidentally missed.
  if config.get("keep_checkpoint_steps"):
    assert config.get("checkpoint_steps"), "Specify `checkpoint_steps`."
    assert config.keep_checkpoint_steps % config.checkpoint_steps == 0, (
        f"`keep_checkpoint_steps` ({config.checkpoint_steps}) should be"
        f"divisible by `checkpoint_steps ({config.checkpoint_steps}).`")

  batch_size = config.batch_size
  batch_size_eval = config.get("batch_size_eval", batch_size)
  if (batch_size % jax.device_count() != 0 or
      batch_size_eval % jax.device_count() != 0):
    raise ValueError(f"Batch sizes ({batch_size} and {batch_size_eval}) must "
                     f"be divisible by device number ({jax.device_count()})")

  local_batch_size = batch_size // jax.host_count()
  local_batch_size_eval = batch_size_eval // jax.host_count()
  logging.info(
      "Global batch size %d on %d hosts results in %d local batch size. "
      "With %d dev per host (%d dev total), that's a %d per-device batch size.",
      batch_size, jax.host_count(), local_batch_size,
      jax.local_device_count(), jax.device_count(),
      local_batch_size // jax.local_device_count())

  write_note("Initializing train dataset...")
  train_ds = input_pipeline.get_data(
      dataset=config.dataset,
      split=config.train_split,
      data_dir=fillin(config.get("dataset_dir")),
      batch_size=local_batch_size,
      preprocess_fn=pp_builder.get_preprocess_fn(config.pp_train),
      shuffle_buffer_size=config.shuffle_buffer_size,
      prefetch=config.get("prefetch_to_host", 2),
      cache=False)

  # Start prefetching already.
  train_iter = u.start_input_pipeline(
      train_ds, config.get("prefetch_to_device", 1), pad=local_batch_size)
  # We always pad to local_batch_size_eval even when less would be enough in
  # order to minimize memory fragmentation.

  write_note("Initializing val dataset(s)...")
  def _get_val_split(dataset, split, pp_eval, data_dir=None):
    # We do ceil rounding such that we include the last incomplete batch.
    nval_img = input_pipeline.get_num_examples(
        dataset, split, data_dir=fillin(data_dir))
    val_steps = int(np.ceil(nval_img / batch_size_eval))
    logging.info("Running validation for %d steps for %s, %s", val_steps,
                 dataset, split)

    val_it = input_pipeline.get_data(
        dataset=dataset,
        split=split,
        data_dir=fillin(data_dir),
        batch_size=local_batch_size_eval,
        preprocess_fn=pp_builder.get_preprocess_fn(pp_eval),
        cache=config.get("val_cache", "batched"),
        repeat_after_batching=True,
        prefetch=0,  # Save memory since we cache.
        drop_remainder=False,
        shuffle_files=False)
    val_it = u.start_input_pipeline(
        val_it, config.get("prefetch_to_device", 1), pad=local_batch_size_eval)

    return (val_it, val_steps)

  if isinstance(config.val_split, str):
    val_ds = {"val": _get_val_split(config.dataset, config.val_split,
                                    config.pp_eval, config.get("dataset_dir"))}
  else:
    val_ds = {t[0]: _get_val_split(*t[1:]) for t in config.val_split}

  ntrain_img = input_pipeline.get_num_examples(
      config.dataset, config.train_split,
      data_dir=fillin(config.get("dataset_dir")))
  steps_per_epoch = ntrain_img / batch_size

  if config.get("num_epochs"):
    total_steps = int(config.num_epochs * steps_per_epoch)
    assert not config.get("total_steps"), "Set either num_epochs or total_steps"
  else:
    total_steps = config.total_steps

  logging.info(
      "Running for %d steps, that means %f epochs and %f steps per epoch",
      total_steps, total_steps * batch_size / ntrain_img, steps_per_epoch)
  mw = u.BigVisionMetricWriter(xm_xp.id, xm_wu.id, steps_per_epoch)

  write_note(f"Initializing {config.model_name} model...")
  model_mod = importlib.import_module(f"{BASEDIR}.models.{config.model_name}")
  model = model_mod.Model(
      num_classes=config.num_classes, **config.get("model", {}))

  # We want all parameters to be created in host RAM, not on any device, they'll
  # be sent there later as needed, otherwise we already encountered two
  # situations where we allocate them twice.
  @partial(jax.jit, backend="cpu")
  def init(rng):
    image_size = tuple(train_ds.element_spec["image"].shape[1:])
    dummy_input = jnp.zeros((local_batch_size,) + image_size, jnp.float32)
    params = flax.core.unfreeze(model.init(rng, dummy_input))["params"]

    # Set bias in the head to a low value, such that loss is small initially.
    params["head"]["bias"] = jnp.full_like(
        params["head"]["bias"], config.get("init_head_bias", 0))

    return params

  rng, rng_init = jax.random.split(rng)
  params_cpu = init(rng_init)

  if jax.host_id() == 0:
    num_params = sum(p.size for p in jax.tree_flatten(params_cpu)[0])
    parameter_overview.log_parameter_overview(params_cpu)
    mw.measure("num_params", num_params)

  @partial(jax.pmap, axis_name="batch")
  def evaluation_fn(params, images, labels, mask):
    # Ignore the entries with all zero labels for evaluation.
    mask *= labels.max(axis=1)
    logits, _ = model.apply({"params": flax.core.freeze(params)}, images)

    losses = getattr(u, config.get("loss", "sigmoid_xent"))(
        logits=logits, labels=labels, reduction=False)
    loss = jax.lax.psum(losses * mask, axis_name="batch")

    top1_idx = jnp.argmax(logits, axis=1)
    # Extracts the label at the highest logit index for each image.
    top1_correct = jnp.take_along_axis(labels, top1_idx[:, None], axis=1)[:, 0]
    ncorrect = jax.lax.psum(top1_correct * mask, axis_name="batch")
    n = jax.lax.psum(mask, axis_name="batch")
    return ncorrect, loss, n

  # Setup function for computing representation.
  @partial(jax.pmap, axis_name="batch")
  def representation_fn(params, images, labels, mask):
    _, outputs = model.apply({"params": flax.core.freeze(params)}, images)
    representation = outputs[config.fewshot.representation_layer]
    representation = jax.lax.all_gather(representation, "batch")
    labels = jax.lax.all_gather(labels, "batch")
    mask = jax.lax.all_gather(mask, "batch")
    return representation, labels, mask

  # Load the optimizer either from our folder or from flax.
  opt_name = config.get("optim_name", "momentum_hp")
  write_note(f"Initializing {opt_name} optimizer...")
  try:
    opt_mod = importlib.import_module(f"{BASEDIR}.optims.{opt_name}")
    opt_def = opt_mod.Optimizer(**config.get("optim", {}))
  except ModuleNotFoundError:
    opt_def = getattr(flax.optim, opt_name)(**config.get("optim", {}))

  # We jit this, such that the arrays that are created are created on the same
  # device as the input is, in this case the CPU. Else they'd be on device[0].
  opt_cpu = jax.jit(opt_def.create)(params_cpu)

  @partial(jax.pmap, axis_name="batch", donate_argnums=(0,))
  def update_fn(opt, lr, images, labels, rng):
    """Update step."""

    measurements = {}

    if config.get("mixup") and config.mixup.p:
      rng, (images, labels), _ = u.mixup(rng, images, labels, **config.mixup)

    # Get device-specific loss rng.
    rng, rng_model = jax.random.split(rng, 2)
    rng_model_local = jax.random.fold_in(rng_model, jax.lax.axis_index("batch"))

    def loss_fn(params, images, labels):
      logits, _ = model.apply(
          {"params": flax.core.freeze(params)}, images,
          train=True, rngs={"dropout": rng_model_local})
      return getattr(u, config.get("loss", "sigmoid_xent"))(
          logits=logits, labels=labels)

    # Implementation considerations compared and summarized at
    # https://docs.google.com/document/d/1g3kMEvqu1DOawaflKNyUsIoQ4yIVEoyE5ZlIPkIl4Lc/edit?hl=en#
    l, g = u.accumulate_gradient(jax.value_and_grad(loss_fn), opt.target,
                                 images, labels,
                                 config.get("grad_accum_steps"))
    l, g = jax.lax.pmean((l, g), axis_name="batch")

    # Log the gradient norm only if we need to compute it anyways (clipping)
    # or if we don't use grad_accum_steps, as they interact badly.
    if config.get("grad_accum_steps", 1) == 1 or config.get("grad_clip_norm"):
      grads, _ = jax.tree_flatten(g)
      l2_g = jnp.sqrt(sum([jnp.vdot(p, p) for p in grads]))
      measurements["l2_grads"] = l2_g

    # Optionally resize the global gradient to a maximum norm. We found this
    # useful in some cases across optimizers, hence it's in the main loop.
    if config.get("grad_clip_norm"):
      g_factor = jnp.minimum(1.0, config.grad_clip_norm / l2_g)
      g = jax.tree_map(lambda p: g_factor * p, g)
    opt = opt.apply_gradient(g, learning_rate=lr)

    decay_rules = config.get("weight_decay", []) or []
    if isinstance(decay_rules, numbers.Number):
      decay_rules = [(".*kernel.*", decay_rules)]
    sched_m = lr/config.lr.base if config.get("weight_decay_decouple") else lr
    def decay_fn(v, wd):
      return (1.0 - sched_m * wd) * v
    opt = opt.replace(target=u.tree_map_with_regex(
        decay_fn, opt.target, decay_rules, name="weight decay"))

    params, _ = jax.tree_flatten(opt.target)
    measurements["l2_params"] = jnp.sqrt(sum([jnp.vdot(p, p) for p in params]))

    return opt, l, rng, measurements

  # Other things besides optimizer state to be stored.
  checkpoint_extra = dict(accum_train_time=0.0)

  # Decide how to initialize training. The order is important.
  # 1. Always resumes from the existing checkpoint, e.g. resumes a finetune job.
  # 2. Resume from a previous checkpoint, e.g. start a cooldown training job.
  # 3. Initialize model from something, e,g, start a fine-tuning job.
  # 4. Train from scratch.
  resume_checkpoint_path = None
  if save_checkpoint_path and tf.io.gfile.exists(save_checkpoint_path):
    resume_checkpoint_path = save_checkpoint_path
  elif config.get("resume"):
    resume_checkpoint_path = fillin(config.resume)
  if resume_checkpoint_path:
    write_note("Resume training from checkpoint...")
    checkpoint = {"opt": opt_cpu, "extra": checkpoint_extra}
    _, checkpoint_tree = jax.tree_flatten(checkpoint)
    loaded = u.load_checkpoint(checkpoint_tree, resume_checkpoint_path)
    # bfloat16 type gets lost when data is saved to disk, so we recover it.
    checkpoint = jax.tree_map(u.recover_dtype, loaded)
    opt_cpu, checkpoint_extra = checkpoint["opt"], checkpoint["extra"]
  elif config.get("model_init"):
    write_note(f"Initialize model from {config.model_init}...")
    loaded = model_mod.load(params_cpu, config.model_init, config.get("model"))
    opt_cpu = opt_cpu.replace(target=loaded)
    if jax.host_id() == 0:
      logging.info("Restored parameter overview:")
      parameter_overview.log_parameter_overview(loaded)

  write_note("Kicking off misc stuff...")
  first_step = int(opt_cpu.state.step)  # Might be a DeviceArray type.
  chrono = u.Chrono(first_step, total_steps, batch_size,
                    checkpoint_extra["accum_train_time"])
  # Note: switch to ProfileAllHosts() if you need to profile all hosts.
  # (Xprof data become much larger and take longer to load for analysis)
  profiler = periodic_actions.Profile(
      # Create profile after every restart to analyze pre-emption related
      # problems and assure we get similar performance in every run.
      logdir=workdir, first_profile=first_step + 10)

  # Prepare the learning-rate and pre-fetch it to device to avoid delays.
  lr_fn = u.create_learning_rate_schedule(
      batch_size, total_steps, steps_per_epoch, **config.get("lr", {}))
  lr_iter = u.prefetch_scalar(map(lr_fn, range(first_step, total_steps)),
                              config.get("prefetch_to_device", 1))

  write_note(f"Replicating...\n{chrono.note}")
  opt_repl = flax_utils.replicate(opt_cpu)

  write_note(f"Initializing few-shotters...\n{chrono.note}")
  if "fewshot" in config:
    fewshotter = fewshot.FewShotEvaluator(
        representation_fn, config.fewshot,
        config.fewshot.get("batch_size") or batch_size_eval)

  rng, rng_loop = jax.random.split(rng, 2)
  rngs_loop = flax_utils.replicate(rng_loop)
  checkpoint_writer = None

  write_note(f"First step compilations...\n{chrono.note}")
  # Using a python integer for step here, because opt.state.step is allocated
  # on TPU during replication.
  for step, train_batch, lr_repl in zip(
      range(first_step + 1, total_steps + 1), train_iter, lr_iter):
    mw.step_start(step)

    with jax.profiler.TraceContext("train_step", step_num=step, _r=1):
      opt_repl, loss_value, rngs_loop, extra_measurements = update_fn(
          opt_repl,
          lr_repl,
          train_batch["image"],
          train_batch["labels"],
          rng=rngs_loop)

    if jax.host_id() == 0:
      profiler(step)

    # Checkpoint saving
    if u.itstime(step, config.get("checkpoint_steps"), total_steps, host=0):
      chrono.pause()
      u.checkpointing_timeout(checkpoint_writer,
                              config.get("checkpoint_timeout", 1))
      checkpoint_extra["accum_train_time"] = chrono.accum_train_time
      # We need to transfer the weights over now or else we risk keeping them
      # alive while they'll be updated in a future step, creating hard to debug
      # memory errors (see b/160593526). Also, takes device 0's params only.
      opt_cpu = jax.tree_map(lambda x: np.array(x[0]), opt_repl)

      # Check whether we want to keep a copy of the current checkpoint.
      copy_step = None
      if u.itstime(step, config.get("keep_checkpoint_steps"), total_steps):
        copy_step = step

      # Checkpoint should be a nested dictionary or FLAX datataclasses from
      # `flax.struct`. Both can be present in a checkpoint.
      checkpoint = {"opt": opt_cpu, "extra": checkpoint_extra}
      checkpoint_writer = pool.apply_async(
          u.save_checkpoint, (checkpoint, save_checkpoint_path, copy_step))
      chrono.resume()

    # Report training progress
    if u.itstime(step, config.log_training_steps, total_steps, host=0):
      mw.measure("learning_rate", lr_repl[0])
      mw.measure("training_loss", loss_value[0])
      for name, value in extra_measurements.items():
        mw.measure(name, value[0])
      chrono.tick(step, mw.measure, write_note)

    # Report validation performance
    if u.itstime(step, config.log_eval_steps, total_steps):
      chrono.pause()
      for val_name, (val_iter, val_steps) in val_ds.items():
        ncorrect, loss, nseen = 0, 0, 0
        for _, batch in zip(range(val_steps), val_iter):
          batch_ncorrect, batch_losses, batch_n = evaluation_fn(
              opt_repl.target, batch["image"], batch["labels"], batch["mask"])
          # All results are a replicated array shaped as follows:
          # (local_devices, per_device_batch_size, elem_shape...)
          # with each local device's entry being identical as they got psum'd.
          # So let's just take the first one to the host as numpy.
          ncorrect += np.sum(np.array(batch_ncorrect[0]))
          loss += np.sum(np.array(batch_losses[0]))
          nseen += np.sum(np.array(batch_n[0]))
        mw.measure(f"{val_name}_prec@1", ncorrect / nseen)
        mw.measure(f"{val_name}_loss", loss / nseen)
      chrono.resume()

    if "fewshot" in config:
      # Compute few-shot on-the-fly evaluation.
      if u.itstime(step, config.fewshot.log_steps, total_steps):
        chrono.pause()
        write_note(f"Few-shot evaluation...\n{chrono.note}")
        r = fewshotter.run_all(opt_repl.target, config.fewshot.datasets)
        fewshotter.walk_results(mw.measure, *r)
        chrono.resume()
    mw.step_end()

  write_note(f"Done!\n{chrono.note}")
  pool.close()
  pool.join()
  mw.close()
 def loss_fn(params, images, labels):
   logits, _ = model.apply(
       {"params": flax.core.freeze(params)}, images,
       train=True, rngs={"dropout": rng_model_local})
   return getattr(u, config.get("loss", "sigmoid_xent"))(
       logits=logits, labels=labels)