コード例 #1
0
    def test_random_integers_vector(self):
        rng_R = random_state_type()
        low = tensor.lvector()
        high = tensor.lvector()
        post_r, out = random_integers(rng_R, low=low, high=high)
        assert out.ndim == 1
        f = compile.function([rng_R, low, high], [post_r, out], accept_inplace=True)

        low_val = [100, 200, 300]
        high_val = [110, 220, 330]
        rng = np.random.RandomState(utt.fetch_seed())
        numpy_rng = np.random.RandomState(utt.fetch_seed())

        # Arguments of size (3,)
        rng0, val0 = f(rng, low_val, high_val)
        numpy_val0 = np.asarray(
            [
                numpy_rng.randint(low=lv, high=hv + 1)
                for lv, hv in zip(low_val, high_val)
            ]
        )
        assert np.all(val0 == numpy_val0)

        # arguments of size (2,)
        rng1, val1 = f(rng0, low_val[:-1], high_val[:-1])
        numpy_val1 = np.asarray(
            [
                numpy_rng.randint(low=lv, high=hv + 1)
                for lv, hv in zip(low_val[:-1], high_val[:-1])
            ]
        )
        assert np.all(val1 == numpy_val1)

        # Specifying the size explicitly
        g = compile.function(
            [rng_R, low, high],
            random_integers(rng_R, low=low, high=high, size=(3,)),
            accept_inplace=True,
        )
        rng2, val2 = g(rng1, low_val, high_val)
        numpy_val2 = np.asarray(
            [
                numpy_rng.randint(low=lv, high=hv + 1)
                for lv, hv in zip(low_val, high_val)
            ]
        )
        assert np.all(val2 == numpy_val2)
        with pytest.raises(ValueError):
            g(rng2, low_val[:-1], high_val[:-1])
コード例 #2
0
    def test_random_integers(self):
        # Test that raw_random.random_integers generates the same
        # results as numpy.  We use randint() for comparison since
        # random_integers() is deprecated.

        # Check over two calls to see if the random state is correctly updated.
        rng_R = random_state_type()
        # Use non-default parameters, and larger dimensions because of
        # the integer nature of the result
        post_r, out = random_integers(rng_R, (11, 8), -3, 16)

        f = compile.function(
            [
                compile.In(
                    rng_R,
                    value=np.random.RandomState(utt.fetch_seed()),
                    update=post_r,
                    mutable=True,
                )
            ],
            [out],
            accept_inplace=True,
        )

        numpy_rng = np.random.RandomState(utt.fetch_seed())
        val0 = f()
        val1 = f()
        numpy_val0 = numpy_rng.randint(-3, 17, size=(11, 8))
        numpy_val1 = numpy_rng.randint(-3, 17, size=(11, 8))
        assert np.allclose(val0, numpy_val0)
        assert np.allclose(val1, numpy_val1)
コード例 #3
0
    def test_pkl(self):
        # Test pickling of RandomFunction.
        # binomial was created by calling RandomFunction on a string,
        # random_integers by calling it on a function.
        rng_r = random_state_type()
        mode = None
        if theano.config.mode in ["DEBUG_MODE", "DebugMode"]:
            mode = "FAST_COMPILE"
        post_bin_r, bin_sample = binomial(rng_r, (3, 5), 1, 0.3)
        f = theano.function([rng_r], [post_bin_r, bin_sample], mode=mode)
        pickle.dumps(f)

        post_int_r, int_sample = random_integers(rng_r, (3, 5), -1, 8)
        g = theano.function([rng_r], [post_int_r, int_sample], mode=mode)
        pkl_g = pickle.dumps(g)
        pickle.loads(pkl_g)
コード例 #4
0
    def test_dtype(self):
        rng_R = random_state_type()
        low = tensor.lscalar()
        high = tensor.lscalar()
        post_r, out = random_integers(
            rng_R, low=low, high=high, size=(20,), dtype="int8"
        )
        assert out.dtype == "int8"
        f = compile.function([rng_R, low, high], [post_r, out])

        rng = np.random.RandomState(utt.fetch_seed())
        rng0, val0 = f(rng, 0, 9)
        assert val0.dtype == "int8"

        rng1, val1 = f(rng0, 255, 257)
        assert val1.dtype == "int8"
        assert np.all(abs(val1) <= 1)