Esempio n. 1
0
File: jet.py Progetto: nhanwei/jax
def _erf_inv_rule(primals_in, series_in):
    x, = primals_in
    series, = series_in

    u = [x] + series
    primal_out = lax.erf_inv(x)
    v = [primal_out] + [None] * len(series)

    # derivative on co-domain for caching purposes
    deriv_const = np.sqrt(np.pi) / 2.
    deriv_y = lambda y: lax.mul(deriv_const, lax.exp(lax.square(y)))

    # manually propagate through deriv_y since we don't have lazy evaluation of sensitivities

    c = [deriv_y(primal_out)] + [None] * (len(series) - 1)
    tmp_sq = [lax.square(v[0])] + [None] * (len(series) - 1)
    tmp_exp = [lax.exp(tmp_sq[0])] + [None] * (len(series) - 1)
    for k in range(1, len(series)):
        # we know c[:k], we compute c[k]

        # propagate c to get v
        v[k] = fact(k - 1) * sum(
            _scale(k, j) * c[k - j] * u[j] for j in range(1, k + 1))

        # propagate v to get next c

        # square
        tmp_sq[k] = fact(k) * sum(
            _scale2(k, j) * v[k - j] * v[j] for j in range(k + 1))

        # exp
        tmp_exp[k] = fact(k - 1) * sum(
            _scale(k, j) * tmp_exp[k - j] * tmp_sq[j] for j in range(1, k + 1))

        # const
        c[k] = deriv_const * tmp_exp[k]

    # we can't, and don't need, to compute c[k+1], just need to get the last v[k]
    k = len(series)
    v[k] = fact(k - 1) * sum(
        _scale(k, j) * c[k - j] * u[j] for j in range(1, k + 1))

    primal_out, *series_out = v
    return primal_out, series_out
Esempio n. 2
0
File: jet.py Progetto: yangliuy/jax
def _atan2_taylor(primals_in, series_in):
  x, y = primals_in
  primal_out = lax.atan2(x, y)

  x, series = jet(lax.div, primals_in, series_in)
  c0, cs = jet(lambda x: lax.div(1, 1 + lax.square(x)), (x, ), (series, ))
  c = [c0] + cs
  u = [x] + series
  v = [primal_out] + [None] * len(series)
  for k in range(1, len(v)):
    v[k] = fact(k-1) * sum(_scale(k, j) * c[k-j] * u[j] for j in range(1, k + 1))
  primal_out, *series_out = v
  return primal_out, series_out
Esempio n. 3
0
File: jet.py Progetto: nhanwei/jax
    series, = series_in
    primal_out = prim.bind(x)
    c0, cs = jet(deriv, primals_in, series_in)
    c = [c0] + cs
    u = [x] + series
    v = [primal_out] + [None] * len(series)
    for k in range(1, len(v)):
        v[k] = fact(k - 1) * sum(
            _scale(k, j) * c[k - j] * u[j] for j in range(1, k + 1))
    primal_out, *series_out = v
    return primal_out, series_out


def_deriv(
    lax.erf_p, lambda x: lax.mul(lax._const(x, 2. / np.sqrt(np.pi)),
                                 lax.exp(lax.neg(lax.square(x)))))


def def_comp(prim, comp):
    """
  Define the jet rule for a primitive in terms of a composition of simpler primitives.
  """
    jet_rules[prim] = partial(jet, comp)


def_comp(lax.expm1_p, lambda x: lax.exp(x) - 1)
def_comp(lax.log1p_p, lambda x: lax.log(1 + x))
def_comp(lax.sqrt_p, lambda x: x**0.5)
def_comp(lax.rsqrt_p, lambda x: x**-0.5)
def_comp(lax.asinh_p, lambda x: lax.log(x + lax.sqrt(lax.square(x) + 1)))
def_comp(lax.acosh_p, lambda x: lax.log(x + lax.sqrt(lax.square(x) - 1)))
Esempio n. 4
0
def _acosh_taylor(primals_in, series_in):
    return jet(lambda x: lax.log(x + lax.sqrt(lax.square(x) - 1)), primals_in,
               series_in)
Esempio n. 5
0
    series, = series_in
    primal_out = prim.bind(x)
    c0, cs = jet(deriv, primals_in, series_in)
    c = [c0] + cs
    u = [x] + series
    v = [primal_out] + [None] * len(series)
    for k in range(1, len(v)):
        v[k] = fact(k - 1) * sum(
            _scale(k, j) * c[k - j] * u[j] for j in range(1, k + 1))
    primal_out, *series_out = v
    return primal_out, series_out


def_deriv(
    lax.erf_p, lambda x: lax.mul(lax._const(x, 2. / np.sqrt(np.pi)),
                                 lax.exp(lax.neg(lax.square(x)))))


def def_comp(prim, comp):
    """
  Define the jet rule for a primitive in terms of a composition of simpler primitives.
  """
    jet_rules[prim] = partial(jet, comp)


def_comp(lax.erfc_p, lambda x: 1 - lax.erf(x))

### More complicated rules


def fact(n):
Esempio n. 6
0
File: jet.py Progetto: yangliuy/jax
def deriv_prop(prim, deriv, primals_in, series_in):
  x, = primals_in
  series, = series_in
  primal_out = prim.bind(x)
  c0, cs = jet(deriv, primals_in, series_in)
  c = [c0] + cs
  u = [x] + series
  v = [primal_out] + [None] * len(series)
  for k in range(1, len(v)):
    v[k] = fact(k-1) * sum(_scale(k, j) * c[k-j] * u[j] for j in range(1, k + 1))
  primal_out, *series_out = v
  return primal_out, series_out


def_deriv(lax.erf_p, lambda x: lax.mul(lax._const(x, 2. / np.sqrt(np.pi)), lax.exp(lax.neg(lax.square(x)))))


def def_comp(prim, comp):
  """
  Define the jet rule for a primitive in terms of a composition of simpler primitives.
  """
  jet_rules[prim] = partial(jet, comp)


def_comp(lax.expm1_p, lambda x: lax.exp(x) - 1)
def_comp(lax.log1p_p, lambda x: lax.log(1 + x))
def_comp(lax.sqrt_p, lambda x: x ** 0.5)
def_comp(lax.rsqrt_p, lambda x: x ** -0.5)
def_comp(lax.asinh_p, lambda x: lax.log(x + lax.sqrt(lax.square(x) + 1)))
def_comp(lax.acosh_p, lambda x: lax.log(x + lax.sqrt(lax.square(x) - 1)))