def visit_For(self, node: ast.For): if node.orelse: raise CompilerError( "or else clause not supported for for statements") iter_node = self.visit(node.iter) target_node = self.visit(node.target) assert iter_node is not None assert target_node is not None pos = extract_positional_info(node) targets = set() iterables = set() # Do initial checks for weird issues that may arise here. # We don't lower it fully at this point, because it injects # additional arithmetic and not all variable types may be fully known # at this point. try: for target, iterable in unpack_iterated( target_node, iter_node, include_enumerate_indices=True): targets.add(target) iterables.add(iterable) except ValueError: # Generator will throw an error on bad unpacking msg = f"Cannot safely unpack for loop expression, line: {pos.line_begin}" raise CompilerError(msg) conflicts = targets.intersection(iterables) if conflicts: conflict_names = ", ".join(c for c in conflicts) msg = f"{conflict_names} appear in both the target an iterable sequences of a for loop, " \ f"line {pos.line_begin}. This is not supported." raise CompilerError(msg) with self.loop_region(node): for stmt in node.body: self.visit(stmt) loop = ir.ForLoop(target_node, iter_node, self.body, pos) self.body.append(loop)
def _(self, node: ir.ForLoop): interm = make_single_index_loop(node, self.symbols) body = self.visit(interm.body) repl = ir.ForLoop(interm.target, interm.iterable, body, node.pos) return repl
def make_single_index_loop(header: ir.ForLoop, symbols): """ Make loop interval of the form (start, stop, step). This tries to find a safe method of calculation. This assumes (with runtime verification if necessary) that 'stop - start' will not overflow. References: LIVINSKII et. al, Random Testing for C and C++ Compilers with YARPGen Dietz et. al, Understanding Integer Overflow in C/C++ Bachmann et. al, Chains of Recurrences - a method to expedite the evaluation of closed-form functions https://gcc.gnu.org/onlinedocs/gcc/Integer-Overflow-Builtins.html https://developercommunity.visualstudio.com/t/please-implement-integer-overflow-detection/409051 https://numpy.org/doc/stable/user/building.html """ by_iterable = {} intervals = set() interval_from_iterable = IntervalBuilder() for _, iterable in unpack_iterated(header.target, header.iterable): interval = interval_from_iterable(iterable) by_iterable[iterable] = interval intervals.add(interval) # loop_interval = _find_shared_interval(intervals) loop_start, loop_stop, loop_step = _find_shared_interval(intervals) loop_expr = ir.AffineSeq(loop_start, loop_stop, loop_step) # Todo: this needs a default setting to avoid excessive casts loop_counter = symbols.make_unique_name_like("i", type_=tr.Int32) body = [] pos = header.pos simplify_expr = arithmetic_folding() for target, iterable in unpack_iterated(header.target, header.iterable): (start, _, step) = by_iterable[iterable] assert step == loop_step assert (start == loop_start) or (loop_start == ir.Zero) if step == loop_step: if start == loop_start: index = loop_counter else: assert loop_start == ir.Zero index = ir.BinOp(loop_counter, start, "+") else: # loop counter must be normalized assert loop_start == ir.Zero assert loop_step == ir.One index = ir.BinOp(step, loop_counter, "*") if start != ir.Zero: index = ir.BinOp(start, index, "+") value = index if isinstance(iterable, ir.AffineSeq) else ir.Subscript(iterable, index) assign = ir.Assign(target, value, pos) body.append(assign) # Todo: this doesn't hoist initial setup body.extend(header.body) repl = ir.ForLoop(loop_counter, loop_expr, body, pos) return repl
def _(self, node: ir.ForLoop): body = self.visit(node.body) body = remove_trailing_continues(body) if body != node.body: node = ir.ForLoop(node.target, node.iterable, body, node.pos) return node
def _(self, node: ir.ForLoop): body = self.visit(node.body) if body != node.body: node = ir.ForLoop(node.target, node.iterable, body, node.pos) return node