示例#1
0
    def do_compile(self, key, args):
        src = textwrap.dedent(oinspect.getsource(self.func))
        tree = ast.parse(src)

        func_body = tree.body[0]
        func_body.decorator_list = []

        visitor = ASTTransformerTotal(is_kernel=False, func=self)
        visitor.visit(tree)

        ast.increment_lineno(tree, oinspect.getsourcelines(self.func)[1] - 1)

        local_vars = {}
        global_vars = _get_global_vars(self.func)

        if impl.get_runtime().experimental_real_function:
            # inject template parameters into globals
            for i in self.template_slot_locations:
                template_var_name = self.argument_names[i]
                global_vars[template_var_name] = args[i]

        exec(
            compile(tree,
                    filename=oinspect.getsourcefile(self.func),
                    mode='exec'), global_vars, local_vars)

        if impl.get_runtime().experimental_real_function:
            self.compiled[key.instance_id] = local_vars[self.func.__name__]
            self.taichi_functions[key.instance_id] = _ti_core.create_function(
                key)
            self.taichi_functions[key.instance_id].set_function_body(
                self.compiled[key.instance_id])
        else:
            self.compiled = local_vars[self.func.__name__]
示例#2
0
def _get_tree_and_ctx(self,
                      excluded_parameters=(),
                      is_kernel=True,
                      arg_features=None,
                      args=None,
                      ast_builder=None,
                      is_real_function=False):
    file = oinspect.getsourcefile(self.func)
    src, start_lineno = oinspect.getsourcelines(self.func)
    src = [textwrap.fill(line, tabsize=4, width=9999) for line in src]
    tree = ast.parse(textwrap.dedent("\n".join(src)))

    func_body = tree.body[0]
    func_body.decorator_list = []

    global_vars = _get_global_vars(self.func)

    for i, arg in enumerate(func_body.args.args):
        anno = arg.annotation
        if isinstance(anno, ast.Name):
            global_vars[anno.id] = self.argument_annotations[i]

    if isinstance(func_body.returns, ast.Name):
        global_vars[func_body.returns.id] = self.return_type

    if is_kernel or is_real_function:
        # inject template parameters into globals
        for i in self.template_slot_locations:
            template_var_name = self.argument_names[i]
            global_vars[template_var_name] = args[i]

    return tree, ASTTransformerContext(excluded_parameters=excluded_parameters,
                                       is_kernel=is_kernel,
                                       func=self,
                                       arg_features=arg_features,
                                       global_vars=global_vars,
                                       argument_data=args,
                                       src=src,
                                       start_lineno=start_lineno,
                                       file=file,
                                       ast_builder=ast_builder,
                                       is_real_function=is_real_function)
示例#3
0
    def do_compile(self):
        src = _remove_indent(oinspect.getsource(self.func))
        tree = ast.parse(src)

        func_body = tree.body[0]
        func_body.decorator_list = []

        visitor = ASTTransformer(is_kernel=False, func=self)
        visitor.visit(tree)

        ast.increment_lineno(tree, oinspect.getsourcelines(self.func)[1] - 1)

        local_vars = {}
        global_vars = _get_global_vars(self.func)

        exec(
            compile(tree,
                    filename=oinspect.getsourcefile(self.func),
                    mode='exec'), global_vars, local_vars)
        self.compiled = local_vars[self.func.__name__]
示例#4
0
    def materialize(self, key=None, args=None, arg_features=None):
        if impl.get_runtime().experimental_ast_refactor:
            return self.materialize_ast_refactor(key=key,
                                                 args=args,
                                                 arg_features=arg_features)
        _taichi_skip_traceback = 1
        if key is None:
            key = (self.func, 0)
        self.runtime.materialize()
        if key in self.compiled_functions:
            return
        grad_suffix = ""
        if self.is_grad:
            grad_suffix = "_grad"
        kernel_name = "{}_c{}_{}{}".format(self.func.__name__,
                                           self.kernel_counter, key[1],
                                           grad_suffix)
        ti.trace("Compiling kernel {}...".format(kernel_name))

        src = textwrap.dedent(oinspect.getsource(self.func))
        tree = ast.parse(src)

        func_body = tree.body[0]
        func_body.decorator_list = []

        local_vars = {}
        global_vars = _get_global_vars(self.func)

        for i, arg in enumerate(func_body.args.args):
            anno = arg.annotation
            if isinstance(anno, ast.Name):
                global_vars[anno.id] = self.argument_annotations[i]

        if isinstance(func_body.returns, ast.Name):
            global_vars[func_body.returns.id] = self.return_type

        if self.is_grad:
            KernelSimplicityASTChecker(self.func).visit(tree)

        visitor = ASTTransformerTotal(
            excluded_parameters=self.template_slot_locations,
            func=self,
            arg_features=arg_features)

        visitor.visit(tree)

        ast.increment_lineno(tree, oinspect.getsourcelines(self.func)[1] - 1)

        # inject template parameters into globals
        for i in self.template_slot_locations:
            template_var_name = self.argument_names[i]
            global_vars[template_var_name] = args[i]

        exec(
            compile(tree,
                    filename=oinspect.getsourcefile(self.func),
                    mode='exec'), global_vars, local_vars)
        compiled = local_vars[self.func.__name__]

        # Do not change the name of 'taichi_ast_generator'
        # The warning system needs this identifier to remove unnecessary messages
        def taichi_ast_generator():
            _taichi_skip_traceback = 1
            if self.runtime.inside_kernel:
                raise TaichiSyntaxError(
                    "Kernels cannot call other kernels. I.e., nested kernels are not allowed. Please check if you have direct/indirect invocation of kernels within kernels. Note that some methods provided by the Taichi standard library may invoke kernels, and please move their invocations to Python-scope."
                )
            self.runtime.inside_kernel = True
            self.runtime.current_kernel = self
            try:
                compiled()
            finally:
                self.runtime.inside_kernel = False
                self.runtime.current_kernel = None

        taichi_kernel = _ti_core.create_kernel(taichi_ast_generator,
                                               kernel_name, self.is_grad)

        self.kernel_cpp = taichi_kernel

        assert key not in self.compiled_functions
        self.compiled_functions[key] = self.get_function_body(taichi_kernel)
示例#5
0
 def __init__(self, func):
     super().__init__()
     self._func_file = oinspect.getsourcefile(func)
     self._func_lineno = oinspect.getsourcelines(func)[1]
     self._func_name = func.__name__
     self._scope_guards = []