def _generate_kernel(inputs, input_schema, output_schema): mapping, in_shape, out_shape, out_dtype = stokes_convert_setup( inputs, input_schema, output_schema) # Flatten input and output shapes # Check that number elements are the same in_elems = reduce(mul, in_shape, 1) out_elems = reduce(mul, out_shape, 1) if in_elems != out_elems: raise ValueError("Number of input_schema elements %s " "and output schema elements %s " "must match for CUDA kernel." % (in_shape, out_shape)) # Infer the output data type if out_dtype == "real": if np.iscomplexobj(inputs): out_dtype = inputs.real.dtype else: out_dtype = inputs.dtype elif out_dtype == "complex": if np.iscomplexobj(inputs): out_dtype = inputs.dtype else: out_dtype = np.result_type(inputs.dtype, np.complex64) else: raise ValueError("Invalid setup dtype %s" % out_dtype) cuda_out_dtype = cuda_type(out_dtype) assign_exprs = [] # Render the assignment expression for each element for (c1, c1i), (c2, c2i), outi, template_fn in mapping: # Flattened indices flat_outi = np.ravel_multi_index(outi, out_shape) render = jinja_env.from_string(template_fn).render kwargs = { c1: "in[%d]" % np.ravel_multi_index(c1i, in_shape), c2: "in[%d]" % np.ravel_multi_index(c2i, in_shape), "out_type": cuda_out_dtype } expr_str = render(**kwargs) assign_exprs.append("out[%d] = %s;" % (flat_outi, expr_str)) # Now render the main template render = jinja_env.get_template(_TEMPLATE_PATH).render name = "stokes_convert" code = render(kernel_name=name, input_type=cuda_type(inputs.dtype), output_type=cuda_type(out_dtype), assign_exprs=assign_exprs, elements=in_elems).encode("utf-8") # cuda block, flatten non-schema dims into a single source dim blockdimx = 512 block = (blockdimx, 1, 1) return (cp.RawKernel(code, name), block, in_shape, out_shape, out_dtype)
def _generate_kernel(time_index, antenna1, antenna2, dde1_jones, source_coh, dde2_jones, die1_jones, base_vis, die2_jones, corrs, out_ndim): tup = predict_checks(time_index, antenna1, antenna2, dde1_jones, source_coh, dde2_jones, die1_jones, base_vis, die2_jones) (have_ddes1, have_coh, have_ddes2, have_dies1, have_bvis, have_dies2) = tup # Check types if time_index.dtype != np.int32: raise TypeError("time_index.dtype != np.int32 '%s'" % time_index.dtype) if antenna1.dtype != np.int32: raise TypeError("antenna1.dtype != np.int32 '%s'" % antenna1.dtype) if antenna2.dtype != np.int32: raise TypeError("antenna2.dtype != np.int32 '%s'" % antenna2.dtype) # Create template render = jinja_env.get_template(_TEMPLATE_PATH).render name = "predict_vis" # Complex output type out_dtype = np.result_type(dde1_jones, source_coh, dde2_jones, die1_jones, base_vis, die2_jones) ncorrs = reduce(mul, corrs, 1) # corrs x channels, rows blockdimx = 32 blockdimy = 24 if out_dtype == np.complex128 else 32 block = (blockdimx, blockdimy, 1) code = render(kernel_name=name, blockdimx=blockdimx, blockdimy=blockdimy, have_dde1=have_ddes1, dde1_type=cuda_type(dde1_jones) if have_ddes1 else "int", dde1_ndim=dde1_jones.ndim if have_ddes1 else 1, have_dde2=have_ddes2, dde2_type=cuda_type(dde2_jones) if have_ddes2 else "int", dde2_ndim=dde2_jones.ndim if have_ddes2 else 1, have_coh=have_coh, coh_type=cuda_type(source_coh) if have_coh else "int", coh_ndim=source_coh.ndim if have_coh else 1, have_die1=have_dies1, die1_type=cuda_type(die1_jones) if have_dies1 else "int", die1_ndim=die1_jones.ndim if have_dies1 else 1, have_base_vis=have_bvis, base_vis_type=cuda_type(base_vis) if have_bvis else "int", base_vis_ndim=base_vis.ndim if have_bvis else 1, have_die2=have_dies2, die2_type=cuda_type(die2_jones) if have_dies2 else "int", die2_ndim=die2_jones.ndim if have_dies2 else 1, out_type=cuda_type(out_dtype), corrs=ncorrs, out_ndim=out_ndim, warp_size=32).encode('utf-8') return cp.RawKernel(code, name), block, out_dtype