def wrap_in_error_handler(body, arg_names): from pytools.py_codegen import PythonCodeGenerator, Indentation err_gen = PythonCodeGenerator() def gen_error_handler(): err_gen(""" if current_arg is not None: args = [{args}] advice = "" from pyopencl.array import Array if isinstance(args[current_arg], Array): advice = " (perhaps you meant to pass 'array.data' " \ "instead of the array itself?)" raise _cl.LogicError( "when processing argument #%d (1-based): %s%s" % (current_arg+1, str(e), advice)) else: raise """ .format(args=", ".join(arg_names))) err_gen("") err_gen("try:") with Indentation(err_gen): err_gen.extend(body) err_gen("except TypeError as e:") with Indentation(err_gen): gen_error_handler() err_gen("except _cl.LogicError as e:") with Indentation(err_gen): gen_error_handler() return err_gen
def wrap_in_error_handler(body, arg_names): from pytools.py_codegen import PythonCodeGenerator, Indentation err_gen = PythonCodeGenerator() def gen_error_handler(): err_gen(""" if current_arg is not None: args = [{args}] advice = "" from pyopencl.array import Array if isinstance(args[current_arg], Array): advice = " (perhaps you meant to pass 'array.data' " \ "instead of the array itself?)" raise _cl.LogicError( "when processing argument #%d (1-based): %s%s" % (current_arg+1, str(e), advice)) else: raise """.format(args=", ".join(arg_names))) err_gen("") err_gen("try:") with Indentation(err_gen): err_gen.extend(body) err_gen("except TypeError as e:") with Indentation(err_gen): gen_error_handler() err_gen("except _cl.LogicError as e:") with Indentation(err_gen): gen_error_handler() return err_gen
def _generate_enqueue_and_set_args_module(function_name, num_passed_args, num_cl_args, scalar_arg_dtypes, work_around_arg_count_bug, warn_about_arg_count_bug): from pytools.py_codegen import PythonCodeGenerator, Indentation arg_names = ["arg%d" % i for i in range(num_passed_args)] if scalar_arg_dtypes is None: body = generate_generic_arg_handling_body(num_passed_args) else: body = generate_specific_arg_handling_body( function_name, num_cl_args, scalar_arg_dtypes, warn_about_arg_count_bug=warn_about_arg_count_bug, work_around_arg_count_bug=work_around_arg_count_bug) err_handler = wrap_in_error_handler(body, arg_names) gen = PythonCodeGenerator() gen("from struct import pack") gen("from pyopencl import status_code") gen("") # {{{ generate _enqueue enqueue_name = "enqueue_knl_%s" % function_name gen("def %s(%s):" % (enqueue_name, ", ".join(["self", "queue", "global_size", "local_size"] + arg_names + ["global_offset=None", "g_times_l=None", "wait_for=None"]))) with Indentation(gen): add_local_imports(gen) gen.extend(err_handler) gen(""" return _cl.enqueue_nd_range_kernel(queue, self, global_size, local_size, global_offset, wait_for, g_times_l=g_times_l) """) # }}} # {{{ generate set_args gen("") gen("def set_args(%s):" % (", ".join(["self"] + arg_names))) with Indentation(gen): add_local_imports(gen) gen.extend(err_handler) # }}} return gen.get_picklable_module(), enqueue_name
def _generate_enqueue_and_set_args_module(function_name, num_passed_args, num_cl_args, scalar_arg_dtypes, work_around_arg_count_bug, warn_about_arg_count_bug): from pytools.py_codegen import PythonCodeGenerator, Indentation arg_names = ["arg%d" % i for i in range(num_passed_args)] if scalar_arg_dtypes is None: body = generate_generic_arg_handling_body(num_passed_args) else: body = generate_specific_arg_handling_body( function_name, num_cl_args, scalar_arg_dtypes, warn_about_arg_count_bug=warn_about_arg_count_bug, work_around_arg_count_bug=work_around_arg_count_bug) err_handler = wrap_in_error_handler(body, arg_names) gen = PythonCodeGenerator() gen("from struct import pack") gen("from pyopencl import status_code") gen("") # {{{ generate _enqueue enqueue_name = "enqueue_knl_%s" % function_name gen("def %s(%s):" % (enqueue_name, ", ".join( ["self", "queue", "global_size", "local_size"] + arg_names + ["global_offset=None", "g_times_l=None", "wait_for=None"]))) with Indentation(gen): add_local_imports(gen) gen.extend(err_handler) gen(""" return _cl.enqueue_nd_range_kernel(queue, self, global_size, local_size, global_offset, wait_for, g_times_l=g_times_l) """) # }}} # {{{ generate set_args gen("") gen("def set_args(%s):" % (", ".join(["self"] + arg_names))) with Indentation(gen): add_local_imports(gen) gen.extend(err_handler) # }}} return gen.get_picklable_module(), enqueue_name
def generate_generic_arg_handling_body(num_args): from pytools.py_codegen import PythonCodeGenerator gen = PythonCodeGenerator() if num_args == 0: gen("pass") for i in range(num_args): gen("# process argument {arg_idx}".format(arg_idx=i)) gen("") gen("current_arg = {arg_idx}".format(arg_idx=i)) generate_generic_arg_handler(gen, i, "arg%d" % i) gen("") return gen
def generate_generic_arg_handling_body(num_args): gen = PythonCodeGenerator() if num_args == 0: gen("pass") else: gen_indices_and_args = [] for i in range(num_args): gen_indices_and_args.append(i) gen_indices_and_args.append(f"arg{i}") gen(f"self._set_arg_multi(" f"({', '.join(str(i) for i in gen_indices_and_args)},), " ")") return gen
def generate_specific_arg_handling_body(function_name, num_cl_args, scalar_arg_dtypes, work_around_arg_count_bug, warn_about_arg_count_bug): assert work_around_arg_count_bug is not None assert warn_about_arg_count_bug is not None fp_arg_count = 0 cl_arg_idx = 0 from pytools.py_codegen import PythonCodeGenerator gen = PythonCodeGenerator() if not scalar_arg_dtypes: gen("pass") for arg_idx, arg_dtype in enumerate(scalar_arg_dtypes): gen("# process argument {arg_idx}".format(arg_idx=arg_idx)) gen("") gen("current_arg = {arg_idx}".format(arg_idx=arg_idx)) arg_var = "arg%d" % arg_idx if arg_dtype is None: generate_generic_arg_handler(gen, cl_arg_idx, arg_var) cl_arg_idx += 1 gen("") continue arg_dtype = np.dtype(arg_dtype) if arg_dtype.char == "V": generate_generic_arg_handler(gen, cl_arg_idx, arg_var) cl_arg_idx += 1 elif arg_dtype.kind == "c": if warn_about_arg_count_bug: warn("{knl_name}: arguments include complex numbers, and " "some (but not all) of the target devices mishandle " "struct kernel arguments (hence the workaround is " "disabled".format(knl_name=function_name, stacklevel=2)) if arg_dtype == np.complex64: arg_char = "f" elif arg_dtype == np.complex128: arg_char = "d" else: raise TypeError("unexpected complex type: %s" % arg_dtype) if (work_around_arg_count_bug == "pocl" and arg_dtype == np.complex128 and fp_arg_count + 2 <= 8): gen("buf = pack('{arg_char}', {arg_var}.real)".format( arg_char=arg_char, arg_var=arg_var)) generate_bytes_arg_setter(gen, cl_arg_idx, "buf") cl_arg_idx += 1 gen("current_arg = current_arg + 1000") gen("buf = pack('{arg_char}', {arg_var}.imag)".format( arg_char=arg_char, arg_var=arg_var)) generate_bytes_arg_setter(gen, cl_arg_idx, "buf") cl_arg_idx += 1 elif (work_around_arg_count_bug == "apple" and arg_dtype == np.complex128 and fp_arg_count + 2 <= 8): raise NotImplementedError( "No work-around to " "Apple's broken structs-as-kernel arg " "handling has been found. " "Cannot pass complex numbers to kernels.") else: gen("buf = pack('{arg_char}{arg_char}', " "{arg_var}.real, {arg_var}.imag)".format(arg_char=arg_char, arg_var=arg_var)) generate_bytes_arg_setter(gen, cl_arg_idx, "buf") cl_arg_idx += 1 fp_arg_count += 2 elif arg_dtype.char in "IL" and _CPY26: # Prevent SystemError: ../Objects/longobject.c:336: bad # argument to internal function gen("buf = pack('{arg_char}', long({arg_var}))".format( arg_char=arg_dtype.char, arg_var=arg_var)) generate_bytes_arg_setter(gen, cl_arg_idx, "buf") cl_arg_idx += 1 else: if arg_dtype.kind == "f": fp_arg_count += 1 arg_char = arg_dtype.char arg_char = _type_char_map.get(arg_char, arg_char) gen("buf = pack('{arg_char}', {arg_var})".format(arg_char=arg_char, arg_var=arg_var)) generate_bytes_arg_setter(gen, cl_arg_idx, "buf") cl_arg_idx += 1 gen("") if cl_arg_idx != num_cl_args: raise TypeError("length of argument list (%d) and " "CL-generated number of arguments (%d) do not agree" % (cl_arg_idx, num_cl_args)) return gen
def as_python(mesh, function_name="make_mesh"): """Return a snippet of Python code (as a string) that will recreate the mesh given as an input parameter. """ from pytools.py_codegen import PythonCodeGenerator, Indentation cg = PythonCodeGenerator() cg(""" # generated by meshmode.mesh.as_python import numpy as np from meshmode.mesh import ( Mesh, MeshElementGroup, FacialAdjacencyGroup, BTAG_ALL, BTAG_REALLY_ALL) """) cg("def %s():" % function_name) with Indentation(cg): cg("vertices = " + _numpy_array_as_python(mesh.vertices)) cg("") cg("groups = []") cg("") for group in mesh.groups: cg("import %s" % type(group).__module__) cg("groups.append(%s.%s(" % ( type(group).__module__, type(group).__name__)) cg(" order=%s," % group.order) cg(" vertex_indices=%s," % _numpy_array_as_python(group.vertex_indices)) cg(" nodes=%s," % _numpy_array_as_python(group.nodes)) cg(" unit_nodes=%s))" % _numpy_array_as_python(group.unit_nodes)) # {{{ facial adjacency groups def fagrp_params_str(fagrp): params = { "igroup": fagrp.igroup, "ineighbor_group": repr(fagrp.ineighbor_group), "elements": _numpy_array_as_python(fagrp.elements), "element_faces": _numpy_array_as_python(fagrp.element_faces), "neighbors": _numpy_array_as_python(fagrp.neighbors), "neighbor_faces": _numpy_array_as_python(fagrp.neighbor_faces), } return ",\n ".join("%s=%s" % (k, v) for k, v in six.iteritems(params)) if mesh._facial_adjacency_groups: cg("facial_adjacency_groups = []") for igrp, fagrp_map in enumerate(mesh.facial_adjacency_groups): cg("facial_adjacency_groups.append({%s})" % ",\n ".join( "%r: FacialAdjacencyGroup(%s)" % ( inb_grp, fagrp_params_str(fagrp)) for inb_grp, fagrp in six.iteritems(fagrp_map))) else: cg("facial_adjacency_groups = %r" % mesh._facial_adjacency_groups) # }}} # {{{ boundary tags def strify_boundary_tag(btag): if isinstance(btag, type): return btag.__name__ else: return repr(btag) btags_str = ", ".join( strify_boundary_tag(btag) for btag in mesh.boundary_tags) # }}} cg("return Mesh(vertices, groups, skip_tests=True,") cg(" vertex_id_dtype=np.%s," % mesh.vertex_id_dtype.name) cg(" element_id_dtype=np.%s," % mesh.element_id_dtype.name) if isinstance(mesh._nodal_adjacency, NodalAdjacency): el_con_str = "(%s, %s)" % ( _numpy_array_as_python( mesh._nodal_adjacency.neighbors_starts), _numpy_array_as_python( mesh._nodal_adjacency.neighbors), ) else: el_con_str = repr(mesh._nodal_adjacency) cg(" nodal_adjacency=%s," % el_con_str) cg(" facial_adjacency_groups=facial_adjacency_groups,") cg(" boundary_tags=[%s])" % btags_str) # FIXME: Handle facial adjacency, boundary tags return cg.get()
def capture_kernel_call(kernel, filename, queue, g_size, l_size, *args, **kwargs): try: source = kernel._source except AttributeError: raise RuntimeError("cannot capture call, kernel source not available") if source is None: raise RuntimeError("cannot capture call, kernel source not available") cg = PythonCodeGenerator() cg("# generated by pyopencl.capture_call") cg("") cg("import numpy as np") cg("import pyopencl as cl") cg("from base64 import b64decode") cg("from zlib import decompress") cg("mf = cl.mem_flags") cg("") cg('CODE = r"""//CL//') for l in source.split("\n"): cg(l) cg('"""') # {{{ invocation arg_data = [] cg("") cg("") cg("def main():") with Indentation(cg): cg("ctx = cl.create_some_context()") cg("queue = cl.CommandQueue(ctx)") cg("") kernel_args = [] for i, arg in enumerate(args): if isinstance(arg, cl.Buffer): buf = bytearray(arg.size) cl.enqueue_copy(queue, buf, arg) arg_data.append(("arg%d_data" % i, buf)) cg("arg%d = cl.Buffer(ctx, " "mf.READ_WRITE | cl.mem_flags.COPY_HOST_PTR," % i) cg(" hostbuf=decompress(b64decode(arg%d_data)))" % i) kernel_args.append("arg%d" % i) elif isinstance(arg, (int, float)): kernel_args.append(repr(arg)) elif isinstance(arg, np.integer): kernel_args.append("np.%s(%s)" % (arg.dtype.type.__name__, repr(int(arg)))) elif isinstance(arg, np.floating): kernel_args.append("np.%s(%s)" % (arg.dtype.type.__name__, repr(float(arg)))) elif isinstance(arg, np.complexfloating): kernel_args.append("np.%s(%s)" % (arg.dtype.type.__name__, repr(complex(arg)))) else: try: arg_buf = buffer(arg) except: raise RuntimeError("cannot capture: " "unsupported arg nr %d (0-based)" % i) arg_data.append(("arg%d_data" % i, arg_buf)) kernel_args.append("decompress(b64decode(arg%d_data))" % i) cg("") g_times_l = kwargs.get("g_times_l", False) if g_times_l: dim = max(len(g_size), len(l_size)) l_size = l_size + (1,) * (dim - len(l_size)) g_size = g_size + (1,) * (dim - len(g_size)) g_size = tuple(gs * ls for gs, ls in zip(g_size, l_size)) global_offset = kwargs.get("global_offset", None) if global_offset is not None: kernel_args.append("global_offset=%s" % repr(global_offset)) cg("prg = cl.Program(ctx, CODE).build()") cg("knl = prg.%s" % kernel.function_name) if hasattr(kernel, "_arg_type_chars"): cg("knl._arg_type_chars = %s" % repr(kernel._arg_type_chars)) cg("knl(queue, %s, %s," % (repr(g_size), repr(l_size))) cg(" %s)" % ", ".join(kernel_args)) cg("") cg("queue.finish()") # }}} # {{{ data from zlib import compress from base64 import b64encode cg("") line_len = 70 for name, val in arg_data: cg("%s = (" % name) with Indentation(cg): val = str(b64encode(compress(buffer(val)))) i = 0 while i < len(val): cg(repr(val[i : i + line_len])) i += line_len cg(")") # }}} # {{{ file trailer cg("") cg('if __name__ == "__main__":') with Indentation(cg): cg("main()") cg("") cg("# vim: filetype=pyopencl") # }}} with open(filename, "w") as outf: outf.write(cg.get())
def capture_kernel_call(kernel, filename, queue, g_size, l_size, *args, **kwargs): try: source = kernel._source except AttributeError: raise RuntimeError("cannot capture call, kernel source not available") if source is None: raise RuntimeError("cannot capture call, kernel source not available") cg = PythonCodeGenerator() cg("# generated by pyopencl.capture_call") cg("") cg("import numpy as np") cg("import pyopencl as cl") cg("from base64 import b64decode") cg("from zlib import decompress") cg("mf = cl.mem_flags") cg("") cg('CODE = r"""//CL//') for line in source.split("\n"): cg(line) cg('"""') # {{{ invocation arg_data = [] cg("") cg("") cg("def main():") with Indentation(cg): cg("ctx = cl.create_some_context()") cg("queue = cl.CommandQueue(ctx)") cg("") kernel_args = [] for i, arg in enumerate(args): if isinstance(arg, cl.Buffer): buf = bytearray(arg.size) cl.enqueue_copy(queue, buf, arg) arg_data.append(("arg%d_data" % i, buf)) cg("arg%d = cl.Buffer(ctx, " "mf.READ_WRITE | cl.mem_flags.COPY_HOST_PTR," % i) cg(" hostbuf=decompress(b64decode(arg%d_data)))" % i) kernel_args.append("arg%d" % i) elif isinstance(arg, (int, float)): kernel_args.append(repr(arg)) elif isinstance(arg, np.integer): kernel_args.append("np.%s(%s)" % (arg.dtype.type.__name__, repr(int(arg)))) elif isinstance(arg, np.floating): kernel_args.append("np.%s(%s)" % (arg.dtype.type.__name__, repr(float(arg)))) elif isinstance(arg, np.complexfloating): kernel_args.append( "np.%s(%s)" % (arg.dtype.type.__name__, repr(complex(arg)))) else: try: arg_buf = memoryview(arg) except Exception: raise RuntimeError("cannot capture: " "unsupported arg nr %d (0-based)" % i) arg_data.append(("arg%d_data" % i, arg_buf)) kernel_args.append("decompress(b64decode(arg%d_data))" % i) cg("") g_times_l = kwargs.get("g_times_l", False) if g_times_l: dim = max(len(g_size), len(l_size)) l_size = l_size + (1, ) * (dim - len(l_size)) g_size = g_size + (1, ) * (dim - len(g_size)) g_size = tuple(gs * ls for gs, ls in zip(g_size, l_size)) global_offset = kwargs.get("global_offset", None) if global_offset is not None: kernel_args.append("global_offset=%s" % repr(global_offset)) cg("prg = cl.Program(ctx, CODE).build()") cg("knl = prg.%s" % kernel.function_name) if hasattr(kernel, "_scalar_arg_dtypes"): def strify_dtype(d): if d is None: return "None" d = np.dtype(d) s = repr(d) if s.startswith("dtype"): s = "np." + s return s cg("knl.set_scalar_arg_dtypes((%s,))" % ", ".join(strify_dtype(dt) for dt in kernel._scalar_arg_dtypes)) cg("knl(queue, %s, %s," % (repr(g_size), repr(l_size))) cg(" %s)" % ", ".join(kernel_args)) cg("") cg("queue.finish()") # }}} # {{{ data from zlib import compress from base64 import b64encode cg("") line_len = 70 for name, val in arg_data: cg("%s = (" % name) with Indentation(cg): val = str(b64encode(compress(memoryview(val)))) i = 0 while i < len(val): cg(repr(val[i:i + line_len])) i += line_len cg(")") # }}} # {{{ file trailer cg("") cg("if __name__ == \"__main__\":") with Indentation(cg): cg("main()") cg("") cg("# vim: filetype=pyopencl") # }}} with open(filename, "w") as outf: outf.write(cg.get())
def generate_specific_arg_handling_body(function_name, num_cl_args, arg_types, work_around_arg_count_bug, warn_about_arg_count_bug, in_enqueue): assert work_around_arg_count_bug is not None assert warn_about_arg_count_bug is not None fp_arg_count = 0 cl_arg_idx = 0 gen = PythonCodeGenerator() if not arg_types: gen("pass") gen_indices_and_args = [] buf_indices_and_args = [] buf_pack_indices_and_args = [] def add_buf_arg(arg_idx, typechar, expr_str): if typechar in BUF_PACK_TYPECHARS: buf_pack_indices_and_args.append(arg_idx) buf_pack_indices_and_args.append(repr(typechar.encode())) buf_pack_indices_and_args.append(expr_str) else: buf_indices_and_args.append(arg_idx) buf_indices_and_args.append(f"pack('{typechar}', {expr_str})") if in_enqueue and arg_types is not None and \ any(isinstance(arg_type, VectorArg) for arg_type in arg_types): # We're about to modify wait_for, make sure it's a copy. gen(""" if wait_for is None: wait_for = [] else: wait_for = list(wait_for) """) gen("") for arg_idx, arg_type in enumerate(arg_types): arg_var = "arg%d" % arg_idx if arg_type is None: gen_indices_and_args.append(cl_arg_idx) gen_indices_and_args.append(arg_var) cl_arg_idx += 1 gen("") continue elif isinstance(arg_type, VectorArg): gen(f"if not {arg_var}.flags.forc:") with Indentation(gen): gen("raise RuntimeError('only contiguous arrays may '") gen(" 'be used as arguments to this operation')") gen("") if in_enqueue: gen(f"assert {arg_var}.queue is None or {arg_var}.queue == queue, " "'queues for all arrays must match the queue supplied " "to enqueue'") gen_indices_and_args.append(cl_arg_idx) gen_indices_and_args.append(f"{arg_var}.base_data") cl_arg_idx += 1 if arg_type.with_offset: add_buf_arg(cl_arg_idx, np.dtype(np.int64).char, f"{arg_var}.offset") cl_arg_idx += 1 if in_enqueue: gen(f"wait_for.extend({arg_var}.events)") continue arg_dtype = np.dtype(arg_type) if arg_dtype.char == "V": buf_indices_and_args.append(cl_arg_idx) buf_indices_and_args.append(arg_var) cl_arg_idx += 1 elif arg_dtype.kind == "c": if warn_about_arg_count_bug: warn("{knl_name}: arguments include complex numbers, and " "some (but not all) of the target devices mishandle " "struct kernel arguments (hence the workaround is " "disabled".format(knl_name=function_name), stacklevel=2) if arg_dtype == np.complex64: arg_char = "f" elif arg_dtype == np.complex128: arg_char = "d" else: raise TypeError("unexpected complex type: %s" % arg_dtype) if (work_around_arg_count_bug == "pocl" and arg_dtype == np.complex128 and fp_arg_count + 2 <= 8): add_buf_arg(cl_arg_idx, arg_char, f"{arg_var}.real") cl_arg_idx += 1 add_buf_arg(cl_arg_idx, arg_char, f"{arg_var}.imag") cl_arg_idx += 1 elif (work_around_arg_count_bug == "apple" and arg_dtype == np.complex128 and fp_arg_count + 2 <= 8): raise NotImplementedError( "No work-around to " "Apple's broken structs-as-kernel arg " "handling has been found. " "Cannot pass complex numbers to kernels.") else: buf_indices_and_args.append(cl_arg_idx) buf_indices_and_args.append( f"pack('{arg_char}{arg_char}', {arg_var}.real, {arg_var}.imag)" ) cl_arg_idx += 1 fp_arg_count += 2 else: if arg_dtype.kind == "f": fp_arg_count += 1 arg_char = arg_dtype.char arg_char = _type_char_map.get(arg_char, arg_char) add_buf_arg(cl_arg_idx, arg_char, arg_var) cl_arg_idx += 1 gen("") for arg_kind, args_and_indices, entry_length in [ ("", gen_indices_and_args, 2), ("_buf", buf_indices_and_args, 2), ("_buf_pack", buf_pack_indices_and_args, 3), ]: assert len(args_and_indices) % entry_length == 0 if args_and_indices: gen(f"self._set_arg{arg_kind}_multi(" f"({', '.join(str(i) for i in args_and_indices)},), " ")") if cl_arg_idx != num_cl_args: raise TypeError("length of argument list (%d) and " "CL-generated number of arguments (%d) do not agree" % (cl_arg_idx, num_cl_args)) return gen
def _generate_enqueue_and_set_args_module(function_name, num_passed_args, num_cl_args, arg_types, work_around_arg_count_bug, warn_about_arg_count_bug): arg_names = ["arg%d" % i for i in range(num_passed_args)] def gen_arg_setting(in_enqueue): if arg_types is None: return generate_generic_arg_handling_body(num_passed_args) else: return generate_specific_arg_handling_body( function_name, num_cl_args, arg_types, warn_about_arg_count_bug=warn_about_arg_count_bug, work_around_arg_count_bug=work_around_arg_count_bug, in_enqueue=in_enqueue) gen = PythonCodeGenerator() gen("from struct import pack") gen("from pyopencl import status_code") gen("import numpy as np") gen("import pyopencl._cl as _cl") gen("") # {{{ generate _enqueue enqueue_name = "enqueue_knl_%s" % function_name gen("def %s(%s):" % (enqueue_name, ", ".join(["self", "queue", "global_size", "local_size"] + arg_names + [ "global_offset=None", "g_times_l=None", "allow_empty_ndrange=False", "wait_for=None" ]))) with Indentation(gen): gen.extend(gen_arg_setting(in_enqueue=True)) # Using positional args here because pybind is slow with keyword args gen(""" return _cl.enqueue_nd_range_kernel(queue, self, global_size, local_size, global_offset, wait_for, g_times_l, allow_empty_ndrange) """) # }}} # {{{ generate set_args gen("") gen("def set_args(%s):" % (", ".join(["self"] + arg_names))) with Indentation(gen): gen.extend(gen_arg_setting(in_enqueue=False)) # }}} return gen.get_picklable_module(), enqueue_name
def _generate_enqueue_and_set_args_module(function_name, num_passed_args, num_cl_args, arg_types, include_debug_code, work_around_arg_count_bug, warn_about_arg_count_bug): arg_names = ["arg%d" % i for i in range(num_passed_args)] def gen_arg_setting(in_enqueue): if arg_types is None: result = generate_generic_arg_handling_body(num_passed_args) if in_enqueue: return result, [] else: return result else: return generate_specific_arg_handling_body( function_name, num_cl_args, arg_types, warn_about_arg_count_bug=warn_about_arg_count_bug, work_around_arg_count_bug=work_around_arg_count_bug, in_enqueue=in_enqueue, include_debug_code=include_debug_code) gen = PythonCodeGenerator() gen("from struct import pack") gen("from pyopencl import status_code") gen("import numpy as np") gen("import pyopencl._cl as _cl") gen("") # {{{ generate _enqueue enqueue_name = "enqueue_knl_%s" % function_name gen("def %s(%s):" % (enqueue_name, ", ".join(["self", "queue", "global_size", "local_size"] + arg_names + [ "global_offset=None", "g_times_l=None", "allow_empty_ndrange=False", "wait_for=None" ]))) with Indentation(gen): subgen, wait_for_parts = gen_arg_setting(in_enqueue=True) gen.extend(subgen) if wait_for_parts: wait_for_expr = ("[*(() if wait_for is None else wait_for), " + ", ".join("*" + wfp for wfp in wait_for_parts) + "]") else: wait_for_expr = "wait_for" # Using positional args here because pybind is slow with keyword args gen(f""" return _cl.enqueue_nd_range_kernel(queue, self, global_size, local_size, global_offset, {wait_for_expr}, g_times_l, allow_empty_ndrange) """) # }}} # {{{ generate set_args gen("") gen("def set_args(%s):" % (", ".join(["self"] + arg_names))) with Indentation(gen): gen.extend(gen_arg_setting(in_enqueue=False)) # }}} return (gen.get_picklable_module( name=f"<pyopencl invoker for '{function_name}'>"), enqueue_name)