def __init__(self, func): self.func = func self.__doc__ = self.func.__doc__ self.global_work_size = None self.local_work_size = None self.global_work_offset = None self._cache = {} self._development_mode = False self._no_cache = False self._file_cacher = NoFileCache() self._use_cache_file = False
class kernel(object): ''' Create an OpenCL kernel from a Python function. This class can be used as a decorator:: @kernel def foo(a): ... ''' def __init__(self, func): self.func = func self.__doc__ = self.func.__doc__ self.global_work_size = None self.local_work_size = None self.global_work_offset = None self._cache = {} self._development_mode = False self._no_cache = False self._file_cacher = NoFileCache() self._use_cache_file = False def clear_cache(self): ''' Clear the binary cache in memory. ''' self._cache.clear() def run_kernel(self, cl_kernel, queue, kernel_args, kwargs): ''' Run a kernel this method is subclassable for the task class. ''' event = cl_kernel(queue, global_work_size=kwargs.get('global_work_size'), global_work_offset=kwargs.get('global_work_offset'), local_work_size=kwargs.get('local_work_size'), **kernel_args) return event def _unpack(self, argnames, arglist, kwarg_types): ''' Unpack memobject structure into two arguments. ''' kernel_args = {} for name, arg in zip(argnames, arglist): if is_const(arg): continue arg_type = kwarg_types[name] if isinstance(arg_type, cl.contextual_memory): if kwarg_types[name].ndim != 0: kernel_args['cly_%s_info' % name] = arg_type._get_array_info(arg) kernel_args[name] = arg return kernel_args def __call__(self, queue_or_context, *args, **kwargs): ''' Call this kernel as a function. :param queue_or_context: a queue or context. if this is a context a queue is created and finish is called before return. :return: an OpenCL event. ''' if isinstance(queue_or_context, cl.Context): queue = cl.Queue(queue_or_context) else: queue = queue_or_context argnames = self.func.func_code.co_varnames[:self.func.func_code.co_argcount] defaults = self.func.func_defaults kwargs_ = kwargs.copy() kwargs_.pop('global_work_size', None) kwargs_.pop('global_work_offset', None) kwargs_.pop('local_work_size', None) arglist = cl.kernel.parse_args(self.func.__name__, args, kwargs_, argnames, defaults) kwarg_types = {argnames[i]:typeof(queue.context, arglist[i]) for i in range(len(argnames))} cl_kernel = self.compile(queue.context, **kwarg_types) kernel_args = self._unpack(argnames, arglist, kwarg_types) event = self.run_kernel(cl_kernel, queue, kernel_args, kwargs) #FIXME: I don't like that this breaks encapsulation if isinstance(event, EventRecord): event.set_kernel_args(kernel_args) if isinstance(queue_or_context, cl.Context): queue.finish() return event def compile(self, ctx, source_only=False, cly_meta=None, **kwargs): ''' Compile a kernel or lookup in cache. :param ctx: openCL context :param cly_meta: meta-information for inspecting the cache. (does nothing) :param kwargs: All other keyword arguments are used for type information. :return: An OpenCL kernel ''' cache = self._cache.setdefault(ctx, {}) cache_key = tuple(sorted(kwargs.viewitems(), key=lambda item:item[0])) #Check for in memory cache if cache_key not in cache or self._no_cache: cl_kernel = self.compile_or_cly(ctx, source_only=source_only, cly_meta=cly_meta, **kwargs) cache[cache_key] = cl_kernel return cache[cache_key] def source(self, ctx, *args, **kwargs): ''' Get the source that would be compiled for specific argument types. .. note:: This is meant to have a similar signature to the function call. i.e:: print func.source(queue.context, arg1, arg2) func(queue, arg1, arg2) ''' argnames = self.func.func_code.co_varnames[:self.func.func_code.co_argcount] defaults = self.func.func_defaults arglist = cl.kernel.parse_args(self.func.__name__, args, kwargs, argnames, defaults) kwarg_types = {argnames[i]:typeof(ctx, arglist[i]) for i in range(len(argnames))} return self.compile_or_cly(ctx, source_only=True, **kwarg_types) @property def db_filename(self): ''' get the filename that the binaries can be cached to ''' from os.path import splitext base = splitext(self.func.func_code.co_filename)[0] return base + '.h5.cly' def compile_or_cly(self, ctx, source_only=False, cly_meta=None, **kwargs): ''' internal ''' cache_key = self._file_cacher.generate_key(kwargs) if (ctx, self.func, cache_key) in self._file_cacher: program, kernel_name, args, defaults = self._file_cacher.get(ctx, self.func, cache_key) else: args, defaults, kernel_name, source = self.translate(ctx, **kwargs) program = self._compile(ctx, args, defaults, kernel_name, source) self._file_cacher.set(ctx, self.func, cache_key, args, defaults, kernel_name, cly_meta, source, program.binaries) cl_kernel = program.kernel(kernel_name) cl_kernel.global_work_size = self.global_work_size cl_kernel.local_work_size = self.local_work_size cl_kernel.global_work_offset = self.global_work_offset cl_kernel.argtypes = [arg[1] for arg in args] cl_kernel.argnames = [arg[0] for arg in args] cl_kernel.__defaults__ = defaults return cl_kernel def translate(self, ctx, **kwargs): ''' Translate this func into a tuple of (args, defaults, kernel_name, source) ''' try: args, defaults, source, kernel_name = create_kernel_source(self.func, kwargs) except cast.CError as error: if self._development_mode: raise redirect = ast.parse('raise error.exc(error.msg)') redirect.body[0].lineno = error.node.lineno filename = self.func.func_code.co_filename redirect_error_to_function = compile(redirect, filename, 'exec') eval(redirect_error_to_function) #use the @cly.developer function decorator to turn this off and see stack trace ... return args, defaults, kernel_name, source def _compile(self, ctx, args, defaults, kernel_name, source): ''' Compile a kernel without cache lookup. ''' tmpfile = mktemp('.cl', 'clyther_') program = cl.Program(ctx, ('#line 1 "%s"\n' % (tmpfile)) + source) try: program.build() except cl.OpenCLException: log_lines = [] for device, log in program.logs.items(): log_lines.append(repr(device)) log_lines.append(log) with open(tmpfile, 'w') as fp: fp.write(source) raise CLComileError('\n'.join(log_lines), program) for device, log in program.logs.items(): if log: print log return program