Exemplo n.º 1
0
    def test_star_with_accumulate_collects_values(self):
        pattern = (matcher.Star((matcher.Var('x'), matcher.Var('y')),
                                accumulate=['y']), )
        self.assertDictEqual(matcher.match(pattern, ((1, 2), (1, 3))),
                             dict(x=1, y=(2, 3)))

        with self.assertRaises(matcher.MatchError):
            matcher.match(pattern, ((1, 2), (2, 3)))
Exemplo n.º 2
0
 def test_can_match_primitive_inside_of_pattern(self):
     pattern = jr.Primitive(matcher.Var('prim'),
                            (matcher.Segment('args'), ),
                            matcher.Var('params'))
     expr = Exp(jr.Literal(1.))
     self.assertDictEqual(
         matcher.match(pattern, expr),
         dict(prim=lax.exp_p, args=(jr.Literal(1.), ), params=jr.Params()))
Exemplo n.º 3
0
 def test_can_match_call_primitive_parts(self):
     pattern = jr.CallPrimitive(matcher.Var('prim'), matcher.Var('args'),
                                matcher.Var('expression'),
                                matcher.Var('params'), matcher.Var('names'))
     expr = jr.CallPrimitive(jax.core.call_p, (),
                             (jr.Literal(0.), jr.Literal(1)), jr.Params(),
                             [])
     self.assertDictEqual(
         matcher.match(pattern, expr),
         dict(prim=jax.core.call_p,
              args=(),
              expression=(jr.Literal(0.), jr.Literal(1.)),
              params=jr.Params(),
              names=[]))
Exemplo n.º 4
0
    def test_star_match_binds_name_to_environment(self):
        pattern = (matcher.Star(matcher.Var('x')), )
        self.assertDictEqual(matcher.match(pattern, ()), {})
        self.assertDictEqual(matcher.match(pattern, (1, )), dict(x=1))
        self.assertDictEqual(matcher.match(pattern, (1, 1)), dict(x=1))

        with self.assertRaises(matcher.MatchError):
            matcher.match(pattern, (1, 2))
Exemplo n.º 5
0
    def test_can_rewrite_jax_functions_with_jit(self):
        pattern = Exp(matcher.Var('x'))
        rule = rules.term_rewriter(rules.make_rule(pattern, Log))

        def f(x):
            return jax.jit(jnp.exp)(x) + 1.

        f = jr.rewrite(f, rule)
        self.assertEqual(f(1.), 1.)
Exemplo n.º 6
0
 def test_var_correctly_applies_restrictions_when_matching(self):
     is_positive = lambda a: a > 0
     is_even = lambda a: a % 2 == 0
     x = matcher.Var('x', restrictions=[is_positive, is_even])
     self.assertDictEqual(matcher.match(x, 2.), dict(x=2))
     with self.assertRaises(matcher.MatchError):
         matcher.match(x, -2)
     with self.assertRaises(matcher.MatchError):
         matcher.match(x, 3)
Exemplo n.º 7
0
    def test_can_rewrite_jax_functions_with_constants(self):
        pattern = Exp(matcher.Var('x'))
        rule = rules.term_rewriter(rules.make_rule(pattern, Log))

        def f(x):
            return jnp.exp(x) + jnp.ones(x.shape)

        f = jr.rewrite(f, rule)
        self.assertTrue((f(jnp.ones(5)) == jnp.ones(5)).all())
Exemplo n.º 8
0
 def test_choice_will_backtrack_and_try_other_options(self):
     x, y = matcher.Var('x'), matcher.Var('y')
     self.assertDictEqual(
         matcher.match((matcher.Choice(x, y), x, y), (3, 2, 3)),
         dict(x=2, y=3))
Exemplo n.º 9
0
 def test_rule_should_pass_bindings_into_rewrite(self):
     add_one = rules.make_rule(matcher.Var('a'), lambda a: a + 1)
     self.assertEqual(add_one(1.), 2.)
Exemplo n.º 10
0
 def test_var_pattern_matches_any_expression(self):
     self.assertDictEqual(matcher.match(matcher.Var('x'), 1.), {'x': 1.})
     self.assertDictEqual(matcher.match(matcher.Var('x'), 'hello'),
                          {'x': 'hello'})
Exemplo n.º 11
0
 def test_can_match_part_op_and_index(self):
     pattern = jr.Part((matcher.Segment('args'), ), matcher.Var('i'))
     expr = jr.Part((jr.Literal(0.), jr.Literal(1.)), 1)
     self.assertDictEqual(matcher.match(pattern, expr),
                          dict(args=(jr.Literal(0.), jr.Literal(1.)), i=1))
Exemplo n.º 12
0
 def test_can_match_jax_var_name_shape_and_dtype(self):
     pattern = jr.JaxVar(matcher.Var('name'), matcher.Var('shape'),
                         matcher.Var('dtype'))
     expr = jr.JaxVar('a', (1, 2, 3), jnp.int32)
     self.assertDictEqual(matcher.match(pattern, expr),
                          dict(name='a', shape=(1, 2, 3), dtype=jnp.int32))
Exemplo n.º 13
0
 def test_can_match_literal_value(self):
     pattern = jr.Literal(matcher.Var('x'))
     expr = jr.Literal(0.)
     self.assertDictEqual(matcher.match(pattern, expr), dict(x=0.))
Exemplo n.º 14
0
 def test_can_match_value_inside_params(self):
     pattern = jr.Primitive(matcher.Dot, (matcher.Segment(name=None), ),
                            jr.Params({'foo': matcher.Var('foo')}))
     expr = jr.Primitive(lax.iota_p, (), jr.Params(foo='bar'))
     self.assertDictEqual(matcher.match(pattern, expr), dict(foo='bar'))
Exemplo n.º 15
0
 def test_can_match_input_to_primitive(self):
     pattern = Exp(matcher.Var('x'))
     expr = Exp(Exp(jr.Literal(0.)))
     self.assertListEqual(list(matcher.match_all(pattern, expr)),
                          [dict(x=Exp(jr.Literal(0.)))])
Exemplo n.º 16
0
 def test_star_with_name_binds_result(self):
     pattern = (matcher.Star(1, name='x'), )
     self.assertDictEqual(matcher.match(pattern, (1, 1)), dict(x=(1, 1)))
     pattern = (matcher.Star(matcher.Var('y'), name='x'), )
     self.assertDictEqual(matcher.match(pattern, (1, 1)), dict(y=1,
                                                               x=(1, 1)))
Exemplo n.º 17
0
 def test_can_match_var_with_length_one_string(self):
     pattern = matcher.Var('x')
     self.assertDictEqual(matcher.match(pattern, 'a'), {'x': 'a'})
Exemplo n.º 18
0
    def test_can_rewrite_simple_jax_expressions(self):
        pattern = Exp(matcher.Var('x'))
        rule = rules.make_rule(pattern, Log)
        expr = Exp(1.)

        self.assertEqual(rule(expr), Log(1.))
Exemplo n.º 19
0
 def test_var_pattern_matches_when_bound_value_matches(self):
     x = matcher.Var('x')
     self.assertDictEqual(matcher.match((x, x), (1, 1)), dict(x=1))
Exemplo n.º 20
0
    def test_can_rewrite_nested_jax_expressions(self):
        pattern = Exp(matcher.Var('x'))
        rule = rules.term_rewriter(rules.make_rule(pattern, Log))
        expr = Exp(Exp(2.))

        self.assertEqual(rule(expr), Log(Log(2.)))
Exemplo n.º 21
0
 def test_var_pattern_errors_when_bound_value_doesnt_match(self):
     x = matcher.Var('x')
     with self.assertRaises(matcher.MatchError):
         matcher.match((x, x), (1, 2))
Exemplo n.º 22
0
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Tests for tensorflow_probability.spinoffs.oryx.experimental.matching.rules."""

from absl.testing import absltest
from oryx.experimental.matching import matcher
from oryx.experimental.matching import rules
from oryx.internal import test_util

is_number = lambda x: isinstance(x, (int, float))
is_positive = lambda x: x > 0
is_tuple = lambda x: isinstance(x, tuple)

Number = lambda name: matcher.Var(name, restrictions=[is_number])
Positive = lambda name: matcher.Var(name,
                                    restrictions=[is_number, is_positive])
Tuple = lambda name: matcher.Var(name, restrictions=[is_tuple])


class RulesTest(test_util.TestCase):
    def test_rule_should_rewrite_matching_expression(self):
        one_rule = rules.make_rule(2., lambda: 1.)
        self.assertEqual(one_rule(2.), 1.)

    def test_rule_doesnt_rewrite_nonmatching_expression(self):
        one_rule = rules.make_rule(2., lambda: 1.)
        self.assertEqual(one_rule(0.), 0.)

    def test_rule_should_pass_bindings_into_rewrite(self):