Ejemplo n.º 1
0
def test_ast_resolver_alias():
    import taichi
    taichi.init()
    node = ast.parse('taichi.kernel', mode='eval').body
    assert ASTResolver.resolve_to(node, taichi.kernel, locals())

    import taichi as tc
    node = ast.parse('tc.kernel', mode='eval').body
    assert ASTResolver.resolve_to(node, tc.kernel, locals())
Ejemplo n.º 2
0
def test_ast_resolver_wrong_ti():
    import taichi
    taichi.init()
    fake_ti = namedtuple('FakeTi', ['kernel'])
    ti = fake_ti(kernel='fake')
    node = ast.parse('ti.kernel', mode='eval').body
    assert not ASTResolver.resolve_to(node, taichi.kernel, locals())
Ejemplo n.º 3
0
 def build_NamedExpr(ctx, node):
     node.value = build_stmt(ctx, node.value)
     node.target = build_stmt(ctx, node.target)
     is_static_assign = isinstance(
         node.value, ast.Call) and ASTResolver.resolve_to(
             node.value.func, ti.static, globals())
     node.ptr = IRBuilder.build_assign_basic(ctx, node.target,
                                             node.value.ptr,
                                             is_static_assign)
     return node
Ejemplo n.º 4
0
 def get_decorator(ctx, node):
     if not isinstance(node, ast.Call):
         return ''
     for wanted, name in [
         (impl.static, 'static'),
         (impl.grouped, 'grouped'),
         (ndrange, 'ndrange'),
     ]:
         if ASTResolver.resolve_to(node.func, wanted, ctx.global_vars):
             return name
     return ''
Ejemplo n.º 5
0
 def get_decorator(node):
     if not isinstance(node, ast.Call):
         return ''
     for wanted, name in [
         (ti.static, 'static'),
         (ti.grouped, 'grouped'),
         (ti.ndrange, 'ndrange'),
     ]:
         if ASTResolver.resolve_to(node.func, wanted, globals()):
             return name
     return ''
Ejemplo n.º 6
0
    def build_AnnAssign(ctx, node):
        build_stmt(ctx, node.value)
        build_stmt(ctx, node.target)
        build_stmt(ctx, node.annotation)

        is_static_assign = isinstance(
            node.value, ast.Call) and ASTResolver.resolve_to(
                node.value.func, ti.static, globals())

        node.ptr = ASTTransformer.build_assign_annotated(
            ctx, node.target, node.value.ptr, is_static_assign,
            node.annotation.ptr)
        return node.ptr
Ejemplo n.º 7
0
    def build_Call(ctx, node):
        if ASTResolver.resolve_to(node.func, ti.static, globals()):
            # Do not modify the expression if the function called is ti.static
            return node
        node.func = build_expr(ctx, node.func)
        node.args = build_exprs(ctx, node.args)
        for i in range(len(node.keywords)):
            node.keywords[i].value = build_expr(ctx, node.keywords[i].value)
        if isinstance(node.func, ast.Attribute):
            attr_name = node.func.attr
            if attr_name == 'format':
                node.args.insert(0, node.func.value)
                node.func = parse_expr('ti.ti_format')
        if isinstance(node.func, ast.Name):
            func_name = node.func.id
            if func_name == 'print':
                node.func = parse_expr('ti.ti_print')
            elif func_name == 'min':
                node.func = parse_expr('ti.ti_min')
            elif func_name == 'max':
                node.func = parse_expr('ti.ti_max')
            elif func_name == 'int':
                node.func = parse_expr('ti.ti_int')
            elif func_name == 'float':
                node.func = parse_expr('ti.ti_float')
            elif func_name == 'any':
                node.func = parse_expr('ti.ti_any')
            elif func_name == 'all':
                node.func = parse_expr('ti.ti_all')
            else:
                pass

        _taichi_skip_traceback = 1
        ti_func = node.func
        if '_sitebuiltins' == getattr(ti_func, '__module__', '') and getattr(
                getattr(ti_func, '__class__', ''), '__name__',
                '') == 'Quitter':
            raise TaichiSyntaxError(
                f'exit or quit not supported in Taichi-scope')
        if getattr(ti_func, '__module__', '') == '__main__' and not getattr(
                ti_func, '__wrapped__', ''):
            warnings.warn(
                f'Calling into non-Taichi function {ti_func.__name__}.'
                ' This means that scope inside that function will not be processed'
                ' by the Taichi transformer. Proceed with caution! '
                ' Maybe you want to decorate it with @ti.func?',
                UserWarning,
                stacklevel=2)

        return node
Ejemplo n.º 8
0
    def build_Assign(ctx, node):
        node.value = build_stmt(ctx, node.value)
        node.targets = build_stmts(ctx, node.targets)

        is_static_assign = isinstance(
            node.value, ast.Call) and ASTResolver.resolve_to(
                node.value.func, ti.static, globals())

        # Keep all generated assign statements and compose single one at last.
        # The variable is introduced to support chained assignments.
        # Ref https://github.com/taichi-dev/taichi/issues/2659.
        for node_target in node.targets:
            IRBuilder.build_assign_unpack(ctx, node_target, node.value.ptr,
                                          is_static_assign)
        return node
Ejemplo n.º 9
0
    def build_Assign(ctx, node):
        node.value = build_expr(ctx, node.value)
        node.targets = build_exprs(ctx, node.targets)

        is_static_assign = isinstance(
            node.value, ast.Call) and ASTResolver.resolve_to(
                node.value.func, ti.static, globals())
        if is_static_assign:
            return node

        # Keep all generated assign statements and compose single one at last.
        # The variable is introduced to support chained assignments.
        # Ref https://github.com/taichi-dev/taichi/issues/2659.
        assign_stmts = []
        for node_target in node.targets:
            if isinstance(node_target, ast.Tuple):
                assign_stmts.append(
                    StmtBuilder.build_assign_unpack(ctx, node, node_target))
            else:
                assign_stmts.append(
                    StmtBuilder.build_assign_basic(ctx, node, node_target,
                                                   node.value))
        return StmtBuilder.make_single_statement(assign_stmts)
Ejemplo n.º 10
0
    def build_FunctionDef(ctx, node):
        args = node.args
        assert args.vararg is None
        assert args.kwonlyargs == []
        assert args.kw_defaults == []
        assert args.kwarg is None

        arg_decls = []

        def transform_as_kernel():
            # Treat return type
            if node.returns is not None:
                ret_init = parse_stmt(
                    'ti.lang.kernel_arguments.decl_scalar_ret(0)')
                ret_init.value.args[0] = node.returns
                ctx.returns = node.returns
                arg_decls.append(ret_init)
                node.returns = None

            for i, arg in enumerate(args.args):
                # Directly pass in template arguments,
                # such as class instances ("self"), fields, SNodes, etc.
                if isinstance(ctx.func.argument_annotations[i], ti.template):
                    continue
                if isinstance(ctx.func.argument_annotations[i],
                              ti.linalg.sparse_matrix_builder):
                    arg_init = parse_stmt(
                        'x = ti.lang.kernel_arguments.decl_sparse_matrix()')
                    arg_init.targets[0].id = arg.arg
                    ctx.create_variable(arg.arg)
                    arg_decls.append(arg_init)
                elif isinstance(ctx.func.argument_annotations[i], ti.any_arr):
                    arg_init = parse_stmt(
                        'x = ti.lang.kernel_arguments.decl_any_arr_arg(0, 0, 0, 0)'
                    )
                    arg_init.targets[0].id = arg.arg
                    ctx.create_variable(arg.arg)
                    array_dt = ctx.arg_features[i][0]
                    array_dim = ctx.arg_features[i][1]
                    array_element_shape = ctx.arg_features[i][2]
                    array_layout = ctx.arg_features[i][3]
                    array_dt = to_taichi_type(array_dt)
                    dt_expr = 'ti.' + ti.core.data_type_name(array_dt)
                    dt = parse_expr(dt_expr)
                    arg_init.value.args[0] = dt
                    arg_init.value.args[1] = parse_expr("{}".format(array_dim))
                    arg_init.value.args[2] = parse_expr(
                        "{}".format(array_element_shape))
                    arg_init.value.args[3] = parse_expr(
                        "ti.{}".format(array_layout))
                    arg_decls.append(arg_init)
                else:
                    arg_init = parse_stmt(
                        'x = ti.lang.kernel_arguments.decl_scalar_arg(0)')
                    arg_init.targets[0].id = arg.arg
                    dt = arg.annotation
                    arg_init.value.args[0] = dt
                    arg_decls.append(arg_init)
            # remove original args
            node.args.args = []

        if ctx.is_kernel:  # ti.kernel
            for decorator in node.decorator_list:
                if ASTResolver.resolve_to(decorator, ti.func, globals()):
                    raise TaichiSyntaxError(
                        "Function definition not allowed in 'ti.kernel'.")
            transform_as_kernel()

        else:  # ti.func
            for decorator in node.decorator_list:
                if ASTResolver.resolve_to(decorator, ti.func, globals()):
                    raise TaichiSyntaxError(
                        "Function definition not allowed in 'ti.func'.")
            if impl.get_runtime().experimental_real_function:
                transform_as_kernel()
            else:
                # Transform as force-inlined func
                arg_decls = []
                for i, arg in enumerate(args.args):
                    # Remove annotations because they are not used.
                    args.args[i].annotation = None
                    # Template arguments are passed by reference.
                    if isinstance(ctx.func.argument_annotations[i],
                                  ti.template):
                        ctx.create_variable(ctx.func.argument_names[i])
                        continue
                    # Create a copy for non-template arguments,
                    # so that they are passed by value.
                    arg_init = parse_stmt('x = ti.expr_init_func(0)')
                    arg_init.targets[0].id = arg.arg
                    ctx.create_variable(arg.arg)
                    arg_init.value.args[0] = parse_expr(arg.arg +
                                                        '_by_value__')
                    args.args[i].arg += '_by_value__'
                    arg_decls.append(arg_init)

        with ctx.variable_scope_guard():
            node.body = build_stmts(ctx, node.body)

        node.body = arg_decls + node.body
        node.body = [parse_stmt('import taichi as ti')] + node.body
        return node
Ejemplo n.º 11
0
    def build_FunctionDef(ctx, node):
        args = node.args
        assert args.vararg is None
        assert args.kwonlyargs == []
        assert args.kw_defaults == []
        assert args.kwarg is None

        def transform_as_kernel():
            # Treat return type
            if node.returns is not None:
                ti.lang.kernel_arguments.decl_scalar_ret(ctx.func.return_type)

            for i, arg in enumerate(args.args):
                if isinstance(ctx.func.argument_annotations[i], ti.template):
                    continue
                elif isinstance(ctx.func.argument_annotations[i],
                                ti.linalg.sparse_matrix_builder):
                    ctx.create_variable(
                        arg.arg, ti.lang.kernel_arguments.decl_sparse_matrix())
                elif isinstance(ctx.func.argument_annotations[i], ti.any_arr):
                    ctx.create_variable(
                        arg.arg,
                        ti.lang.kernel_arguments.decl_any_arr_arg(
                            to_taichi_type(ctx.arg_features[i][0]),
                            ctx.arg_features[i][1], ctx.arg_features[i][2],
                            ctx.arg_features[i][3]))
                else:
                    ctx.global_vars[
                        arg.arg] = ti.lang.kernel_arguments.decl_scalar_arg(
                            ctx.func.argument_annotations[i])
            # remove original args
            node.args.args = []

        if ctx.is_kernel:  # ti.kernel
            for decorator in node.decorator_list:
                if ASTResolver.resolve_to(decorator, ti.func, globals()):
                    raise TaichiSyntaxError(
                        "Function definition not allowed in 'ti.kernel'.")
            transform_as_kernel()

        else:  # ti.func
            for decorator in node.decorator_list:
                if ASTResolver.resolve_to(decorator, ti.func, globals()):
                    raise TaichiSyntaxError(
                        "Function definition not allowed in 'ti.func'.")
            if impl.get_runtime().experimental_real_function:
                transform_as_kernel()
            else:
                len_args = len(args.args)
                len_default = len(args.defaults)
                len_provided = len(ctx.argument_data)
                len_minimum = len_args - len_default
                if len_args < len_provided or len_args - len_default > len_provided:
                    if len(args.defaults):
                        raise TaichiSyntaxError(
                            f"Function receives {len_minimum} to {len_args} argument(s) and {len_provided} provided."
                        )
                    else:
                        raise TaichiSyntaxError(
                            f"Function receives {len_args} argument(s) and {len_provided} provided."
                        )
                # Transform as force-inlined func
                default_start = len_provided - len_minimum
                ctx.argument_data = list(ctx.argument_data)
                for arg in args.defaults[default_start:]:
                    ctx.argument_data.append(build_stmt(ctx, arg).ptr)
                assert len(args.args) == len(ctx.argument_data)
                for i, (arg,
                        data) in enumerate(zip(args.args, ctx.argument_data)):
                    # Remove annotations because they are not used.
                    args.args[i].annotation = None
                    # Template arguments are passed by reference.
                    if isinstance(ctx.func.argument_annotations[i],
                                  ti.template):
                        ctx.create_variable(ctx.func.argument_names[i], data)
                        continue
                    # Create a copy for non-template arguments,
                    # so that they are passed by value.
                    ctx.create_variable(arg.arg, ti.expr_init_func(data))

        with ctx.variable_scope_guard():
            build_stmts(ctx, node.body)

        return node
Ejemplo n.º 12
0
    def build_FunctionDef(ctx, node):
        args = node.args
        assert args.vararg is None
        assert args.kwonlyargs == []
        assert args.kw_defaults == []
        assert args.kwarg is None

        arg_decls = []

        def transform_as_kernel():
            # Treat return type
            if node.returns is not None:
                node.returns = build_stmt(ctx, node.returns)
                ti.lang.kernel_arguments.decl_scalar_ret(node.returns.ptr)
                ctx.returns = node.returns.ptr

            for i, arg in enumerate(args.args):
                # Directly pass in template arguments,
                # such as class instances ("self"), fields, SNodes, etc.
                # if isinstance(ctx.func.argument_annotations[i], ti.template):
                #     continue
                # if isinstance(ctx.func.argument_annotations[i],
                #               ti.sparse_matrix_builder):
                #     arg_init = parse_stmt(
                #         'x = ti.lang.kernel_arguments.decl_sparse_matrix()')
                #     arg_init.targets[0].id = arg.arg
                #     ctx.create_variable(arg.arg)
                #     arg_decls.append(arg_init)
                # elif isinstance(ctx.func.argument_annotations[i], ti.any_arr):
                #     arg_init = parse_stmt(
                #         'x = ti.lang.kernel_arguments.decl_any_arr_arg(0, 0, 0, 0)'
                #     )
                #     arg_init.targets[0].id = arg.arg
                #     ctx.create_variable(arg.arg)
                #     array_dt = ctx.arg_features[i][0]
                #     array_dim = ctx.arg_features[i][1]
                #     array_element_shape = ctx.arg_features[i][2]
                #     array_layout = ctx.arg_features[i][3]
                #     array_dt = to_taichi_type(array_dt)
                #     dt_expr = 'ti.' + ti.core.data_type_name(array_dt)
                #     dt = parse_expr(dt_expr)
                #     arg_init.value.args[0] = dt
                #     arg_init.value.args[1] = parse_expr("{}".format(array_dim))
                #     arg_init.value.args[2] = parse_expr(
                #         "{}".format(array_element_shape))
                #     arg_init.value.args[3] = parse_expr(
                #         "ti.{}".format(array_layout))
                #     arg_decls.append(arg_init)
                if isinstance(ctx.func.argument_annotations[i], ti.template):
                    continue
                else:
                    arg.annotation = build_stmt(ctx, arg.annotation)
                    ctx.create_variable(
                        arg.arg,
                        ti.lang.kernel_arguments.decl_scalar_arg(
                            arg.annotation.ptr))
            # remove original args
            node.args.args = []

        if ctx.is_kernel:  # ti.kernel
            for decorator in node.decorator_list:
                if ASTResolver.resolve_to(decorator, ti.func, globals()):
                    raise TaichiSyntaxError(
                        "Function definition not allowed in 'ti.kernel'.")
            transform_as_kernel()

        else:  # ti.func
            for decorator in node.decorator_list:
                if ASTResolver.resolve_to(decorator, ti.func, globals()):
                    raise TaichiSyntaxError(
                        "Function definition not allowed in 'ti.func'.")
            # if impl.get_runtime().experimental_real_function:
            #     transform_as_kernel()
            if False:
                pass
            else:
                if len(args.args) != len(ctx.argument_data):
                    raise TaichiSyntaxError("Function argument of ")
                # Transform as force-inlined func
                for i, (arg,
                        data) in enumerate(zip(args.args, ctx.argument_data)):
                    # Remove annotations because they are not used.
                    args.args[i].annotation = None
                    # Template arguments are passed by reference.
                    if isinstance(ctx.func.argument_annotations[i],
                                  ti.template):
                        ctx.create_variable(ctx.func.argument_names[i], data)
                        continue
                    # Create a copy for non-template arguments,
                    # so that they are passed by value.
                    ctx.create_variable(arg.arg, ti.expr_init_func(data))

        with ctx.variable_scope_guard():
            build_stmts(ctx, node.body)

        return node
Ejemplo n.º 13
0
def test_ast_resolver_basic():
    # import within the function to avoid polluting the global scope
    import taichi as ti
    ti.init()
    node = ast.parse('ti.kernel', mode='eval').body
    assert ASTResolver.resolve_to(node, ti.kernel, locals())
Ejemplo n.º 14
0
def test_ast_resolver_chain():
    import taichi as ti
    ti.init()
    node = ast.parse('ti.lang.ops.atomic_add', mode='eval').body
    assert ASTResolver.resolve_to(node, ti.atomic_add, locals())
Ejemplo n.º 15
0
def test_ast_resolver_direct_import():
    import taichi as ti
    ti.init()
    from taichi import kernel
    node = ast.parse('kernel', mode='eval').body
    assert ASTResolver.resolve_to(node, kernel, locals())