예제 #1
0
def _generate_kernel(inputs, input_schema, output_schema):
    mapping, in_shape, out_shape, out_dtype = stokes_convert_setup(
        inputs, input_schema, output_schema)

    # Flatten input and output shapes
    # Check that number elements are the same
    in_elems = reduce(mul, in_shape, 1)
    out_elems = reduce(mul, out_shape, 1)

    if in_elems != out_elems:
        raise ValueError("Number of input_schema elements %s "
                         "and output schema elements %s "
                         "must match for CUDA kernel." % (in_shape, out_shape))

    # Infer the output data type
    if out_dtype == "real":
        if np.iscomplexobj(inputs):
            out_dtype = inputs.real.dtype
        else:
            out_dtype = inputs.dtype
    elif out_dtype == "complex":
        if np.iscomplexobj(inputs):
            out_dtype = inputs.dtype
        else:
            out_dtype = np.result_type(inputs.dtype, np.complex64)
    else:
        raise ValueError("Invalid setup dtype %s" % out_dtype)

    cuda_out_dtype = cuda_type(out_dtype)
    assign_exprs = []

    # Render the assignment expression for each element
    for (c1, c1i), (c2, c2i), outi, template_fn in mapping:
        # Flattened indices
        flat_outi = np.ravel_multi_index(outi, out_shape)
        render = jinja_env.from_string(template_fn).render
        kwargs = {
            c1: "in[%d]" % np.ravel_multi_index(c1i, in_shape),
            c2: "in[%d]" % np.ravel_multi_index(c2i, in_shape),
            "out_type": cuda_out_dtype
        }

        expr_str = render(**kwargs)
        assign_exprs.append("out[%d] = %s;" % (flat_outi, expr_str))

    # Now render the main template
    render = jinja_env.get_template(_TEMPLATE_PATH).render
    name = "stokes_convert"
    code = render(kernel_name=name,
                  input_type=cuda_type(inputs.dtype),
                  output_type=cuda_type(out_dtype),
                  assign_exprs=assign_exprs,
                  elements=in_elems).encode("utf-8")

    # cuda block, flatten non-schema dims into a single source dim
    blockdimx = 512
    block = (blockdimx, 1, 1)

    return (cp.RawKernel(code, name), block, in_shape, out_shape, out_dtype)
예제 #2
0
def _generate_kernel(time_index, antenna1, antenna2, dde1_jones, source_coh,
                     dde2_jones, die1_jones, base_vis, die2_jones, corrs,
                     out_ndim):

    tup = predict_checks(time_index, antenna1, antenna2, dde1_jones,
                         source_coh, dde2_jones, die1_jones, base_vis,
                         die2_jones)

    (have_ddes1, have_coh, have_ddes2, have_dies1, have_bvis, have_dies2) = tup

    # Check types
    if time_index.dtype != np.int32:
        raise TypeError("time_index.dtype != np.int32 '%s'" % time_index.dtype)

    if antenna1.dtype != np.int32:
        raise TypeError("antenna1.dtype != np.int32 '%s'" % antenna1.dtype)

    if antenna2.dtype != np.int32:
        raise TypeError("antenna2.dtype != np.int32 '%s'" % antenna2.dtype)

    # Create template
    render = jinja_env.get_template(_TEMPLATE_PATH).render
    name = "predict_vis"

    # Complex output type
    out_dtype = np.result_type(dde1_jones, source_coh, dde2_jones, die1_jones,
                               base_vis, die2_jones)

    ncorrs = reduce(mul, corrs, 1)

    # corrs x channels, rows
    blockdimx = 32
    blockdimy = 24 if out_dtype == np.complex128 else 32

    block = (blockdimx, blockdimy, 1)

    code = render(kernel_name=name,
                  blockdimx=blockdimx,
                  blockdimy=blockdimy,
                  have_dde1=have_ddes1,
                  dde1_type=cuda_type(dde1_jones) if have_ddes1 else "int",
                  dde1_ndim=dde1_jones.ndim if have_ddes1 else 1,
                  have_dde2=have_ddes2,
                  dde2_type=cuda_type(dde2_jones) if have_ddes2 else "int",
                  dde2_ndim=dde2_jones.ndim if have_ddes2 else 1,
                  have_coh=have_coh,
                  coh_type=cuda_type(source_coh) if have_coh else "int",
                  coh_ndim=source_coh.ndim if have_coh else 1,
                  have_die1=have_dies1,
                  die1_type=cuda_type(die1_jones) if have_dies1 else "int",
                  die1_ndim=die1_jones.ndim if have_dies1 else 1,
                  have_base_vis=have_bvis,
                  base_vis_type=cuda_type(base_vis) if have_bvis else "int",
                  base_vis_ndim=base_vis.ndim if have_bvis else 1,
                  have_die2=have_dies2,
                  die2_type=cuda_type(die2_jones) if have_dies2 else "int",
                  die2_ndim=die2_jones.ndim if have_dies2 else 1,
                  out_type=cuda_type(out_dtype),
                  corrs=ncorrs,
                  out_ndim=out_ndim,
                  warp_size=32).encode('utf-8')

    return cp.RawKernel(code, name), block, out_dtype