Example #1
0
    def test_primitive_should_evaluate_to_jax_values(self):
        expr = Exp(0.)
        self.assertEqual(jr.evaluate(expr, {}), jnp.exp(0.))

        expr = jr.Primitive(lax.add_p, (1., 2.), jr.Params())
        self.assertEqual(jr.evaluate(expr, {}), 3.)

        expr = jr.Primitive(lax.add_p, (jr.JaxVar('a', (), jnp.float32), 2.),
                            jr.Params())
        self.assertEqual(jr.evaluate(expr, {'a': 1.}), 3.)
Example #2
0
 def test_part_infers_correct_shape_dtype(self):
     call_expr = jr.CallPrimitive(jax.core.call_p, (),
                                  (jr.Literal(0.), jr.Literal(1)),
                                  jr.Params(), [])
     p0_expr = jr.Part(call_expr, 0)
     p1_expr = jr.Part(call_expr, 1)
     self.assertTupleEqual(p0_expr.shape, ())
     self.assertTupleEqual(p1_expr.shape, ())
     self.assertEqual(p0_expr.dtype, jnp.float32)
     self.assertEqual(p1_expr.dtype, jnp.int32)
     self.assertEqual(jr.evaluate(p0_expr, {}), 0.)
     self.assertEqual(jr.evaluate(p1_expr, {}), 1)
Example #3
0
 def evaluate(self, env: Dict[str, Any]) -> Any:
     """Evaluates an `Einsum` in an environment."""
     operands = jr.evaluate(self.operands, env)
     return jnp.einsum(self.formula, *operands)
Example #4
0
 def evaluate(self, env: Dict[str, Any]) -> Any:
     """Evaluates an `AddN` in an environment."""
     operands = jr.evaluate(self.operands, env)
     return functools.reduce(operator.add, operands)
Example #5
0
 def test_call_primitive_should_include_call_in_trace(self):
     exp_expr = Exp(jr.Literal(0.))
     call_expr = jr.CallPrimitive(jax.core.call_p, (), (exp_expr, ),
                                  jr.Params(), [])
     jaxpr = jax.make_jaxpr(lambda: jr.evaluate(call_expr, {}))()
     self.assertEqual(jaxpr.jaxpr.eqns[0].primitive, jax.core.call_p)
Example #6
0
 def test_evaluate_tuple_should_recursively_evaluate_values(self):
     self.assertTupleEqual(jr.evaluate((jr.Literal(1.), 2.), {}), (1., 2.))
     self.assertTupleEqual(
         jr.evaluate((jr.JaxVar('a', (), jnp.float32), 2.), {'a': 1.}),
         (1., 2.))
Example #7
0
 def test_evaluate_jaxvar_should_look_up_name_in_environment(self):
     self.assertEqual(
         jr.evaluate(jr.JaxVar('a', (), jnp.float32), {
             'a': 1.,
             'b': 2.
         }), 1.)
Example #8
0
 def test_evaluate_literal_should_evaluate_to_value(self):
     self.assertEqual(jr.evaluate(jr.Literal(1.), {}), 1.)
     self.assertTrue((jr.evaluate(jr.Literal(jnp.ones(5)),
                                  {}) == jnp.ones(5)).all())
Example #9
0
 def test_evaluate_value_should_return_value(self):
     self.assertEqual(jr.evaluate(1., {}), 1.)
     self.assertTrue((jr.evaluate(jnp.ones(5), {}) == jnp.ones(5)).all())