def check(shape): data = relay.var("data", shape=shape, dtype="int8") conv_weight = relay.var("weight") bias1 = relay.var("bias1", shape=(16, 1, 1), dtype="int32") bias2 = relay.var("bias2", shape=(16, 1, 1), dtype="int32") y = before(data, conv_weight, bias1, bias2) mod = tvm.IRModule.from_expr(y) seq = _transform.Sequential([_transform.InferType(), _transform.CanonicalizeCast(), _transform.InferType()]) with _transform.PassContext(opt_level=3): mod = seq(mod) y = mod["main"] y_expected = expected(data, conv_weight, bias1, bias2) gv = relay.GlobalVar("expected") mod[gv] = y_expected mod = _transform.InferType()(mod) y_expected = mod["expected"] assert relay.analysis.alpha_equal(y, y_expected)
def check(shape): data = relay.var("data", shape=shape, dtype="int8") conv_weight = relay.var("weight") bias1 = relay.var("bias1", shape=(16, 1, 1), dtype="int32") bias2 = relay.var("bias2", shape=(16, 1, 1), dtype="int32") y = before(data, conv_weight, bias1, bias2) mod = _module.Module.from_expr(y) seq = _transform.Sequential([ _transform.InferType(), _transform.CanonicalizeCast(), _transform.InferType() ]) with _transform.PassContext(opt_level=3): mod = seq(mod) y = mod[mod.entry_func.name_hint] y_expected = expected(data, conv_weight, bias1, bias2) y_expected = relay.ir_pass.infer_type(y_expected) assert relay.ir_pass.alpha_equal(y, y_expected)
def match_pass_name(name): if name == 'FoldScaleAxis': return transform.FoldScaleAxis() if name == 'BackwardFoldScaleAxis': return transform.BackwardFoldScaleAxis() if name == 'ForwardFoldScaleAxis': return transform.ForwardFoldScaleAxis() if name == 'FuseOps': return transform.FuseOps(3) if name == 'FoldConstant': return transform.FoldConstant() if name == 'CombineParallelConv2d': return transform.CombineParallelConv2D() if name == 'AlterOpLayout': return transform.AlterOpLayout() if name == 'EliminateCommonSubexpr': return transform.EliminateCommonSubexpr() if name == 'PartialEvaluate': return transform.PartialEvaluate() if name == 'CanonicalizeCast': return transform.CanonicalizeCast() if name == 'CanonicalizeOps': return transform.CanonicalizeOps() raise Exception('Name {} does not match any pass'.format(name))
def add_functions(mod, funcs): """Workaround for type checker and mutually recursive functions.""" for gv in funcs: func = funcs[gv] body = _placeholder_body(func.ret_type) mod[gv] = relay.Function(func.params, body, func.ret_type) for gv in funcs: mod[gv] = funcs[gv] pass_set = transform.Sequential( passes=[ transform.SimplifyInference(), transform.CanonicalizeOps(), transform.CanonicalizeCast(), transform.FuseOps(3), # transform.CombineParallelConv2d(), transform.AlterOpLayout(), # transform.RewriteAnnotatedOps(???), ], opt_level=0) def optimize(mod): """Optimize all the functions in a module. Modules are the only mutable piece of Relay. We write an optimization pass over the module which destructively updates each function while optimizing. """