Beispiel #1
0
    def test_star_can_nest_to_match_nested_patterns(self):
        pattern = (matcher.Star((matcher.Star(1), )), )
        self.assertDictEqual(matcher.match(pattern, ()), {})
        self.assertDictEqual(matcher.match(pattern, ((), )), {})
        self.assertDictEqual(matcher.match(pattern, ((1, ), (1, 1, 1))), {})

        with self.assertRaises(matcher.MatchError):
            matcher.match(pattern, (1, ))
Beispiel #2
0
    def test_star_with_accumulate_collects_values(self):
        pattern = (matcher.Star((matcher.Var('x'), matcher.Var('y')),
                                accumulate=['y']), )
        self.assertDictEqual(matcher.match(pattern, ((1, 2), (1, 3))),
                             dict(x=1, y=(2, 3)))

        with self.assertRaises(matcher.MatchError):
            matcher.match(pattern, ((1, 2), (2, 3)))
Beispiel #3
0
    def test_star_match_binds_name_to_environment(self):
        pattern = (matcher.Star(matcher.Var('x')), )
        self.assertDictEqual(matcher.match(pattern, ()), {})
        self.assertDictEqual(matcher.match(pattern, (1, )), dict(x=1))
        self.assertDictEqual(matcher.match(pattern, (1, 1)), dict(x=1))

        with self.assertRaises(matcher.MatchError):
            matcher.match(pattern, (1, 2))
Beispiel #4
0
 def test_var_correctly_applies_restrictions_when_matching(self):
     is_positive = lambda a: a > 0
     is_even = lambda a: a % 2 == 0
     x = matcher.Var('x', restrictions=[is_positive, is_even])
     self.assertDictEqual(matcher.match(x, 2.), dict(x=2))
     with self.assertRaises(matcher.MatchError):
         matcher.match(x, -2)
     with self.assertRaises(matcher.MatchError):
         matcher.match(x, 3)
Beispiel #5
0
    def test_star_greedily_matches_when_flag_is_set(self):
        pattern = (matcher.Star(1, name='x', greedy=True),
                   matcher.Star(1, name='y', greedy=True))
        # x will be the largset possible match
        self.assertDictEqual(matcher.match(pattern, (1, 1)),
                             dict(x=(1, 1), y=()))

        pattern = (matcher.Star(1, name='x', greedy=True),
                   matcher.Star(1, name='y', greedy=False))
        self.assertDictEqual(matcher.match(pattern, (1, 1)),
                             dict(x=(1, 1), y=()))
Beispiel #6
0
    def test_segment_matches_tuple_slices(self):
        pattern = (matcher.Segment('a'), )
        self.assertDictEqual(matcher.match(pattern, (1, 2, 3)),
                             {'a': (1, 2, 3)})

        pattern = (matcher.Segment('a'), 2, 3)
        self.assertDictEqual(matcher.match(pattern, (1, 2, 3)), {'a': (1, )})

        pattern = (matcher.Segment('a'), 2, 3)
        with self.assertRaises(matcher.MatchError):
            matcher.match(pattern, (1, 2))
Beispiel #7
0
    def test_segment_must_be_the_same_when_given_same_name(self):
        pattern = (matcher.Segment('a'), matcher.Segment('a'))
        self.assertDictEqual(matcher.match(pattern, ()), {'a': ()})
        self.assertDictEqual(matcher.match(pattern, (1, 1)), {'a': (1, )})
        with self.assertRaises(matcher.MatchError):
            matcher.match(pattern, (1, 1, 1))

        pattern = (matcher.Segment('x'), matcher.Segment('y'),
                   matcher.Segment('x'))
        matches = list(matcher.match_all(pattern, (1, ) * 10))
        self.assertLen(matches, 6)
        for i in range(len(matches)):
            self.assertDictEqual(matches[i],
                                 dict(x=(1, ) * i, y=(1, ) * (10 - 2 * i)))
Beispiel #8
0
 def test_can_match_addn_components(self):
   x = JaxVar('x', (5,), jnp.float32)
   op = AddN((x, x))
   pattern = AddN((matcher.Segment('args'),))
   self.assertDictEqual(
       matcher.match(pattern, op), {
           'args': (x, x)
       })
Beispiel #9
0
 def test_can_match_primitive_inside_of_pattern(self):
     pattern = jr.Primitive(matcher.Var('prim'),
                            (matcher.Segment('args'), ),
                            matcher.Var('params'))
     expr = Exp(jr.Literal(1.))
     self.assertDictEqual(
         matcher.match(pattern, expr),
         dict(prim=lax.exp_p, args=(jr.Literal(1.), ), params=jr.Params()))
Beispiel #10
0
 def test_can_match_einsum_components(self):
   x = JaxVar('x', (5,), jnp.float32)
   op = Einsum('a,a->', (x, x))
   pattern = Einsum(Var('formula'), (matcher.Segment('args'),))
   self.assertDictEqual(
       matcher.match(pattern, op), {
           'formula': 'a,a->',
           'args': (x, x)
       })
Beispiel #11
0
 def test_can_match_call_primitive_parts(self):
     pattern = jr.CallPrimitive(matcher.Var('prim'), matcher.Var('args'),
                                matcher.Var('expression'),
                                matcher.Var('params'), matcher.Var('names'))
     expr = jr.CallPrimitive(jax.core.call_p, (),
                             (jr.Literal(0.), jr.Literal(1)), jr.Params(),
                             [])
     self.assertDictEqual(
         matcher.match(pattern, expr),
         dict(prim=jax.core.call_p,
              args=(),
              expression=(jr.Literal(0.), jr.Literal(1.)),
              params=jr.Params(),
              names=[]))
Beispiel #12
0
    def test_star_match_correctly_matches_sequence_of_patterns(self):
        pattern = (matcher.Star(1), )
        self.assertDictEqual(matcher.match(pattern, ()), {})
        self.assertDictEqual(matcher.match(pattern, (1, )), {})
        self.assertDictEqual(matcher.match(pattern, (1, 1)), {})
        self.assertDictEqual(matcher.match(pattern, (1, 1, 1)), {})

        with self.assertRaises(matcher.MatchError):
            matcher.match(pattern, (1, 2))
Beispiel #13
0
    def test_can_use_star_patterns_in_string_patterns(self):
        pattern = ['a', 'b', matcher.Segment('rest')]
        self.assertDictEqual(matcher.match(pattern, 'abcd'), {'rest': 'cd'})

        with self.assertRaises(matcher.MatchError):
            matcher.match(pattern, 'acd')

        pattern = ['a', 'b', matcher.Star('c'), 'd']
        self.assertDictEqual(matcher.match(pattern, 'abccccd'), {})
        self.assertDictEqual(matcher.match(pattern, 'abd'), {})

        with self.assertRaises(matcher.MatchError):
            matcher.match(pattern, 'abccc')
Beispiel #14
0
 def test_dict_patterns_match_equal_dicts(self):
     self.assertDictEqual(matcher.match(dict(a=1, b=2), dict(a=1, b=2)), {})
     self.assertDictEqual(matcher.match(dict(a=1, b=2), dict(b=2, a=1)), {})
Beispiel #15
0
 def test_tuple_patterns_error_on_nonequal_tuples(self):
     with self.assertRaises(matcher.MatchError):
         matcher.match((1, 2, 4), (1, 2, 3))
Beispiel #16
0
 def test_tuple_patterns_match_equal_tuples(self):
     self.assertDictEqual(matcher.match((1, 2, 3), (1, 2, 3)), {})
     self.assertDictEqual(matcher.match(((1, 2), 2, 3), ((1, 2), 2, 3)), {})
Beispiel #17
0
 def test_list_patterns_match_equal_lists(self):
     self.assertDictEqual(matcher.match([1, 2, 3], [1, 2, 3]), {})
     self.assertDictEqual(matcher.match([(1, 2), 2, 3], [(1, 2), 2, 3]), {})
Beispiel #18
0
 def test_match_errors_with_star_pattern(self):
     with self.assertRaises(ValueError):
         matcher.match(matcher.Star(1), 1.)
Beispiel #19
0
 def test_star_with_plus_matches_nonempty_tuple(self):
     pattern = (matcher.Plus(1), )
     self.assertDictEqual(matcher.match(pattern, (1, 1)), {})
Beispiel #20
0
 def test_var_pattern_matches_when_bound_value_matches(self):
     x = matcher.Var('x')
     self.assertDictEqual(matcher.match((x, x), (1, 1)), dict(x=1))
Beispiel #21
0
 def test_var_pattern_errors_when_bound_value_doesnt_match(self):
     x = matcher.Var('x')
     with self.assertRaises(matcher.MatchError):
         matcher.match((x, x), (1, 2))
Beispiel #22
0
 def test_default_matcher_correctly_matches_equal_values(self):
     self.assertDictEqual(matcher.match(1., 1.), {})
     self.assertDictEqual(matcher.match(1., 1), {})
     self.assertDictEqual(matcher.match('hello', 'hello'), {})
Beispiel #23
0
 def test_dict_patterns_error_on_nonequal_dicts(self):
     with self.assertRaises(matcher.MatchError):
         matcher.match(dict(a=1, b=2), dict(a=2, b=2))
     with self.assertRaises(matcher.MatchError):
         matcher.match(dict(a=1, b=2), dict(a=1, b=2, c=3))
Beispiel #24
0
 def test_default_matcher_errors_on_nonequal_values(self):
     with self.assertRaises(matcher.MatchError):
         matcher.match(0., 1.)
     with self.assertRaises(matcher.MatchError):
         matcher.match('a', 'b')
Beispiel #25
0
 def test_var_pattern_matches_any_expression(self):
     self.assertDictEqual(matcher.match(matcher.Var('x'), 1.), {'x': 1.})
     self.assertDictEqual(matcher.match(matcher.Var('x'), 'hello'),
                          {'x': 'hello'})
Beispiel #26
0
 def test_not_pattern_correctly_matches_nonequal_values(self):
     self.assertDictEqual(matcher.match(matcher.Not(0.), 1.), {})
     self.assertDictEqual(matcher.match(matcher.Not('a'), 1.), {})
Beispiel #27
0
 def test_can_match_string_literals(self):
     pattern = 'abcd'
     self.assertDictEqual(matcher.match(pattern, 'abcd'), {})
     with self.assertRaises(matcher.MatchError):
         matcher.match(pattern, 'dcba')
Beispiel #28
0
 def test_plus_errors_on_empty_tuple(self):
     pattern = (matcher.Plus(1), )
     with self.assertRaises(matcher.MatchError):
         matcher.match(pattern, ())
Beispiel #29
0
 def test_dot_pattern_matches_without_creating_binding(self):
     self.assertDictEqual(matcher.match(matcher.Dot, 1), {})
     self.assertDictEqual(matcher.match((matcher.Dot, matcher.Dot), (1, 2)),
                          {})
Beispiel #30
0
 def test_can_match_var_with_length_one_string(self):
     pattern = matcher.Var('x')
     self.assertDictEqual(matcher.match(pattern, 'a'), {'x': 'a'})