def _atan2_taylor(primals_in, series_in): x, y = primals_in primal_out = lax.atan2(x, y) x, series = jet(lax.div, primals_in, series_in) c0, cs = jet(lambda x: lax.div(1, 1 + lax.square(x)), (x, ), (series, )) c = [c0] + cs u = [x] + series v = [primal_out] + [None] * len(series) for k in range(1, len(v)): v[k] = fact(k-1) * sum(_scale(k, j) * c[k-j] * u[j] for j in range(1, k + 1)) primal_out, *series_out = v return primal_out, series_out
def _reduce_chooser_taylor_rule(g): return lax.div( lax._reduce_sum(lax.mul(g, location_indicators), axes), counts)
def def_comp(prim, comp): """ Define the jet rule for a primitive in terms of a composition of simpler primitives. """ jet_rules[prim] = partial(jet, comp) def_comp(lax.expm1_p, lambda x: lax.exp(x) - 1) def_comp(lax.log1p_p, lambda x: lax.log(1 + x)) def_comp(lax.sqrt_p, lambda x: x**0.5) def_comp(lax.rsqrt_p, lambda x: x**-0.5) def_comp(lax.asinh_p, lambda x: lax.log(x + lax.sqrt(lax.square(x) + 1))) def_comp(lax.acosh_p, lambda x: lax.log(x + lax.sqrt(lax.square(x) - 1))) def_comp(lax.atanh_p, lambda x: 0.5 * lax.log(lax.div(1 + x, 1 - x))) def_comp(lax.erfc_p, lambda x: 1 - lax.erf(x)) def_comp(lax.rem_p, lambda x, y: x - y * lax.floor(x / y)) def_comp(lax.clamp_p, lambda a, x, b: lax.min(lax.max(a, x), b)) def _erf_inv_rule(primals_in, series_in): x, = primals_in series, = series_in u = [x] + series primal_out = lax.erf_inv(x) v = [primal_out] + [None] * len(series) # derivative on co-domain for caching purposes deriv_const = np.sqrt(np.pi) / 2.
def _atanh_taylor(primals_in, series_in): return jet(lambda x: 0.5 * lax.log(lax.div(1 + x, 1 - x)), primals_in, series_in)