コード例 #1
0
def test_fixedpoint_vmap():
  def elem(y):
    def fixed_point(x):
      return np.concatenate([np.array([1.]), 2 * lax.slice(x + y, [0], [3])])
    return lazy_eval_fixed_point(fixed_point, np.zeros(4))[:]

  ys = np.array([1, 2, 3, 4])
  expected = np.array([elem(y) for y in ys])
  actual = vmap(elem)(ys)
  jtu.check_close(expected, actual)
コード例 #2
0
ファイル: test_transpose.py プロジェクト: duvenaud/jaxde
def test_odeint_linearize_fwrap():
    def odeint_fwrap(y0, ts, fargs):
        return odeint(y0, ts, func=f, fargs=fargs)

    _, out_tangent = jvp(odeint_fwrap, (y0, ts, fargs),
                         (y0, ts, fargs))  # when break this is why
    y, f_jvp = linearize(odeint_fwrap, *(y0, ts, fargs))
    out_tangent_2 = f_jvp(*(y0, ts, fargs))

    # print(make_jaxpr(f_jvp)((y0,t0,t1,fargs),))
    check_close(out_tangent, out_tangent_2)
コード例 #3
0
def test_jitted_update_fn():
    data = jnp.array([1.0] * 8 + [0.0] * 2)

    def model(data):
        f = numpyro.sample("beta", dist.Beta(1.0, 1.0))
        numpyro.sample("obs", dist.Bernoulli(f), obs=data)

    def guide(data):
        alpha_q = numpyro.param("alpha_q", 1.0, constraint=constraints.positive)
        beta_q = numpyro.param("beta_q", 1.0, constraint=constraints.positive)
        numpyro.sample("beta", dist.Beta(alpha_q, beta_q))

    adam = optim.Adam(0.05)
    svi = SVI(model, guide, adam, Trace_ELBO())
    svi_state = svi.init(random.PRNGKey(1), data)
    expected = svi.get_params(svi.update(svi_state, data)[0])

    actual = svi.get_params(jit(svi.update)(svi_state, data=data)[0])
    check_close(actual, expected, atol=1e-5)
コード例 #4
0
def test_mcmc_progbar():
    true_mean, true_std = 1., 2.
    num_warmup, num_samples = 10, 10

    def model(data):
        mean = numpyro.param('mean', 0.)
        std = numpyro.param('std', 1., constraint=constraints.positive)
        return numpyro.sample('obs', dist.Normal(mean, std), obs=data)

    data = dist.Normal(true_mean, true_std).sample(random.PRNGKey(1), (2000, ))
    kernel = NUTS(model=model)
    mcmc = MCMC(kernel, num_warmup, num_samples)
    mcmc.warmup(random.PRNGKey(2), data)
    mcmc.run(random.PRNGKey(3), data)
    mcmc1 = MCMC(kernel, num_warmup, num_samples, progress_bar=False)
    mcmc1.run(random.PRNGKey(2), data)

    with pytest.raises(AssertionError):
        check_close(mcmc1.get_samples(),
                    mcmc.get_samples(),
                    atol=1e-4,
                    rtol=1e-4)
    mcmc1.warmup(random.PRNGKey(2), data)
    mcmc1.run(random.PRNGKey(3), data)
    check_close(mcmc1.get_samples(), mcmc.get_samples(), atol=1e-4, rtol=1e-4)
    check_close(mcmc1._warmup_state, mcmc._warmup_state, atol=1e-4, rtol=1e-4)
コード例 #5
0
def test_mcmc_progbar():
    true_mean, true_std = 1.0, 2.0
    num_warmup, num_samples = 10, 10

    def model(data):
        mean = numpyro.sample("mean", dist.Normal(0, 1).mask(False))
        std = numpyro.sample("std", dist.LogNormal(0, 1).mask(False))
        return numpyro.sample("obs", dist.Normal(mean, std), obs=data)

    data = dist.Normal(true_mean, true_std).sample(random.PRNGKey(1), (2000, ))
    kernel = NUTS(model=model)
    mcmc = MCMC(kernel, num_warmup=num_warmup, num_samples=num_samples)
    mcmc.warmup(random.PRNGKey(2), data)
    mcmc.run(random.PRNGKey(3), data)
    mcmc1 = MCMC(kernel,
                 num_warmup=num_warmup,
                 num_samples=num_samples,
                 progress_bar=False)
    mcmc1.run(random.PRNGKey(2), data)

    with pytest.raises(AssertionError):
        check_close(mcmc1.get_samples(),
                    mcmc.get_samples(),
                    atol=1e-4,
                    rtol=1e-4)
    mcmc1.warmup(random.PRNGKey(2), data)
    mcmc1.run(random.PRNGKey(3), data)
    check_close(mcmc1.get_samples(), mcmc.get_samples(), atol=1e-4, rtol=1e-4)
    check_close(mcmc1.post_warmup_state,
                mcmc.post_warmup_state,
                atol=1e-4,
                rtol=1e-4)
コード例 #6
0
    def test_richardson_solve(self):
        """Check that richardson_solve produces correct outputs and derivatives."""

        # Ensure we converge to the fixed point
        matrix = jax.random.normal(jax.random.PRNGKey(0), (50, 50))
        matrix = jnp.eye(50) - 0.9 * matrix / jnp.sum(
            jnp.abs(matrix), axis=0, keepdims=True)
        b = jax.random.normal(jax.random.PRNGKey(1), (50, ))

        def iter_solve(matrix, b):
            return linear_solvers.richardson_solve(lambda x: matrix @ x,
                                                   b,
                                                   iterations=100)

        # Correct output
        jtu.check_close(iter_solve(matrix, b),
                        jax.scipy.linalg.solve(matrix, b),
                        rtol=1e-4)

        # Correct jvp
        dmatrix = jax.random.normal(jax.random.PRNGKey(2), (50, 50))
        db = jax.random.normal(jax.random.PRNGKey(3), (50, ))
        jtu.check_close(jax.jvp(iter_solve, (matrix, b), (dmatrix, db)),
                        jax.jvp(jax.scipy.linalg.solve, (matrix, b),
                                (dmatrix, db)),
                        rtol=1e-4)

        # Correct vjp
        co_x = jax.random.normal(jax.random.PRNGKey(3), (50, ))
        jtu.check_close(jax.vjp(iter_solve, matrix, b)[1](co_x),
                        jax.vjp(jax.scipy.linalg.solve, matrix, b)[1](co_x),
                        rtol=1e-4)
コード例 #7
0
    def test_richardson_solve_structured(self):
        """Check that richardson_solve works on pytrees."""
        def structured_matvec(m1, m2, m3, xs):
            x1, x2 = xs
            b1 = m1 @ x1
            b2 = {
                "foo": m2 @ x2["foo"],
                "bar": m3 @ x2["bar"],
            }
            return b1, b2

        def structured_direct_solve(m1, m2, m3, b1, b2):
            x1 = jax.scipy.linalg.solve(m1, b1)
            x2 = {
                "foo": jax.scipy.linalg.solve(m2, b2["foo"]),
                "bar": jax.scipy.linalg.solve(m3, b2["bar"]),
            }
            return x1, x2

        def structured_iter_solve(m1, m2, m3, b1, b2):
            return linear_solvers.richardson_solve(functools.partial(
                structured_matvec, m1, m2, m3), (b1, b2),
                                                   iterations=200)

        def mk_mat(key):
            matrix = jax.random.normal(jax.random.PRNGKey(key), (50, 50))
            return jnp.eye(50) - 0.9 * matrix / jnp.sum(
                jnp.abs(matrix), axis=0, keepdims=True)

        m1 = mk_mat(0)
        m2 = mk_mat(1)
        m3 = mk_mat(2)
        b1 = jax.random.normal(jax.random.PRNGKey(3), (50, ))
        b2 = {
            "foo": jax.random.normal(jax.random.PRNGKey(4), (50, )),
            "bar": jax.random.normal(jax.random.PRNGKey(5), (50, )),
        }
        jtu.check_close(structured_iter_solve(m1, m2, m3, b1, b2),
                        structured_direct_solve(m1, m2, m3, b1, b2))
コード例 #8
0
ファイル: core_test.py プロジェクト: tomhennigan/jax
 def test_jit(self, f, args):
   jtu.check_close(jit(f)(*args), f(*args))
コード例 #9
0
ファイル: test_transpose.py プロジェクト: duvenaud/jaxde
    def odeint2(y0, ts, fargs):
        return odeint(f, y0, ts, fargs, atol=1e-8, rtol=1e-8)

    odeint2_prim = custom_transforms(odeint2).primitive

    def odeint2_jvp((y0, ts, fargs), (tan_y, tan_ts, tan_fargs)):
        return jvp_odeint(f, (y0, ts, fargs), (tan_y, tan_ts, tan_fargs))

    ad.defjvp(odeint2_prim, odeint2_jvp)

    _, out_tangent = jvp(odeint2, (y0, ts, fargs),
                         (y0, ts, fargs))  # when break this is why
    y, f_jvp = linearize(odeint2, *(y0, ts, fargs))
    out_tangent_2 = f_jvp(*(y0, ts, fargs))

    # print(make_jaxpr(f_jvp)(y0,t0,t1,fargs))
    check_close(out_tangent, out_tangent_2)


def test_odeint_linearize_fwrap():
    def odeint_fwrap(y0, ts, fargs):
        return odeint(y0, ts, func=f, fargs=fargs)

    _, out_tangent = jvp(odeint_fwrap, (y0, ts, fargs),
                         (y0, ts, fargs))  # when break this is why
    y, f_jvp = linearize(odeint_fwrap, *(y0, ts, fargs))
    out_tangent_2 = f_jvp(*(y0, ts, fargs))

    # print(make_jaxpr(f_jvp)((y0,t0,t1,fargs),))
    check_close(out_tangent, out_tangent_2)
コード例 #10
0
def test_pickle_hmcecs():
    mcmc = MCMC(HMCECS(NUTS(logistic_regression)), num_warmup=10, num_samples=10)
    mcmc.run(random.PRNGKey(0))
    pickled_mcmc = pickle.loads(pickle.dumps(mcmc))
    test_util.check_close(mcmc.get_samples(), pickled_mcmc.get_samples())
コード例 #11
0
def test_pickle_discrete_hmc(kernel):
    mcmc = MCMC(kernel(HMC(bernoulli_model)), num_warmup=10, num_samples=10)
    mcmc.run(random.PRNGKey(0))
    pickled_mcmc = pickle.loads(pickle.dumps(mcmc))
    test_util.check_close(mcmc.get_samples(), pickled_mcmc.get_samples())
コード例 #12
0
def test_pickle_hmc(kernel):
    mcmc = MCMC(kernel(normal_model), 10, 10)
    mcmc.run(random.PRNGKey(0))
    pickled_mcmc = pickle.loads(pickle.dumps(mcmc))
    test_util.check_close(mcmc.get_samples(), pickled_mcmc.get_samples())