예제 #1
0
def logpdf(x, loc=0, scale=1):
    x, loc, scale = _promote_args_inexact("norm.logpdf", x, loc, scale)
    two = _constant_like(x, 2)
    scale_sqrd = lax.pow(scale, two)
    log_normalizer = lax.log(lax.mul(_constant_like(x, 2 * np.pi), scale_sqrd))
    quadratic = lax.div(lax.pow(lax.sub(x, loc), two), scale_sqrd)
    return lax.div(lax.neg(lax.add(log_normalizer, quadratic)), two)
예제 #2
0
def pow_right(y, z, ildj_):
    # x ** y = z
    # x = f^-1(z) = z ** (1 / y)
    # grad(f^-1)(z) = 1 / y * z ** (1 / y - 1)
    # log(grad(f^-1)(z)) = (1 / y - 1)log(z) - log(y)
    y_inv = np.reciprocal(y)
    return lax.pow(z, y_inv), ildj_ + (y_inv - 1.) * np.log(z) - np.log(y)
예제 #3
0
def integer_pow_inverse(z, *, y):
    """Inverse for `integer_pow_p` primitive."""
    if y == 0:
        raise ValueError('Cannot invert raising to a value to the 0-th power.')
    elif y == 1:
        return z
    elif y == -1:
        return np.reciprocal(z)
    elif y == 2:
        return np.sqrt(z)
    return lax.pow(z, 1. / y)
예제 #4
0
def _power(x1, x2):
  x1, x2 = _promote_args("power", x1, x2)
  dtype = dtypes.dtype(x1)
  if not dtypes.issubdtype(dtype, np.integer):
    return lax.pow(x1, x2)

  # Integer power => use binary exponentiation.

  # TODO(phawkins): add integer pow support to XLA.
  bits = 6  # Anything more would overflow for any x1 > 1
  zero = _constant_like(x2, 0)
  one = _constant_like(x2, 1)
  # Initialize acc carefully such that pow(0, x2) is zero for x2 != 0
  acc = _where(lax.bitwise_and(lax.eq(x1, zero), lax.ne(x2, zero)), zero, one)
  for _ in range(bits):
    acc = _where(lax.bitwise_and(x2, one), lax.mul(acc, x1), acc)
    x1 = lax.mul(x1, x1)
    x2 = lax.shift_right_logical(x2, one)
  return acc
예제 #5
0
파일: lapax.py 프로젝트: zudehuang/jax
def sqrt(x):
    return LapaxMatrix(lax.pow(x.ndarray, lax.full_like(x.ndarray, 0.5)), x.bs)
예제 #6
0
 def f(x, y):
     return lax.pow(x, y)