Пример #1
0
Файл: jet.py Проект: nhanwei/jax
def _log_taylor(primals_in, series_in):
    x, = primals_in
    series, = series_in
    u = [x] + series
    v = [lax.log(x)] + [None] * len(series)
    for k in range(1, len(v)):
        conv = sum([_scale(k, j) * v[j] * u[k - j] for j in range(1, k)])
        v[k] = (u[k] - fact(k - 1) * conv) / u[0]
    primal_out, *series_out = v
    return primal_out, series_out
Пример #2
0
def _pow_taylor(primals_in, series_in):
  u_, r_ = primals_in

  x, series = jet(lambda x, y: lax.mul(y, lax.log(x)), primals_in, series_in)

  u = [x] + series
  v = [u_ ** r_] + [None] * len(series)
  for k in range(1, len(v)):
    v[k] = fact(k-1) * sum([_scale(k, j)* v[k-j] * u[j] for j in range(1, k+1)])
  primal_out, *series_out = v

  return primal_out, series_out
Пример #3
0
Файл: jet.py Проект: nhanwei/jax

def_deriv(
    lax.erf_p, lambda x: lax.mul(lax._const(x, 2. / np.sqrt(np.pi)),
                                 lax.exp(lax.neg(lax.square(x)))))


def def_comp(prim, comp):
    """
  Define the jet rule for a primitive in terms of a composition of simpler primitives.
  """
    jet_rules[prim] = partial(jet, comp)


def_comp(lax.expm1_p, lambda x: lax.exp(x) - 1)
def_comp(lax.log1p_p, lambda x: lax.log(1 + x))
def_comp(lax.sqrt_p, lambda x: x**0.5)
def_comp(lax.rsqrt_p, lambda x: x**-0.5)
def_comp(lax.asinh_p, lambda x: lax.log(x + lax.sqrt(lax.square(x) + 1)))
def_comp(lax.acosh_p, lambda x: lax.log(x + lax.sqrt(lax.square(x) - 1)))
def_comp(lax.atanh_p, lambda x: 0.5 * lax.log(lax.div(1 + x, 1 - x)))
def_comp(lax.erfc_p, lambda x: 1 - lax.erf(x))
def_comp(lax.rem_p, lambda x, y: x - y * lax.floor(x / y))
def_comp(lax.clamp_p, lambda a, x, b: lax.min(lax.max(a, x), b))


def _erf_inv_rule(primals_in, series_in):
    x, = primals_in
    series, = series_in

    u = [x] + series
Пример #4
0
def _atanh_taylor(primals_in, series_in):
    return jet(lambda x: 0.5 * lax.log(lax.div(1 + x, 1 - x)), primals_in,
               series_in)
Пример #5
0
def _acosh_taylor(primals_in, series_in):
    return jet(lambda x: lax.log(x + lax.sqrt(lax.square(x) - 1)), primals_in,
               series_in)