Example #1
    def visit_For(self, node):
        iter_var, low, ext, for_type = self.visit(node.iter)
        _internal_assert(isinstance(node.target, ast.Name), \
                         "The loop iterator should be a variable!")

        _name = node.target.id

        if isinstance(for_type, tuple):
            low = _ir_pass.CanonicalSimplify(low)
            ext = _ir_pass.CanonicalSimplify(ext)
            _internal_assert(isinstance(low, _expr.ConstExpr) and
                             isinstance(ext, _expr.ConstExpr), \
                             "Const range should start from a const " + \
                             "and iterate const times")

            low, ext = low.value, ext.value
            if ext > 114514:
                logging.log(logging.CRITICAL, \
                            '[Warning] Are you sure to unroll a large loop in Python?')

            bodies = []
            for i in range(low, low + ext):
                self.add_symbol(_name, Symbol.ConstLoopVar, i)
                body = visit_list_to_block(self.visit, node.body)
                body = self.wrap_up_realize(node, body)
            return concat_list_to_block(bodies)

        if iter_var is None:
            _internal_assert(for_type is not None,
                             "The loop iterating function parse error!")
            offset = iter_var = tvm.te.var(_name)
            if not tvm.tir.analysis.expr_deep_equal(
                    low, tvm.runtime.const(0, 'int32')):
                offset = iter_var + low
            self.add_symbol(_name, Symbol.LoopVar, offset)
            _body = visit_list_to_block(self.visit, node.body)
            _internal_assert(for_type is None,
                             "The loop bind function parse error!")
            self.add_symbol(_name, Symbol.ThreadBind, iter_var)
            self.device += 1
            _body = visit_list_to_block(self.visit, node.body)
            self.device -= 1

        _body = self.wrap_up_realize(node, _body)

        if for_type is None:
            res = _body
            _internal_assert(not isinstance(for_type, tuple), \
                            "Micro expansion should be handled before!")
            res = tvm.tir.For(iter_var, tvm.runtime.const(0, 'int32'), ext,
                              for_type, 0, _body)

        return res
Example #2
def ana_lower(sch, args, binds=None, simple_mode=True):
    """Do lower while keeping all axes in IR
    i.e. Do not eliminate loop with extent of 1, do not vectorize, unroll or inject virtual threads
    binds, _ = build_module.get_binds(args, binds)
    sch = sch.normalize()
    # Phase 0
    bounds = schedule.InferBound(sch)
    stmt = schedule.ScheduleOps(sch, bounds, True)
    stmt = ir_pass.StorageFlatten(stmt, binds, 64)
    stmt = ir_pass.CanonicalSimplify(stmt)
    assert simple_mode
    return stmt
Example #3
    def visit_If(self, node):
        cond = _ir_pass.CanonicalSimplify(self.visit(node.test))

        # Return no IfThenElse if proven
        if isinstance(cond, _expr.IntImm):
            if cond.value:
                return visit_list_to_block(self.visit, node.body)
            if node.orelse:
                return visit_list_to_block(self.visit, node.orelse)
            return util.make_nop()

        if_body = visit_list_to_block(self.visit, node.body)

        if node.orelse:
            else_body = visit_list_to_block(self.visit, node.orelse)
            else_body = None
        return tvm.tir.IfThenElse(cond, if_body, else_body)
Example #4
def lower(sch, args, name="default_function", binds=None, simple_mode=False):
    """Lowering step before build into target.

    sch : tvm.te.schedule.Schedule
        The schedule to be built

    args : list of Buffer or Tensor or Var
        The argument lists to the function.

    name : str, optional
        The name of result function.

    binds : dict of :any:`Tensor` to :any:`Buffer`, optional
        Dictionary that maps the Tensor to Buffer which specified the data layout
        requirement of the function. By default, a new compact buffer is created
        for each tensor in the argument.

    simple_mode : bool, optional
        Whether only output simple and compact statement, this will skip
        LoopPartition, api wrapper generation and Unrolling.

    m : IRModule or Stmt
       The result IRModule, if simple_mode=False
       Then the Stmt before make api is returned.
    cfg = BuildConfig.current()
    add_lower_pass = cfg.add_lower_pass if cfg.add_lower_pass else []
    if cfg.dump_pass_ir:
        add_lower_pass = BuildConfig._dump_ir.decorate_custompass(
    lower_phase0 = [x[1] for x in add_lower_pass if x[0] == 0]
    lower_phase1 = [x[1] for x in add_lower_pass if x[0] == 1]
    lower_phase2 = [x[1] for x in add_lower_pass if x[0] == 2]
    lower_phase3 = [x[1] for x in add_lower_pass if x[0] > 2]

    # Phase 0
    if isinstance(sch, schedule.Schedule):
        stmt = form_body(sch)

    for f in lower_phase0:
        stmt = f(stmt)

    compact = ir_pass.VerifyCompactBuffer(stmt)
    binds, arg_list = get_binds(args, compact, binds)

    # Phase 1
    stmt = ir_pass.RewriteForTensorCore(stmt, sch, binds)
    stmt = ir_pass.StorageFlatten(stmt, binds, 64,
    stmt = ir_pass.NarrowDataType(stmt, 32)
    stmt = ir_pass.CanonicalSimplify(stmt)
    for f in lower_phase1:
        stmt = f(stmt)

    # Phase 2
    if not simple_mode:
        stmt = ir_pass.LoopPartition(stmt, cfg.partition_const_loop)
    if cfg.disable_vectorize:
        stmt = ir_pass.SkipVectorize(stmt)
        stmt = ir_pass.VectorizeLoop(stmt)
    stmt = ir_pass.InjectVirtualThread(stmt)
    stmt = ir_pass.InjectDoubleBuffer(stmt, cfg.double_buffer_split_loop)
    stmt = ir_pass.StorageRewrite(stmt)
    stmt = ir_pass.UnrollLoop(stmt, cfg.auto_unroll_max_step,
                              cfg.auto_unroll_max_extent, cfg.unroll_explicit)
    for f in lower_phase2:
        stmt = f(stmt)

    # Phase 3
    stmt = ir_pass.Simplify(stmt)
    stmt = ir_pass.RemoveNoOp(stmt)
    if not cfg.disable_select_rewriting:
        stmt = ir_pass.RewriteUnsafeSelect(stmt)
    for f in lower_phase3:
        stmt = f(stmt)
    # Instrument BoundCheckers
    if cfg.instrument_bound_checkers:
        stmt = ir_pass.InstrumentBoundCheckers(stmt)
    if simple_mode:
        return stmt

    f = tvm.tir.PrimFunc(arg_list, stmt).with_attr("global_symbol",
    if cfg.restricted_func:
        f = f.with_attr("tir.noalias", True)
    mod = tvm.IRModule({name: f})
    return tvm.tir.transform.MakePackedAPI()(mod)