Exemplo n.º 1
0
    def visit_Assign(self, node):
        rhs = self.visit(node.value)
        if isinstance(rhs, Operation):
            rmap = {}
            _internal_assert(len(node.targets) == rhs.num_outputs, \
                             "Unable to detuple the outs to targets")
            for i in range(rhs.num_outputs):
                _internal_assert(isinstance(node.targets[i], ast.Name),
                                 "You should bind a pure name to the tensors")
                self.add_symbol(node.targets[i].id, Symbol.GlobalBuffer, rhs.output(i))
                rmap[rhs.outputs[i].op] = rhs.output(i)
            return util.replace_io(rhs.body, rmap)

        _internal_assert(len(node.targets) == 1, "So far only one-valued assignment is supported!")
        lhs = node.targets[0]
        if isinstance(rhs, _expr.PrimExpr):
            rhs = _ir_pass.Simplify(rhs)
        if isinstance(lhs, ast.Name):
            #TODO: support defined intermediate buffer later
            lhs_ = lhs
            lhs = lhs.id
            if lhs in self.symbols.keys():
                ty, _ = self.symbols[lhs]
                _internal_assert(ty != Symbol.LoopVar, \
                                 "Loop variable cannot be overwritten!")
            decl, _, rw = self.usage[lhs]
            if decl == lhs_:
                _internal_assert(lhs not in self.symbols.keys(),
                                 "This value should not be defined before this point!")
                if isinstance(rhs, tuple):
                    shape, dtype, scope = rhs
                    ph = tvm.te.placeholder(shape, dtype=dtype, name=lhs)
                    self.add_symbol(lhs, getattr(Symbol, scope.title() + "Buffer"), ph)
                    if scope == 'output':
                        self.outputs.append(lhs)
                    return util.make_nop()
                if isinstance(rhs, util.halide_imm_types) and ast.Store not in rw:
                    self.add_symbol(lhs, Symbol.ConstVar, rhs)
                else:
                    _internal_assert(self.device == 0,
                                     "Single variable not supported in devices' side!\n" + \
                                     "If you are using GPU, please allocate a 'local' spad " + \
                                     "outside the bind body")
                    ph = tvm.te.placeholder((1, ), dtype=rhs.dtype, name=lhs)
                    self.add_symbol(lhs, Symbol.BufferVar, ph)
            lhs = self.visit(lhs_)
            if lhs is not None:
                buf, args = lhs
                return tvm.tir.Provide(buf.op, 0, rhs, args)
            return util.make_nop()

        lhs, args = self.visit(lhs)
        _internal_assert(isinstance(lhs, Tensor), \
                         "An array access's LHS is expected to be a expr.Call!")
        res = tvm.tir.Provide(lhs.op, lhs.value_index, rhs, args)
        return res
Exemplo n.º 2
0
def get_const_int(exp):
    """Verifies expr is integer and get the constant value.

    Parameters
    ----------
    exp : tvm.Expr or int
        The input expression.

    Returns
    -------
    out_value : int
        The output.
    """
    if isinstance(exp, int):
        return exp
    if not isinstance(exp, (expr.IntImm, )):
        exp = ir_pass.Simplify(exp)
    if not isinstance(exp, (expr.IntImm, )):
        raise ValueError("Expect value to be constant int")
    return exp.value
Exemplo n.º 3
0
def get_const_tuple(in_tuple):
    """Verifies input tuple is IntImm or Var, returns tuple of int or Var.

    Parameters
    ----------
    in_tuple : tuple of Expr
        The input.

    Returns
    -------
    out_tuple : tuple of int
        The output.
    """
    ret = []
    for elem in in_tuple:
        if isinstance(elem, expr.Var):
            ret.append(elem)
        elif not isinstance(elem, (expr.IntImm, int)):
            elem = ir_pass.Simplify(elem)
            if not isinstance(elem, (expr.IntImm)):
                ret.append(elem)
        else:
            ret.append(get_const_int(elem))
    return tuple(ret)
Exemplo n.º 4
0
def lower(sch, args, name="default_function", binds=None, simple_mode=False):
    """Lowering step before build into target.

    Parameters
    ----------
    sch : tvm.te.schedule.Schedule
        The schedule to be built

    args : list of Buffer or Tensor or Var
        The argument lists to the function.

    name : str, optional
        The name of result function.

    binds : dict of :any:`Tensor` to :any:`Buffer`, optional
        Dictionary that maps the Tensor to Buffer which specified the data layout
        requirement of the function. By default, a new compact buffer is created
        for each tensor in the argument.

    simple_mode : bool, optional
        Whether only output simple and compact statement, this will skip
        LoopPartition, api wrapper generation and Unrolling.

    Returns
    -------
    m : IRModule or Stmt
       The result IRModule, if simple_mode=False
       Then the Stmt before make api is returned.
    """
    cfg = BuildConfig.current()
    add_lower_pass = cfg.add_lower_pass if cfg.add_lower_pass else []
    if cfg.dump_pass_ir:
        add_lower_pass = BuildConfig._dump_ir.decorate_custompass(
            add_lower_pass)
    lower_phase0 = [x[1] for x in add_lower_pass if x[0] == 0]
    lower_phase1 = [x[1] for x in add_lower_pass if x[0] == 1]
    lower_phase2 = [x[1] for x in add_lower_pass if x[0] == 2]
    lower_phase3 = [x[1] for x in add_lower_pass if x[0] > 2]

    # Phase 0
    if isinstance(sch, schedule.Schedule):
        stmt = form_body(sch)

    for f in lower_phase0:
        stmt = f(stmt)

    compact = ir_pass.VerifyCompactBuffer(stmt)
    binds, arg_list = get_binds(args, compact, binds)

    # Phase 1
    stmt = ir_pass.RewriteForTensorCore(stmt, sch, binds)
    stmt = ir_pass.StorageFlatten(stmt, binds, 64,
                                  cfg.instrument_bound_checkers)
    stmt = ir_pass.NarrowDataType(stmt, 32)
    stmt = ir_pass.CanonicalSimplify(stmt)
    for f in lower_phase1:
        stmt = f(stmt)

    # Phase 2
    if not simple_mode:
        stmt = ir_pass.LoopPartition(stmt, cfg.partition_const_loop)
    if cfg.disable_vectorize:
        stmt = ir_pass.SkipVectorize(stmt)
    else:
        stmt = ir_pass.VectorizeLoop(stmt)
    stmt = ir_pass.InjectVirtualThread(stmt)
    stmt = ir_pass.InjectDoubleBuffer(stmt, cfg.double_buffer_split_loop)
    stmt = ir_pass.StorageRewrite(stmt)
    stmt = ir_pass.UnrollLoop(stmt, cfg.auto_unroll_max_step,
                              cfg.auto_unroll_max_depth,
                              cfg.auto_unroll_max_extent, cfg.unroll_explicit)
    for f in lower_phase2:
        stmt = f(stmt)

    # Phase 3
    stmt = ir_pass.Simplify(stmt)
    stmt = ir_pass.RemoveNoOp(stmt)
    if not cfg.disable_select_rewriting:
        stmt = ir_pass.RewriteUnsafeSelect(stmt)
    for f in lower_phase3:
        stmt = f(stmt)
    # Instrument BoundCheckers
    if cfg.instrument_bound_checkers:
        stmt = ir_pass.InstrumentBoundCheckers(stmt)
    if simple_mode:
        return stmt

    f = tvm.tir.PrimFunc(arg_list, stmt).with_attr("global_symbol",
                                                   tvm.runtime.String(name))
    if cfg.restricted_func:
        f = f.with_attr("tir.noalias", True)
    mod = tvm.IRModule({name: f})
    return tvm.tir.transform.MakePackedAPI()(mod)