def parse_for(self, node, parent): with self._for_loop_vars(node) as (iter_var, c_var, extent_var, lower, upper, step, for_type): extent = tir.FloorDiv(tir.Sub(upper, lower), step) return tir.LetStmt( extent_var, extent, tir.For( iter_var, tir.IntImm('int32', 0), extent_var, for_type, tir.LetStmt(c_var, tir.Add(tir.Mul(iter_var, step), lower), self.parse(node.body(), node))))
def _for_bounds(self, node): assert node.cond().arg(0).to_C_str() == node.iterator().to_C_str() var_name = node.iterator().to_C_str() lower = self.expr_parser.parse(node.init()) upper = self.expr_parser.parse(node.cond().arg(1)) step = self.expr_parser.parse(node.inc()) if isinstance(node.cond(), (isl.ast_expr_op_le, isl.ast_expr_op_ge)): upper = tir.Add(upper, step) return var_name, lower, upper, step
def parse_op_add(self, expr, parent): if expr.n_arg() == 1: return tir.Add(tir.IntImm('int32', 0), self.parse(expr.arg(0), expr)) return tir.Add(self.parse(expr.arg(0), expr), self.parse(expr.arg(1), expr))