def check(x_shape, repeat):
     x = relay.var("x", shape=x_shape)
     in_c = x_shape[1]
     out_c = in_c // 2
     w = relay.var("w", shape=(out_c, in_c, 1, 1))
     y_before = before(x, w, repeat)
     y = run_opt_pass(y_before, transform.CombineParallelConv2D(min_num_branches=2))
     y_expected = expected(x, w, out_c, repeat)
     y_expected = run_opt_pass(y_expected, transform.InferType())
     assert tvm.ir.structural_equal(y, y_expected, map_free_vars=True)
Exemplo n.º 2
0
 def check(x_shape, repeat):
     x = relay.var("x", shape=x_shape)
     in_c = x_shape[1]
     out_c = in_c // 2
     w = relay.var("w", shape=(out_c, in_c, 1, 1))
     y_before = before(x, w, repeat)
     y = run_opt_pass(y_before,
                      transform.CombineParallelConv2D(min_num_branches=2))
     y_expected = expected(x, w, out_c, repeat)
     y_expected = run_opt_pass(y_expected, transform.InferType())
     assert relay.analysis.alpha_equal(y, y_expected)
 def check(x_shape, channels1, channels2):
     x = relay.var("x", shape=x_shape)
     in_c = x_shape[1]
     w1 = relay.var("w1", shape=(channels1, in_c, 1, 1))
     w2 = relay.var("w2", shape=(channels2, in_c, 1, 1))
     scale1 = relay.var("scale1", shape=(1,))
     scale2 = relay.var("scale2", shape=(1,))
     y_before = before(x, w1, w2, scale1, scale2)
     y = run_opt_pass(y_before, transform.CombineParallelConv2D(min_num_branches=2))
     y_expected = expected(x, w1, w2, scale1, scale2, channels1, channels2)
     y_expected = run_opt_pass(y_expected, transform.InferType())
     assert tvm.ir.structural_equal(y, y_expected, map_free_vars=True)
Exemplo n.º 4
0
    def check(x_shape, channels1, channels2, channels3, channels4):
        x =  relay.var("x", shape=x_shape)
        in_c = x_shape[1]
        w1 = relay.var("w1", shape=(channels1, in_c, 1, 1))
        w2 = relay.var("w2", shape=(channels2, in_c, 1, 1))
        w3 = relay.var("w3", shape=(channels3, in_c, 3, 3))
        w4 = relay.var("w4", shape=(channels4, in_c, 1, 1))

        y_before = before(x, w1, w2, w3, w4)
        y = run_opt_pass(y_before,
                         transform.CombineParallelConv2D(min_num_branches=2))
        y_expected = expected(x, w1, w2, w3, w4, channels1, channels2, channels3, channels4)
        y_expected = run_opt_pass(y_expected, transform.InferType())
        assert relay.analysis.alpha_equal(y, y_expected)
Exemplo n.º 5
0
 def check(x_shape, channels1, channels2):
     x = relay.var("x", shape=x_shape)
     in_c = x_shape[1]
     w1 = relay.var("w1", shape=(channels1, in_c, 1, 1))
     w2 = relay.var("w2", shape=(channels2, in_c, 1, 1))
     scale1 = relay.var("scale1", shape=(channels1, 1, 1))
     scale2 = relay.var("scale2", shape=(channels2, 1, 1))
     bias = relay.var("bias", shape=(channels2, 1, 1))
     y_before = before(x, w1, w2, scale1, scale2, bias)
     y = run_opt_pass(y_before,
                      transform.CombineParallelConv2D(min_num_branches=2))
     y_expected = expected(x, w1, w2, scale1, scale2, bias, channels1, channels2)
     y_expected = run_opt_pass(y_expected, transform.InferType())
     assert relay.analysis.alpha_equal(y, y_expected)
Exemplo n.º 6
0
 def match_pass_name(name):
     if name == 'FoldScaleAxis':
         return transform.FoldScaleAxis()
     if name == 'BackwardFoldScaleAxis':
         return transform.BackwardFoldScaleAxis()
     if name == 'ForwardFoldScaleAxis':
         return transform.ForwardFoldScaleAxis()
     if name == 'FuseOps':
         return transform.FuseOps(3)
     if name == 'FoldConstant':
         return transform.FoldConstant()
     if name == 'CombineParallelConv2d':
         return transform.CombineParallelConv2D()
     if name == 'AlterOpLayout':
         return transform.AlterOpLayout()
     if name == 'EliminateCommonSubexpr':
         return transform.EliminateCommonSubexpr()
     if name == 'PartialEvaluate':
         return transform.PartialEvaluate()
     if name == 'CanonicalizeCast':
         return transform.CanonicalizeCast()
     if name == 'CanonicalizeOps':
         return transform.CanonicalizeOps()
     raise Exception('Name {} does not match any pass'.format(name))
Exemplo n.º 7
0
def run_combine_parallel(expr, min_num_branches=3):
    mod = tvm.IRModule.from_expr(expr)
    mod = transform.CombineParallelConv2D(min_num_branches)(mod)
    return mod["main"]
def run_combine_parallel(expr, min_num_branches=3):
    mod = relay.Module.from_expr(expr)
    mod = transform.CombineParallelConv2D(min_num_branches)(mod)
    return mod[mod.entry_func]