def do_compile(self, key, args): src = textwrap.dedent(oinspect.getsource(self.func)) tree = ast.parse(src) func_body = tree.body[0] func_body.decorator_list = [] visitor = ASTTransformerTotal(is_kernel=False, func=self) visitor.visit(tree) ast.increment_lineno(tree, oinspect.getsourcelines(self.func)[1] - 1) local_vars = {} global_vars = _get_global_vars(self.func) if impl.get_runtime().experimental_real_function: # inject template parameters into globals for i in self.template_slot_locations: template_var_name = self.argument_names[i] global_vars[template_var_name] = args[i] exec( compile(tree, filename=oinspect.getsourcefile(self.func), mode='exec'), global_vars, local_vars) if impl.get_runtime().experimental_real_function: self.compiled[key.instance_id] = local_vars[self.func.__name__] self.taichi_functions[key.instance_id] = _ti_core.create_function( key) self.taichi_functions[key.instance_id].set_function_body( self.compiled[key.instance_id]) else: self.compiled = local_vars[self.func.__name__]
def materialize_ast_refactor(self, key=None, args=None, arg_features=None): _taichi_skip_traceback = 1 if key is None: key = (self.func, 0) self.runtime.materialize() if key in self.compiled_functions: return grad_suffix = "" if self.is_grad: grad_suffix = "_grad" kernel_name = "{}_c{}_{}{}".format(self.func.__name__, self.kernel_counter, key[1], grad_suffix) ti.trace("Compiling kernel {}...".format(kernel_name)) tree, global_vars = _get_tree_and_global_vars(self, args) if self.is_grad: KernelSimplicityASTChecker(self.func).visit(tree) visitor = ASTTransformerTotal( excluded_parameters=self.template_slot_locations, func=self, arg_features=arg_features, globals=global_vars) ast.increment_lineno(tree, oinspect.getsourcelines(self.func)[1] - 1) # Do not change the name of 'taichi_ast_generator' # The warning system needs this identifier to remove unnecessary messages def taichi_ast_generator(): _taichi_skip_traceback = 1 if self.runtime.inside_kernel: raise TaichiSyntaxError( "Kernels cannot call other kernels. I.e., nested kernels are not allowed. Please check if you have direct/indirect invocation of kernels within kernels. Note that some methods provided by the Taichi standard library may invoke kernels, and please move their invocations to Python-scope." ) self.runtime.inside_kernel = True self.runtime.current_kernel = self try: visitor.visit(tree) finally: self.runtime.inside_kernel = False self.runtime.current_kernel = None taichi_kernel = _ti_core.create_kernel(taichi_ast_generator, kernel_name, self.is_grad) self.kernel_cpp = taichi_kernel assert key not in self.compiled_functions self.compiled_functions[key] = self.get_function_body(taichi_kernel)
def _get_tree_and_ctx(self, excluded_parameters=(), is_kernel=True, arg_features=None, args=None, ast_builder=None, is_real_function=False): file = oinspect.getsourcefile(self.func) src, start_lineno = oinspect.getsourcelines(self.func) src = [textwrap.fill(line, tabsize=4, width=9999) for line in src] tree = ast.parse(textwrap.dedent("\n".join(src))) func_body = tree.body[0] func_body.decorator_list = [] global_vars = _get_global_vars(self.func) for i, arg in enumerate(func_body.args.args): anno = arg.annotation if isinstance(anno, ast.Name): global_vars[anno.id] = self.argument_annotations[i] if isinstance(func_body.returns, ast.Name): global_vars[func_body.returns.id] = self.return_type if is_kernel or is_real_function: # inject template parameters into globals for i in self.template_slot_locations: template_var_name = self.argument_names[i] global_vars[template_var_name] = args[i] return tree, ASTTransformerContext(excluded_parameters=excluded_parameters, is_kernel=is_kernel, func=self, arg_features=arg_features, global_vars=global_vars, argument_data=args, src=src, start_lineno=start_lineno, file=file, ast_builder=ast_builder, is_real_function=is_real_function)
def do_compile(self): src = _remove_indent(oinspect.getsource(self.func)) tree = ast.parse(src) func_body = tree.body[0] func_body.decorator_list = [] visitor = ASTTransformer(is_kernel=False, func=self) visitor.visit(tree) ast.increment_lineno(tree, oinspect.getsourcelines(self.func)[1] - 1) local_vars = {} global_vars = _get_global_vars(self.func) exec( compile(tree, filename=oinspect.getsourcefile(self.func), mode='exec'), global_vars, local_vars) self.compiled = local_vars[self.func.__name__]
def do_compile(self, key, args): src = textwrap.dedent(oinspect.getsource(self.func)) tree = ast.parse(src) func_body = tree.body[0] func_body.decorator_list = [] ast.increment_lineno(tree, oinspect.getsourcelines(self.func)[1] - 1) global_vars = _get_global_vars(self.func) # inject template parameters into globals for i in self.template_slot_locations: template_var_name = self.argument_names[i] global_vars[template_var_name] = args[i] visitor = ASTTransformerTotal(is_kernel=False, func=self, global_vars=global_vars) self.compiled[key.instance_id] = lambda: visitor.visit(tree) self.taichi_functions[key.instance_id] = _ti_core.create_function(key) self.taichi_functions[key.instance_id].set_function_body( self.compiled[key.instance_id])
def materialize(self, key=None, args=None, arg_features=None): if impl.get_runtime().experimental_ast_refactor: return self.materialize_ast_refactor(key=key, args=args, arg_features=arg_features) _taichi_skip_traceback = 1 if key is None: key = (self.func, 0) self.runtime.materialize() if key in self.compiled_functions: return grad_suffix = "" if self.is_grad: grad_suffix = "_grad" kernel_name = "{}_c{}_{}{}".format(self.func.__name__, self.kernel_counter, key[1], grad_suffix) ti.trace("Compiling kernel {}...".format(kernel_name)) src = textwrap.dedent(oinspect.getsource(self.func)) tree = ast.parse(src) func_body = tree.body[0] func_body.decorator_list = [] local_vars = {} global_vars = _get_global_vars(self.func) for i, arg in enumerate(func_body.args.args): anno = arg.annotation if isinstance(anno, ast.Name): global_vars[anno.id] = self.argument_annotations[i] if isinstance(func_body.returns, ast.Name): global_vars[func_body.returns.id] = self.return_type if self.is_grad: KernelSimplicityASTChecker(self.func).visit(tree) visitor = ASTTransformerTotal( excluded_parameters=self.template_slot_locations, func=self, arg_features=arg_features) visitor.visit(tree) ast.increment_lineno(tree, oinspect.getsourcelines(self.func)[1] - 1) # inject template parameters into globals for i in self.template_slot_locations: template_var_name = self.argument_names[i] global_vars[template_var_name] = args[i] exec( compile(tree, filename=oinspect.getsourcefile(self.func), mode='exec'), global_vars, local_vars) compiled = local_vars[self.func.__name__] # Do not change the name of 'taichi_ast_generator' # The warning system needs this identifier to remove unnecessary messages def taichi_ast_generator(): _taichi_skip_traceback = 1 if self.runtime.inside_kernel: raise TaichiSyntaxError( "Kernels cannot call other kernels. I.e., nested kernels are not allowed. Please check if you have direct/indirect invocation of kernels within kernels. Note that some methods provided by the Taichi standard library may invoke kernels, and please move their invocations to Python-scope." ) self.runtime.inside_kernel = True self.runtime.current_kernel = self try: compiled() finally: self.runtime.inside_kernel = False self.runtime.current_kernel = None taichi_kernel = _ti_core.create_kernel(taichi_ast_generator, kernel_name, self.is_grad) self.kernel_cpp = taichi_kernel assert key not in self.compiled_functions self.compiled_functions[key] = self.get_function_body(taichi_kernel)
def __init__(self, func): super().__init__() self._func_file = oinspect.getsourcefile(func) self._func_lineno = oinspect.getsourcelines(func)[1] self._func_name = func.__name__ self._scope_guards = []