def check(shape):
     data = relay.var("data", shape=shape, dtype="int8")
     conv_weight = relay.var("weight")
     bias1 = relay.var("bias1", shape=(16, 1, 1), dtype="int32")
     bias2 = relay.var("bias2", shape=(16, 1, 1), dtype="int32")
     y = before(data, conv_weight, bias1, bias2)
     mod = tvm.IRModule.from_expr(y)
     seq = _transform.Sequential([_transform.InferType(), _transform.CanonicalizeCast(),
                                  _transform.InferType()])
     with _transform.PassContext(opt_level=3):
         mod = seq(mod)
     y = mod["main"]
     y_expected = expected(data, conv_weight, bias1, bias2)
     gv = relay.GlobalVar("expected")
     mod[gv] = y_expected
     mod = _transform.InferType()(mod)
     y_expected = mod["expected"]
     assert relay.analysis.alpha_equal(y, y_expected)
示例#2
0
 def check(shape):
     data = relay.var("data", shape=shape, dtype="int8")
     conv_weight = relay.var("weight")
     bias1 = relay.var("bias1", shape=(16, 1, 1), dtype="int32")
     bias2 = relay.var("bias2", shape=(16, 1, 1), dtype="int32")
     y = before(data, conv_weight, bias1, bias2)
     mod = _module.Module.from_expr(y)
     seq = _transform.Sequential([
         _transform.InferType(),
         _transform.CanonicalizeCast(),
         _transform.InferType()
     ])
     with _transform.PassContext(opt_level=3):
         mod = seq(mod)
     y = mod[mod.entry_func.name_hint]
     y_expected = expected(data, conv_weight, bias1, bias2)
     y_expected = relay.ir_pass.infer_type(y_expected)
     assert relay.ir_pass.alpha_equal(y, y_expected)
示例#3
0
 def match_pass_name(name):
     if name == 'FoldScaleAxis':
         return transform.FoldScaleAxis()
     if name == 'BackwardFoldScaleAxis':
         return transform.BackwardFoldScaleAxis()
     if name == 'ForwardFoldScaleAxis':
         return transform.ForwardFoldScaleAxis()
     if name == 'FuseOps':
         return transform.FuseOps(3)
     if name == 'FoldConstant':
         return transform.FoldConstant()
     if name == 'CombineParallelConv2d':
         return transform.CombineParallelConv2D()
     if name == 'AlterOpLayout':
         return transform.AlterOpLayout()
     if name == 'EliminateCommonSubexpr':
         return transform.EliminateCommonSubexpr()
     if name == 'PartialEvaluate':
         return transform.PartialEvaluate()
     if name == 'CanonicalizeCast':
         return transform.CanonicalizeCast()
     if name == 'CanonicalizeOps':
         return transform.CanonicalizeOps()
     raise Exception('Name {} does not match any pass'.format(name))
示例#4
0
def add_functions(mod, funcs):
    """Workaround for type checker and mutually recursive functions."""
    for gv in funcs:
        func = funcs[gv]
        body = _placeholder_body(func.ret_type)
        mod[gv] = relay.Function(func.params, body, func.ret_type)

    for gv in funcs:
        mod[gv] = funcs[gv]


pass_set = transform.Sequential(
    passes=[
        transform.SimplifyInference(),
        transform.CanonicalizeOps(),
        transform.CanonicalizeCast(),
        transform.FuseOps(3),
        # transform.CombineParallelConv2d(),
        transform.AlterOpLayout(),
        # transform.RewriteAnnotatedOps(???),
    ],
    opt_level=0)


def optimize(mod):
    """Optimize all the functions in a module.

    Modules are the only mutable piece of Relay.  We write an
    optimization pass over the module which destructively updates each
    function while optimizing.
    """