Ejemplo n.º 1
0
def test_natural_normal():
    chol = B.randn(2, 2)
    dist = Normal(B.randn(2, 1), B.reg(chol @ chol.T, diag=1e-1))
    nat = NaturalNormal.from_normal(dist)

    # Test properties.
    assert dist.dtype == nat.dtype
    for name in ["dim", "mean", "var", "m2"]:
        approx(getattr(dist, name), getattr(nat, name))

    # Test sampling.
    state = B.create_random_state(dist.dtype, seed=0)
    state, sample = nat.sample(state, num=1_000_000)
    emp_mean = B.mean(B.dense(sample), axis=1, squeeze=False)
    emp_var = (sample - emp_mean) @ (sample - emp_mean).T / 1_000_000
    approx(dist.mean, emp_mean, rtol=5e-2)
    approx(dist.var, emp_var, rtol=5e-2)

    # Test KL.
    chol = B.randn(2, 2)
    other_dist = Normal(B.randn(2, 1), B.reg(chol @ chol.T, diag=1e-2))
    other_nat = NaturalNormal.from_normal(other_dist)
    approx(dist.kl(other_dist), nat.kl(other_nat))

    # Test log-pdf.
    x = B.randn(2, 1)
    approx(dist.logpdf(x), nat.logpdf(x))
Ejemplo n.º 2
0
def test_random_generators(f, t, dtype_transform, just_single_arg,
                           check_lazy_shapes):
    # Test without specifying data type.
    if not just_single_arg:
        assert B.dtype(f()) is dtype_transform(B.default_dtype)
        assert B.shape(f()) == ()
    assert B.dtype(f(2)) is dtype_transform(B.default_dtype)
    assert B.shape(f(2)) == (2, )
    if not just_single_arg:
        assert B.dtype(f(2, 3)) is dtype_transform(B.default_dtype)
        assert B.shape(f(2, 3)) == (2, 3)

    # Test with specifying data type.
    state = B.create_random_state(t, 0)

    # Test direct specification.
    if not just_single_arg:
        assert B.dtype(f(t)) is dtype_transform(t)
        assert B.shape(f(t)) == ()
    assert B.dtype(f(t, 2)) is dtype_transform(t)
    assert B.shape(f(t, 2)) == (2, )
    if not just_single_arg:
        assert B.dtype(f(t, 2, 3)) is dtype_transform(t)
        assert B.shape(f(t, 2, 3)) == (2, 3)

    # Test state specification.
    if not just_single_arg:
        assert isinstance(f(state, t)[0], B.RandomState)
        assert B.dtype(f(state, t)[1]) is dtype_transform(t)
        assert B.shape(f(state, t)[1]) == ()
    assert isinstance(f(state, t, 2)[0], B.RandomState)
    assert B.dtype(f(state, t, 2)[1]) is dtype_transform(t)
    assert B.shape(f(state, t, 2)[1]) == (2, )
    if not just_single_arg:
        assert isinstance(f(state, t, 2, 3)[0], B.RandomState)
        assert B.dtype(f(state, t, 2, 3)[1]) is dtype_transform(t)
        assert B.shape(f(state, t, 2, 3)[1]) == (2, 3)

    if not just_single_arg:
        # Test reference specification.
        assert B.dtype(f(f(t))) is dtype_transform(t)
        assert B.shape(f(f())) == ()
        assert B.dtype(f(f(t, 2))) is dtype_transform(t)
        assert B.shape(f(f(t, 2))) == (2, )
        assert B.dtype(f(f(t, 2, 3))) is dtype_transform(t)
        assert B.shape(f(f(t, 2, 3))) == (2, 3)

        # Test state and reference specification.
        assert isinstance(f(state, f(t))[0], B.RandomState)
        assert B.dtype(f(state, f(t))[1]) is dtype_transform(t)
        assert B.shape(f(state, f(t))[1]) == ()
        assert isinstance(f(state, f(t, 2))[0], B.RandomState)
        assert B.dtype(f(state, f(t, 2))[1]) is dtype_transform(t)
        assert B.shape(f(state, f(t, 2))[1]) == (2, )
        assert isinstance(f(state, f(t, 2, 3))[0], B.RandomState)
        assert B.dtype(f(state, f(t, 2, 3))[1]) is dtype_transform(t)
        assert B.shape(f(state, f(t, 2, 3))[1]) == (2, 3)
Ejemplo n.º 3
0
def test_create_random_state(dtype):
    # Test specification without argument.
    B.create_random_state(dtype)

    # Check that it does the right thing.
    state = B.create_random_state(dtype, seed=0)
    state, x1 = B.rand(state, dtype)
    state, x2 = B.rand(state, dtype)
    x1, x2 = to_np(x1), to_np(x2)

    state = B.create_random_state(dtype, seed=0)
    state, y1 = B.rand(state, dtype)
    state, y2 = B.rand(state, dtype)
    y1, y2 = to_np(y1), to_np(y2)

    assert x1 != x2
    assert x1 == y1
    assert x2 == y2
Ejemplo n.º 4
0
def test_normal_sampling():
    for mean in [0, 1]:
        dist = Normal(mean, 3 * B.eye(np.int32, 200))

        # Sample without noise.
        samples = dist.sample(2000)
        approx(B.mean(samples), mean, atol=5e-2)
        approx(B.std(samples) ** 2, 3, atol=5e-2)

        # Sample with noise
        samples = dist.sample(2000, noise=2)
        approx(B.mean(samples), mean, atol=5e-2)
        approx(B.std(samples) ** 2, 5, atol=5e-2)

        state, sample1 = dist.sample(B.create_random_state(B.dtype(dist), seed=0))
        state, sample2 = dist.sample(B.create_random_state(B.dtype(dist), seed=0))
        assert isinstance(state, B.RandomState)
        approx(sample1, sample2)
Ejemplo n.º 5
0
def test_set_seed_set_global_random_state(dtype, f_plain, check_lazy_shapes):
    B.set_random_seed(0)
    x1 = to_np(B.rand(dtype))
    x2 = to_np(f_plain())
    B.set_random_seed(0)
    y1 = to_np(B.rand(dtype))
    y2 = to_np(f_plain())
    assert x1 == y1
    assert x2 == y2

    B.set_global_random_state(B.create_random_state(dtype, seed=0))
    x1 = to_np(B.rand(dtype))
    x2 = to_np(f_plain())
    B.set_global_random_state(B.create_random_state(dtype, seed=0))
    y1 = to_np(B.rand(dtype))
    y2 = to_np(f_plain())
    assert x1 == y1
    # TODO: Make this work with TF!
    if not isinstance(dtype, B.TFDType):
        assert x2 == y2
Ejemplo n.º 6
0
def test_choice(x, p, check_lazy_shapes):
    state = B.create_random_state(B.dtype(x))

    # Make `p` a dictionary so that we can optionally give it.
    if p is None:
        p = {}
    else:
        # Cast weights to the right framework.
        p = {"p": B.cast(B.dtype(x), p)}

    # Check shape.
    assert B.shape(B.choice(x, **p)) == B.shape(x)[1:]
    assert B.shape(B.choice(x, 5, **p)) == (5, ) + B.shape(x)[1:]
    assert B.shape(B.choice(x, 5, 5, **p)) == (5, 5) + B.shape(x)[1:]

    assert isinstance(B.choice(state, x, **p)[0], B.RandomState)
    assert B.shape(B.choice(state, x, **p)[1]) == B.shape(x)[1:]
    assert B.shape(B.choice(state, x, 5, **p)[1]) == (5, ) + B.shape(x)[1:]
    assert B.shape(B.choice(state, x, 5, 5, **p)[1]) == (5, 5) + B.shape(x)[1:]

    # Check correctness.
    dtype = B.dtype(x)
    choices = set(to_np(B.choice(B.range(dtype, 5), 1000)))
    assert choices == set(to_np(B.range(dtype, 5)))
Ejemplo n.º 7
0
def test_framework_jax(t, check_lazy_shapes):
    assert isinstance(jnp.asarray(1), t)
    assert isinstance(jnp.float32, t)
    assert isinstance(B.create_random_state(jnp.float32), t)
    assert isinstance(B.device(jnp.asarray(1)), t)
Ejemplo n.º 8
0
def test_framework_torch(t, check_lazy_shapes):
    assert isinstance(torch.tensor(1), t)
    assert isinstance(torch.float32, t)
    assert isinstance(B.create_random_state(torch.float32), t)
    assert isinstance(B.device(torch.tensor(1)), t)
Ejemplo n.º 9
0
def test_framework_tf(t, check_lazy_shapes):
    assert isinstance(tf.constant(1), t)
    assert isinstance(tf.float32, t)
    assert isinstance(B.create_random_state(tf.float32), t)
    assert isinstance(B.device(tf.constant(1)), t)
Ejemplo n.º 10
0
def test_framework_ag(t, check_lazy_shapes):
    assert isinstance(autograd_box(np.array(1)), t)
    assert isinstance(np.float32, t)
    assert isinstance(B.create_random_state(np.float32), t)
Ejemplo n.º 11
0
def test_framework_np(t, check_lazy_shapes):
    assert isinstance(np.array(1), t)
    assert isinstance(np.float32, t)
    assert isinstance(B.create_random_state(np.float32), t)
Ejemplo n.º 12
0
def test_random_state(t, FWRandomState, check_lazy_shapes):
    assert isinstance(B.create_random_state(t), FWRandomState)