def test_tuple_get_time(): def before(): x = relay.var("x", shape=(1, 16, 1, 1)) var = relay.var("var", shape=(16, )) mean = relay.var("mean", shape=(16, )) beta = relay.var("beta", shape=(16, )) gamma = relay.var("gamma", shape=(16, )) BN = relay.op.nn.batch_norm(x, gamma, beta, mean, var, epsilon=1e-5) T1 = BN[0] T2 = BN[0] add = T1 + T2 f = relay.Function([x, var, mean, beta, gamma], add) return f def expected(): x = relay.var("x", shape=(1, 16, 1, 1)) var = relay.var("var", shape=(16, )) mean = relay.var("mean", shape=(16, )) beta = relay.var("beta", shape=(16, )) gamma = relay.var("gamma", shape=(16, )) BN = relay.op.nn.batch_norm(x, gamma, beta, mean, var, epsilon=1e-5) T1 = BN[0] add = T1 + T1 f = relay.Function([x, var, mean, beta, gamma], add) return run_opt_pass(f, transform.InferType()) z = before() z = run_opt_pass(z, transform.EliminateCommonSubexpr()) assert tvm.ir.structural_equal(z, expected())
def test_callback(): def before(): x = relay.var("x", shape=(1, 16)) y1 = relay.nn.relu(x) y2 = relay.nn.relu(x) y1 = relay.add(y1, relay.const(1.0, "float32")) y2 = relay.add(y2, relay.const(1.0, "float32")) y = relay.add(y1, y2) f = relay.Function([x], y) return f def expected(): x = relay.var("x", shape=(1, 16)) y = relay.nn.relu(x) y1 = relay.add(y, relay.const(1.0, "float32")) y2 = relay.add(y, relay.const(1.0, "float32")) y = relay.add(y1, y2) f = relay.Function([x], y) return run_opt_pass(f, transform.InferType()) def fskip(expr): if isinstance(expr, relay.expr.Call) and expr.op.name == "add": return True return False z = before() z = run_opt_pass(z, transform.EliminateCommonSubexpr(fskip)) assert tvm.ir.structural_equal(z, expected())
def test_simple(): def before(): x = relay.var("x", shape=(1, 16)) y1 = relay.nn.relu(x) y2 = relay.nn.relu(x) y1 = relay.add(y1, relay.const(1.0, "float32")) y2 = relay.add(y2, relay.const(1.0, "float32")) y = relay.add(y1, y2) f = relay.Function([x], y) return f def expected(): x = relay.var("x", shape=(1, 16)) y = relay.nn.relu(x) y = relay.add(y, relay.const(1.0, "float32")) y = relay.add(y, y) f = relay.Function([x], y) return run_opt_pass(f, transform.InferType()) z = before() z = run_opt_pass(z, transform.EliminateCommonSubexpr()) assert tvm.ir.structural_equal(z, expected())
def match_pass_name(name): if name == 'FoldScaleAxis': return transform.FoldScaleAxis() if name == 'BackwardFoldScaleAxis': return transform.BackwardFoldScaleAxis() if name == 'ForwardFoldScaleAxis': return transform.ForwardFoldScaleAxis() if name == 'FuseOps': return transform.FuseOps(3) if name == 'FoldConstant': return transform.FoldConstant() if name == 'CombineParallelConv2d': return transform.CombineParallelConv2D() if name == 'AlterOpLayout': return transform.AlterOpLayout() if name == 'EliminateCommonSubexpr': return transform.EliminateCommonSubexpr() if name == 'PartialEvaluate': return transform.PartialEvaluate() if name == 'CanonicalizeCast': return transform.CanonicalizeCast() if name == 'CanonicalizeOps': return transform.CanonicalizeOps() raise Exception('Name {} does not match any pass'.format(name))