def alpha_equal(x, y): """ Wrapper around alpha equality which ensures that the hash function respects equality. """ return analysis.alpha_equal( x, y) and analysis.structural_hash(x) == analysis.structural_hash(y)
def test_tuple_match(): a = relay.Var("a") b = relay.Var("b") clause = relay.Clause(relay.PatternTuple([relay.PatternVar(a), relay.PatternVar(b)]), a + b) x = relay.Match(relay.Tuple([relay.const(1), relay.const(1)]), [clause]) a = relay.Var("a") b = relay.Var("b") clause = relay.Clause(relay.PatternTuple([relay.PatternVar(a), relay.PatternVar(b)]), a + b) y = relay.Match(relay.Tuple([relay.const(1), relay.const(1)]), [clause]) assert analysis.alpha_equal(x, y) assert analysis.structural_hash(x) == analysis.structural_hash(y)
def test_hash_unequal(): x1 = relay.var("x1", shape=(10, 10), dtype="float32") y1 = relay.var("y1", shape=(10, 10), dtype="float32") func1 = relay.Function([x1, y1], relay.add(x1, y1)) # func2 is exactly same structure with same variables shapes and dtypes x2 = relay.var("x2", shape=(10, 10), dtype="float32") y2 = relay.var("y2", shape=(10, 10), dtype="float32") func2 = relay.Function([x2, y2], relay.add(x2, y2)) assert analysis.structural_hash(func1) == analysis.structural_hash(func2) # func3 is same as func1 but with different var shapes x3 = relay.var("x3", shape=(20, 10), dtype="float32") y3 = relay.var("y3", shape=(20, 10), dtype="float32") func3 = relay.Function([x3, y3], relay.add(x3, y3)) assert not analysis.structural_hash(func1) == analysis.structural_hash(func3)