def build_grouped_ndrange_for(ctx, node): # for I in ti.grouped(ti.ndrange(n, m)) node.body = build_stmts(ctx, node.body) target = node.target.id template = ''' if ti.static(1): __ndrange = 0 ___begin = ti.Expr(0) ___end = __ndrange.acc_dimensions[0] ___begin = ti.cast(___begin, ti.i32) ___end = ti.cast(___end, ti.i32) __ndrange_I = ti.Expr(ti.core.make_id_expr('')) ti.core.begin_frontend_range_for(__ndrange_I.ptr, ___begin.ptr, ___end.ptr) {} = ti.expr_init(ti.Vector([0] * len(__ndrange.dimensions), dt=ti.i32)) __I = __ndrange_I for __grouped_I in range(len(__ndrange.dimensions)): __grouped_I_tmp = 0 if __grouped_I + 1 < len(__ndrange.dimensions): __grouped_I_tmp = __I // __ndrange.acc_dimensions[__grouped_I + 1] else: __grouped_I_tmp = __I ti.subscript({}, __grouped_I).assign(__grouped_I_tmp + __ndrange.bounds[__grouped_I][0]) if __grouped_I + 1 < len(__ndrange.dimensions): __I = __I - __grouped_I_tmp * __ndrange.acc_dimensions[__grouped_I + 1] ti.core.end_frontend_range_for() '''.format(target, target) t = ast.parse(template).body[0] node.iter.args[0].args = build_exprs(ctx, node.iter.args[0].args) t.body[0].value = node.iter.args[0] cut = len(t.body) - 1 t.body = t.body[:cut] + node.body + t.body[cut:] return ast.copy_location(t, node)
def build_Assign(ctx, node): node.value = build_expr(ctx, node.value) node.targets = build_exprs(ctx, node.targets) is_static_assign = isinstance( node.value, ast.Call) and ASTResolver.resolve_to( node.value.func, ti.static, globals()) if is_static_assign: return node # Keep all generated assign statements and compose single one at last. # The variable is introduced to support chained assignments. # Ref https://github.com/taichi-dev/taichi/issues/2659. assign_stmts = [] for node_target in node.targets: if isinstance(node_target, ast.Tuple): assign_stmts.append( StmtBuilder.build_assign_unpack(ctx, node, node_target)) else: assign_stmts.append( StmtBuilder.build_assign_basic(ctx, node, node_target, node.value)) return StmtBuilder.make_single_statement(assign_stmts)
def build_Assign(ctx, node): assert (len(node.targets) == 1) node.value = build_expr(ctx, node.value) node.targets = build_exprs(ctx, node.targets) is_static_assign = isinstance( node.value, ast.Call) and ASTResolver.resolve_to( node.value.func, ti.static, globals()) if is_static_assign: return node if isinstance(node.targets[0], ast.Tuple): targets = node.targets[0].elts # Create stmts = [] holder = parse_stmt('__tmp_tuple = ti.expr_init_list(0, ' f'{len(targets)})') holder.value.args[0] = node.value stmts.append(holder) def tuple_indexed(i): indexing = parse_stmt('__tmp_tuple[0]') StmtBuilder.set_subscript_index(indexing.value, parse_expr("{}".format(i))) return indexing.value for i, target in enumerate(targets): is_local = isinstance(target, ast.Name) if is_local and ctx.is_creation(target.id): var_name = target.id target.ctx = ast.Store() # Create, no AST resolution needed init = ast.Attribute(value=ast.Name(id='ti', ctx=ast.Load()), attr='expr_init', ctx=ast.Load()) rhs = ast.Call( func=init, args=[tuple_indexed(i)], keywords=[], ) ctx.create_variable(var_name) stmts.append( ast.Assign(targets=[target], value=rhs, type_comment=None)) else: # Assign target.ctx = ast.Load() func = ast.Attribute(value=target, attr='assign', ctx=ast.Load()) call = ast.Call(func=func, args=[tuple_indexed(i)], keywords=[]) stmts.append(ast.Expr(value=call)) for stmt in stmts: ast.copy_location(stmt, node) stmts.append(parse_stmt('del __tmp_tuple')) return StmtBuilder.make_single_statement(stmts) else: is_local = isinstance(node.targets[0], ast.Name) if is_local and ctx.is_creation(node.targets[0].id): var_name = node.targets[0].id # Create, no AST resolution needed init = ast.Attribute(value=ast.Name(id='ti', ctx=ast.Load()), attr='expr_init', ctx=ast.Load()) rhs = ast.Call( func=init, args=[node.value], keywords=[], ) ctx.create_variable(var_name) return ast.copy_location( ast.Assign(targets=node.targets, value=rhs, type_comment=None), node) else: # Assign node.targets[0].ctx = ast.Load() func = ast.Attribute(value=node.targets[0], attr='assign', ctx=ast.Load()) call = ast.Call(func=func, args=[node.value], keywords=[]) return ast.copy_location(ast.Expr(value=call), node)