Beispiel #1
0
Datei: jet.py Projekt: 0x0is1/jax
def _erf_inv_rule(primals_in, series_in):
    x, = primals_in
    series, = series_in

    u = [x] + series
    primal_out = lax.erf_inv(x)
    v = [primal_out] + [None] * len(series)

    # derivative on co-domain for caching purposes
    deriv_const = np.sqrt(np.pi) / 2.
    deriv_y = lambda y: lax.mul(deriv_const, lax.exp(lax.square(y)))

    # manually propagate through deriv_y since we don't have lazy evaluation of sensitivities

    c = [deriv_y(primal_out)] + [None] * (len(series) - 1)
    tmp_sq = [lax.square(v[0])] + [None] * (len(series) - 1)
    tmp_exp = [lax.exp(tmp_sq[0])] + [None] * (len(series) - 1)
    for k in range(1, len(series)):
        # we know c[:k], we compute c[k]

        # propagate c to get v
        v[k] = fact(k - 1) * sum(
            _scale(k, j) * c[k - j] * u[j] for j in range(1, k + 1))

        # propagate v to get next c

        # square
        tmp_sq[k] = fact(k) * sum(
            _scale2(k, j) * v[k - j] * v[j] for j in range(k + 1))

        # exp
        tmp_exp[k] = fact(k - 1) * sum(
            _scale(k, j) * tmp_exp[k - j] * tmp_sq[j] for j in range(1, k + 1))

        # const
        c[k] = deriv_const * tmp_exp[k]

    # we can't, and don't need, to compute c[k+1], just need to get the last v[k]
    k = len(series)
    v[k] = fact(k - 1) * sum(
        _scale(k, j) * c[k - j] * u[j] for j in range(1, k + 1))

    primal_out, *series_out = v
    return primal_out, series_out
Beispiel #2
0
def erfinv(x):
    x, = _promote_args_inexact("erfinv", x)
    return lax.erf_inv(x)