Esempio n. 1
0
File: jet.py Progetto: nhanwei/jax
    x, = primals_in
    series, = series_in
    primal_out = prim.bind(x)
    c0, cs = jet(deriv, primals_in, series_in)
    c = [c0] + cs
    u = [x] + series
    v = [primal_out] + [None] * len(series)
    for k in range(1, len(v)):
        v[k] = fact(k - 1) * sum(
            _scale(k, j) * c[k - j] * u[j] for j in range(1, k + 1))
    primal_out, *series_out = v
    return primal_out, series_out


def_deriv(
    lax.erf_p, lambda x: lax.mul(lax._const(x, 2. / np.sqrt(np.pi)),
                                 lax.exp(lax.neg(lax.square(x)))))


def def_comp(prim, comp):
    """
  Define the jet rule for a primitive in terms of a composition of simpler primitives.
  """
    jet_rules[prim] = partial(jet, comp)


def_comp(lax.expm1_p, lambda x: lax.exp(x) - 1)
def_comp(lax.log1p_p, lambda x: lax.log(1 + x))
def_comp(lax.sqrt_p, lambda x: x**0.5)
def_comp(lax.rsqrt_p, lambda x: x**-0.5)
def_comp(lax.asinh_p, lambda x: lax.log(x + lax.sqrt(lax.square(x) + 1)))
Esempio n. 2
0
def _allgather(x, dim, size, axis_name):
    shape = list(x.shape)
    shape.insert(dim, size)
    out = lax.full(shape, lax._const(x, 0))
    out = lax.dynamic_update_index_in_dim(out, x, axis_index(axis_name), dim)
    return psum(out, axis_name)
Esempio n. 3
0
def _expand(dim, size, axis_name, x):
    shape = list(x.shape)
    shape.insert(dim, size)
    out = lax.full(shape, lax._const(x, 0))
    return lax.dynamic_update_index_in_dim(out, x, axis_index(axis_name), dim)
Esempio n. 4
0
 def while_body_fun(loop_carry):
     i, x = loop_carry
     return lax.add(i, lax._const(i, 1)), body_fun(i, x)
Esempio n. 5
0
File: jet.py Progetto: yangliuy/jax
def deriv_prop(prim, deriv, primals_in, series_in):
  x, = primals_in
  series, = series_in
  primal_out = prim.bind(x)
  c0, cs = jet(deriv, primals_in, series_in)
  c = [c0] + cs
  u = [x] + series
  v = [primal_out] + [None] * len(series)
  for k in range(1, len(v)):
    v[k] = fact(k-1) * sum(_scale(k, j) * c[k-j] * u[j] for j in range(1, k + 1))
  primal_out, *series_out = v
  return primal_out, series_out


def_deriv(lax.erf_p, lambda x: lax.mul(lax._const(x, 2. / np.sqrt(np.pi)), lax.exp(lax.neg(lax.square(x)))))


def def_comp(prim, comp):
  """
  Define the jet rule for a primitive in terms of a composition of simpler primitives.
  """
  jet_rules[prim] = partial(jet, comp)


def_comp(lax.expm1_p, lambda x: lax.exp(x) - 1)
def_comp(lax.log1p_p, lambda x: lax.log(1 + x))
def_comp(lax.sqrt_p, lambda x: x ** 0.5)
def_comp(lax.rsqrt_p, lambda x: x ** -0.5)
def_comp(lax.asinh_p, lambda x: lax.log(x + lax.sqrt(lax.square(x) + 1)))
def_comp(lax.acosh_p, lambda x: lax.log(x + lax.sqrt(lax.square(x) - 1)))