def test_normal_vector(self): random = RandomStreams(utt.fetch_seed()) avg = tensor.dvector() std = tensor.dvector() out = random.normal(avg=avg, std=std) assert out.ndim == 1 f = function([avg, std], out) avg_val = [1, 2, 3] std_val = [0.1, 0.2, 0.3] seed_gen = np.random.RandomState(utt.fetch_seed()) numpy_rng = np.random.RandomState(int(seed_gen.randint(2**30))) # Arguments of size (3,) val0 = f(avg_val, std_val) numpy_val0 = numpy_rng.normal(loc=avg_val, scale=std_val) assert np.allclose(val0, numpy_val0) # arguments of size (2,) val1 = f(avg_val[:-1], std_val[:-1]) numpy_val1 = numpy_rng.normal(loc=avg_val[:-1], scale=std_val[:-1]) assert np.allclose(val1, numpy_val1) # Specifying the size explicitly g = function([avg, std], random.normal(avg=avg, std=std, size=(3, ))) val2 = g(avg_val, std_val) numpy_rng = np.random.RandomState(int(seed_gen.randint(2**30))) numpy_val2 = numpy_rng.normal(loc=avg_val, scale=std_val, size=(3, )) assert np.allclose(val2, numpy_val2) with pytest.raises(ValueError): g(avg_val[:-1], std_val[:-1])
def test_tutorial(self): srng = RandomStreams(seed=234) rv_u = srng.uniform((2, 2)) rv_n = srng.normal((2, 2)) f = function([], rv_u) g = function([], rv_n, no_default_updates=True) # Not updating rv_n.rng nearly_zeros = function([], rv_u + rv_u - 2 * rv_u) assert np.all(f() != f()) assert np.all(g() == g()) assert np.all(abs(nearly_zeros()) < 1e-5) assert isinstance(rv_u.rng.get_value(borrow=True), np.random.RandomState)
def test_normal(self): # Test that RandomStreams.normal generates the same results as numpy # Check over two calls to see if the random state is correctly updated. random = RandomStreams(utt.fetch_seed()) fn = function([], random.normal((2, 2), -1, 2)) fn_val0 = fn() fn_val1 = fn() rng_seed = np.random.RandomState(utt.fetch_seed()).randint(2**30) rng = np.random.RandomState(int(rng_seed)) # int() is for 32bit numpy_val0 = rng.normal(-1, 2, size=(2, 2)) numpy_val1 = rng.normal(-1, 2, size=(2, 2)) assert np.allclose(fn_val0, numpy_val0) assert np.allclose(fn_val1, numpy_val1)