def test_natural_normal(): chol = B.randn(2, 2) dist = Normal(B.randn(2, 1), B.reg(chol @ chol.T, diag=1e-1)) nat = NaturalNormal.from_normal(dist) # Test properties. assert dist.dtype == nat.dtype for name in ["dim", "mean", "var", "m2"]: approx(getattr(dist, name), getattr(nat, name)) # Test sampling. state = B.create_random_state(dist.dtype, seed=0) state, sample = nat.sample(state, num=1_000_000) emp_mean = B.mean(B.dense(sample), axis=1, squeeze=False) emp_var = (sample - emp_mean) @ (sample - emp_mean).T / 1_000_000 approx(dist.mean, emp_mean, rtol=5e-2) approx(dist.var, emp_var, rtol=5e-2) # Test KL. chol = B.randn(2, 2) other_dist = Normal(B.randn(2, 1), B.reg(chol @ chol.T, diag=1e-2)) other_nat = NaturalNormal.from_normal(other_dist) approx(dist.kl(other_dist), nat.kl(other_nat)) # Test log-pdf. x = B.randn(2, 1) approx(dist.logpdf(x), nat.logpdf(x))
def test_random_generators(f, t, dtype_transform, just_single_arg, check_lazy_shapes): # Test without specifying data type. if not just_single_arg: assert B.dtype(f()) is dtype_transform(B.default_dtype) assert B.shape(f()) == () assert B.dtype(f(2)) is dtype_transform(B.default_dtype) assert B.shape(f(2)) == (2, ) if not just_single_arg: assert B.dtype(f(2, 3)) is dtype_transform(B.default_dtype) assert B.shape(f(2, 3)) == (2, 3) # Test with specifying data type. state = B.create_random_state(t, 0) # Test direct specification. if not just_single_arg: assert B.dtype(f(t)) is dtype_transform(t) assert B.shape(f(t)) == () assert B.dtype(f(t, 2)) is dtype_transform(t) assert B.shape(f(t, 2)) == (2, ) if not just_single_arg: assert B.dtype(f(t, 2, 3)) is dtype_transform(t) assert B.shape(f(t, 2, 3)) == (2, 3) # Test state specification. if not just_single_arg: assert isinstance(f(state, t)[0], B.RandomState) assert B.dtype(f(state, t)[1]) is dtype_transform(t) assert B.shape(f(state, t)[1]) == () assert isinstance(f(state, t, 2)[0], B.RandomState) assert B.dtype(f(state, t, 2)[1]) is dtype_transform(t) assert B.shape(f(state, t, 2)[1]) == (2, ) if not just_single_arg: assert isinstance(f(state, t, 2, 3)[0], B.RandomState) assert B.dtype(f(state, t, 2, 3)[1]) is dtype_transform(t) assert B.shape(f(state, t, 2, 3)[1]) == (2, 3) if not just_single_arg: # Test reference specification. assert B.dtype(f(f(t))) is dtype_transform(t) assert B.shape(f(f())) == () assert B.dtype(f(f(t, 2))) is dtype_transform(t) assert B.shape(f(f(t, 2))) == (2, ) assert B.dtype(f(f(t, 2, 3))) is dtype_transform(t) assert B.shape(f(f(t, 2, 3))) == (2, 3) # Test state and reference specification. assert isinstance(f(state, f(t))[0], B.RandomState) assert B.dtype(f(state, f(t))[1]) is dtype_transform(t) assert B.shape(f(state, f(t))[1]) == () assert isinstance(f(state, f(t, 2))[0], B.RandomState) assert B.dtype(f(state, f(t, 2))[1]) is dtype_transform(t) assert B.shape(f(state, f(t, 2))[1]) == (2, ) assert isinstance(f(state, f(t, 2, 3))[0], B.RandomState) assert B.dtype(f(state, f(t, 2, 3))[1]) is dtype_transform(t) assert B.shape(f(state, f(t, 2, 3))[1]) == (2, 3)
def test_create_random_state(dtype): # Test specification without argument. B.create_random_state(dtype) # Check that it does the right thing. state = B.create_random_state(dtype, seed=0) state, x1 = B.rand(state, dtype) state, x2 = B.rand(state, dtype) x1, x2 = to_np(x1), to_np(x2) state = B.create_random_state(dtype, seed=0) state, y1 = B.rand(state, dtype) state, y2 = B.rand(state, dtype) y1, y2 = to_np(y1), to_np(y2) assert x1 != x2 assert x1 == y1 assert x2 == y2
def test_normal_sampling(): for mean in [0, 1]: dist = Normal(mean, 3 * B.eye(np.int32, 200)) # Sample without noise. samples = dist.sample(2000) approx(B.mean(samples), mean, atol=5e-2) approx(B.std(samples) ** 2, 3, atol=5e-2) # Sample with noise samples = dist.sample(2000, noise=2) approx(B.mean(samples), mean, atol=5e-2) approx(B.std(samples) ** 2, 5, atol=5e-2) state, sample1 = dist.sample(B.create_random_state(B.dtype(dist), seed=0)) state, sample2 = dist.sample(B.create_random_state(B.dtype(dist), seed=0)) assert isinstance(state, B.RandomState) approx(sample1, sample2)
def test_set_seed_set_global_random_state(dtype, f_plain, check_lazy_shapes): B.set_random_seed(0) x1 = to_np(B.rand(dtype)) x2 = to_np(f_plain()) B.set_random_seed(0) y1 = to_np(B.rand(dtype)) y2 = to_np(f_plain()) assert x1 == y1 assert x2 == y2 B.set_global_random_state(B.create_random_state(dtype, seed=0)) x1 = to_np(B.rand(dtype)) x2 = to_np(f_plain()) B.set_global_random_state(B.create_random_state(dtype, seed=0)) y1 = to_np(B.rand(dtype)) y2 = to_np(f_plain()) assert x1 == y1 # TODO: Make this work with TF! if not isinstance(dtype, B.TFDType): assert x2 == y2
def test_choice(x, p, check_lazy_shapes): state = B.create_random_state(B.dtype(x)) # Make `p` a dictionary so that we can optionally give it. if p is None: p = {} else: # Cast weights to the right framework. p = {"p": B.cast(B.dtype(x), p)} # Check shape. assert B.shape(B.choice(x, **p)) == B.shape(x)[1:] assert B.shape(B.choice(x, 5, **p)) == (5, ) + B.shape(x)[1:] assert B.shape(B.choice(x, 5, 5, **p)) == (5, 5) + B.shape(x)[1:] assert isinstance(B.choice(state, x, **p)[0], B.RandomState) assert B.shape(B.choice(state, x, **p)[1]) == B.shape(x)[1:] assert B.shape(B.choice(state, x, 5, **p)[1]) == (5, ) + B.shape(x)[1:] assert B.shape(B.choice(state, x, 5, 5, **p)[1]) == (5, 5) + B.shape(x)[1:] # Check correctness. dtype = B.dtype(x) choices = set(to_np(B.choice(B.range(dtype, 5), 1000))) assert choices == set(to_np(B.range(dtype, 5)))
def test_framework_jax(t, check_lazy_shapes): assert isinstance(jnp.asarray(1), t) assert isinstance(jnp.float32, t) assert isinstance(B.create_random_state(jnp.float32), t) assert isinstance(B.device(jnp.asarray(1)), t)
def test_framework_torch(t, check_lazy_shapes): assert isinstance(torch.tensor(1), t) assert isinstance(torch.float32, t) assert isinstance(B.create_random_state(torch.float32), t) assert isinstance(B.device(torch.tensor(1)), t)
def test_framework_tf(t, check_lazy_shapes): assert isinstance(tf.constant(1), t) assert isinstance(tf.float32, t) assert isinstance(B.create_random_state(tf.float32), t) assert isinstance(B.device(tf.constant(1)), t)
def test_framework_ag(t, check_lazy_shapes): assert isinstance(autograd_box(np.array(1)), t) assert isinstance(np.float32, t) assert isinstance(B.create_random_state(np.float32), t)
def test_framework_np(t, check_lazy_shapes): assert isinstance(np.array(1), t) assert isinstance(np.float32, t) assert isinstance(B.create_random_state(np.float32), t)
def test_random_state(t, FWRandomState, check_lazy_shapes): assert isinstance(B.create_random_state(t), FWRandomState)