Exemplo n.º 1
0
 def test_bottom_up_should_rewrite_from_children_to_root(self):
     rule = rules.bottom_up(
         rules.in_order(
             rules.make_rule(Positive('a'), lambda a: a + 1.),
             rules.make_rule(Tuple('t'), lambda t: sum(t))  # pylint: disable=unnecessary-lambda
         ))
     self.assertEqual(rule((1., 2., 3.)), 9.)
Exemplo n.º 2
0
 def test_top_down_should_rewrite_from_root_to_children(self):
     rule = rules.top_down(
         rules.in_order(
             rules.make_rule(Tuple('t'), lambda t: sum(t)),  # pylint: disable=unnecessary-lambda
             rules.make_rule(Positive('a'), lambda a: a + 1.),
         ))
     self.assertEqual(rule((1., 2., 3.)), 7.)
Exemplo n.º 3
0
    def test_can_rewrite_jax_functions_with_constants(self):
        pattern = Exp(matcher.Var('x'))
        rule = rules.term_rewriter(rules.make_rule(pattern, Log))

        def f(x):
            return jnp.exp(x) + jnp.ones(x.shape)

        f = jr.rewrite(f, rule)
        self.assertTrue((f(jnp.ones(5)) == jnp.ones(5)).all())
Exemplo n.º 4
0
    def test_can_rewrite_jax_functions_with_jit(self):
        pattern = Exp(matcher.Var('x'))
        rule = rules.term_rewriter(rules.make_rule(pattern, Log))

        def f(x):
            return jax.jit(jnp.exp)(x) + 1.

        f = jr.rewrite(f, rule)
        self.assertEqual(f(1.), 1.)
Exemplo n.º 5
0
 def test_can_replace_einsum_operands(self):
   x = JaxVar('x', (5,), jnp.float32)
   y = JaxVar('y', (5,), jnp.float32)
   z = JaxVar('y', (5,), jnp.float32)
   op = Einsum('a,a->', (x, y))
   pattern = Einsum(Var('formula'), (matcher.Segment('args'),))
   def replace_with_z(formula, args):
     del args
     return Einsum(formula, (z, z))
   replace_rule = rules.make_rule(pattern, replace_with_z)
   replaced_op = replace_rule(op)
   self.assertEqual(replaced_op, Einsum('a,a->', (z, z)))
Exemplo n.º 6
0
 def test_can_replace_addn_operands(self):
   x = JaxVar('x', (5,), jnp.float32)
   y = JaxVar('y', (5,), jnp.float32)
   z = JaxVar('y', (5,), jnp.float32)
   op = AddN((x, y))
   pattern = AddN((matcher.Segment('args'),))
   def replace_with_z(args):
     del args
     return AddN((z, z))
   replace_rule = rules.make_rule(pattern, replace_with_z)
   replaced_op = replace_rule(op)
   self.assertEqual(replaced_op, AddN((z, z)))
Exemplo n.º 7
0
 def test_rewrite_subexpressions_should_not_rewrite_primitive_types(self):
     rule = rules.rewrite_subexpressions(
         rules.make_rule(Positive('a'), lambda a: a + 1.))
     self.assertEqual(rule(1.), 1.)
     self.assertEqual(rule(-1.), -1.)
Exemplo n.º 8
0
 def test_rule_should_rewrite_matching_expression(self):
     one_rule = rules.make_rule(2., lambda: 1.)
     self.assertEqual(one_rule(2.), 1.)
Exemplo n.º 9
0
 def test_rule_doesnt_rewrite_nonmatching_expression(self):
     one_rule = rules.make_rule(2., lambda: 1.)
     self.assertEqual(one_rule(0.), 0.)
Exemplo n.º 10
0
    def test_can_rewrite_nested_jax_expressions(self):
        pattern = Exp(matcher.Var('x'))
        rule = rules.term_rewriter(rules.make_rule(pattern, Log))
        expr = Exp(Exp(2.))

        self.assertEqual(rule(expr), Log(Log(2.)))
Exemplo n.º 11
0
    def test_can_rewrite_simple_jax_expressions(self):
        pattern = Exp(matcher.Var('x'))
        rule = rules.make_rule(pattern, Log)
        expr = Exp(1.)

        self.assertEqual(rule(expr), Log(1.))
Exemplo n.º 12
0
 def test_rule_should_pass_bindings_into_rewrite(self):
     add_one = rules.make_rule(matcher.Var('a'), lambda a: a + 1)
     self.assertEqual(add_one(1.), 2.)
Exemplo n.º 13
0
 def test_bottom_up_should_recursively_rewrite_elements(self):
     rule = rules.bottom_up(rules.make_rule(Positive('a'),
                                            lambda a: a + 1.))
     self.assertEqual(rule(((-1., ), 1., (1., 1.))),
                      ((-1., ), 2., (2., 2.)))
Exemplo n.º 14
0
 def test_rewrite_subexpressions_should_not_recursively_rewrite_elements(
         self):
     rule = rules.rewrite_subexpressions(
         rules.make_rule(Positive('a'), lambda a: a + 1.))
     self.assertEqual(rule(((-1., ), 1., (1., 1.))),
                      ((-1., ), 2., (1., 1.)))
Exemplo n.º 15
0
 def test_term_rewriter_should_recursively_rewrite_until_convergence(self):
     rule = rules.term_rewriter(
         rules.make_rule(Positive('a'), lambda a: a - 1.))
     self.assertEqual(rule(1.), 0.)
     self.assertEqual(rule((1., 2., 3.)), (0., 0., 0.))
     self.assertEqual(rule(((1., 2.), 3.)), ((0., 0.), 0.))
Exemplo n.º 16
0
 def test_iterated_should_apply_rule_until_expression_no_longer_matches(
         self):
     rule = rules.iterated(rules.make_rule(Positive('a'), lambda a: a - 1.))
     self.assertEqual(rule(1.), 0.)
     self.assertEqual(rule(10.), 0.)
     self.assertEqual(rule(-1.), -1.)
Exemplo n.º 17
0
 def test_in_order_should_apply_rules_even_if_multiple_match(self):
     rule = rules.in_order(rules.make_rule(Positive('a'), lambda a: a + 2.),
                           rules.make_rule(Positive('a'), lambda a: a + 3.))
     self.assertEqual(rule(1.), 6.)
Exemplo n.º 18
0
 def test_rule_list_should_not_rewrite_expression_if_no_rules_match(self):
     rule = rules.rule_list(
         rules.make_rule(Positive('a'), lambda a: a + 2.),
         rules.make_rule(Positive('a'), lambda a: a + 3.))
     self.assertEqual(rule(-1.), -1.)
Exemplo n.º 19
0
 def test_rule_list_should_apply_first_rule_that_matches(self):
     rule = rules.rule_list(
         rules.make_rule(Positive('a'), lambda a: a + 1.),
         rules.make_rule(Number('a'), lambda a: a + 2.))
     self.assertEqual(rule(-1.), 1.)
     self.assertEqual(rule(2.), 3)
Exemplo n.º 20
0
 def test_rule_with_restrictions_should_not_rewrite_if_no_match(self):
     add_one = rules.make_rule(Number('a'), lambda a: a + 1)
     self.assertEqual(add_one(1.), 2.)
     self.assertEqual(add_one(()), ())
Exemplo n.º 21
0
 def register(handler: Callable[..., Any]) -> rules.Rule:
     return rules.make_rule(pattern, handler)
Exemplo n.º 22
0
 def test_rewrite_subexpressions_should_rewrite_tuple_elements(self):
     rule = rules.rewrite_subexpressions(
         rules.make_rule(Positive('a'), lambda a: a + 1.))
     self.assertEqual(rule((-1., 0., 1.)), (-1., 0., 2.))