def _parse_function_object(func): # Parses function object into ast.FunctionDef object. if not callable(func): raise ValueError('`func` must be a callable object.') if func.__name__ != '<lambda>': if jit._getsource_func is None: lines = inspect.getsource(func).split('\n') num_indent = len(lines[0]) - len(lines[0].lstrip()) source = '\n'.join( [line.replace(' ' * num_indent, '', 1) for line in lines]) else: source = jit._getsource_func(func) tree = ast.parse(source) assert isinstance(tree, ast.Module) assert len(tree.body) == 1 return tree.body[0], source if jit._getsource_func is not None: full_source = jit._getsource_func(func) start_line, end_line = 0, math.inf source = full_source else: try: filename = inspect.getsourcefile(func) except TypeError: filename = None if filename is None: raise ValueError(f'JIT needs access to Python source for {func}' 'but could not be located') with open(filename) as f: full_source = f.read() source, start_line = inspect.getsourcelines(func) end_line = start_line + len(source) source = ''.join(source) tree = ast.parse(full_source) nodes = [ node for node in ast.walk(tree) if isinstance(node, ast.Lambda) and start_line <= node.lineno < end_line ] if len(nodes) > 1: raise ValueError('Multiple callables are found near the' f' definition of {func}, and JIT could not' ' identify the source code for it.') node = nodes[0] return ast.FunctionDef( name='_lambda_kernel', args=node.args, body=[ast.Return(node.body)], decorator_list=[], returns=None, type_comment=None, ), source
def transpile(func, attributes, mode, in_types, ret_type): """Transpile the target function Args: func (function): Target function. attributes (list of str): Attributes of the generated CUDA function. mode ('numpy' or 'cuda'): The rule for typecast. in_types (list of _types.TypeBase): Types of the arguments. ret_type (_types.TypeBase or None): Type of the return value. """ if not callable(func): raise ValueError('`func` must be a callable object.') if func.__name__ == '<lambda>': raise NotImplementedError('Lambda function is not supported.') attributes = ' '.join(attributes) source = jit._getsource_func(func) lines = source.split('\n') num_indent = len(lines[0]) - len(lines[0].lstrip()) source = '\n'.join( [line.replace(' ' * num_indent, '', 1) for line in lines]) cvars = inspect.getclosurevars(func) consts = dict(**cvars.globals, **cvars.nonlocals, **cvars.builtins) tree = ast.parse(source) assert isinstance(tree, ast.Module) assert len(tree.body) == 1 cuda_code, env = _transpile_function(tree.body[0], attributes, mode, consts, in_types, ret_type, source=source) cuda_code = ''.join([code + '\n' for code in env.preambles]) + cuda_code return Result( func_name=func.__name__, code=cuda_code, return_type=env.ret_type, )