def _get_bn_correction(conv_type, kernel, bias, mu_bt, var_bt, mu_mv, var_mv, gamma, epsilon): """Get batchnorm correction params. Before freeze: corr_scale = sigma_bt / sigma_mv corr_recip = 1 / corr_scale corr_offset = 0 After freeze: corr_scale = sigma_bt / sigma_mv corr_recip = 1 corr_offset = gamma * ( (mu_bt - bias)/sigma_bt - (mu_mv - bias)/sigma_mv) """ sigma_bt = math_ops.rsqrt(var_bt + epsilon) sigma_bt_recip = math_ops.reciprocal(sigma_bt) sigma_mv = math_ops.rsqrt(var_mv + epsilon) sigma_mv_recip = math_ops.reciprocal(sigma_mv) corr_scale = math_ops.divide(sigma_bt, sigma_mv, name='corr_scale') corr_recip = math_ops.reciprocal(corr_scale) corr_offset = array_ops.zeros(mu_bt.shape) if conv_type == 'DepthwiseConv2D': new_shape = [kernel.shape[2], kernel.shape[3]] corr_scale = array_ops.reshape(corr_scale, new_shape) return corr_scale, corr_recip, corr_offset
def _grad_and_hess_for_logloss(logits, labels): # TODO(youngheek): add weights handling. predictions = math_ops.reciprocal(math_ops.exp(-logits) + 1.0) normalizer = math_ops.reciprocal( math_ops.cast(array_ops.size(predictions), dtypes.float32)) gradients = (predictions - labels) * normalizer hessians = predictions * (1.0 - predictions) * normalizer return gradients, hessians
def _grad_and_hess_for_logloss(logits, labels): """A closed form gradient and hessian for logistic loss.""" # TODO(youngheek): add weights handling. predictions = math_ops.reciprocal(math_ops.exp(-logits) + 1.0) normalizer = math_ops.reciprocal( math_ops.cast(array_ops.size(predictions), dtypes.float32)) labels = math_ops.cast(labels, dtypes.float32) labels = head_lib._check_dense_labels_match_logits_and_reshape( # pylint: disable=protected-access labels, logits, head.logits_dimension) gradients = (predictions - labels) * normalizer hessians = predictions * (1.0 - predictions) * normalizer return gradients, hessians
def _grad(op, grad): """A gradient function for IRFFT with the provided `rank` and `rfft_fn`.""" # Generate a simple mask like [1.0, 2.0, ..., 2.0, 1.0] for even-length FFTs # and [1.0, 2.0, ..., 2.0] for odd-length FFTs. To reduce extra ops in the # graph we special-case the situation where the FFT length and last # dimension of the input are known at graph construction time. fft_length = op.inputs[1] fft_length_static = _tensor_util.constant_value(fft_length) if fft_length_static is not None: fft_length = fft_length_static real_dtype = grad.dtype if real_dtype == _dtypes.float32: complex_dtype = _dtypes.complex64 elif real_dtype == _dtypes.float64: complex_dtype = _dtypes.complex128 is_odd = _math_ops.mod(fft_length[-1], 2) input_last_dimension = _array_ops.shape(op.inputs[0])[-1] mask = _array_ops.concat( [[1.0], 2.0 * _array_ops.ones([input_last_dimension - 2 + is_odd], real_dtype), _array_ops.ones([1 - is_odd], real_dtype)], 0) rsize = _math_ops.reciprocal( _math_ops.cast(_fft_size_for_grad(grad, rank), real_dtype)) # The gradient of IRFFT is the RFFT of the incoming gradient times a scaling # factor and a mask. The mask scales the gradient for the Hermitian # symmetric components of the RFFT by a factor of two, since these # components are de-duplicated in the RFFT. the_rfft = rfft_fn(grad, fft_length) return the_rfft * _math_ops.cast(rsize * mask, complex_dtype), None
def normalize_moments(counts, mean_ss, variance_ss, shift, name=None): """Calculate the mean and variance of based on the sufficient statistics. Args: counts: A `Tensor` containing a the total count of the data (one value). mean_ss: A `Tensor` containing the mean sufficient statistics: the (possibly shifted) sum of the elements to average over. variance_ss: A `Tensor` containing the variance sufficient statistics: the (possibly shifted) squared sum of the data to compute the variance over. shift: A `Tensor` containing the value by which the data is shifted for numerical stability, or `None` if no shift was performed. name: Name used to scope the operations that compute the moments. Returns: Two `Tensor` objects: `mean` and `variance`. """ with ops.name_scope(name, "normalize", [counts, mean_ss, variance_ss, shift]): divisor = math_ops.reciprocal(counts, name="divisor") if shift is not None: shifted_mean = math_ops.multiply(mean_ss, divisor, name="shifted_mean") mean = math_ops.add(shifted_mean, shift, name="mean") else: # no shift. shifted_mean = math_ops.multiply(mean_ss, divisor, name="mean") mean = shifted_mean variance = math_ops.subtract(math_ops.multiply(variance_ss, divisor), math_ops.square(shifted_mean), name="variance") return (mean, variance)
def normalize_moments(counts, mean_ss, variance_ss, shift, name=None): """Calculate the mean and variance of based on the sufficient statistics. Args: counts: A `Tensor` containing a the total count of the data (one value). mean_ss: A `Tensor` containing the mean sufficient statistics: the (possibly shifted) sum of the elements to average over. variance_ss: A `Tensor` containing the variance sufficient statistics: the (possibly shifted) squared sum of the data to compute the variance over. shift: A `Tensor` containing the value by which the data is shifted for numerical stability, or `None` if no shift was performed. name: Name used to scope the operations that compute the moments. Returns: Two `Tensor` objects: `mean` and `variance`. """ with ops.name_scope(name, "normalize", [counts, mean_ss, variance_ss, shift]): divisor = math_ops.reciprocal(counts, name="divisor") if shift is not None: shifted_mean = math_ops.mul(mean_ss, divisor, name="shifted_mean") mean = math_ops.add(shifted_mean, shift, name="mean") else: # no shift. shifted_mean = math_ops.mul(mean_ss, divisor, name="mean") mean = shifted_mean variance = math_ops.sub(math_ops.mul(variance_ss, divisor), math_ops.square(shifted_mean), name="variance") return (mean, variance)
def _TanGrad(op, grad): """Returns grad * 1/sec^2(x).""" x = op.inputs[0] with ops.control_dependencies([grad]): x = math_ops.conj(x) secx = math_ops.reciprocal(math_ops.cos(x)) secx2 = math_ops.square(secx) return grad * secx2
def _AtanGrad(op, grad): """Returns grad * 1/ (1 + x^2).""" x = op.inputs[0] with ops.control_dependencies([grad]): x = math_ops.conj(x) x2 = math_ops.square(x) one = constant_op.constant(1, dtype=grad.dtype) inv = math_ops.reciprocal(math_ops.add(one, x2)) return grad * inv
def qbatch_normalization(x, mean, variance, offset, scale, variance_epsilon, quantizer, name=None): r"""Batch normalization. As described in http://arxiv.org/abs/1502.03167. Normalizes a tensor by `mean` and `variance`, and applies (optionally) a `scale` \\(\gamma\\) to it, as well as an `offset` \\(\beta\\): \\(\frac{\gamma(x-\mu)}{\sigma}+\beta\\) `mean`, `variance`, `offset` and `scale` are all expected to be of one of two shapes: * In all generality, they can have the same number of dimensions as the input `x`, with identical sizes as `x` for the dimensions that are not normalized over (the 'depth' dimension(s)), and dimension 1 for the others which are being normalized over. `mean` and `variance` in this case would typically be the outputs of `tf.nn.moments(..., keep_dims=True)` during training, or running averages thereof during inference. * In the common case where the 'depth' dimension is the last dimension in the input tensor `x`, they may be one dimensional tensors of the same size as the 'depth' dimension. This is the case for example for the common `[batch, depth]` layout of fully-connected layers, and `[batch, height, width, depth]` for convolutions. `mean` and `variance` in this case would typically be the outputs of `tf.nn.moments(..., keep_dims=False)` during training, or running averages thereof during inference. Args: x: Input `Tensor` of arbitrary dimensionality. mean: A mean `Tensor`. variance: A variance `Tensor`. offset: An offset `Tensor`, often denoted \\(\beta\\) in equations, or None. If present, will be added to the normalized tensor. scale: A scale `Tensor`, often denoted \\(\gamma\\) in equations, or `None`. If present, the scale is applied to the normalized tensor. variance_epsilon: A small float number to avoid dividing by 0. name: A name for this operation (optional). Returns: the normalized, scaled, offset tensor. """ with ops.name_scope(name, "batchnorm", [x, mean, variance, scale, offset]): #internal calculation of sqrt is NOT quantized! inv = quantizer.quantize(math_ops.sqrt(variance + variance_epsilon)) inv = quantizer.quantize(math_ops.reciprocal(inv)) if scale is not None: inv = quantizer.quantize(inv * scale) #return x * inv + (offset - mean * inv if offset is not None else -mean * inv) if offset is not None: rest = quantizer.quantize(offset - quantizer.quantize(mean * inv)) else: rest = quantizer.quantize(-mean * inv) result = quantizer.quantize(quantizer.quantize(x * inv) + rest) return result
def _AcosGrad(op, grad): """Returns grad * -1/sqrt(1-x^2).""" x = op.inputs[0] with ops.control_dependencies([grad]): x = math_ops.conj(x) x2 = math_ops.square(x) one = constant_op.constant(1, dtype=grad.dtype) den = math_ops.sqrt(math_ops.subtract(one, x2)) inv = math_ops.reciprocal(den) return -grad * inv
def _AngleGrad(op, grad): """Returns -grad / (Im(x) + iRe(x))""" x = op.inputs[0] with ops.control_dependencies([grad]): re = math_ops.real(x) im = math_ops.imag(x) z = math_ops.reciprocal(math_ops.complex(im, re)) zero = constant_op.constant(0, dtype=grad.dtype) complex_grad = math_ops.complex(grad, zero) return -complex_grad * z
def _get_matches_hook(self, y_pred_click_ranks): """ Return reciprocal click ranks for MRR Parameters ---------- y_pred_click_ranks: Tensor object Tensor object containing the ranks of the clicked records for each query Returns ------- Tensor object Reciprocal ranks tensor """ return math_ops.reciprocal(tf.cast(y_pred_click_ranks, tf.float32))
def _BesselI1eGrad(op, grad): """Compute gradient of bessel_i1e(x) with respect to its argument.""" x = op.inputs[0] y = op.outputs[0] with ops.control_dependencies([grad]): # For x = 0, the correct gradient is 0.5. # However, the main branch gives NaN because of the division by x, so # we impute the gradient manually. # An alternative solution is to express the gradient via bessel_i0e and # bessel_i2e, but the latter is not yet implemented in Eigen. eps = np.finfo(x.dtype.as_numpy_dtype).eps zeros = array_ops.zeros_like(x) x_is_not_tiny = math_ops.abs(x) > eps safe_x = array_ops.where(x_is_not_tiny, x, eps + zeros) dy_dx = math_ops.bessel_i0e(safe_x) - y * ( math_ops.sign(safe_x) + math_ops.reciprocal(safe_x)) return grad * array_ops.where(x_is_not_tiny, dy_dx, 0.5 + zeros)
def _BatchNormGrad(grad_y, x, scale, epsilon, data_format): """Returns the gradients for the 3 inputs of BatchNorm. Args: grad_y: A `Tensor` of 4 dimensions for gradient for y. x: A `Tensor` of 4 dimensions for x. scale: A `Tensor` of 1 dimension for scaling. epsilon: A small float number added to the variance of x. data_format: The data format for input. Either b"NHWC" or b"NCHW". Returns: A tuple (grad_x, grad_scale, grad_offset), where grad_x is the gradient for x, grad_scale the gradient for scale, and grad_offset the gradient for offset. """ if data_format == b"NHWC": keep_dims = False reduce_axis = [0, 1, 2] else: keep_dims = True reduce_axis = [0, 2, 3] shape = [1, array_ops.size(scale), 1, 1] scale = array_ops.reshape(scale, shape) mean_grad_y = math_ops.reduce_mean(grad_y, reduce_axis, keep_dims=keep_dims) mean_x = math_ops.reduce_mean(x, reduce_axis, keep_dims=keep_dims) var_x = math_ops.reduce_mean(math_ops.squared_difference( x, array_ops.stop_gradient(mean_x)), reduce_axis, keep_dims=keep_dims) grad_y_offset = grad_y - mean_grad_y x_offset = x - mean_x mean = math_ops.reduce_mean(grad_y * x_offset, axis=reduce_axis, keep_dims=keep_dims) grad_x = scale * math_ops.rsqrt(var_x + epsilon) * ( grad_y_offset - math_ops.reciprocal(var_x + epsilon) * mean * x_offset) grad_scale = math_ops.rsqrt(var_x + epsilon) * math_ops.reduce_sum( grad_y * x_offset, axis=reduce_axis, keep_dims=keep_dims) if data_format == b"NCHW": grad_scale = array_ops.squeeze(grad_scale) grad_offset = math_ops.reduce_sum(grad_y, axis=reduce_axis) return grad_x, grad_scale, grad_offset
def _BatchNormGrad(grad_y, x, scale, epsilon, data_format): """Returns the gradients for the 3 inputs of BatchNorm. Args: grad_y: A `Tensor` of 4 dimensions for gradient for y. x: A `Tensor` of 4 dimensions for x. scale: A `Tensor` of 1 dimension for scaling. epsilon: A small float number added to the variance of x. data_format: The data format for input. Either b"NHWC" or b"NCHW". Returns: A tuple (grad_x, grad_scale, grad_offset), where grad_x is the gradient for x, grad_scale the gradient for scale, and grad_offset the gradient for offset. """ if data_format == b"NHWC": keep_dims = False reduce_axis = [0, 1, 2] else: keep_dims = True reduce_axis = [0, 2, 3] shape = [1, array_ops.size(scale), 1, 1] scale = array_ops.reshape(scale, shape) mean_grad_y = math_ops.reduce_mean(grad_y, reduce_axis, keep_dims=keep_dims) mean_x = math_ops.reduce_mean(x, reduce_axis, keep_dims=keep_dims) var_x = math_ops.reduce_mean( math_ops.squared_difference(x, array_ops.stop_gradient(mean_x)), reduce_axis, keep_dims=keep_dims) grad_y_offset = grad_y - mean_grad_y x_offset = x - mean_x mean = math_ops.reduce_mean( grad_y * x_offset, axis=reduce_axis, keep_dims=keep_dims) grad_x = scale * math_ops.rsqrt(var_x + epsilon) * ( grad_y_offset - math_ops.reciprocal(var_x + epsilon) * mean * x_offset) grad_scale = math_ops.rsqrt(var_x + epsilon) * math_ops.reduce_sum( grad_y * x_offset, axis=reduce_axis, keep_dims=keep_dims) if data_format == b"NCHW": grad_scale = array_ops.squeeze(grad_scale) grad_offset = math_ops.reduce_sum(grad_y, axis=reduce_axis) return grad_x, grad_scale, grad_offset
def _Grad(op, grad): """A gradient function for IRFFT with the provided `rank` and `rfft_fn`.""" # Generate a simple mask like [1.0, 2.0, ..., 2.0, 1.0] for even-length FFTs # and [1.0, 2.0, ..., 2.0] for odd-length FFTs. To reduce extra ops in the # graph we special-case the situation where the FFT length and last # dimension of the input are known at graph construction time. fft_length = op.inputs[1] is_odd = math_ops.mod(fft_length[-1], 2) input_last_dimension = array_ops.shape(op.inputs[0])[-1] mask = array_ops.concat( [[1.0], 2.0 * array_ops.ones([input_last_dimension - 2 + is_odd]), array_ops.ones([1 - is_odd])], 0) rsize = math_ops.reciprocal(math_ops.to_float(_FFTSizeForGrad(grad, rank))) # The gradient of IRFFT is the RFFT of the incoming gradient times a scaling # factor and a mask. The mask scales the gradient for the Hermitian # symmetric components of the RFFT by a factor of two, since these # components are de-duplicated in the RFFT. rfft = rfft_fn(grad, fft_length) return rfft * math_ops.cast(rsize * mask, dtypes.complex64), None
def _SelfAdjointEigV2Grad(op, grad_e, grad_v): """Gradient for SelfAdjointEigV2.""" e = op.outputs[0] compute_v = op.get_attr("compute_v") # a = op.inputs[0], which satisfies # a[...,:,:] * v[...,:,i] = e[...,i] * v[...,i] with ops.control_dependencies([grad_e, grad_v]): if compute_v: v = op.outputs[1] # Construct the matrix f(i,j) = (i != j ? 1 / (e_i - e_j) : 0). # Notice that because of the term involving f, the gradient becomes # infinite (or NaN in practice) when eigenvalues are not unique. # Mathematically this should not be surprising, since for (k-fold) # degenerate eigenvalues, the corresponding eigenvectors are only defined # up to arbitrary rotation in a (k-dimensional) subspace. f = array_ops.matrix_set_diag( math_ops.reciprocal( array_ops.expand_dims(e, -2) - array_ops.expand_dims(e, -1)), array_ops.zeros_like(e)) grad_a = math_ops.matmul( v, math_ops.matmul( array_ops.matrix_diag(grad_e) + f * math_ops.matmul(v, grad_v, adjoint_a=True), v, adjoint_b=True)) else: _, v = linalg_ops.self_adjoint_eig(op.inputs[0]) grad_a = math_ops.matmul(v, math_ops.matmul( array_ops.matrix_diag(grad_e), v, adjoint_b=True)) # The forward op only depends on the lower triangular part of a, so here we # symmetrize and take the lower triangle grad_a = array_ops.matrix_band_part( grad_a + math_ops.conj(array_ops.matrix_transpose(grad_a)), -1, 0) grad_a = array_ops.matrix_set_diag(grad_a, 0.5 * array_ops.matrix_diag_part(grad_a)) return grad_a
def _SelfAdjointEigV2Grad(op, grad_e, grad_v): """Gradient for SelfAdjointEigV2.""" e = op.outputs[0] compute_v = op.get_attr("compute_v") # a = op.inputs[0], which satisfies # a[...,:,:] * v[...,:,i] = e[...,i] * v[...,i] with ops.control_dependencies([grad_e, grad_v]): if compute_v: v = op.outputs[1] # Construct the matrix f(i,j) = (i != j ? 1 / (e_i - e_j) : 0). # Notice that because of the term involving f, the gradient becomes # infinite (or NaN in practice) when eigenvalues are not unique. # Mathematically this should not be surprising, since for (k-fold) # degenerate eigenvalues, the corresponding eigenvectors are only defined # up to arbitrary rotation in a (k-dimensional) subspace. f = array_ops.matrix_set_diag( math_ops.reciprocal( array_ops.expand_dims(e, -2) - array_ops.expand_dims(e, -1)), array_ops.zeros_like(e)) grad_a = math_ops.matmul( v, math_ops.matmul(array_ops.matrix_diag(grad_e) + f * math_ops.matmul(v, grad_v, adjoint_a=True), v, adjoint_b=True)) else: _, v = linalg_ops.self_adjoint_eig(op.inputs[0]) grad_a = math_ops.matmul( v, math_ops.matmul(array_ops.matrix_diag(grad_e), v, adjoint_b=True)) # The forward op only depends on the lower triangular part of a, so here we # symmetrize and take the lower triangle grad_a = array_ops.matrix_band_part(grad_a + _linalg.adjoint(grad_a), -1, 0) grad_a = array_ops.matrix_set_diag( grad_a, 0.5 * array_ops.matrix_diag_part(grad_a)) return grad_a
def _Grad(op, grad): """A gradient function for IRFFT with the provided `rank` and `rfft_fn`.""" # Generate a simple mask like [1.0, 2.0, ..., 2.0, 1.0] for even-length FFTs # and [1.0, 2.0, ..., 2.0] for odd-length FFTs. To reduce extra ops in the # graph we special-case the situation where the FFT length and last # dimension of the input are known at graph construction time. fft_length = op.inputs[1] is_odd = math_ops.mod(fft_length[-1], 2) input_last_dimension = array_ops.shape(op.inputs[0])[-1] mask = array_ops.concat( [[1.0], 2.0 * array_ops.ones([input_last_dimension - 2 + is_odd]), array_ops.ones([1 - is_odd])], 0) rsize = math_ops.reciprocal( math_ops.to_float(_FFTSizeForGrad(grad, rank))) # The gradient of IRFFT is the RFFT of the incoming gradient times a scaling # factor and a mask. The mask scales the gradient for the Hermitian # symmetric components of the RFFT by a factor of two, since these # components are de-duplicated in the RFFT. rfft = rfft_fn(grad, fft_length) return rfft * math_ops.cast(rsize * mask, dtypes.complex64), None
def QuantizedBatchNormalizationCore(inputs, mean, variance, beta, gamma, variance_epsilon, Q_info, name=None): """ Intrinsic quantization of BatchNormalization layer. Parameters ---------- | inputs : Tensor. | mean : tf.Variable | variance : tf.Variable | beta : tf.Variable | gamma : tf.Variable | variance_epsilon : Float Returns ------- output : Tensor """ with ops.name_scope(name, "batchnorm", [inputs, mean, variance, gamma, beta]): coef = Q_info.quantize( math_ops.sqrt(variance + variance_epsilon)) coef = Q_info.quantize( math_ops.reciprocal(coef)) if gamma is not None: coef = Q_info.quantize(coef*gamma) if beta is not None: const = Q_info.quantize( beta - Q_info.quantize(mean * coef)) else: const = Q_info.quantize(-mean * coef) output = Q_info.quantize( Q_info.quantize(inputs * coef) + const) return output
def _SvdGrad(op, grad_s, grad_u, grad_v): """Gradient for the singular value decomposition.""" # The derivation for the compute_uv=False case, and most of # the derivation for the full_matrices=True case, are in # Giles' paper (see reference at top of file). A derivation for # the full_matrices=False case is available at # https://j-towns.github.io/papers/svd-derivative.pdf a = op.inputs[0] a_shape = a.get_shape().with_rank_at_least(2) grad_s_mat = array_ops.matrix_diag(grad_s) if not op.get_attr("compute_uv"): s, u, v = linalg_ops.svd(a, compute_uv=True) grad_a = math_ops.matmul( u, math_ops.matmul(grad_s_mat, v, adjoint_b=True)) grad_a.set_shape(a_shape) return grad_a full_matrices = op.get_attr("full_matrices") # TODO(rmlarsen): Make this work with complex types. if a.dtype.is_complex: raise NotImplementedError( "SVD gradient is not implemented for complex types and " "compute_uv=True.") grad_u_shape = grad_u.get_shape().with_rank_at_least(2) grad_v_shape = grad_v.get_shape().with_rank_at_least(2) m = a_shape[-2].merge_with(grad_u_shape[-2]) n = a_shape[-1].merge_with(grad_v_shape[-2]) batch_shape = a_shape[:-2].merge_with(grad_u_shape[:-2]).merge_with( grad_v_shape[:-2]) a_shape = batch_shape.concatenate([m, n]) m = a_shape[-2].value n = a_shape[-1].value # TODO(rmlarsen): Make this work with placeholders. if m is None or n is None: raise NotImplementedError( "SVD gradient has not been implemented for input with unknown " "inner matrix shape.") s = op.outputs[0] u = op.outputs[1] v = op.outputs[2] use_adjoint = False if m > n: # Compute the gradient for A^H = V * S^T * U^H, and (implicitly) take the # Hermitian transpose of the gradient at the end. use_adjoint = True m, n = n, m u, v = v, u grad_u, grad_v = grad_v, grad_u with ops.control_dependencies([grad_s, grad_u, grad_v]): if full_matrices and abs(m - n) > 1: raise NotImplementedError( "svd gradient is not implemented for abs(m - n) > 1 " "when full_matrices is True") s_mat = array_ops.matrix_diag(s) s2 = math_ops.square(s) # NOTICE: Because of the term involving f, the gradient becomes # infinite (or NaN in practice) when singular values are not unique. # Mathematically this should not be surprising, since for (k-fold) # degenerate singular values, the corresponding singular vectors are # only defined up a (k-dimensional) subspace. In practice, this can # lead to numerical instability when singular values are close but not # exactly equal. f = array_ops.matrix_set_diag( math_ops.reciprocal( array_ops.expand_dims(s2, -2) - array_ops.expand_dims(s2, -1)), array_ops.zeros_like(s)) s_inv_mat = array_ops.matrix_diag(math_ops.reciprocal(s)) v1 = v[..., :, :m] grad_v1 = grad_v[..., :, :m] u_gu = math_ops.matmul(u, grad_u, adjoint_a=True) v_gv = math_ops.matmul(v1, grad_v1, adjoint_a=True) f_u = f * u_gu f_v = f * v_gv term1_nouv = (grad_s_mat + math_ops.matmul(f_u + _linalg.adjoint(f_u), s_mat) + math_ops.matmul(s_mat, f_v + _linalg.adjoint(f_v))) term1 = math_ops.matmul( u, math_ops.matmul(term1_nouv, v1, adjoint_b=True)) if m == n: grad_a_before_transpose = term1 else: gv1t = array_ops.matrix_transpose(grad_v1) gv1t_v1 = math_ops.matmul(gv1t, v1) term2_nous = gv1t - math_ops.matmul(gv1t_v1, v1, adjoint_b=True) if full_matrices: v2 = v[..., :, m:n] grad_v2 = grad_v[..., :, m:n] v1t_gv2 = math_ops.matmul(v1, grad_v2, adjoint_a=True) term2_nous -= math_ops.matmul(v1t_gv2, v2, adjoint_b=True) u_s_inv = math_ops.matmul(u, s_inv_mat) term2 = math_ops.matmul(u_s_inv, term2_nous) grad_a_before_transpose = term1 + term2 if use_adjoint: grad_a = array_ops.matrix_transpose(grad_a_before_transpose) else: grad_a = grad_a_before_transpose grad_a.set_shape(a_shape) return grad_a
def _BatchNormGrad(grad_y, x, scale, pop_mean, pop_var, epsilon, data_format, is_training=True): """Returns the gradients for the 3 inputs of BatchNorm. Args: grad_y: A `Tensor` of 4 dimensions for gradient for y. x: A `Tensor` of 4 dimensions for x. scale: A `Tensor` of 1 dimension for scaling. pop_mean: A `Tensor` of 1 dimension for the population mean. Only used when is_training=False. pop_var: A `Tensor` of 1 dimension for the population variance. Only used when is_training=False. epsilon: A small float number added to the variance of x. data_format: The data format for input. Either b"NHWC" or b"NCHW". is_training: A bool value to indicate the operation is for training (default) or inference. Returns: A tuple (grad_x, grad_scale, grad_offset), where grad_x is the gradient for x, grad_scale the gradient for scale, and grad_offset the gradient for offset. """ x_dtype = x.dtype.base_dtype if x_dtype == dtypes.float16: # float16 math is too imprecise, so we do the batch norm gradient # computations in float32. x = math_ops.cast(x, dtypes.float32) grad_y = math_ops.cast(grad_y, dtypes.float32) if is_training: if data_format == b"NHWC": keepdims = False reduce_axis = [0, 1, 2] else: keepdims = True reduce_axis = [0, 2, 3] shape = [1, array_ops.size(scale), 1, 1] scale = array_ops.reshape(scale, shape) mean_grad_y = math_ops.reduce_mean(grad_y, reduce_axis, keepdims=keepdims) mean_x = math_ops.reduce_mean(x, reduce_axis, keepdims=keepdims) var_x = math_ops.reduce_mean(math_ops.squared_difference( x, array_ops.stop_gradient(mean_x)), reduce_axis, keepdims=keepdims) grad_y_offset = grad_y - mean_grad_y x_offset = x - mean_x mean = math_ops.reduce_mean(grad_y * x_offset, axis=reduce_axis, keepdims=keepdims) grad_x = scale * math_ops.rsqrt(var_x + epsilon) * ( grad_y_offset - math_ops.reciprocal(var_x + epsilon) * mean * x_offset) grad_scale = math_ops.rsqrt(var_x + epsilon) * math_ops.reduce_sum( grad_y * x_offset, axis=reduce_axis, keepdims=keepdims) if data_format == b"NCHW": grad_scale = array_ops.squeeze(grad_scale) grad_offset = math_ops.reduce_sum(grad_y, axis=reduce_axis) return math_ops.cast(grad_x, x_dtype), grad_scale, grad_offset else: if data_format == b"NHWC": reduce_axis = [0, 1, 2] else: reduce_axis = [0, 2, 3] shape = [1, array_ops.size(pop_mean), 1, 1] pop_mean = array_ops.reshape(pop_mean, shape) pop_var = array_ops.reshape(pop_var, shape) scale = array_ops.reshape(scale, shape) grad_offset = math_ops.reduce_sum(grad_y, axis=reduce_axis) var_rsqrt = math_ops.rsqrt(pop_var + epsilon) grad_scale = math_ops.reduce_sum(grad_y * (x - pop_mean) * var_rsqrt, axis=reduce_axis) grad_x = grad_y * scale * var_rsqrt return math_ops.cast(grad_x, x_dtype), grad_scale, grad_offset
def _ComputeBatchNormCorrections(context, match, freeze_batch_norm_delay, fused_batch_norm): """Computes batch norm correction params. Before batch normalization is frozen: We use batch statistics for batch norm. correction_scale = sigma_b/sigma_mv correction_recip = 1/correction_scale correction_offset = 0 After batch normalization is frozen: correction_scale = sigma_b/sigma_mv correction_recip = 1 correction_offset = gamma*(mu_b/sigma_b-mu_mv/sigma_mv). Batch norm is frozen if global_step > bn_freeze_delay. The corrections ensure that: a) The weights are quantized after scaling by gamma/sigma_mv. This enables smoother training as the scaling on the weights changes slowly, rather than jump across mini-batches b) Changing the values of the corrections allows for one to switch between using batch statistics to using moving mean and average, without requiring changes to batch_norm Args: context: The scope under which we look for batch norm params match: Object containing required batch norm tensors for correction computation. freeze_batch_norm_delay: Delay in steps at which computation switches from regular batch norm to frozen mean and variance. fused_batch_norm: Bool, true if fused batch norm is used. Returns: A tuple of correction_scale, correction_recip, correction_offset """ g = ops.get_default_graph() prefix = '' if not context else context + '/' with g.name_scope(prefix + 'batch_norm_correction'): recip_sigma_mv = math_ops.rsqrt( match.moving_variance_tensor + match.batch_epsilon) recip_sigma = math_ops.rsqrt(match.variance_tensor + match.batch_epsilon) correction_scale = math_ops.divide( recip_sigma_mv, recip_sigma, name='scale_compute') correction_scale = array_ops.identity( correction_scale, name='correction_scale') correction_recip = math_ops.reciprocal( correction_scale, name='reciprocal_compute') correction_offset = math_ops.multiply( match.gamma_tensor, match.mean_tensor * recip_sigma - match.moving_mean_tensor * recip_sigma_mv, name='offset_compute') if freeze_batch_norm_delay is not None: use_mv_avg = math_ops.greater_equal( common.CreateOrGetQuantizationStep(), freeze_batch_norm_delay, name='use_moving_average') else: use_mv_avg = False bn_decay_zero = 0.0 bn_decay_mean_consumers = list(match.bn_decay_mean_tensor.consumers()) bn_decay_var_consumers = list(match.bn_decay_mean_tensor.consumers()) bn_decay_mean_out = utils.smart_cond( use_mv_avg, lambda: bn_decay_zero, lambda: match.bn_decay_mean_tensor, name='freeze_moving_mean') graph_editor.reroute_ts( [bn_decay_mean_out], [match.bn_decay_mean_tensor], can_modify=bn_decay_mean_consumers) if fused_batch_norm is False: bn_decay_var_consumers = list(match.bn_decay_var_tensor.consumers()) bn_decay_var_out = utils.smart_cond( use_mv_avg, lambda: bn_decay_zero, lambda: match.bn_decay_var_tensor, name='freeze_moving_var') graph_editor.reroute_ts( [bn_decay_var_out], [match.bn_decay_var_tensor], can_modify=bn_decay_var_consumers) correction_recip = utils.smart_cond( use_mv_avg, lambda: array_ops.ones(correction_scale.shape), lambda: correction_recip, name='correction_recip') correction_offset = utils.smart_cond( use_mv_avg, lambda: correction_offset, lambda: array_ops.zeros(correction_offset.shape), name='correction_offset') return correction_scale, correction_recip, correction_offset
def _BatchNormGrad(grad_y, x, scale, pop_mean, pop_var, epsilon, data_format, is_training=True): """Returns the gradients for the 3 inputs of BatchNorm. Args: grad_y: A `Tensor` of 4 dimensions for gradient for y. x: A `Tensor` of 4 dimensions for x. scale: A `Tensor` of 1 dimension for scaling. pop_mean: A `Tensor` of 1 dimension for the population mean. Only used when is_training=False. pop_var: A `Tensor` of 1 dimension for the population variance. Only used when is_training=False. epsilon: A small float number added to the variance of x. data_format: The data format for input. Either b"NHWC" or b"NCHW". is_training: A bool value to indicate the operation is for training (default) or inference. Returns: A tuple (grad_x, grad_scale, grad_offset), where grad_x is the gradient for x, grad_scale the gradient for scale, and grad_offset the gradient for offset. """ x_dtype = x.dtype.base_dtype if x_dtype == dtypes.float16: # float16 math is too imprecise, so we do the batch norm gradient # computations in float32. x = math_ops.cast(x, dtypes.float32) grad_y = math_ops.cast(grad_y, dtypes.float32) if is_training: if data_format == b"NHWC": keepdims = False reduce_axis = [0, 1, 2] else: keepdims = True reduce_axis = [0, 2, 3] shape = [1, array_ops.size(scale), 1, 1] scale = array_ops.reshape(scale, shape) mean_grad_y = math_ops.reduce_mean(grad_y, reduce_axis, keepdims=keepdims) mean_x = math_ops.reduce_mean(x, reduce_axis, keepdims=keepdims) var_x = math_ops.reduce_mean( math_ops.squared_difference(x, array_ops.stop_gradient(mean_x)), reduce_axis, keepdims=keepdims) grad_y_offset = grad_y - mean_grad_y x_offset = x - mean_x mean = math_ops.reduce_mean( grad_y * x_offset, axis=reduce_axis, keepdims=keepdims) grad_x = scale * math_ops.rsqrt(var_x + epsilon) * ( grad_y_offset - math_ops.reciprocal(var_x + epsilon) * mean * x_offset) grad_scale = math_ops.rsqrt(var_x + epsilon) * math_ops.reduce_sum( grad_y * x_offset, axis=reduce_axis, keepdims=keepdims) if data_format == b"NCHW": grad_scale = array_ops.squeeze(grad_scale) grad_offset = math_ops.reduce_sum(grad_y, axis=reduce_axis) return math_ops.cast(grad_x, x_dtype), grad_scale, grad_offset else: if data_format == b"NHWC": reduce_axis = [0, 1, 2] else: reduce_axis = [0, 2, 3] shape = [1, array_ops.size(pop_mean), 1, 1] pop_mean = array_ops.reshape(pop_mean, shape) pop_var = array_ops.reshape(pop_var, shape) scale = array_ops.reshape(scale, shape) grad_offset = math_ops.reduce_sum(grad_y, axis=reduce_axis) var_rsqrt = math_ops.rsqrt(pop_var + epsilon) grad_scale = math_ops.reduce_sum( grad_y * (x - pop_mean) * var_rsqrt, axis=reduce_axis) grad_x = grad_y * scale * var_rsqrt return math_ops.cast(grad_x, x_dtype), grad_scale, grad_offset
def _SvdGrad(op, grad_s, grad_u, grad_v): """Gradient for the singular value decomposition.""" # The derivation for the compute_uv=False case, and most of # the derivation for the full_matrices=True case, are in # Giles' paper (see reference at top of file). A derivation for # the full_matrices=False case is available at # https://j-towns.github.io/papers/svd-derivative.pdf a = op.inputs[0] a_shape = a.get_shape().with_rank_at_least(2) grad_s_mat = array_ops.matrix_diag(grad_s) if not op.get_attr("compute_uv"): s, u, v = linalg_ops.svd(a, compute_uv=True) grad_a = math_ops.matmul(u, math_ops.matmul(grad_s_mat, v, adjoint_b=True)) grad_a.set_shape(a_shape) return grad_a full_matrices = op.get_attr("full_matrices") # TODO(rmlarsen): Make this work with complex types. if a.dtype.is_complex: raise NotImplementedError( "SVD gradient is not implemented for complex types and " "compute_uv=True.") grad_u_shape = grad_u.get_shape().with_rank_at_least(2) grad_v_shape = grad_v.get_shape().with_rank_at_least(2) m = a_shape.dims[-2].merge_with(grad_u_shape[-2]) n = a_shape.dims[-1].merge_with(grad_v_shape[-2]) batch_shape = a_shape[:-2].merge_with(grad_u_shape[:-2]).merge_with( grad_v_shape[:-2]) a_shape = batch_shape.concatenate([m, n]) m = a_shape.dims[-2].value n = a_shape.dims[-1].value # TODO(rmlarsen): Make this work with placeholders. if m is None or n is None: raise NotImplementedError( "SVD gradient has not been implemented for input with unknown " "inner matrix shape.") s = op.outputs[0] u = op.outputs[1] v = op.outputs[2] use_adjoint = False if m > n: # Compute the gradient for A^H = V * S^T * U^H, and (implicitly) take the # Hermitian transpose of the gradient at the end. use_adjoint = True m, n = n, m u, v = v, u grad_u, grad_v = grad_v, grad_u with ops.control_dependencies([grad_s, grad_u, grad_v]): if full_matrices and abs(m - n) > 1: raise NotImplementedError( "svd gradient is not implemented for abs(m - n) > 1 " "when full_matrices is True") s_mat = array_ops.matrix_diag(s) s2 = math_ops.square(s) # NOTICE: Because of the term involving f, the gradient becomes # infinite (or NaN in practice) when singular values are not unique. # Mathematically this should not be surprising, since for (k-fold) # degenerate singular values, the corresponding singular vectors are # only defined up a (k-dimensional) subspace. In practice, this can # lead to numerical instability when singular values are close but not # exactly equal. f = array_ops.matrix_set_diag( math_ops.reciprocal( array_ops.expand_dims(s2, -2) - array_ops.expand_dims(s2, -1)), array_ops.zeros_like(s)) s_inv_mat = array_ops.matrix_diag(math_ops.reciprocal(s)) v1 = v[..., :, :m] grad_v1 = grad_v[..., :, :m] u_gu = math_ops.matmul(u, grad_u, adjoint_a=True) v_gv = math_ops.matmul(v1, grad_v1, adjoint_a=True) f_u = f * u_gu f_v = f * v_gv term1_nouv = ( grad_s_mat + math_ops.matmul(f_u + _linalg.adjoint(f_u), s_mat) + math_ops.matmul(s_mat, f_v + _linalg.adjoint(f_v))) term1 = math_ops.matmul(u, math_ops.matmul(term1_nouv, v1, adjoint_b=True)) if m == n: grad_a_before_transpose = term1 else: gv1t = array_ops.matrix_transpose(grad_v1) gv1t_v1 = math_ops.matmul(gv1t, v1) term2_nous = gv1t - math_ops.matmul(gv1t_v1, v1, adjoint_b=True) if full_matrices: v2 = v[..., :, m:n] grad_v2 = grad_v[..., :, m:n] v1t_gv2 = math_ops.matmul(v1, grad_v2, adjoint_a=True) term2_nous -= math_ops.matmul(v1t_gv2, v2, adjoint_b=True) u_s_inv = math_ops.matmul(u, s_inv_mat) term2 = math_ops.matmul(u_s_inv, term2_nous) grad_a_before_transpose = term1 + term2 if use_adjoint: grad_a = array_ops.matrix_transpose(grad_a_before_transpose) else: grad_a = grad_a_before_transpose grad_a.set_shape(a_shape) return grad_a
def _get_matches_hook(self, y_pred_click_ranks): """Return reciprocal click ranks for MRR""" return math_ops.reciprocal(tf.cast(y_pred_click_ranks, tf.float32))
def _ComputeBatchNormCorrections(context, match, freeze_batch_norm_delay): """Computes batch norm correction params. Before batch normalization is frozen: We use batch statistics for batch norm. correction_scale = sigma_b/sigma_mv correction_recip = 1/correction_scale correction_offset = 0 After batch normalization is frozen: correction_scale = sigma_b/sigma_mv correction_recip = 1 correction_offset = gamma*(mu_b/sigma_b-mu_mv/sigma_mv). Batch norm is frozen if global_step > bn_freeze_delay. The corrections ensure that: a) The weights are quantized after scaling by gamma/sigma_mv. This enables smoother training as the scaling on the weights changes slowly, rather than jump across mini-batches b) Changing the values of the corrections allows for one to switch between using batch statistics to using moving mean and average, without requiring changes to batch_norm Args: context: The scope under which we look for batch norm params match: Object containing required batch norm tensors for correction computation. freeze_batch_norm_delay: Delay in steps at which computation switches from regular batch norm to frozen mean and variance. Returns: A tuple of correction_scale, correction_recip, correction_offset """ g = ops.get_default_graph() prefix = '' if not context else context with g.name_scope(prefix + 'batch_norm_correction'): recip_sigma_mv = math_ops.rsqrt(match.moving_variance_tensor + match.batch_epsilon) recip_sigma = math_ops.rsqrt(match.variance_tensor + match.batch_epsilon) correction_scale = math_ops.divide(recip_sigma_mv, recip_sigma, name='scale_compute') correction_scale = array_ops.identity(correction_scale, name='correction_scale') correction_recip = math_ops.reciprocal(correction_scale, name='reciprocal_compute') mv = match.moving_mean_tensor #if match.moving_mean_tensor is not None else 0 correction_offset = math_ops.multiply(match.gamma_tensor, match.mean_tensor * recip_sigma - mv, name='offset_compute') if freeze_batch_norm_delay is not None: use_mv_avg = math_ops.greater_equal( common.CreateOrGetQuantizationStep(), freeze_batch_norm_delay, name='use_moving_average') else: use_mv_avg = False bn_decay_zero = 0.0 bn_decay_mean_consumers = list(match.bn_decay_mean_tensor.consumers()) bn_decay_var_consumers = list(match.bn_decay_mean_tensor.consumers()) bn_decay_mean_out = utils.smart_cond( use_mv_avg, lambda: bn_decay_zero, lambda: match.bn_decay_mean_tensor, name='freeze_moving_mean') common.RerouteTensor(bn_decay_mean_out, match.bn_decay_mean_tensor, can_modify=bn_decay_mean_consumers) bn_decay_var_consumers = list(match.bn_decay_var_tensor.consumers()) bn_decay_var_out = utils.smart_cond(use_mv_avg, lambda: bn_decay_zero, lambda: match.bn_decay_var_tensor, name='freeze_moving_var') common.RerouteTensor(bn_decay_var_out, match.bn_decay_var_tensor, can_modify=bn_decay_var_consumers) correction_recip = utils.smart_cond( use_mv_avg, lambda: array_ops.ones(correction_scale.shape), lambda: correction_recip, name='correction_recip') correction_offset = utils.smart_cond( use_mv_avg, lambda: correction_offset, lambda: array_ops.zeros(correction_offset.shape), name='correction_offset') return correction_scale, correction_recip, correction_offset
def _Log1pGrad(op, grad): """Returns grad * (1/(1 + x)).""" x = op.inputs[0] with ops.control_dependencies([grad]): x = math_ops.conj(x) return grad * math_ops.reciprocal(1 + x)
def _SvdGrad(op, grad_s, grad_u, grad_v): """Gradient for Svd based on Giles' algorithm. Reference at top of file.""" if op.get_attr("compute_uv") and not op.get_attr("full_matrices"): raise NotImplementedError( "SVD gradient is not implemented for compute_uv=True and " "full_matrices=False.") a = op.inputs[0] a_shape = a.get_shape().with_rank_at_least(2) if op.get_attr("compute_uv"): # TODO(rmlarsen): Make this work with complex types. if a.dtype.is_complex: raise NotImplementedError( "SVD gradient is not implemented for complex types and " "compute_uv=True.") grad_u_shape = grad_u.get_shape().with_rank_at_least(2) grad_v_shape = grad_v.get_shape().with_rank_at_least(2) m = a_shape[-2].merge_with(grad_u_shape[-2]) n = a_shape[-1].merge_with(grad_v_shape[-2]) batch_shape = a_shape[:-2].merge_with(grad_u_shape[:-2]).merge_with( grad_v_shape[:-2]) a_shape = batch_shape.concatenate([m, n]) m = a_shape[-2].value n = a_shape[-1].value # TODO(rmlarsen): Make this work with placeholders. if m is None or n is None: raise NotImplementedError( "SVD gradient has not been implemented for input with unknown " "inner matrix shape.") if not op.get_attr("full_matrices") or not op.get_attr("compute_uv"): s, u, v = linalg_ops.svd(a, compute_uv=True, full_matrices=True) else: s = op.outputs[0] u = op.outputs[1] v = op.outputs[2] use_adjoint = False if m > n: # Compute the gradient for A^H = V * S^T * U^H, and (implicitly) take the # Hermitian transpose of the gradient at the end. use_adjoint = True m, n = n, m u, v = v, u grad_u, grad_v = grad_v, grad_u with ops.control_dependencies([grad_s, grad_u, grad_v]): grad_s_mat = array_ops.matrix_diag(grad_s) if not op.get_attr("compute_uv"): if use_adjoint: grad_a = math_ops.matmul( v[..., :, :m], math_ops.matmul(u, grad_s_mat), adjoint_b=True) else: grad_a = math_ops.matmul(u, math_ops.matmul( grad_s_mat, v[..., :, :m], adjoint_b=True)) grad_a.set_shape(a_shape) return grad_a # TODO(rmlarsen): Define a gradient that is numerically stable for # abs(m-n) > 1. Currently this does not work because there are effectively # multiple singular values with value zero. I am not sure if this is a true # instability or if it simply throws off the finite difference gradient # checker. if abs(m - n) > 1: raise NotImplementedError( "svd gradient is not implemented for abs(m - n) > 1") s_mat = array_ops.matrix_diag(s) s2 = math_ops.square(s) # NOTICE: Because of the term involving f, the gradient becomes # infinite (or NaN in practice) when singular values are not unique. # Mathematically this should not be surprising, since for (k-fold) # degenerate singular values, the corresponding singular vectors are # only defined up a (k-dimensional) subspace. In practice, this can # lead to numerical instability when singular values are close but not # exactly equal. f = array_ops.matrix_set_diag( math_ops.reciprocal( array_ops.expand_dims(s2, -2) - array_ops.expand_dims(s2, -1)), array_ops.zeros_like(s)) s_inv_mat = array_ops.matrix_diag(math_ops.reciprocal(s)) u_gu = math_ops.matmul(u, grad_u, adjoint_a=True) v_gv = math_ops.matmul(v, grad_v, adjoint_a=True) if m == n: f_u = f * u_gu f_v = f * v_gv else: dv2 = array_ops.matrix_transpose(v_gv[..., m:n, :m]) - v_gv[..., :m, m:n] f_u = f * u_gu f_v = f * v_gv[..., :m, :m] grad_a_nouv = ( grad_s_mat + math_ops.matmul(f_u + _linalg.adjoint(f_u), s_mat) + math_ops.matmul(s_mat, f_v + _linalg.adjoint(f_v))) if m != n: grad_a_nouv = array_ops.concat( [grad_a_nouv, math_ops.matmul(s_inv_mat, dv2)], -1) if use_adjoint: # Use (U X V^H)^H = V (U X)^H. grad_a = math_ops.matmul( v, math_ops.matmul(u, grad_a_nouv), adjoint_b=True) else: grad_a = math_ops.matmul(u, math_ops.matmul(grad_a_nouv, v, adjoint_b=True)) grad_a.set_shape(a_shape) return grad_a
def weighted_moments(x, axes, frequency_weights, tower_config, name=None, keep_dims=False): """Returns the frequency-weighted mean and variance of `x`. Args: x: A tensor. axes: 1-d tensor of int32 values; these are the axes along which to compute mean and variance. frequency_weights: A tensor of positive weights which can be broadcast with x. name: Name used to scope the operation. keep_dims: Produce moments with the same dimensionality as the input. Returns: Two tensors: `weighted_mean` and `weighted_variance`. """ with ops.name_scope(name, "weighted_moments", [x, frequency_weights, axes]): x = ops.convert_to_tensor(x, name="x") frequency_weights = ops.convert_to_tensor(frequency_weights, name="frequency_weights") # Unlike moments(), this just uses a simpler two-pass method. # See comment in moments() WRT precision; it applies here too. needs_cast = x.dtype == dtypes.float16 if needs_cast: x = math_ops.cast(x, dtypes.float32) if frequency_weights.dtype != x.dtype: frequency_weights = math_ops.cast(frequency_weights, x.dtype) # Note that we use keep_dims=True for our reductions regardless of the arg; # this is so that the results remain broadcast-compatible with the inputs. # Original Code: weighted_input_sum = math_ops.reduce_sum( # frequency_weights * x, axes, name="weighted_input_sum", keepdims=True) nccl_name = "NCCL" if not tower_config.is_test else "NCCL_TEST" shared_name = tf.get_variable_scope().name. \ replace(tower_config.name, tower_config.prefix.format(nccl_name)) device_weighted_input_sum = math_ops.reduce_sum( frequency_weights * x, axes, name="weighted_input_sum", keepdims=True) weighted_input_sum = gen_nccl_ops.nccl_all_reduce( input=device_weighted_input_sum, reduction="sum", num_devices=tower_config.num_devices, shared_name=shared_name) / (1.0 * tower_config.num_devices) # The shape of the weights isn't necessarily the same as x's # shape, just broadcast-compatible with it -- so this expression # performs broadcasting to give a per-item weight, with the same # shape as (freqency_weights * x). This avoids having to reason # through all the broadcast logic to compute a correct # sum_of_weights. broadcasted_weights = frequency_weights + array_ops.zeros_like(x) sum_of_weights = math_ops.reduce_sum(broadcasted_weights, axes, name="sum_of_weights", keepdims=True) divisor = math_ops.reciprocal(sum_of_weights, name="inv_weight_sum") weighted_mean = math_ops.multiply(weighted_input_sum, divisor) # Have the weighted mean; now on to variance: # weighted_distsq = math_ops.reduce_sum( # frequency_weights * math_ops.squared_difference(x, weighted_mean), # axes, # name="weighted_distsq", # keepdims=True) nccl_name = "NCCL" if not tower_config.is_test else "NCCL_TEST" shared_name = tf.get_variable_scope().name. \ replace(tower_config.name, tower_config.prefix.format(nccl_name)) device_weighted_distsq = math_ops.reduce_sum( frequency_weights * math_ops.squared_difference(x, weighted_mean), axes, name="weighted_distsq", keepdims=True) weighted_distsq = gen_nccl_ops.nccl_all_reduce( input=device_weighted_distsq, reduction="sum", num_devices=tower_config.num_devices, shared_name=shared_name) / (1.0 * tower_config.num_devices) weighted_variance = math_ops.multiply(weighted_distsq, divisor) if not keep_dims: weighted_mean = array_ops.squeeze(weighted_mean, axis=axes) weighted_variance = array_ops.squeeze(weighted_variance, axis=axes) if needs_cast: weighted_mean = math_ops.cast(weighted_mean, dtypes.float16) weighted_variance = math_ops.cast(weighted_variance, dtypes.float16) return weighted_mean, weighted_variance
def _SvdGrad(op, grad_s, grad_u, grad_v): """Gradient for Svd based on Giles' algorithm. Reference at top of file.""" if op.get_attr("compute_uv") and not op.get_attr("full_matrices"): raise NotImplementedError( "SVD gradient is not implemented for compute_uv=True and " "full_matrices=False.") a = op.inputs[0] a_shape = a.get_shape().with_rank_at_least(2) if op.get_attr("compute_uv"): # TODO(rmlarsen): Make this work with complex types. if a.dtype.is_complex: raise NotImplementedError( "SVD gradient is not implemented for complex types and " "compute_uv=True.") grad_u_shape = grad_u.get_shape().with_rank_at_least(2) grad_v_shape = grad_v.get_shape().with_rank_at_least(2) m = a_shape[-2].merge_with(grad_u_shape[-2]) n = a_shape[-1].merge_with(grad_v_shape[-2]) batch_shape = a_shape[:-2].merge_with(grad_u_shape[:-2]).merge_with( grad_v_shape[:-2]) a_shape = batch_shape.concatenate([m, n]) m = a_shape[-2].value n = a_shape[-1].value # TODO(rmlarsen): Make this work with placeholders. if m is None or n is None: raise NotImplementedError( "SVD gradient has not been implemented for input with unknown " "inner matrix shape.") if not op.get_attr("full_matrices") or not op.get_attr("compute_uv"): s, u, v = linalg_ops.svd(a, compute_uv=True, full_matrices=True) else: s = op.outputs[0] u = op.outputs[1] v = op.outputs[2] use_adjoint = False if m > n: # Compute the gradient for A^H = V * S^T * U^H, and (implicitly) take the # Hermitian transpose of the gradient at the end. use_adjoint = True m, n = n, m u, v = v, u grad_u, grad_v = grad_v, grad_u with ops.control_dependencies([grad_s, grad_u, grad_v]): grad_s_mat = array_ops.matrix_diag(grad_s) if not op.get_attr("compute_uv"): if use_adjoint: grad_a = math_ops.matmul(v[..., :, :m], math_ops.matmul(u, grad_s_mat), adjoint_b=True) else: grad_a = math_ops.matmul( u, math_ops.matmul(grad_s_mat, v[..., :, :m], adjoint_b=True)) grad_a.set_shape(a_shape) return grad_a # TODO(rmlarsen): Define a gradient that is numerically stable for # abs(m-n) > 1. Currently this does not work because there are effectively # multiple singular values with value zero. I am not sure if this is a true # instability or if it simply throws off the finite difference gradient # checker. if abs(m - n) > 1: raise NotImplementedError( "svd gradient is not implemented for abs(m - n) > 1") s_mat = array_ops.matrix_diag(s) s2 = math_ops.square(s) # NOTICE: Because of the term involving f, the gradient becomes # infinite (or NaN in practice) when singular values are not unique. # Mathematically this should not be surprising, since for (k-fold) # degenerate singular values, the corresponding singular vectors are # only defined up a (k-dimensional) subspace. In practice, this can # lead to numerical instability when singular values are close but not # exactly equal. f = array_ops.matrix_set_diag( math_ops.reciprocal( array_ops.expand_dims(s2, -2) - array_ops.expand_dims(s2, -1)), array_ops.zeros_like(s)) s_inv_mat = array_ops.matrix_diag(math_ops.reciprocal(s)) u_gu = math_ops.matmul(u, grad_u, adjoint_a=True) v_gv = math_ops.matmul(v, grad_v, adjoint_a=True) if m == n: f_u = f * u_gu f_v = f * v_gv else: dv2 = array_ops.matrix_transpose( v_gv[..., m:n, :m]) - v_gv[..., :m, m:n] f_u = f * u_gu f_v = f * v_gv[..., :m, :m] grad_a_nouv = (grad_s_mat + math_ops.matmul(f_u + _linalg.adjoint(f_u), s_mat) + math_ops.matmul(s_mat, f_v + _linalg.adjoint(f_v))) if m != n: grad_a_nouv = array_ops.concat( [grad_a_nouv, math_ops.matmul(s_inv_mat, dv2)], -1) if use_adjoint: # Use (U X V^H)^H = V (U X)^H. grad_a = math_ops.matmul(v, math_ops.matmul(u, grad_a_nouv), adjoint_b=True) else: grad_a = math_ops.matmul( u, math_ops.matmul(grad_a_nouv, v, adjoint_b=True)) grad_a.set_shape(a_shape) return grad_a
def weighted_moments(x, axes, frequency_weights, name=None, keep_dims=False): """Returns the frequency-weighted mean and variance of `x`. Args: x: A tensor. axes: 1-d tensor of int32 values; these are the axes along which to compute mean and variance. frequency_weights: A tensor of positive weights which can be broadcast with x. name: Name used to scope the operation. keep_dims: Produce moments with the same dimensionality as the input. Returns: Two tensors: `weighted_mean` and `weighted_variance`. """ with ops.name_scope(name, "weighted_moments", [x, frequency_weights, axes]): x = ops.convert_to_tensor(x, name="x") frequency_weights = ops.convert_to_tensor( frequency_weights, name="frequency_weights") # Unlike moments(), this just uses a simpler two-pass method. # See comment in moments() WRT precision; it applies here too. needs_cast = x.dtype == dtypes.float16 if needs_cast: x = math_ops.cast(x, dtypes.float32) if frequency_weights.dtype != x.dtype: frequency_weights = math_ops.cast(frequency_weights, x.dtype) # Note that we use keep_dims=True for our reductions regardless of the arg; # this is so that the results remain broadcast-compatible with the inputs. weighted_input_sum = math_ops.reduce_sum( frequency_weights * x, axes, name="weighted_input_sum", keep_dims=True) # The shape of the weights isn't necessarily the same as x's # shape, just broadcast-compatible with it -- so this expression # performs broadcasting to give a per-item weight, with the same # shape as (freqency_weights * x). This avoids having to reason # through all the broadcast logic to compute a correct # sum_of_weights. broadcasted_weights = frequency_weights + array_ops.zeros_like(x) sum_of_weights = math_ops.reduce_sum( broadcasted_weights, axes, name="sum_of_weights", keep_dims=True) divisor = math_ops.reciprocal(sum_of_weights, name="inv_weight_sum") weighted_mean = math_ops.mul(weighted_input_sum, divisor) # Have the weighted mean; now on to variance: weighted_distsq = math_ops.reduce_sum( frequency_weights * math_ops.squared_difference(x, weighted_mean), axes, name="weighted_distsq", keep_dims=True) weighted_variance = math_ops.mul(weighted_distsq, divisor) if not keep_dims: weighted_mean = array_ops.squeeze(weighted_mean, squeeze_dims=axes) weighted_variance = array_ops.squeeze( weighted_variance, squeeze_dims=axes) if needs_cast: weighted_mean = math_ops.cast(weighted_mean, dtypes.float16) weighted_variance = math_ops.cast(weighted_variance, dtypes.float16) return weighted_mean, weighted_variance
def _SafeReciprocal(x, epsilon=1E-20): return x * math_ops.reciprocal(x * x + epsilon)
def _mini_batch_training_op(self, inputs, cluster_idx_list, cluster_centers, cluster_centers_var, total_counts): """Creates an op for training for mini batch case. Args: inputs: list of input Tensors. cluster_idx_list: A vector (or list of vectors). Each element in the vector corresponds to an input row in 'inp' and specifies the cluster id corresponding to the input. cluster_centers: Tensor of cluster centers, possibly normalized. cluster_centers_var: Tensor Ref of cluster centers. total_counts: Tensor Ref of cluster counts. Returns: An op for doing an update of mini-batch k-means. """ update_ops = [] for inp, cluster_idx in zip(inputs, cluster_idx_list): with ops.colocate_with(inp): assert total_counts is not None cluster_idx = array_ops.reshape(cluster_idx, [-1]) # Dedupe the unique ids of cluster_centers being updated so that updates # can be locally aggregated. unique_ids, unique_idx = array_ops.unique(cluster_idx) num_unique_cluster_idx = array_ops.size(unique_ids) # Fetch the old values of counts and cluster_centers. with ops.colocate_with(total_counts): old_counts = array_ops.gather(total_counts, unique_ids) with ops.colocate_with(cluster_centers): old_cluster_centers = array_ops.gather(cluster_centers, unique_ids) # Locally aggregate the increment to counts. count_updates = math_ops.unsorted_segment_sum( array_ops.ones_like( unique_idx, dtype=total_counts.dtype), unique_idx, num_unique_cluster_idx) # Locally compute the sum of inputs mapped to each id. # For a cluster with old cluster value x, old count n, and with data # d_1,...d_k newly assigned to it, we recompute the new value as # x += (sum_i(d_i) - k * x) / (n + k). # Compute sum_i(d_i), see comment above. cluster_center_updates = math_ops.unsorted_segment_sum( inp, unique_idx, num_unique_cluster_idx) # Shape to enable broadcasting count_updates and learning_rate to inp. # It extends the shape with 1's to match the rank of inp. broadcast_shape = array_ops.concat( [ array_ops.reshape(num_unique_cluster_idx, [1]), array_ops.ones( array_ops.reshape(array_ops.rank(inp) - 1, [1]), dtype=dtypes.int32) ], 0) # Subtract k * x, see comment above. cluster_center_updates -= math_ops.cast( array_ops.reshape(count_updates, broadcast_shape), inp.dtype) * old_cluster_centers learning_rate = math_ops.reciprocal( math_ops.cast(old_counts + count_updates, inp.dtype)) learning_rate = array_ops.reshape(learning_rate, broadcast_shape) # scale by 1 / (n + k), see comment above. cluster_center_updates *= learning_rate # Apply the updates. update_counts = state_ops.scatter_add(total_counts, unique_ids, count_updates) update_cluster_centers = state_ops.scatter_add(cluster_centers_var, unique_ids, cluster_center_updates) update_ops.extend([update_counts, update_cluster_centers]) return control_flow_ops.group(*update_ops)
def _mini_batch_training_op(self, inputs, cluster_idx_list, cluster_centers, total_counts): """Creates an op for training for mini batch case. Args: inputs: list of input Tensors. cluster_idx_list: A vector (or list of vectors). Each element in the vector corresponds to an input row in 'inp' and specifies the cluster id corresponding to the input. cluster_centers: Tensor Ref of cluster centers. total_counts: Tensor Ref of cluster counts. Returns: An op for doing an update of mini-batch k-means. """ update_ops = [] for inp, cluster_idx in zip(inputs, cluster_idx_list): with ops.colocate_with(inp): assert total_counts is not None cluster_idx = array_ops.reshape(cluster_idx, [-1]) # Dedupe the unique ids of cluster_centers being updated so that updates # can be locally aggregated. unique_ids, unique_idx = array_ops.unique(cluster_idx) num_unique_cluster_idx = array_ops.size(unique_ids) # Fetch the old values of counts and cluster_centers. with ops.colocate_with(total_counts, ignore_existing=True): old_counts = array_ops.gather(total_counts, unique_ids) # TODO(agarwal): This colocation seems to run into problems. Fix it. with ops.colocate_with(cluster_centers, ignore_existing=True): old_cluster_centers = array_ops.gather( cluster_centers, unique_ids) # Locally aggregate the increment to counts. count_updates = math_ops.unsorted_segment_sum( array_ops.ones_like(unique_idx, dtype=total_counts.dtype), unique_idx, num_unique_cluster_idx) # Locally compute the sum of inputs mapped to each id. # For a cluster with old cluster value x, old count n, and with data # d_1,...d_k newly assigned to it, we recompute the new value as # x += (sum_i(d_i) - k * x) / (n + k). # Compute sum_i(d_i), see comment above. cluster_center_updates = math_ops.unsorted_segment_sum( inp, unique_idx, num_unique_cluster_idx) # Shape to enable broadcasting count_updates and learning_rate to inp. # It extends the shape with 1's to match the rank of inp. broadcast_shape = array_ops.concat([ array_ops.reshape(num_unique_cluster_idx, [1]), array_ops.ones(array_ops.reshape( array_ops.rank(inp) - 1, [1]), dtype=dtypes.int32) ], 0) # Subtract k * x, see comment above. cluster_center_updates -= math_ops.cast( array_ops.reshape(count_updates, broadcast_shape), inp.dtype) * old_cluster_centers learning_rate = math_ops.reciprocal( math_ops.cast(old_counts + count_updates, inp.dtype)) learning_rate = array_ops.reshape(learning_rate, broadcast_shape) # scale by 1 / (n + k), see comment above. cluster_center_updates *= learning_rate # Apply the updates. update_counts = state_ops.scatter_add(total_counts, unique_ids, count_updates) update_cluster_centers = state_ops.scatter_add( cluster_centers, unique_ids, cluster_center_updates) update_ops.extend([update_counts, update_cluster_centers]) return control_flow_ops.group(*update_ops)