def test_fuse_simple(): """Simple testcase.""" def before(): x = relay.var("x", shape=(10, 20)) y = relay.add(x, relay.const(1, "float32")) z = relay.exp(y) w = relay.squeeze(z) return relay.Function([x], w) def expected(): x = relay.var("p", shape=(10, 20)) y = relay.add(x, relay.const(1, "float32")) z = relay.exp(y) w = relay.squeeze(z) f1 = relay.Function([x], w) f1 = f1.with_attr("Primitive", tvm.tir.IntImm("int32", 1)) x = relay.var("x", shape=(10, 20)) y = relay.Call(f1, [x]) return relay.Function([x], y) z = before() zz = run_opt_pass(z, transform.FuseOps(fuse_opt_level=2)) zz = run_opt_pass(z, transform.FuseOps()) after = run_opt_pass(expected(), transform.InferType()) assert tvm.ir.structural_equal(zz, after)
def test_tuple_root(): """Test fusion case where Tuple node is the root in its group""" def before(dshape): x = relay.var("x", shape=dshape) pooled = relay.nn.max_pool2d(x, pool_size=(2, 2), strides=(2, 2), padding=(0, 0)) upsampled = relay.nn.upsampling(pooled, scale_h=2, scale_w=2, layout="NCHW") out = relay.Tuple((upsampled, x)) return relay.Function(relay.analysis.free_vars(out), out) def expected(dshape): x = relay.var("x", shape=dshape) pooled = relay.nn.max_pool2d(x, pool_size=(2, 2), strides=(2, 2), padding=(0, 0)) f0 = relay.Function([x], pooled) f0 = f0.with_attr("Primitive", tvm.tir.IntImm("int32", 1)) p0 = relay.var("p0", shape=(dshape[0], dshape[1], dshape[2] // 2, dshape[3] // 2)) upsampled = relay.nn.upsampling(p0, scale_h=2, scale_w=2, layout="NCHW") f1 = relay.Function([p0], upsampled) f1 = f1.with_attr("Primitive", tvm.tir.IntImm("int32", 1)) x = relay.var("x", shape=dshape) y = relay.Call(f0, [x]) z = relay.Call(f1, [y]) tup = relay.Tuple((z, x)) return relay.Function([x], tup) dshape = (1, 16, 64, 64) z = before(dshape) zz = run_opt_pass(z, transform.FuseOps(fuse_opt_level=0)) assert not relay.analysis.free_vars(zz) zz = run_opt_pass(z, transform.FuseOps(fuse_opt_level=2)) assert not relay.analysis.free_vars(zz) after = run_opt_pass(expected(dshape), transform.InferType()) assert tvm.ir.structural_equal(zz, after)
def test_fuse_tuple_get_elemwise(): def before(dim): X = relay.var("X", shape=(1, dim)) W = relay.var("W", shape=(3 * dim, dim)) matmul = relay.nn.dense(X, W) splitted = relay.split(matmul, indices_or_sections=3, axis=1) out = relay.sigmoid(splitted[0]) + relay.tanh(splitted[1]) * relay.exp(splitted[2]) return relay.Function([X, W], out) def expected(dim): p0 = relay.var("p0", shape=(1, dim)) p1 = relay.var("p1", shape=(3 * dim, dim)) matmul = relay.nn.dense(p0, p1) f0 = relay.Function([p0, p1], matmul) p01 = relay.var("p01", shape=(1, 3 * dim)) splitted = relay.split(p01, indices_or_sections=3, axis=1) out = relay.sigmoid(splitted[0]) + relay.tanh(splitted[1]) * relay.exp(splitted[2]) f1 = relay.Function([p01], out) X = relay.var("X", shape=(1, dim)) W = relay.var("W", shape=(3 * dim, dim)) y = relay.Call(f0, [X, W]) z = relay.Call(f1, [y]) return relay.Function([X, W], z) dim = 10 z = before(dim) zz = run_opt_pass(z, transform.FuseOps(fuse_opt_level=0)) assert not relay.analysis.free_vars(zz) zz = run_opt_pass(z, transform.FuseOps(fuse_opt_level=2)) assert not relay.analysis.free_vars(zz) after = run_opt_pass(expected(dim), transform.InferType()) assert relay.analysis.alpha_equal(zz, after)
def test_tuple_get_root(): def before(dim): X = relay.var("X", shape=(1, 3 * dim)) W = relay.var("W", shape=(dim, dim)) splitted = relay.split(X, indices_or_sections=3, axis=1) out = relay.nn.dense(splitted[0], W) return relay.Function([X, W], out) def expected(dim): p0 = relay.var("p0", shape=(1, 3 * dim)) splitted = relay.split(p0, indices_or_sections=3, axis=1) out = splitted[0] f0 = relay.Function([p0], out) f0 = f0.with_attr("Primitive", tvm.tir.IntImm("int32", 1)) p01 = relay.var("p01", shape=(1, dim)) p1 = relay.var("p1", shape=(dim, dim)) out = relay.nn.dense(p01, p1) f1 = relay.Function([p01, p1], out) f1 = f1.with_attr("Primitive", tvm.tir.IntImm("int32", 1)) X = relay.var("X", shape=(1, 3 * dim)) W = relay.var("W", shape=(dim, dim)) y = relay.Call(f0, [X]) z = relay.Call(f1, [y, W]) return relay.Function([X, W], z) dim = 10 z = before(dim) zz = run_opt_pass(z, transform.FuseOps(fuse_opt_level=0)) assert not relay.analysis.free_vars(zz) zz = run_opt_pass(z, transform.FuseOps(fuse_opt_level=2)) assert not relay.analysis.free_vars(zz) after = run_opt_pass(expected(dim), transform.InferType()) assert tvm.ir.structural_equal(zz, after)
def test_fuse_parallel_injective(): """Test fusing parallel injective ops to an elemwise op.""" def before(): x = relay.var("x", shape=(10, 20)) y = relay.add(x, relay.const(1, "float32")) z = relay.squeeze(y) u = relay.transpose(y, axes=[0, 1]) w = relay.left_shift(z, u) return relay.Function([x], w) def expected(): x = relay.var("p", shape=(10, 20)) y = relay.add(x, relay.const(1, "float32")) z = relay.squeeze(y) u = relay.transpose(y, axes=[0, 1]) w = relay.left_shift(z, u) f1 = relay.Function([x], w) f1 = f1.with_attr("Primitive", tvm.tir.IntImm("int32", 1)) x = relay.var("x", shape=(10, 20)) y = relay.Call(f1, [x]) return relay.Function([x], y) z = before() zz = run_opt_pass(z, transform.FuseOps(fuse_opt_level=0)) assert not relay.analysis.free_vars(zz) zz = run_opt_pass(z, transform.FuseOps(fuse_opt_level=2)) assert not relay.analysis.free_vars(zz) after = run_opt_pass(expected(), transform.InferType()) assert tvm.ir.structural_equal(zz, after)
def test_fuse_max(): """Test the constraint of number of nodes in op fusion.""" max_fused_ops = 256 # n is the number of nodes to be fused, should be less than 2*max_fused_ops n = 300 def before(): x = relay.var("x", shape=(10, 20)) y = x for i in range(n): y = relay.exp(y) return relay.Function([x], y) def expected(): x = relay.var("p", shape=(10, 20)) y = x for i in range(max_fused_ops): y = relay.exp(y) f1 = relay.Function([x], y) x = relay.var("x", shape=(10, 20)) z = relay.Call(f1, [x]) xx = relay.var("pp", shape=(10, 20)) yy = xx for i in range(n - max_fused_ops): yy = relay.exp(yy) f2 = relay.Function([xx], yy) zz = relay.Call(f2, [z]) return relay.Function([x], zz) z = before() zz = run_opt_pass(z, transform.FuseOps(fuse_opt_level=2)) zz = run_opt_pass(z, transform.FuseOps()) after = run_opt_pass(expected(), transform.InferType()) assert relay.analysis.alpha_equal(zz, after)
def test_concatenate(): """Test fusion case involving concat op and Tuple node""" def before(dshape): x = relay.var("x", shape=dshape) pooled = relay.nn.max_pool2d(x, pool_size=(2, 2), strides=(2, 2), padding=(0, 0)) upsampled = relay.nn.upsampling(pooled, scale_h=2, scale_w=2, layout="NCHW") concat = relay.concatenate((upsampled, x), axis=1) out = relay.add(concat, relay.const(1, "float32")) return relay.Function(relay.analysis.free_vars(out), out) def expected(dshape): x = relay.var("x", shape=dshape) pooled = relay.nn.max_pool2d(x, pool_size=(2, 2), strides=(2, 2), padding=(0, 0)) f0 = relay.Function([x], pooled) f0 = f0.with_attr("Primitive", tvm.tir.IntImm("int32", 1)) p0 = relay.var("p0", shape=(dshape[0], dshape[1], dshape[2] // 2, dshape[3] // 2)) p1 = relay.var("p1", shape=dshape) upsampled = relay.nn.upsampling(p0, scale_h=2, scale_w=2, layout="NCHW") concat = relay.concatenate((upsampled, p1), axis=1) out = relay.add(concat, relay.const(1, "float32")) f1 = relay.Function([p0, p1], out) f1 = f1.with_attr("Primitive", tvm.tir.IntImm("int32", 1)) x = relay.var("x", shape=dshape) y = relay.Call(f0, [x]) z = relay.Call(f1, [y, x]) return relay.Function([x], z) dshape = (1, 16, 64, 64) z = before(dshape) zz = run_opt_pass(z, transform.FuseOps(fuse_opt_level=0)) assert not relay.analysis.free_vars(zz) zz = run_opt_pass(z, transform.FuseOps(fuse_opt_level=2)) assert not relay.analysis.free_vars(zz) after = run_opt_pass(expected(dshape), transform.InferType()) assert tvm.ir.structural_equal(zz, after)
def test_immutable(): """Verify the fusion pass won't change original module.""" def before(): x = relay.var("x", shape=(10, 20)) y = relay.add(x, relay.const(1, "float32")) z = relay.exp(y) w = relay.squeeze(z) mod = relay.module.Module() mod["main"] = relay.Function([x], w) return mod def expected(): x = relay.var("p", shape=(10, 20)) y = relay.add(x, relay.const(1, "float32")) z = relay.exp(y) w = relay.squeeze(z) f1 = relay.Function([x], w) x = relay.var("x", shape=(10, 20)) y = relay.Call(f1, [x]) mod = relay.module.Module() mod["main"] = relay.Function([x], y) return mod mod = before() new_mod = transform.FuseOps(fuse_opt_level=2)(mod) assert relay.analysis.alpha_equal(mod, before()) assert relay.analysis.alpha_equal(new_mod, expected())
def test_fuse_myia_regression(): def before(dshape, dtype): x = relay.var('x', shape=dshape, dtype=dtype) y = relay.var('y', shape=dshape, dtype=dtype) sb = relay.ScopeBuilder() with sb.if_scope(relay.op.greater(x, y)): sb.ret(relay.Function([], x)) with sb.else_scope(): sb.ret(relay.Function([], y)) return relay.Function([x, y], relay.Call(sb.get(), [])) def expected(dshape, dtype): x = relay.var('x', shape=dshape, dtype=dtype) y = relay.var('y', shape=dshape, dtype=dtype) sb = relay.ScopeBuilder() p1 = relay.var('p1', shape=dshape, dtype=dtype) p2 = relay.var('p2', shape=dshape, dtype=dtype) fused_gt = relay.Function([p1, p2], relay.op.greater(p1, p2)) with sb.if_scope(fused_gt(x, y)): sb.ret(relay.Function([], x)) with sb.else_scope(): sb.ret(relay.Function([], y)) return relay.Function([x, y], relay.Call(sb.get(), [])) dshape = () dtype = 'int64' f = before(dshape, dtype) zz = run_opt_pass(f, transform.FuseOps()) after = run_opt_pass(expected(dshape, dtype), transform.InferType()) assert relay.analysis.alpha_equal(zz, after)
def test_stop_fusion(): def before(dshape): x = relay.var("x", shape=dshape) y = relay.add(x, relay.const(1, "float32")) y = relay.annotation.stop_fusion(y) z = relay.exp(y) return relay.Function([x], z) def expected(dshape): x = relay.var("p0", shape=dshape) y = relay.add(x, relay.const(1, "float32")) f1 = relay.Function([x], y) x = relay.var("p01", shape=dshape) y = relay.exp(x) f2 = relay.Function([x], y) x = relay.var("x", shape=dshape) y = relay.Call(f1, [x]) z = relay.Call(f2, [y]) return relay.Function([x], z) dshape = (10, 20) z = before(dshape) zz = run_opt_pass(z, transform.FuseOps()) after = run_opt_pass(expected(dshape), transform.InferType()) assert relay.analysis.alpha_equal(zz, after)
def test_immutable(): """Verify the fusion pass won't change original module.""" def before(): x = relay.var("x", shape=(10, 20)) y = relay.add(x, relay.const(1, "float32")) z = relay.exp(y) w = relay.squeeze(z) mod = tvm.IRModule() mod["main"] = relay.Function([x], w) return mod def expected(): x = relay.var("p", shape=(10, 20)) y = relay.add(x, relay.const(1, "float32")) z = relay.exp(y) w = relay.squeeze(z) f1 = relay.Function([x], w) f1 = f1.with_attr("Primitive", tvm.tir.IntImm("int32", 1)) x = relay.var("x", shape=(10, 20)) y = relay.Call(f1, [x]) mod = tvm.IRModule() mod["main"] = relay.Function([x], y) return mod mod = before() new_mod = transform.FuseOps(fuse_opt_level=2)(mod) assert tvm.ir.structural_equal(mod, before()) assert tvm.ir.structural_equal(new_mod, expected())
def test_fuse_gather_nd(link_params): """Test fusion case involving concat and gather_nd""" def before(): shape = (tvm.tir.const(10, "int64"), tvm.tir.const(1, "int64")) x = relay.var("x", shape=shape) concat = relay.concatenate([x, x], axis=-1) out = relay.gather_nd(concat, indices=relay.expr.const([[0, 1], [1, 0]], dtype="int64")) return relay.Function(relay.analysis.free_vars(out), out) def expected(link_params): shape1 = (tvm.tir.const(10, "int64"), tvm.tir.const(1, "int64")) shape2 = (tvm.tir.const(2, "int64"), tvm.tir.const(2, "int64")) x = relay.var("x", shape=shape1) p0 = relay.var("p0", shape=shape1) p1 = relay.var("p1", shape=shape2, dtype="int64") c = relay.const([[0, 1], [1, 0]], dtype="int64") concat = relay.concatenate([p0, p0], axis=-1) out = relay.gather_nd(concat, indices=c if link_params else p1) f0 = relay.Function([p0] if link_params else [p0, p1], out) f0 = f0.with_attr("Primitive", tvm.tir.IntImm("int32", 1)) y = relay.Call(f0, [x] if link_params else [x, c]) return relay.Function([x], y) after = run_opt_pass(expected(link_params), transform.InferType()) with tvm.transform.PassContext( opt_level=2, config={"relay.FuseOps.link_params": link_params}): m = run_opt_pass(before(), transform.InferType()) m = run_opt_pass(m, transform.FuseOps()) assert tvm.ir.structural_equal(m, after) relay.build(m, "llvm")
def test_stop_fusion(): def before(dshape): x = relay.var("x", shape=dshape) y = relay.add(x, relay.const(1, "float32")) y = relay.annotation.stop_fusion(y) z = relay.exp(y) return relay.Function([x], z) def expected(dshape): x = relay.var("p0", shape=dshape) y = relay.add(x, relay.const(1, "float32")) f1 = relay.Function([x], y) f1 = f1.with_attr("Primitive", tvm.tir.IntImm("int32", 1)) x = relay.var("p01", shape=dshape) y = relay.exp(x) f2 = relay.Function([x], y) f2 = f2.with_attr("Primitive", tvm.tir.IntImm("int32", 1)) x = relay.var("x", shape=dshape) y = relay.Call(f1, [x]) z = relay.Call(f2, [y]) return relay.Function([x], z) dshape = (10, 20) z = before(dshape) zz = run_opt_pass(z, transform.FuseOps()) after = run_opt_pass(expected(dshape), transform.InferType()) assert tvm.ir.structural_equal(zz, after)
def test_fuse_myia_regression(): def before(dshape, dtype): x = relay.var('x', shape=dshape, dtype=dtype) y = relay.var('y', shape=dshape, dtype=dtype) sb = relay.ScopeBuilder() with sb.if_scope(relay.op.greater(x, y)): sb.ret(relay.Function([], x)) with sb.else_scope(): sb.ret(relay.Function([], y)) return relay.Function([x, y], relay.Call(sb.get(), [])) def expected(dshape, dtype): x = relay.var('x', shape=dshape, dtype=dtype) y = relay.var('y', shape=dshape, dtype=dtype) sb = relay.ScopeBuilder() p1 = relay.var('p1', shape=dshape, dtype=dtype) p2 = relay.var('p2', shape=dshape, dtype=dtype) fused_gt = relay.Function([p1, p2], relay.op.greater(p1, p2)) fused_gt = fused_gt.with_attr("Primitive", tvm.tir.IntImm("int32", 1)) with sb.if_scope(fused_gt(x, y)): sb.ret(relay.Function([], x)) with sb.else_scope(): sb.ret(relay.Function([], y)) return relay.Function([x, y], relay.Call(sb.get(), [])) dshape = () dtype = 'int64' f = before(dshape, dtype) zz = run_opt_pass(f, transform.FuseOps()) after = run_opt_pass(expected(dshape, dtype), transform.InferType()) assert tvm.ir.structural_equal(zz, after)
def test_inception_like(): def conv(data): y = relay.nn.conv2d(data, relay.var("w"), kernel_size=(3, 3), padding=(1, 1), channels=16) return relay.nn.relu(data=y) def inception_like(data): c0 = conv(data) c1 = conv(data) return relay.concatenate((c0, c1), axis=1) def before(dshape): x = relay.var("x", shape=dshape) in1 = inception_like(x) in2 = inception_like(in1) return relay.Function(relay.analysis.free_vars(in2), in2) dshape = (1, 16, 64, 64) x = before(dshape) x = run_opt_pass(x, transform.InferType()) fused = run_opt_pass(x, transform.FuseOps()) defused = run_opt_pass(fused, transform.DefuseOps()) assert tvm.ir.structural_equal(x, defused)
def test_fuse_max(): """Test the constraint of number of nodes in op fusion.""" def before(n): x = relay.var("x", shape=(10, 20)) y = x for i in range(n): y = relay.exp(y) return relay.Function([x], y) def expected(n, max_fused_ops): x = relay.var("p", shape=(10, 20)) y = x for i in range(max_fused_ops): y = relay.exp(y) f1 = relay.Function([x], y) f1 = f1.with_attr("Primitive", tvm.tir.IntImm("int32", 1)) x = relay.var("x", shape=(10, 20)) z = relay.Call(f1, [x]) xx = relay.var("pp", shape=(10, 20)) yy = xx # it is assumed that there are two fused functions for i in range(n - max_fused_ops): yy = relay.exp(yy) f2 = relay.Function([xx], yy) f2 = f2.with_attr("Primitive", tvm.tir.IntImm("int32", 1)) zz = relay.Call(f2, [z]) return relay.Function([x], zz) max_fused_ops = 256 n = 300 z = before(n) zz = run_opt_pass(z, transform.FuseOps(fuse_opt_level=2)) zz = run_opt_pass(z, transform.FuseOps()) after = run_opt_pass(expected(n, max_fused_ops), transform.InferType()) assert tvm.ir.structural_equal(zz, after) max_fused_ops = 10 n = 20 z = before(n) after = run_opt_pass(expected(n, max_fused_ops), transform.InferType()) with tvm.transform.PassContext(config={"relay.FuseOps.max_depth": max_fused_ops}): zz = run_opt_pass(z, transform.FuseOps()) assert tvm.ir.structural_equal(zz, after)
def test_split(): """Test that the result is well formed.""" x = relay.var("x", shape=(6, 9)) y = relay.split(x, 3).astuple() a = relay.TupleGetItem(y, 0) b = relay.TupleGetItem(y, 1) c = relay.TupleGetItem(y, 2) mod = relay.module.Module() mod["main"] = relay.Function([x], a + relay.RefRead(relay.RefCreate(b)) + c) mod = transform.FuseOps()(mod)
def test_extern_ccompiler_default_ops(): def expected(): mod = tvm.IRModule() x = relay.var("x", shape=(8, 8)) y = relay.var("y", shape=(8, 8)) x0 = relay.var("x0", shape=(8, 8)) y0 = relay.var("y0", shape=(8, 8)) add = x0 + y0 # Function that uses C compiler func = relay.Function([x0, y0], add) func = func.with_attr("Primitive", tvm.tir.IntImm("int32", 1)) func = func.with_attr("Inline", tvm.tir.IntImm("int32", 1)) func = func.with_attr("Compiler", tvm.tir.StringImm("ccompiler")) func = func.with_attr("ExternalSymbol", tvm.tir.StringImm("ccompiler_0")) glb_0 = relay.GlobalVar("ccompiler_0") mod[glb_0] = func add_call = relay.Call(glb_0, [x, y]) # Function that uses default compiler. Ops are fused in this function. p0 = relay.var("p0", shape=(8, 8)) log = relay.log(p0) exp = relay.exp(p0) concat = relay.concatenate([log, exp], axis=0) fused_func = relay.Function([p0], concat) fused_func = fused_func.with_attr("Primitive", tvm.tir.IntImm("int32", 1)) fused_call = relay.Call(fused_func, [add_call]) main = relay.Function([x, y], fused_call) mod["main"] = main return mod x = relay.var("x", shape=(8, 8)) y = relay.var("y", shape=(8, 8)) add = x + y log = relay.log(add) exp = relay.exp(add) concat = relay.concatenate([log, exp], axis=0) f = relay.Function([x, y], concat) mod = tvm.IRModule() mod["main"] = f mod = WhiteListAnnotator(["add", "subtract", "multiply"], "ccompiler")(mod) mod = transform.PartitionGraph()(mod) fused_mod = transform.FuseOps(2)(mod) expected_mod = expected() assert relay.alpha_equal(fused_mod, expected_mod) x_data = np.random.rand(8, 8).astype('float32') y_data = np.random.rand(8, 8).astype('float32') np_add = x_data + y_data res = np.concatenate([np.log(np_add), np.exp(np_add)]) check_result(mod, {"x": x_data, "y": y_data}, (16, 8), res)
def test_defuse_simple(): """Simple testcase.""" def before(): x = relay.var("x", shape=(10, 20)) y = relay.add(x, relay.const(1, "float32")) z = relay.exp(y) w = relay.squeeze(z) return relay.Function([x], w) x = before() x = run_opt_pass(x, transform.InferType()) fused = run_opt_pass(x, transform.FuseOps()) defused = run_opt_pass(fused, transform.DefuseOps()) assert tvm.ir.structural_equal(x, defused)
def test_fuse_simple(): """Simple testcase.""" def before(): x = relay.var("x", shape=(10, 20)) y = relay.add(x, relay.const(1, "float32")) z = relay.exp(y) w = relay.squeeze(z) return relay.Function([x], w) def expected(): x = relay.var("p", shape=(10, 20)) y = relay.add(x, relay.const(1, "float32")) z = relay.exp(y) w = relay.squeeze(z) f1 = relay.Function([x], w) x = relay.var("x", shape=(10, 20)) y = relay.Call(f1, [x]) return relay.Function([x], y) z = before() zz = run_opt_pass(z, transform.FuseOps(fuse_opt_level=2)) zz = run_opt_pass(z, transform.FuseOps()) after = run_opt_pass(expected(), transform.InferType()) assert relay.analysis.alpha_equal(zz, after)
def check_storage_and_device_types(): func = annotated() func = run_opt_pass(func, [transform.RewriteAnnotatedOps(3), transform.FuseOps(2)]) smap = relay.backend._backend.GraphPlanMemory(func) storage_ids = [] device_types = [] for _, storage_dev_type in smap.items(): assert len(storage_dev_type) == 2 for sid in storage_dev_type[0]: storage_ids.append(sid.value) for did in storage_dev_type[1]: device_types.append(did.value) assert len(storage_ids) == 10 assert len(set(storage_ids)) == 8 assert len(set(device_types)) == 2 assert set(device_types) == {1, 2}
def test_fuse_max_diamond(): def create_diamond(x, branch_len): x1 = x x2 = x for _ in range(branch_len): x1 = relay.exp(x1) x2 = relay.exp(x2) return relay.add(x1, x2) def before(branch_len, num_diamond): x = relay.var("x", shape=(10, 20)) out = x for _ in range(num_diamond): out = create_diamond(out, branch_len) return relay.Function([x], out) def after(branch_len, num_diamond): def create_diamond_func(inp): inp_var = relay.var("p", shape=(10, 20)) d = create_diamond(inp_var, branch_len) f = relay.Function([inp_var], d) f = f.with_attr("Primitive", tvm.tir.IntImm("int32", 1)) return relay.Call(f, [inp]) inp = relay.var("x", shape=(10, 20)) out = inp for _ in range(num_diamond): out = create_diamond_func(out) return relay.Function([inp], out) branch_len = 5 max_fused_ops = branch_len * 2 + 1 # the number of ops in one diamond num_diamond = 3 with tvm.transform.PassContext( config={"relay.FuseOps.max_depth": max_fused_ops}): fused = run_opt_pass(before(branch_len, num_diamond), transform.FuseOps()) expected = run_opt_pass(after(branch_len, num_diamond), transform.InferType()) assert tvm.ir.structural_equal(fused, expected)
def match_pass_name(name): if name == 'FoldScaleAxis': return transform.FoldScaleAxis() if name == 'BackwardFoldScaleAxis': return transform.BackwardFoldScaleAxis() if name == 'ForwardFoldScaleAxis': return transform.ForwardFoldScaleAxis() if name == 'FuseOps': return transform.FuseOps(3) if name == 'FoldConstant': return transform.FoldConstant() if name == 'CombineParallelConv2d': return transform.CombineParallelConv2D() if name == 'AlterOpLayout': return transform.AlterOpLayout() if name == 'EliminateCommonSubexpr': return transform.EliminateCommonSubexpr() if name == 'PartialEvaluate': return transform.PartialEvaluate() if name == 'CanonicalizeCast': return transform.CanonicalizeCast() if name == 'CanonicalizeOps': return transform.CanonicalizeOps() raise Exception('Name {} does not match any pass'.format(name))
def test_conv2d_fuse(): """Test fusion case of conv2d""" def before(dshape): x = relay.var("x", shape=dshape) x = relay.add(x, relay.const(1, "float32")) y = relay.nn.conv2d(x, relay.var("w1"), kernel_size=(3, 3), padding=(1, 1), channels=16) # this is the next dominator. y1 = relay.add(relay.const(1, "float32"), y) y = relay.add(y, y1) # second path z2 = relay.nn.conv2d(y, relay.var("w2"), kernel_size=(1, 1), padding=(0, 0), channels=16) z3 = relay.nn.conv2d(y, relay.var("w3"), kernel_size=(3, 3), padding=(1, 1), channels=16) # add can only be fused to z1 z = relay.add(z2, z3) return relay.Function(relay.analysis.free_vars(z), z) def expected(dshape): # segment 0 x = relay.var("p0", shape=dshape) y = relay.add(x, relay.const(1, "float32")) f0 = relay.Function([x], y) f0 = f0.with_attr("Primitive", tvm.tir.IntImm("int32", 1)) # segment 1 x = relay.var("p0", shape=dshape) w = relay.var("p1") y = relay.nn.conv2d(x, w, kernel_size=(3, 3), padding=(1, 1), channels=16) y1 = relay.add(relay.const(1, "float32"), y) y = relay.add(y, y1) f1 = relay.Function([x, w], y) f1 = f1.with_attr("Primitive", tvm.tir.IntImm("int32", 1)) # segment 2 x = relay.var("p0", shape=dshape) w = relay.var("p1") z2 = relay.nn.conv2d(x, w, kernel_size=(3, 3), padding=(1, 1), channels=16) f2 = relay.Function([x, w], z2) f2 = f2.with_attr("Primitive", tvm.tir.IntImm("int32", 1)) # segment 3 x = relay.var("p0", shape=dshape) w = relay.var("p1") offset = relay.var("p2", shape=dshape) z3 = relay.nn.conv2d(x, w, kernel_size=(1, 1), padding=(0, 0), channels=16) z3 = relay.add(z3, offset) f3 = relay.Function([x, w, offset], z3) f3 = f3.with_attr("Primitive", tvm.tir.IntImm("int32", 1)) # compose x = relay.var("x", shape=dshape) y = relay.Call(f0, [x]) y = relay.Call(f1, [y, relay.var("w1")]) z2 = relay.Call(f2, [y, relay.var("w3")]) z3 = relay.Call(f3, [y, relay.var("w2"), z2]) z = z3 return relay.Function(relay.analysis.free_vars(z), z) dshape = (1, 16, 64, 64) z = before(dshape) zz = run_opt_pass(z, transform.FuseOps(fuse_opt_level=2)) after = run_opt_pass(expected(dshape), transform.InferType()) assert tvm.ir.structural_equal(zz, after)
"""Workaround for type checker and mutually recursive functions.""" for gv in funcs: func = funcs[gv] body = _placeholder_body(func.ret_type) mod[gv] = relay.Function(func.params, body, func.ret_type) for gv in funcs: mod[gv] = funcs[gv] pass_set = transform.Sequential( passes=[ transform.SimplifyInference(), transform.CanonicalizeOps(), transform.CanonicalizeCast(), transform.FuseOps(3), # transform.CombineParallelConv2d(), transform.AlterOpLayout(), # transform.RewriteAnnotatedOps(???), ], opt_level=0) def optimize(mod): """Optimize all the functions in a module. Modules are the only mutable piece of Relay. We write an optimization pass over the module which destructively updates each function while optimizing. """ return pass_set(mod)