Пример #1
0
    def build_grouped_ndrange_for(ctx, node):
        # for I in ti.grouped(ti.ndrange(n, m))
        node.body = build_stmts(ctx, node.body)
        target = node.target.id
        template = '''
if ti.static(1):
    __ndrange = 0
    ___begin = ti.Expr(0)
    ___end = __ndrange.acc_dimensions[0]
    ___begin = ti.cast(___begin, ti.i32)
    ___end = ti.cast(___end, ti.i32)
    __ndrange_I = ti.Expr(ti.core.make_id_expr(''))
    ti.core.begin_frontend_range_for(__ndrange_I.ptr, ___begin.ptr, ___end.ptr)
    {} = ti.expr_init(ti.Vector([0] * len(__ndrange.dimensions), dt=ti.i32))
    __I = __ndrange_I
    for __grouped_I in range(len(__ndrange.dimensions)):
        __grouped_I_tmp = 0
        if __grouped_I + 1 < len(__ndrange.dimensions):
            __grouped_I_tmp = __I // __ndrange.acc_dimensions[__grouped_I + 1]
        else:
            __grouped_I_tmp = __I
        ti.subscript({}, __grouped_I).assign(__grouped_I_tmp + __ndrange.bounds[__grouped_I][0])
        if __grouped_I + 1 < len(__ndrange.dimensions):
            __I = __I - __grouped_I_tmp * __ndrange.acc_dimensions[__grouped_I + 1]
    ti.core.end_frontend_range_for()
        '''.format(target, target)
        t = ast.parse(template).body[0]
        node.iter.args[0].args = build_exprs(ctx, node.iter.args[0].args)
        t.body[0].value = node.iter.args[0]
        cut = len(t.body) - 1
        t.body = t.body[:cut] + node.body + t.body[cut:]
        return ast.copy_location(t, node)
Пример #2
0
    def build_Assign(ctx, node):
        node.value = build_expr(ctx, node.value)
        node.targets = build_exprs(ctx, node.targets)

        is_static_assign = isinstance(
            node.value, ast.Call) and ASTResolver.resolve_to(
                node.value.func, ti.static, globals())
        if is_static_assign:
            return node

        # Keep all generated assign statements and compose single one at last.
        # The variable is introduced to support chained assignments.
        # Ref https://github.com/taichi-dev/taichi/issues/2659.
        assign_stmts = []
        for node_target in node.targets:
            if isinstance(node_target, ast.Tuple):
                assign_stmts.append(
                    StmtBuilder.build_assign_unpack(ctx, node, node_target))
            else:
                assign_stmts.append(
                    StmtBuilder.build_assign_basic(ctx, node, node_target,
                                                   node.value))
        return StmtBuilder.make_single_statement(assign_stmts)
Пример #3
0
    def build_Assign(ctx, node):
        assert (len(node.targets) == 1)
        node.value = build_expr(ctx, node.value)
        node.targets = build_exprs(ctx, node.targets)

        is_static_assign = isinstance(
            node.value, ast.Call) and ASTResolver.resolve_to(
                node.value.func, ti.static, globals())
        if is_static_assign:
            return node

        if isinstance(node.targets[0], ast.Tuple):
            targets = node.targets[0].elts

            # Create
            stmts = []

            holder = parse_stmt('__tmp_tuple = ti.expr_init_list(0, '
                                f'{len(targets)})')
            holder.value.args[0] = node.value

            stmts.append(holder)

            def tuple_indexed(i):
                indexing = parse_stmt('__tmp_tuple[0]')
                StmtBuilder.set_subscript_index(indexing.value,
                                                parse_expr("{}".format(i)))
                return indexing.value

            for i, target in enumerate(targets):
                is_local = isinstance(target, ast.Name)
                if is_local and ctx.is_creation(target.id):
                    var_name = target.id
                    target.ctx = ast.Store()
                    # Create, no AST resolution needed
                    init = ast.Attribute(value=ast.Name(id='ti',
                                                        ctx=ast.Load()),
                                         attr='expr_init',
                                         ctx=ast.Load())
                    rhs = ast.Call(
                        func=init,
                        args=[tuple_indexed(i)],
                        keywords=[],
                    )
                    ctx.create_variable(var_name)
                    stmts.append(
                        ast.Assign(targets=[target],
                                   value=rhs,
                                   type_comment=None))
                else:
                    # Assign
                    target.ctx = ast.Load()
                    func = ast.Attribute(value=target,
                                         attr='assign',
                                         ctx=ast.Load())
                    call = ast.Call(func=func,
                                    args=[tuple_indexed(i)],
                                    keywords=[])
                    stmts.append(ast.Expr(value=call))

            for stmt in stmts:
                ast.copy_location(stmt, node)
            stmts.append(parse_stmt('del __tmp_tuple'))
            return StmtBuilder.make_single_statement(stmts)
        else:
            is_local = isinstance(node.targets[0], ast.Name)
            if is_local and ctx.is_creation(node.targets[0].id):
                var_name = node.targets[0].id
                # Create, no AST resolution needed
                init = ast.Attribute(value=ast.Name(id='ti', ctx=ast.Load()),
                                     attr='expr_init',
                                     ctx=ast.Load())
                rhs = ast.Call(
                    func=init,
                    args=[node.value],
                    keywords=[],
                )
                ctx.create_variable(var_name)
                return ast.copy_location(
                    ast.Assign(targets=node.targets,
                               value=rhs,
                               type_comment=None), node)
            else:
                # Assign
                node.targets[0].ctx = ast.Load()
                func = ast.Attribute(value=node.targets[0],
                                     attr='assign',
                                     ctx=ast.Load())
                call = ast.Call(func=func, args=[node.value], keywords=[])
                return ast.copy_location(ast.Expr(value=call), node)