def test_computation_performance(thr_and_double, fast_math, test_sampler_float): thr, double = thr_and_double size = 2**15 batch = 2**6 sampler = test_sampler_float.get_sampler(double) rng = CBRNG(Type(sampler.dtype, shape=(batch, size)), 1, sampler) dest_dev = thr.empty_like(rng.parameter.randoms) counters = rng.create_counters() counters_dev = thr.to_device(counters) rngc = rng.compile(thr, fast_math=fast_math) attempts = 10 times = [] for i in range(attempts): t1 = time.time() rngc(counters_dev, dest_dev) thr.synchronize() times.append(time.time() - t1) byte_size = size * batch * sampler.dtype.itemsize return min(times), byte_size
def test_computation_general(thr_and_double): size = 10000 batch = 101 thr, double = thr_and_double dtype = numpy.float64 if double else numpy.float32 mean, std = -2, 10 bijection = philox(64, 4) sampler = normal_bm(bijection, dtype, mean=mean, std=std) rng = CBRNG(Type(dtype, shape=(batch, size)), 1, sampler) check_computation(thr, rng, mean=mean, std=std)
def test_computation_general(thr_and_double): size = 10000 batch = 101 thr, double = thr_and_double bijection = philox(64, 4) ref = NormalBMHelper(mean=-2, std=10) sampler = ref.get_sampler(bijection, double) rng = CBRNG(Type(sampler.dtype, shape=(batch, size)), 1, sampler) check_computation(thr, rng, ref)