Exemple #1
0
def test_overflow_gpu_new_backend():
    seed = 12345
    n_substreams = 7
    curr_rstate = np.array([seed] * 6, dtype='int32')
    rstate = [curr_rstate.copy()]
    for j in range(1, n_substreams):
        rstate.append(rng_mrg.ff_2p72(rstate[-1]))
    rstate = np.asarray(rstate)
    rstate = gpuarray_shared_constructor(rstate)
    fct = functools.partial(GPUA_mrg_uniform.new,
                            rstate,
                            ndim=None,
                            dtype='float32')
    # should raise error as the size overflows
    sizes = [(2**31, ), (2**32, ), (
        2**15,
        2**16,
    ), (2, 2**15, 2**15)]
    rng_mrg_overflow(sizes, fct, mode, should_raise_error=True)
    # should not raise error
    sizes = [(2**5, ), (2**5, 2**5), (2**5, 2**5, 2**5)]
    rng_mrg_overflow(sizes, fct, mode, should_raise_error=False)
    # should support int32 sizes
    sizes = [(np.int32(2**10), ),
             (np.int32(2), np.int32(2**10), np.int32(2**10))]
    rng_mrg_overflow(sizes, fct, mode, should_raise_error=False)
Exemple #2
0
def test_overflow_gpu_new_backend():
    seed = 12345
    n_substreams = 7
    curr_rstate = np.array([seed] * 6, dtype='int32')
    rstate = [curr_rstate.copy()]
    for j in range(1, n_substreams):
        rstate.append(rng_mrg.ff_2p72(rstate[-1]))
    rstate = np.asarray(rstate)
    rstate = gpuarray_shared_constructor(rstate)
    fct = functools.partial(GPUA_mrg_uniform.new, rstate,
                            ndim=None, dtype='float32')
    # should raise error as the size overflows
    sizes = [(2**31, ), (2**32, ), (2**15, 2**16,), (2, 2**15, 2**15)]
    rng_mrg_overflow(sizes, fct, mode, should_raise_error=True)
    # should not raise error
    sizes = [(2**5, ), (2**5, 2**5), (2**5, 2**5, 2**5)]
    rng_mrg_overflow(sizes, fct, mode, should_raise_error=False)
    # should support int32 sizes
    sizes = [(np.int32(2**10), ),
             (np.int32(2), np.int32(2**10), np.int32(2**10))]
    rng_mrg_overflow(sizes, fct, mode, should_raise_error=False)