예제 #1
0
파일: jet.py 프로젝트: yangliuy/jax
def _pow_taylor(primals_in, series_in):
  u_, r_ = primals_in

  x, series = jet(lambda x, y: lax.mul(y, lax.log(x)), primals_in, series_in)

  u = [x] + series
  v = [u_ ** r_] + [None] * len(series)
  for k in range(1, len(v)):
    v[k] = fact(k-1) * sum([_scale(k, j)* v[k-j] * u[j] for j in range(1, k+1)])
  primal_out, *series_out = v

  return primal_out, series_out
예제 #2
0
파일: jet.py 프로젝트: nhanwei/jax
def _erf_inv_rule(primals_in, series_in):
    x, = primals_in
    series, = series_in

    u = [x] + series
    primal_out = lax.erf_inv(x)
    v = [primal_out] + [None] * len(series)

    # derivative on co-domain for caching purposes
    deriv_const = np.sqrt(np.pi) / 2.
    deriv_y = lambda y: lax.mul(deriv_const, lax.exp(lax.square(y)))

    # manually propagate through deriv_y since we don't have lazy evaluation of sensitivities

    c = [deriv_y(primal_out)] + [None] * (len(series) - 1)
    tmp_sq = [lax.square(v[0])] + [None] * (len(series) - 1)
    tmp_exp = [lax.exp(tmp_sq[0])] + [None] * (len(series) - 1)
    for k in range(1, len(series)):
        # we know c[:k], we compute c[k]

        # propagate c to get v
        v[k] = fact(k - 1) * sum(
            _scale(k, j) * c[k - j] * u[j] for j in range(1, k + 1))

        # propagate v to get next c

        # square
        tmp_sq[k] = fact(k) * sum(
            _scale2(k, j) * v[k - j] * v[j] for j in range(k + 1))

        # exp
        tmp_exp[k] = fact(k - 1) * sum(
            _scale(k, j) * tmp_exp[k - j] * tmp_sq[j] for j in range(1, k + 1))

        # const
        c[k] = deriv_const * tmp_exp[k]

    # we can't, and don't need, to compute c[k+1], just need to get the last v[k]
    k = len(series)
    v[k] = fact(k - 1) * sum(
        _scale(k, j) * c[k - j] * u[j] for j in range(1, k + 1))

    primal_out, *series_out = v
    return primal_out, series_out
예제 #3
0
파일: jet.py 프로젝트: nhanwei/jax
 def _reduce_chooser_taylor_rule(g):
     return lax.div(
         lax._reduce_sum(lax.mul(g, location_indicators), axes), counts)
예제 #4
0
파일: jet.py 프로젝트: nhanwei/jax
    x, = primals_in
    series, = series_in
    primal_out = prim.bind(x)
    c0, cs = jet(deriv, primals_in, series_in)
    c = [c0] + cs
    u = [x] + series
    v = [primal_out] + [None] * len(series)
    for k in range(1, len(v)):
        v[k] = fact(k - 1) * sum(
            _scale(k, j) * c[k - j] * u[j] for j in range(1, k + 1))
    primal_out, *series_out = v
    return primal_out, series_out


def_deriv(
    lax.erf_p, lambda x: lax.mul(lax._const(x, 2. / np.sqrt(np.pi)),
                                 lax.exp(lax.neg(lax.square(x)))))


def def_comp(prim, comp):
    """
  Define the jet rule for a primitive in terms of a composition of simpler primitives.
  """
    jet_rules[prim] = partial(jet, comp)


def_comp(lax.expm1_p, lambda x: lax.exp(x) - 1)
def_comp(lax.log1p_p, lambda x: lax.log(1 + x))
def_comp(lax.sqrt_p, lambda x: x**0.5)
def_comp(lax.rsqrt_p, lambda x: x**-0.5)
def_comp(lax.asinh_p, lambda x: lax.log(x + lax.sqrt(lax.square(x) + 1)))
def_comp(lax.acosh_p, lambda x: lax.log(x + lax.sqrt(lax.square(x) - 1)))