Ejemplo n.º 1
0
  def _ProcessSingleInput(self, source_id, src, tgt):
    """Performs strings-to-ids on the given input pair via p.tokenizer_dict."""
    _, src_labels, src_paddings = self.StringsToIds(
        tf.reshape(src, [1]), is_source=True, key=self._src_tokenizer_key)
    tgt_ids, tgt_labels, tgt_paddings = self.StringsToIds(
        tf.reshape(tgt, [1]), is_source=False, key=self._tgt_tokenizer_key)
    # Mask positions to 0 where padding is 1 for consistency. We do this because
    # tokenizer implementation may use EOS token to pad.
    src_labels = py_utils.ApplyPadding(src_paddings, src_labels)
    tgt_ids = py_utils.ApplyPadding(tgt_paddings, tgt_ids)
    tgt_labels = py_utils.ApplyPadding(tgt_paddings, tgt_labels)

    features = py_utils.NestedMap()
    features.src = py_utils.NestedMap()
    features.src.ids = src_labels
    # ids_indicator is 1 if and only if the output from tokenizer has a
    # non-padded id. Unlike weights, it will not mutate and can be used for
    # determining actual sequence length, for example.
    features.src.ids_indicator = 1 - src_paddings
    features.tgt = py_utils.NestedMap()
    features.tgt.ids = tgt_ids
    features.tgt.labels = tgt_labels
    features.tgt.ids_indicator = 1 - tgt_paddings

    src_task_id, tgt_task_id = self._GetTaskIds(source_id)
    # task_ids are padded with zeros.
    features.src.task_ids = tf.cast(
        features.src.ids_indicator, dtype=tf.int32) * src_task_id
    features.tgt.task_ids = tf.cast(
        features.tgt.ids_indicator, dtype=tf.int32) * tgt_task_id

    if not py_utils.use_tpu():
      features.src.strs = src
      features.tgt.strs = tgt
    return features.Transform(tf.squeeze)
Ejemplo n.º 2
0
 def testApplyPaddingToConstWithBroadcast(self):
   with self.session():
     y = py_utils.ApplyPadding(
         tf.convert_to_tensor([[0.0], [1.0], [0.0]]),
         tf.convert_to_tensor([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]]),
         tf.convert_to_tensor([[1.0, 2.0], [9.0, 10.0], [5.0, 6.0]])).eval()
     self.assertAllClose(y, [[1.0, 2.0], [9.0, 10.0], [5.0, 6.0]])
Ejemplo n.º 3
0
def AddMultiCurveSubplot(fig,
                         tensors,
                         paddings,
                         labels,
                         xlabels=None,
                         **kwargs):
    """Adds a multi curve subplot to Matplotlib figure.

  Plots one line for each entry in tensors and assigns a plot label legend.

  Args:
    fig: The Matplotlib figure.
    tensors: List of tensors of shape [batch, length]
    paddings: Paddings for 'tensors' with shape [batch, length] with 0. in valid
      positions and 1. in invalid.
    labels: A list of tensor names (strings) of the same length as 'tensors'.
    xlabels: A string tensor of shape [batch] with an xlabel per batch.
    **kwargs: With optional, title, xlabel, ylabel, fontsize.
  """
    data = []
    row_labels = []
    for t, l in zip(tensors, labels):
        if t is not None:
            data.append(py_utils.ApplyPadding(paddings, t))
            row_labels.append(l)
    shape = py_utils.GetShape(data[0], 2)
    data = tf.reshape(tf.concat(data, -1), [shape[0], len(data), shape[1]])

    args = [data, py_utils.LengthsFromPaddings(paddings)]
    if xlabels is not None:
        args.append(xlabels)
    fig.AddSubplot(args,
                   plot_func=_AddMultiCurveRowPlots,
                   row_labels=row_labels,
                   **kwargs)
Ejemplo n.º 4
0
 def _ComputeBN(self, inputs, paddings, gamma, beta, norm_mean,
                norm_variance):
     p = self.params
     with tf.control_dependencies([
             py_utils.assert_greater_equal(norm_variance,
                                           tf.zeros_like(norm_variance)),
             py_utils.assert_shape_match([tf.shape(inputs)[-1]],
                                         tf.shape(norm_mean)),
             py_utils.assert_shape_match([tf.shape(inputs)[-1]],
                                         tf.shape(norm_variance)),
     ]):
         if p.use_fused_batch_norm_for_eval and (self.do_eval
                                                 or p.freeze_bn_stats):
             bn_output, _, _ = nn.fused_batch_norm(inputs,
                                                   gamma,
                                                   beta,
                                                   norm_mean,
                                                   norm_variance,
                                                   self._epsilon,
                                                   is_training=False)
         else:
             bn_output = tf.nn.batch_normalization(inputs, norm_mean,
                                                   norm_variance, beta,
                                                   gamma, self._epsilon)
         if p.set_padded_output_to_zero:
             bn_output = py_utils.ApplyPadding(paddings, bn_output)
     return bn_output
Ejemplo n.º 5
0
  def testCausalConv2DLayerStridedWithPaddingFPropV2(self, seq_len):
    """Check strided convs get the same values for different length dim."""
    with self.session(use_gpu=True):
      batch_size = 5
      expected_seq_len = 3

      params = conv_layers.CausalConv2DLayerWithPadding.Params()
      params.v2_padding = True
      params.weight_norm = False
      params.filter_stride = [2, 2]
      params.name = 'conv'
      params.filter_shape = [3, 1, 1, 1]
      params.params_init = py_utils.WeightInit.Constant(1.0)
      conv_layer = params.Instantiate()

      # Set up the padding for the sequence length. (starting at 5).
      in_padding = tf.constant([
          [0, 0, 0, 0, 0],
          [0, 0, 0, 0, 1],
          [0, 0, 0, 1, 1],
          [0, 0, 1, 1, 1],
          [0, 1, 1, 1, 1],
      ], tf.float32)
      in_padding = tf.pad(
          in_padding, [[0, 0], [0, seq_len - 5]], constant_values=1.0)

      inputs = 1.0 + tf.tile(
          tf.reshape(tf.range(seq_len, dtype=tf.float32), [1, seq_len, 1, 1]),
          [batch_size, 1, 3, 1])
      inputs = py_utils.ApplyPadding(
          tf.reshape(in_padding, [batch_size, seq_len, 1, 1]), inputs)

      inputs = py_utils.Debug(inputs)

      output, out_padding = conv_layer.FPropDefaultTheta(inputs, in_padding)

      output = py_utils.Debug(output)
      out_padding = py_utils.Debug(out_padding)

      self.evaluate(tf.global_variables_initializer())
      output, out_padding = self.evaluate([output, out_padding])

      self.assertEqual((batch_size, expected_seq_len, 2, 1), output.shape)
      self.assertAllClose([
          [0, 0, 0],
          [0, 0, 1],
          [0, 0, 1],
          [0, 1, 1],
          [0, 1, 1],
      ], out_padding)

      self.assertAllClose(
          [
              [[[1], [1]], [[6], [6]], [[12], [12]]],
              [[[1], [1]], [[6], [6]], [[7], [7]]],
              [[[1], [1]], [[6], [6]], [[3], [3]]],  # NOTE: not padded.
              [[[1], [1]], [[3], [3]], [[0], [0]]],
              [[[1], [1]], [[1], [1]], [[0], [0]]],
          ],
          output)
Ejemplo n.º 6
0
 def testApplyPaddingToZeroWithoutBroadcastArithmetic(self):
   with self.session():
     y = py_utils.ApplyPadding(
         tf.convert_to_tensor([[0.0, 0.0], [1.0, 0.0], [0.0, 1.0]]),
         tf.convert_to_tensor([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]]),
         use_select=False).eval()
     self.assertAllClose(y, [[1.0, 2.0], [0.0, 4.0], [5.0, 0.0]])
    def StreamStep(self, theta, inputs, paddings, state0):
        """Apply a single step of convolution to input_tensor.

    Only supports 1d causal convolution. Doesn't support dilation.

    Args:
      theta: A NestedMap of layer params.
      inputs: A Tensor of shape [b, t, 1, c]
      paddings: A 0/1 valued tensor of shape [b, t].
      state0: A NestedMap of tensors of the same struct as returned by
        zero_state().

    Returns:
      outputs: A Tensor of shape [b, t, 1, c * channel_multiplier]
      padding: the same as input paddings.
      state1: A NestedMap of the same struct as input state
    """
        p = self.params
        assert p.filter_shape[1] == 1, (
            'StreamStep only supports 1d causal convolution.')
        assert p.filter_stride[0] == 1, (
            'StreamStep doesn\'t support striding')
        assert p.dilation_rate == (1,
                                   1), ('StreamStep doesn\'t support dilation')

        with tf.name_scope(p.name):
            inputs = py_utils.HasShape(inputs, [-1, -1, 1, p.filter_shape[2]])
            paddings = py_utils.HasShape(paddings,
                                         py_utils.GetShape(inputs)[:2])
            q = py_utils.GetShape(paddings)[1]

            padded_inputs = py_utils.ApplyPadding(
                py_utils.AppendDims(paddings, 2), inputs)

            concat_inputs = tf.concat([state0.context, padded_inputs], axis=1)
            outputs = tf.nn.depthwise_conv2d(concat_inputs,
                                             self._GetWeight(theta),
                                             strides=(1, 1, 1, 1),
                                             dilations=(1, 1),
                                             data_format='NHWC',
                                             padding='VALID')
            if p.bias:
                outputs = tf.nn.bias_add(outputs, theta.b)
            new_context = concat_inputs[:, q:]
            return outputs, paddings, py_utils.NestedMap(context=new_context)
Ejemplo n.º 8
0
  def FProp(self, theta, prepared_inputs, step_inputs, padding, state0):
    """Produces a context vector from the attention algorithm.

    The context vector is a summary of the inputs from external_inputs
    which the attention algorithm has determined would be useful for decoding
    the next output.

    Args:
      theta: A NestedMap containing weights' values of this layer and its
        children layers.
      prepared_inputs: A set of encoded tensors that have been pre-processed by
        PrepareExternalInputs.
      step_inputs: A NestedMap containing an 'inputs' tensor with the query
        vector to use.
      padding: A [batch, 1] 0/1 float tensor, where 1.0 means that this batch
        slot is not used.
      state0: A NestedMap of state, either produced by ZeroState or a previous
        invocation of this graph.

    Returns:
      output, state1, defined as follows:
      - output: a NestedMap containing a query tensor, a context tensor, and
        cum_atten_probs, the log of attention probabilities for each input
        vector.
      - state1: a NestedMap of state to be used in subsequent invocations of
        this graph.
    """
    (new_atten_context, new_atten_probs,
     new_atten_states) = self.atten.ComputeContextVectorWithSource(
         theta.atten,
         prepared_inputs.packed_src,
         tf.concat(step_inputs.inputs, axis=1),
         attention_state=state0.atten_state)
    new_atten_probs = py_utils.ApplyPadding(padding, new_atten_probs)
    output = py_utils.NestedMap(
        context=new_atten_context, probs=new_atten_probs)
    state1 = py_utils.NestedMap(
        atten_context=new_atten_context, atten_state=new_atten_states)
    return output, state1
Ejemplo n.º 9
0
  def FProp(self, theta, x, paddings=None, update=False):
    """Computes distances of the given input 'x' to all centroids.

    This implementation applies layer normalization on 'x' internally first,
    and the returned 'dists' is computed using the normalized 'x'.

    Args:
      theta: A `.NestedMap` of weights' values of this layer.
      x: A tensor of shape [B, L, N, H].
      paddings: If not None, a tensor of shape [B, L].
      update: bool, whether to update centroids using x.

    Returns:
      dists: "distances" of the given input 'x' to all centroids.
             Shape [B, L, N, K].
      k_means_loss: the average squared Euclidean distances to the closest
                    centroid, a scalar.
    """
    p = self.params
    if paddings is None:
      paddings = tf.zeros_like(x[:, :, 0, 0])
    # Shape [B, L, 1, 1]
    paddings_4d = paddings[:, :, None, None]

    if p.apply_layer_norm:
      x = KMeansClusteringForAtten.LayerNorm(x, p.epsilon)

    # 'x' is normalized (but theta.means is not), we use negative dot product to
    # approximate the Euclidean distance here.
    dists = -tf.einsum('BLNH, NKH -> BLNK', x, theta.means)

    # For padded positions we update the distances to very large numbers.
    very_large_dists = tf.ones_like(dists) * tf.constant(
        0.1, dtype=dists.dtype) * dists.dtype.max
    paddings_tiled = tf.tile(paddings_4d, [1, 1, p.num_heads, p.num_clusters])
    dists = tf.where(paddings_tiled > 0.0, very_large_dists, dists)

    # Shape [B, L, N, K], the same as 'dists' above.
    nearest_one_hot = tf.one_hot(
        tf.math.argmin(dists, axis=-1),
        p.num_clusters,
        dtype=py_utils.FPropDtype(p))
    # Same shape as the input 'x'.
    nearest_centroid = tf.einsum('BLNK, NKH -> BLNH', nearest_one_hot,
                                 theta.means)
    diff = tf.math.squared_difference(x, tf.stop_gradient(nearest_centroid))
    diff = py_utils.ApplyPadding(paddings_4d, diff)
    diff = tf.math.reduce_mean(diff, axis=2)

    # The commitment loss which when back proped against encourages the 'x'
    # values to commit to their chosen centroids.
    k_means_loss = tf.math.reduce_sum(diff) / tf.math.reduce_sum(1.0 - paddings)
    summary_utils.scalar('k_means/squared_distance_loss', k_means_loss)

    # TODO(zhouwk): investigate normalizing theta.means after each update.
    means_norm = tf.norm(theta.means)
    summary_utils.scalar('k_means/centroid_l2_norm/min',
                         tf.math.reduce_min(means_norm))
    summary_utils.scalar('k_means/centroid_l2_norm/mean',
                         tf.math.reduce_mean(means_norm))

    if not update:
      return dists, k_means_loss

    # To update the centroids (self.vars.means), we apply gradient descent on
    # the mini-batch of input 'x', which yields the following:
    #   new_centroid = centroid + (1 - decay) * (x_mean - centroid)
    # where x_mean is the average over all the input vectors closest to this
    # centroid.
    #
    # Note that this approach is equivalent with backprop via
    #    loss = tf.math.reduce_mean(
    #        tf.math.squared_difference(tf.stop_gradient(x), nearest_centroid)))
    # , except that here the learning rate is independently set via 'decay'.

    # Ensure that the padded positions are not used to update the centroids.
    nearest_one_hot = py_utils.ApplyPadding(paddings_4d, nearest_one_hot)

    # Sum away batch and sequence length dimensions to get per cluster count.
    # Shape: [N, K]
    per_cluster_count = tf.reduce_sum(nearest_one_hot, axis=[0, 1])
    summary_utils.histogram('k_means/per_cluster_vec_count', per_cluster_count)

    # Sum of the input 'x' per each closest centroid.
    sum_x = tf.einsum('BLNK, BLNH -> NKH', nearest_one_hot, x)

    if py_utils.use_tpu():
      per_cluster_count = tf.tpu.cross_replica_sum(per_cluster_count)
      sum_x = tf.tpu.cross_replica_sum(sum_x)

    # If per_cluster_count for a cluster is 0, then 'nearest_one_hot' in that
    # cluster's position will always be 0, hence 'sum_x' in that dimension will
    # be 0.
    new_means = sum_x / tf.maximum(
        tf.constant(1.0, dtype=per_cluster_count.dtype),
        tf.expand_dims(per_cluster_count, axis=-1))

    # We use exponential moving average. TODO(zhouwk): investigate smooth this
    # over an exponentially moving averaged per cluster count.
    #
    # Note that we intentionally do not normalize the means after this update
    # as empirically this works better.
    update_means_diff = tf.cast((1.0 - p.decay) * (new_means - theta.means),
                                self.vars.means.dtype)
    return py_utils.with_dependencies(
        [tf.assign_add(self.vars.means, update_means_diff)],
        dists), k_means_loss
Ejemplo n.º 10
0
    def _StreamMoments(self, inputs, paddings, cached_sum, cached_count,
                       cached_var):
        """Computes mean and variance over the valid data points in inputs.

    Args:
      inputs: [B, T, F, N, G] or [B, T, N, G]
      paddings: [B, T, 1, 1, 1] or [B, T, 1, 1]
      cached_sum: [B, 1, 1, N, 1] or [B, 1, N, 1]
      cached_count: same shape as cached_sum.
      cached_var: same shape as cached_sum.

    Returns:
      mean: [B, T, 1, N, 1] or [B, T, N, 1]
      variance: same shape as mean.
      new_cached_sum: same shape as cached_sum.
      new_cached_count: same shape as cached_count.
    """
        tf.logging.vlog(1, 'inputs: %r', inputs)
        tf.logging.vlog(1, 'paddings: %r', paddings)
        tf.logging.vlog(1, 'cached_sum: %r', cached_sum)
        tf.logging.vlog(1, 'cached_count: %r', cached_count)

        inputs = py_utils.ApplyPadding(paddings, inputs, use_select=False)

        input_rank = py_utils.GetRank(inputs)
        assert input_rank is not None, (f'inputs rank must be staic for '
                                        f'{repr(inputs)}')
        reduce_over_dims = list(range(input_rank))
        # Skip B, T, and N. Reduce {F,G} or just G.
        reduce_over_dims = reduce_over_dims[2:-2] + reduce_over_dims[-1:]
        tf.logging.vlog(1, 'reduce_over_dims: %s', reduce_over_dims)

        # [B, T, 1, N, 1] or [B, T, N, 1]
        sum_v = tf.reduce_sum(inputs, reduce_over_dims, keepdims=True)
        sum_v = tf.math.cumsum(sum_v, axis=1)
        sum_v += cached_sum

        # [B, T, 1, 1, 1] or [B, T, 1, 1]
        mask = tf.cast(1.0 - paddings, inputs.dtype)
        count_v = tf.reduce_sum(mask, reduce_over_dims, keepdims=True)
        count_v = tf.math.cumsum(count_v, axis=1)
        input_shape = py_utils.GetShape(inputs)
        if input_rank == 4:
            # F * G
            multiplier = input_shape[-1] * input_shape[-3]
        else:
            # G
            multiplier = input_shape[-1]
        count_v *= multiplier
        count_v += cached_count

        tf.logging.vlog(1, 'sum_v: %r', sum_v)
        tf.logging.vlog(1, 'count_v: %r', count_v)

        mean = sum_v / tf.maximum(count_v, 1.0)

        sum_vv = tf.reduce_sum(py_utils.ApplyPadding(
            paddings,
            tf.math.squared_difference(inputs, mean),
            use_select=False),
                               reduce_over_dims,
                               keepdims=True)
        sum_vv = tf.math.cumsum(sum_vv, axis=1)
        sum_vv += cached_var

        cached_sum = sum_v[:, -1:]
        cached_count = count_v[:, -1:]
        cached_var = sum_vv[:, -1:]

        variance = py_utils.with_dependencies([
            py_utils.assert_greater_equal(sum_vv, tf.cast(0, sum_vv.dtype)),
        ], sum_vv / tf.maximum(count_v, 1.0))
        return mean, variance, cached_sum, cached_count, cached_var
Ejemplo n.º 11
0
  def testConv2DLayerStridedWithPaddingFProp(self, seq_len):
    """Check strided convs get the same values for different length dim."""
    # TODO(isaace): THIS TEST SHOWS THAT THERE IS A BUG IN THE CODE.
    with self.session(use_gpu=True):
      batch_size = 3
      expected_seq_len = 3

      params = conv_layers.Conv2DLayerWithPadding.Params()
      params.weight_norm = False
      params.filter_stride = [2, 2]
      params.name = 'conv'
      params.filter_shape = [3, 3, 1, 1]
      params.params_init = py_utils.WeightInit.Constant(1.0)
      conv_layer = params.Instantiate()

      # Set up the padding for the sequence length. (starting at 5).
      in_padding = tf.constant([
          [0, 0, 0, 0, 0],
          [0, 0, 0, 0, 1],
          [0, 0, 0, 1, 1],
      ], tf.float32)
      in_padding = tf.pad(
          in_padding, [[0, 0], [0, seq_len - 5]], constant_values=1.0)

      inputs = 1.0 + tf.tile(
          tf.reshape(tf.range(seq_len, dtype=tf.float32), [1, seq_len, 1, 1]),
          [batch_size, 1, 3, 1])
      inputs = py_utils.ApplyPadding(
          tf.reshape(in_padding, [batch_size, seq_len, 1, 1]), inputs)

      inputs = py_utils.Debug(inputs)

      output, out_padding = conv_layer.FPropDefaultTheta(inputs, in_padding)

      output = py_utils.Debug(output)
      out_padding = py_utils.Debug(out_padding)

      self.evaluate(tf.global_variables_initializer())
      output, out_padding = self.evaluate([output, out_padding])

      self.assertEqual((batch_size, expected_seq_len, 2, 1), output.shape)
      self.assertAllClose([
          [0, 0, 1],
          [0, 0, 1],
          [0, 1, 1],
      ], out_padding)

      # This here shows a bug in the implementation; the output should be the
      # same. Also there are bugs with the output not having the correct
      # padding.
      if seq_len == 5:
        self.assertAllClose([
            [[[6], [6]], [[18], [18]], [[18], [18]]],
            [[[6], [6]], [[18], [18]], [[8], [8]]],
            [[[6], [6]], [[10], [10]], [[0], [0]]],
        ], output)
      elif seq_len == 6:
        self.assertAllClose([
            [[[12], [12]], [[24], [24]], [[10], [10]]],
            [[[12], [12]], [[14], [14]], [[0], [0]]],
            [[[12], [12]], [[6], [6]], [[0], [0]]],
        ], output)
      else:
        raise ValueError('Test does not handle length {seq_len}')
Ejemplo n.º 12
0
  def _StreamMoments(self, inputs, paddings, cached_sum, cached_count,
                     cached_var):
    """Computes mean and variance over the valid data points in inputs.

    Args:
      inputs: [B, T, F, N, G] or [B, T, N, G]
      paddings: [B, T, 1, 1, 1] or [B, T, 1, 1] (same rank as inputs)
      cached_sum: [B, N]
      cached_count: [B, 1]
      cached_var: [B, N]

    Returns:
      mean: [B, T, 1, N, 1] or [B, T, N, 1] (same rank as inputs)
      variance: same shape as mean.
      new_cached_sum: same shape as cached_sum.
      new_cached_count: same shape as cached_count.
      new_cached_var: same shape as cached_var.
    """
    tf.logging.vlog(1, 'inputs: %r', inputs)
    tf.logging.vlog(1, 'paddings: %r', paddings)
    tf.logging.vlog(1, 'cached_sum: %r', cached_sum)
    tf.logging.vlog(1, 'cached_count: %r', cached_count)
    tf.logging.vlog(1, 'cached_var: %r', cached_var)

    input_rank = py_utils.GetRank(inputs)
    paddings = py_utils.HasRank(paddings, input_rank)
    cached_sum = py_utils.HasRank(cached_sum, 2)
    cached_count = py_utils.HasRank(cached_count, 2)
    cached_var = py_utils.HasRank(cached_var, 2)

    input_shape = py_utils.GetShape(inputs)
    output_shape = input_shape[:]
    if input_rank == 4:
      # Skip {B,T,N}. Reduce just G.
      reduce_over_dims = [3]
      multiplier = input_shape[3]
      output_shape[3] = 1
    else:
      assert input_rank == 5
      # Skip {B,T,N}. Reduce {F,G}.
      reduce_over_dims = [2, 4]
      multiplier = input_shape[2] * input_shape[4]
      output_shape[2] = 1
      output_shape[4] = 1

    # [B, T, N]
    sum_v = tf.reduce_sum(
        py_utils.ApplyPadding(paddings, inputs),
        reduce_over_dims,
        keepdims=False)
    sum_v = tf.math.cumsum(sum_v, axis=1)
    sum_v += cached_sum[:, tf.newaxis, :]

    # [B, T, 1]
    count_v = tf.reduce_sum(
        py_utils.ApplyPadding(
            paddings, tf.cast(multiplier, inputs.dtype), ensure_shape=False),
        reduce_over_dims,
        keepdims=False)
    count_v = tf.math.cumsum(count_v, axis=1)
    count_v += cached_count[:, tf.newaxis, :]

    # [B, T, 1, N, 1] or [B, T, N, 1]
    mean = tf.reshape(sum_v / tf.maximum(count_v, 1.0), output_shape)

    # [B, T, N]
    sum_vv = tf.reduce_sum(
        py_utils.ApplyPadding(paddings,
                              tf.math.squared_difference(inputs, mean)),
        reduce_over_dims,
        keepdims=False)
    sum_vv = tf.math.cumsum(sum_vv, axis=1)
    sum_vv += cached_var[:, tf.newaxis, :]

    # [B, N]
    cached_sum = sum_v[:, -1]
    # [B, 1]
    cached_count = count_v[:, -1]
    # [B, N]
    cached_var = sum_vv[:, -1]

    # [B, T, 1, N, 1] or [B, T, N, 1]
    variance = tf.reshape(sum_vv / tf.maximum(count_v, 1.0), output_shape)

    tf.logging.vlog(1, 'sum_v: %r', sum_v)
    tf.logging.vlog(1, 'count_v: %r', count_v)
    tf.logging.vlog(1, 'sum_vv: %r', sum_vv)

    return mean, variance, cached_sum, cached_count, cached_var