예제 #1
0
    def loss_fn(model):
        if train:
            with nn.stateful(state.model_state) as new_model_state:
                with nn.stochastic(run_rng):
                    if not class_conditional:
                        scores = model(perturbed_data, labels, train=train)
                    else:
                        scores = model(perturbed_data,
                                       labels,
                                       y=class_labels,
                                       train=train)
        else:
            with nn.stateful(state.model_state, mutable=False):
                with nn.stochastic(run_rng):
                    if not class_conditional:
                        scores = model(perturbed_data, labels, train=train)
                    else:
                        scores = model(perturbed_data,
                                       labels,
                                       y=class_labels,
                                       train=train)

            new_model_state = state.model_state

        scores = scores.reshape((scores.shape[0], -1))
        target = -1 / (used_sigmas**2) * noise
        target = target.reshape((target.shape[0], -1))
        losses = 1 / 2. * ((scores - target)**2).sum(
            axis=-1) * used_sigmas.squeeze()**anneal_power
        loss = jnp.mean(losses)

        if loss_per_sigma:
            return loss, new_model_state, losses
        else:
            return loss, new_model_state
  def training_loss_fn(self, flax_model, train_state, batch, dropout_rng,
                       env_ids):
    """Runs forward pass and computes loss.

    Args:
      flax_model: A flax module.
      train_state: TrainState, the state of training including the current
        global_step, model_state, rng, and optimizer.
      batch: Batches from different environments.
      dropout_rng: FLAX PRNG key.
      env_ids: list(int); List of environment codes.

    Returns:
      loss, new_module_state and computed logits for each batch.
    """
    del env_ids
    inputs = pipeline_utils.get_multi_env_inputs(batch, 'inputs')

    with nn.stochastic(dropout_rng):
      env_logits, new_model_state = pipeline_utils.vmapped_flax_module_train(
          flax_model, train_state.model_state, inputs)

    #  Model state, e.g. batch statistics, are averaged over all environments
    #  because we use vmapped_flax_module_train.
    new_model_state = jax.tree_util.tree_map(
        functools.partial(jnp.mean, axis=0), new_model_state)

    loss = self.task.loss_function(env_logits, batch, flax_model.params,
                                   train_state.global_step)
    logs = None
    return loss, (new_model_state, env_logits, logs)
예제 #3
0
  def forward_pass(self,
                   flax_model,
                   train_state,
                   batch,
                   rng,
                   input_layer_key='input',
                   train=True):
    """Forward pass.

    Args:
      flax_model: flax.deprecated.nn.Model; Flax model.
      train_state: TrainState object.
      batch: dict; A batch of examples.
      rng: float; Jax random number generator key.
      input_layer_key: str; Which layer the input should be plugged in.
      train: bool; Train flag.

    Returns:
      logits, hidden activations, activations of key layer, and new model state.
    """
    # bind the rng to the host/device we are on.
    rng = pipeline_utils.bind_rng_to_host_device(
        rng, axis_name='batch', bind_to=['host', 'device'])

    inputs = batch['inputs']

    with nn.stochastic(rng):
      (logits, all_reps, selected_reps,
       new_model_state) = pipeline_utils.forward_pass_with_reps(
           inputs, flax_model, train_state.model_state, input_layer_key, train)

    selected_reps = selected_reps.reshape(
        (selected_reps.shape[0], selected_reps.shape[1], -1))

    return logits, all_reps, selected_reps, new_model_state
예제 #4
0
 def impl_loss_fn(model_params):
     with nn.stochastic(rng), nn.stateful(
             state.model_state) as new_model_state:
         logits, stats = module.call(model_params, batch["image"])
     losses = loss_fn if isinstance(loss_fn, (list, tuple)) else [loss_fn]
     loss = sum(l(logits, batch["label"], stats) for l in losses)
     return loss, (logits, new_model_state, stats)
  def training_loss_fn(self, flax_model, train_state, batch, dropout_rng,
                       env_ids):
    """Runs forward pass and computes loss.

    Args:
      flax_model: A flax module.
      train_state: TrainState, the state of training including the current
        global_step, model_state, rng, and optimizer.
      batch: Batches from different environments.
      dropout_rng: FLAX PRNG key.
      env_ids: list[int]; List of env codes.

    Returns:
      loss, new_module_state and computed logits for each batch.
    """
    env_logits, _, selected_env_reps, new_model_state = self.forward_pass(
        flax_model, train_state, batch, dropout_rng)
    #  Model state, e.g. batch statistics, are averaged over all environments
    #  because we use vmapped_flax_module_train.
    new_model_state = jax.tree_util.tree_map(
        functools.partial(jnp.mean, axis=0), new_model_state)

    with nn.stochastic(dropout_rng):
      # Compute the total loss (inside nn.stochastic):
      loss = self.task.loss_function(env_logits, selected_env_reps, batch,
                                     env_ids, flax_model.params,
                                     train_state.global_step)
    logs = None
    return loss, (new_model_state, env_logits, logs)
예제 #6
0
    def loss_fn(model):
        if train:
            with nn.stateful(state.model_state) as new_model_state:
                with nn.stochastic(run_rng):
                    scores = model(perturbed_data, T, train=train)
        else:
            with nn.stateful(state.model_state, mutable=False):
                with nn.stochastic(run_rng):
                    scores = model(perturbed_data, T, train=train)

            new_model_state = state.model_state

        scores = scores.reshape((scores.shape[0], -1))
        target = noise.reshape((noise.shape[0], -1))
        loss = jnp.mean((scores - target)**2)
        return loss, new_model_state
예제 #7
0
def eval_step(
    state,
    module,
    batch,
    metrics_dict,
    rng):
  """Compute the metrics for the given model in inference mode.

  The model is applied to the inputs using all devices on the host. Afterwards
  metrics are averaged across *all* devices (of all hosts).

  Args:
    state: Replicated model state.
    module: Model function.
    batch: Inputs that should be evaluated.
    metrics_dict: A dictionary of metrics, mapping names to metric functions.
    rng: Jax pseudo-random number generator key.

  Returns:
    Dictionary of replicated metrics, stats output by the model and updated PRNG
      key.
  """
  rng, new_rng = jax.random.split(rng)
  with nn.stochastic(rng), flax.deprecated.nn.stateful(
      state.model_state, mutable=False):
    logits, stats = module.call(state.model_params, batch["image"], train=False)
  metrics = {m: fn(logits, batch["label"], stats)
             for (m, fn) in metrics_dict.items()}
  metrics = jax.lax.all_gather(metrics, axis_name="batch")
  stats = jax.lax.all_gather(stats, axis_name="batch")
  return metrics, stats, new_rng
예제 #8
0
def create_model(module, input_shape, rng):
  """Instanciates the model."""
  model_rng, init_rng = jax.random.split(rng)
  with nn.stochastic(model_rng), nn.stateful() as init_state:
    x = jnp.ones(input_shape, dtype=jnp.float32)
    _, init_params = module.init(init_rng, x)
  model = nn.Model(module, init_params)
  return model, init_params, init_state
        def get_reps(train_state, flax_module, batch):
            with nn.stochastic(train_state.rng):
                with nn.stateful(train_state.model_state):
                    _, reps, _ = flax_module(batch['inputs'],
                                             train=True,
                                             return_activations=True)

            return reps
예제 #10
0
 def _create_flax_module():
     device_batch_size = hparams.batch_size // jax.device_count()
     shape = (device_batch_size, ) + tuple(input_shape[1:])
     model_rng, init_rng = jax.random.split(rng)
     with nn.stateful() as init_model_state:
         with nn.stochastic(model_rng):
             _, initial_params = flax_module_def.init_by_shape(
                 init_rng, [(shape, model_input_dtype)])
     flax_module = nn.Model(flax_module_def, initial_params)
     num_trainable_params = model_utils.log_param_shapes(flax_module)
     return flax_module, init_model_state, num_trainable_params
예제 #11
0
def train_step(
    state,
    batch,
    module,
    loss_fn,
    optimizer,
    metrics_dict,
    rng
):
  """Perform a single training step.

  Args:
    state: Current training state. Updated training state will be returned.
    batch: Training inputs for this step.
    module: Module function.
    loss_fn: Loss function that takes logits and labels as input.
    optimizer: Optax optimizer to compute updates from gradients.
    metrics_dict: A dictionary of metrics, mapping names to metric functions.
    rng: Jax pseudo-random number generator key.

  Returns:
    Tuple of updated state, dictionary with metrics, and updated PRNG key.
  """

  rng, new_rng = jax.random.split(rng)

  def impl_loss_fn(model_params):
    with nn.stochastic(rng), nn.stateful(state.model_state) as new_model_state:
      logits, stats = module.call(model_params, batch["image"])
    losses = loss_fn if isinstance(loss_fn, (list, tuple)) else [loss_fn]
    loss = sum(l(logits, batch["label"], stats) for l in losses)
    return loss, (logits, new_model_state, stats)

  grad_fn = jax.value_and_grad(impl_loss_fn, has_aux=True)
  with nn.stochastic(rng):
    (_, loss_aux), grad = grad_fn(state.model_params)
  logits, new_model_state, stats = loss_aux
  # Compute average gradient across multiple workers.
  grad = jax.lax.pmean(grad, axis_name="batch")
  updates, new_opt_state = optimizer.update(grad, state.optimizer_state,
                                            params=state.model_params)
  new_model_params = optax.apply_updates(state.model_params, updates)
  metrics = {m: fn(logits, batch["label"], stats)
             for (m, fn) in metrics_dict.items()}
  metrics = jax.lax.all_gather(metrics, axis_name="batch")
  stats = jax.lax.all_gather(stats, axis_name="batch")
  stats = jax.tree_map(lambda x: x[0], stats)
  new_state = state.replace(  # pytype: disable=attribute-error
      step=state.step + 1,
      optimizer_state=new_opt_state,
      model_state=new_model_state,
      model_params=new_model_params)
  return new_state, grad, updates, metrics, stats, new_rng
예제 #12
0
  def training_loss_fn(self, flax_module, train_state, batch, dropout_rng):
    """Runs forward pass and computes loss.

    Args:
      flax_module: A flax module.
      train_state: TrainState, the state of training including the current
        global_step, model_state, rng, and optimizer.
      batch: Batches from different environments.
      dropout_rng: FLAX PRNG key.

    Returns:
      loss, new_module_state and computed logits for each batch.
    """
    with nn.stateful(train_state.model_state) as new_model_state:
      with nn.stochastic(dropout_rng):
        logits = flax_module(batch['inputs'], train=True)

    loss = self.task.loss_function(logits, batch, flax_module.params)
    return loss, (new_model_state, logits)
예제 #13
0
def pseudo_label_generator(batch,
                           train_state,
                           pseudo_labels_transformer_fn=lambda x: (x, None),
                           input_key='inputs',
                           train=True):
    """Pseudo label generator passed to the dataset class.

  This function can be passed to datasets initializer for self-supervised
   training or distillation.

  Args:
    batch: dict; Batch of examples, witch an 'inputs' key.
    train_state: TrainState; Train state of the model which we want to use to
      generate pseudo labels.
    pseudo_labels_transformer_fn: function; A function that applies a specific
      transformation on the logits from the model to generate the labels. The
      most basic function to be used here is a simple softmax or argmax to get
      one-hot labels. This function should return the labels and the weights for
      each example in the batch (for each label) and has the following API: ```
        new_labels, weights = pseudo_labels_transformer(logits) ```
    input_key: str; What key to use to retrieve the input field of the batch.
    train: bool; Train flag passed to the model forward pass.

  Returns:
    Return the batch with ground truth labels and weights replaced with
    pseudo labels and new weights.
  """
    inputs = batch[input_key]
    _, dropout_rng = jax.random.split(train_state.rng)

    with nn.stochastic(dropout_rng):
        with nn.stateful(train_state.model_state):
            logits = train_state.optimizer.target(inputs, train=train)
            # Make sure the parameter of the teacher are not updated.

            logits = jax.lax.stop_gradient(logits)

            batch['label'], weights = pseudo_labels_transformer_fn(logits)

            if weights is not None:
                batch['weights'] = weights

    return batch
예제 #14
0
  def loss_fn(model):
    """Loss function used for training."""
    # Stateful collection for tracking internal state like activations.
    with nn.stateful() as batch_stats:
      with nn.stochastic(dropout_rng):
        outputs = model(inputs, train=True, cache=None)

      if isinstance(outputs, dict):
        logits = outputs.get('logits', None)
        regression_predictions = outputs.get('regression', None)
      else:
        logits = outputs
        regression_predictions = None

    mean_loss = 0.0

    # Classification loss
    if classification_targets is not None:
      classification_loss, classification_weight_sum = utils.compute_weighted_cross_entropy(
          logits,
          classification_targets,
          token_weights=classification_weights,
          example_weights=example_weights)
      classification_weight_sum = jnp.maximum(classification_weight_sum,
                                              epsilon)
      # Handle case where nothing is masked out in BERT
      # (Only occurs with very short sequences).
      mean_classification_loss = classification_loss / classification_weight_sum
      mean_loss += mean_classification_loss

    if regression_targets is not None:
      regression_loss, regression_weight_sum = utils.compute_weighted_mse(
          regression_predictions,
          regression_targets,
          weights=regression_weights)
      regression_weight_sum = jnp.maximum(regression_weight_sum, epsilon)
      mean_regression_loss = regression_loss / regression_weight_sum
      outputs['regression_loss'] = mean_regression_loss

      # TODO(ddohan): Allow weighting each loss separately.
      mean_loss += mean_regression_loss

    return mean_loss, (outputs, batch_stats)
예제 #15
0
  def forward_pass(self,
                   flax_model,
                   train_state,
                   batch,
                   rng,
                   input_layer_key='input',
                   train=True):
    # bind the rng to the host/device we are on.
    rng = pipeline_utils.bind_rng_to_host_device(
        rng, axis_name='batch', bind_to=['host', 'device'])

    inputs = pipeline_utils.get_multi_env_inputs(batch, 'inputs')

    with nn.stochastic(rng):
      (env_logits, all_env_reps, selected_env_reps,
       new_model_state) = pipeline_utils.vmapped_flax_module_with_reps(
           inputs, flax_model, train_state.model_state, input_layer_key, train)

    selected_env_reps = selected_env_reps.reshape(
        (selected_env_reps.shape[0], selected_env_reps.shape[1], -1))

    return env_logits, all_env_reps, selected_env_reps, new_model_state
  def __init__(self, model_cls, task, hparams, experiment_dir,
               tb_summary_writer, rng):
    rng, init_rng = jax.random.split(rng)
    super().__init__(model_cls, task, hparams, experiment_dir,
                     tb_summary_writer, init_rng)

    # Set up state transformers to compute the representation based
    # auxilary loss.

    # Get sample batch
    # TODO(samiraabnar): Refactor this by implementing a sample_batch for task.
    _, train_iters = list(
        zip(*dict(self.task.dataset.data_iters['train']).items()))
    init_batch = self.get_next_batch(train_iters)

    # Run the forward pass once to get the representations and their dimensions.
    flax_model = self.train_state.optimizer.target
    with nn.stochastic(rng):
      _, _, selected_env_reps, _ = jax.pmap(
          self.forward_pass,
          axis_name='batch')(flax_model, self.train_state, init_batch,
                             self.train_state.rng)
      self.task.setup_transformers(hidden_reps_dim=selected_env_reps.shape[-1])
  def get_env_aligned_pairs_idx(self, env_reps, env_batches, env_ids):
    """Computes alignments between all environment pairs.

    Args:
      env_reps: jnp array; Reps for different environments (sharded).
      env_batches: list of dict; Batches of different environments (sharded).
      env_ids: jnp array; Environment ids.

    Returns:
      alignment between batches of environment pairs (sharded).
    """
    # TODO(riannevdberg, samiraabnar): aligning is done on the total
    #  unsharded batch, but that requires access between local batches
    #  when computing the loss. Unsure why this works! To be compatible
    #  with random alignment and sinkhorn soft alignment we should do
    #  alignment only within local batches.
    env_reps = shard_util.unshard_env_batch(env_reps)
    env_batches = shard_util.unshard(env_batches)
    with nn.stochastic(jax_utils.unreplicate(self.train_state.rng)):
      alignments = self.task.get_env_aligned_pairs_idx(env_reps, env_batches,
                                                       env_ids)
    alignments = dataset_utils.shard(alignments)

    return alignments
예제 #18
0
def train(config, workdir):
  """Runs a training loop.

  Args:
    config: Configuration to use.
    workdir: Working directory for checkpoints and TF summaries. If this
      contains checkpoint training will be resumed from the latest checkpoint.
  """

  # Create directories for experimental logs
  tf.io.gfile.makedirs(workdir)
  sample_dir = os.path.join(workdir, "samples")
  tf.io.gfile.makedirs(sample_dir)
  rng = jax.random.PRNGKey(config.seed)
  tb_dir = os.path.join(workdir, "tensorboard")
  tf.io.gfile.makedirs(tb_dir)
  if jax.host_id() == 0:
    writer = tensorboard.SummaryWriter(tb_dir)

  # Initialize model.
  rng, model_rng = jax.random.split(rng)
  model_name = config.model.name
  ncsn_def = mutils.get_model(model_name).partial(config=config)
  rng, run_rng = jax.random.split(rng)
  # Whether the generative model is conditioned on class labels
  class_conditional = "conditional" in config.training.loss.lower()
  with nn.stateful() as init_model_state:
    with nn.stochastic(run_rng):
      input_shape = (jax.local_device_count(), config.data.image_size,
                     config.data.image_size, 3)
      input_list = [(input_shape, jnp.float32), (input_shape[:1], jnp.int32)]
      if class_conditional:
        input_list.append(input_list[-1])
      _, initial_params = ncsn_def.init_by_shape(
          model_rng, input_list, train=True)
      ncsn = nn.Model(ncsn_def, initial_params)

  optimizer = losses.get_optimizer(config).create(ncsn)

  state = mutils.State(step=0, optimizer=optimizer, lr=config.optim.lr,
                       model_state=init_model_state,
                       ema_rate=config.model.ema_rate,
                       params_ema=initial_params,
                       rng=rng)  # pytype: disable=wrong-keyword-args

  del ncsn, init_model_state  # Do not keep a copy of the initial model.

  # Create checkpoints directory and the initial checkpoint
  checkpoint_dir = os.path.join(workdir, "checkpoints")
  ckpt = utils.Checkpoint(
      checkpoint_dir,
      max_to_keep=None)
  ckpt.restore_or_initialize(state)

  # Save intermediate checkpoints to resume training automatically
  checkpoint_meta_dir = os.path.join(workdir, "checkpoints-meta")
  ckpt_meta = utils.Checkpoint(
      checkpoint_meta_dir,
      max_to_keep=1)
  state = ckpt_meta.restore_or_initialize(state)
  initial_step = int(state.step)
  rng = state.rng

  # Build input pipeline.
  rng, ds_rng = jax.random.split(rng)
  train_ds, eval_ds, _ = datasets.get_dataset(ds_rng, config)
  train_iter = iter(train_ds)  # pytype: disable=wrong-arg-types
  eval_iter = iter(eval_ds)  # pytype: disable=wrong-arg-types
  scaler = datasets.get_data_scaler(config)  # data normalizer
  inverse_scaler = datasets.get_data_inverse_scaler(config)

  # Distribute training.
  optimize_fn = losses.optimization_manager(config)
  if config.training.loss.lower() == "ddpm":
    # Use score matching loss with DDPM-type perturbation.
    ddpm_params = mutils.get_ddpm_params()
    train_step = functools.partial(losses.ddpm_loss, ddpm_params=ddpm_params,
                                   train=True, optimize_fn=optimize_fn)
    eval_step = functools.partial(losses.ddpm_loss, ddpm_params=ddpm_params,
                                  train=False)
  else:
    # Use score matching loss with NCSN-type perturbation.
    sigmas = mutils.get_sigmas(config)
    # Whether to use a continuous distribution of noise levels
    continuous = "continuous" in config.training.loss.lower()
    train_step = functools.partial(
        losses.ncsn_loss,
        sigmas=sigmas,
        class_conditional=class_conditional,
        continuous=continuous,
        train=True,
        optimize_fn=optimize_fn,
        anneal_power=config.training.anneal_power)
    eval_step = functools.partial(
        losses.ncsn_loss,
        sigmas=sigmas,
        class_conditional=class_conditional,
        continuous=continuous,
        train=False,
        anneal_power=config.training.anneal_power)

  p_train_step = jax.pmap(train_step, axis_name="batch")
  p_eval_step = jax.pmap(eval_step, axis_name="batch")
  state = flax_utils.replicate(state)

  num_train_steps = config.training.n_iters

  logging.info("Starting training loop at step %d.", initial_step)
  rng = jax.random.fold_in(rng, jax.host_id())
  for step in range(initial_step, num_train_steps + 1):
    # `step` is a Python integer. `state.step` is JAX integer on the GPU/TPU
    # devices.

    # Convert data to JAX arrays. Use ._numpy() to avoid copy.
    batch = jax.tree_map(lambda x: scaler(x._numpy()), next(train_iter))  # pylint: disable=protected-access

    rng, *next_rng = jax.random.split(rng, num=jax.local_device_count() + 1)
    next_rng = jnp.asarray(next_rng)
    loss, state = p_train_step(next_rng, state, batch)
    loss = flax.jax_utils.unreplicate(loss)

    # Quick indication that training is happening.
    logging.log_first_n(logging.INFO, "Finished training step %d.", 5, step)

    if jax.host_id() == 0 and step % 50 == 0:
      logging.info("step: %d, training_loss: %.5e", step, loss)
      writer.scalar("training_loss", loss, step)

    # Save a temporary checkpoint to resume training after pre-emption.
    if step % config.training.snapshot_freq_for_preemption == 0 and jax.host_id(
    ) == 0:
      saved_state = flax_utils.unreplicate(state)
      saved_state = saved_state.replace(rng=rng)
      ckpt_meta.save(saved_state)

    # Report the loss on an evaluation dataset.
    if step % 100 == 0:
      rng, *next_rng = jax.random.split(rng, num=jax.local_device_count() + 1)
      next_rng = jnp.asarray(next_rng)
      eval_batch = jax.tree_map(lambda x: scaler(x._numpy()), next(eval_iter))  # pylint: disable=protected-access
      eval_loss, _ = p_eval_step(next_rng, state, eval_batch)
      eval_loss = flax.jax_utils.unreplicate(eval_loss)
      if jax.host_id() == 0:
        logging.info("step: %d, eval_loss: %.5e", step, eval_loss)
        writer.scalar("eval_loss", eval_loss, step)

    # Save a checkpoint periodically and generate samples.
    if (step +
        1) % config.training.snapshot_freq == 0 or step == num_train_steps:
      # Save the checkpoint.
      if jax.host_id() == 0:
        saved_state = flax_utils.unreplicate(state)
        saved_state = saved_state.replace(rng=rng)
        ckpt.save(saved_state)

      # Generate and save samples
      if config.training.snapshot_sampling:
        rng, sample_rng = jax.random.split(rng)
        init_shape = tuple(train_ds.element_spec["image"].shape)
        samples = sampling.get_samples(sample_rng,
                                       config,
                                       flax_utils.unreplicate(state),
                                       init_shape,
                                       scaler,
                                       inverse_scaler,
                                       class_conditional=class_conditional)
        this_sample_dir = os.path.join(
            sample_dir, "iter_{}_host_{}".format(step, jax.host_id()))
        tf.io.gfile.makedirs(this_sample_dir)

        if config.sampling.final_only:  # Do not save intermediate samples
          sample = samples[-1]
          image_grid = sample.reshape((-1, *sample.shape[2:]))
          nrow = int(np.sqrt(image_grid.shape[0]))
          sample = np.clip(sample * 255, 0, 255).astype(np.uint8)
          with tf.io.gfile.GFile(
              os.path.join(this_sample_dir, "sample.np"), "wb") as fout:
            np.save(fout, sample)

          with tf.io.gfile.GFile(
              os.path.join(this_sample_dir, "sample.png"), "wb") as fout:
            utils.save_image(image_grid, fout, nrow=nrow, padding=2)
        else:  # Save all intermediate samples produced during sampling.
          for i, sample in enumerate(samples):
            image_grid = sample.reshape((-1, *sample.shape[2:]))
            nrow = int(np.sqrt(image_grid.shape[0]))
            sample = np.clip(sample * 255, 0, 255).astype(np.uint8)
            with tf.io.gfile.GFile(
                os.path.join(this_sample_dir, "sample_{}.np".format(i)),
                "wb") as fout:
              np.save(fout, sample)

            with tf.io.gfile.GFile(
                os.path.join(this_sample_dir, "sample_{}.png".format(i)),
                "wb") as fout:
              utils.save_image(image_grid, fout, nrow=nrow, padding=2)
예제 #19
0
def evaluate(config,
             workdir,
             eval_folder = "eval"):
  """Evaluate trained models.

  Args:
    config: Configuration to use.
    workdir: Working directory for checkpoints.
    eval_folder: The subfolder for storing evaluation results. Default to
      "eval".
  """
  # Create eval_dir
  eval_dir = os.path.join(workdir, eval_folder)
  tf.io.gfile.makedirs(eval_dir)

  rng = jax.random.PRNGKey(config.seed + 1)

  # Build input pipeline.
  rng, ds_rng = jax.random.split(rng)
  _, eval_ds, _ = datasets.get_dataset(ds_rng, config, evaluation=True)
  scaler = datasets.get_data_scaler(config)
  inverse_scaler = datasets.get_data_inverse_scaler(config)

  # Initialize model.
  rng, model_rng = jax.random.split(rng)
  model_name = config.model.name
  ncsn_def = mutils.get_model(model_name).partial(config=config)
  rng, run_rng = jax.random.split(rng)
  class_conditional = "conditional" in config.training.loss.lower()
  with nn.stateful() as init_model_state:
    with nn.stochastic(run_rng):
      input_shape = tuple(eval_ds.element_spec["image"].shape[1:])
      input_list = [(input_shape, jnp.float32), (input_shape[:1], jnp.int32)]
      if class_conditional:
        input_list.append(input_list[-1])
      _, initial_params = ncsn_def.init_by_shape(
          model_rng, input_list, train=True)
      ncsn = nn.Model(ncsn_def, initial_params)

  optimizer = losses.get_optimizer(config).create(ncsn)
  state = mutils.State(step=0, optimizer=optimizer, lr=config.optim.lr,
                       model_state=init_model_state,
                       ema_rate=config.model.ema_rate,
                       params_ema=initial_params,
                       rng=rng)  # pytype: disable=wrong-keyword-args

  del ncsn, init_model_state  # Do not keep a copy of the initial model.

  checkpoint_dir = os.path.join(workdir, "checkpoints")
  if config.training.loss.lower() == "ddpm":
    # Use the score matching loss with DDPM-type perturbation.
    ddpm_params = mutils.get_ddpm_params()
    eval_step = functools.partial(
        losses.ddpm_loss, ddpm_params=ddpm_params, train=False)
  else:
    # Use the score matching loss with NCSN-type perturbation.
    sigmas = mutils.get_sigmas(config)
    continuous = "continuous" in config.training.loss.lower()
    eval_step = functools.partial(
        losses.ncsn_loss,
        sigmas=sigmas,
        continuous=continuous,
        class_conditional=class_conditional,
        train=False,
        anneal_power=config.training.anneal_power)

  p_eval_step = jax.pmap(eval_step, axis_name="batch")

  rng = jax.random.fold_in(rng, jax.host_id())

  # A data class for checkpointing.
  @flax.struct.dataclass
  class EvalMeta:
    ckpt_id: int
    round_id: int
    rng: Any

  # Add one additional round to get the exact number of samples as required.
  num_rounds = config.eval.num_samples // config.eval.batch_size + 1

  eval_meta = EvalMeta(ckpt_id=config.eval.begin_ckpt, round_id=-1, rng=rng)
  eval_meta = checkpoints.restore_checkpoint(
      eval_dir, eval_meta, step=None, prefix=f"meta_{jax.host_id()}_")

  if eval_meta.round_id < num_rounds - 1:
    begin_ckpt = eval_meta.ckpt_id
    begin_round = eval_meta.round_id + 1
  else:
    begin_ckpt = eval_meta.ckpt_id + 1
    begin_round = 0

  rng = eval_meta.rng
  # Use inceptionV3 for images with higher resolution
  inceptionv3 = config.data.image_size >= 256
  inception_model = evaluation.get_inception_model(inceptionv3=inceptionv3)

  logging.info("begin checkpoint: %d", begin_ckpt)
  for ckpt in range(begin_ckpt, config.eval.end_ckpt + 1):
    ckpt_filename = os.path.join(checkpoint_dir, "ckpt-{}.flax".format(ckpt))

    # Wait if the target checkpoint hasn't been produced yet.
    waiting_message_printed = False
    while not tf.io.gfile.exists(ckpt_filename):
      if not waiting_message_printed and jax.host_id() == 0:
        logging.warn("Waiting for the arrival of ckpt-%d.flax", ckpt)
        waiting_message_printed = True
      time.sleep(10)

    # In case the file was just written and not ready to read from yet.
    try:
      state = utils.load_state_dict(ckpt_filename, state)
    except:
      time.sleep(60)
      try:
        state = utils.load_state_dict(ckpt_filename, state)
      except:
        time.sleep(120)
        state = utils.load_state_dict(ckpt_filename, state)

    pstate = flax.jax_utils.replicate(state)
    eval_iter = iter(eval_ds)  # pytype: disable=wrong-arg-types

    # Compute the loss function on the full evaluation dataset.
    all_losses = []
    for i, batch in enumerate(eval_iter):
      rng, *next_rng = jax.random.split(rng, num=jax.local_device_count() + 1)
      next_rng = jnp.asarray(next_rng)
      eval_batch = jax.tree_map(lambda x: scaler(x._numpy()), batch)  # pylint: disable=protected-access
      eval_loss, _ = p_eval_step(next_rng, pstate, eval_batch)
      eval_loss = flax.jax_utils.unreplicate(eval_loss)
      all_losses.append(eval_loss)
      if (i + 1) % 1000 == 0 and jax.host_id() == 0:
        logging.info("Finished %dth step loss evaluation", i + 1)

    all_losses = jnp.asarray(all_losses)

    state = jax.device_put(state)
    # Sampling and computing statistics for Inception scores, FIDs, and KIDs.
    # Designed to be pre-emption safe. Automatically resumes when interrupted.
    for r in range(begin_round, num_rounds):
      if jax.host_id() == 0:
        logging.info("sampling -- ckpt: %d, round: %d", ckpt, r)
      rng, sample_rng = jax.random.split(rng)
      init_shape = tuple(eval_ds.element_spec["image"].shape)

      this_sample_dir = os.path.join(
          eval_dir, f"ckpt_{ckpt}_host_{jax.host_id()}")
      tf.io.gfile.makedirs(this_sample_dir)
      samples = sampling.get_samples(sample_rng, config, state, init_shape,
                                     scaler, inverse_scaler,
                                     class_conditional=class_conditional)
      samples = samples[-1]
      samples = np.clip(samples * 255., 0, 255).astype(np.uint8)
      samples = samples.reshape(
          (-1, config.data.image_size, config.data.image_size, 3))
      with tf.io.gfile.GFile(
          os.path.join(this_sample_dir, f"samples_{r}.npz"), "wb") as fout:
        io_buffer = io.BytesIO()
        np.savez_compressed(io_buffer, samples=samples)
        fout.write(io_buffer.getvalue())

      gc.collect()
      latents = evaluation.run_inception_distributed(samples, inception_model,
                                                     inceptionv3=inceptionv3)
      gc.collect()
      with tf.io.gfile.GFile(
          os.path.join(this_sample_dir, f"statistics_{r}.npz"), "wb") as fout:
        io_buffer = io.BytesIO()
        np.savez_compressed(
            io_buffer, pool_3=latents["pool_3"], logits=latents["logits"])
        fout.write(io_buffer.getvalue())

      eval_meta = eval_meta.replace(ckpt_id=ckpt, round_id=r, rng=rng)
      # Save an intermediate checkpoint directly if not the last round.
      # Otherwise save eval_meta after computing the Inception scores and FIDs
      if r < num_rounds - 1:
        checkpoints.save_checkpoint(
            eval_dir,
            eval_meta,
            step=ckpt * num_rounds + r,
            keep=1,
            prefix=f"meta_{jax.host_id()}_")

    # Compute inception scores, FIDs and KIDs.
    if jax.host_id() == 0:
      # Load all statistics that have been previously computed and saved.
      all_logits = []
      all_pools = []
      for host in range(jax.host_count()):
        this_sample_dir = os.path.join(eval_dir, f"ckpt_{ckpt}_host_{host}")

        stats = tf.io.gfile.glob(
            os.path.join(this_sample_dir, "statistics_*.npz"))
        wait_message = False
        while len(stats) < num_rounds:
          if not wait_message:
            logging.warn("Waiting for statistics on host %d", host)
            wait_message = True
          stats = tf.io.gfile.glob(
              os.path.join(this_sample_dir, "statistics_*.npz"))
          time.sleep(1)

        for stat_file in stats:
          with tf.io.gfile.GFile(stat_file, "rb") as fin:
            stat = np.load(fin)
            if not inceptionv3:
              all_logits.append(stat["logits"])
            all_pools.append(stat["pool_3"])

      if not inceptionv3:
        all_logits = np.concatenate(
            all_logits, axis=0)[:config.eval.num_samples]
      all_pools = np.concatenate(all_pools, axis=0)[:config.eval.num_samples]

      # Load pre-computed dataset statistics.
      data_stats = evaluation.load_dataset_stats(config)
      data_pools = data_stats["pool_3"]

      if hasattr(config.eval, "num_partitions"):
        # Divide samples into several partitions and compute FID/KID/IS on them.
        assert not inceptionv3
        fids = []
        kids = []
        inception_scores = []
        partition_size = config.eval.num_samples // config.eval.num_partitions
        tf_data_pools = tf.convert_to_tensor(data_pools)
        for i in range(config.eval.num_partitions):
          this_pools = all_pools[i * partition_size:(i + 1) * partition_size]
          this_logits = all_logits[i * partition_size:(i + 1) * partition_size]
          inception_scores.append(
              tfgan.eval.classifier_score_from_logits(this_logits))
          fids.append(
              tfgan.eval.frechet_classifier_distance_from_activations(
                  data_pools, this_pools))
          this_pools = tf.convert_to_tensor(this_pools)
          kids.append(
              tfgan.eval.kernel_classifier_distance_from_activations(
                  tf_data_pools, this_pools).numpy())

        fids = np.asarray(fids)
        inception_scores = np.asarray(inception_scores)
        kids = np.asarray(kids)
        with tf.io.gfile.GFile(os.path.join(eval_dir, f"report_all_{ckpt}.npz"),
                               "wb") as f:
          io_buffer = io.BytesIO()
          np.savez_compressed(
              io_buffer, all_losses=all_losses, mean_loss=all_losses.mean(),
              ISs=inception_scores, fids=fids, kids=kids)
          f.write(io_buffer.getvalue())

      else:
        # Compute FID/KID/IS on all samples together.
        if not inceptionv3:
          inception_score = tfgan.eval.classifier_score_from_logits(all_logits)
        else:
          inception_score = -1

        fid = tfgan.eval.frechet_classifier_distance_from_activations(
            data_pools, all_pools)
        # Hack to get tfgan KID work for eager execution.
        tf_data_pools = tf.convert_to_tensor(data_pools)
        tf_all_pools = tf.convert_to_tensor(all_pools)
        kid = tfgan.eval.kernel_classifier_distance_from_activations(
            tf_data_pools, tf_all_pools).numpy()
        del tf_data_pools, tf_all_pools

        logging.info(
            "ckpt-%d --- loss: %.6e, inception_score: %.6e, FID: %.6e, KID: %.6e",
            ckpt, all_losses.mean(), inception_score, fid, kid)

        with tf.io.gfile.GFile(os.path.join(eval_dir, f"report_{ckpt}.npz"),
                               "wb") as f:
          io_buffer = io.BytesIO()
          np.savez_compressed(
              io_buffer, all_losses=all_losses, mean_loss=all_losses.mean(),
              IS=inception_score, fid=fid, kid=kid)
          f.write(io_buffer.getvalue())
    else:
      # For host_id() != 0.
      # Use file existence to emulate synchronization across hosts.
      if hasattr(config.eval, "num_partitions"):
        assert not inceptionv3
        while not tf.io.gfile.exists(
            os.path.join(eval_dir, f"report_all_{ckpt}.npz")):
          time.sleep(1.)

      else:
        while not tf.io.gfile.exists(
            os.path.join(eval_dir, f"report_{ckpt}.npz")):
          time.sleep(1.)

    # Save eval_meta after computing IS/KID/FID to mark the end of evaluation
    # for this checkpoint.
    checkpoints.save_checkpoint(
        eval_dir,
        eval_meta,
        step=ckpt * num_rounds + r,
        keep=1,
        prefix=f"meta_{jax.host_id()}_")

    begin_round = 0

  # Remove all meta files after finishing evaluation.
  meta_files = tf.io.gfile.glob(
      os.path.join(eval_dir, f"meta_{jax.host_id()}_*"))
  for file in meta_files:
    tf.io.gfile.remove(file)
    def training_loss_fn(self, flax_model, train_state, teacher_train_state,
                         batch, unlabeled_batch, dropout_rng, env_ids,
                         unlabeled_env_ids, sampled_layer):
        """Runs forward pass and computes loss.

    Args:
      flax_model: A flax module.
      train_state: TrainState; The state of training including the current
        global_step, model_state, rng, and optimizer.
      teacher_train_state: TrainState; The state of training for the teacher
        (including the current global_step, model_state, rng, and optimizer).
      batch: list(dict); A batch of data for each environment in the labeld set.
      unlabeled_batch: list(dict); A batch of data for each environment in the
        unlabeld set.
      dropout_rng: FLAX PRNG key.
      env_ids: list(int); List of labeled training environments ids.
      unlabeled_env_ids: list(int); List of unlabeled environments ids.
      sampled_layer: str; Name of the layer on which mixup is applied.

    Returns:
      loss, new_module_state and computed logits for each batch.
    """

        dropout_rng, new_rng = jax.random.split(dropout_rng)
        with nn.stochastic(dropout_rng):
            # Run student forward pass on the labeled envs.
            (all_std_env_reps, std_env_logits, _,
             train_state) = self.stateful_forward_pass(flax_model, train_state,
                                                       batch)

            # Run teacher forward pass on the labeled envs.
            (labeled_tchr_env_logits, _, _) = self.stateless_forward_pass(
                teacher_train_state.optimizer.target, teacher_train_state,
                batch)

            # Run teacher forward pass on the unlabeled envs.
            (unlabeled_tchr_env_logits, all_tchr_unlabeled_env_reps,
             _) = self.stateless_forward_pass(
                 teacher_train_state.optimizer.target, teacher_train_state,
                 unlabeled_batch)

            # Replace labels with predicted labels from the teacher model.
            for ub_id in range(len(unlabeled_env_ids)):
                unlabeled_batch[ub_id]['label'] = jnp.argmax(
                    unlabeled_tchr_env_logits[ub_id], axis=-1)

        # Get sampled layer for interpolations:
        std_sampled_reps = all_std_env_reps[sampled_layer]
        sampled_unlabeled_reps = all_tchr_unlabeled_env_reps[sampled_layer]

        interpolation_rng, new_rng = jax.random.split(new_rng)
        with nn.stochastic(interpolation_rng):
            (interpolated_batches, interpolated_logits, _,
             train_state) = self.maybe_inter_env_interpolation(
                 batch, env_ids, flax_model, self.intra_interpolate_fn,
                 sampled_layer, std_sampled_reps, std_sampled_reps,
                 train_state)

            (same_env_interpolated_batches, same_env_interpolated_logits, _,
             train_state) = self.maybe_intra_env_interpolation(
                 batch, env_ids, flax_model, self.intra_interpolate_fn,
                 sampled_layer, std_sampled_reps, train_state)

            (unlabeled_interpolated_batches, unlabeled_interpolated_logits,
             unlabeled_mixup_lambdas, unlabeled_mixup_alpha,
             unlabeled_mixup_beta,
             train_state) = self.maybe_gradual_interpolation(
                 batch, unlabeled_batch, env_ids, unlabeled_env_ids,
                 flax_model, self.interpolate_fn, sampled_layer,
                 std_sampled_reps, sampled_unlabeled_reps, std_sampled_reps,
                 sampled_unlabeled_reps, labeled_tchr_env_logits,
                 unlabeled_tchr_env_logits, train_state, teacher_train_state)

            # Compute the total loss (inside nn.stochastic):
            # env_reps and env_ids are set to None to avoid computing a loss for
            # domain mapping (the mapping model is not trained and not used in
            # computing the loss).
            ground_truth_factor_params = pipeline_utils.get_weight_param(
                self.hparams, 'ground_truth_factor', 1.0)
            ground_truth_factor = pipeline_utils.scheduler(
                train_state.global_step, ground_truth_factor_params)

            ground_truth_loss = self.task.loss_function(
                std_env_logits, None, batch, None, flax_model.params,
                train_state.global_step)
            loss = ground_truth_loss * ground_truth_factor

            # Add the loss for cross environment interpolated states:
            if len(env_ids) > 1 and self.hparams.get('inter_env_interpolation',
                                                     True):
                inter_mixup_factor_params = pipeline_utils.get_weight_param(
                    self.hparams, 'inter_mixup_factor', 1.0)
                inter_mixup_factor = pipeline_utils.scheduler(
                    train_state.global_step, inter_mixup_factor_params)
                loss += self.task.loss_function(
                    interpolated_logits, None, interpolated_batches, None,
                    None, train_state.global_step) * inter_mixup_factor

            # Add the loss for same environment interpolated states:
            if self.hparams.get('intra_env_interpolation', True):
                intra_mixup_factor_params = pipeline_utils.get_weight_param(
                    self.hparams, 'intra_mixup_factor', 1.0)
                intra_mixup_factor = pipeline_utils.scheduler(
                    train_state.global_step, intra_mixup_factor_params)

                loss += self.task.loss_function(
                    same_env_interpolated_logits, None,
                    same_env_interpolated_batches, None, None,
                    train_state.global_step) * intra_mixup_factor

            # Add the loss for gradual environment interpolations toward unlabeled
            # target environment(s):
            unlabeled_mixup_factor = 0
            unlabeled_loss = 0
        if self.hparams.get('unlabeled_interpolation', True):
            unlabeled_mixup_factor_params = pipeline_utils.get_weight_param(
                self.hparams, 'unlabeled_mixup_factor', 1.0)
            unlabeled_mixup_factor = pipeline_utils.scheduler(
                train_state.global_step, unlabeled_mixup_factor_params)
            unlabeled_loss = self.task.loss_function(
                unlabeled_interpolated_logits, None,
                unlabeled_interpolated_batches, None, None,
                train_state.global_step)
            loss += unlabeled_loss * unlabeled_mixup_factor

        logs = {}
        logs['unlabeled_mixup_lambda'] = unlabeled_mixup_lambdas
        logs['unlabeled_mixup_alpha'] = unlabeled_mixup_alpha
        logs['unlabeled_mixup_beta'] = unlabeled_mixup_beta
        logs['unlabeled_mixup_factor'] = unlabeled_mixup_factor
        logs['train_loss'] = ground_truth_loss
        logs['unlabeled_loss'] = unlabeled_loss

        return loss, (train_state.model_state, std_env_logits, logs)
    def training_loss_fn(self, flax_module, train_state, batch, dropout_rng,
                         mixup_rng, sampled_layer):
        """Runs forward pass and computes loss.

    Args:
      flax_module: A flax module.
      train_state: TrainState, the state of training including the current
        global_step, model_state, rng, and optimizer.
      batch: Batches from different environments.
      dropout_rng: FLAX PRNG key.
      mixup_rng: FLAX PRNG key.
      sampled_layer: str; Name of the layer on which mixup will be applied.

    Returns:
      loss, new_module_state and computed logits for each batch.
    """

        with nn.stochastic(dropout_rng):
            with nn.stateful(train_state.model_state) as new_model_state:
                logits, reps, _ = flax_module(batch['inputs'],
                                              train=True,
                                              return_activations=True)

                # Get mathing between examples from the mini batch:
                matching_matrix = pipeline_utils.get_self_matching_matrix(
                    batch,
                    reps[sampled_layer],
                    mode=self.hparams.get('intra_mixup_mode', 'random'),
                    label_cost=self.hparams.get('intra_mixup_label_cost', 1.0),
                    l2_cost=self.hparams.get('intra_mixup_l2_cost', 0.001))

        beta_params = self.hparams.get('beta_schedule_params') or {
            'initial_value': 1.0,
            'mode': 'constant'
        }
        alpha_params = self.hparams.get('alpha_schedule_params') or {
            'initial_value': 1.0,
            'mode': 'constant'
        }
        step = train_state.global_step
        beta = pipeline_utils.scheduler(step, beta_params)
        alpha = pipeline_utils.scheduler(step, alpha_params)

        with nn.stochastic(mixup_rng):
            with nn.stateful(new_model_state) as new_model_state:
                new_logits, sample_lambdas = self.interpolate_and_predict(
                    nn.make_rng(), flax_module, matching_matrix, reps,
                    sampled_layer, alpha, beta)

            new_batch = copy.deepcopy(batch)

            # Compute labels for the interpolated states:
            new_batch['label'] = tensor_util.convex_interpolate(
                batch['label'], batch['label'][jnp.argmax(matching_matrix,
                                                          axis=-1)],
                sample_lambdas)

            # Compute weights for the interpolated states:
            if batch.get('weights') is not None:
                new_batch['weights'] = tensor_util.convex_interpolate(
                    batch['weights'],
                    batch['weights'][jnp.argmax(matching_matrix,
                                                axis=-1)], sample_lambdas)

        # Standard loss:
        loss = self.task.loss_function(logits, batch, flax_module.params)
        # Add the loss from interpolated states:
        loss += self.task.loss_function(new_logits, new_batch)

        return loss, (new_model_state, logits)
  def training_loss_fn(self, flax_model, train_state, batch, dropout_rng,
                       env_ids, sampled_layer):
    """Runs forward pass and computes loss.

    Args:
      flax_model: A flax module.
      train_state: TrainState, the state of training including the current
        global_step, model_state, rng, and optimizer.
      batch: Batches from different environments.
      dropout_rng: FLAX PRNG key.
      env_ids: list[int]; List of env codes.
      sampled_layer: str; Name of the layer on which mixup is applied.

    Returns:
      loss, new_module_state and computed logits for each batch.
    """
    dropout_rng, new_rng = jax.random.split(dropout_rng)
    with nn.stochastic(dropout_rng):
      # Run student forward pass:
      (all_env_reps, env_logits, selected_env_reps,
       train_state) = self.stateful_forward_pass(flax_model, train_state, batch)
      new_model_state = train_state.model_state

    sampled_reps = all_env_reps[sampled_layer]
    interpolate_fn = jax.vmap(
        pipeline_utils.interpolate,
        in_axes=(0, 0, 0, 0, None, None, None, None))

    interpolate_rng, new_rng = jax.random.split(new_rng)
    with nn.stochastic(interpolate_rng):
      (interpolated_batches, interpolated_logits, sampled_lambdas,
       train_state) = self.maybe_inter_env_interpolation(
           batch, env_ids, flax_model, interpolate_fn, sampled_layer,
           sampled_reps, selected_env_reps, train_state)

      (same_env_interpolated_batches, same_env_interpolated_logits, _,
       train_state) = self.maybe_intra_env_interpolation(
           batch, env_ids, flax_model, interpolate_fn, sampled_layer,
           sampled_reps, train_state)

    loss_rng, new_rng = jax.random.split(new_rng)
    with nn.stochastic(loss_rng):
      # Compute the total loss (inside nn.stochastic):
      loss = self.task.loss_function(env_logits, selected_env_reps, batch,
                                     env_ids, flax_model.params,
                                     train_state.global_step)
      # Add the loss for cross environment interpolated states:
      if len(env_ids) > 1 and self.hparams.get('inter_env_interpolation', True):
        inter_mixup_factor = self.hparams.get('inter_mixup_factor', 1.0)
        loss += self.task.loss_function(
            interpolated_logits, None, interpolated_batches, None, None,
            train_state.global_step) * inter_mixup_factor

      # Add the loss for same environment interpolated states:
      if self.hparams.get('intra_env_interpolation', True):
        intra_mixup_factor = self.hparams.get('intra_mixup_factor', 1.0)
        loss += self.task.loss_function(
            same_env_interpolated_logits, None, same_env_interpolated_batches,
            None, None, train_state.global_step) * intra_mixup_factor

    logs = {'sampled_lambdas': sampled_lambdas}

    return loss, (new_model_state, env_logits, logs)