示例#1
0
def lower_tvm_stmt(stmt, args, binds=None, name='main'):
    name = slugify(name)
    compact = schedule.VerifyCompactBuffer(stmt)
    binds, arg_list = build_module.get_binds(args,
                                             compact=compact,
                                             binds=binds)

    func = schedule.SchedulePostProcToPrimFunc(arg_list, stmt, binds)

    func = func.with_attr('global_symbol', name)

    pass_ctx = tvm.ir.transform.PassContext.current()
    if pass_ctx.config.get('tir.noalias', True):
        func = func.with_attr('tir.noalias', True)

    module = tvm.IRModule({name: func})
    module = build_module.lower(module, args, name=name)

    return module
示例#2
0
def form_irmodule(sch, args, name, binds):
    """According to the given schedule, form a function.

    Parameters
    ----------
    sch : tvm.te.schedule.Schedule
        The given scheduler to form the raw body

    args : list of Buffer or Tensor or Var
        The argument lists to the function.

    name : str
        The name of result function.

    binds : dict of :any:`Tensor` to :any:`Buffer`, optional
        The binds information

    Returns
    -------
    The body formed according to the given schedule
    """
    # normalize schedule first
    pass_ctx = PassContext.current()
    sch = sch.normalize()
    bounds = schedule.InferBound(sch)
    stmt = schedule.ScheduleOps(sch, bounds)

    compact = schedule.VerifyCompactBuffer(stmt)
    binds, arg_list = get_binds(args, compact, binds)

    stmt = schedule.SchedulePostProcRewriteForTensorCore(stmt, sch, binds)
    func = schedule.SchedulePostProcToPrimFunc(arg_list, stmt, binds)

    func = func.with_attr("global_symbol", name)

    if pass_ctx.config.get("tir.noalias", True):
        func = func.with_attr("tir.noalias", True)
    return tvm.IRModule({name: func})