Beispiel #1
0
def test_ast_resolver_alias():
    import taichi
    node = ast.parse('taichi.kernel', mode='eval').body
    assert ASTResolver.resolve_to(node, taichi.kernel, locals())

    import taichi as tc
    node = ast.parse('tc.kernel', mode='eval').body
    assert ASTResolver.resolve_to(node, tc.kernel, locals())
Beispiel #2
0
 def visit_Call(self, node):
     if not ASTResolver.resolve_to(node.func, ti.static, globals()):
         # Do not apply the generic visitor if the function called is ti.static
         self.generic_visit(node)
     if isinstance(node.func, ast.Attribute):
         attr_name = node.func.attr
         if attr_name == 'format':
             node.args.insert(0, node.func.value)
             node.func = self.parse_expr('ti.ti_format')
     if isinstance(node.func, ast.Name):
         func_name = node.func.id
         if func_name == 'print':
             node.func = self.parse_expr('ti.ti_print')
         elif func_name == 'min':
             node.func = self.parse_expr('ti.ti_min')
         elif func_name == 'max':
             node.func = self.parse_expr('ti.ti_max')
         elif func_name == 'int':
             node.func = self.parse_expr('ti.ti_int')
         elif func_name == 'float':
             node.func = self.parse_expr('ti.ti_float')
         elif func_name == 'any':
             node.func = self.parse_expr('ti.ti_any')
         elif func_name == 'all':
             node.func = self.parse_expr('ti.ti_all')
         else:
             pass
     return node
Beispiel #3
0
def test_ast_resolver_wrong_ti():
    import taichi
    taichi.init()
    fake_ti = namedtuple('FakeTi', ['kernel'])
    ti = fake_ti(kernel='fake')
    node = ast.parse('ti.kernel', mode='eval').body
    assert not ASTResolver.resolve_to(node, taichi.kernel, locals())
Beispiel #4
0
 def build_Call(ctx, node):
     if ASTResolver.resolve_to(node.func, ti.static, globals()):
         # Do not modify the expression if the function called is ti.static
         return node
     node.func = build_expr(ctx, node.func)
     node.args = build_exprs(ctx, node.args)
     if isinstance(node.func, ast.Attribute):
         attr_name = node.func.attr
         if attr_name == 'format':
             node.args.insert(0, node.func.value)
             node.func = parse_expr('ti.ti_format')
     if isinstance(node.func, ast.Name):
         func_name = node.func.id
         if func_name == 'print':
             node.func = parse_expr('ti.ti_print')
         elif func_name == 'min':
             node.func = parse_expr('ti.ti_min')
         elif func_name == 'max':
             node.func = parse_expr('ti.ti_max')
         elif func_name == 'int':
             node.func = parse_expr('ti.ti_int')
         elif func_name == 'float':
             node.func = parse_expr('ti.ti_float')
         elif func_name == 'any':
             node.func = parse_expr('ti.ti_any')
         elif func_name == 'all':
             node.func = parse_expr('ti.ti_all')
         else:
             pass
     return node
Beispiel #5
0
 def get_decorator(node):
     if not isinstance(node, ast.Call):
         return ''
     for wanted, name in [
         (ti.static, 'static'),
         (ti.grouped, 'grouped'),
         (ti.ndrange, 'ndrange'),
     ]:
         if ASTResolver.resolve_to(node.func, wanted, globals()):
             return name
     return ''
Beispiel #6
0
    def build_Call(ctx, node):
        if ASTResolver.resolve_to(node.func, ti.static, globals()):
            # Do not modify the expression if the function called is ti.static
            return node
        node.func = build_expr(ctx, node.func)
        node.args = build_exprs(ctx, node.args)
        for i in range(len(node.keywords)):
            node.keywords[i].value = build_expr(ctx, node.keywords[i].value)
        if isinstance(node.func, ast.Attribute):
            attr_name = node.func.attr
            if attr_name == 'format':
                node.args.insert(0, node.func.value)
                node.func = parse_expr('ti.ti_format')
        if isinstance(node.func, ast.Name):
            func_name = node.func.id
            if func_name == 'print':
                node.func = parse_expr('ti.ti_print')
            elif func_name == 'min':
                node.func = parse_expr('ti.ti_min')
            elif func_name == 'max':
                node.func = parse_expr('ti.ti_max')
            elif func_name == 'int':
                node.func = parse_expr('ti.ti_int')
            elif func_name == 'float':
                node.func = parse_expr('ti.ti_float')
            elif func_name == 'any':
                node.func = parse_expr('ti.ti_any')
            elif func_name == 'all':
                node.func = parse_expr('ti.ti_all')
            else:
                pass

        _taichi_skip_traceback = 1
        ti_func = node.func
        if '_sitebuiltins' == getattr(ti_func, '__module__', '') and getattr(
                getattr(ti_func, '__class__', ''), '__name__',
                '') == 'Quitter':
            raise TaichiSyntaxError(
                f'exit or quit not supported in Taichi-scope')
        if getattr(ti_func, '__module__', '') == '__main__' and not getattr(
                ti_func, '__wrapped__', ''):
            warnings.warn(
                f'Calling into non-Taichi function {ti_func.__name__}.'
                ' This means that scope inside that function will not be processed'
                ' by the Taichi transformer. Proceed with caution! '
                ' Maybe you want to decorate it with @ti.func?',
                UserWarning,
                stacklevel=2)

        return node
Beispiel #7
0
    def visit_FunctionDef(self, node):
        args = node.args
        assert args.vararg is None
        assert args.kwonlyargs == []
        assert args.kw_defaults == []
        assert args.kwarg is None

        arg_decls = []

        def transform_as_kernel():
            # Treat return type
            if node.returns is not None:
                ret_init = self.parse_stmt(
                    'ti.lang.kernel_arguments.decl_scalar_ret(0)')
                ret_init.value.args[0] = node.returns
                self.returns = node.returns
                arg_decls.append(ret_init)
                node.returns = None

            for i, arg in enumerate(args.args):
                # Directly pass in template arguments,
                # such as class instances ("self"), fields, SNodes, etc.
                if isinstance(self.func.argument_annotations[i], template):
                    continue
                if isinstance(self.func.argument_annotations[i], ext_arr):
                    arg_init = self.parse_stmt(
                        'x = ti.lang.kernel_arguments.decl_ext_arr_arg(0, 0)')
                    arg_init.targets[0].id = arg.arg
                    self.create_variable(arg.arg)
                    array_dt = self.arg_features[i][0]
                    array_dim = self.arg_features[i][1]
                    array_dt = to_taichi_type(array_dt)
                    dt_expr = 'ti.' + ti.core.data_type_name(array_dt)
                    dt = self.parse_expr(dt_expr)
                    arg_init.value.args[0] = dt
                    arg_init.value.args[1] = self.parse_expr(
                        "{}".format(array_dim))
                    arg_decls.append(arg_init)
                else:
                    arg_init = self.parse_stmt(
                        'x = ti.lang.kernel_arguments.decl_scalar_arg(0)')
                    arg_init.targets[0].id = arg.arg
                    dt = arg.annotation
                    arg_init.value.args[0] = dt
                    arg_decls.append(arg_init)
            # remove original args
            node.args.args = []

        if self.is_kernel:  # ti.kernel
            for decorator in node.decorator_list:
                if ASTResolver.resolve_to(decorator, ti.func, globals()):
                    raise TaichiSyntaxError(
                        "Function definition not allowed in 'ti.kernel'.")
            transform_as_kernel()

        else:  # ti.func
            for decorator in node.decorator_list:
                if ASTResolver.resolve_to(decorator, ti.func, globals()):
                    raise TaichiSyntaxError(
                        "Function definition not allowed in 'ti.func'.")
            if impl.get_runtime().experimental_real_function:
                transform_as_kernel()
            else:
                # Transform as func (all parameters passed by value)
                arg_decls = []
                for i, arg in enumerate(args.args):
                    # Directly pass in template arguments,
                    # such as class instances ("self"), fields, SNodes, etc.
                    if isinstance(self.func.argument_annotations[i], template):
                        continue
                    # Create a copy for non-template arguments,
                    # so that they are passed by value.
                    arg_init = self.parse_stmt('x = ti.expr_init_func(0)')
                    arg_init.targets[0].id = arg.arg
                    self.create_variable(arg.arg)
                    arg_init.value.args[0] = self.parse_expr(arg.arg +
                                                             '_by_value__')
                    args.args[i].arg += '_by_value__'
                    arg_decls.append(arg_init)

        with self.variable_scope():
            self.generic_visit(node)

        node.body = arg_decls + node.body
        node.body = [self.parse_stmt('import taichi as ti')] + node.body
        return node
Beispiel #8
0
    def visit_Assign(self, node):
        assert (len(node.targets) == 1)
        self.generic_visit(node)

        is_static_assign = isinstance(
            node.value, ast.Call) and ASTResolver.resolve_to(
                node.value.func, ti.static, globals())
        if is_static_assign:
            return node

        if isinstance(node.targets[0], ast.Tuple):
            targets = node.targets[0].elts

            # Create
            stmts = []

            holder = self.parse_stmt('__tmp_tuple = ti.expr_init_list(0, '
                                     f'{len(targets)})')
            holder.value.args[0] = node.value

            stmts.append(holder)

            def tuple_indexed(i):
                indexing = self.parse_stmt('__tmp_tuple[0]')
                self.set_subscript_index(indexing.value,
                                         self.parse_expr("{}".format(i)))
                return indexing.value

            for i, target in enumerate(targets):
                is_local = isinstance(target, ast.Name)
                if is_local and self.is_creation(target.id):
                    var_name = target.id
                    target.ctx = ast.Store()
                    # Create, no AST resolution needed
                    init = ast.Attribute(value=ast.Name(id='ti',
                                                        ctx=ast.Load()),
                                         attr='expr_init',
                                         ctx=ast.Load())
                    rhs = ast.Call(
                        func=init,
                        args=[tuple_indexed(i)],
                        keywords=[],
                    )
                    self.create_variable(var_name)
                    stmts.append(ast.Assign(targets=[target], value=rhs))
                else:
                    # Assign
                    target.ctx = ast.Load()
                    func = ast.Attribute(value=target,
                                         attr='assign',
                                         ctx=ast.Load())
                    call = ast.Call(func=func,
                                    args=[tuple_indexed(i)],
                                    keywords=[])
                    stmts.append(ast.Expr(value=call))

            for stmt in stmts:
                ast.copy_location(stmt, node)
            stmts.append(self.parse_stmt('del __tmp_tuple'))
            return self.make_single_statement(stmts)
        else:
            is_local = isinstance(node.targets[0], ast.Name)
            if is_local and self.is_creation(node.targets[0].id):
                var_name = node.targets[0].id
                # Create, no AST resolution needed
                init = ast.Attribute(value=ast.Name(id='ti', ctx=ast.Load()),
                                     attr='expr_init',
                                     ctx=ast.Load())
                rhs = ast.Call(
                    func=init,
                    args=[node.value],
                    keywords=[],
                )
                self.create_variable(var_name)
                return ast.copy_location(
                    ast.Assign(targets=node.targets, value=rhs), node)
            else:
                # Assign
                node.targets[0].ctx = ast.Load()
                func = ast.Attribute(value=node.targets[0],
                                     attr='assign',
                                     ctx=ast.Load())
                call = ast.Call(func=func, args=[node.value], keywords=[])
                return ast.copy_location(ast.Expr(value=call), node)
Beispiel #9
0
def test_ast_resolver_basic():
    # import within the function to avoid polluting the global scope
    import taichi as ti
    node = ast.parse('ti.kernel', mode='eval').body
    assert ASTResolver.resolve_to(node, ti.kernel, locals())
Beispiel #10
0
def test_ast_resolver_chain():
    import taichi as ti
    node = ast.parse('ti.lang.ops.atomic_add', mode='eval').body
    assert ASTResolver.resolve_to(node, ti.atomic_add, locals())
Beispiel #11
0
def test_ast_resolver_direct_import():
    from taichi import kernel
    node = ast.parse('kernel', mode='eval').body
    assert ASTResolver.resolve_to(node, kernel, locals())