def _build_plan(self, plan_factory, device_params, alpha, beta, seed): plan = plan_factory() bijection = philox(64, 2) # Keeping the kernel the same so it can be cached. # The seed will be passed as the computation parameter instead. keygen = KeyGenerator.create(bijection, seed=numpy.int32(0)) sampler = normal_bm(bijection, numpy.float64) squeezing = plan.persistent_array(self._system.squeezing) decoherence = plan.persistent_array(self._system.decoherence) plan.kernel_call(TEMPLATE.get_def("generate_input_state"), [alpha, beta, squeezing, decoherence, seed], kernel_name="generate", global_size=alpha.shape, render_kwds=dict( system=self._system, representation=self._representation, Representation=Representation, bijection=bijection, keygen=keygen, sampler=sampler, ordering=ordering, exp=functions.exp(numpy.float64), mul_cr=functions.mul(numpy.complex128, numpy.float64), add_cc=functions.add(numpy.complex128, numpy.complex128), )) return plan
def __init__(self, randoms_arr, generators_dim, sampler, seed=None): self._sampler = sampler self._keygen = KeyGenerator.create(sampler.bijection, seed=seed, reserve_id_space=True) assert sampler.dtype == randoms_arr.dtype counters_size = randoms_arr.shape[-generators_dim:] self._generators_dim = generators_dim self._counters_t = Type(sampler.bijection.counter_dtype, shape=counters_size) Computation.__init__(self, [ Parameter('counters', Annotation(self._counters_t, 'io')), Parameter('randoms', Annotation(randoms_arr, 'o'))])
def check_kernel_sampler(thr, sampler, extent=None, mean=None, std=None): size = 10000 batch = 100 seed = 456 bijection = sampler.bijection keygen = KeyGenerator.create(bijection, seed=seed) rng_kernel = thr.compile_static(""" KERNEL void test(GLOBAL_MEM ${ctype} *dest, int ctr_start) { VIRTUAL_SKIP_THREADS; const VSIZE_T idx = virtual_global_id(0); ${bijection.module}Key key = ${keygen.module}key_from_int(idx); ${bijection.module}Counter ctr = ${bijection.module}make_counter_from_int(ctr_start); ${bijection.module}State st = ${bijection.module}make_state(key, ctr); ${sampler.module}Result res; for(int j = 0; j < ${batch}; j++) { res = ${sampler.module}sample(&st); %for i in range(sampler.randoms_per_call): dest[j * ${size * sampler.randoms_per_call} + ${size * i} + idx] = res.v[${i}]; %endfor } ${bijection.module}Counter next_ctr = ${bijection.module}get_next_unused_counter(st); } """, 'test', size, render_kwds=dict(size=size, batch=batch, ctype=dtypes.ctype( sampler.dtype), bijection=bijection, keygen=keygen, sampler=sampler)) dest = thr.array((batch, sampler.randoms_per_call, size), sampler.dtype) rng_kernel(dest, numpy.int32(0)) dest = dest.get() check_distribution(dest, extent=extent, mean=mean, std=std)
def check_kernel_sampler(thr, sampler, extent=None, mean=None, std=None): size = 10000 batch = 100 seed = 456 bijection = sampler.bijection keygen = KeyGenerator.create(bijection, seed=seed) rng_kernel = thr.compile_static( """ KERNEL void test(GLOBAL_MEM ${ctype} *dest, int ctr_start) { VIRTUAL_SKIP_THREADS; const VSIZE_T idx = virtual_global_id(0); ${bijection.module}Key key = ${keygen.module}key_from_int(idx); ${bijection.module}Counter ctr = ${bijection.module}make_counter_from_int(ctr_start); ${bijection.module}State st = ${bijection.module}make_state(key, ctr); ${sampler.module}Result res; for(int j = 0; j < ${batch}; j++) { res = ${sampler.module}sample(&st); %for i in range(sampler.randoms_per_call): dest[j * ${size * sampler.randoms_per_call} + ${size * i} + idx] = res.v[${i}]; %endfor } ${bijection.module}Counter next_ctr = ${bijection.module}get_next_unused_counter(st); } """, 'test', size, render_kwds=dict( size=size, batch=batch, ctype=dtypes.ctype(sampler.dtype), bijection=bijection, keygen=keygen, sampler=sampler)) dest = thr.array((batch, sampler.randoms_per_call, size), sampler.dtype) rng_kernel(dest, numpy.int32(0)) dest = dest.get() check_distribution(dest, extent=extent, mean=mean, std=std)
def test_kernel_bijection(thr, test_bijection): size = 1000 seed = 123 bijection = test_bijection.bijection keygen = KeyGenerator.create(bijection, seed=seed, reserve_id_space=False) counters_ref = numpy.zeros(size, bijection.counter_dtype) rng_kernel = thr.compile_static(""" KERNEL void test(GLOBAL_MEM ${bijection.module}Counter *dest, int ctr) { VIRTUAL_SKIP_THREADS; const VSIZE_T idx = virtual_global_id(0); ${bijection.module}Key key = ${keygen.module}key_from_int(idx); ${bijection.module}Counter counter = ${bijection.module}make_counter_from_int(ctr); ${bijection.module}Counter result = ${bijection.module}bijection(key, counter); dest[idx] = result; } """, 'test', size, render_kwds=dict(bijection=bijection, keygen=keygen)) dest = thr.array(size, bijection.counter_dtype) rng_kernel(dest, numpy.int32(0)) dest_ref = test_bijection.reference(counters_ref, keygen.reference) assert (dest.get() == dest_ref).all() rng_kernel(dest, numpy.int32(1)) counters_ref['v'][:, -1] = 1 dest_ref = test_bijection.reference(counters_ref, keygen.reference) assert (dest.get() == dest_ref).all()
def test_kernel_bijection(thr, test_bijection): size = 1000 seed = 123 bijection = test_bijection.bijection keygen = KeyGenerator.create(bijection, seed=seed, reserve_id_space=False) counters_ref = numpy.zeros(size, bijection.counter_dtype) rng_kernel = thr.compile_static( """ KERNEL void test(GLOBAL_MEM ${bijection.module}Counter *dest, int ctr) { VIRTUAL_SKIP_THREADS; const VSIZE_T idx = virtual_global_id(0); ${bijection.module}Key key = ${keygen.module}key_from_int(idx); ${bijection.module}Counter counter = ${bijection.module}make_counter_from_int(ctr); ${bijection.module}Counter result = ${bijection.module}bijection(key, counter); dest[idx] = result; } """, 'test', size, render_kwds=dict(bijection=bijection, keygen=keygen)) dest = thr.array(size, bijection.counter_dtype) rng_kernel(dest, numpy.int32(0)) dest_ref = test_bijection.reference(counters_ref, keygen.reference) assert (dest.get() == dest_ref).all() rng_kernel(dest, numpy.int32(1)) counters_ref['v'][:,-1] = 1 dest_ref = test_bijection.reference(counters_ref, keygen.reference) assert (dest.get() == dest_ref).all()
def _build_plan(self, plan_factory, device_params, alpha, beta, alpha_i, beta_i, seed): plan = plan_factory() system = self._system representation = self._representation unitary = plan.persistent_array(self._system.unitary) needs_noise_matrix = representation != Representation.POSITIVE_P and system.needs_noise_matrix( ) mmul = MatrixMul(alpha, unitary, transposed_b=True) if not needs_noise_matrix: # TODO: this could be sped up for repr != POSITIVE_P, # since in that case alpha == conj(beta), and we don't need to do two multuplications. mmul_beta = MatrixMul(beta, unitary, transposed_b=True) trf_conj = self._make_trf_conj() mmul_beta.parameter.matrix_b.connect(trf_conj, trf_conj.output, matrix_b_p=trf_conj.input) plan.computation_call(mmul, alpha, alpha_i, unitary) plan.computation_call(mmul_beta, beta, beta_i, unitary) else: noise_matrix = system.noise_matrix() noise_matrix_dev = plan.persistent_array(noise_matrix) # If we're here, it's not positive-P, and alpha == conj(beta). # This means we can just calculate alpha, and then build beta from it. w = plan.temp_array_like(alpha) temp_alpha = plan.temp_array_like(alpha) plan.computation_call(mmul, temp_alpha, alpha_i, unitary) bijection = philox(64, 2) # Keeping the kernel the same so it can be cached. # The seed will be passed as the computation parameter instead. keygen = KeyGenerator.create(bijection, seed=numpy.int32(0)) sampler = normal_bm(bijection, numpy.float64) plan.kernel_call(TEMPLATE.get_def("generate_apply_matrix_noise"), [w, seed], kernel_name="generate_apply_matrix_noise", global_size=alpha.shape, render_kwds=dict( bijection=bijection, keygen=keygen, sampler=sampler, mul_cr=functions.mul(numpy.complex128, numpy.float64), add_cc=functions.add(numpy.complex128, numpy.complex128), )) noise = plan.temp_array_like(alpha) plan.computation_call(mmul, noise, w, noise_matrix_dev) plan.kernel_call(TEMPLATE.get_def("add_noise"), [alpha, beta, temp_alpha, noise], kernel_name="add_noise", global_size=alpha.shape, render_kwds=dict( add=functions.add(numpy.complex128, numpy.complex128), conj=functions.conj(numpy.complex128))) return plan