def test_segment_matches_tuple_slices(self): pattern = (matcher.Segment('a'), ) self.assertDictEqual(matcher.match(pattern, (1, 2, 3)), {'a': (1, 2, 3)}) pattern = (matcher.Segment('a'), 2, 3) self.assertDictEqual(matcher.match(pattern, (1, 2, 3)), {'a': (1, )}) pattern = (matcher.Segment('a'), 2, 3) with self.assertRaises(matcher.MatchError): matcher.match(pattern, (1, 2))
def test_segment_must_be_the_same_when_given_same_name(self): pattern = (matcher.Segment('a'), matcher.Segment('a')) self.assertDictEqual(matcher.match(pattern, ()), {'a': ()}) self.assertDictEqual(matcher.match(pattern, (1, 1)), {'a': (1, )}) with self.assertRaises(matcher.MatchError): matcher.match(pattern, (1, 1, 1)) pattern = (matcher.Segment('x'), matcher.Segment('y'), matcher.Segment('x')) matches = list(matcher.match_all(pattern, (1, ) * 10)) self.assertLen(matches, 6) for i in range(len(matches)): self.assertDictEqual(matches[i], dict(x=(1, ) * i, y=(1, ) * (10 - 2 * i)))
def test_can_match_addn_components(self): x = JaxVar('x', (5,), jnp.float32) op = AddN((x, x)) pattern = AddN((matcher.Segment('args'),)) self.assertDictEqual( matcher.match(pattern, op), { 'args': (x, x) })
def test_can_match_primitive_inside_of_pattern(self): pattern = jr.Primitive(matcher.Var('prim'), (matcher.Segment('args'), ), matcher.Var('params')) expr = Exp(jr.Literal(1.)) self.assertDictEqual( matcher.match(pattern, expr), dict(prim=lax.exp_p, args=(jr.Literal(1.), ), params=jr.Params()))
def test_can_match_einsum_components(self): x = JaxVar('x', (5,), jnp.float32) op = Einsum('a,a->', (x, x)) pattern = Einsum(Var('formula'), (matcher.Segment('args'),)) self.assertDictEqual( matcher.match(pattern, op), { 'formula': 'a,a->', 'args': (x, x) })
def test_can_replace_addn_operands(self): x = JaxVar('x', (5,), jnp.float32) y = JaxVar('y', (5,), jnp.float32) z = JaxVar('y', (5,), jnp.float32) op = AddN((x, y)) pattern = AddN((matcher.Segment('args'),)) def replace_with_z(args): del args return AddN((z, z)) replace_rule = rules.make_rule(pattern, replace_with_z) replaced_op = replace_rule(op) self.assertEqual(replaced_op, AddN((z, z)))
def test_can_replace_einsum_operands(self): x = JaxVar('x', (5,), jnp.float32) y = JaxVar('y', (5,), jnp.float32) z = JaxVar('y', (5,), jnp.float32) op = Einsum('a,a->', (x, y)) pattern = Einsum(Var('formula'), (matcher.Segment('args'),)) def replace_with_z(formula, args): del args return Einsum(formula, (z, z)) replace_rule = rules.make_rule(pattern, replace_with_z) replaced_op = replace_rule(op) self.assertEqual(replaced_op, Einsum('a,a->', (z, z)))
def test_can_use_star_patterns_in_string_patterns(self): pattern = ['a', 'b', matcher.Segment('rest')] self.assertDictEqual(matcher.match(pattern, 'abcd'), {'rest': 'cd'}) with self.assertRaises(matcher.MatchError): matcher.match(pattern, 'acd') pattern = ['a', 'b', matcher.Star('c'), 'd'] self.assertDictEqual(matcher.match(pattern, 'abccccd'), {}) self.assertDictEqual(matcher.match(pattern, 'abd'), {}) with self.assertRaises(matcher.MatchError): matcher.match(pattern, 'abccc')
def test_segment_matches_multiple_tuple_slices(self): pattern = (matcher.Segment('a'), matcher.Segment('b')) self.assertLen(list(matcher.match_all(pattern, ())), 1) self.assertLen(list(matcher.match_all(pattern, (1, ))), 2) self.assertLen(list(matcher.match_all(pattern, (1, 1))), 3)
def test_can_match_part_op_and_index(self): pattern = jr.Part((matcher.Segment('args'), ), matcher.Var('i')) expr = jr.Part((jr.Literal(0.), jr.Literal(1.)), 1) self.assertDictEqual(matcher.match(pattern, expr), dict(args=(jr.Literal(0.), jr.Literal(1.)), i=1))
def test_can_match_value_inside_params(self): pattern = jr.Primitive(matcher.Dot, (matcher.Segment(name=None), ), jr.Params({'foo': matcher.Var('foo')})) expr = jr.Primitive(lax.iota_p, (), jr.Params(foo='bar')) self.assertDictEqual(matcher.match(pattern, expr), dict(foo='bar'))