示例#1
0
def ComputeMoments(inputs,
                   padding,
                   reduce_over_dims,
                   cumulative_axis=None,
                   enable_cross_replica_sum_on_tpu=False,
                   keepdims=False):
    """Computes mean and variance over the valid data points in inputs."""
    mask = 1.0 - padding
    inputs = py_utils.with_dependencies([
        py_utils.assert_equal(tf.rank(inputs), tf.rank(mask)),
        py_utils.assert_greater_equal(mask, tf.zeros_like(mask)),
    ], inputs)
    sum_v = tf.reduce_sum(inputs * tf.cast(mask, inputs.dtype),
                          reduce_over_dims,
                          keepdims=keepdims)
    count_v = tf.reduce_sum(mask, reduce_over_dims, keepdims=keepdims)

    if cumulative_axis is not None:
        sum_v = tf.math.cumsum(sum_v, axis=cumulative_axis)
        count_v = tf.math.cumsum(count_v, axis=cumulative_axis)
    # Input shape is guaranteed to be a multiple of mask shape because the
    # inputs * mask op above was successfully broadcasted.
    input_size_on_reduced_dims = tf.reduce_prod(
        tf.gather(tf.shape(inputs), reduce_over_dims))
    mask_size_on_reduced_dims = tf.reduce_prod(
        tf.gather(tf.shape(mask), reduce_over_dims))
    mask_multiplier = tf.math.truediv(input_size_on_reduced_dims,
                                      mask_size_on_reduced_dims)
    count_v *= tf.cast(mask_multiplier, count_v.dtype)
    if py_utils.use_tpu() and enable_cross_replica_sum_on_tpu:
        sum_v = tf.tpu.cross_replica_sum(sum_v)
        count_v = tf.tpu.cross_replica_sum(count_v)

    count_v = tf.maximum(count_v, 1.0)
    mean = sum_v / count_v
    sum_vv = tf.reduce_sum((inputs - mean) * (inputs - mean) * mask,
                           reduce_over_dims,
                           keepdims=keepdims)
    if cumulative_axis is not None:
        sum_vv = tf.math.cumsum(sum_vv, axis=cumulative_axis)

    if py_utils.use_tpu() and enable_cross_replica_sum_on_tpu:
        sum_vv = tf.tpu.cross_replica_sum(sum_vv)

    variance = py_utils.with_dependencies([
        py_utils.assert_greater_equal(sum_vv, tf.zeros_like(sum_vv)),
    ], sum_vv / count_v)
    return mean, variance
示例#2
0
    def _Moments(inputs, mask, enable_cross_replica_sum_on_tpu=False):
        """Computes mean and variance over the valid data points in inputs."""
        inputs = py_utils.with_dependencies([
            py_utils.assert_equal(tf.rank(inputs), tf.rank(mask)),
            py_utils.assert_greater_equal(mask, tf.zeros_like(mask)),
        ], inputs)
        rank = tf.rank(mask)
        reduce_over_dims = tf.range(0, rank - 1)
        sum_v = tf.reduce_sum(inputs * tf.cast(mask, inputs.dtype),
                              reduce_over_dims)
        count_v = tf.reduce_sum(mask, reduce_over_dims)
        # Input shape is guaranteed to be a multiple of mask shape because the
        # inputs * mask op above was successfully broadcasted.
        mask_multiplier = tf.shape(inputs)[:-1] // tf.shape(mask)[:-1]
        count_v *= tf.cast(tf.reduce_prod(mask_multiplier), count_v.dtype)
        if py_utils.use_tpu() and enable_cross_replica_sum_on_tpu:
            sum_v = tf.tpu.cross_replica_sum(sum_v)
            count_v = tf.tpu.cross_replica_sum(count_v)

        count_v = tf.maximum(count_v, 1.0)
        mean = sum_v / count_v
        sum_vv = tf.reduce_sum((inputs - mean) * (inputs - mean) * mask,
                               reduce_over_dims)

        if py_utils.use_tpu() and enable_cross_replica_sum_on_tpu:
            sum_vv = tf.tpu.cross_replica_sum(sum_vv)

        variance = py_utils.with_dependencies([
            py_utils.assert_greater_equal(sum_vv, tf.zeros_like(sum_vv)),
        ], sum_vv / count_v)
        return mean, variance
示例#3
0
  def FProp(self, theta, inputs):
    """Apply projection to inputs.

    Args:
      theta: A NestedMap object containing weights' values of this layer and its
        children layers.
      inputs: The inputs tensor.  Shaped [..., input_dims].

    Returns:
      Projected inputs.
    """
    p = self.params
    with tf.name_scope(p.name):
      computation_cost.Add(
          self, 'flops',
          tf.reduce_prod(tf.to_int64(tf.shape(inputs)[:-1])) * tf.to_int64(
              symbolic.EvalExpr(symbolic.TENSOR_VALUES,
                                p.input_dims * p.output_dims)) * 2)
      use_tpu = py_utils.use_tpu()
      if use_tpu and inputs.shape is not None and inputs.shape.rank < 26:
        # Avoids reshape if feasible and uses Einsum.
        if inputs.shape.rank == 2:
          return tf.matmul(inputs, theta.w)
        else:
          s = ''.join([chr(x) for x in range(97, 123)])  # abc...xyz
          r = inputs.shape.rank
          return tf.einsum('{0}y,yz->{0}z'.format(s[:r - 1]), inputs, theta.w)

      input_dim = py_utils.GetShape(inputs)[-1]
      act = tf.matmul(tf.reshape(inputs, [-1, input_dim]), theta.w)
      output_dim = tf.shape(theta.w)[-1]
      act = tf.reshape(act,
                       tf.concat([tf.shape(inputs)[:-1], [output_dim]], axis=0))
      return act
示例#4
0
 def _InputBatch(self):
   length = tf.reduce_prod(self.shape)
   counter = summary_utils.StatsCounter('CountingInputGenerator')
   new_value = tf.cast(counter.IncBy(length), dtype=tf.int32) - length
   new_value = tf.stop_gradient(new_value)
   values = new_value + tf.range(length)
   shaped_values = tf.reshape(tf.cast(values, dtype=tf.float32), self.shape)
   targets = tf.reduce_sum(shaped_values, axis=0)
   return py_utils.NestedMap(src_ids=shaped_values, tgt_ids=targets)
示例#5
0
  def FProp(self, theta, inputs):
    """Apply projection to inputs.

    Args:
      theta: A NestedMap object containing weights' values of this layer and its
        children layers.
      inputs: The inputs tensor.  Shaped [..., input_dims].

    Returns:
      Projected inputs.
    """
    p = self.params
    with tf.name_scope(p.name):
      computation_cost.Add(
          self, 'flops',
          tf.reduce_prod(tf.cast(tf.shape(inputs)[:-1], tf.int64)) *
          tf.cast(symbolic.ToTensor(p.input_dims * p.output_dims), tf.int64) *
          2)
      return py_utils.ProjectLastDim(inputs, theta.w, p.input_dims,
                                     p.output_dims)