def test_star_can_nest_to_match_nested_patterns(self): pattern = (matcher.Star((matcher.Star(1), )), ) self.assertDictEqual(matcher.match(pattern, ()), {}) self.assertDictEqual(matcher.match(pattern, ((), )), {}) self.assertDictEqual(matcher.match(pattern, ((1, ), (1, 1, 1))), {}) with self.assertRaises(matcher.MatchError): matcher.match(pattern, (1, ))
def test_star_with_accumulate_collects_values(self): pattern = (matcher.Star((matcher.Var('x'), matcher.Var('y')), accumulate=['y']), ) self.assertDictEqual(matcher.match(pattern, ((1, 2), (1, 3))), dict(x=1, y=(2, 3))) with self.assertRaises(matcher.MatchError): matcher.match(pattern, ((1, 2), (2, 3)))
def test_star_match_binds_name_to_environment(self): pattern = (matcher.Star(matcher.Var('x')), ) self.assertDictEqual(matcher.match(pattern, ()), {}) self.assertDictEqual(matcher.match(pattern, (1, )), dict(x=1)) self.assertDictEqual(matcher.match(pattern, (1, 1)), dict(x=1)) with self.assertRaises(matcher.MatchError): matcher.match(pattern, (1, 2))
def test_var_correctly_applies_restrictions_when_matching(self): is_positive = lambda a: a > 0 is_even = lambda a: a % 2 == 0 x = matcher.Var('x', restrictions=[is_positive, is_even]) self.assertDictEqual(matcher.match(x, 2.), dict(x=2)) with self.assertRaises(matcher.MatchError): matcher.match(x, -2) with self.assertRaises(matcher.MatchError): matcher.match(x, 3)
def test_star_greedily_matches_when_flag_is_set(self): pattern = (matcher.Star(1, name='x', greedy=True), matcher.Star(1, name='y', greedy=True)) # x will be the largset possible match self.assertDictEqual(matcher.match(pattern, (1, 1)), dict(x=(1, 1), y=())) pattern = (matcher.Star(1, name='x', greedy=True), matcher.Star(1, name='y', greedy=False)) self.assertDictEqual(matcher.match(pattern, (1, 1)), dict(x=(1, 1), y=()))
def test_segment_matches_tuple_slices(self): pattern = (matcher.Segment('a'), ) self.assertDictEqual(matcher.match(pattern, (1, 2, 3)), {'a': (1, 2, 3)}) pattern = (matcher.Segment('a'), 2, 3) self.assertDictEqual(matcher.match(pattern, (1, 2, 3)), {'a': (1, )}) pattern = (matcher.Segment('a'), 2, 3) with self.assertRaises(matcher.MatchError): matcher.match(pattern, (1, 2))
def test_segment_must_be_the_same_when_given_same_name(self): pattern = (matcher.Segment('a'), matcher.Segment('a')) self.assertDictEqual(matcher.match(pattern, ()), {'a': ()}) self.assertDictEqual(matcher.match(pattern, (1, 1)), {'a': (1, )}) with self.assertRaises(matcher.MatchError): matcher.match(pattern, (1, 1, 1)) pattern = (matcher.Segment('x'), matcher.Segment('y'), matcher.Segment('x')) matches = list(matcher.match_all(pattern, (1, ) * 10)) self.assertLen(matches, 6) for i in range(len(matches)): self.assertDictEqual(matches[i], dict(x=(1, ) * i, y=(1, ) * (10 - 2 * i)))
def test_can_match_addn_components(self): x = JaxVar('x', (5,), jnp.float32) op = AddN((x, x)) pattern = AddN((matcher.Segment('args'),)) self.assertDictEqual( matcher.match(pattern, op), { 'args': (x, x) })
def test_can_match_primitive_inside_of_pattern(self): pattern = jr.Primitive(matcher.Var('prim'), (matcher.Segment('args'), ), matcher.Var('params')) expr = Exp(jr.Literal(1.)) self.assertDictEqual( matcher.match(pattern, expr), dict(prim=lax.exp_p, args=(jr.Literal(1.), ), params=jr.Params()))
def test_can_match_einsum_components(self): x = JaxVar('x', (5,), jnp.float32) op = Einsum('a,a->', (x, x)) pattern = Einsum(Var('formula'), (matcher.Segment('args'),)) self.assertDictEqual( matcher.match(pattern, op), { 'formula': 'a,a->', 'args': (x, x) })
def test_can_match_call_primitive_parts(self): pattern = jr.CallPrimitive(matcher.Var('prim'), matcher.Var('args'), matcher.Var('expression'), matcher.Var('params'), matcher.Var('names')) expr = jr.CallPrimitive(jax.core.call_p, (), (jr.Literal(0.), jr.Literal(1)), jr.Params(), []) self.assertDictEqual( matcher.match(pattern, expr), dict(prim=jax.core.call_p, args=(), expression=(jr.Literal(0.), jr.Literal(1.)), params=jr.Params(), names=[]))
def test_star_match_correctly_matches_sequence_of_patterns(self): pattern = (matcher.Star(1), ) self.assertDictEqual(matcher.match(pattern, ()), {}) self.assertDictEqual(matcher.match(pattern, (1, )), {}) self.assertDictEqual(matcher.match(pattern, (1, 1)), {}) self.assertDictEqual(matcher.match(pattern, (1, 1, 1)), {}) with self.assertRaises(matcher.MatchError): matcher.match(pattern, (1, 2))
def test_can_use_star_patterns_in_string_patterns(self): pattern = ['a', 'b', matcher.Segment('rest')] self.assertDictEqual(matcher.match(pattern, 'abcd'), {'rest': 'cd'}) with self.assertRaises(matcher.MatchError): matcher.match(pattern, 'acd') pattern = ['a', 'b', matcher.Star('c'), 'd'] self.assertDictEqual(matcher.match(pattern, 'abccccd'), {}) self.assertDictEqual(matcher.match(pattern, 'abd'), {}) with self.assertRaises(matcher.MatchError): matcher.match(pattern, 'abccc')
def test_dict_patterns_match_equal_dicts(self): self.assertDictEqual(matcher.match(dict(a=1, b=2), dict(a=1, b=2)), {}) self.assertDictEqual(matcher.match(dict(a=1, b=2), dict(b=2, a=1)), {})
def test_tuple_patterns_error_on_nonequal_tuples(self): with self.assertRaises(matcher.MatchError): matcher.match((1, 2, 4), (1, 2, 3))
def test_tuple_patterns_match_equal_tuples(self): self.assertDictEqual(matcher.match((1, 2, 3), (1, 2, 3)), {}) self.assertDictEqual(matcher.match(((1, 2), 2, 3), ((1, 2), 2, 3)), {})
def test_list_patterns_match_equal_lists(self): self.assertDictEqual(matcher.match([1, 2, 3], [1, 2, 3]), {}) self.assertDictEqual(matcher.match([(1, 2), 2, 3], [(1, 2), 2, 3]), {})
def test_match_errors_with_star_pattern(self): with self.assertRaises(ValueError): matcher.match(matcher.Star(1), 1.)
def test_star_with_plus_matches_nonempty_tuple(self): pattern = (matcher.Plus(1), ) self.assertDictEqual(matcher.match(pattern, (1, 1)), {})
def test_var_pattern_matches_when_bound_value_matches(self): x = matcher.Var('x') self.assertDictEqual(matcher.match((x, x), (1, 1)), dict(x=1))
def test_var_pattern_errors_when_bound_value_doesnt_match(self): x = matcher.Var('x') with self.assertRaises(matcher.MatchError): matcher.match((x, x), (1, 2))
def test_default_matcher_correctly_matches_equal_values(self): self.assertDictEqual(matcher.match(1., 1.), {}) self.assertDictEqual(matcher.match(1., 1), {}) self.assertDictEqual(matcher.match('hello', 'hello'), {})
def test_dict_patterns_error_on_nonequal_dicts(self): with self.assertRaises(matcher.MatchError): matcher.match(dict(a=1, b=2), dict(a=2, b=2)) with self.assertRaises(matcher.MatchError): matcher.match(dict(a=1, b=2), dict(a=1, b=2, c=3))
def test_default_matcher_errors_on_nonequal_values(self): with self.assertRaises(matcher.MatchError): matcher.match(0., 1.) with self.assertRaises(matcher.MatchError): matcher.match('a', 'b')
def test_var_pattern_matches_any_expression(self): self.assertDictEqual(matcher.match(matcher.Var('x'), 1.), {'x': 1.}) self.assertDictEqual(matcher.match(matcher.Var('x'), 'hello'), {'x': 'hello'})
def test_not_pattern_correctly_matches_nonequal_values(self): self.assertDictEqual(matcher.match(matcher.Not(0.), 1.), {}) self.assertDictEqual(matcher.match(matcher.Not('a'), 1.), {})
def test_can_match_string_literals(self): pattern = 'abcd' self.assertDictEqual(matcher.match(pattern, 'abcd'), {}) with self.assertRaises(matcher.MatchError): matcher.match(pattern, 'dcba')
def test_plus_errors_on_empty_tuple(self): pattern = (matcher.Plus(1), ) with self.assertRaises(matcher.MatchError): matcher.match(pattern, ())
def test_dot_pattern_matches_without_creating_binding(self): self.assertDictEqual(matcher.match(matcher.Dot, 1), {}) self.assertDictEqual(matcher.match((matcher.Dot, matcher.Dot), (1, 2)), {})
def test_can_match_var_with_length_one_string(self): pattern = matcher.Var('x') self.assertDictEqual(matcher.match(pattern, 'a'), {'x': 'a'})