def test_legalize_multi_input(): """Test directly replacing an operator with a new one""" def before(): x = relay.var("x", shape=(1, 64, 56, 56)) y = relay.var("y", shape=(1, 64, 56, 20)) z = relay.var("z", shape=(1, 64, 56, 10)) func = relay.concatenate([x, y, z], axis=3) func = relay.Function([x, y, z], func) return func @register_legalize("concatenate", level=100) def legalize_concatenate(attrs, inputs, arg_types): # Check that the correct multi-input case is handled. assert len(inputs) == 1 assert isinstance(inputs[0], tvm.relay.expr.Tuple) assert len(arg_types) == 1 assert isinstance(arg_types[0], tvm.relay.ty.TupleType) return None def expected(): x = relay.var("x", shape=(1, 64, 56, 56)) y = relay.var("y", shape=(1, 64, 56, 20)) z = relay.var("z", shape=(1, 64, 56, 10)) func = relay.concatenate([x, y, z], axis=3) func = relay.Function([x, y, z], func) return func a = before() a = run_opt_pass(a, transform.Legalize()) b = run_opt_pass(expected(), transform.InferType()) assert analysis.alpha_equal(a, b), "Actual = \n" + str(a)
def test_legalize_multi_input(): """Test directly replacing an operator with a new one""" def before(): x = relay.var("x", shape=(1, 64, 56, 56)) y = relay.var("y", shape=(1, 64, 56, 20)) z = relay.var("z", shape=(1, 64, 56, 10)) func = relay.concatenate([x, y, z], axis=3) func = relay.Function([x, y, z], func) return func def legalize_concatenate(attrs, inputs, types): # Check that the correct multi-input case is handled. assert len(inputs) == 1 assert isinstance(inputs[0], tvm.relay.expr.Tuple) assert len(types) == 2 assert isinstance(types[0], tvm.relay.ty.TupleType) assert isinstance(types[1], tvm.relay.ty.TensorType) return None def expected(): x = relay.var("x", shape=(1, 64, 56, 56)) y = relay.var("y", shape=(1, 64, 56, 20)) z = relay.var("z", shape=(1, 64, 56, 10)) func = relay.concatenate([x, y, z], axis=3) func = relay.Function([x, y, z], func) return func with TempOpAttr("concatenate", "FTVMLegalize", legalize_concatenate): a = before() a = run_opt_pass(a, transform.Legalize()) b = run_opt_pass(expected(), transform.InferType()) assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a)
def test_legalize(): """Test directly replacing an operator with a new one""" def before(): x = relay.var("x", shape=(1, 64, 56, 56)) weight = relay.var('weight', shape=(64, 64, 3, 3)) y = relay.nn.conv2d(x, weight, channels=64, kernel_size=(3, 3), padding=(1, 1)) y = relay.nn.relu(y) y = relay.Function([x, weight], y) return y def legalize_conv2d(attrs, inputs, types): data, weight = inputs weight = relay.multiply(weight, relay.const(2.0, "float32")) return relay.nn.conv2d(data, weight, **attrs) def expected(): x = relay.var("x", shape=(1, 64, 56, 56)) weight = relay.var('weight', shape=(64, 64, 3, 3)) y = relay.nn.conv2d(x, relay.multiply(weight, relay.const(2.0, "float32")), channels=64, kernel_size=(3, 3), padding=(1, 1)) y = relay.nn.relu(y) y = relay.Function([x, weight], y) return y with TempOpAttr("nn.conv2d", "FTVMLegalize", legalize_conv2d): a = before() a = run_opt_pass(a, transform.Legalize()) b = run_opt_pass(expected(), transform.InferType()) assert analysis.alpha_equal(a, b), "Actual = \n" + str(a)
def _test_legalize_batch_matmul(data_shape, kernel_shape, pad_shape, dtype, do_pad=True): """test legalize dense to enable tensorcore""" B, M, _ = data_shape _, N, _ = kernel_shape out_shape = (B, M, N) dm, dk, dn = pad_shape def before(): x = relay.var("x", shape=data_shape, dtype=dtype) weight = relay.var("weight", shape=kernel_shape, dtype=dtype) y = relay.nn.batch_matmul(x, weight) y = relay.Function([x, weight], y) return y def legalize_batch_matmul(attrs, inputs, types): with tvm.target.Target("cuda"): return topi.nn.batch_matmul_legalize(attrs, inputs, types) def expected(): if not do_pad: return before() x = relay.var("x", shape=data_shape, dtype=dtype) if dm or dk: x_pad = relay.nn.pad(x, pad_width=((0, 0), (0, dm), (0, dk))) else: x_pad = x weight = relay.var("weight", shape=(kernel_shape), dtype=dtype) if dn or dk: weight_pad = relay.nn.pad(weight, pad_width=((0, 0), (0, dn), (0, dk))) else: weight_pad = weight y_pad = relay.nn.batch_matmul( x_pad, weight_pad, ) if dm or dn: y = relay.strided_slice(y_pad, begin=[0, 0, 0], end=out_shape) else: y = y_pad y = relay.Function([x, weight], y) return y with TempOpAttr("nn.batch_matmul", "FTVMLegalize", legalize_batch_matmul): a = before() a = run_opt_pass(a, transform.Legalize()) b = run_opt_pass(expected(), transform.InferType()) assert tvm.ir.structural_equal( a, b), "Actual = \n" + str(a) + "Expected = \n" + str(b)
def test_legalize_arm_layout_functional(): """Test if the legalized conversion yields same result as original""" def get_output(func, data_val, parameters): with relay.build_config(opt_level=0): graph, lib, params = relay.build(func, target='llvm', params=parameters) m = graph_runtime.create(graph, lib, tvm.cpu()) m.set_input("data", data_val) m.set_input(**params) m.run() out = m.get_output(0, tvm.nd.empty((1, 224, 224, 32), 'float32')).asnumpy() return out def before(): n, ic, ih, iw, oc, kh, kw = 1, 16, 224, 224, 32, 3, 3 data = relay.var("data", relay.TensorType((n, ih, iw, ic), 'float32')) kernel = relay.var("kernel", relay.TensorType((kh, kw, ic, oc), 'float32')) y = relay.nn.conv2d(data, kernel, kernel_size=(kh, kw), channels=oc, padding=(1, 1), dilation=(1, 1), data_layout='NHWC', kernel_layout='HWIO', out_dtype='float32') func = relay.Function([data, kernel], y) return func @register_legalize("nn.conv2d", level=101) def legalize_conv2d(attrs, inputs, types): from topi.arm_cpu.conv2d import _conv2d_legalize return _conv2d_legalize(attrs, inputs, types) a = before() b = run_opt_pass(a, transform.Legalize()) assert b.astext().count('transpose') == 3 wdata = np.random.rand(3, 3, 16, 32) * 10 parameters = {"kernel": tvm.nd.array(wdata.astype('float32'))} data_val = np.random.rand(1, 224, 224, 16).astype('float32') ref_out = get_output(a, data_val, parameters) legalized_out = get_output(b, data_val, parameters) np.testing.assert_allclose(ref_out, legalized_out, rtol=0.01)
def test_legalize_none(): """Test doing nothing by returning 'None' """ def before(): x = relay.var("x", shape=(1, 64, 56, 56)) y = relay.nn.global_max_pool2d(x) y = relay.Function([x], y) return y called = [False] def legalize_conv2d(attrs, inputs, types): called[0] = True return None with TempOpAttr("nn.global_max_pool2d", "FTVMLegalize", legalize_conv2d): a = before() a = run_opt_pass(a, transform.Legalize()) b = run_opt_pass(before(), transform.InferType()) assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a) assert called[0]
def test_legalize_none(): """Test doing nothing by returning 'None' """ def before(): x = relay.var("x", shape=(1, 64, 56, 56)) y = relay.nn.global_max_pool2d(x) y = relay.Function([x], y) return y called = [False] @register_legalize("nn.global_max_pool2d", level=101) def legalize_conv2d(attrs, inputs, types): called[0] = True return None a = before() a = run_opt_pass(a, transform.Legalize()) b = before() b = run_opt_pass(b, transform.InferType()) assert analysis.alpha_equal(a, b), "Actual = \n" + str(a) assert (called[0])
def _test_legalize_conv2d(data_shape, kernel_shape, pad_shape, dtype, do_pad=True): out_channel = kernel_shape[3] out_shape = list(data_shape) out_shape[3] = out_channel db, di, do = pad_shape def before(): x = relay.var("x", shape=data_shape, dtype=dtype) weight = relay.var("weight", shape=kernel_shape, dtype=dtype) y = relay.nn.conv2d( x, weight, channels=out_channel, kernel_size=(3, 3), padding=(1, 1), data_layout="NHWC", kernel_layout="HWIO", ) y = relay.Function([x, weight], y) return y def legalize_conv2d(attrs, inputs, types): with tvm.target.Target("cuda"): return topi.nn.conv2d_legalize(attrs, inputs, types) def expected(): if not do_pad: return before() x = relay.var("x", shape=data_shape, dtype=dtype) if db or di: x_pad = relay.nn.pad(x, pad_width=((0, db), (0, 0), (0, 0), (0, di))) else: x_pad = x weight = relay.var("weight", shape=(kernel_shape), dtype=dtype) if di or do: weight_pad = relay.nn.pad(weight, pad_width=((0, 0), (0, 0), (0, di), (0, do))) else: weight_pad = weight y_pad = relay.nn.conv2d( x_pad, weight=weight_pad, channels=out_channel + do, kernel_size=(3, 3), padding=(1, 1), data_layout="NHWC", kernel_layout="HWIO", ) if db or do: y = relay.strided_slice(y_pad, begin=[0, 0, 0, 0], end=out_shape) else: y = y_pad y = relay.Function([x, weight], y) return y with TempOpAttr("nn.conv2d", "FTVMLegalize", legalize_conv2d): a = before() a = run_opt_pass(a, transform.Legalize()) b = run_opt_pass(expected(), transform.InferType()) assert tvm.ir.structural_equal( a, b), "Actual = \n" + str(a) + "Expected = \n" + str(b)
def partition_for_dnnl(mod, params=None, alter_layout=True): """Partition the graph greedily offloading supported operators to DNNL. Parameters ---------- mod : Module The module to run passes on. params : Optional[Dict[str, NDArray]] Constant input parameters. Returns ------- mod : Module Annotated and partitioned module. """ if params: mod["main"] = bind_params_by_name(mod["main"], params) with TempOpAttr("nn.conv2d", "FTVMLegalize", dnnl.legalize_group_conv): with TempOpAttr("nn.conv2d_transpose", "FTVMLegalize", dnnl.legalize_group_conv): seq = tvm.transform.Sequential([ transform.CanonicalizeOps(), transform.InferType(), transform.SimplifyInference(), transform.FoldConstant(), transform.FoldScaleAxis(), # fold consecutive add ops to simplify pattern `conv2d-bias_add-bn-relu` transform.SimplifyExpr(), transform.FoldConstant(), # alter group conv /conv_transpose layout to `GOIHW` / `GIOHW` transform.Legalize(), transform.FoldConstant(), ]) with tvm.transform.PassContext(opt_level=3): mod = seq(mod) if alter_layout: with TempOpAttr("nn.conv1d", "FTVMAlterOpLayout", dnnl.alter_conv): with TempOpAttr("nn.conv2d", "FTVMAlterOpLayout", dnnl.alter_conv): with TempOpAttr("nn.conv3d", "FTVMAlterOpLayout", dnnl.alter_conv): with TempOpAttr("nn.conv2d_transpose", "FTVMAlterOpLayout", dnnl.alter_conv_transpose): with TempOpAttr("nn.conv3d_transpose", "FTVMAlterOpLayout", dnnl.alter_conv_transpose): alter_layout_seq = tvm.transform.Sequential([ transform.AlterOpLayout(), transform.FoldConstant(), ]) with tvm.transform.PassContext(opt_level=3): mod = alter_layout_seq(mod) byoc_seq = tvm.transform.Sequential([ transform.MergeComposite(dnnl.pattern_table()), transform.AnnotateTarget("dnnl"), transform.MergeCompilerRegions(), transform.PartitionGraph(), ]) with tvm.transform.PassContext(opt_level=3): mod = byoc_seq(mod) return mod