Exemple #1
0
    def materialize_ast_refactor(self, key=None, args=None, arg_features=None):
        _taichi_skip_traceback = 1
        if key is None:
            key = (self.func, 0)
        self.runtime.materialize()
        if key in self.compiled_functions:
            return
        grad_suffix = ""
        if self.is_grad:
            grad_suffix = "_grad"
        kernel_name = "{}_c{}_{}{}".format(self.func.__name__,
                                           self.kernel_counter, key[1],
                                           grad_suffix)
        ti.trace("Compiling kernel {}...".format(kernel_name))

        tree, global_vars = _get_tree_and_global_vars(self, args)

        if self.is_grad:
            KernelSimplicityASTChecker(self.func).visit(tree)
        visitor = ASTTransformerTotal(
            excluded_parameters=self.template_slot_locations,
            func=self,
            arg_features=arg_features,
            globals=global_vars)

        ast.increment_lineno(tree, oinspect.getsourcelines(self.func)[1] - 1)

        # Do not change the name of 'taichi_ast_generator'
        # The warning system needs this identifier to remove unnecessary messages
        def taichi_ast_generator():
            _taichi_skip_traceback = 1
            if self.runtime.inside_kernel:
                raise TaichiSyntaxError(
                    "Kernels cannot call other kernels. I.e., nested kernels are not allowed. Please check if you have direct/indirect invocation of kernels within kernels. Note that some methods provided by the Taichi standard library may invoke kernels, and please move their invocations to Python-scope."
                )
            self.runtime.inside_kernel = True
            self.runtime.current_kernel = self
            try:
                visitor.visit(tree)
            finally:
                self.runtime.inside_kernel = False
                self.runtime.current_kernel = None

        taichi_kernel = _ti_core.create_kernel(taichi_ast_generator,
                                               kernel_name, self.is_grad)

        self.kernel_cpp = taichi_kernel

        assert key not in self.compiled_functions
        self.compiled_functions[key] = self.get_function_body(taichi_kernel)
Exemple #2
0
    def materialize(self, key=None, args=None, arg_features=None):
        if impl.get_runtime().experimental_ast_refactor:
            return self.materialize_ast_refactor(key=key,
                                                 args=args,
                                                 arg_features=arg_features)
        _taichi_skip_traceback = 1
        if key is None:
            key = (self.func, 0)
        self.runtime.materialize()
        if key in self.compiled_functions:
            return
        grad_suffix = ""
        if self.is_grad:
            grad_suffix = "_grad"
        kernel_name = "{}_c{}_{}{}".format(self.func.__name__,
                                           self.kernel_counter, key[1],
                                           grad_suffix)
        ti.trace("Compiling kernel {}...".format(kernel_name))

        src = textwrap.dedent(oinspect.getsource(self.func))
        tree = ast.parse(src)

        func_body = tree.body[0]
        func_body.decorator_list = []

        local_vars = {}
        global_vars = _get_global_vars(self.func)

        for i, arg in enumerate(func_body.args.args):
            anno = arg.annotation
            if isinstance(anno, ast.Name):
                global_vars[anno.id] = self.argument_annotations[i]

        if isinstance(func_body.returns, ast.Name):
            global_vars[func_body.returns.id] = self.return_type

        if self.is_grad:
            KernelSimplicityASTChecker(self.func).visit(tree)

        visitor = ASTTransformerTotal(
            excluded_parameters=self.template_slot_locations,
            func=self,
            arg_features=arg_features)

        visitor.visit(tree)

        ast.increment_lineno(tree, oinspect.getsourcelines(self.func)[1] - 1)

        # inject template parameters into globals
        for i in self.template_slot_locations:
            template_var_name = self.argument_names[i]
            global_vars[template_var_name] = args[i]

        exec(
            compile(tree,
                    filename=oinspect.getsourcefile(self.func),
                    mode='exec'), global_vars, local_vars)
        compiled = local_vars[self.func.__name__]

        # Do not change the name of 'taichi_ast_generator'
        # The warning system needs this identifier to remove unnecessary messages
        def taichi_ast_generator():
            _taichi_skip_traceback = 1
            if self.runtime.inside_kernel:
                raise TaichiSyntaxError(
                    "Kernels cannot call other kernels. I.e., nested kernels are not allowed. Please check if you have direct/indirect invocation of kernels within kernels. Note that some methods provided by the Taichi standard library may invoke kernels, and please move their invocations to Python-scope."
                )
            self.runtime.inside_kernel = True
            self.runtime.current_kernel = self
            try:
                compiled()
            finally:
                self.runtime.inside_kernel = False
                self.runtime.current_kernel = None

        taichi_kernel = _ti_core.create_kernel(taichi_ast_generator,
                                               kernel_name, self.is_grad)

        self.kernel_cpp = taichi_kernel

        assert key not in self.compiled_functions
        self.compiled_functions[key] = self.get_function_body(taichi_kernel)