def visit_list_to_block(visit, lst): """Visit and concatenate a list of Python IR nodes to HalideIR Block""" lst = [visit(stmt) for stmt in lst if not util.is_docstring(stmt)] lst = [stmt for stmt in lst if not _ir_pass.Equal(stmt, util.make_nop())] if not lst: return util.make_nop() return concat_list_to_block(lst)
def visit_For(self, node): iter_var, low, ext, for_type = self.visit(node.iter) _internal_assert(isinstance(node.target, ast.Name), \ "The loop iterator should be a variable!") _name = node.target.id if isinstance(for_type, tuple): low = _ir_pass.CanonicalSimplify(low) ext = _ir_pass.CanonicalSimplify(ext) _internal_assert(isinstance(low, _expr.ConstExpr) and isinstance(ext, _expr.ConstExpr), \ "Const range should start from a const " + \ "and iterate const times") low, ext = low.value, ext.value if ext > 114514: logging.log(logging.CRITICAL, \ '[Warning] Are you sure to unroll a large loop in Python?') bodies = [] for i in range(low, low + ext): self.add_symbol(_name, Symbol.ConstLoopVar, i) body = visit_list_to_block(self.visit, node.body) body = self.wrap_up_realize(node, body) bodies.append(body) self.symbols.pop(_name) return concat_list_to_block(bodies) if iter_var is None: _internal_assert(for_type is not None, "The loop iterating function parse error!") offset = iter_var = _api.var(_name) if not _ir_pass.Equal(low, _api.const(0, 'int32')): offset = iter_var + low self.add_symbol(_name, Symbol.LoopVar, offset) _body = visit_list_to_block(self.visit, node.body) else: _internal_assert(for_type is None, "The loop bind function parse error!") self.add_symbol(_name, Symbol.ThreadBind, iter_var) self.device += 1 _body = visit_list_to_block(self.visit, node.body) self.device -= 1 _body = self.wrap_up_realize(node, _body) if for_type is None: res = _body else: _internal_assert(not isinstance(for_type, tuple), \ "Micro expansion should be handled before!") res = tvm.tir.For(iter_var, _api.const(0, 'int32'), ext, for_type, 0, _body) self.symbols.pop(_name) return res
def _range(annotation, args): """Handling TVM loop types""" n = args.__len__() if n == 1: low, ext = _api.const(0, dtype='int32'), args[0] else: _internal_assert(n == 2, "A loop intrinsic should only have 1 or 2 arguments!") low, ext = args[0], args[1] if not ir_pass.Equal(low, _api.const(0, dtype='int32')): ext = ext - low for_type = LOOP_INTRIN[annotation] iter_var = None return iter_var, low, ext, for_type
def add_symbol(self, key, ty, val): #pylint: disable=invalid-name """Add value to the symbol table context""" if key in self.symbols.keys(): old = str(self.symbols[key]) new = str((ty, val)) _internal_assert(False, "Name conflict in symbol table! [%s] %s -> %s" % (key, old, new)) self.symbols[key] = ty, val if ty == Symbol.ThreadBind: if val.var.name not in self.binds.keys(): self.binds[val.var.name] = val return val_ = self.binds[val.var.name] _internal_assert(_ir_pass.Equal(val_.dom.extent, val.dom.extent), "Thread extents should be uniform!") self.symbols[key] = ty, val_