Exemplo n.º 1
0
    def test_primitive_should_evaluate_to_jax_values(self):
        expr = Exp(0.)
        self.assertEqual(jr.evaluate(expr, {}), jnp.exp(0.))

        expr = jr.Primitive(lax.add_p, (1., 2.), jr.Params())
        self.assertEqual(jr.evaluate(expr, {}), 3.)

        expr = jr.Primitive(lax.add_p, (jr.JaxVar('a', (), jnp.float32), 2.),
                            jr.Params())
        self.assertEqual(jr.evaluate(expr, {'a': 1.}), 3.)
Exemplo n.º 2
0
 def test_can_match_primitive_inside_of_pattern(self):
     pattern = jr.Primitive(matcher.Var('prim'),
                            (matcher.Segment('args'), ),
                            matcher.Var('params'))
     expr = Exp(jr.Literal(1.))
     self.assertDictEqual(
         matcher.match(pattern, expr),
         dict(prim=lax.exp_p, args=(jr.Literal(1.), ), params=jr.Params()))
Exemplo n.º 3
0
    def test_primitive_should_infer_shape_dtype_correctly(self):
        expr = Exp(0.)
        self.assertTupleEqual(expr.shape, ())
        self.assertEqual(expr.dtype, jnp.float32)

        expr = Exp(jr.JaxVar('a', (5, ), jnp.float32))
        self.assertTupleEqual(expr.shape, (5, ))
        self.assertEqual(expr.dtype, jnp.float32)

        expr = jr.Primitive(lax.add_p, (jr.Literal(1), jr.Literal(2)),
                            jr.Params())
        self.assertTupleEqual(expr.shape, ())
        self.assertEqual(expr.dtype, jnp.int32)
Exemplo n.º 4
0
 def test_can_match_value_inside_params(self):
     pattern = jr.Primitive(matcher.Dot, (matcher.Segment(name=None), ),
                            jr.Params({'foo': matcher.Var('foo')}))
     expr = jr.Primitive(lax.iota_p, (), jr.Params(foo='bar'))
     self.assertDictEqual(matcher.match(pattern, expr), dict(foo='bar'))
Exemplo n.º 5
0
# limitations under the License.
# ============================================================================
"""Tests for tensorflow_probability.spinoffs.oryx.experimental.matching.jax_rewrite."""

from absl.testing import absltest

import jax
from jax import lax
import jax.numpy as jnp

from oryx.experimental.matching import jax_rewrite as jr
from oryx.experimental.matching import matcher
from oryx.experimental.matching import rules
from oryx.internal import test_util

Exp = lambda x: jr.Primitive(lax.exp_p, (x, ), jr.Params())
Log = lambda x: jr.Primitive(lax.log_p, (x, ), jr.Params())


class JaxExpressionTest(test_util.TestCase):
    def test_evaluate_value_should_return_value(self):
        self.assertEqual(jr.evaluate(1., {}), 1.)
        self.assertTrue((jr.evaluate(jnp.ones(5), {}) == jnp.ones(5)).all())

    def test_evaluate_literal_should_evaluate_to_value(self):
        self.assertEqual(jr.evaluate(jr.Literal(1.), {}), 1.)
        self.assertTrue((jr.evaluate(jr.Literal(jnp.ones(5)),
                                     {}) == jnp.ones(5)).all())

    def test_evaluate_jaxvar_should_look_up_name_in_environment(self):
        self.assertEqual(