Ejemplo n.º 1
0
 def cell_fn(carry_0, xs_t):
     # theta is implicitly captured.
     y = theta.delta + xs_t.x + carry_0.y
     z = y + 1
     carry_1 = NestedMap(y=y)
     base_layer.add_summary('test_summary', z)
     return carry_1, NestedMap(z=z)
Ejemplo n.º 2
0
 def _forward_summary(self, summaries):
   """Forwards summary from the inner JaxContext to the outer context."""
   p = self.params
   for summary_key, summary_value in summaries.items():
     logging.info((summary_key, summary_value))
     summary_type = base_layer.get_summary_type_from_key(summary_key)
     assert summary_value.shape[0] == p.num_stages
     if p.unpack_summaries:
       # unstack summary_value
       unstacked_values = jnp.split(summary_value, p.num_stages)
       for i, v in enumerate(unstacked_values):
         base_layer.add_summary(f'{summary_key}/{i}', v, summary_type)
     else:
       base_layer.add_summary('{summary_key}', summary_value, summary_type)
Ejemplo n.º 3
0
  def apply_gradient(
      self,
      old_vars: NestedJTensor,
      transformed_grads: NestedJTensor,
      var_is_learnable: NestedBool,
  ) -> NestedJTensor:
    """Applies grads to model_variables.

    Note, in a flax model learnable variables are often referred to as 'params'.
    But since 'params' in Lingvo often refers to a hyperparams.Params, we
    refer to learnable weights of a network as 'variables'.

    Args:
      old_vars: a nested structure of model variables.
      transformed_grads: grads of loss wrt to the old_vars. Must be of the same
        structure as old_var. 'transformed_grads' have already gone through
        various gradient transformations.
      var_is_learnable: a nested structure of boolean values indicate whether a
        var is trainable. Must be of the same structure as old_vars.
        'non-trainable' vars include batch norm stats, various other counts,
        etc. Only learnable variables are updated.

    Returns:
      updated variables. Only learnable variables are updated.
    """
    p = self.params
    tf.nest.assert_same_structure(old_vars, transformed_grads)
    tf.nest.assert_same_structure(old_vars, var_is_learnable)

    assert p.skip_zero_gradients is None

    # Add a summary of total var norm.
    var_squared = jax.tree_map(lambda x: jnp.sum(x * x), old_vars)
    var_squared, _ = jax.tree_flatten(var_squared)
    var_squared = jnp.concatenate([x[jnp.newaxis] for x in var_squared])
    var_norm = jnp.sqrt(jnp.sum(var_squared))
    base_layer.add_summary('var_norm', var_norm)

    # TODO(yonghui): implement skip_zero_gradients.
    # TODO(yonghui): implement numerical checks.

    def _adjust_var(old_var, transformed_grad, is_learnable):
      if is_learnable:
        return old_var + transformed_grad
      else:
        return old_var

    return tf.nest.map_structure(_adjust_var, old_vars, transformed_grads,
                                 var_is_learnable)
Ejemplo n.º 4
0
  def compute_loss(self, predictions: NestedMap,
                   input_batch: NestedMap) -> Tuple[Metrics, Dict[str, Any]]:
    """Computes the loss and other metrics for the given predictions.

    Args:
      predictions: The output of `compute_predictions`.
      input_batch: A `.NestedMap` object containing input tensors to this tower.

    Returns:
      - A dict or NestedMap containing str keys and (metric, weight) pairs as
        values, where one of the entries is expected to correspond to the loss.
      - A dict containing arbitrary tensors describing something about each
        training example, where the first dimension of each tensor is the batch
        index. The base class just returns an empty dict.
    """
    avg_xent = predictions.softmax_output.avg_xent
    total_weight = predictions.softmax_output.total_weight
    metrics = NestedMap(
        avg_xent=(avg_xent, total_weight),
        num_predictions=(total_weight, jnp.array(1.0, total_weight.dtype)))
    # Compute top-1 and top-5 accuracy and add summary.
    acc1 = metric_utils.top_k_accuracy(
        1,
        predictions.softmax_output.logits,
        label_probs=input_batch.label_probs,
        weights=predictions.example_weights)
    acc5 = metric_utils.top_k_accuracy(
        5,
        predictions.softmax_output.logits,
        label_probs=input_batch.label_probs,
        weights=predictions.example_weights)
    metrics.update(
        accuracy=(acc1, predictions.softmax_output.total_weight),
        acc5=(acc5, predictions.softmax_output.total_weight),
        error=(1.0 - acc1, predictions.softmax_output.total_weight),
        error5=(1.0 - acc5, predictions.softmax_output.total_weight))
    # Add top-1 and top-5 accuracies to summaries.
    base_layer.add_summary('acc1', acc1)
    base_layer.add_summary('acc5', acc5)
    return metrics, {}
Ejemplo n.º 5
0
    def fprop(self,
              inputs: JTensor,
              class_weights: JTensor,
              class_ids: Optional[JTensor] = None,
              class_probabilities: Optional[JTensor] = None) -> NestedMap:
        """Computes logits, cross entropy etc.

    Args:
      inputs: a single JTensor with shape [..., input_dim].
      class_weights: a JTensor with shape [..., 1] containing the weights for
        each target word.
      class_ids: a JTensor with shape [..., 1] of int32 dtype containing the
        target class labels.
      class_probabilities: a JTensor with shape [..., num_classes] of float
        values indicating class-membership probabilities.

    Returns:
      A `.NestedMap` containing the following fields

      - logits: with shape [..., num_classes]. Unnormalized softmax's logits.
      - per_example_argmax: with shape [...]. argmax of i-th example.
      - per_example_xent: with shape [...]. Cross entropy between i-th example's
        prediction and its label.
      - per_example_weight: with shape [...]. class_weights casted to
        this layer's dtype.
      - total_xent: A scalar. The sum of per_example_weight * per_example_xent.
      - total_weight: A scalar. The sum of per_example_weight.
      - avg_xent: A scalar. total_loss / total_weight.
    """
        p = self.params
        # Assert one of class_ids or class_probabilities is not None
        if class_ids is None and class_probabilities is None:
            raise ValueError(
                'One of class_ids or class_probabilities must be given.')

        # Compute logits
        inputs_dtype = inputs.dtype
        logits = self.get_logits(inputs)
        # We perform softmax in float32 to improve stability.
        logits = logits.astype(jnp.float32)
        log_probs = jax.nn.log_softmax(logits)

        if class_probabilities is None:
            class_probabilities = jax.nn.one_hot(
                jnp.squeeze(class_ids, axis=-1), p.num_classes)
            class_probabilities = jax.lax.stop_gradient(class_probabilities)

        per_example_xent = -jnp.sum(log_probs * class_probabilities, axis=-1)

        per_example_argmax = jax.lax.stop_gradient(jnp.argmax(logits, axis=-1))

        # Compute total softmax for the entire sequence
        total_xent = jnp.sum(
            jnp.expand_dims(per_example_xent, axis=-1) * class_weights)

        total_weight = jnp.sum(class_weights)

        if p.use_tgt_labels_size_as_loss_denominator:
            loss_denominator = jnp.sum(jnp.ones_like(class_weights))
        else:
            loss_denominator = total_weight
        avg_xent = (total_xent / loss_denominator).astype(inputs_dtype)
        z_loss = (jnp.sum(self.compute_z_loss(logits) * class_weights) /
                  loss_denominator)
        z_loss *= p.z_loss_weight
        base_layer.add_summary('aux_z_loss', z_loss)
        aux_loss_ctx = py_utils.AuxLossContext.Current()

        if aux_loss_ctx is not None:
            aux_loss_ctx.AddLoss(z_loss)

        output_nmap = NestedMap(
            logits=logits.astype(inputs_dtype),
            log_probs=log_probs.astype(inputs_dtype),
            per_example_argmax=per_example_argmax.astype(inputs_dtype),
            per_example_xent=per_example_xent.astype(inputs_dtype),
            total_xent=total_xent.astype(inputs_dtype),
            # base_model.py _compute_xent_loss_helper uses avg_xent_weight if set,
            # this helper is currently used by LanguageModel only, if we have
            # EncoderDecoder model we will have to adjust weighting as well.
            avg_xent_weight=loss_denominator,
            avg_xent=avg_xent,
            total_weight=total_weight)

        return output_nmap
Ejemplo n.º 6
0
def train_step_single_learner(
    jax_task: base_task.SingleTask,
    states: TrainState,
    prng_key: JTensor,
    inputs: Union[JTensor, NestedMap],
    data_parallel_axis_name: Optional[str] = 'batch',
    fprop_dtype: jnp.dtype = jnp.float32
) -> Tuple[TrainState, Any, Any, Any, SummaryDict]:
  """Trains a model for a single step.

  This function works for both pmap-ed model and pjit-ed model. When this
  function is called from pmap-ed trainer, data_parallel_axis_name has to be
  set. Otherwise, data_parallel_axis_name should be either an empty string or
  None.

  TODO(yonghui): Maybe refactor pmap and pjit into two functions.

  This utility is specialized for the singler learner case.

  Args:
    jax_task: An instance of base_task.SingleTask.
    states: An instance of model.TrainState.
    prng_key: A PRNGKey, of shape [2], of type np.uint32.
    inputs: Inputs to the mdl.fprop() function.
    data_parallel_axis_name: a string, the device axis to aggregate gradients
      over.
    fprop_dtype: fprop datatype, can be either jnp.float32 or jnp.bfloat16.

  Returns:
    A tuple of the following elements.
    updated_states - updated states.
    loss - loss as computed by mdl.fprop.
    mean_metrics - a dict of metrics. Each element of the dict is a pair
    (metric, weight).
    per_example_out - auxilillary per-example output as computed in mdl.fprop.
    summary_tensors - A dict or nested map of summary tensors computed in
      forward as well as backward.
  """
  assert len(jax_task.learners) == 1
  learner = jax_task.learners[0]
  model = jax_task.model

  context_p = base_layer.JaxContext.Params().Set(do_eval=False)
  # Fold in global_step as part of the random seed key, so that random
  # numbers depends on global step.
  prng_key = jax.random.fold_in(prng_key, states.step)

  if data_parallel_axis_name:
    in_pmap = True
  else:
    in_pmap = False

  prng_key, subkey = jax.random.split(prng_key)

  def _loss_fn(
      mdl_vars: NestedJTensor, inputs: NestedMap
  ) -> Tuple[JTensor, Tuple[Any, NestedMap, SummaryDict, SummaryDict]]:
    """Computes loss as well as other auxiliary outputs."""
    if fprop_dtype == jnp.float32:
      pass
    elif fprop_dtype == jnp.bfloat16:
      mdl_vars = jax.tree_map(_maybe_to_bfloat16, mdl_vars)
      inputs = jax.tree_map(_maybe_to_bfloat16, inputs)
    else:
      assert NotImplementedError(f'fprop_dtype {fprop_dtype} not supported.')

    with base_layer.JaxContext.new_context(
        params=context_p, prng_key=subkey,
        global_step=states.step) as jax_context:
      jax_context.bind(model, model.vars_to_flax_vars(mdl_vars),
                       [base_layer.SCOPE_VARS, base_layer.SCOPE_AUX_LOSS])

      metrics, per_example_output = model.fprop(inputs)
      loss_name = learner.loss_name
      assert loss_name in metrics
      loss, loss_weight = metrics[loss_name]
      assert loss.ndim == 0, 'loss has to be a scalar.'
      assert loss_weight.ndim == 0, 'loss_weight has to be a scalar'
      loss_weight = jax.lax.stop_gradient(loss_weight)
      if in_pmap:
        # Renormalize loss weight by the total weight across all replicas.
        # This also takes care of dividing loss by num of data parallel
        # replicas.
        loss_weight /= jax.lax.psum(
            loss_weight, axis_name=data_parallel_axis_name)
      else:
        # loss_weight == 1 in spmd.
        loss_weight /= jnp.sum(loss_weight)
      weighted_loss = loss * loss_weight
      # Fetch forward-updated vars, which often include batch norm vars, other
      # misc stats, etc.
      forward_updated_vars = model.updated_vars
      # Finally, fetch all the summary tensors.
      summary_tensors = base_layer.all_summaries()
      if in_pmap:
        summary_tensors = summary_utils.aggregate_per_replica_summaries(
            summary_tensors, data_parallel_axis_name)
    if fprop_dtype == jnp.bfloat16 and weighted_loss.dtype == fprop_dtype:
      weighted_loss = weighted_loss.astype(jnp.float32)
    return weighted_loss, (metrics, forward_updated_vars, summary_tensors,
                           per_example_output)

  grad_fn = jax.value_and_grad(_loss_fn, has_aux=True)

  (weighted_loss, (metrics, fwd_updated_vars, fwd_summary_tensors,
                   per_example_out)), grads = grad_fn(states.mdl_vars, inputs)

  if in_pmap:
    # Scale weighted_loss back after gradient computation.
    # This is the average loss across all replicas.
    weighted_loss = jax.lax.psum(
        weighted_loss, axis_name=data_parallel_axis_name)
  else:
    # No sum of loss over all the replicas.
    pass

  if in_pmap:
    mean_metrics = type(metrics)()
    for key in metrics:
      value, weight = metrics[key]
      sum_value = jax.lax.psum(
          value * weight, axis_name=data_parallel_axis_name)
      sum_weight = jax.lax.psum(weight, axis_name=data_parallel_axis_name)
      mean_metrics[key] = (sum_value / (sum_weight + 1e-8), sum_weight)
  else:
    # No aggregation is needed.
    mean_metrics = metrics

  var_weight_params = model.vars
  tf.nest.assert_same_structure(states.mdl_vars, var_weight_params)
  tf.nest.assert_same_structure(states.mdl_vars, grads)
  tf.nest.assert_same_structure(states.mdl_vars, fwd_updated_vars)
  var_is_learnable = tf.nest.map_structure(
      lambda x: not base_layer.var_not_trainable(x), var_weight_params)
  tf.nest.assert_same_structure(states.mdl_vars, var_is_learnable)

  def _maybe_zero_out_grad_fn(var_grad, var, var_learnable):
    if var_learnable:
      return var_grad
    else:
      return jnp.zeros_like(var)

  # Zero-out gradient for non-learnable vars.
  grads = tf.nest.map_structure(_maybe_zero_out_grad_fn, grads, states.mdl_vars,
                                var_is_learnable)

  if in_pmap:
    # Aggregate grads across different model replicas.
    grads = jax.lax.psum(grads, axis_name=data_parallel_axis_name)
  else:
    # No gradient aggregation is needed.
    pass

  # Carry out backward computation under a JaxContext.
  prng_key, subkey = jax.random.split(prng_key)
  with base_layer.JaxContext.new_context(
      params=context_p, prng_key=subkey,
      global_step=states.step) as jax_context:
    # Nothing is allowed to change, except for summaries.
    jax_context.bind(model, model.vars_to_flax_vars(states.mdl_vars),
                     [base_layer.SCOPE_AUX_LOSS])

    # Add a summary for learning rate
    learning_rate = learner.optimizer.get_learning_rate(states.step)
    base_layer.add_summary('learning_rate', learning_rate)

    # Apply gradient transformations.
    transformed_grads, new_opt_states = learner.update_states(
        grads, states.opt_states[0], states.mdl_vars, var_weight_params)

    # Gradient descent on learnable vars.
    mdl_vars = learner.apply_gradient(states.mdl_vars, transformed_grads,
                                      var_is_learnable)

    def _synchronize_vars_using_mean(new_var: NestedMap,
                                     old_var: NestedMap) -> NestedMap:
      """Synchronize variables across a replica by averaging."""
      delta = new_var - old_var
      delta_mean = jax.lax.pmean(delta, axis_name=data_parallel_axis_name)
      updated_var = old_var + delta_mean
      return updated_var

    def _update_non_learnable_var(old_var: NestedMap, new_var: NestedMap,
                                  var_params: ParamsT) -> NestedMap:
      """Update non-trainable variables, using cross-replica synchronization.

      Args:
        old_var: Nested map of old variables.
        new_var: Nested map of new variables.
        var_params: Nested map of var param attributes such as whether a
          variable is trainable or requires synchornization across replicas.

      Returns:
        Updated variables.

      Raises:
        ValueError if no synchronization method is provided for non-trainable
        variables.
      """
      if not base_layer.var_not_trainable(var_params):
        assert new_var is None
        return old_var
      elif not in_pmap:
        # No aggregation is needed.
        assert new_var is not None
        return new_var
      elif base_layer.var_requires_mean_sync(var_params):
        assert new_var is not None
        return _synchronize_vars_using_mean(new_var, old_var)
      else:
        raise ValueError('Non-trainable variables must have a cross-replica '
                         'synchronization method specified.')

    var_weight_params = model.vars
    tf.nest.assert_same_structure(mdl_vars, var_weight_params)
    mdl_vars = tf.nest.map_structure(_update_non_learnable_var, mdl_vars,
                                     fwd_updated_vars, var_weight_params)

    new_states = states.new_state(
        mdl_vars=mdl_vars, opt_states=[new_opt_states])
    # Finally fetch all backward summary tensors. We do not aggregate the scalar
    # summaries with pmean because the grads are already psum-ed.
    bwd_summary_tensors = base_layer.all_summaries()

  summary_tensors = NestedMap(
      fwd_summary_tensors=fwd_summary_tensors,
      bwd_summary_tensors=bwd_summary_tensors)

  return (new_states, weighted_loss, mean_metrics, per_example_out,
          summary_tensors)
Ejemplo n.º 7
0
    def fprop(self,
              inputs: JTensor,
              paddings: Optional[JTensor] = None) -> Tuple[JTensor, JTensor]:
        """Computes distances of the given input 'x' to all centroids.

    Args:
      inputs: Input tensor of shape [B, L, N, H] or [B, L, D].
      paddings: If not None, a tensor of shape [B, L]. The padding tensor is
        supplied when we want certain tokens to not affect the centroids.

    Returns:
      dists: "distances" of the given input 'x' to all centroids.
             Shape [B, L, N, K].
      nearest_centroid: The inputs with the input embeddings replaced by the
             centroid embeddings, it has the same shape as the inputs i.e.,
             [B, L, N, H].
    """
        p = self.params
        theta = self.local_theta()
        inputs = self._cast_to_fprop_dtype(inputs)
        inputs_shape = inputs.shape
        if len(inputs_shape) == 3:
            inputs = jnp.reshape(inputs, [
                inputs_shape[0], inputs_shape[1], p.num_heads, p.dim_per_head
            ])

        if paddings is not None:
            # Shape [B, L, 1, 1]
            paddings_4d = paddings[:, :, jnp.newaxis, jnp.newaxis]

        dists = -2 * jnp.einsum('BLNH, NKH -> BLNK', inputs, theta.means)
        # [B, L, N, 1]
        inputs_norm_sq = jnp.sum(jnp.square(inputs), axis=-1, keepdims=True)
        # [N, K]
        means_norm_sq = jnp.sum(jnp.square(theta.means),
                                axis=-1,
                                keepdims=False)
        # [1, 1, N, K]
        means_norm_sq = means_norm_sq[jnp.newaxis, jnp.newaxis, :, :]
        dists += inputs_norm_sq + means_norm_sq

        # Shape [B, L, N, K], the same as 'dists' above.
        nearest_one_hot = jax.nn.one_hot(jnp.argmin(dists, axis=-1),
                                         p.num_clusters,
                                         dtype=theta.means.dtype)

        # Apply paddings.
        if paddings is not None:
            nearest_one_hot *= (1 - paddings_4d)

        # Same shape as the input [B, L, N, H].
        nearest_centroid = jnp.einsum('BLNK, NKH -> BLNH', nearest_one_hot,
                                      theta.means)

        means_norm = jnp.linalg.norm(theta.means, ord=2, axis=-1)
        base_layer.add_summary('k_means/centroid/l2_norm_avg',
                               jnp.mean(means_norm))
        base_layer.add_summary('k_means/centroid/l2_norm_min',
                               jnp.min(means_norm))
        base_layer.add_summary('k_means/centroid/l2_norm_max',
                               jnp.max(means_norm))

        if not self.do_eval:
            # To update the centroids (self.vars.means), we apply gradient descent on
            # the mini-batch of input, which yields the following:
            #   new_centroid = centroid + (1 - decay) * (x_mean - centroid)
            # where x_mean is the average over all the input vectors closest to this
            # centroid.

            # Sum away batch and sequence length dimensions to get per cluster count.
            # Shape: [N, K]
            per_cluster_count = jnp.sum(nearest_one_hot, axis=[0, 1])
            base_layer.add_summary('k_means/centroid/avg_cluster_count',
                                   jnp.mean(per_cluster_count))

            # Sum of the input per each closest centroid.
            sum_x = jnp.einsum('BLNK, BLNH -> NKH', nearest_one_hot, inputs)

            # Sum up cluster counts across replicas.

            # If per_cluster_count for a cluster is 0, then 'nearest_one_hot' in that
            # cluster's position will always be 0, hence 'sum_x' in that dimension
            # will be 0.
            new_means = sum_x / (p.epsilon +
                                 jnp.expand_dims(per_cluster_count, axis=-1))
            updated_means = (1.0 - p.decay) * new_means + p.decay * theta.means
            updated_means = jnp.array(updated_means, self.vars.means.dtype)
            self.update_var('means', updated_means)
        return dists, nearest_centroid
Ejemplo n.º 8
0
 def add_grad_norm_summary(key, value):
   base_layer.add_summary(f'{learner_name}/var_grad_norm/{key}', value)
Ejemplo n.º 9
0
  def scale_gradients(self, raw_grads: NestedMap) -> Tuple[NestedMap, JTensor]:
    """Scales the gradient.

    Args:
      raw_grads: A nested structure of gradient values.

    Returns:
     A nested structure with the rescaled gradient values.
     A predicate tensor indicating whether the step is valid, i.e., it does not
       have anomaly detected (e.g. Nan or Inf, or excessively big gradient norm)
       and should not be skipped.
    """
    p = self.params
    learner_name = self.params.name
    # Compute gradient norm.
    grad_squared = jax.tree_map(lambda x: jnp.sum(x * x), raw_grads)

    if p.grad_norm_individual_vars:
      grad_norms = jax.tree_map(jnp.sqrt, grad_squared)
      var_keys = py_utils.extract_prefixed_keys_from_nested_map(grad_norms)

      def add_grad_norm_summary(key, value):
        base_layer.add_summary(f'{learner_name}/var_grad_norm/{key}', value)

      jax.tree_multimap(add_grad_norm_summary, var_keys, grad_norms)

    grad_squared, _ = jax.tree_flatten(grad_squared)
    grad_squared = jnp.concatenate([x[jnp.newaxis] for x in grad_squared])
    raw_grad_norm = jnp.sqrt(jnp.sum(grad_squared))
    base_layer.add_summary(f'{learner_name}/grad_norm', raw_grad_norm)

    def keep_step(grad_norm):
      keep_threshold = p.skip_step_gradient_norm_value
      if keep_threshold:
        return jnp.logical_and(
            jnp.all(jnp.isfinite(grad_norm)),
            jnp.all(jnp.less(grad_norm, keep_threshold)))
      else:
        return jnp.all(jnp.isfinite(grad_norm))

    def clip_grads(grads, grad_norm):
      if p.optimizer.clip_gradient_norm_to_value:
        assert p.optimizer.clip_gradient_single_norm_to_value == 0.
        grad_scale = jnp.minimum(
            jnp.array(1, grad_norm.dtype),
            jnp.array(p.optimizer.clip_gradient_norm_to_value, grad_norm.dtype)
            / grad_norm)
        grads = jax.tree_map(lambda g: g * grad_scale, grads)
      elif p.optimizer.clip_gradient_single_norm_to_value:
        assert p.optimizer.clip_gradient_norm_to_value == 0.
        grad_single_norm = jax.tree_map(lambda x: jnp.sqrt(jnp.sum(x * x)),
                                        grads)

        def scale_gradient(grad, norm):
          return grad * jnp.minimum(
              jnp.array(1, norm.dtype),
              jnp.array(p.optimizer.clip_gradient_single_norm_to_value,
                        norm.dtype) / norm)

        grads = jax.tree_map(scale_gradient, grads, grad_single_norm)
        grad_scale = jnp.array(1.0)
      else:
        # no clipping is needed.
        grad_scale = jnp.array(1.0)
      return grads, grad_scale

    # Mark the step as invalid if any gradient anomaly is detected (e.g. Nan or
    # Inf, or excessively big gradient norm).
    valid_step = keep_step(raw_grad_norm)
    grads, grad_scale = clip_grads(raw_grads, raw_grad_norm)
    base_layer.add_summary('grad_scale', grad_scale)

    return grads, valid_step
Ejemplo n.º 10
0
    def compute_and_update_moments(
            self, inputs: JTensor,
            paddings: JTensor) -> Tuple[JTensor, JTensor, JTensor, JTensor]:
        """Computes moments and updates state.

    Args:
      inputs: The inputs JTensor. Shaped [..., dim].
      paddings: The paddings JTensor. Shaped [..., 1], with the same rank as the
        input JTensor.

    Returns:
      Tuple of (mean, variance, beta, gamma).
    """
        p = self.params
        theta = self.local_theta()
        if self.do_eval:
            # The mean and variance used for normalization.
            norm_mean, norm_variance = theta.moving_mean, theta.moving_variance
            base_layer.add_summary('moving_mean', theta.moving_mean)
            base_layer.add_summary('moving_variance', theta.moving_variance)
        else:
            rank = inputs.ndim
            reduce_over_dims = list(range(0, rank - 1))
            mean, variance = compute_moments(inputs,
                                             paddings,
                                             reduce_over_dims,
                                             enable_cross_replica_sum_on_tpu=p.
                                             enable_cross_replica_sum_on_tpu,
                                             keepdims=True)

            new_moving_mean = theta.moving_mean * p.decay + mean * (1.0 -
                                                                    p.decay)
            self.update_var('moving_mean', new_moving_mean)
            new_moving_variance = (theta.moving_variance * p.decay + variance *
                                   (1.0 - p.decay))
            self.update_var('moving_variance', new_moving_variance)

            # Add some summaries for visualization.
            base_layer.add_summary('mean', mean)
            base_layer.add_summary('variance', variance)
            base_layer.add_summary('moving_mean', theta.moving_mean)
            base_layer.add_summary('moving_variance', theta.moving_variance)
            if p.use_moving_avg_in_training:
                # Use the global statistics for normalization.
                norm_mean = theta.moving_mean
                norm_variance = theta.moving_variance
            else:
                # Use the batch statistics for normalization.
                norm_mean = mean
                norm_variance = variance

        beta, gamma = self._get_beta_gamma()
        return norm_mean, norm_variance, beta, gamma