def _DigammaGrad(op, grad): """Compute gradient of the digamma function with respect to its argument.""" x = op.inputs[0] with ops.control_dependencies([grad]): x = math_ops.conj(x) return grad * math_ops.polygamma(array_ops.constant(1, dtype=x.dtype), x)
def _PolygammaGrad(op, grad): """Returns gradient of psi(n, x) with respect to n and x.""" # TODO(tillahoffmann): Add derivative with respect to n n = op.inputs[0] x = op.inputs[1] # Broadcast gradients sn = array_ops.shape(n) sx = array_ops.shape(x) unused_rn, rx = gen_array_ops._broadcast_gradient_args(sn, sx) # Evaluate gradient with ops.control_dependencies([grad.op]): partial_x = math_ops.polygamma(n + 1, x) return (None, array_ops.reshape(math_ops.reduce_sum(partial_x * grad, rx), sx))
def _PolygammaGrad(op, grad): """Returns gradient of psi(n, x) with respect to n and x.""" # TODO(tillahoffmann): Add derivative with respect to n n = op.inputs[0] x = op.inputs[1] # Broadcast gradients sn = array_ops.shape(n) sx = array_ops.shape(x) # pylint: disable=protected-access unused_rn, rx = gen_array_ops._broadcast_gradient_args(sn, sx) # pylint: enable=protected-access # Evaluate gradient with ops.control_dependencies([grad]): n = math_ops.conj(n) x = math_ops.conj(x) partial_x = math_ops.polygamma(n + 1, x) # TODO(b/36815900): Mark None return values as NotImplemented return (None, array_ops.reshape(math_ops.reduce_sum(partial_x * grad, rx), sx))
def _polygamma(n, x): return math_ops.polygamma(n, x)
def safe_polygamma(x, y): return math_ops.polygamma( math_ops.round(clip_ops.clip_by_value(y, 1, 10)), x * x + 1)