def test_star_with_accumulate_collects_values(self): pattern = (matcher.Star((matcher.Var('x'), matcher.Var('y')), accumulate=['y']), ) self.assertDictEqual(matcher.match(pattern, ((1, 2), (1, 3))), dict(x=1, y=(2, 3))) with self.assertRaises(matcher.MatchError): matcher.match(pattern, ((1, 2), (2, 3)))
def test_can_match_primitive_inside_of_pattern(self): pattern = jr.Primitive(matcher.Var('prim'), (matcher.Segment('args'), ), matcher.Var('params')) expr = Exp(jr.Literal(1.)) self.assertDictEqual( matcher.match(pattern, expr), dict(prim=lax.exp_p, args=(jr.Literal(1.), ), params=jr.Params()))
def test_can_match_call_primitive_parts(self): pattern = jr.CallPrimitive(matcher.Var('prim'), matcher.Var('args'), matcher.Var('expression'), matcher.Var('params'), matcher.Var('names')) expr = jr.CallPrimitive(jax.core.call_p, (), (jr.Literal(0.), jr.Literal(1)), jr.Params(), []) self.assertDictEqual( matcher.match(pattern, expr), dict(prim=jax.core.call_p, args=(), expression=(jr.Literal(0.), jr.Literal(1.)), params=jr.Params(), names=[]))
def test_star_match_binds_name_to_environment(self): pattern = (matcher.Star(matcher.Var('x')), ) self.assertDictEqual(matcher.match(pattern, ()), {}) self.assertDictEqual(matcher.match(pattern, (1, )), dict(x=1)) self.assertDictEqual(matcher.match(pattern, (1, 1)), dict(x=1)) with self.assertRaises(matcher.MatchError): matcher.match(pattern, (1, 2))
def test_can_rewrite_jax_functions_with_jit(self): pattern = Exp(matcher.Var('x')) rule = rules.term_rewriter(rules.make_rule(pattern, Log)) def f(x): return jax.jit(jnp.exp)(x) + 1. f = jr.rewrite(f, rule) self.assertEqual(f(1.), 1.)
def test_var_correctly_applies_restrictions_when_matching(self): is_positive = lambda a: a > 0 is_even = lambda a: a % 2 == 0 x = matcher.Var('x', restrictions=[is_positive, is_even]) self.assertDictEqual(matcher.match(x, 2.), dict(x=2)) with self.assertRaises(matcher.MatchError): matcher.match(x, -2) with self.assertRaises(matcher.MatchError): matcher.match(x, 3)
def test_can_rewrite_jax_functions_with_constants(self): pattern = Exp(matcher.Var('x')) rule = rules.term_rewriter(rules.make_rule(pattern, Log)) def f(x): return jnp.exp(x) + jnp.ones(x.shape) f = jr.rewrite(f, rule) self.assertTrue((f(jnp.ones(5)) == jnp.ones(5)).all())
def test_choice_will_backtrack_and_try_other_options(self): x, y = matcher.Var('x'), matcher.Var('y') self.assertDictEqual( matcher.match((matcher.Choice(x, y), x, y), (3, 2, 3)), dict(x=2, y=3))
def test_rule_should_pass_bindings_into_rewrite(self): add_one = rules.make_rule(matcher.Var('a'), lambda a: a + 1) self.assertEqual(add_one(1.), 2.)
def test_var_pattern_matches_any_expression(self): self.assertDictEqual(matcher.match(matcher.Var('x'), 1.), {'x': 1.}) self.assertDictEqual(matcher.match(matcher.Var('x'), 'hello'), {'x': 'hello'})
def test_can_match_part_op_and_index(self): pattern = jr.Part((matcher.Segment('args'), ), matcher.Var('i')) expr = jr.Part((jr.Literal(0.), jr.Literal(1.)), 1) self.assertDictEqual(matcher.match(pattern, expr), dict(args=(jr.Literal(0.), jr.Literal(1.)), i=1))
def test_can_match_jax_var_name_shape_and_dtype(self): pattern = jr.JaxVar(matcher.Var('name'), matcher.Var('shape'), matcher.Var('dtype')) expr = jr.JaxVar('a', (1, 2, 3), jnp.int32) self.assertDictEqual(matcher.match(pattern, expr), dict(name='a', shape=(1, 2, 3), dtype=jnp.int32))
def test_can_match_literal_value(self): pattern = jr.Literal(matcher.Var('x')) expr = jr.Literal(0.) self.assertDictEqual(matcher.match(pattern, expr), dict(x=0.))
def test_can_match_value_inside_params(self): pattern = jr.Primitive(matcher.Dot, (matcher.Segment(name=None), ), jr.Params({'foo': matcher.Var('foo')})) expr = jr.Primitive(lax.iota_p, (), jr.Params(foo='bar')) self.assertDictEqual(matcher.match(pattern, expr), dict(foo='bar'))
def test_can_match_input_to_primitive(self): pattern = Exp(matcher.Var('x')) expr = Exp(Exp(jr.Literal(0.))) self.assertListEqual(list(matcher.match_all(pattern, expr)), [dict(x=Exp(jr.Literal(0.)))])
def test_star_with_name_binds_result(self): pattern = (matcher.Star(1, name='x'), ) self.assertDictEqual(matcher.match(pattern, (1, 1)), dict(x=(1, 1))) pattern = (matcher.Star(matcher.Var('y'), name='x'), ) self.assertDictEqual(matcher.match(pattern, (1, 1)), dict(y=1, x=(1, 1)))
def test_can_match_var_with_length_one_string(self): pattern = matcher.Var('x') self.assertDictEqual(matcher.match(pattern, 'a'), {'x': 'a'})
def test_can_rewrite_simple_jax_expressions(self): pattern = Exp(matcher.Var('x')) rule = rules.make_rule(pattern, Log) expr = Exp(1.) self.assertEqual(rule(expr), Log(1.))
def test_var_pattern_matches_when_bound_value_matches(self): x = matcher.Var('x') self.assertDictEqual(matcher.match((x, x), (1, 1)), dict(x=1))
def test_can_rewrite_nested_jax_expressions(self): pattern = Exp(matcher.Var('x')) rule = rules.term_rewriter(rules.make_rule(pattern, Log)) expr = Exp(Exp(2.)) self.assertEqual(rule(expr), Log(Log(2.)))
def test_var_pattern_errors_when_bound_value_doesnt_match(self): x = matcher.Var('x') with self.assertRaises(matcher.MatchError): matcher.match((x, x), (1, 2))
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================ """Tests for tensorflow_probability.spinoffs.oryx.experimental.matching.rules.""" from absl.testing import absltest from oryx.experimental.matching import matcher from oryx.experimental.matching import rules from oryx.internal import test_util is_number = lambda x: isinstance(x, (int, float)) is_positive = lambda x: x > 0 is_tuple = lambda x: isinstance(x, tuple) Number = lambda name: matcher.Var(name, restrictions=[is_number]) Positive = lambda name: matcher.Var(name, restrictions=[is_number, is_positive]) Tuple = lambda name: matcher.Var(name, restrictions=[is_tuple]) class RulesTest(test_util.TestCase): def test_rule_should_rewrite_matching_expression(self): one_rule = rules.make_rule(2., lambda: 1.) self.assertEqual(one_rule(2.), 1.) def test_rule_doesnt_rewrite_nonmatching_expression(self): one_rule = rules.make_rule(2., lambda: 1.) self.assertEqual(one_rule(0.), 0.) def test_rule_should_pass_bindings_into_rewrite(self):