def test_truncated_normal(limit, stddev, seed): rng = np.random.RandomState(seed) dist = dists.TruncatedNormal(mean=0, stddev=stddev, limit=limit) if limit is None: limit = 2 * stddev samples = dist.sample(1000, 2000, rng=rng) assert samples.shape == (1000, 2000) assert np.allclose(np.mean(samples), 0.0, atol=5e-3) assert np.allclose(np.var(samples), tnorm_var(stddev, limit), rtol=5e-3) assert np.all(samples < limit) assert np.all(samples > -limit) # test with default rng samples = dist.sample(1000, 2000) assert samples.shape == (1000, 2000)
rng = np.random.RandomState(seed) dist = dists.TruncatedNormal(mean=0, stddev=stddev, limit=limit) if limit is None: limit = 2 * stddev samples = dist.sample(1000, 2000, rng=rng) assert samples.shape == (1000, 2000) assert np.allclose(np.mean(samples), 0.0, atol=5e-3) assert np.allclose(np.var(samples), tnorm_var(stddev, limit), rtol=5e-3) assert np.all(samples < limit) assert np.all(samples > -limit) # test with default rng samples = dist.sample(1000, 2000) assert samples.shape == (1000, 2000) @pytest.mark.parametrize( "dist", [ dists.TruncatedNormal(), dists.VarianceScaling(), dists.Glorot(), dists.He() ], ) def test_seeding(dist, seed): assert np.allclose( dist.sample(100, rng=np.random.RandomState(seed)), dist.sample(100, rng=np.random.RandomState(seed)), )