def _PowGrad(op, grad): """Returns grad * (y*x^(y-1), z*log(x)).""" x = op.inputs[0] y = op.inputs[1] z = op.outputs[0] sx = array_ops.shape(x) sy = array_ops.shape(y) rx, ry = gen_array_ops.broadcast_gradient_args(sx, sy) x = math_ops.conj(x) y = math_ops.conj(y) z = math_ops.conj(z) gx = array_ops.reshape( math_ops.reduce_sum(grad * y * math_ops.pow(x, y - 1), rx), sx) # Avoid false singularity at x = 0 if x.dtype.is_complex: # real(x) < 0 is fine for the complex case mask = math_ops.not_equal(x, 0) else: # There's no sensible real value to return if x < 0, so return 0 mask = x > 0 safe_x = array_ops.where(mask, x, array_ops.ones_like(x)) log_x = array_ops.where(mask, math_ops.log(safe_x), array_ops.zeros_like(x)) gy = array_ops.reshape(math_ops.reduce_sum(grad * z * log_x, ry), sy) return gx, gy
def _TensorStridedSliceUpdateGrad(op, grad): # pylint:disable=missing-function-docstring begin = op.inputs[1] end = op.inputs[2] strides = op.inputs[3] begin_mask = op.get_attr("begin_mask") end_mask = op.get_attr("end_mask") ellipsis_mask = op.get_attr("ellipsis_mask") new_axis_mask = op.get_attr("new_axis_mask") shrink_axis_mask = op.get_attr("shrink_axis_mask") def Apply(f, *args): return f(*args, begin_mask=begin_mask, end_mask=end_mask, shrink_axis_mask=shrink_axis_mask, new_axis_mask=new_axis_mask, ellipsis_mask=ellipsis_mask) dy = Apply(array_ops.strided_slice, grad, begin, end, strides) dx = Apply(array_ops.tensor_strided_slice_update, grad, begin, end, strides, array_ops.zeros_like(dy)) # The value is potentially broadcast to the shape of the strided slice, so we # may need to adjust dy. slice_shape = array_ops.shape(dy, out_type=begin.dtype) value_shape = array_ops.shape(op.inputs[4], out_type=slice_shape.dtype) _, reduction_axes = gen_array_ops.broadcast_gradient_args( slice_shape, value_shape) dy_reshaped = math_ops.reduce_sum(dy, axis=reduction_axes, keepdims=True) dy = array_ops.reshape(dy_reshaped, value_shape) return dx, None, None, None, dy
def SecureAddGrad(op, grad): """Gradient for Secure Add.""" y = op.inputs[1] skip_input_indices = None try: skip_input_indices = op.skip_input_indices if skip_input_indices is not None and 1 in skip_input_indices and _IsScalar(y): return grad, None except AttributeError: # No gradient skipping, so do the full gradient computation pass x = op.inputs[0] if (isinstance(grad, ops.Tensor) and math_grad._ShapesFullySpecifiedAndEqual(x, y, grad)): return grad, grad sx = array_ops.shape(x) sy = array_ops.shape(y) rx, ry = gen_array_ops.broadcast_gradient_args(sx, sy) if skip_input_indices is not None and 0 in skip_input_indices: gx = None else: gx = array_ops.reshape(SecureSum(grad, rx), sx) if skip_input_indices is not None and 1 in skip_input_indices: gy = None else: gy = array_ops.reshape(SecureSum(grad, ry), sy) return (gx, gy)
def _ComplexGrad(op, grad): """Returns the real and imaginary components of 'grad', respectively.""" x = op.inputs[0] y = op.inputs[1] sx = array_ops.shape(x) sy = array_ops.shape(y) rx, ry = gen_array_ops.broadcast_gradient_args(sx, sy) return (array_ops.reshape(math_ops.reduce_sum(math_ops.real(grad), rx), sx), array_ops.reshape(math_ops.reduce_sum(math_ops.imag(grad), ry), sy))
def _SecureDivideGrad(op, grad): x = op.inputs[0] y = op.inputs[1] sx = array_ops.shape(x) sy = array_ops.shape(y) rx, ry = gen_array_ops.broadcast_gradient_args(sx, sy) dX = array_ops.reshape(SecureSum(SecureTruediv(grad, y), rx), sx) temp = SecureTruediv(SecureTruediv(SecureNeg(x), y), y) dY = array_ops.reshape(SecureSum(SecureMul(grad, temp), ry), sy) return (dX, dY)
def _BroadcastToGrad(op, grad): input_value = op.inputs[0] broadcast_shape = op.inputs[1] input_value_shape = array_ops.shape(input_value) _, reduction_axes = gen_array_ops.broadcast_gradient_args(broadcast_shape, input_value_shape) updates_grad_reshaped = math_ops.reduce_sum(grad, axis=reduction_axes, keepdims=True) updates_grad = array_ops.reshape(updates_grad_reshaped, input_value_shape) return [updates_grad, None]
def _BroadcastToGrad(op, grad): input_value = op.inputs[0] broadcast_shape = op.inputs[1] input_value_shape = array_ops.shape(input_value) _, reduction_axes = gen_array_ops.broadcast_gradient_args( broadcast_shape, input_value_shape) updates_grad_reshaped = math_ops.reduce_sum(grad, axis=reduction_axes, keepdims=True) updates_grad = array_ops.reshape(updates_grad_reshaped, input_value_shape) return [updates_grad, None]
def _SubGrad(op, grad): """Gradient for Sub.""" x = op.inputs[0] y = op.inputs[1] if (isinstance(grad, ops.Tensor) and _ShapesFullySpecifiedAndEqual(x, y, grad)): return grad, -grad sx = array_ops.shape(x) sy = array_ops.shape(y) rx, ry = gen_array_ops.broadcast_gradient_args(sx, sy) return (array_ops.reshape(math_ops.reduce_sum(grad, rx), sx), array_ops.reshape(-math_ops.reduce_sum(grad, ry), sy))
def _MpcDivideGrad(op, grad): x = op.inputs[0] y = op.inputs[1] sx = array_ops.shape(x) sy = array_ops.shape(y) rx, ry = gen_array_ops.broadcast_gradient_args(sx, sy) x = math_ops.conj(x) y = math_ops.conj(y) dX = array_ops.reshape(math_ops.reduce_sum(MpcTruediv(grad, y), rx), sx) temp = MpcTruediv(MpcTruediv(-x, y), y) dY = array_ops.reshape(math_ops.reduce_sum(MpcMul(grad, temp), ry), sy) return (dX, dY)
def _clip_by_value_grad(op, grad): """Returns grad of clip_by_value.""" x = op.inputs[0] y = op.inputs[1] z = op.inputs[2] gdtype = grad.dtype sx = array_ops.shape(x) sy = array_ops.shape(y) sz = array_ops.shape(z) gradshape = array_ops.shape(grad) zeros = array_ops.zeros(gradshape, gdtype) xymask = math_ops.less(x, y) xzmask = math_ops.greater(x, z) rx, ry = gen_array_ops.broadcast_gradient_args(sx, sy) rx, rz = gen_array_ops.broadcast_gradient_args(sx, sz) xgrad = array_ops.where(math_ops.logical_or(xymask, xzmask), zeros, grad) ygrad = array_ops.where(xymask, grad, zeros) zgrad = array_ops.where(xzmask, grad, zeros) gx = array_ops.reshape(math_ops.reduce_sum(xgrad, rx), sx) gy = array_ops.reshape(math_ops.reduce_sum(ygrad, ry), sy) gz = array_ops.reshape(math_ops.reduce_sum(zgrad, rz), sz) return (gx, gy, gz)
def _DivGrad(op, grad): """The gradient for the Div operator.""" x = op.inputs[0] y = op.inputs[1] sx = array_ops.shape(x) sy = array_ops.shape(y) rx, ry = gen_array_ops.broadcast_gradient_args(sx, sy) x = math_ops.conj(x) y = math_ops.conj(y) return (array_ops.reshape(math_ops.reduce_sum(math_ops.div(grad, y), rx), sx), array_ops.reshape( math_ops.reduce_sum(grad * math_ops.div(math_ops.div(-x, y), y), ry), sy))
def _FloorModGrad(op, grad): """Returns grad * (1, -floor(x/y)).""" x = math_ops.conj(op.inputs[0]) y = math_ops.conj(op.inputs[1]) sx = array_ops.shape(x) sy = array_ops.shape(y) rx, ry = gen_array_ops.broadcast_gradient_args(sx, sy) floor_xy = math_ops.floor_div(x, y) gx = array_ops.reshape(math_ops.reduce_sum(grad, rx), sx) gy = array_ops.reshape( math_ops.reduce_sum(grad * math_ops.negative(floor_xy), ry), sy) return gx, gy
def _SquaredDifferenceGrad(op, grad): """Returns the gradient for (x-y)^2.""" x = op.inputs[0] y = op.inputs[1] sx = array_ops.shape(x) sy = array_ops.shape(y) rx, ry = gen_array_ops.broadcast_gradient_args(sx, sy) with ops.control_dependencies([grad]): # The parens ensure that if grad is IndexedSlices, it'll get multiplied by # Tensor (not a number like 2.0) which causes it to convert to Tensor. x_grad = math_ops.scalar_mul(2.0, grad) * (x - y) return (array_ops.reshape(math_ops.reduce_sum(x_grad, rx), sx), -array_ops.reshape(math_ops.reduce_sum(x_grad, ry), sy))
def SecureSubGrad(op, grad): """Gradient for Secure Sub.""" x = op.inputs[0] y = op.inputs[1] if (isinstance(grad, ops.Tensor) and math_grad._ShapesFullySpecifiedAndEqual(x, y, grad)): return grad, SecureNeg(grad) sx = array_ops.shape(x) sy = array_ops.shape(y) rx, ry = gen_array_ops.broadcast_gradient_args(sx, sy) return (array_ops.reshape(SecureSum(grad, rx), sx), array_ops.reshape(SecureNeg(SecureSum(grad, ry)), sy))
def _NextAfterGrad(op, grad): """Returns gradient of nextafter(x1, x2) with respect to x1 and x2.""" x1 = op.inputs[0] x2 = op.inputs[1] s_x1 = array_ops.shape(x1) s_x2 = array_ops.shape(x2) r_x1, r_x2 = gen_array_ops.broadcast_gradient_args(s_x1, s_x2) with ops.control_dependencies([grad]): partial_x1 = array_ops.ones(s_x1, dtype=x1.dtype) partial_x2 = array_ops.zeros(s_x2, dtype=x2.dtype) return (array_ops.reshape( math_ops.reduce_sum(partial_x1 * grad, r_x1), s_x1), array_ops.reshape( math_ops.reduce_sum(partial_x2 * grad, r_x2), s_x2))
def _XDivyGrad(op, grad): """Returns gradient of xdivy(x, y) with respect to x and y.""" x = op.inputs[0] y = op.inputs[1] sx = array_ops.shape(x) sy = array_ops.shape(y) rx, ry = gen_array_ops.broadcast_gradient_args(sx, sy) with ops.control_dependencies([grad]): not_zero_x = math_ops.cast( math_ops.not_equal(x, math_ops.cast(0., dtype=x.dtype)), dtype=x.dtype) partial_x = gen_math_ops.xdivy(not_zero_x, y) partial_y = gen_math_ops.xdivy(math_ops.negative(x), y**2) return (array_ops.reshape(math_ops.reduce_sum(partial_x * grad, rx), sx), array_ops.reshape(math_ops.reduce_sum(partial_y * grad, ry), sy))
def _DivNoNanGrad(op, grad): """DivNoNan op gradient.""" x = op.inputs[0] y = op.inputs[1] sx = array_ops.shape(x) sy = array_ops.shape(y) rx, ry = gen_array_ops.broadcast_gradient_args(sx, sy) x = math_ops.conj(x) y = math_ops.conj(y) return (array_ops.reshape( math_ops.reduce_sum(math_ops.div_no_nan(grad, y), rx), sx), array_ops.reshape( math_ops.reduce_sum( grad * math_ops.div_no_nan(math_ops.div_no_nan(-x, y), y), ry), sy))
def _IgammaGrad(op, grad): """Returns gradient of igamma(a, x) with respect to x.""" # TODO(ebrevdo): Perhaps add the derivative w.r.t. a a = op.inputs[0] x = op.inputs[1] sa = array_ops.shape(a) sx = array_ops.shape(x) unused_ra, rx = gen_array_ops.broadcast_gradient_args(sa, sx) # Perform operations in log space before summing, because Gamma(a) # and Gamma'(a) can grow large. partial_x = math_ops.exp(-x + (a - 1) * math_ops.log(x) - math_ops.lgamma(a)) # TODO(b/36815900): Mark None return values as NotImplemented return (None, array_ops.reshape( math_ops.reduce_sum(partial_x * grad, rx), sx))
def _MaximumMinimumGrad(op, grad, selector_op): """Factor out the code for the gradient of Maximum or Minimum.""" x = op.inputs[0] y = op.inputs[1] gdtype = grad.dtype sx = array_ops.shape(x) sy = array_ops.shape(y) gradshape = array_ops.shape(grad) zeros = array_ops.zeros(gradshape, gdtype) xmask = selector_op(x, y) rx, ry = gen_array_ops.broadcast_gradient_args(sx, sy) xgrad = array_ops.where(xmask, grad, zeros) ygrad = array_ops.where(xmask, zeros, grad) gx = array_ops.reshape(math_ops.reduce_sum(xgrad, rx), sx) gy = array_ops.reshape(math_ops.reduce_sum(ygrad, ry), sy) return (gx, gy)
def _IgammaGrad(op, grad): """Returns gradient of igamma(a, x) with respect to a and x.""" a = op.inputs[0] x = op.inputs[1] sa = array_ops.shape(a) sx = array_ops.shape(x) ra, rx = gen_array_ops.broadcast_gradient_args(sa, sx) with ops.control_dependencies([grad]): partial_a = gen_math_ops.igamma_grad_a(a, x) # Perform operations in log space before summing, because Gamma(a) # and Gamma'(a) can grow large. partial_x = math_ops.exp(-x + (a - 1) * math_ops.log(x) - math_ops.lgamma(a)) return (array_ops.reshape(math_ops.reduce_sum(partial_a * grad, ra), sa), array_ops.reshape(math_ops.reduce_sum(partial_x * grad, rx), sx))
def _BroadcastToGrad(op, grad): input_value = op.inputs[0] broadcast_shape = op.inputs[1] input_value_shape = array_ops.shape(input_value) if not isinstance(broadcast_shape, ops.EagerTensor): broadcast_shape_static = tensor_shape.TensorShape( pywrap_tf_session.TF_TryEvaluateConstant_wrapper( broadcast_shape.graph._c_graph, broadcast_shape._as_tf_output())) # pylint: disable=protected-access if broadcast_shape_static.is_fully_defined(): broadcast_shape = constant_op.constant( broadcast_shape_static.as_list(), dtype=dtypes.int32) _, reduction_axes = gen_array_ops.broadcast_gradient_args( broadcast_shape, input_value_shape) updates_grad_reshaped = math_ops.reduce_sum( grad, axis=reduction_axes, keepdims=True) updates_grad = array_ops.reshape(updates_grad_reshaped, input_value_shape) return [updates_grad, None]
def _ZetaGrad(op, grad): """Returns gradient of zeta(x, q) with respect to x and q.""" # TODO(tillahoffmann): Add derivative with respect to x x = op.inputs[0] q = op.inputs[1] # Broadcast gradients sx = array_ops.shape(x) sq = array_ops.shape(q) unused_rx, rq = gen_array_ops.broadcast_gradient_args(sx, sq) # Evaluate gradient with ops.control_dependencies([grad]): x = math_ops.conj(x) q = math_ops.conj(q) partial_q = -x * math_ops.zeta(x + 1, q) # TODO(b/36815900): Mark None return values as NotImplemented return (None, array_ops.reshape(math_ops.reduce_sum(partial_q * grad, rq), sq))
def _PolygammaGrad(op, grad): """Returns gradient of psi(n, x) with respect to n and x.""" # TODO(tillahoffmann): Add derivative with respect to n n = op.inputs[0] x = op.inputs[1] # Broadcast gradients sn = array_ops.shape(n) sx = array_ops.shape(x) unused_rn, rx = gen_array_ops.broadcast_gradient_args(sn, sx) # Evaluate gradient with ops.control_dependencies([grad]): n = math_ops.conj(n) x = math_ops.conj(x) partial_x = math_ops.polygamma(n + 1, x) # TODO(b/36815900): Mark None return values as NotImplemented return (None, array_ops.reshape(math_ops.reduce_sum(partial_x * grad, rx), sx))
def SecureMulGrad(op, grad): """The gradient of Secure multiplication.""" x = op.inputs[0] y = op.inputs[1] if (isinstance(grad, ops.Tensor) and math_grad._ShapesFullySpecifiedAndEqual(x, y, grad) and grad.dtype in (dtypes.int32, dtypes.float32, dtypes.float64, dtypes.string)): return (SecureMul(grad, y), SecureMul(grad, x)) assert x.dtype.base_dtype == y.dtype.base_dtype, (x.dtype, " Secure_vs. ", y.dtype) sx = array_ops.shape(x) sy = array_ops.shape(y) rx, ry = gen_array_ops.broadcast_gradient_args(sx, sy) return (array_ops.reshape(SecureSum(SecureMul(grad, y), rx), sx), array_ops.reshape(SecureSum(SecureMul(x, grad), ry), sy))
def _MulGrad(op, grad): """The gradient of scalar multiplication.""" x = op.inputs[0] y = op.inputs[1] if (isinstance(grad, ops.Tensor) and _ShapesFullySpecifiedAndEqual(x, y, grad) and grad.dtype in (dtypes.int32, dtypes.float32)): return gen_math_ops.mul(grad, y), gen_math_ops.mul(grad, x) assert x.dtype.base_dtype == y.dtype.base_dtype, (x.dtype, " vs. ", y.dtype) sx = array_ops.shape(x) sy = array_ops.shape(y) rx, ry = gen_array_ops.broadcast_gradient_args(sx, sy) x = math_ops.conj(x) y = math_ops.conj(y) return (array_ops.reshape( math_ops.reduce_sum(gen_math_ops.mul(grad, y), rx), sx), array_ops.reshape( math_ops.reduce_sum(gen_math_ops.mul(x, grad), ry), sy))
def grad_fn(grad): grad_lhs_name = name + 'GradA' grad_rhs_name = name + 'GradB' # Reduce along the broadcasted batch dimensions, if broadcasting is # required. lhs_shape = lhs.shape.as_list() rhs_shape = rhs.shape.as_list() lhs_reduction = None rhs_reduction = None if lhs_shape != rhs_shape: lhs_reduction, rhs_reduction = gen_array_ops.broadcast_gradient_args( lhs_shape[:-2], rhs_shape[:-2]) if not transpose_a and not transpose_b: grad_lhs = grad_a_fn(grad, rhs, False, True, grad_lhs_name, lhs_shape, lhs_reduction, backend_config_pb2.TRAINING_BWD) grad_rhs = grad_b_fn(lhs, grad, True, False, grad_rhs_name, rhs_shape, rhs_reduction, backend_config_pb2.TRAINING_WU) elif not transpose_a and transpose_b: grad_lhs = grad_a_fn(grad, rhs, False, False, grad_lhs_name, lhs_shape, lhs_reduction, backend_config_pb2.TRAINING_BWD) grad_rhs = grad_b_fn(grad, lhs, True, False, grad_rhs_name, rhs_shape, rhs_reduction, backend_config_pb2.TRAINING_WU) elif transpose_a and not transpose_b: grad_lhs = grad_a_fn(rhs, grad, False, True, grad_lhs_name, lhs_shape, lhs_reduction, backend_config_pb2.TRAINING_BWD) grad_rhs = grad_b_fn(lhs, grad, False, False, grad_rhs_name, rhs_shape, rhs_reduction, backend_config_pb2.TRAINING_WU) elif transpose_a and transpose_b: grad_lhs = grad_a_fn(rhs, grad, True, True, grad_lhs_name, lhs_shape, lhs_reduction, backend_config_pb2.TRAINING_BWD) grad_rhs = grad_b_fn(grad, lhs, True, True, grad_rhs_name, rhs_shape, rhs_reduction, backend_config_pb2.TRAINING_WU) return [grad_lhs, grad_rhs]
def _BroadcastToGrad(op, grad): input_value = op.inputs[0] broadcast_shape = op.inputs[1] shape_dtype = dtypes.int32 if isinstance(broadcast_shape, ops.Tensor): shape_dtype = broadcast_shape.dtype input_value_shape = array_ops.shape(input_value, out_type=shape_dtype) if not isinstance(broadcast_shape, ops.EagerTensor): broadcast_shape_static = tensor_shape.TensorShape( tensor_util.try_evaluate_constant(broadcast_shape)) if broadcast_shape_static.is_fully_defined(): broadcast_shape = constant_op.constant( broadcast_shape_static.as_list(), dtype=shape_dtype) _, reduction_axes = gen_array_ops.broadcast_gradient_args( broadcast_shape, input_value_shape) updates_grad_reshaped = math_ops.reduce_sum(grad, axis=reduction_axes, keepdims=True) updates_grad = array_ops.reshape(updates_grad_reshaped, input_value_shape) return [updates_grad, None]
def _PowGrad(op, grad): """Returns grad * (y*x^(y-1), z*log(x)).""" x = op.inputs[0] y = op.inputs[1] z = op.outputs[0] sx = array_ops.shape(x) sy = array_ops.shape(y) rx, ry = gen_array_ops.broadcast_gradient_args(sx, sy) x = math_ops.conj(x) y = math_ops.conj(y) z = math_ops.conj(z) gx = array_ops.reshape( math_ops.reduce_sum(grad * y * math_ops.pow(x, y - 1), rx), sx) # Avoid false singularity at x = 0 if x.dtype.is_complex: # real(x) < 0 is fine for the complex case log_x = array_ops.where( math_ops.not_equal(x, 0), math_ops.log(x), array_ops.zeros_like(x)) else: # There's no sensible real value to return if x < 0, so return 0 log_x = array_ops.where(x > 0, math_ops.log(x), array_ops.zeros_like(x)) gy = array_ops.reshape(math_ops.reduce_sum(grad * z * log_x, ry), sy) return gx, gy
def _BetaincGrad(op, grad): """Returns gradient of betainc(a, b, x) with respect to x.""" # TODO(ebrevdo): Perhaps add the derivative w.r.t. a, b a, b, x = op.inputs # two cases: x is a scalar and a/b are same-shaped tensors, or vice # versa; so its sufficient to check against shape(a). sa = array_ops.shape(a) sx = array_ops.shape(x) _, rx = gen_array_ops.broadcast_gradient_args(sa, sx) # Perform operations in log space before summing, because terms # can grow large. log_beta = (gen_math_ops.lgamma(a) + gen_math_ops.lgamma(b) - gen_math_ops.lgamma(a + b)) partial_x = math_ops.exp((b - 1) * math_ops.log(1 - x) + (a - 1) * math_ops.log(x) - log_beta) # TODO(b/36815900): Mark None return values as NotImplemented return ( None, # da None, # db array_ops.reshape(math_ops.reduce_sum(partial_x * grad, rx), sx))
def _BetaincGrad(op, grad): """Returns gradient of betainc(a, b, x) with respect to x.""" # TODO(ebrevdo): Perhaps add the derivative w.r.t. a, b a, b, x = op.inputs # two cases: x is a scalar and a/b are same-shaped tensors, or vice # versa; so its sufficient to check against shape(a). sa = array_ops.shape(a) sx = array_ops.shape(x) _, rx = gen_array_ops.broadcast_gradient_args(sa, sx) # Perform operations in log space before summing, because terms # can grow large. log_beta = ( gen_math_ops.lgamma(a) + gen_math_ops.lgamma(b) - gen_math_ops.lgamma(a + b)) partial_x = math_ops.exp((b - 1) * math_ops.log(1 - x) + (a - 1) * math_ops.log(x) - log_beta) # TODO(b/36815900): Mark None return values as NotImplemented return ( None, # da None, # db array_ops.reshape(math_ops.reduce_sum(partial_x * grad, rx), sx))
def _GetGradientArgs(self, xs, ys): return self.evaluate(broadcast_gradient_args(xs, ys))
def _reduce_and_reshape_grad(g, t): """Returns the gradient, sum-reduced and reshaped to `t`'s shape.""" shape = array_ops.shape(t) g_shape = array_ops.shape(g) bcast_dims, _ = gen_array_ops.broadcast_gradient_args(shape, g_shape) return array_ops.reshape(math_ops.reduce_sum(g, bcast_dims), shape)
def _GetGradientArgs(self, xs, ys): with self.cached_session() as sess: return sess.run(broadcast_gradient_args(xs, ys))
def _StatelessParameterizedTruncatedNormalGrad(op, grad): # pylint: disable=invalid-name """Returns the gradient of a TruncatedNormal sample w.r.t. parameters. The gradient is computed using implicit differentiation (Figurnov et al., 2018). Args: op: A `StatelessParameterizedTruncatedNormal` operation. We assume that the inputs to the operation are `shape`, `seed`, `mean`, `stddev`, `minval`, and `maxval` tensors, and the output is the `sample` tensor. grad: The incoming gradient `dloss / dsample` of the same shape as `op.outputs[0]`. Returns: A list of `Tensor` with derivates with respect to each parameter. References: Implicit Reparameterization Gradients: [Figurnov et al., 2018] (http://papers.nips.cc/paper/7326-implicit-reparameterization-gradients) ([pdf] (http://papers.nips.cc/paper/7326-implicit-reparameterization-gradients.pdf)) """ shape = op.inputs[0] mean = op.inputs[2] stddev = op.inputs[3] minval = op.inputs[4] maxval = op.inputs[5] sample = op.outputs[0] with ops.control_dependencies([grad]): minval_std = (minval - mean) / stddev maxval_std = (maxval - mean) / stddev sample_std = (sample - mean) / stddev cdf_sample = (_Ndtr(sample_std) - _Ndtr(minval_std)) / ( _Ndtr(maxval_std) - _Ndtr(minval_std)) # Clip to avoid zero argument for log_cdf expression tiny = np.finfo(mean.dtype.as_numpy_dtype).tiny eps = np.finfo(mean.dtype.as_numpy_dtype).eps cdf_sample = clip_ops.clip_by_value(cdf_sample, tiny, 1 - eps) dmaxval = math_ops.exp(0.5 * (sample_std ** 2 - maxval_std ** 2) + math_ops.log(cdf_sample)) dminval = math_ops.exp(0.5 * (sample_std ** 2 - minval_std ** 2) + math_ops.log1p(-cdf_sample)) dmean = array_ops.ones_like(sample_std) dstddev = sample_std # Reduce over extra dimensions caused by `shape`. We need to get the # difference in rank from shape vs. the broadcasted rank. mean_shape = array_ops.shape(mean) stddev_shape = array_ops.shape(stddev) minval_shape = array_ops.shape(minval) maxval_shape = array_ops.shape(maxval) broadcast_shape = array_ops.broadcast_dynamic_shape( mean_shape, stddev_shape) broadcast_shape = array_ops.broadcast_dynamic_shape( minval_shape, broadcast_shape) broadcast_shape = array_ops.broadcast_dynamic_shape( maxval_shape, broadcast_shape) extra_dims = math_ops.range( array_ops.size(shape) - array_ops.size(broadcast_shape)) grad_mean = math_ops.reduce_sum(grad * dmean, axis=extra_dims) grad_stddev = math_ops.reduce_sum(grad * dstddev, axis=extra_dims) grad_minval = math_ops.reduce_sum(grad * dminval, axis=extra_dims) grad_maxval = math_ops.reduce_sum(grad * dmaxval, axis=extra_dims) _, rmean = gen_array_ops.broadcast_gradient_args( broadcast_shape, mean_shape) _, rstddev = gen_array_ops.broadcast_gradient_args( broadcast_shape, stddev_shape) _, rminval = gen_array_ops.broadcast_gradient_args( broadcast_shape, minval_shape) _, rmaxval = gen_array_ops.broadcast_gradient_args( broadcast_shape, maxval_shape) grad_mean = array_ops.reshape( math_ops.reduce_sum(grad_mean, axis=rmean, keepdims=True), mean_shape) grad_stddev = array_ops.reshape( math_ops.reduce_sum(grad_stddev, axis=rstddev, keepdims=True), stddev_shape) grad_minval = array_ops.reshape( math_ops.reduce_sum(grad_minval, axis=rminval, keepdims=True), minval_shape) grad_maxval = array_ops.reshape( math_ops.reduce_sum(grad_maxval, axis=rmaxval, keepdims=True), maxval_shape) # The first two inputs are shape. return (None, None, grad_mean, grad_stddev, grad_minval, grad_maxval)
def _GetGradientArgs(self, xs, ys): with self.test_session(use_gpu=True) as sess: return sess.run(broadcast_gradient_args(xs, ys))