def test_ethosu_pooling_type_inference( ifm_shape, ifm_layout, ofm_shape, ofm_layout, ): dtype = "int8" ifm = relay.var("ifm", shape=ifm_shape, dtype=dtype) pooling_type = "AVG" pool_shape = (3, 2) ofm_channels = 55 strides = (1, 2) padding = (0, 1, 2, 3) pooling = make_ethosu_pooling( ifm, pooling_type, pool_shape, ofm_channels, strides, padding, ifm_layout=ifm_layout, ofm_layout=ofm_layout, ) func = relay.Function([ifm], pooling) func = run_opt_pass(func, relay.transform.InferType()) assert tuple(func.body.checked_type.shape) == ofm_shape assert func.body.checked_type.dtype == dtype
def test_pooling_single( ifm_shape, ofm_channels, ifm_layout, ofm_layout, pooling_type, activation, rounding_mode, upscale, ): pool_shape = (3, 2) strides = (1, 2) # When strides are not (1, 1) it is possible to create invalid # padding configurations. It is possible to construct a pooling # operation with invalid padding, but the compiler will account # for this and adjust the padding accordingly, leading to a # mismatch between the expected and actual result. Therefore, # hardcoded padding values are used for each case. padding = (1, 1, 1, 0) if upscale == "NONE" else (0, 0, 0, 0) ifm = relay.var("ifm", shape=ifm_shape, dtype="int8") pooling = make_ethosu_pooling( ifm, pooling_type, pool_shape, ofm_channels, strides, padding, activation, ifm_layout, ofm_layout, rounding_mode, upscale, ) func = relay.Function(relay.analysis.free_vars(pooling), pooling) func = run_opt_pass(func, relay.transform.InferType()) mod, _ = _lower_to_tir(func) data = [] def _visit(stmt): if isinstance(stmt, tvm.tir.Call): data.append(get_pooling_args(stmt)) tvm.tir.stmt_functor.post_order_visit(mod["main"].body, _visit) serial_pooling = _create_serial_pooling( ifm_shape, ofm_channels, ifm_layout, ofm_layout, pool_shape, pooling_type, strides, padding, activation, rounding_mode, upscale, ) assert data[0] == ["ethosu_pooling"] + list(serial_pooling)
def test_fuse_softmax(): """Test if softmax can be fused with following ops.""" channel_size = 16 def before(): x = relay.var("x", shape=(16, channel_size)) softmax = relay.nn.softmax(x) out = relay.cast(softmax, "float16") return relay.Function([x], out) def expected(): p0 = relay.var("p0", shape=(16, channel_size)) softmax = relay.nn.softmax(p0) out = relay.cast(softmax, "float16") x = relay.var("x", shape=(16, channel_size)) f0 = relay.Function([p0], out) f0 = f0.with_attr("Primitive", tvm.tir.IntImm("int32", 1)) y = relay.Call(f0, [x]) return relay.Function([x], y) orig = before() m = fuse2(tvm.IRModule.from_expr(orig)) after = run_opt_pass(expected(), transform.InferType()) assert tvm.ir.structural_equal(m["main"], after) inp = np.random.randn(16, channel_size).astype("float32") ref = tvm.topi.testing.softmax_python(inp).astype("float16") for tgt, dev in tvm.testing.enabled_targets(): ex = relay.create_executor("graph", mod=m, device=dev, target=tgt) result = ex.evaluate()(inp).numpy() tvm.testing.assert_allclose(result, ref, rtol=1e-4, atol=1e-4)
def test_fuse_strided_slice(): """Test fusion case involving concat and strided_slice""" def before(): shape = (tvm.tir.const(10, "int64"), tvm.tir.const(1, "int64")) x = relay.var("x", shape=shape) concat = relay.concatenate([x,x], axis=-1) out = relay.strided_slice(concat, begin=[np.int64(0)], end=[np.int64(3)]) t = relay.Function(relay.analysis.free_vars(out), out) return relay.Function(relay.analysis.free_vars(out), out) def expected(): shape = (tvm.tir.const(10, "int64"), tvm.tir.const(1, "int64")) x = relay.var("x", shape=shape) p0 = relay.var("p0", shape=shape) concat = relay.concatenate([p0,p0], axis=-1) out = relay.strided_slice(concat, begin=[np.int64(0)], end=[np.int64(3)]) f0 = relay.Function([p0], out) f0 = f0.with_attr("Primitive", tvm.tir.IntImm("int32", 1)) y = relay.Call(f0, [x]) return relay.Function([x], y) orig = before() fuse0(tvm.IRModule.from_expr(orig)) t = tvm.IRModule.from_expr(orig) m = fuse2(tvm.IRModule.from_expr(orig)) attention = m["main"].body.op.params relay.build(m, 'llvm') after = run_opt_pass(expected(), transform.InferType()) assert tvm.ir.structural_equal(m["main"], after)
def test_ethosu_binary_elementwise_type_inference( ifm_shape, ifm_layout, ofm_shape, ofm_layout, ): dtype = "int8" ifm = relay.var("ifm", shape=ifm_shape, dtype=dtype) ifm2 = relay.var("ifm2", shape=ifm_shape, dtype=dtype) operator_type = "ADD" ifm_channels, ifm2_channels = 33, 33 binary_elementwise = make_ethosu_binary_elementwise( ifm, ifm2, ifm_channels, ifm2_channels, operator_type, dtype, ifm_layout=ifm_layout, ifm2_layout=ifm_layout, ofm_layout=ofm_layout, ) func = relay.Function([ifm, ifm2], binary_elementwise) func = run_opt_pass(func, relay.transform.InferType()) assert tuple(func.body.checked_type.shape) == ofm_shape assert func.body.checked_type.dtype == dtype
def test_fuse_bcast_reduce_scalar(): """Test fusion case with broadcast and reduction involving scalar""" def before(): x = relay.var("x", shape=(), dtype="int32") less = relay.less(x, relay.const(10, dtype="int32")) z = relay.min(less) return relay.Function([x], z) def expected(): p0 = relay.var("p0", shape=(), dtype="int32") less = relay.less(p0, relay.const(10, dtype="int32")) z0 = relay.min(less) f0 = relay.Function([p0], z0) f0 = f0.with_attr("Primitive", tvm.tir.IntImm("int32", 1)) x = relay.var("x", shape=(), dtype="int32") f = relay.Call(f0, [x]) return relay.Function([x], f) orig = before() m = fuse2(tvm.IRModule.from_expr(orig)) for tgt, dev in tvm.testing.enabled_targets(): relay.build(m, tgt) after = run_opt_pass(expected(), transform.InferType()) assert tvm.ir.structural_equal(m["main"], after)
def test_ethosu_binary_elementwise_shift_invalid_data_type(invalid_dtype, operator_type): ifm_shape = [1, 4, 5, 33] ifm = relay.var("ifm", shape=ifm_shape, dtype=invalid_dtype) ifm2 = relay.var("ifm2", shape=ifm_shape, dtype=invalid_dtype) ifm_channels, ifm2_channels = 33, 33 binary_elementwise = make_ethosu_binary_elementwise( ifm, ifm2, ifm_channels, ifm2_channels, operator_type, invalid_dtype, ) func = relay.Function([ifm, ifm2], binary_elementwise) with pytest.raises(TVMError): run_opt_pass(func, relay.transform.InferType())
def _get_func(): ifm_a = relay.var("ifm_a", shape=(1, 8, 8, 8), dtype="int8") ifm_b = relay.var("ifm_b", shape=(1, 8, 8, 8), dtype="int8") conv1 = make_ethosu_conv2d( ifm=ifm_a, ifm_channels=8, ofm_channels=8, kernel_shape=(1, 1), padding=(0, 0, 0, 0), strides=(1, 1), dilation=(1, 1), ) conv2 = make_ethosu_conv2d( ifm=ifm_b, ifm_channels=8, ofm_channels=8, kernel_shape=(1, 1), padding=(0, 0, 0, 0), strides=(1, 1), dilation=(1, 1), ) add1 = make_ethosu_binary_elementwise( ifm=conv1, ifm2=conv2, ifm_channels=8, ifm2_channels=8, operator_type="ADD", ofm_dtype="int8", ) func = relay.Function(relay.analysis.free_vars(add1), add1) func = run_opt_pass(func, relay.transform.InferType()) return func
def _get_func(): ifm = relay.var("ifm", shape=(1, 12, 12, 8), dtype="int8") conv1 = make_ethosu_conv2d( ifm=ifm, ifm_channels=8, ofm_channels=32, kernel_shape=(1, 1), padding=(0, 0), strides=(1, 1), dilation=(1, 1), activation="NONE", ifm_layout="NHWC", ofm_layout="NHCWB16", ) conv2 = make_ethosu_conv2d( ifm=conv1, ifm_channels=32, ofm_channels=16, kernel_shape=(3, 3), padding=(1, 1), strides=(1, 1), dilation=(1, 1), activation="NONE", ifm_layout="NHCWB16", ofm_layout="NHWC", ) func = relay.Function(relay.analysis.free_vars(conv2), conv2) func = run_opt_pass(func, relay.transform.InferType()) return func
def test_ethosu_conv2d_type_inference( ifm_shape, ifm_layout, ofm_shape, ofm_layout, ): ifm_channels = 55 ofm_channels = 122 kernel_shape = (3, 2) padding = (0, 1, 2, 3) strides = (1, 2) dilation = (2, 1) ifm = relay.var("ifm", shape=ifm_shape, dtype="int8") conv2d = make_ethosu_conv2d( ifm, ifm_channels, ofm_channels, kernel_shape, padding, strides, dilation, ifm_layout=ifm_layout, ofm_layout=ofm_layout, ) f = relay.Function([ifm], conv2d) f = run_opt_pass(f, relay.transform.InferType()) assert tuple(f.body.checked_type.shape) == ofm_shape
def test_tuple_intermediate(): def before(x): inj = relay.squeeze(x) y1 = relay.add(inj, relay.const(1, "float32")) tmp = relay.squeeze(inj) tmp = relay.add(tmp, relay.const(1, "float32")) y2 = relay.add(tmp, relay.const(1, "float32")) y3 = relay.add(inj, relay.const(1, "float32")) concat = relay.concatenate((y1, y2, y3), axis=1) out_inj = relay.squeeze(concat) out = relay.add(out_inj, relay.const(1, "float32")) return relay.Function(relay.analysis.free_vars(out), out) def expected(p0): f0 = before(p0) x = relay.var("x", shape=dshape) y = relay.Call(f0, [x]) return relay.Function([x], y) dshape = (1, 16, 64, 64) x = relay.var("x", shape=dshape) orig = before(x) fuse0(relay.Module.from_expr(orig)) m = fuse2(relay.Module.from_expr(orig)) relay.build(m, 'llvm') after = run_opt_pass(expected(x), transform.InferType()) assert relay.analysis.alpha_equal(m["main"], after)
def _get_func( ifm_shape, channels, kernel_shape, padding, strides, dilation, activation, ifm_layout, ofm_layout, ): ifm = relay.var("ifm", shape=ifm_shape, dtype="int8") depthwise = make_ethosu_depthwise_conv2d( ifm, channels, kernel_shape, padding, strides, dilation, activation, ifm_layout, ofm_layout, ) func = relay.Function(relay.analysis.free_vars(depthwise), depthwise) func = run_opt_pass(func, relay.transform.InferType()) return func
def _get_func(): ifm_shape = (1, 2, 2, 8) ifm = relay.var("ifm", shape=ifm_shape, dtype="int8") conv = make_ethosu_conv2d(ifm, ifm_shape[3], 16, (1, 1), (7, 7), (1, 1), (1, 1), "NHWC") func = relay.Function(relay.analysis.free_vars(conv), conv) func = run_opt_pass(func, relay.transform.InferType()) return func
def test_fuse_gather_nd(): """Test fusion case involving concat and gather_nd""" def before(): shape = (tvm.tir.const(10, "int64"), tvm.tir.const(1, "int64")) x = relay.var("x", shape=shape) concat = relay.concatenate([x, x], axis=-1) out = relay.gather_nd(concat, indices=relay.expr.const([[0, 1], [1, 0]], dtype="int64")) return relay.Function(relay.analysis.free_vars(out), out) def expected(): shape1 = (tvm.tir.const(10, "int64"), tvm.tir.const(1, "int64")) shape2 = (tvm.tir.const(2, "int64"), tvm.tir.const(2, "int64")) x = relay.var("x", shape=shape1) p0 = relay.var("p0", shape=shape1) p1 = relay.var("p1", shape=shape2, dtype="int64") c = relay.const([[0, 1], [1, 0]], dtype="int64") concat = relay.concatenate([p0, p0], axis=-1) out = relay.gather_nd(concat, indices=p1) f0 = relay.Function([p0, p1], out) f0 = f0.with_attr("Primitive", tvm.tir.IntImm("int32", 1)) y = relay.Call(f0, [x, c]) return relay.Function([x], y) orig = before() m = fuse2(tvm.IRModule.from_expr(orig)) relay.build(m, "llvm") after = run_opt_pass(expected(), transform.InferType()) assert tvm.ir.structural_equal(m["main"], after)
def verify_conv2d_backward_weight(dy_shape, x_shape, kernel_size, stride, padding): dtype = "float32" dy = relay.var("dy", shape=dy_shape, dtype=dtype) x = relay.var("x", shape=x_shape, dtype=dtype) dw_func = relay.Function( [dy, x], relay.nn.conv2d_backward_weight( dy, x, strides=stride, padding=padding, kernel_size=kernel_size ), ) dw_func_legalized = run_opt_pass(dw_func, relay.transform.Legalize()) for dw, target in [(dw_func_legalized, "llvm"), (dw_func, "cuda -libs=cudnn")]: if "cudnn" in target and not tvm.contrib.cudnn.exists(): continue dev = tvm.device(target, 0) dy_np = np.random.randn(*dy_shape).astype(dtype) x_np = np.random.randn(*x_shape).astype(dtype) dw_np = relay.create_executor(device=dev, target=target).evaluate(dw)(dy_np, x_np).numpy() ref_dw_np = tvm.topi.testing.conv2d_backward_weight_python( dy_np, x_np, kernel_size, stride, padding ) np.testing.assert_allclose(dw_np, ref_dw_np, rtol=1e-4, atol=1e-4)
def test_tuple_intermediate(): def before(x): inj = relay.squeeze(x) y1 = relay.add(inj, relay.const(1, "float32")) tmp = relay.squeeze(inj) tmp = relay.add(tmp, relay.const(1, "float32")) y2 = relay.add(tmp, relay.const(1, "float32")) y3 = relay.add(inj, relay.const(1, "float32")) concat = relay.concatenate((y1, y2, y3), axis=1) out_inj = relay.squeeze(concat) out = relay.add(out_inj, relay.const(1, "float32")) return relay.Function(relay.analysis.free_vars(out), out) def expected(p0): f0 = before(p0) f1 = f0.with_attr("Primitive", tvm.tir.IntImm("int32", 1)) x = relay.var("x", shape=dshape) y = relay.Call(f1, [x]) return relay.Function([x], y) dshape = (1, 16, 64, 64) x = relay.var("x", shape=dshape) orig = before(x) fuse0(tvm.IRModule.from_expr(orig)) m = fuse2(tvm.IRModule.from_expr(orig)) relay.build(m, "llvm") after = run_opt_pass(expected(x), transform.InferType()) assert tvm.ir.structural_equal(m["main"], after)
def _get_func( ifm_shape, ifm_channels, ofm_channels, kernel_shape, padding, strides, dilation, activation, ifm_layout, ofm_layout, rounding_mode, upscale, ): ifm = relay.var("ifm", shape=ifm_shape, dtype="int8") conv = make_ethosu_conv2d( ifm, ifm_channels, ofm_channels, kernel_shape, padding, strides, dilation, activation=activation, ifm_layout=ifm_layout, ofm_layout=ofm_layout, rounding_mode=rounding_mode, upscale=upscale, ) func = relay.Function(relay.analysis.free_vars(conv), conv) func = run_opt_pass(func, relay.transform.InferType()) return func
def _expected(): var_input1 = relay.var("data", shape=(10, 10), dtype="uint8") var_input2 = relay.var("p1", shape=(10, 10), dtype="uint8") out = relay.add(var_input1, var_input2) func = relay.Function(relay.analysis.free_vars(out), out) func = run_opt_pass(func, relay.transform.InferType()) return func
def test_parallel_merge(): """Tests that parallel patterns relying on the same inputs are correctly merged. The test graph is difficult to draw out as ascii art. It is essentially two parallel add-sub-mul units which both consume input_1 and input_2 with their results being multiplied to give the output. We expect both parallel branches should get merged and both should still consume the same input variables, input_1 and input_2.""" def before(): input_1 = relay.var('input_1', shape=(10, 10)) input_2 = relay.var('input_2', shape=(10, 10)) branch_1_add = relay.add(input_1, input_2) branch_1_sub = relay.subtract(input_1, input_2) branch_1 = relay.multiply(branch_1_add, branch_1_sub) branch_2_add = relay.add(input_1, input_2) branch_2_sub = relay.subtract(input_1, input_2) branch_2 = relay.multiply(branch_2_add, branch_2_sub) out = relay.multiply(branch_1, branch_2) return relay.Function([input_1, input_2], out) def after(): input_1 = relay.var('input_1', shape=(10, 10)) input_2 = relay.var('input_2', shape=(10, 10)) x = relay.var('x') y = relay.var('y') branch_1 = relay.multiply(relay.add(x, y), relay.subtract(x, y)) func_1 = relay.Function([x, y], branch_1) func_1 = func_1.set_attribute('Primitive', tir.IntImm('int32', 1)) func_1 = func_1.set_attribute('Composite', tir.StringImm("add_sub_mul")) call_1 = relay.Call(func_1, [input_1, input_2]) x1 = relay.var('x1') y1 = relay.var('y1') branch_2 = relay.multiply(relay.add(x1, y1), relay.subtract(x1, y1)) func_2 = relay.Function([x1, y1], branch_2) func_2 = func_2.set_attribute('Primitive', tir.IntImm('int32', 1)) func_2 = func_2.set_attribute('Composite', tir.StringImm("add_sub_mul")) call_2 = relay.Call(func_2, [input_1, input_2]) out = relay.multiply(call_1, call_2) return relay.Function([input_1, input_2], out) pattern_table = [("add_sub_mul", make_add_sub_mul_pattern())] result = run_opt_pass(before(), relay.transform.MergeComposite(pattern_table)) assert not relay.analysis.free_vars(result) expected = run_opt_pass(after(), relay.transform.InferType()) assert relay.analysis.alpha_equal(result, expected)
def check_result(pattern_table, expected_graph, import_prelude=False): """Utility function to check inline composites results.""" result = run_opt_pass(expected_graph, relay.transform.MergeComposite(pattern_table), import_prelude=import_prelude) assert check_success_composite_pass( result), "Merge Composite pass didn't produced partioned from Pattern" result = run_opt_pass(expected_graph, relay.transform.InlineComposites(target=""), import_prelude=import_prelude) assert not relay.analysis.free_vars( result), "Found free vars in the result graph: {0}".format(str(result)) expected = run_opt_pass(expected_graph, relay.transform.InferType()) assert tvm.ir.structural_equal( result, expected, map_free_vars=True ), "Graph mismatch: output vs. expected\n{0}\n=====\n{1}".format( str(result), str(expected))
def _get_func(): var_input = relay.var("data", shape=(10, 10), dtype="uint8") const_data = np.random.uniform(0, 255, (10, 10)).astype("uint8") const_input = relay.const(const_data, dtype="uint8") out = relay.add(var_input, const_input) func = relay.Function(relay.analysis.free_vars(out), out) func = run_opt_pass(func, relay.transform.InferType()) return func, const_input
def test_fuse_max(): """Test the constraint of number of nodes in op fusion.""" def before(n): x = relay.var("x", shape=(10, 20)) y = x for i in range(n): y = relay.exp(y) return relay.Function([x], y) def expected(n, max_fused_ops): x = relay.var("p", shape=(10, 20)) y = x for i in range(max_fused_ops): y = relay.exp(y) f1 = relay.Function([x], y) f1 = f1.with_attr("Primitive", tvm.tir.IntImm("int32", 1)) x = relay.var("x", shape=(10, 20)) z = relay.Call(f1, [x]) xx = relay.var("pp", shape=(10, 20)) yy = xx # it is assumed that there are two fused functions for i in range(n - max_fused_ops): yy = relay.exp(yy) f2 = relay.Function([xx], yy) f2 = f2.with_attr("Primitive", tvm.tir.IntImm("int32", 1)) zz = relay.Call(f2, [z]) return relay.Function([x], zz) max_fused_ops = 256 n = 300 z = before(n) zz = run_opt_pass(z, transform.FuseOps(fuse_opt_level=2)) zz = run_opt_pass(z, transform.FuseOps()) after = run_opt_pass(expected(n, max_fused_ops), transform.InferType()) assert tvm.ir.structural_equal(zz, after) max_fused_ops = 10 n = 20 z = before(n) after = run_opt_pass(expected(n, max_fused_ops), transform.InferType()) with tvm.transform.PassContext(config={"relay.FuseOps.max_depth": max_fused_ops}): zz = run_opt_pass(z, transform.FuseOps()) assert tvm.ir.structural_equal(zz, after)
def destroy_ref(x): x = run_infer_type(x) x = to_cps(x) x = run_infer_type(x) y = un_cps(x) y = run_infer_type(y) x = run_opt_pass(x, transform.Sequential([transform.PartialEvaluate(), transform.DeadCodeElimination(inline_once=True)])) assert Feature.fRefCreate not in detect_feature(x)
def test_concretize_ones_like(): dtype = "int32" shape_like = relay.var("shape_like", shape=(3, 4, 5), dtype=dtype) expr = relay.ones_like(shape_like) expected = run_infer_type(relay.ones((3, 4, 5), dtype)) actual = run_opt_pass(expr, relay.transform.SimplifyExpr()) assert tvm.ir.structural_equal(actual, expected)
def test_ethosu_identity_type_inference(shape): dtype = "int8" ifm = relay.var("ifm", shape=shape, dtype=dtype) identity = make_ethosu_identity(ifm) func = relay.Function([ifm], identity) func = run_opt_pass(func, relay.transform.InferType()) assert tuple(func.body.checked_type.shape) == shape assert func.body.checked_type.dtype == dtype
def test_tuple_root(): """Test fusion case where Tuple node is the root in its group""" def before(dshape): x = relay.var("x", shape=dshape) pooled = relay.nn.max_pool2d(x, pool_size=(2, 2), strides=(2, 2), padding=(0, 0)) upsampled = relay.nn.upsampling(pooled, scale_h=2, scale_w=2, layout="NCHW") out = relay.Tuple((upsampled, x)) return relay.Function(relay.analysis.free_vars(out), out) def expected(dshape): x = relay.var("x", shape=dshape) pooled = relay.nn.max_pool2d(x, pool_size=(2, 2), strides=(2, 2), padding=(0, 0)) f0 = relay.Function([x], pooled) p0 = relay.var("p0", shape=(dshape[0], dshape[1], dshape[2] // 2, dshape[3] // 2)) upsampled = relay.nn.upsampling(p0, scale_h=2, scale_w=2, layout="NCHW") f1 = relay.Function([p0], upsampled) x = relay.var("x", shape=dshape) y = relay.Call(f0, [x]) z = relay.Call(f1, [y]) tup = relay.Tuple((z, x)) return relay.Function([x], tup) dshape = (1, 16, 64, 64) z = before(dshape) zz = run_opt_pass(z, transform.FuseOps(fuse_opt_level=0)) assert not relay.analysis.free_vars(zz) zz = run_opt_pass(z, transform.FuseOps(fuse_opt_level=2)) assert not relay.analysis.free_vars(zz) after = run_opt_pass(expected(dshape), transform.InferType()) assert relay.analysis.alpha_equal(zz, after)
def test_concretize_reshape_like_attrs(): data = relay.var("data", shape=(2, 3, 4), dtype="float32") shape_like = relay.var("shape_like", shape=(6, 2, 2), dtype="float32") expr = relay.reshape_like(data, shape_like, lhs_begin=2, rhs_begin=1) expected = run_infer_type(relay.reshape(data, (2, 3, 2, 2))) actual = run_opt_pass(expr, relay.transform.SimplifyExpr()) assert tvm.ir.structural_equal(actual, expected)
def test_concretize_broadcast_to_like(): data = relay.var("data", shape=(3, ), dtype="float32") shape_like = relay.var("shape_like", shape=(3, 3, 3), dtype="float32") expr = relay.broadcast_to_like(data, shape_like) expected = run_infer_type(relay.broadcast_to(data, (3, 3, 3))) actual = run_opt_pass(expr, relay.transform.SimplifyExpr()) assert tvm.ir.structural_equal(actual, expected)
def validate(shape, value, dtype): def before_left(x, elem_op, full): return elem_op(full, x) def after_left(x, elem_op, value): return elem_op(relay.const(value, dtype), x) def before_right(x, elem_op, full): return elem_op(x, full) def after_right(x, elem_op, value): return elem_op(x, relay.const(value, dtype)) x = relay.var("x", shape=shape, dtype=dtype) elem_ops = [relay.add, relay.multiply, relay.subtract, relay.divide] full_ops = [] if value == 0: full_ops.append(relay.zeros(shape, dtype)) full_ops.append(relay.zeros_like(x)) if value == 1: full_ops.append(relay.ones(shape, dtype)) full_ops.append(relay.ones_like(x)) else: full_ops.append(relay.full(relay.const(value, dtype), shape)) full_ops.append(relay.full_like(x, relay.const(value, dtype))) for op in elem_ops: for full in full_ops: z = before_left(x, op, full) zz = run_opt_pass(z, transform.SimplifyExpr()) after = run_opt_pass(after_left(x, op, value), transform.InferType()) assert tvm.ir.structural_equal(zz, after) z = before_right(x, op, full) zz = run_opt_pass(z, transform.SimplifyExpr()) after = run_opt_pass(after_right(x, op, value), transform.InferType()) assert tvm.ir.structural_equal(zz, after) # Test the case in which x is broadcast to full's shape full_ops = [] if value == 0: full_ops.append(relay.zeros(shape * 2, dtype)) if value == 1: full_ops.append(relay.ones(shape * 2, dtype)) else: full_ops.append(relay.full(relay.const(value, dtype), shape * 2)) for op in elem_ops: for full in full_ops: z = before_left(x, op, full) zz = run_opt_pass(z, transform.SimplifyExpr()) after = run_opt_pass(before_left(x, op, full), transform.InferType()) assert tvm.ir.structural_equal(zz, after) z = before_right(x, op, full) zz = run_opt_pass(z, transform.SimplifyExpr()) after = run_opt_pass(before_right(x, op, full), transform.InferType()) assert tvm.ir.structural_equal(zz, after)
def test_simple_merge(): """Test composite function is correctly produced from simple graph. We could expect the pattern `make_add_relu_pattern` to be merged into a single op `add_relu`. a b \ / a b add ====> \ / | add_relu relu """ pattern_table = [("add_relu", make_add_relu_pattern())] def before(): a = relay.var('a', shape=(10, 10)) b = relay.var('b', shape=(10, 10)) add_node = relay.add(a, b) r = relay.nn.relu(add_node) return relay.Function([a, b], r) def expected(): a = relay.var('a', shape=(10, 10)) b = relay.var('b', shape=(10, 10)) # add_relu function in_1 = relay.var('in_1', shape=(10, 10)) in_2 = relay.var('in_2', shape=(10, 10)) add_node = relay.add(in_1, in_2) relu_node = relay.nn.relu(add_node) add_relu = relay.Function([in_1, in_2], relu_node) add_relu = add_relu.set_attribute("Primitive", tir.IntImm("int32", 1)) add_relu = add_relu.set_attribute("Composite", tir.StringImm("add_relu")) # merged function r = relay.Call(add_relu, [a, b]) return relay.Function([a, b], r) result = run_opt_pass(before(), relay.transform.MergeComposite(pattern_table)) assert not relay.analysis.free_vars(result) expected = run_opt_pass(expected(), relay.transform.InferType()) assert relay.analysis.alpha_equal(result, expected)