Beispiel #1
0
        def sgd_step(
            state: TrainingState, sample: reverb.ReplaySample
        ) -> Tuple[TrainingState, Dict[str, jnp.ndarray]]:
            """Computes an SGD step, returning new state and metrics for logging."""

            # Compute gradients.
            grad_fn = jax.value_and_grad(loss_fn)
            loss_value, gradients = grad_fn(state.params, sample)

            # Average gradients over pmap replicas before optimizer update.
            gradients = jax.lax.pmean(gradients, _PMAP_AXIS_NAME)

            # Apply updates.
            updates, new_opt_state = optimizer.update(gradients,
                                                      state.opt_state)
            new_params = optax.apply_updates(state.params, updates)

            metrics = {
                'loss': loss_value,
                'param_norm': optax.global_norm(new_params),
                'param_updates_norm': optax.global_norm(updates),
            }

            new_state = TrainingState(params=new_params,
                                      opt_state=new_opt_state)

            return new_state, metrics
Beispiel #2
0
            def model_update_minibatch(
                carry: Tuple[networks_lib.Params, optax.OptState],
                minibatch: Batch,
            ) -> Tuple[Tuple[networks_lib.Params, optax.OptState], Dict[
                    str, jnp.ndarray]]:
                """Performs model update for a single minibatch."""
                params, opt_state = carry
                # Normalize advantages at the minibatch level before using them.
                advantages = ((minibatch.advantages -
                               jnp.mean(minibatch.advantages, axis=0)) /
                              (jnp.std(minibatch.advantages, axis=0) + 1e-8))
                gradients, metrics = grad_fn(params, minibatch.observations,
                                             minibatch.actions,
                                             minibatch.behavior_log_probs,
                                             minibatch.target_values,
                                             advantages,
                                             minibatch.behavior_values)

                # Apply updates
                updates, opt_state = optimizer.update(gradients, opt_state)
                params = optax.apply_updates(params, updates)

                metrics['norm_grad'] = optax.global_norm(gradients)
                metrics['norm_updates'] = optax.global_norm(updates)
                return (params, opt_state), metrics
Beispiel #3
0
    def _update_func(
        self,
        params: hk.Params,
        ema_params: hk.Params,
        network_state: hk.State,
        ema_network_state: hk.State,
        opt_state: optax.OptState,
        global_step: jnp.ndarray,
        rng: jnp.ndarray,
        batch: datasets.Batch,
    ) -> Tuple[hk.Params, hk.Params, hk.State, hk.State, optax.OptState,
               losses.LogsDict]:
        """Updates parameters."""

        grad_fn = jax.value_and_grad(self._loss, has_aux=True)
        (_, (stats, new_network_state)), grads = grad_fn(
            params, ema_params, network_state, ema_network_state, rng, batch)
        learning_rate = self._get_learning_rate(global_step)
        _, opt_apply = self._optimizer(learning_rate)
        grad = jax.lax.pmean(grads, axis_name='i')
        updates, opt_state = opt_apply(grad, opt_state, params)
        params = optax.apply_updates(params, updates)

        # Stats and logging.
        param_norm = optax.global_norm(params)
        grad_norm = optax.global_norm(grad)
        ema_rate = schedules.ema_decay_schedule(
            step=global_step, **self.config.eval.ema_annealing_schedule)
        num_non_padded_nodes = (
            batch.graph.n_node.sum() -
            jraph.get_number_of_padding_with_graphs_nodes(batch.graph))
        num_non_padded_edges = (
            batch.graph.n_edge.sum() -
            jraph.get_number_of_padding_with_graphs_edges(batch.graph))
        num_non_padded_graphs = (
            batch.graph.n_node.shape[0] -
            jraph.get_number_of_padding_with_graphs_graphs(batch.graph))
        avg_num_nodes = num_non_padded_nodes / num_non_padded_graphs
        avg_num_edges = num_non_padded_edges / num_non_padded_graphs
        stats.update(
            dict(
                global_step=global_step,
                grad_norm=grad_norm,
                param_norm=param_norm,
                learning_rate=learning_rate,
                ema_rate=ema_rate,
                avg_num_nodes=avg_num_nodes,
                avg_num_edges=avg_num_edges,
            ))
        ema_fn = (
            lambda x, y:  # pylint:disable=g-long-lambda
            schedules.apply_ema_decay(x, y, ema_rate))
        ema_params = jax.tree_multimap(ema_fn, ema_params, params)
        ema_network_state = jax.tree_multimap(
            ema_fn,
            ema_network_state,
            network_state,
        )
        return (params, ema_params, new_network_state, ema_network_state,
                opt_state, stats)
Beispiel #4
0
    def _step(self, particles, optimizer_state, params):
        """
        Updates particles in the direction given by self.gradient

        Arguments:
            particles: jnp.ndarray of shape (n, d)
            params: can be anything. e.g. inducing particles in the case of SVGD,
        deep NN params for learned f, or None.

        Returns:
            particles (updated)
            optimizer_state (updated)
            grad_aux: dict containing auxdata
        """
        grads, grad_aux = self.gradient(params, particles, aux=True)
        updated_grads, optimizer_state = self.opt.update(
            grads, optimizer_state, particles)
        particles = optax.apply_updates(particles, updated_grads)
        grad_aux.update({
            "global_grad_norm":
            optax.global_norm(grads),
            "global_grad_norm_post_update":
            optax.global_norm(updated_grads),
        })
        grad_aux.update({})
        # grad_aux.update({"grads": updated_grads})
        return particles, optimizer_state, grad_aux
Beispiel #5
0
    def loss_fn(self,
                params,
                dlogp: np.ndarray,
                key: np.ndarray,
                particles: np.ndarray,
                dropout: bool = False):
        """
        Arguments:
            params: neural net paramers
            dlogp: gradient grad(log p)(x), shaped (n, d)
            key: random PRNGKey
            particles: array of shape (n, d)
            dropout: whether to use dropout in the gradient network
        """
        n, d = particles.shape
        v = self.get_field(particles, params, dropout=dropout)
        if dropout:
            f = utils.negative(v)
        else:

            def f(x, dummy_key):
                return -v(x)

        # stein discrepancy
        def h(x, dlogp_x, key):
            zkey, fkey = random.split(key)
            z = random.normal(zkey, (d, ))
            zdf = grad(lambda _x: np.vdot(z, f(_x, fkey)))
            div_f = np.vdot(zdf(x), z)
            #div_f = np.trace(jacfwd(f)(x, fkey))
            sd = np.vdot(f(x, fkey), dlogp_x) + div_f
            l2 = np.vdot(f(x, fkey), f(x, fkey))
            aux = {
                "sd": sd,
                "l2": l2,
            }
            return -sd + l2 * self.lambda_reg, aux

        keys = random.split(key, n)
        loss, aux = vmap(h)(particles, dlogp, keys)
        loss = loss.mean()
        aux = {k: v.mean() for k, v in aux.items()}
        fnorm = optax.global_norm(jnp.mean(vmap(f)(particles, keys), axis=0))
        pnorm = optax.global_norm(jnp.mean(dlogp, axis=0))
        aux.update({
            "loss": loss,
            "l1_diff": fnorm - pnorm,
            "l1_ratio": fnorm / pnorm
        })
        #        #  add L1 term
        #        if self.l1_weight:
        #            loss = loss + self.l1_weight * np.abs(jnp.mean(vmap(f)(particles) - dlogp))
        return loss, aux
Beispiel #6
0
def _create_loss_metrics(
    loss_has_aux: bool,
    loss_result: Union[jnp.ndarray, Tuple[jnp.ndarray, loggers.LoggingData]],
    gradients: jnp.ndarray,
):
    """Creates loss metrics for logging."""
    # Validate input.
    if loss_has_aux and not (len(loss_result) == 2 and isinstance(
            loss_result[0], jnp.ndarray) and isinstance(loss_result[1], dict)):
        raise ValueError(
            'Could not parse loss value and metrics from loss_fn\'s '
            'output. Since loss_has_aux is enabled, loss_fn must '
            'return loss_value and auxiliary metrics.')

    if not loss_has_aux and not isinstance(loss_result, jnp.ndarray):
        raise ValueError(
            f'Loss returns type {loss_result}. However, it should '
            'return a jnp.ndarray, given that loss_has_aux = False.')

    # Maybe unpack loss result.
    if loss_has_aux:
        loss, metrics = loss_result
    else:
        loss = loss_result
        metrics = {}

    # Complete metrics dict and return it.
    metrics['loss'] = loss
    metrics['gradient_norm'] = optax.global_norm(gradients)
    return metrics
Beispiel #7
0
    def sgd_step(
        state: TrainingState,
        transitions: types.Transition,
    ) -> Tuple[TrainingState, Dict[str, jnp.ndarray]]:

      loss_and_grad = jax.value_and_grad(loss_fn, argnums=1)

      # Compute losses and their gradients.
      key, key_input = jax.random.split(state.key)
      loss_value, gradients = loss_and_grad(network.apply, state.policy_params,
                                            key_input, transitions)

      policy_update, optimizer_state = optimizer.update(gradients,
                                                        state.optimizer_state)
      policy_params = optax.apply_updates(state.policy_params, policy_update)

      new_state = TrainingState(
          optimizer_state=optimizer_state,
          policy_params=policy_params,
          key=key,
          steps=state.steps + 1,
      )
      metrics = {
          'loss': loss_value,
          'gradient_norm': optax.global_norm(gradients)
      }

      return new_state, metrics
Beispiel #8
0
    def _step(self, key, params, optimizer_state, dlogp, val_dlogp, particles,
              val_particles):
        """
        update parameters and compute validation loss
        args:
            dlogp: array of shape (n_train, d)
            val_dlogp: array of shape (n_validation, d)
        """
        [loss,
         loss_aux], grads = value_and_grad(self.loss_fn,
                                           has_aux=True)(params,
                                                         dlogp,
                                                         key,
                                                         particles,
                                                         dropout=self.dropout)
        grads, optimizer_state = self.opt.update(grads, optimizer_state,
                                                 params)
        params = optax.apply_updates(params, grads)

        _, val_loss_aux = self.loss_fn(params,
                                       val_dlogp,
                                       key,
                                       val_particles,
                                       dropout=False)
        auxdata = {k: v for k, v in loss_aux.items()}
        auxdata.update({"val_" + k: v for k, v in val_loss_aux.items()})
        auxdata.update({
            "global_gradient_norm": optax.global_norm(grads),
        })
        return params, optimizer_state, auxdata
Beispiel #9
0
    def sgd_step(
        state: TrainingState, sample: reverb.ReplaySample
    ) -> Tuple[TrainingState, Dict[str, jnp.DeviceArray]]:
      """Do a step of SGD."""
      grad_fn = jax.value_and_grad(loss)
      loss_value, gradients = grad_fn(state.params, sample)
      updates, new_opt_state = optimizer.update(gradients, state.opt_state)
      new_params = optax.apply_updates(state.params, updates)

      steps = state.steps + 1

      new_state = TrainingState(
          params=new_params, opt_state=new_opt_state, steps=steps)

      # Compute the global norm of the gradients for logging.
      global_gradient_norm = optax.global_norm(gradients)
      fetches = {'loss': loss_value, 'gradient_norm': global_gradient_norm}

      return new_state, fetches
Beispiel #10
0
    def _update_func(
        self,
        base_rng,
        state,
        inputs,
    ):
        """Computes loss and updates model parameters."""
        step = state.step
        rng = jax.random.fold_in(base_rng, jax.lax.axis_index('batch'))
        rng = jax.random.fold_in(rng, step)
        grad_loss_fn = jax.value_and_grad(self._loss_fn, has_aux=True)
        (_, loss_dict), scaled_grads = grad_loss_fn(state.params, inputs, rng)
        grads = jax.lax.psum(scaled_grads, axis_name='batch')
        grad_norm = optax.global_norm(grads)
        loss_dict['scalars']['grad_norm'] = grad_norm

        # Compute and apply updates via our optimizer.
        learning_rate = self.learning_rate(state.step)
        loss_dict['scalars']['learning_rate'] = learning_rate
        _, opt_apply = self.optimizer(learning_rate)
        updates, new_opt_state = opt_apply(grads, state.opt_state,
                                           state.params)
        new_params = optax.apply_updates(state.params, updates)

        # Update ema params
        ema_rate = self.config.evaluation.ema_rate
        new_ema_params = jax.tree_multimap(
            lambda x, y: x + (1 - ema_rate) * (y - x),
            state.ema_params,
            new_params,
        )
        new_state = state.replace(step=step + 1,
                                  params=new_params,
                                  ema_params=new_ema_params,
                                  opt_state=new_opt_state)

        # Rescale loss dict and return
        loss_dict['scalars'] = jax.tree_map(
            lambda x: jax.lax.psum(x, axis_name='batch') / jax.device_count(),
            loss_dict['scalars'],
        )
        return new_state, loss_dict
Beispiel #11
0
    def _update_func(
        self,
        params: hk.Params,
        state: hk.State,
        opt_state: OptState,
        inputs: dataset.Batch,
        rng: jnp.ndarray,
        global_step: int,
    ) -> Tuple[hk.Params, hk.State, OptState, Scalars]:
        """Applies an update to parameters and returns new state."""
        # This function computes the gradient of the first output of loss_fn and
        # passes through the other arguments unchanged.
        grad_loss_fn = jax.grad(self._loss_fn, has_aux=True)
        scaled_grads, (loss_scalars,
                       state) = grad_loss_fn(params, state, inputs, rng)
        grads = jax.lax.psum(scaled_grads, axis_name='i')

        # Grab the learning rate to log before performing the step.
        learning_rate = self._lr_schedule(global_step)

        # Compute and apply updates via our optimizer.
        updates, opt_state = self._optimizer.update(grads, opt_state, params)
        params = optax.apply_updates(params, updates)

        n_params = 0
        for k in params.keys():
            for l in params[k]:
                n_params = n_params + np.prod(params[k][l].shape)

        # Scalars to log (note: we log the mean across all hosts/devices).
        scalars = {
            'learning_rate': learning_rate,
            'n_params (M)': float(n_params / 1e6),
            'global_gradient_norm': optax.global_norm(grads)
        }
        loss_scalars = {f'train_{k}': v for k, v in loss_scalars.items()}
        scalars.update(loss_scalars)
        scalars = jax.lax.pmean(scalars, axis_name='i')

        return params, state, opt_state, scalars
Beispiel #12
0
    def update(updates, state, params=None):
        inner_state = state.inner_state
        # Compute gradient norm and clip gradient if necessary
        gradient_norm = optax.global_norm(updates)
        flat_updates = jax.tree_flatten(updates)[0]
        isfinite = jnp.all(
            jnp.array([jnp.all(jnp.isfinite(p)) for p in flat_updates]))
        islowerthan = gradient_norm < gradient_norm_skip_threshold

        def do_update(_):
            return inner.update(updates, inner_state, params)

        def reject_update(_):
            return (jax.tree_map(jnp.zeros_like, updates), inner_state)

        updates, new_inner_state = jax.lax.cond(jnp.logical_and(
            isfinite, islowerthan),
                                                do_update,
                                                reject_update,
                                                operand=None)

        return updates, MaybeSkipGradientUpdateState(
            inner_state=new_inner_state)
Beispiel #13
0
def solve_dual_train(
    env: Dict[int, DualOp],
    dual_state: ConfigDict,
    opt: optax.GradientTransformation,
    inner_opt: InnerMaxStrategy,
    dual_params: Params,
    spec_type: verify_utils.SpecType,
    dual_params_types: ParamsTypes,
    logger: Callable[[int, Mapping[str, Any]], None],
    key: jnp.array,
    num_steps: int,
    affine_before_relu: bool,
    device_type=None,
    merge_problems: Optional[Dict[int, int]] = None,
    block_to_time: bool = False,
) -> ConfigDict:
    """Compute verified upper bound via functional lagrangian relaxation.

  Args:
    env: Lagrangian computations for each contributing graph node.
    dual_state: state of the dual problem.
    opt: an optimizer for the outer Lagrangian parameters.
    inner_opt: inner optimization strategy for training.
    dual_params: dual parameters to be minimized via gradient-based
      optimization.
    spec_type: Specification type, adversarial or uncertainty specification.
    dual_params_types: types of inequality encoded by the corresponding
      dual_params.
    logger: logging function.
    key: jax.random.PRNGKey.
    num_steps: total number of outer optimization steps.
    affine_before_relu: whether layer ordering uses the affine layer before the
      ReLU.
    device_type: string, used to clamp to a particular hardware device. Default
      None uses JAX default device placement.
    merge_problems: the key of the dictionary corresponds to the index of the
      layer to begin the merge, and the associated value corresponds to the
      number of consecutive layers to be merged with it.
      For example, `{0: 2, 2: 3}` will merge together layer 0 and 1, as well as
        layers 2, 3 and 4.
    block_to_time: whether to block computations at the end of each iteration to
      account for asynchronicity dispatch when timing.

  Returns:
    dual_state: new state of the dual problem.
    info: various information for logging / debugging.
  """
    assert device_type in (None, 'cpu', 'gpu'), 'invalid device_type'

    # create dual functions
    loss_func = dual_build.build_dual_fun(
        env=env,
        lagrangian_form=dual_params_types.lagrangian_form,
        inner_opt=inner_opt,
        merge_problems=merge_problems,
        affine_before_relu=affine_before_relu,
        spec_type=spec_type)

    value_and_grad = jax.value_and_grad(loss_func, has_aux=True)

    def grad_step(params, opt_state, key, step):
        (loss_val, stats), g = value_and_grad(params, key, step)
        updates, new_opt_state = opt.update(g, opt_state)
        new_params = optax.apply_updates(params, updates)
        return new_params, new_opt_state, loss_val, stats

    # Some solvers (e.g. MIP) cannot be jitted and run on CPU only
    if inner_opt.jittable:
        grad_step = jax.jit(grad_step, backend=device_type)

    dual_state.step = 0
    dual_state.key = key
    dual_state.opt_state = opt.init(dual_params)
    dual_state.dual_params = dual_params
    dual_state.loss = 0.0

    dual_state.best_loss = jnp.inf
    dual_state.best_dual_params = dual_params

    # optimize the dual (Lagrange) parameters with a gradient-based optimizer
    while dual_state.step < num_steps:
        key_step, dual_state.key = jax.random.split(dual_state.key)
        start_time = time.time()
        dual_params, dual_state.opt_state, dual_state.loss, stats = grad_step(
            dual_state.dual_params, dual_state.opt_state, key_step,
            dual_state.step)
        dual_params = dual_build.project_dual(dual_params, dual_params_types)
        if dual_state.loss <= dual_state.best_loss:
            dual_state.best_loss = dual_state.loss
            # store value from previous iteration as loss corresponds to those params
            dual_state.best_dual_params = dual_state.dual_params
        dual_state.dual_params = dual_params  # projected dual params
        if block_to_time:
            dual_state.loss.block_until_ready()  # asynchronous dispatch
        stats['time_per_iteration'] = time.time() - start_time
        stats['best_loss'] = dual_state.best_loss
        stats['dual_params_norm'] = optax.global_norm(dual_state.dual_params)

        logger(dual_state.step, stats)

        dual_state.step += 1

    return dual_state
Beispiel #14
0
def _parallel_train_step(
    optimizer,
    batched_examples,
    static_batch_metadata,
    loss_fn,
    max_global_norm=None,
    **optimizer_hyper_params,
):
    """Train the model for one step in parallel across devices.

  Args:
    optimizer: Optimizer that tracks the model and parameter state. Should be
      replicated to each device, i.e. should contain ShardedDeviceArrays with a
      leading axis (num_devices, ...) but with the same content on each device.
    batched_examples: A structure of NDArrays representing a batch of examples.
      Should have two leading batch dimensions: (num_devices,
        batch_size_per_device, ...)
    static_batch_metadata: Metadata about this batch, which will be shared
      across all batched examples. Each value of this results in a separate
      XLA-compiled module.
    loss_fn: Task-specific non-batched loss function to apply. Should take the
      current model (optimizer.target) and an example from batched_examples, and
      return a tuple of the current loss (as a scalar) and a dictionary from
      string names to metric values (also scalars, or RatioMetrics).
    max_global_norm: Maximum global norm to clip gradients to. Should be a
      scalar, which will be broadcast automatically.
    **optimizer_hyper_params: Hyperparameters to pass to the optimizer's
      `apply_gradient` function, which will be broadcast across devices
      automatically.

  Returns:
    Tuple (updated_optimizer, grads_ok, metrics). Metrics will be as returned by
    loss_fn, with an extra elements "loss". All metrics will be averaged
    across all elements of the batch. Both optimizer and metrics will contain
    ShardedDeviceArrays that are identical across devices. grads_ok will be
    a replicated bool ndarray that is True if the gradients were finite.
  """
    def batched_loss_fn(model):
        """Apply loss function across a batch of examples."""
        loss, metrics = jax.vmap(loss_fn,
                                 (None, 0, None))(model, batched_examples,
                                                  static_batch_metadata)
        return jnp.mean(loss), metrics

    # Compute gradients of loss, along with metrics.
    (loss, metrics), grads = jax.value_and_grad(batched_loss_fn,
                                                has_aux=True)(optimizer.target)
    metrics["loss"] = loss

    # Exchange average gradients and metrics across devices.
    agg_grads = jax.lax.pmean(grads, "devices")
    agg_metrics = {}
    for k, v in metrics.items():
        if isinstance(v, RatioMetric):
            num = jax.lax.psum(jnp.sum(v.numerator), "devices")
            denom = jax.lax.psum(jnp.sum(v.denominator), "devices")
            new_value = num / denom
        else:
            # Use nanmean to aggregate bare floats.
            new_value = jnp.nanmean(jax.lax.all_gather(v, "devices"))
        agg_metrics[k] = new_value

    # Compute global norm and possibly clip.
    global_norm = optax.global_norm(agg_grads)
    agg_metrics["gradient_global_norm"] = global_norm
    if max_global_norm is not None:
        should_clip = global_norm > max_global_norm
        agg_grads = jax.tree_map(
            lambda g: jnp.where(should_clip, g * max_global_norm / global_norm,
                                g), agg_grads)
        agg_metrics["gradient_was_clipped"] = should_clip.astype("float32")

    # Check for non-finite gradients.
    grads_ok = jnp.all(
        jnp.stack(
            [jnp.all(jnp.isfinite(x)) for x in jax.tree_leaves(agg_grads)]))

    # Apply updates.
    updated_optimizer = optimizer.apply_gradient(agg_grads,
                                                 **optimizer_hyper_params)

    return updated_optimizer, grads_ok, agg_metrics, agg_grads
Beispiel #15
0
        def sgd_step(
            state: TrainingState, sample: reverb.ReplaySample
        ) -> Tuple[TrainingState, Dict[str, jnp.ndarray]]:
            """Performs a minibatch SGD step, returning new state and metrics."""

            # Extract the data.
            data = sample.data
            # TODO(sinopalnikov): replace it with namedtuple unpacking
            observations, actions, rewards, termination, extra = (
                data.observation, data.action, data.reward, data.discount,
                data.extras)
            discounts = termination * discount
            behavior_log_probs = extra['log_prob']

            def get_behavior_values(
                    params: networks_lib.Params,
                    observations: types.NestedArray) -> jnp.ndarray:
                o = jax.tree_map(
                    lambda x: jnp.reshape(x, [-1] + list(x.shape[2:])),
                    observations)
                _, behavior_values = ppo_networks.network.apply(params, o)
                behavior_values = jnp.reshape(behavior_values,
                                              rewards.shape[0:2])
                return behavior_values

            behavior_values = get_behavior_values(state.params, observations)

            # Vmap over batch dimension
            batch_gae_advantages = jax.vmap(gae_advantages, in_axes=0)
            advantages, target_values = batch_gae_advantages(
                rewards, discounts, behavior_values)
            # Exclude the last step - it was only used for bootstrapping.
            # The shape is [num_sequences, num_steps, ..]
            observations, actions, behavior_log_probs, behavior_values = jax.tree_map(
                lambda x: x[:, :-1],
                (observations, actions, behavior_log_probs, behavior_values))
            trajectories = Batch(observations=observations,
                                 actions=actions,
                                 advantages=advantages,
                                 behavior_log_probs=behavior_log_probs,
                                 target_values=target_values,
                                 behavior_values=behavior_values)

            # Concatenate all trajectories. Reshape from [num_sequences, num_steps,..]
            # to [num_sequences * num_steps,..]
            assert len(target_values.shape) > 1
            num_sequences = target_values.shape[0]
            num_steps = target_values.shape[1]
            batch_size = num_sequences * num_steps
            assert batch_size % num_minibatches == 0, (
                'Num minibatches must divide batch size. Got batch_size={}'
                ' num_minibatches={}.').format(batch_size, num_minibatches)
            batch = jax.tree_map(
                lambda x: x.reshape((batch_size, ) + x.shape[2:]),
                trajectories)

            # Compute gradients.
            grad_fn = jax.grad(loss, has_aux=True)

            def model_update_minibatch(
                carry: Tuple[networks_lib.Params, optax.OptState],
                minibatch: Batch,
            ) -> Tuple[Tuple[networks_lib.Params, optax.OptState], Dict[
                    str, jnp.ndarray]]:
                """Performs model update for a single minibatch."""
                params, opt_state = carry
                # Normalize advantages at the minibatch level before using them.
                advantages = ((minibatch.advantages -
                               jnp.mean(minibatch.advantages, axis=0)) /
                              (jnp.std(minibatch.advantages, axis=0) + 1e-8))
                gradients, metrics = grad_fn(params, minibatch.observations,
                                             minibatch.actions,
                                             minibatch.behavior_log_probs,
                                             minibatch.target_values,
                                             advantages,
                                             minibatch.behavior_values)

                # Apply updates
                updates, opt_state = optimizer.update(gradients, opt_state)
                params = optax.apply_updates(params, updates)

                metrics['norm_grad'] = optax.global_norm(gradients)
                metrics['norm_updates'] = optax.global_norm(updates)
                return (params, opt_state), metrics

            def model_update_epoch(
                carry: Tuple[jnp.ndarray, networks_lib.Params, optax.OptState,
                             Batch], unused_t: Tuple[()]
            ) -> Tuple[Tuple[jnp.ndarray, networks_lib.Params, optax.OptState,
                             Batch], Dict[str, jnp.ndarray]]:
                """Performs model updates based on one epoch of data."""
                key, params, opt_state, batch = carry
                key, subkey = jax.random.split(key)
                permutation = jax.random.permutation(subkey, batch_size)
                shuffled_batch = jax.tree_map(
                    lambda x: jnp.take(x, permutation, axis=0), batch)
                minibatches = jax.tree_map(
                    lambda x: jnp.reshape(x, [num_minibatches, -1] + list(
                        x.shape[1:])), shuffled_batch)

                (params,
                 opt_state), metrics = jax.lax.scan(model_update_minibatch,
                                                    (params, opt_state),
                                                    minibatches,
                                                    length=num_minibatches)

                return (key, params, opt_state, batch), metrics

            params = state.params
            opt_state = state.opt_state
            # Repeat training for the given number of epoch, taking a random
            # permutation for every epoch.
            (key, params, opt_state, _), metrics = jax.lax.scan(
                model_update_epoch,
                (state.random_key, params, opt_state, batch), (),
                length=num_epochs)

            metrics = jax.tree_map(jnp.mean, metrics)
            metrics['norm_params'] = optax.global_norm(params)
            metrics['observations_mean'] = jnp.mean(
                utils.batch_concat(jax.tree_map(
                    lambda x: jnp.abs(jnp.mean(x, axis=(0, 1))), observations),
                                   num_batch_dims=0))
            metrics['observations_std'] = jnp.mean(
                utils.batch_concat(jax.tree_map(
                    lambda x: jnp.std(x, axis=(0, 1)), observations),
                                   num_batch_dims=0))
            metrics['rewards_mean'] = jnp.mean(
                jnp.abs(jnp.mean(rewards, axis=(0, 1))))
            metrics['rewards_std'] = jnp.std(rewards, axis=(0, 1))
            new_state = TrainingState(params=params,
                                      opt_state=opt_state,
                                      random_key=key)
            return new_state, metrics