예제 #1
0
def _BetaincGrad(op, grad):
    """Returns gradient of betainc(a, b, x) with respect to x."""
    # TODO(ebrevdo): Perhaps add the derivative w.r.t. a, b
    a, b, x = op.inputs

    # two cases: x is a scalar and a/b are same-shaped tensors, or vice
    # versa; so its sufficient to check against shape(a).
    sa = array_ops.shape(a)
    sx = array_ops.shape(x)
    # pylint: disable=protected-access
    _, rx = gen_array_ops._broadcast_gradient_args(sa, sx)
    # pylint: enable=protected-access

    # Perform operations in log space before summing, because terms
    # can grow large.
    log_beta = (gen_math_ops.lgamma(a) + gen_math_ops.lgamma(b) -
                gen_math_ops.lgamma(a + b))
    partial_x = math_ops.exp((b - 1) * math_ops.log(1 - x) +
                             (a - 1) * math_ops.log(x) - log_beta)

    # TODO(b/36815900): Mark None return values as NotImplemented
    return (
        None,  # da
        None,  # db
        array_ops.reshape(math_ops.reduce_sum(partial_x * grad, rx), sx))
예제 #2
0
def _BetaincGrad(op, grad):
  """Returns gradient of betainc(a, b, x) with respect to x."""
  # TODO(ebrevdo): Perhaps add the derivative w.r.t. a, b
  a, b, x = op.inputs

  # two cases: x is a scalar and a/b are same-shaped tensors, or vice
  # versa; so its sufficient to check against shape(a).
  sa = array_ops.shape(a)
  sx = array_ops.shape(x)
  # pylint: disable=protected-access
  _, rx = gen_array_ops._broadcast_gradient_args(sa, sx)
  # pylint: enable=protected-access

  # Perform operations in log space before summing, because terms
  # can grow large.
  log_beta = (
      gen_math_ops.lgamma(a) + gen_math_ops.lgamma(b) -
      gen_math_ops.lgamma(a + b))
  partial_x = math_ops.exp((b - 1) * math_ops.log(1 - x) +
                           (a - 1) * math_ops.log(x) - log_beta)

  # TODO(b/36815900): Mark None return values as NotImplemented
  return (
      None,  # da
      None,  # db
      array_ops.reshape(math_ops.reduce_sum(partial_x * grad, rx), sx))
 def __log_surface_area(self):
     return math.log(2) + ((self._dim + 1) / 2) * \
            math.log(math.pi) - gen_math_ops.lgamma(
         math_ops.cast((self._dim + 1) / 2, dtype=self.dtype))