def materialize_ast_refactor(self, key=None, args=None, arg_features=None): _taichi_skip_traceback = 1 if key is None: key = (self.func, 0) self.runtime.materialize() if key in self.compiled_functions: return grad_suffix = "" if self.is_grad: grad_suffix = "_grad" kernel_name = "{}_c{}_{}{}".format(self.func.__name__, self.kernel_counter, key[1], grad_suffix) ti.trace("Compiling kernel {}...".format(kernel_name)) tree, global_vars = _get_tree_and_global_vars(self, args) if self.is_grad: KernelSimplicityASTChecker(self.func).visit(tree) visitor = ASTTransformerTotal( excluded_parameters=self.template_slot_locations, func=self, arg_features=arg_features, globals=global_vars) ast.increment_lineno(tree, oinspect.getsourcelines(self.func)[1] - 1) # Do not change the name of 'taichi_ast_generator' # The warning system needs this identifier to remove unnecessary messages def taichi_ast_generator(): _taichi_skip_traceback = 1 if self.runtime.inside_kernel: raise TaichiSyntaxError( "Kernels cannot call other kernels. I.e., nested kernels are not allowed. Please check if you have direct/indirect invocation of kernels within kernels. Note that some methods provided by the Taichi standard library may invoke kernels, and please move their invocations to Python-scope." ) self.runtime.inside_kernel = True self.runtime.current_kernel = self try: visitor.visit(tree) finally: self.runtime.inside_kernel = False self.runtime.current_kernel = None taichi_kernel = _ti_core.create_kernel(taichi_ast_generator, kernel_name, self.is_grad) self.kernel_cpp = taichi_kernel assert key not in self.compiled_functions self.compiled_functions[key] = self.get_function_body(taichi_kernel)
def materialize(self, key=None, args=None, arg_features=None): if impl.get_runtime().experimental_ast_refactor: return self.materialize_ast_refactor(key=key, args=args, arg_features=arg_features) _taichi_skip_traceback = 1 if key is None: key = (self.func, 0) self.runtime.materialize() if key in self.compiled_functions: return grad_suffix = "" if self.is_grad: grad_suffix = "_grad" kernel_name = "{}_c{}_{}{}".format(self.func.__name__, self.kernel_counter, key[1], grad_suffix) ti.trace("Compiling kernel {}...".format(kernel_name)) src = textwrap.dedent(oinspect.getsource(self.func)) tree = ast.parse(src) func_body = tree.body[0] func_body.decorator_list = [] local_vars = {} global_vars = _get_global_vars(self.func) for i, arg in enumerate(func_body.args.args): anno = arg.annotation if isinstance(anno, ast.Name): global_vars[anno.id] = self.argument_annotations[i] if isinstance(func_body.returns, ast.Name): global_vars[func_body.returns.id] = self.return_type if self.is_grad: KernelSimplicityASTChecker(self.func).visit(tree) visitor = ASTTransformerTotal( excluded_parameters=self.template_slot_locations, func=self, arg_features=arg_features) visitor.visit(tree) ast.increment_lineno(tree, oinspect.getsourcelines(self.func)[1] - 1) # inject template parameters into globals for i in self.template_slot_locations: template_var_name = self.argument_names[i] global_vars[template_var_name] = args[i] exec( compile(tree, filename=oinspect.getsourcefile(self.func), mode='exec'), global_vars, local_vars) compiled = local_vars[self.func.__name__] # Do not change the name of 'taichi_ast_generator' # The warning system needs this identifier to remove unnecessary messages def taichi_ast_generator(): _taichi_skip_traceback = 1 if self.runtime.inside_kernel: raise TaichiSyntaxError( "Kernels cannot call other kernels. I.e., nested kernels are not allowed. Please check if you have direct/indirect invocation of kernels within kernels. Note that some methods provided by the Taichi standard library may invoke kernels, and please move their invocations to Python-scope." ) self.runtime.inside_kernel = True self.runtime.current_kernel = self try: compiled() finally: self.runtime.inside_kernel = False self.runtime.current_kernel = None taichi_kernel = _ti_core.create_kernel(taichi_ast_generator, kernel_name, self.is_grad) self.kernel_cpp = taichi_kernel assert key not in self.compiled_functions self.compiled_functions[key] = self.get_function_body(taichi_kernel)