예제 #1
0
def _get_bn_correction(conv_type, kernel, bias, mu_bt, var_bt, mu_mv, var_mv,
                       gamma, epsilon):
  """Get batchnorm correction params.

     Before freeze:
       corr_scale = sigma_bt / sigma_mv
       corr_recip = 1 / corr_scale
       corr_offset = 0
     After freeze:
       corr_scale = sigma_bt / sigma_mv
       corr_recip = 1
       corr_offset = gamma * ( (mu_bt - bias)/sigma_bt - (mu_mv - bias)/sigma_mv)
  """
  sigma_bt = math_ops.rsqrt(var_bt + epsilon)
  sigma_bt_recip = math_ops.reciprocal(sigma_bt)
  sigma_mv = math_ops.rsqrt(var_mv + epsilon)
  sigma_mv_recip = math_ops.reciprocal(sigma_mv)

  corr_scale = math_ops.divide(sigma_bt, sigma_mv, name='corr_scale')
  corr_recip = math_ops.reciprocal(corr_scale)
  corr_offset = array_ops.zeros(mu_bt.shape)

  if conv_type == 'DepthwiseConv2D':
    new_shape = [kernel.shape[2], kernel.shape[3]]
    corr_scale = array_ops.reshape(corr_scale, new_shape)

  return corr_scale, corr_recip, corr_offset
예제 #2
0
 def _grad_and_hess_for_logloss(logits, labels):
     # TODO(youngheek): add weights handling.
     predictions = math_ops.reciprocal(math_ops.exp(-logits) + 1.0)
     normalizer = math_ops.reciprocal(
         math_ops.cast(array_ops.size(predictions), dtypes.float32))
     gradients = (predictions - labels) * normalizer
     hessians = predictions * (1.0 - predictions) * normalizer
     return gradients, hessians
예제 #3
0
 def _grad_and_hess_for_logloss(logits, labels):
   # TODO(youngheek): add weights handling.
   predictions = math_ops.reciprocal(math_ops.exp(-logits) + 1.0)
   normalizer = math_ops.reciprocal(
       math_ops.cast(array_ops.size(predictions), dtypes.float32))
   gradients = (predictions - labels) * normalizer
   hessians = predictions * (1.0 - predictions) * normalizer
   return gradients, hessians
예제 #4
0
 def _grad_and_hess_for_logloss(logits, labels):
   """A closed form gradient and hessian for logistic loss."""
   # TODO(youngheek): add weights handling.
   predictions = math_ops.reciprocal(math_ops.exp(-logits) + 1.0)
   normalizer = math_ops.reciprocal(
       math_ops.cast(array_ops.size(predictions), dtypes.float32))
   labels = math_ops.cast(labels, dtypes.float32)
   labels = head_lib._check_dense_labels_match_logits_and_reshape(  # pylint: disable=protected-access
       labels, logits, head.logits_dimension)
   gradients = (predictions - labels) * normalizer
   hessians = predictions * (1.0 - predictions) * normalizer
   return gradients, hessians
예제 #5
0
 def _grad_and_hess_for_logloss(logits, labels):
   """A closed form gradient and hessian for logistic loss."""
   # TODO(youngheek): add weights handling.
   predictions = math_ops.reciprocal(math_ops.exp(-logits) + 1.0)
   normalizer = math_ops.reciprocal(
       math_ops.cast(array_ops.size(predictions), dtypes.float32))
   labels = math_ops.cast(labels, dtypes.float32)
   labels = head_lib._check_dense_labels_match_logits_and_reshape(  # pylint: disable=protected-access
       labels, logits, head.logits_dimension)
   gradients = (predictions - labels) * normalizer
   hessians = predictions * (1.0 - predictions) * normalizer
   return gradients, hessians
예제 #6
0
    def _grad(op, grad):
        """A gradient function for IRFFT with the provided `rank` and `rfft_fn`."""
        # Generate a simple mask like [1.0, 2.0, ..., 2.0, 1.0] for even-length FFTs
        # and [1.0, 2.0, ..., 2.0] for odd-length FFTs. To reduce extra ops in the
        # graph we special-case the situation where the FFT length and last
        # dimension of the input are known at graph construction time.
        fft_length = op.inputs[1]
        fft_length_static = _tensor_util.constant_value(fft_length)
        if fft_length_static is not None:
            fft_length = fft_length_static
        real_dtype = grad.dtype
        if real_dtype == _dtypes.float32:
            complex_dtype = _dtypes.complex64
        elif real_dtype == _dtypes.float64:
            complex_dtype = _dtypes.complex128
        is_odd = _math_ops.mod(fft_length[-1], 2)
        input_last_dimension = _array_ops.shape(op.inputs[0])[-1]
        mask = _array_ops.concat(
            [[1.0], 2.0 *
             _array_ops.ones([input_last_dimension - 2 + is_odd], real_dtype),
             _array_ops.ones([1 - is_odd], real_dtype)], 0)

        rsize = _math_ops.reciprocal(
            _math_ops.cast(_fft_size_for_grad(grad, rank), real_dtype))

        # The gradient of IRFFT is the RFFT of the incoming gradient times a scaling
        # factor and a mask. The mask scales the gradient for the Hermitian
        # symmetric components of the RFFT by a factor of two, since these
        # components are de-duplicated in the RFFT.
        the_rfft = rfft_fn(grad, fft_length)
        return the_rfft * _math_ops.cast(rsize * mask, complex_dtype), None
예제 #7
0
def normalize_moments(counts, mean_ss, variance_ss, shift, name=None):
  """Calculate the mean and variance of based on the sufficient statistics.

  Args:
    counts: A `Tensor` containing a the total count of the data (one value).
    mean_ss: A `Tensor` containing the mean sufficient statistics: the (possibly
      shifted) sum of the elements to average over.
    variance_ss: A `Tensor` containing the variance sufficient statistics: the
      (possibly shifted) squared sum of the data to compute the variance over.
    shift: A `Tensor` containing the value by which the data is shifted for
      numerical stability, or `None` if no shift was performed.
    name: Name used to scope the operations that compute the moments.

  Returns:
    Two `Tensor` objects: `mean` and `variance`.
  """
  with ops.name_scope(name, "normalize", [counts, mean_ss, variance_ss, shift]):
    divisor = math_ops.reciprocal(counts, name="divisor")
    if shift is not None:
      shifted_mean = math_ops.multiply(mean_ss, divisor, name="shifted_mean")
      mean = math_ops.add(shifted_mean, shift, name="mean")
    else:  # no shift.
      shifted_mean = math_ops.multiply(mean_ss, divisor, name="mean")
      mean = shifted_mean
    variance = math_ops.subtract(math_ops.multiply(variance_ss, divisor),
                                 math_ops.square(shifted_mean),
                                 name="variance")
  return (mean, variance)
예제 #8
0
def normalize_moments(counts, mean_ss, variance_ss, shift, name=None):
  """Calculate the mean and variance of based on the sufficient statistics.

  Args:
    counts: A `Tensor` containing a the total count of the data (one value).
    mean_ss: A `Tensor` containing the mean sufficient statistics: the (possibly
      shifted) sum of the elements to average over.
    variance_ss: A `Tensor` containing the variance sufficient statistics: the
      (possibly shifted) squared sum of the data to compute the variance over.
    shift: A `Tensor` containing the value by which the data is shifted for
      numerical stability, or `None` if no shift was performed.
    name: Name used to scope the operations that compute the moments.

  Returns:
    Two `Tensor` objects: `mean` and `variance`.
  """
  with ops.name_scope(name, "normalize", [counts, mean_ss, variance_ss, shift]):
    divisor = math_ops.reciprocal(counts, name="divisor")
    if shift is not None:
      shifted_mean = math_ops.mul(mean_ss, divisor, name="shifted_mean")
      mean = math_ops.add(shifted_mean, shift, name="mean")
    else:  # no shift.
      shifted_mean = math_ops.mul(mean_ss, divisor, name="mean")
      mean = shifted_mean
    variance = math_ops.sub(math_ops.mul(variance_ss, divisor),
                            math_ops.square(shifted_mean),
                            name="variance")
  return (mean, variance)
예제 #9
0
def _TanGrad(op, grad):
    """Returns grad * 1/sec^2(x)."""
    x = op.inputs[0]
    with ops.control_dependencies([grad]):
        x = math_ops.conj(x)
        secx = math_ops.reciprocal(math_ops.cos(x))
        secx2 = math_ops.square(secx)
        return grad * secx2
예제 #10
0
def _TanGrad(op, grad):
  """Returns grad * 1/sec^2(x)."""
  x = op.inputs[0]
  with ops.control_dependencies([grad]):
    x = math_ops.conj(x)
    secx = math_ops.reciprocal(math_ops.cos(x))
    secx2 = math_ops.square(secx)
    return grad * secx2
예제 #11
0
def _AtanGrad(op, grad):
  """Returns grad * 1/ (1 + x^2)."""
  x = op.inputs[0]
  with ops.control_dependencies([grad]):
    x = math_ops.conj(x)
    x2 = math_ops.square(x)
    one = constant_op.constant(1, dtype=grad.dtype)
    inv = math_ops.reciprocal(math_ops.add(one, x2))
    return grad * inv
예제 #12
0
def _AtanGrad(op, grad):
    """Returns grad * 1/ (1 + x^2)."""
    x = op.inputs[0]
    with ops.control_dependencies([grad]):
        x = math_ops.conj(x)
        x2 = math_ops.square(x)
        one = constant_op.constant(1, dtype=grad.dtype)
        inv = math_ops.reciprocal(math_ops.add(one, x2))
        return grad * inv
예제 #13
0
def qbatch_normalization(x,
                         mean,
                         variance,
                         offset,
                         scale,
                         variance_epsilon,
                         quantizer,
                         name=None):
    r"""Batch normalization.
  As described in http://arxiv.org/abs/1502.03167.
  Normalizes a tensor by `mean` and `variance`, and applies (optionally) a
  `scale` \\(\gamma\\) to it, as well as an `offset` \\(\beta\\):
  \\(\frac{\gamma(x-\mu)}{\sigma}+\beta\\)
  `mean`, `variance`, `offset` and `scale` are all expected to be of one of two
  shapes:
    * In all generality, they can have the same number of dimensions as the
      input `x`, with identical sizes as `x` for the dimensions that are not
      normalized over (the 'depth' dimension(s)), and dimension 1 for the
      others which are being normalized over.
      `mean` and `variance` in this case would typically be the outputs of
      `tf.nn.moments(..., keep_dims=True)` during training, or running averages
      thereof during inference.
    * In the common case where the 'depth' dimension is the last dimension in
      the input tensor `x`, they may be one dimensional tensors of the same
      size as the 'depth' dimension.
      This is the case for example for the common `[batch, depth]` layout of
      fully-connected layers, and `[batch, height, width, depth]` for
      convolutions.
      `mean` and `variance` in this case would typically be the outputs of
      `tf.nn.moments(..., keep_dims=False)` during training, or running averages
      thereof during inference.
  Args:
    x: Input `Tensor` of arbitrary dimensionality.
    mean: A mean `Tensor`.
    variance: A variance `Tensor`.
    offset: An offset `Tensor`, often denoted \\(\beta\\) in equations, or
      None. If present, will be added to the normalized tensor.
    scale: A scale `Tensor`, often denoted \\(\gamma\\) in equations, or
      `None`. If present, the scale is applied to the normalized tensor.
    variance_epsilon: A small float number to avoid dividing by 0.
    name: A name for this operation (optional).
  Returns:
    the normalized, scaled, offset tensor.
  """
    with ops.name_scope(name, "batchnorm", [x, mean, variance, scale, offset]):
        #internal calculation of sqrt is NOT quantized!
        inv = quantizer.quantize(math_ops.sqrt(variance + variance_epsilon))
        inv = quantizer.quantize(math_ops.reciprocal(inv))
        if scale is not None:
            inv = quantizer.quantize(inv * scale)
        #return x * inv + (offset - mean * inv if offset is not None else -mean * inv)
        if offset is not None:
            rest = quantizer.quantize(offset - quantizer.quantize(mean * inv))
        else:
            rest = quantizer.quantize(-mean * inv)
        result = quantizer.quantize(quantizer.quantize(x * inv) + rest)
        return result
예제 #14
0
def _AcosGrad(op, grad):
    """Returns grad * -1/sqrt(1-x^2)."""
    x = op.inputs[0]
    with ops.control_dependencies([grad]):
        x = math_ops.conj(x)
        x2 = math_ops.square(x)
        one = constant_op.constant(1, dtype=grad.dtype)
        den = math_ops.sqrt(math_ops.subtract(one, x2))
        inv = math_ops.reciprocal(den)
        return -grad * inv
예제 #15
0
def _AcosGrad(op, grad):
  """Returns grad * -1/sqrt(1-x^2)."""
  x = op.inputs[0]
  with ops.control_dependencies([grad]):
    x = math_ops.conj(x)
    x2 = math_ops.square(x)
    one = constant_op.constant(1, dtype=grad.dtype)
    den = math_ops.sqrt(math_ops.subtract(one, x2))
    inv = math_ops.reciprocal(den)
    return -grad * inv
예제 #16
0
def _AngleGrad(op, grad):
    """Returns -grad / (Im(x) + iRe(x))"""
    x = op.inputs[0]
    with ops.control_dependencies([grad]):
        re = math_ops.real(x)
        im = math_ops.imag(x)
        z = math_ops.reciprocal(math_ops.complex(im, re))
        zero = constant_op.constant(0, dtype=grad.dtype)
        complex_grad = math_ops.complex(grad, zero)
        return -complex_grad * z
예제 #17
0
def _AngleGrad(op, grad):
  """Returns -grad / (Im(x) + iRe(x))"""
  x = op.inputs[0]
  with ops.control_dependencies([grad]):
    re = math_ops.real(x)
    im = math_ops.imag(x)
    z = math_ops.reciprocal(math_ops.complex(im, re))
    zero = constant_op.constant(0, dtype=grad.dtype)
    complex_grad = math_ops.complex(grad, zero)
    return -complex_grad * z
예제 #18
0
    def _get_matches_hook(self, y_pred_click_ranks):
        """
        Return reciprocal click ranks for MRR

        Parameters
        ----------
        y_pred_click_ranks: Tensor object
            Tensor object containing the ranks of the clicked records for each query

        Returns
        -------
        Tensor object
            Reciprocal ranks tensor
        """
        return math_ops.reciprocal(tf.cast(y_pred_click_ranks, tf.float32))
예제 #19
0
def _BesselI1eGrad(op, grad):
  """Compute gradient of bessel_i1e(x) with respect to its argument."""
  x = op.inputs[0]
  y = op.outputs[0]
  with ops.control_dependencies([grad]):
    # For x = 0, the correct gradient is 0.5.
    # However, the main branch gives NaN because of the division by x, so
    # we impute the gradient manually.
    # An alternative solution is to express the gradient via bessel_i0e and
    # bessel_i2e, but the latter is not yet implemented in Eigen.
    eps = np.finfo(x.dtype.as_numpy_dtype).eps
    zeros = array_ops.zeros_like(x)
    x_is_not_tiny = math_ops.abs(x) > eps
    safe_x = array_ops.where(x_is_not_tiny, x, eps + zeros)
    dy_dx = math_ops.bessel_i0e(safe_x) - y * (
        math_ops.sign(safe_x) + math_ops.reciprocal(safe_x))
    return grad * array_ops.where(x_is_not_tiny, dy_dx, 0.5 + zeros)
예제 #20
0
def _BesselI1eGrad(op, grad):
  """Compute gradient of bessel_i1e(x) with respect to its argument."""
  x = op.inputs[0]
  y = op.outputs[0]
  with ops.control_dependencies([grad]):
    # For x = 0, the correct gradient is 0.5.
    # However, the main branch gives NaN because of the division by x, so
    # we impute the gradient manually.
    # An alternative solution is to express the gradient via bessel_i0e and
    # bessel_i2e, but the latter is not yet implemented in Eigen.
    eps = np.finfo(x.dtype.as_numpy_dtype).eps
    zeros = array_ops.zeros_like(x)
    x_is_not_tiny = math_ops.abs(x) > eps
    safe_x = array_ops.where(x_is_not_tiny, x, eps + zeros)
    dy_dx = math_ops.bessel_i0e(safe_x) - y * (
        math_ops.sign(safe_x) + math_ops.reciprocal(safe_x))
    return grad * array_ops.where(x_is_not_tiny, dy_dx, 0.5 + zeros)
예제 #21
0
def _BatchNormGrad(grad_y, x, scale, epsilon, data_format):
    """Returns the gradients for the 3 inputs of BatchNorm.

  Args:
    grad_y: A `Tensor` of 4 dimensions for gradient for y.
    x: A `Tensor` of 4 dimensions for x.
    scale: A `Tensor` of 1 dimension for scaling.
    epsilon: A small float number added to the variance of x.
    data_format: The data format for input. Either b"NHWC" or b"NCHW".

  Returns:
    A tuple (grad_x, grad_scale, grad_offset), where grad_x is the gradient
    for x, grad_scale the gradient for scale, and grad_offset the gradient
    for offset.
  """
    if data_format == b"NHWC":
        keep_dims = False
        reduce_axis = [0, 1, 2]
    else:
        keep_dims = True
        reduce_axis = [0, 2, 3]
        shape = [1, array_ops.size(scale), 1, 1]
        scale = array_ops.reshape(scale, shape)
    mean_grad_y = math_ops.reduce_mean(grad_y,
                                       reduce_axis,
                                       keep_dims=keep_dims)
    mean_x = math_ops.reduce_mean(x, reduce_axis, keep_dims=keep_dims)
    var_x = math_ops.reduce_mean(math_ops.squared_difference(
        x, array_ops.stop_gradient(mean_x)),
                                 reduce_axis,
                                 keep_dims=keep_dims)
    grad_y_offset = grad_y - mean_grad_y
    x_offset = x - mean_x
    mean = math_ops.reduce_mean(grad_y * x_offset,
                                axis=reduce_axis,
                                keep_dims=keep_dims)
    grad_x = scale * math_ops.rsqrt(var_x + epsilon) * (
        grad_y_offset - math_ops.reciprocal(var_x + epsilon) * mean * x_offset)
    grad_scale = math_ops.rsqrt(var_x + epsilon) * math_ops.reduce_sum(
        grad_y * x_offset, axis=reduce_axis, keep_dims=keep_dims)
    if data_format == b"NCHW":
        grad_scale = array_ops.squeeze(grad_scale)
    grad_offset = math_ops.reduce_sum(grad_y, axis=reduce_axis)
    return grad_x, grad_scale, grad_offset
예제 #22
0
def _BatchNormGrad(grad_y, x, scale, epsilon, data_format):
  """Returns the gradients for the 3 inputs of BatchNorm.

  Args:
    grad_y: A `Tensor` of 4 dimensions for gradient for y.
    x: A `Tensor` of 4 dimensions for x.
    scale: A `Tensor` of 1 dimension for scaling.
    epsilon: A small float number added to the variance of x.
    data_format: The data format for input. Either b"NHWC" or b"NCHW".

  Returns:
    A tuple (grad_x, grad_scale, grad_offset), where grad_x is the gradient
    for x, grad_scale the gradient for scale, and grad_offset the gradient
    for offset.
  """
  if data_format == b"NHWC":
    keep_dims = False
    reduce_axis = [0, 1, 2]
  else:
    keep_dims = True
    reduce_axis = [0, 2, 3]
    shape = [1, array_ops.size(scale), 1, 1]
    scale = array_ops.reshape(scale, shape)
  mean_grad_y = math_ops.reduce_mean(grad_y, reduce_axis, keep_dims=keep_dims)
  mean_x = math_ops.reduce_mean(x, reduce_axis, keep_dims=keep_dims)
  var_x = math_ops.reduce_mean(
      math_ops.squared_difference(x, array_ops.stop_gradient(mean_x)),
      reduce_axis,
      keep_dims=keep_dims)
  grad_y_offset = grad_y - mean_grad_y
  x_offset = x - mean_x
  mean = math_ops.reduce_mean(
      grad_y * x_offset, axis=reduce_axis, keep_dims=keep_dims)
  grad_x = scale * math_ops.rsqrt(var_x + epsilon) * (
      grad_y_offset - math_ops.reciprocal(var_x + epsilon) * mean * x_offset)
  grad_scale = math_ops.rsqrt(var_x + epsilon) * math_ops.reduce_sum(
      grad_y * x_offset, axis=reduce_axis, keep_dims=keep_dims)
  if data_format == b"NCHW":
    grad_scale = array_ops.squeeze(grad_scale)
  grad_offset = math_ops.reduce_sum(grad_y, axis=reduce_axis)
  return grad_x, grad_scale, grad_offset
예제 #23
0
  def _Grad(op, grad):
    """A gradient function for IRFFT with the provided `rank` and `rfft_fn`."""
    # Generate a simple mask like [1.0, 2.0, ..., 2.0, 1.0] for even-length FFTs
    # and [1.0, 2.0, ..., 2.0] for odd-length FFTs. To reduce extra ops in the
    # graph we special-case the situation where the FFT length and last
    # dimension of the input are known at graph construction time.
    fft_length = op.inputs[1]
    is_odd = math_ops.mod(fft_length[-1], 2)
    input_last_dimension = array_ops.shape(op.inputs[0])[-1]
    mask = array_ops.concat(
        [[1.0], 2.0 * array_ops.ones([input_last_dimension - 2 + is_odd]),
         array_ops.ones([1 - is_odd])], 0)

    rsize = math_ops.reciprocal(math_ops.to_float(_FFTSizeForGrad(grad, rank)))

    # The gradient of IRFFT is the RFFT of the incoming gradient times a scaling
    # factor and a mask. The mask scales the gradient for the Hermitian
    # symmetric components of the RFFT by a factor of two, since these
    # components are de-duplicated in the RFFT.
    rfft = rfft_fn(grad, fft_length)
    return rfft * math_ops.cast(rsize * mask, dtypes.complex64), None
예제 #24
0
def _SelfAdjointEigV2Grad(op, grad_e, grad_v):
  """Gradient for SelfAdjointEigV2."""
  e = op.outputs[0]
  compute_v = op.get_attr("compute_v")
  # a = op.inputs[0], which satisfies
  # a[...,:,:] * v[...,:,i] = e[...,i] * v[...,i]
  with ops.control_dependencies([grad_e, grad_v]):
    if compute_v:
      v = op.outputs[1]
      # Construct the matrix f(i,j) = (i != j ? 1 / (e_i - e_j) : 0).
      # Notice that because of the term involving f, the gradient becomes
      # infinite (or NaN in practice) when eigenvalues are not unique.
      # Mathematically this should not be surprising, since for (k-fold)
      # degenerate eigenvalues, the corresponding eigenvectors are only defined
      # up to arbitrary rotation in a (k-dimensional) subspace.
      f = array_ops.matrix_set_diag(
          math_ops.reciprocal(
              array_ops.expand_dims(e, -2) - array_ops.expand_dims(e, -1)),
          array_ops.zeros_like(e))
      grad_a = math_ops.matmul(
          v,
          math_ops.matmul(
              array_ops.matrix_diag(grad_e) +
              f * math_ops.matmul(v, grad_v, adjoint_a=True),
              v,
              adjoint_b=True))
    else:
      _, v = linalg_ops.self_adjoint_eig(op.inputs[0])
      grad_a = math_ops.matmul(v,
                               math_ops.matmul(
                                   array_ops.matrix_diag(grad_e),
                                   v,
                                   adjoint_b=True))
    # The forward op only depends on the lower triangular part of a, so here we
    # symmetrize and take the lower triangle
    grad_a = array_ops.matrix_band_part(
        grad_a + math_ops.conj(array_ops.matrix_transpose(grad_a)), -1, 0)
    grad_a = array_ops.matrix_set_diag(grad_a,
                                       0.5 * array_ops.matrix_diag_part(grad_a))
    return grad_a
예제 #25
0
def _SelfAdjointEigV2Grad(op, grad_e, grad_v):
    """Gradient for SelfAdjointEigV2."""
    e = op.outputs[0]
    compute_v = op.get_attr("compute_v")
    # a = op.inputs[0], which satisfies
    # a[...,:,:] * v[...,:,i] = e[...,i] * v[...,i]
    with ops.control_dependencies([grad_e, grad_v]):
        if compute_v:
            v = op.outputs[1]
            # Construct the matrix f(i,j) = (i != j ? 1 / (e_i - e_j) : 0).
            # Notice that because of the term involving f, the gradient becomes
            # infinite (or NaN in practice) when eigenvalues are not unique.
            # Mathematically this should not be surprising, since for (k-fold)
            # degenerate eigenvalues, the corresponding eigenvectors are only defined
            # up to arbitrary rotation in a (k-dimensional) subspace.
            f = array_ops.matrix_set_diag(
                math_ops.reciprocal(
                    array_ops.expand_dims(e, -2) -
                    array_ops.expand_dims(e, -1)), array_ops.zeros_like(e))
            grad_a = math_ops.matmul(
                v,
                math_ops.matmul(array_ops.matrix_diag(grad_e) +
                                f * math_ops.matmul(v, grad_v, adjoint_a=True),
                                v,
                                adjoint_b=True))
        else:
            _, v = linalg_ops.self_adjoint_eig(op.inputs[0])
            grad_a = math_ops.matmul(
                v,
                math_ops.matmul(array_ops.matrix_diag(grad_e),
                                v,
                                adjoint_b=True))
        # The forward op only depends on the lower triangular part of a, so here we
        # symmetrize and take the lower triangle
        grad_a = array_ops.matrix_band_part(grad_a + _linalg.adjoint(grad_a),
                                            -1, 0)
        grad_a = array_ops.matrix_set_diag(
            grad_a, 0.5 * array_ops.matrix_diag_part(grad_a))
        return grad_a
예제 #26
0
    def _Grad(op, grad):
        """A gradient function for IRFFT with the provided `rank` and `rfft_fn`."""
        # Generate a simple mask like [1.0, 2.0, ..., 2.0, 1.0] for even-length FFTs
        # and [1.0, 2.0, ..., 2.0] for odd-length FFTs. To reduce extra ops in the
        # graph we special-case the situation where the FFT length and last
        # dimension of the input are known at graph construction time.
        fft_length = op.inputs[1]
        is_odd = math_ops.mod(fft_length[-1], 2)
        input_last_dimension = array_ops.shape(op.inputs[0])[-1]
        mask = array_ops.concat(
            [[1.0], 2.0 * array_ops.ones([input_last_dimension - 2 + is_odd]),
             array_ops.ones([1 - is_odd])], 0)

        rsize = math_ops.reciprocal(
            math_ops.to_float(_FFTSizeForGrad(grad, rank)))

        # The gradient of IRFFT is the RFFT of the incoming gradient times a scaling
        # factor and a mask. The mask scales the gradient for the Hermitian
        # symmetric components of the RFFT by a factor of two, since these
        # components are de-duplicated in the RFFT.
        rfft = rfft_fn(grad, fft_length)
        return rfft * math_ops.cast(rsize * mask, dtypes.complex64), None
예제 #27
0
def QuantizedBatchNormalizationCore(inputs,
                                    mean,
                                    variance,
                                    beta,
                                    gamma,
                                    variance_epsilon,
                                    Q_info,
                                    name=None):
    """ Intrinsic quantization of BatchNormalization layer.

    Parameters
    ----------
    | inputs : Tensor.
    | mean : tf.Variable
    | variance : tf.Variable
    | beta : tf.Variable
    | gamma : tf.Variable
    | variance_epsilon : Float

    Returns
    -------
    output : Tensor

    """
    with ops.name_scope(name, "batchnorm", [inputs, mean, variance, gamma, beta]):
        coef = Q_info.quantize( math_ops.sqrt(variance + variance_epsilon))
        coef = Q_info.quantize( math_ops.reciprocal(coef))
        if gamma is not None:
          coef = Q_info.quantize(coef*gamma)
        
        if beta is not None:
            const = Q_info.quantize( beta - Q_info.quantize(mean * coef))
        else:
            const = Q_info.quantize(-mean * coef)
        output = Q_info.quantize( Q_info.quantize(inputs * coef) + const)
        return output
예제 #28
0
def _SvdGrad(op, grad_s, grad_u, grad_v):
    """Gradient for the singular value decomposition."""

    # The derivation for the compute_uv=False case, and most of
    # the derivation for the full_matrices=True case, are in
    # Giles' paper (see reference at top of file).  A derivation for
    # the full_matrices=False case is available at
    # https://j-towns.github.io/papers/svd-derivative.pdf
    a = op.inputs[0]
    a_shape = a.get_shape().with_rank_at_least(2)
    grad_s_mat = array_ops.matrix_diag(grad_s)

    if not op.get_attr("compute_uv"):
        s, u, v = linalg_ops.svd(a, compute_uv=True)
        grad_a = math_ops.matmul(
            u, math_ops.matmul(grad_s_mat, v, adjoint_b=True))
        grad_a.set_shape(a_shape)
        return grad_a

    full_matrices = op.get_attr("full_matrices")

    # TODO(rmlarsen): Make this work with complex types.
    if a.dtype.is_complex:
        raise NotImplementedError(
            "SVD gradient is not implemented for complex types and "
            "compute_uv=True.")
    grad_u_shape = grad_u.get_shape().with_rank_at_least(2)
    grad_v_shape = grad_v.get_shape().with_rank_at_least(2)
    m = a_shape[-2].merge_with(grad_u_shape[-2])
    n = a_shape[-1].merge_with(grad_v_shape[-2])
    batch_shape = a_shape[:-2].merge_with(grad_u_shape[:-2]).merge_with(
        grad_v_shape[:-2])
    a_shape = batch_shape.concatenate([m, n])

    m = a_shape[-2].value
    n = a_shape[-1].value
    # TODO(rmlarsen): Make this work with placeholders.
    if m is None or n is None:
        raise NotImplementedError(
            "SVD gradient has not been implemented for input with unknown "
            "inner matrix shape.")

    s = op.outputs[0]
    u = op.outputs[1]
    v = op.outputs[2]

    use_adjoint = False
    if m > n:
        # Compute the gradient for A^H = V * S^T * U^H, and (implicitly) take the
        # Hermitian transpose of the gradient at the end.
        use_adjoint = True
        m, n = n, m
        u, v = v, u
        grad_u, grad_v = grad_v, grad_u

    with ops.control_dependencies([grad_s, grad_u, grad_v]):
        if full_matrices and abs(m - n) > 1:
            raise NotImplementedError(
                "svd gradient is not implemented for abs(m - n) > 1 "
                "when full_matrices is True")
        s_mat = array_ops.matrix_diag(s)
        s2 = math_ops.square(s)

        # NOTICE: Because of the term involving f, the gradient becomes
        # infinite (or NaN in practice) when singular values are not unique.
        # Mathematically this should not be surprising, since for (k-fold)
        # degenerate singular values, the corresponding singular vectors are
        # only defined up a (k-dimensional) subspace. In practice, this can
        # lead to numerical instability when singular values are close but not
        # exactly equal.
        f = array_ops.matrix_set_diag(
            math_ops.reciprocal(
                array_ops.expand_dims(s2, -2) - array_ops.expand_dims(s2, -1)),
            array_ops.zeros_like(s))
        s_inv_mat = array_ops.matrix_diag(math_ops.reciprocal(s))

        v1 = v[..., :, :m]
        grad_v1 = grad_v[..., :, :m]

        u_gu = math_ops.matmul(u, grad_u, adjoint_a=True)
        v_gv = math_ops.matmul(v1, grad_v1, adjoint_a=True)

        f_u = f * u_gu
        f_v = f * v_gv

        term1_nouv = (grad_s_mat +
                      math_ops.matmul(f_u + _linalg.adjoint(f_u), s_mat) +
                      math_ops.matmul(s_mat, f_v + _linalg.adjoint(f_v)))

        term1 = math_ops.matmul(
            u, math_ops.matmul(term1_nouv, v1, adjoint_b=True))

        if m == n:
            grad_a_before_transpose = term1
        else:
            gv1t = array_ops.matrix_transpose(grad_v1)
            gv1t_v1 = math_ops.matmul(gv1t, v1)
            term2_nous = gv1t - math_ops.matmul(gv1t_v1, v1, adjoint_b=True)

            if full_matrices:
                v2 = v[..., :, m:n]
                grad_v2 = grad_v[..., :, m:n]

                v1t_gv2 = math_ops.matmul(v1, grad_v2, adjoint_a=True)
                term2_nous -= math_ops.matmul(v1t_gv2, v2, adjoint_b=True)

            u_s_inv = math_ops.matmul(u, s_inv_mat)
            term2 = math_ops.matmul(u_s_inv, term2_nous)

            grad_a_before_transpose = term1 + term2

        if use_adjoint:
            grad_a = array_ops.matrix_transpose(grad_a_before_transpose)
        else:
            grad_a = grad_a_before_transpose

        grad_a.set_shape(a_shape)
        return grad_a
예제 #29
0
def _BatchNormGrad(grad_y,
                   x,
                   scale,
                   pop_mean,
                   pop_var,
                   epsilon,
                   data_format,
                   is_training=True):
    """Returns the gradients for the 3 inputs of BatchNorm.

  Args:
    grad_y: A `Tensor` of 4 dimensions for gradient for y.
    x: A `Tensor` of 4 dimensions for x.
    scale: A `Tensor` of 1 dimension for scaling.
    pop_mean: A `Tensor` of 1 dimension for the population mean. Only used when
      is_training=False.
    pop_var: A `Tensor` of 1 dimension for the population variance. Only used
      when is_training=False.
    epsilon: A small float number added to the variance of x.
    data_format: The data format for input. Either b"NHWC" or b"NCHW".
    is_training: A bool value to indicate the operation is for training
      (default) or inference.

  Returns:
    A tuple (grad_x, grad_scale, grad_offset), where grad_x is the gradient
    for x, grad_scale the gradient for scale, and grad_offset the gradient
    for offset.
  """
    x_dtype = x.dtype.base_dtype
    if x_dtype == dtypes.float16:
        # float16 math is too imprecise, so we do the batch norm gradient
        # computations in float32.
        x = math_ops.cast(x, dtypes.float32)
        grad_y = math_ops.cast(grad_y, dtypes.float32)
    if is_training:
        if data_format == b"NHWC":
            keepdims = False
            reduce_axis = [0, 1, 2]
        else:
            keepdims = True
            reduce_axis = [0, 2, 3]
            shape = [1, array_ops.size(scale), 1, 1]
            scale = array_ops.reshape(scale, shape)
        mean_grad_y = math_ops.reduce_mean(grad_y,
                                           reduce_axis,
                                           keepdims=keepdims)
        mean_x = math_ops.reduce_mean(x, reduce_axis, keepdims=keepdims)
        var_x = math_ops.reduce_mean(math_ops.squared_difference(
            x, array_ops.stop_gradient(mean_x)),
                                     reduce_axis,
                                     keepdims=keepdims)
        grad_y_offset = grad_y - mean_grad_y
        x_offset = x - mean_x
        mean = math_ops.reduce_mean(grad_y * x_offset,
                                    axis=reduce_axis,
                                    keepdims=keepdims)
        grad_x = scale * math_ops.rsqrt(var_x + epsilon) * (
            grad_y_offset -
            math_ops.reciprocal(var_x + epsilon) * mean * x_offset)
        grad_scale = math_ops.rsqrt(var_x + epsilon) * math_ops.reduce_sum(
            grad_y * x_offset, axis=reduce_axis, keepdims=keepdims)
        if data_format == b"NCHW":
            grad_scale = array_ops.squeeze(grad_scale)
        grad_offset = math_ops.reduce_sum(grad_y, axis=reduce_axis)
        return math_ops.cast(grad_x, x_dtype), grad_scale, grad_offset
    else:
        if data_format == b"NHWC":
            reduce_axis = [0, 1, 2]
        else:
            reduce_axis = [0, 2, 3]
            shape = [1, array_ops.size(pop_mean), 1, 1]
            pop_mean = array_ops.reshape(pop_mean, shape)
            pop_var = array_ops.reshape(pop_var, shape)
            scale = array_ops.reshape(scale, shape)

        grad_offset = math_ops.reduce_sum(grad_y, axis=reduce_axis)
        var_rsqrt = math_ops.rsqrt(pop_var + epsilon)
        grad_scale = math_ops.reduce_sum(grad_y * (x - pop_mean) * var_rsqrt,
                                         axis=reduce_axis)
        grad_x = grad_y * scale * var_rsqrt
        return math_ops.cast(grad_x, x_dtype), grad_scale, grad_offset
예제 #30
0
def _ComputeBatchNormCorrections(context, match, freeze_batch_norm_delay,
                                 fused_batch_norm):
  """Computes batch norm correction params.

     Before batch normalization is frozen:
     We use batch statistics for batch norm.
       correction_scale = sigma_b/sigma_mv
       correction_recip = 1/correction_scale
       correction_offset = 0

     After batch normalization is frozen:
      correction_scale = sigma_b/sigma_mv
      correction_recip = 1
      correction_offset =  gamma*(mu_b/sigma_b-mu_mv/sigma_mv).

     Batch norm is frozen if global_step > bn_freeze_delay.
     The corrections ensure that:
     a) The weights are quantized after scaling by gamma/sigma_mv. This enables
     smoother training as the scaling on the weights changes slowly, rather than
     jump across mini-batches
     b) Changing the values of the corrections allows for one to switch between
     using batch statistics to using moving mean and average, without requiring
     changes to batch_norm


  Args:
    context: The scope under which we look for batch norm params
    match: Object containing required batch norm tensors for correction
      computation.
    freeze_batch_norm_delay: Delay in steps at which computation switches
      from regular batch norm to frozen mean and variance.
    fused_batch_norm: Bool, true if fused batch norm is used.

  Returns:
    A tuple of correction_scale, correction_recip, correction_offset
  """

  g = ops.get_default_graph()
  prefix = '' if not context else context + '/'
  with g.name_scope(prefix + 'batch_norm_correction'):
    recip_sigma_mv = math_ops.rsqrt(
        match.moving_variance_tensor + match.batch_epsilon)
    recip_sigma = math_ops.rsqrt(match.variance_tensor + match.batch_epsilon)
    correction_scale = math_ops.divide(
        recip_sigma_mv, recip_sigma, name='scale_compute')
    correction_scale = array_ops.identity(
        correction_scale, name='correction_scale')
    correction_recip = math_ops.reciprocal(
        correction_scale, name='reciprocal_compute')
    correction_offset = math_ops.multiply(
        match.gamma_tensor,
        match.mean_tensor * recip_sigma -
        match.moving_mean_tensor * recip_sigma_mv,
        name='offset_compute')

    if freeze_batch_norm_delay is not None:
      use_mv_avg = math_ops.greater_equal(
          common.CreateOrGetQuantizationStep(),
          freeze_batch_norm_delay,
          name='use_moving_average')
    else:
      use_mv_avg = False

    bn_decay_zero = 0.0
    bn_decay_mean_consumers = list(match.bn_decay_mean_tensor.consumers())
    bn_decay_var_consumers = list(match.bn_decay_mean_tensor.consumers())

    bn_decay_mean_out = utils.smart_cond(
        use_mv_avg,
        lambda: bn_decay_zero,
        lambda: match.bn_decay_mean_tensor,
        name='freeze_moving_mean')
    graph_editor.reroute_ts(
        [bn_decay_mean_out], [match.bn_decay_mean_tensor],
        can_modify=bn_decay_mean_consumers)

    if fused_batch_norm is False:
      bn_decay_var_consumers = list(match.bn_decay_var_tensor.consumers())
      bn_decay_var_out = utils.smart_cond(
          use_mv_avg,
          lambda: bn_decay_zero,
          lambda: match.bn_decay_var_tensor,
          name='freeze_moving_var')
      graph_editor.reroute_ts(
          [bn_decay_var_out], [match.bn_decay_var_tensor],
          can_modify=bn_decay_var_consumers)

    correction_recip = utils.smart_cond(
        use_mv_avg,
        lambda: array_ops.ones(correction_scale.shape),
        lambda: correction_recip,
        name='correction_recip')

    correction_offset = utils.smart_cond(
        use_mv_avg,
        lambda: correction_offset,
        lambda: array_ops.zeros(correction_offset.shape),
        name='correction_offset')
  return correction_scale, correction_recip, correction_offset
예제 #31
0
def _BatchNormGrad(grad_y,
                   x,
                   scale,
                   pop_mean,
                   pop_var,
                   epsilon,
                   data_format,
                   is_training=True):
  """Returns the gradients for the 3 inputs of BatchNorm.

  Args:
    grad_y: A `Tensor` of 4 dimensions for gradient for y.
    x: A `Tensor` of 4 dimensions for x.
    scale: A `Tensor` of 1 dimension for scaling.
    pop_mean: A `Tensor` of 1 dimension for the population mean. Only used when
      is_training=False.
    pop_var: A `Tensor` of 1 dimension for the population variance. Only used
      when is_training=False.
    epsilon: A small float number added to the variance of x.
    data_format: The data format for input. Either b"NHWC" or b"NCHW".
    is_training: A bool value to indicate the operation is for training
      (default) or inference.

  Returns:
    A tuple (grad_x, grad_scale, grad_offset), where grad_x is the gradient
    for x, grad_scale the gradient for scale, and grad_offset the gradient
    for offset.
  """
  x_dtype = x.dtype.base_dtype
  if x_dtype == dtypes.float16:
    # float16 math is too imprecise, so we do the batch norm gradient
    # computations in float32.
    x = math_ops.cast(x, dtypes.float32)
    grad_y = math_ops.cast(grad_y, dtypes.float32)
  if is_training:
    if data_format == b"NHWC":
      keepdims = False
      reduce_axis = [0, 1, 2]
    else:
      keepdims = True
      reduce_axis = [0, 2, 3]
      shape = [1, array_ops.size(scale), 1, 1]
      scale = array_ops.reshape(scale, shape)
    mean_grad_y = math_ops.reduce_mean(grad_y, reduce_axis, keepdims=keepdims)
    mean_x = math_ops.reduce_mean(x, reduce_axis, keepdims=keepdims)
    var_x = math_ops.reduce_mean(
        math_ops.squared_difference(x, array_ops.stop_gradient(mean_x)),
        reduce_axis,
        keepdims=keepdims)
    grad_y_offset = grad_y - mean_grad_y
    x_offset = x - mean_x
    mean = math_ops.reduce_mean(
        grad_y * x_offset, axis=reduce_axis, keepdims=keepdims)
    grad_x = scale * math_ops.rsqrt(var_x + epsilon) * (
        grad_y_offset - math_ops.reciprocal(var_x + epsilon) * mean * x_offset)
    grad_scale = math_ops.rsqrt(var_x + epsilon) * math_ops.reduce_sum(
        grad_y * x_offset, axis=reduce_axis, keepdims=keepdims)
    if data_format == b"NCHW":
      grad_scale = array_ops.squeeze(grad_scale)
    grad_offset = math_ops.reduce_sum(grad_y, axis=reduce_axis)
    return math_ops.cast(grad_x, x_dtype), grad_scale, grad_offset
  else:
    if data_format == b"NHWC":
      reduce_axis = [0, 1, 2]
    else:
      reduce_axis = [0, 2, 3]
      shape = [1, array_ops.size(pop_mean), 1, 1]
      pop_mean = array_ops.reshape(pop_mean, shape)
      pop_var = array_ops.reshape(pop_var, shape)
      scale = array_ops.reshape(scale, shape)

    grad_offset = math_ops.reduce_sum(grad_y, axis=reduce_axis)
    var_rsqrt = math_ops.rsqrt(pop_var + epsilon)
    grad_scale = math_ops.reduce_sum(
        grad_y * (x - pop_mean) * var_rsqrt, axis=reduce_axis)
    grad_x = grad_y * scale * var_rsqrt
    return math_ops.cast(grad_x, x_dtype), grad_scale, grad_offset
예제 #32
0
def _SvdGrad(op, grad_s, grad_u, grad_v):
  """Gradient for the singular value decomposition."""

  # The derivation for the compute_uv=False case, and most of
  # the derivation for the full_matrices=True case, are in
  # Giles' paper (see reference at top of file).  A derivation for
  # the full_matrices=False case is available at
  # https://j-towns.github.io/papers/svd-derivative.pdf
  a = op.inputs[0]
  a_shape = a.get_shape().with_rank_at_least(2)
  grad_s_mat = array_ops.matrix_diag(grad_s)

  if not op.get_attr("compute_uv"):
    s, u, v = linalg_ops.svd(a, compute_uv=True)
    grad_a = math_ops.matmul(u, math_ops.matmul(grad_s_mat, v, adjoint_b=True))
    grad_a.set_shape(a_shape)
    return grad_a

  full_matrices = op.get_attr("full_matrices")

  # TODO(rmlarsen): Make this work with complex types.
  if a.dtype.is_complex:
    raise NotImplementedError(
        "SVD gradient is not implemented for complex types and "
        "compute_uv=True.")
  grad_u_shape = grad_u.get_shape().with_rank_at_least(2)
  grad_v_shape = grad_v.get_shape().with_rank_at_least(2)
  m = a_shape.dims[-2].merge_with(grad_u_shape[-2])
  n = a_shape.dims[-1].merge_with(grad_v_shape[-2])
  batch_shape = a_shape[:-2].merge_with(grad_u_shape[:-2]).merge_with(
      grad_v_shape[:-2])
  a_shape = batch_shape.concatenate([m, n])

  m = a_shape.dims[-2].value
  n = a_shape.dims[-1].value
  # TODO(rmlarsen): Make this work with placeholders.
  if m is None or n is None:
    raise NotImplementedError(
        "SVD gradient has not been implemented for input with unknown "
        "inner matrix shape.")

  s = op.outputs[0]
  u = op.outputs[1]
  v = op.outputs[2]

  use_adjoint = False
  if m > n:
    # Compute the gradient for A^H = V * S^T * U^H, and (implicitly) take the
    # Hermitian transpose of the gradient at the end.
    use_adjoint = True
    m, n = n, m
    u, v = v, u
    grad_u, grad_v = grad_v, grad_u

  with ops.control_dependencies([grad_s, grad_u, grad_v]):
    if full_matrices and abs(m - n) > 1:
      raise NotImplementedError(
          "svd gradient is not implemented for abs(m - n) > 1 "
          "when full_matrices is True")
    s_mat = array_ops.matrix_diag(s)
    s2 = math_ops.square(s)

    # NOTICE: Because of the term involving f, the gradient becomes
    # infinite (or NaN in practice) when singular values are not unique.
    # Mathematically this should not be surprising, since for (k-fold)
    # degenerate singular values, the corresponding singular vectors are
    # only defined up a (k-dimensional) subspace. In practice, this can
    # lead to numerical instability when singular values are close but not
    # exactly equal.
    f = array_ops.matrix_set_diag(
        math_ops.reciprocal(
            array_ops.expand_dims(s2, -2) - array_ops.expand_dims(s2, -1)),
        array_ops.zeros_like(s))
    s_inv_mat = array_ops.matrix_diag(math_ops.reciprocal(s))

    v1 = v[..., :, :m]
    grad_v1 = grad_v[..., :, :m]

    u_gu = math_ops.matmul(u, grad_u, adjoint_a=True)
    v_gv = math_ops.matmul(v1, grad_v1, adjoint_a=True)

    f_u = f * u_gu
    f_v = f * v_gv

    term1_nouv = (
        grad_s_mat + math_ops.matmul(f_u + _linalg.adjoint(f_u), s_mat) +
        math_ops.matmul(s_mat, f_v + _linalg.adjoint(f_v)))

    term1 = math_ops.matmul(u, math_ops.matmul(term1_nouv, v1, adjoint_b=True))

    if m == n:
      grad_a_before_transpose = term1
    else:
      gv1t = array_ops.matrix_transpose(grad_v1)
      gv1t_v1 = math_ops.matmul(gv1t, v1)
      term2_nous = gv1t - math_ops.matmul(gv1t_v1, v1, adjoint_b=True)

      if full_matrices:
        v2 = v[..., :, m:n]
        grad_v2 = grad_v[..., :, m:n]

        v1t_gv2 = math_ops.matmul(v1, grad_v2, adjoint_a=True)
        term2_nous -= math_ops.matmul(v1t_gv2, v2, adjoint_b=True)

      u_s_inv = math_ops.matmul(u, s_inv_mat)
      term2 = math_ops.matmul(u_s_inv, term2_nous)

      grad_a_before_transpose = term1 + term2

    if use_adjoint:
      grad_a = array_ops.matrix_transpose(grad_a_before_transpose)
    else:
      grad_a = grad_a_before_transpose

    grad_a.set_shape(a_shape)
    return grad_a
예제 #33
0
 def _get_matches_hook(self, y_pred_click_ranks):
     """Return reciprocal click ranks for MRR"""
     return math_ops.reciprocal(tf.cast(y_pred_click_ranks, tf.float32))
예제 #34
0
def _ComputeBatchNormCorrections(context, match, freeze_batch_norm_delay):
    """Computes batch norm correction params.

     Before batch normalization is frozen:
     We use batch statistics for batch norm.
       correction_scale = sigma_b/sigma_mv
       correction_recip = 1/correction_scale
       correction_offset = 0

     After batch normalization is frozen:
      correction_scale = sigma_b/sigma_mv
      correction_recip = 1
      correction_offset =  gamma*(mu_b/sigma_b-mu_mv/sigma_mv).

     Batch norm is frozen if global_step > bn_freeze_delay.
     The corrections ensure that:
     a) The weights are quantized after scaling by gamma/sigma_mv. This enables
     smoother training as the scaling on the weights changes slowly, rather than
     jump across mini-batches
     b) Changing the values of the corrections allows for one to switch between
     using batch statistics to using moving mean and average, without requiring
     changes to batch_norm


  Args:
    context: The scope under which we look for batch norm params
    match: Object containing required batch norm tensors for correction
      computation.
    freeze_batch_norm_delay: Delay in steps at which computation switches
      from regular batch norm to frozen mean and variance.


  Returns:
    A tuple of correction_scale, correction_recip, correction_offset
  """

    g = ops.get_default_graph()
    prefix = '' if not context else context
    with g.name_scope(prefix + 'batch_norm_correction'):
        recip_sigma_mv = math_ops.rsqrt(match.moving_variance_tensor +
                                        match.batch_epsilon)
        recip_sigma = math_ops.rsqrt(match.variance_tensor +
                                     match.batch_epsilon)
        correction_scale = math_ops.divide(recip_sigma_mv,
                                           recip_sigma,
                                           name='scale_compute')
        correction_scale = array_ops.identity(correction_scale,
                                              name='correction_scale')
        correction_recip = math_ops.reciprocal(correction_scale,
                                               name='reciprocal_compute')
        mv = match.moving_mean_tensor  #if match.moving_mean_tensor is not None else 0
        correction_offset = math_ops.multiply(match.gamma_tensor,
                                              match.mean_tensor * recip_sigma -
                                              mv,
                                              name='offset_compute')

        if freeze_batch_norm_delay is not None:
            use_mv_avg = math_ops.greater_equal(
                common.CreateOrGetQuantizationStep(),
                freeze_batch_norm_delay,
                name='use_moving_average')
        else:
            use_mv_avg = False

        bn_decay_zero = 0.0
        bn_decay_mean_consumers = list(match.bn_decay_mean_tensor.consumers())
        bn_decay_var_consumers = list(match.bn_decay_mean_tensor.consumers())

        bn_decay_mean_out = utils.smart_cond(
            use_mv_avg,
            lambda: bn_decay_zero,
            lambda: match.bn_decay_mean_tensor,
            name='freeze_moving_mean')

        common.RerouteTensor(bn_decay_mean_out,
                             match.bn_decay_mean_tensor,
                             can_modify=bn_decay_mean_consumers)

        bn_decay_var_consumers = list(match.bn_decay_var_tensor.consumers())
        bn_decay_var_out = utils.smart_cond(use_mv_avg,
                                            lambda: bn_decay_zero,
                                            lambda: match.bn_decay_var_tensor,
                                            name='freeze_moving_var')
        common.RerouteTensor(bn_decay_var_out,
                             match.bn_decay_var_tensor,
                             can_modify=bn_decay_var_consumers)

        correction_recip = utils.smart_cond(
            use_mv_avg,
            lambda: array_ops.ones(correction_scale.shape),
            lambda: correction_recip,
            name='correction_recip')

        correction_offset = utils.smart_cond(
            use_mv_avg,
            lambda: correction_offset,
            lambda: array_ops.zeros(correction_offset.shape),
            name='correction_offset')
    return correction_scale, correction_recip, correction_offset
예제 #35
0
def _Log1pGrad(op, grad):
    """Returns grad * (1/(1 + x))."""
    x = op.inputs[0]
    with ops.control_dependencies([grad]):
        x = math_ops.conj(x)
        return grad * math_ops.reciprocal(1 + x)
예제 #36
0
def _SvdGrad(op, grad_s, grad_u, grad_v):
  """Gradient for Svd based on Giles' algorithm. Reference at top of file."""

  if op.get_attr("compute_uv") and not op.get_attr("full_matrices"):
    raise NotImplementedError(
        "SVD gradient is not implemented for compute_uv=True and "
        "full_matrices=False.")

  a = op.inputs[0]
  a_shape = a.get_shape().with_rank_at_least(2)

  if op.get_attr("compute_uv"):
    # TODO(rmlarsen): Make this work with complex types.
    if a.dtype.is_complex:
      raise NotImplementedError(
          "SVD gradient is not implemented for complex types and "
          "compute_uv=True.")
    grad_u_shape = grad_u.get_shape().with_rank_at_least(2)
    grad_v_shape = grad_v.get_shape().with_rank_at_least(2)
    m = a_shape[-2].merge_with(grad_u_shape[-2])
    n = a_shape[-1].merge_with(grad_v_shape[-2])
    batch_shape = a_shape[:-2].merge_with(grad_u_shape[:-2]).merge_with(
        grad_v_shape[:-2])
    a_shape = batch_shape.concatenate([m, n])

  m = a_shape[-2].value
  n = a_shape[-1].value
  # TODO(rmlarsen): Make this work with placeholders.
  if m is None or n is None:
    raise NotImplementedError(
        "SVD gradient has not been implemented for input with unknown "
        "inner matrix shape.")

  if not op.get_attr("full_matrices") or not op.get_attr("compute_uv"):
    s, u, v = linalg_ops.svd(a, compute_uv=True, full_matrices=True)
  else:
    s = op.outputs[0]
    u = op.outputs[1]
    v = op.outputs[2]

  use_adjoint = False
  if m > n:
    # Compute the gradient for A^H = V * S^T * U^H, and (implicitly) take the
    # Hermitian transpose of the gradient at the end.
    use_adjoint = True
    m, n = n, m
    u, v = v, u
    grad_u, grad_v = grad_v, grad_u

  with ops.control_dependencies([grad_s, grad_u, grad_v]):
    grad_s_mat = array_ops.matrix_diag(grad_s)
    if not op.get_attr("compute_uv"):
      if use_adjoint:
        grad_a = math_ops.matmul(
            v[..., :, :m], math_ops.matmul(u, grad_s_mat), adjoint_b=True)
      else:
        grad_a = math_ops.matmul(u,
                                 math_ops.matmul(
                                     grad_s_mat, v[..., :, :m], adjoint_b=True))
      grad_a.set_shape(a_shape)
      return grad_a

    # TODO(rmlarsen): Define a gradient that is numerically stable for
    # abs(m-n) > 1. Currently this does not work because there are effectively
    # multiple singular values with value zero. I am not sure if this is a true
    # instability or if it simply throws off the finite difference gradient
    # checker.
    if abs(m - n) > 1:
      raise NotImplementedError(
          "svd gradient is not implemented for abs(m - n) > 1")
    s_mat = array_ops.matrix_diag(s)
    s2 = math_ops.square(s)

    # NOTICE: Because of the term involving f, the gradient becomes
    # infinite (or NaN in practice) when singular values are not unique.
    # Mathematically this should not be surprising, since for (k-fold)
    # degenerate singular values, the corresponding singular vectors are
    # only defined up a (k-dimensional) subspace. In practice, this can
    # lead to numerical instability when singular values are close but not
    # exactly equal.
    f = array_ops.matrix_set_diag(
        math_ops.reciprocal(
            array_ops.expand_dims(s2, -2) - array_ops.expand_dims(s2, -1)),
        array_ops.zeros_like(s))
    s_inv_mat = array_ops.matrix_diag(math_ops.reciprocal(s))
    u_gu = math_ops.matmul(u, grad_u, adjoint_a=True)
    v_gv = math_ops.matmul(v, grad_v, adjoint_a=True)

    if m == n:
      f_u = f * u_gu
      f_v = f * v_gv
    else:
      dv2 = array_ops.matrix_transpose(v_gv[..., m:n, :m]) - v_gv[..., :m, m:n]
      f_u = f * u_gu
      f_v = f * v_gv[..., :m, :m]

    grad_a_nouv = (
        grad_s_mat + math_ops.matmul(f_u + _linalg.adjoint(f_u), s_mat) +
        math_ops.matmul(s_mat, f_v + _linalg.adjoint(f_v)))

    if m != n:
      grad_a_nouv = array_ops.concat(
          [grad_a_nouv, math_ops.matmul(s_inv_mat, dv2)], -1)

    if use_adjoint:
      # Use (U X V^H)^H = V (U X)^H.
      grad_a = math_ops.matmul(
          v, math_ops.matmul(u, grad_a_nouv), adjoint_b=True)
    else:
      grad_a = math_ops.matmul(u,
                               math_ops.matmul(grad_a_nouv, v, adjoint_b=True))

    grad_a.set_shape(a_shape)
    return grad_a
예제 #37
0
    def weighted_moments(x,
                         axes,
                         frequency_weights,
                         tower_config,
                         name=None,
                         keep_dims=False):
        """Returns the frequency-weighted mean and variance of `x`.
        Args:
          x: A tensor.
          axes: 1-d tensor of int32 values; these are the axes along which
            to compute mean and variance.
          frequency_weights: A tensor of positive weights which can be
            broadcast with x.
          name: Name used to scope the operation.
          keep_dims: Produce moments with the same dimensionality as the input.
        Returns:
          Two tensors: `weighted_mean` and `weighted_variance`.
        """
        with ops.name_scope(name, "weighted_moments",
                            [x, frequency_weights, axes]):
            x = ops.convert_to_tensor(x, name="x")
            frequency_weights = ops.convert_to_tensor(frequency_weights,
                                                      name="frequency_weights")

            # Unlike moments(), this just uses a simpler two-pass method.

            # See comment in moments() WRT precision; it applies here too.
            needs_cast = x.dtype == dtypes.float16
            if needs_cast:
                x = math_ops.cast(x, dtypes.float32)

            if frequency_weights.dtype != x.dtype:
                frequency_weights = math_ops.cast(frequency_weights, x.dtype)

            # Note that we use keep_dims=True for our reductions regardless of the arg;
            # this is so that the results remain broadcast-compatible with the inputs.

            # Original Code: weighted_input_sum = math_ops.reduce_sum(
            #     frequency_weights * x, axes, name="weighted_input_sum", keepdims=True)

            nccl_name = "NCCL" if not tower_config.is_test else "NCCL_TEST"
            shared_name = tf.get_variable_scope().name. \
                replace(tower_config.name, tower_config.prefix.format(nccl_name))
            device_weighted_input_sum = math_ops.reduce_sum(
                frequency_weights * x,
                axes,
                name="weighted_input_sum",
                keepdims=True)
            weighted_input_sum = gen_nccl_ops.nccl_all_reduce(
                input=device_weighted_input_sum,
                reduction="sum",
                num_devices=tower_config.num_devices,
                shared_name=shared_name) / (1.0 * tower_config.num_devices)

            # The shape of the weights isn't necessarily the same as x's
            # shape, just broadcast-compatible with it -- so this expression
            # performs broadcasting to give a per-item weight, with the same
            # shape as (freqency_weights * x). This avoids having to reason
            # through all the broadcast logic to compute a correct
            # sum_of_weights.
            broadcasted_weights = frequency_weights + array_ops.zeros_like(x)

            sum_of_weights = math_ops.reduce_sum(broadcasted_weights,
                                                 axes,
                                                 name="sum_of_weights",
                                                 keepdims=True)

            divisor = math_ops.reciprocal(sum_of_weights,
                                          name="inv_weight_sum")

            weighted_mean = math_ops.multiply(weighted_input_sum, divisor)

            # Have the weighted mean; now on to variance:
            # weighted_distsq = math_ops.reduce_sum(
            #     frequency_weights * math_ops.squared_difference(x, weighted_mean),
            #     axes,
            #     name="weighted_distsq",
            #     keepdims=True)

            nccl_name = "NCCL" if not tower_config.is_test else "NCCL_TEST"
            shared_name = tf.get_variable_scope().name. \
                replace(tower_config.name, tower_config.prefix.format(nccl_name))
            device_weighted_distsq = math_ops.reduce_sum(
                frequency_weights *
                math_ops.squared_difference(x, weighted_mean),
                axes,
                name="weighted_distsq",
                keepdims=True)
            weighted_distsq = gen_nccl_ops.nccl_all_reduce(
                input=device_weighted_distsq,
                reduction="sum",
                num_devices=tower_config.num_devices,
                shared_name=shared_name) / (1.0 * tower_config.num_devices)

            weighted_variance = math_ops.multiply(weighted_distsq, divisor)

            if not keep_dims:
                weighted_mean = array_ops.squeeze(weighted_mean, axis=axes)
                weighted_variance = array_ops.squeeze(weighted_variance,
                                                      axis=axes)

            if needs_cast:
                weighted_mean = math_ops.cast(weighted_mean, dtypes.float16)
                weighted_variance = math_ops.cast(weighted_variance,
                                                  dtypes.float16)

            return weighted_mean, weighted_variance
예제 #38
0
def _Log1pGrad(op, grad):
  """Returns grad * (1/(1 + x))."""
  x = op.inputs[0]
  with ops.control_dependencies([grad]):
    x = math_ops.conj(x)
    return grad * math_ops.reciprocal(1 + x)
예제 #39
0
def _SvdGrad(op, grad_s, grad_u, grad_v):
    """Gradient for Svd based on Giles' algorithm. Reference at top of file."""

    if op.get_attr("compute_uv") and not op.get_attr("full_matrices"):
        raise NotImplementedError(
            "SVD gradient is not implemented for compute_uv=True and "
            "full_matrices=False.")

    a = op.inputs[0]
    a_shape = a.get_shape().with_rank_at_least(2)

    if op.get_attr("compute_uv"):
        # TODO(rmlarsen): Make this work with complex types.
        if a.dtype.is_complex:
            raise NotImplementedError(
                "SVD gradient is not implemented for complex types and "
                "compute_uv=True.")
        grad_u_shape = grad_u.get_shape().with_rank_at_least(2)
        grad_v_shape = grad_v.get_shape().with_rank_at_least(2)
        m = a_shape[-2].merge_with(grad_u_shape[-2])
        n = a_shape[-1].merge_with(grad_v_shape[-2])
        batch_shape = a_shape[:-2].merge_with(grad_u_shape[:-2]).merge_with(
            grad_v_shape[:-2])
        a_shape = batch_shape.concatenate([m, n])

    m = a_shape[-2].value
    n = a_shape[-1].value
    # TODO(rmlarsen): Make this work with placeholders.
    if m is None or n is None:
        raise NotImplementedError(
            "SVD gradient has not been implemented for input with unknown "
            "inner matrix shape.")

    if not op.get_attr("full_matrices") or not op.get_attr("compute_uv"):
        s, u, v = linalg_ops.svd(a, compute_uv=True, full_matrices=True)
    else:
        s = op.outputs[0]
        u = op.outputs[1]
        v = op.outputs[2]

    use_adjoint = False
    if m > n:
        # Compute the gradient for A^H = V * S^T * U^H, and (implicitly) take the
        # Hermitian transpose of the gradient at the end.
        use_adjoint = True
        m, n = n, m
        u, v = v, u
        grad_u, grad_v = grad_v, grad_u

    with ops.control_dependencies([grad_s, grad_u, grad_v]):
        grad_s_mat = array_ops.matrix_diag(grad_s)
        if not op.get_attr("compute_uv"):
            if use_adjoint:
                grad_a = math_ops.matmul(v[..., :, :m],
                                         math_ops.matmul(u, grad_s_mat),
                                         adjoint_b=True)
            else:
                grad_a = math_ops.matmul(
                    u,
                    math_ops.matmul(grad_s_mat, v[..., :, :m], adjoint_b=True))
            grad_a.set_shape(a_shape)
            return grad_a

        # TODO(rmlarsen): Define a gradient that is numerically stable for
        # abs(m-n) > 1. Currently this does not work because there are effectively
        # multiple singular values with value zero. I am not sure if this is a true
        # instability or if it simply throws off the finite difference gradient
        # checker.
        if abs(m - n) > 1:
            raise NotImplementedError(
                "svd gradient is not implemented for abs(m - n) > 1")
        s_mat = array_ops.matrix_diag(s)
        s2 = math_ops.square(s)

        # NOTICE: Because of the term involving f, the gradient becomes
        # infinite (or NaN in practice) when singular values are not unique.
        # Mathematically this should not be surprising, since for (k-fold)
        # degenerate singular values, the corresponding singular vectors are
        # only defined up a (k-dimensional) subspace. In practice, this can
        # lead to numerical instability when singular values are close but not
        # exactly equal.
        f = array_ops.matrix_set_diag(
            math_ops.reciprocal(
                array_ops.expand_dims(s2, -2) - array_ops.expand_dims(s2, -1)),
            array_ops.zeros_like(s))
        s_inv_mat = array_ops.matrix_diag(math_ops.reciprocal(s))
        u_gu = math_ops.matmul(u, grad_u, adjoint_a=True)
        v_gv = math_ops.matmul(v, grad_v, adjoint_a=True)

        if m == n:
            f_u = f * u_gu
            f_v = f * v_gv
        else:
            dv2 = array_ops.matrix_transpose(
                v_gv[..., m:n, :m]) - v_gv[..., :m, m:n]
            f_u = f * u_gu
            f_v = f * v_gv[..., :m, :m]

        grad_a_nouv = (grad_s_mat +
                       math_ops.matmul(f_u + _linalg.adjoint(f_u), s_mat) +
                       math_ops.matmul(s_mat, f_v + _linalg.adjoint(f_v)))

        if m != n:
            grad_a_nouv = array_ops.concat(
                [grad_a_nouv, math_ops.matmul(s_inv_mat, dv2)], -1)

        if use_adjoint:
            # Use (U X V^H)^H = V (U X)^H.
            grad_a = math_ops.matmul(v,
                                     math_ops.matmul(u, grad_a_nouv),
                                     adjoint_b=True)
        else:
            grad_a = math_ops.matmul(
                u, math_ops.matmul(grad_a_nouv, v, adjoint_b=True))

        grad_a.set_shape(a_shape)
        return grad_a
예제 #40
0
def weighted_moments(x, axes, frequency_weights, name=None, keep_dims=False):
  """Returns the frequency-weighted mean and variance of `x`.

  Args:
    x: A tensor.
    axes: 1-d tensor of int32 values; these are the axes along which
      to compute mean and variance.
    frequency_weights: A tensor of positive weights which can be
      broadcast with x.
    name: Name used to scope the operation.
    keep_dims: Produce moments with the same dimensionality as the input.

  Returns:
    Two tensors: `weighted_mean` and `weighted_variance`.
  """
  with ops.name_scope(name, "weighted_moments", [x, frequency_weights, axes]):
    x = ops.convert_to_tensor(x, name="x")
    frequency_weights = ops.convert_to_tensor(
        frequency_weights, name="frequency_weights")

    # Unlike moments(), this just uses a simpler two-pass method.

    # See comment in moments() WRT precision; it applies here too.
    needs_cast = x.dtype == dtypes.float16
    if needs_cast:
      x = math_ops.cast(x, dtypes.float32)

    if frequency_weights.dtype != x.dtype:
      frequency_weights = math_ops.cast(frequency_weights, x.dtype)

    # Note that we use keep_dims=True for our reductions regardless of the arg;
    # this is so that the results remain broadcast-compatible with the inputs.
    weighted_input_sum = math_ops.reduce_sum(
        frequency_weights * x, axes, name="weighted_input_sum", keep_dims=True)

    # The shape of the weights isn't necessarily the same as x's
    # shape, just broadcast-compatible with it -- so this expression
    # performs broadcasting to give a per-item weight, with the same
    # shape as (freqency_weights * x). This avoids having to reason
    # through all the broadcast logic to compute a correct
    # sum_of_weights.
    broadcasted_weights = frequency_weights + array_ops.zeros_like(x)

    sum_of_weights = math_ops.reduce_sum(
        broadcasted_weights, axes, name="sum_of_weights", keep_dims=True)

    divisor = math_ops.reciprocal(sum_of_weights, name="inv_weight_sum")

    weighted_mean = math_ops.mul(weighted_input_sum, divisor)

    # Have the weighted mean; now on to variance:
    weighted_distsq = math_ops.reduce_sum(
        frequency_weights * math_ops.squared_difference(x, weighted_mean),
        axes,
        name="weighted_distsq",
        keep_dims=True)

    weighted_variance = math_ops.mul(weighted_distsq, divisor)

    if not keep_dims:
      weighted_mean = array_ops.squeeze(weighted_mean, squeeze_dims=axes)
      weighted_variance = array_ops.squeeze(
          weighted_variance, squeeze_dims=axes)

    if needs_cast:
      weighted_mean = math_ops.cast(weighted_mean, dtypes.float16)
      weighted_variance = math_ops.cast(weighted_variance, dtypes.float16)

    return weighted_mean, weighted_variance
예제 #41
0
def _SafeReciprocal(x, epsilon=1E-20):
    return x * math_ops.reciprocal(x * x + epsilon)
예제 #42
0
  def _mini_batch_training_op(self, inputs, cluster_idx_list, cluster_centers,
                              cluster_centers_var, total_counts):
    """Creates an op for training for mini batch case.

    Args:
      inputs: list of input Tensors.
      cluster_idx_list: A vector (or list of vectors). Each element in the
        vector corresponds to an input row in 'inp' and specifies the cluster id
        corresponding to the input.
      cluster_centers: Tensor of cluster centers, possibly normalized.
      cluster_centers_var: Tensor Ref of cluster centers.
      total_counts: Tensor Ref of cluster counts.

    Returns:
      An op for doing an update of mini-batch k-means.
    """
    update_ops = []
    for inp, cluster_idx in zip(inputs, cluster_idx_list):
      with ops.colocate_with(inp):
        assert total_counts is not None
        cluster_idx = array_ops.reshape(cluster_idx, [-1])
        # Dedupe the unique ids of cluster_centers being updated so that updates
        # can be locally aggregated.
        unique_ids, unique_idx = array_ops.unique(cluster_idx)
        num_unique_cluster_idx = array_ops.size(unique_ids)
        # Fetch the old values of counts and cluster_centers.
        with ops.colocate_with(total_counts):
          old_counts = array_ops.gather(total_counts, unique_ids)
        with ops.colocate_with(cluster_centers):
          old_cluster_centers = array_ops.gather(cluster_centers, unique_ids)
        # Locally aggregate the increment to counts.
        count_updates = math_ops.unsorted_segment_sum(
            array_ops.ones_like(
                unique_idx, dtype=total_counts.dtype),
            unique_idx,
            num_unique_cluster_idx)
        # Locally compute the sum of inputs mapped to each id.
        # For a cluster with old cluster value x, old count n, and with data
        # d_1,...d_k newly assigned to it, we recompute the new value as
        # x += (sum_i(d_i) - k * x) / (n + k).
        # Compute sum_i(d_i), see comment above.
        cluster_center_updates = math_ops.unsorted_segment_sum(
            inp, unique_idx, num_unique_cluster_idx)
        # Shape to enable broadcasting count_updates and learning_rate to inp.
        # It extends the shape with 1's to match the rank of inp.
        broadcast_shape = array_ops.concat(
            [
                array_ops.reshape(num_unique_cluster_idx, [1]), array_ops.ones(
                    array_ops.reshape(array_ops.rank(inp) - 1, [1]),
                    dtype=dtypes.int32)
            ],
            0)
        # Subtract k * x, see comment above.
        cluster_center_updates -= math_ops.cast(
            array_ops.reshape(count_updates, broadcast_shape),
            inp.dtype) * old_cluster_centers
        learning_rate = math_ops.reciprocal(
            math_ops.cast(old_counts + count_updates, inp.dtype))
        learning_rate = array_ops.reshape(learning_rate, broadcast_shape)
        # scale by 1 / (n + k), see comment above.
        cluster_center_updates *= learning_rate
        # Apply the updates.
      update_counts = state_ops.scatter_add(total_counts, unique_ids,
                                            count_updates)
      update_cluster_centers = state_ops.scatter_add(cluster_centers_var,
                                                     unique_ids,
                                                     cluster_center_updates)
      update_ops.extend([update_counts, update_cluster_centers])
    return control_flow_ops.group(*update_ops)
예제 #43
0
    def _mini_batch_training_op(self, inputs, cluster_idx_list,
                                cluster_centers, total_counts):
        """Creates an op for training for mini batch case.

    Args:
      inputs: list of input Tensors.
      cluster_idx_list: A vector (or list of vectors). Each element in the
        vector corresponds to an input row in 'inp' and specifies the cluster id
        corresponding to the input.
      cluster_centers: Tensor Ref of cluster centers.
      total_counts: Tensor Ref of cluster counts.

    Returns:
      An op for doing an update of mini-batch k-means.
    """
        update_ops = []
        for inp, cluster_idx in zip(inputs, cluster_idx_list):
            with ops.colocate_with(inp):
                assert total_counts is not None
                cluster_idx = array_ops.reshape(cluster_idx, [-1])
                # Dedupe the unique ids of cluster_centers being updated so that updates
                # can be locally aggregated.
                unique_ids, unique_idx = array_ops.unique(cluster_idx)
                num_unique_cluster_idx = array_ops.size(unique_ids)
                # Fetch the old values of counts and cluster_centers.
                with ops.colocate_with(total_counts, ignore_existing=True):
                    old_counts = array_ops.gather(total_counts, unique_ids)
                # TODO(agarwal): This colocation seems to run into problems. Fix it.
                with ops.colocate_with(cluster_centers, ignore_existing=True):
                    old_cluster_centers = array_ops.gather(
                        cluster_centers, unique_ids)
                # Locally aggregate the increment to counts.
                count_updates = math_ops.unsorted_segment_sum(
                    array_ops.ones_like(unique_idx, dtype=total_counts.dtype),
                    unique_idx, num_unique_cluster_idx)
                # Locally compute the sum of inputs mapped to each id.
                # For a cluster with old cluster value x, old count n, and with data
                # d_1,...d_k newly assigned to it, we recompute the new value as
                # x += (sum_i(d_i) - k * x) / (n + k).
                # Compute sum_i(d_i), see comment above.
                cluster_center_updates = math_ops.unsorted_segment_sum(
                    inp, unique_idx, num_unique_cluster_idx)
                # Shape to enable broadcasting count_updates and learning_rate to inp.
                # It extends the shape with 1's to match the rank of inp.
                broadcast_shape = array_ops.concat([
                    array_ops.reshape(num_unique_cluster_idx, [1]),
                    array_ops.ones(array_ops.reshape(
                        array_ops.rank(inp) - 1, [1]),
                                   dtype=dtypes.int32)
                ], 0)
                # Subtract k * x, see comment above.
                cluster_center_updates -= math_ops.cast(
                    array_ops.reshape(count_updates, broadcast_shape),
                    inp.dtype) * old_cluster_centers
                learning_rate = math_ops.reciprocal(
                    math_ops.cast(old_counts + count_updates, inp.dtype))
                learning_rate = array_ops.reshape(learning_rate,
                                                  broadcast_shape)
                # scale by 1 / (n + k), see comment above.
                cluster_center_updates *= learning_rate
                # Apply the updates.
            update_counts = state_ops.scatter_add(total_counts, unique_ids,
                                                  count_updates)
            update_cluster_centers = state_ops.scatter_add(
                cluster_centers, unique_ids, cluster_center_updates)
            update_ops.extend([update_counts, update_cluster_centers])
        return control_flow_ops.group(*update_ops)