예제 #1
0
    def do_compile(self, key, args):
        src = _remove_indent(oinspect.getsource(self.func))
        tree = ast.parse(src)

        func_body = tree.body[0]
        func_body.decorator_list = []

        visitor = ASTTransformer(is_kernel=False, func=self)
        visitor.visit(tree)

        ast.increment_lineno(tree, oinspect.getsourcelines(self.func)[1] - 1)

        local_vars = {}
        global_vars = _get_global_vars(self.func)

        if impl.get_runtime().experimental_real_function:
            # inject template parameters into globals
            for i in self.template_slot_locations:
                template_var_name = self.argument_names[i]
                global_vars[template_var_name] = args[i]

        exec(
            compile(tree,
                    filename=oinspect.getsourcefile(self.func),
                    mode='exec'), global_vars, local_vars)

        if impl.get_runtime().experimental_real_function:
            self.compiled[key.instance_id] = local_vars[self.func.__name__]
            self.taichi_functions[key.instance_id] = _ti_core.create_function(
                key)
            self.taichi_functions[key.instance_id].set_function_body(
                self.compiled[key.instance_id])
        else:
            self.compiled = local_vars[self.func.__name__]
예제 #2
0
    def do_compile(self):
        src = _remove_indent(oinspect.getsource(self.func))
        tree = ast.parse(src)

        func_body = tree.body[0]
        func_body.decorator_list = []

        visitor = ASTTransformer(is_kernel=False, func=self)
        visitor.visit(tree)

        ast.increment_lineno(tree, oinspect.getsourcelines(self.func)[1] - 1)

        local_vars = {}
        global_vars = _get_global_vars(self.func)

        exec(
            compile(tree,
                    filename=oinspect.getsourcefile(self.func),
                    mode='exec'), global_vars, local_vars)
        self.compiled = local_vars[self.func.__name__]
예제 #3
0
    def materialize(self, key=None, args=None, arg_features=None):
        _taichi_skip_traceback = 1
        if key is None:
            key = (self.func, 0)
        if not self.runtime.materialized:
            self.runtime.materialize()
        if key in self.compiled_functions:
            return
        grad_suffix = ""
        if self.is_grad:
            grad_suffix = "_grad"
        kernel_name = "{}_c{}_{}{}".format(self.func.__name__,
                                           self.kernel_counter, key[1],
                                           grad_suffix)
        ti.trace("Compiling kernel {}...".format(kernel_name))

        src = _remove_indent(oinspect.getsource(self.func))
        tree = ast.parse(src)

        func_body = tree.body[0]
        func_body.decorator_list = []

        local_vars = {}
        global_vars = _get_global_vars(self.func)

        for i, arg in enumerate(func_body.args.args):
            anno = arg.annotation
            if isinstance(anno, ast.Name):
                global_vars[anno.id] = self.argument_annotations[i]

        if isinstance(func_body.returns, ast.Name):
            global_vars[func_body.returns.id] = self.return_type

        if self.is_grad:
            KernelSimplicityASTChecker(self.func).visit(tree)

        visitor = ASTTransformer(
            excluded_paremeters=self.template_slot_locations,
            func=self,
            arg_features=arg_features)

        visitor.visit(tree)

        ast.increment_lineno(tree, oinspect.getsourcelines(self.func)[1] - 1)

        # inject template parameters into globals
        for i in self.template_slot_locations:
            template_var_name = self.argument_names[i]
            global_vars[template_var_name] = args[i]

        exec(
            compile(tree,
                    filename=oinspect.getsourcefile(self.func),
                    mode='exec'), global_vars, local_vars)
        compiled = local_vars[self.func.__name__]

        taichi_kernel = _ti_core.create_kernel(kernel_name, self.is_grad)

        # Do not change the name of 'taichi_ast_generator'
        # The warning system needs this identifier to remove unnecessary messages
        def taichi_ast_generator():
            _taichi_skip_traceback = 1
            if self.runtime.inside_kernel:
                raise TaichiSyntaxError(
                    "Kernels cannot call other kernels. I.e., nested kernels are not allowed. Please check if you have direct/indirect invocation of kernels within kernels. Note that some methods provided by the Taichi standard library may invoke kernels, and please move their invocations to Python-scope."
                )
            self.runtime.inside_kernel = True
            self.runtime.current_kernel = self
            compiled()
            self.runtime.inside_kernel = False
            self.runtime.current_kernel = None

        taichi_kernel = taichi_kernel.define(taichi_ast_generator)

        assert key not in self.compiled_functions
        self.compiled_functions[key] = self.get_function_body(taichi_kernel)