Exemplo n.º 1
0
    def test_primitive_should_evaluate_to_jax_values(self):
        expr = Exp(0.)
        self.assertEqual(jr.evaluate(expr, {}), jnp.exp(0.))

        expr = jr.Primitive(lax.add_p, (1., 2.), jr.Params())
        self.assertEqual(jr.evaluate(expr, {}), 3.)

        expr = jr.Primitive(lax.add_p, (jr.JaxVar('a', (), jnp.float32), 2.),
                            jr.Params())
        self.assertEqual(jr.evaluate(expr, {'a': 1.}), 3.)
Exemplo n.º 2
0
    def test_primitive_should_infer_shape_dtype_correctly(self):
        expr = Exp(0.)
        self.assertTupleEqual(expr.shape, ())
        self.assertEqual(expr.dtype, jnp.float32)

        expr = Exp(jr.JaxVar('a', (5, ), jnp.float32))
        self.assertTupleEqual(expr.shape, (5, ))
        self.assertEqual(expr.dtype, jnp.float32)

        expr = jr.Primitive(lax.add_p, (jr.Literal(1), jr.Literal(2)),
                            jr.Params())
        self.assertTupleEqual(expr.shape, ())
        self.assertEqual(expr.dtype, jnp.int32)
Exemplo n.º 3
0
 def test_evaluate_jaxvar_should_look_up_name_in_environment(self):
     self.assertEqual(
         jr.evaluate(jr.JaxVar('a', (), jnp.float32), {
             'a': 1.,
             'b': 2.
         }), 1.)
Exemplo n.º 4
0
 def test_evaluate_tuple_should_recursively_evaluate_values(self):
     self.assertTupleEqual(jr.evaluate((jr.Literal(1.), 2.), {}), (1., 2.))
     self.assertTupleEqual(
         jr.evaluate((jr.JaxVar('a', (), jnp.float32), 2.), {'a': 1.}),
         (1., 2.))
Exemplo n.º 5
0
 def test_can_match_jax_var_name_shape_and_dtype(self):
     pattern = jr.JaxVar(matcher.Var('name'), matcher.Var('shape'),
                         matcher.Var('dtype'))
     expr = jr.JaxVar('a', (1, 2, 3), jnp.int32)
     self.assertDictEqual(matcher.match(pattern, expr),
                          dict(name='a', shape=(1, 2, 3), dtype=jnp.int32))