def test_basics(self, rng_ctor): random = RandomStream(seed=utt.fetch_seed(), rng_ctor=rng_ctor) with pytest.raises(ValueError): random.uniform(0, 1, size=(2, 2), rng=np.random.default_rng(23)) with pytest.raises(AttributeError): random.blah assert hasattr(random, "standard_normal") with pytest.raises(AttributeError): np_random = RandomStream(namespace=np, rng_ctor=rng_ctor) np_random.ndarray fn = function([], random.uniform(0, 1, size=(2, 2)), updates=random.updates()) fn_val0 = fn() fn_val1 = fn() rng_seed = np.random.SeedSequence(utt.fetch_seed()) (rng_seed, ) = rng_seed.spawn(1) rng = random.rng_ctor(rng_seed) numpy_val0 = rng.uniform(0, 1, size=(2, 2)) numpy_val1 = rng.uniform(0, 1, size=(2, 2)) assert np.allclose(fn_val0, numpy_val0) assert np.allclose(fn_val1, numpy_val1)
def test_basics(self): random = RandomStream(seed=utt.fetch_seed()) with pytest.raises(TypeError): random.uniform(0, 1, size=(2, 2), rng=np.random.RandomState(23)) with pytest.raises(AttributeError): random.blah with pytest.raises(AttributeError): np_random = RandomStream(namespace=np) np_random.ndarray fn = function([], random.uniform(0, 1, size=(2, 2)), updates=random.updates()) fn_val0 = fn() fn_val1 = fn() rng_seed = np.random.RandomState(utt.fetch_seed()).randint(2**30) rng = np.random.RandomState(int(rng_seed)) # int() is for 32bit numpy_val0 = rng.uniform(0, 1, size=(2, 2)) numpy_val1 = rng.uniform(0, 1, size=(2, 2)) assert np.allclose(fn_val0, numpy_val0) assert np.allclose(fn_val1, numpy_val1)
def test_default_updates(self, rng_ctor): # Basic case: default_updates random_a = RandomStream(utt.fetch_seed(), rng_ctor=rng_ctor) out_a = random_a.uniform(0, 1, size=(2, 2)) fn_a = function([], out_a) fn_a_val0 = fn_a() fn_a_val1 = fn_a() assert not np.all(fn_a_val0 == fn_a_val1) nearly_zeros = function([], out_a + out_a - 2 * out_a) assert np.all(abs(nearly_zeros()) < 1e-5) # Explicit updates #1 random_b = RandomStream(utt.fetch_seed(), rng_ctor=rng_ctor) out_b = random_b.uniform(0, 1, size=(2, 2)) fn_b = function([], out_b, updates=random_b.updates()) fn_b_val0 = fn_b() fn_b_val1 = fn_b() assert np.all(fn_b_val0 == fn_a_val0) assert np.all(fn_b_val1 == fn_a_val1) # Explicit updates #2 random_c = RandomStream(utt.fetch_seed(), rng_ctor=rng_ctor) out_c = random_c.uniform(0, 1, size=(2, 2)) fn_c = function([], out_c, updates=random_c.state_updates) fn_c_val0 = fn_c() fn_c_val1 = fn_c() assert np.all(fn_c_val0 == fn_a_val0) assert np.all(fn_c_val1 == fn_a_val1) # No updates at all random_d = RandomStream(utt.fetch_seed(), rng_ctor=rng_ctor) out_d = random_d.uniform(0, 1, size=(2, 2)) fn_d = function([], out_d, no_default_updates=True) fn_d_val0 = fn_d() fn_d_val1 = fn_d() assert np.all(fn_d_val0 == fn_a_val0) assert np.all(fn_d_val1 == fn_d_val0) # No updates for out random_e = RandomStream(utt.fetch_seed(), rng_ctor=rng_ctor) out_e = random_e.uniform(0, 1, size=(2, 2)) fn_e = function([], out_e, no_default_updates=[random_e.state_updates[0][0]]) fn_e_val0 = fn_e() fn_e_val1 = fn_e() assert np.all(fn_e_val0 == fn_a_val0) assert np.all(fn_e_val1 == fn_e_val0)