コード例 #1
0
    def test_basics(self, rng_ctor):
        random = RandomStream(seed=utt.fetch_seed(), rng_ctor=rng_ctor)

        with pytest.raises(ValueError):
            random.uniform(0, 1, size=(2, 2), rng=np.random.default_rng(23))

        with pytest.raises(AttributeError):
            random.blah

        assert hasattr(random, "standard_normal")

        with pytest.raises(AttributeError):
            np_random = RandomStream(namespace=np, rng_ctor=rng_ctor)
            np_random.ndarray

        fn = function([],
                      random.uniform(0, 1, size=(2, 2)),
                      updates=random.updates())

        fn_val0 = fn()
        fn_val1 = fn()

        rng_seed = np.random.SeedSequence(utt.fetch_seed())
        (rng_seed, ) = rng_seed.spawn(1)
        rng = random.rng_ctor(rng_seed)

        numpy_val0 = rng.uniform(0, 1, size=(2, 2))
        numpy_val1 = rng.uniform(0, 1, size=(2, 2))

        assert np.allclose(fn_val0, numpy_val0)
        assert np.allclose(fn_val1, numpy_val1)
コード例 #2
0
    def test_basics(self):
        random = RandomStream(seed=utt.fetch_seed())

        with pytest.raises(TypeError):
            random.uniform(0, 1, size=(2, 2), rng=np.random.RandomState(23))

        with pytest.raises(AttributeError):
            random.blah

        with pytest.raises(AttributeError):
            np_random = RandomStream(namespace=np)
            np_random.ndarray

        fn = function([],
                      random.uniform(0, 1, size=(2, 2)),
                      updates=random.updates())

        fn_val0 = fn()
        fn_val1 = fn()

        rng_seed = np.random.RandomState(utt.fetch_seed()).randint(2**30)
        rng = np.random.RandomState(int(rng_seed))  # int() is for 32bit

        numpy_val0 = rng.uniform(0, 1, size=(2, 2))
        numpy_val1 = rng.uniform(0, 1, size=(2, 2))

        assert np.allclose(fn_val0, numpy_val0)
        assert np.allclose(fn_val1, numpy_val1)
コード例 #3
0
    def test_default_updates(self, rng_ctor):
        # Basic case: default_updates
        random_a = RandomStream(utt.fetch_seed(), rng_ctor=rng_ctor)
        out_a = random_a.uniform(0, 1, size=(2, 2))
        fn_a = function([], out_a)
        fn_a_val0 = fn_a()
        fn_a_val1 = fn_a()
        assert not np.all(fn_a_val0 == fn_a_val1)

        nearly_zeros = function([], out_a + out_a - 2 * out_a)
        assert np.all(abs(nearly_zeros()) < 1e-5)

        # Explicit updates #1
        random_b = RandomStream(utt.fetch_seed(), rng_ctor=rng_ctor)
        out_b = random_b.uniform(0, 1, size=(2, 2))
        fn_b = function([], out_b, updates=random_b.updates())
        fn_b_val0 = fn_b()
        fn_b_val1 = fn_b()
        assert np.all(fn_b_val0 == fn_a_val0)
        assert np.all(fn_b_val1 == fn_a_val1)

        # Explicit updates #2
        random_c = RandomStream(utt.fetch_seed(), rng_ctor=rng_ctor)
        out_c = random_c.uniform(0, 1, size=(2, 2))
        fn_c = function([], out_c, updates=random_c.state_updates)
        fn_c_val0 = fn_c()
        fn_c_val1 = fn_c()
        assert np.all(fn_c_val0 == fn_a_val0)
        assert np.all(fn_c_val1 == fn_a_val1)

        # No updates at all
        random_d = RandomStream(utt.fetch_seed(), rng_ctor=rng_ctor)
        out_d = random_d.uniform(0, 1, size=(2, 2))
        fn_d = function([], out_d, no_default_updates=True)
        fn_d_val0 = fn_d()
        fn_d_val1 = fn_d()
        assert np.all(fn_d_val0 == fn_a_val0)
        assert np.all(fn_d_val1 == fn_d_val0)

        # No updates for out
        random_e = RandomStream(utt.fetch_seed(), rng_ctor=rng_ctor)
        out_e = random_e.uniform(0, 1, size=(2, 2))
        fn_e = function([],
                        out_e,
                        no_default_updates=[random_e.state_updates[0][0]])
        fn_e_val0 = fn_e()
        fn_e_val1 = fn_e()
        assert np.all(fn_e_val0 == fn_a_val0)
        assert np.all(fn_e_val1 == fn_e_val0)