def test_bottom_up_should_rewrite_from_children_to_root(self): rule = rules.bottom_up( rules.in_order( rules.make_rule(Positive('a'), lambda a: a + 1.), rules.make_rule(Tuple('t'), lambda t: sum(t)) # pylint: disable=unnecessary-lambda )) self.assertEqual(rule((1., 2., 3.)), 9.)
def test_top_down_should_rewrite_from_root_to_children(self): rule = rules.top_down( rules.in_order( rules.make_rule(Tuple('t'), lambda t: sum(t)), # pylint: disable=unnecessary-lambda rules.make_rule(Positive('a'), lambda a: a + 1.), )) self.assertEqual(rule((1., 2., 3.)), 7.)
def test_can_rewrite_jax_functions_with_constants(self): pattern = Exp(matcher.Var('x')) rule = rules.term_rewriter(rules.make_rule(pattern, Log)) def f(x): return jnp.exp(x) + jnp.ones(x.shape) f = jr.rewrite(f, rule) self.assertTrue((f(jnp.ones(5)) == jnp.ones(5)).all())
def test_can_rewrite_jax_functions_with_jit(self): pattern = Exp(matcher.Var('x')) rule = rules.term_rewriter(rules.make_rule(pattern, Log)) def f(x): return jax.jit(jnp.exp)(x) + 1. f = jr.rewrite(f, rule) self.assertEqual(f(1.), 1.)
def test_can_replace_einsum_operands(self): x = JaxVar('x', (5,), jnp.float32) y = JaxVar('y', (5,), jnp.float32) z = JaxVar('y', (5,), jnp.float32) op = Einsum('a,a->', (x, y)) pattern = Einsum(Var('formula'), (matcher.Segment('args'),)) def replace_with_z(formula, args): del args return Einsum(formula, (z, z)) replace_rule = rules.make_rule(pattern, replace_with_z) replaced_op = replace_rule(op) self.assertEqual(replaced_op, Einsum('a,a->', (z, z)))
def test_can_replace_addn_operands(self): x = JaxVar('x', (5,), jnp.float32) y = JaxVar('y', (5,), jnp.float32) z = JaxVar('y', (5,), jnp.float32) op = AddN((x, y)) pattern = AddN((matcher.Segment('args'),)) def replace_with_z(args): del args return AddN((z, z)) replace_rule = rules.make_rule(pattern, replace_with_z) replaced_op = replace_rule(op) self.assertEqual(replaced_op, AddN((z, z)))
def test_rewrite_subexpressions_should_not_rewrite_primitive_types(self): rule = rules.rewrite_subexpressions( rules.make_rule(Positive('a'), lambda a: a + 1.)) self.assertEqual(rule(1.), 1.) self.assertEqual(rule(-1.), -1.)
def test_rule_should_rewrite_matching_expression(self): one_rule = rules.make_rule(2., lambda: 1.) self.assertEqual(one_rule(2.), 1.)
def test_rule_doesnt_rewrite_nonmatching_expression(self): one_rule = rules.make_rule(2., lambda: 1.) self.assertEqual(one_rule(0.), 0.)
def test_can_rewrite_nested_jax_expressions(self): pattern = Exp(matcher.Var('x')) rule = rules.term_rewriter(rules.make_rule(pattern, Log)) expr = Exp(Exp(2.)) self.assertEqual(rule(expr), Log(Log(2.)))
def test_can_rewrite_simple_jax_expressions(self): pattern = Exp(matcher.Var('x')) rule = rules.make_rule(pattern, Log) expr = Exp(1.) self.assertEqual(rule(expr), Log(1.))
def test_rule_should_pass_bindings_into_rewrite(self): add_one = rules.make_rule(matcher.Var('a'), lambda a: a + 1) self.assertEqual(add_one(1.), 2.)
def test_bottom_up_should_recursively_rewrite_elements(self): rule = rules.bottom_up(rules.make_rule(Positive('a'), lambda a: a + 1.)) self.assertEqual(rule(((-1., ), 1., (1., 1.))), ((-1., ), 2., (2., 2.)))
def test_rewrite_subexpressions_should_not_recursively_rewrite_elements( self): rule = rules.rewrite_subexpressions( rules.make_rule(Positive('a'), lambda a: a + 1.)) self.assertEqual(rule(((-1., ), 1., (1., 1.))), ((-1., ), 2., (1., 1.)))
def test_term_rewriter_should_recursively_rewrite_until_convergence(self): rule = rules.term_rewriter( rules.make_rule(Positive('a'), lambda a: a - 1.)) self.assertEqual(rule(1.), 0.) self.assertEqual(rule((1., 2., 3.)), (0., 0., 0.)) self.assertEqual(rule(((1., 2.), 3.)), ((0., 0.), 0.))
def test_iterated_should_apply_rule_until_expression_no_longer_matches( self): rule = rules.iterated(rules.make_rule(Positive('a'), lambda a: a - 1.)) self.assertEqual(rule(1.), 0.) self.assertEqual(rule(10.), 0.) self.assertEqual(rule(-1.), -1.)
def test_in_order_should_apply_rules_even_if_multiple_match(self): rule = rules.in_order(rules.make_rule(Positive('a'), lambda a: a + 2.), rules.make_rule(Positive('a'), lambda a: a + 3.)) self.assertEqual(rule(1.), 6.)
def test_rule_list_should_not_rewrite_expression_if_no_rules_match(self): rule = rules.rule_list( rules.make_rule(Positive('a'), lambda a: a + 2.), rules.make_rule(Positive('a'), lambda a: a + 3.)) self.assertEqual(rule(-1.), -1.)
def test_rule_list_should_apply_first_rule_that_matches(self): rule = rules.rule_list( rules.make_rule(Positive('a'), lambda a: a + 1.), rules.make_rule(Number('a'), lambda a: a + 2.)) self.assertEqual(rule(-1.), 1.) self.assertEqual(rule(2.), 3)
def test_rule_with_restrictions_should_not_rewrite_if_no_match(self): add_one = rules.make_rule(Number('a'), lambda a: a + 1) self.assertEqual(add_one(1.), 2.) self.assertEqual(add_one(()), ())
def register(handler: Callable[..., Any]) -> rules.Rule: return rules.make_rule(pattern, handler)
def test_rewrite_subexpressions_should_rewrite_tuple_elements(self): rule = rules.rewrite_subexpressions( rules.make_rule(Positive('a'), lambda a: a + 1.)) self.assertEqual(rule((-1., 0., 1.)), (-1., 0., 2.))