Beispiel #1
0
 def test_can_match_primitive_inside_of_pattern(self):
     pattern = jr.Primitive(matcher.Var('prim'),
                            (matcher.Segment('args'), ),
                            matcher.Var('params'))
     expr = Exp(jr.Literal(1.))
     self.assertDictEqual(
         matcher.match(pattern, expr),
         dict(prim=lax.exp_p, args=(jr.Literal(1.), ), params=jr.Params()))
Beispiel #2
0
 def test_part_infers_correct_shape_dtype(self):
     call_expr = jr.CallPrimitive(jax.core.call_p, (),
                                  (jr.Literal(0.), jr.Literal(1)),
                                  jr.Params(), [])
     p0_expr = jr.Part(call_expr, 0)
     p1_expr = jr.Part(call_expr, 1)
     self.assertTupleEqual(p0_expr.shape, ())
     self.assertTupleEqual(p1_expr.shape, ())
     self.assertEqual(p0_expr.dtype, jnp.float32)
     self.assertEqual(p1_expr.dtype, jnp.int32)
     self.assertEqual(jr.evaluate(p0_expr, {}), 0.)
     self.assertEqual(jr.evaluate(p1_expr, {}), 1)
Beispiel #3
0
    def test_primitive_should_infer_shape_dtype_correctly(self):
        expr = Exp(0.)
        self.assertTupleEqual(expr.shape, ())
        self.assertEqual(expr.dtype, jnp.float32)

        expr = Exp(jr.JaxVar('a', (5, ), jnp.float32))
        self.assertTupleEqual(expr.shape, (5, ))
        self.assertEqual(expr.dtype, jnp.float32)

        expr = jr.Primitive(lax.add_p, (jr.Literal(1), jr.Literal(2)),
                            jr.Params())
        self.assertTupleEqual(expr.shape, ())
        self.assertEqual(expr.dtype, jnp.int32)
Beispiel #4
0
 def test_can_match_call_primitive_parts(self):
     pattern = jr.CallPrimitive(matcher.Var('prim'), matcher.Var('args'),
                                matcher.Var('expression'),
                                matcher.Var('params'), matcher.Var('names'))
     expr = jr.CallPrimitive(jax.core.call_p, (),
                             (jr.Literal(0.), jr.Literal(1)), jr.Params(),
                             [])
     self.assertDictEqual(
         matcher.match(pattern, expr),
         dict(prim=jax.core.call_p,
              args=(),
              expression=(jr.Literal(0.), jr.Literal(1.)),
              params=jr.Params(),
              names=[]))
Beispiel #5
0
 def test_call_primitive_shape_and_dtype_are_multi_part(self):
     exp_expr = Exp(jr.Literal(0.))
     call_expr = jr.CallPrimitive(jax.core.call_p, (), (exp_expr, ),
                                  jr.Params(), [])
     self.assertTupleEqual(call_expr.shape, ((), ))
     self.assertEqual(call_expr.dtype, (jnp.float32, ))
Beispiel #6
0
 def test_call_primitive_should_include_call_in_trace(self):
     exp_expr = Exp(jr.Literal(0.))
     call_expr = jr.CallPrimitive(jax.core.call_p, (), (exp_expr, ),
                                  jr.Params(), [])
     jaxpr = jax.make_jaxpr(lambda: jr.evaluate(call_expr, {}))()
     self.assertEqual(jaxpr.jaxpr.eqns[0].primitive, jax.core.call_p)
Beispiel #7
0
 def test_evaluate_tuple_should_recursively_evaluate_values(self):
     self.assertTupleEqual(jr.evaluate((jr.Literal(1.), 2.), {}), (1., 2.))
     self.assertTupleEqual(
         jr.evaluate((jr.JaxVar('a', (), jnp.float32), 2.), {'a': 1.}),
         (1., 2.))
Beispiel #8
0
 def test_evaluate_literal_should_evaluate_to_value(self):
     self.assertEqual(jr.evaluate(jr.Literal(1.), {}), 1.)
     self.assertTrue((jr.evaluate(jr.Literal(jnp.ones(5)),
                                  {}) == jnp.ones(5)).all())
Beispiel #9
0
 def test_can_match_part_op_and_index(self):
     pattern = jr.Part((matcher.Segment('args'), ), matcher.Var('i'))
     expr = jr.Part((jr.Literal(0.), jr.Literal(1.)), 1)
     self.assertDictEqual(matcher.match(pattern, expr),
                          dict(args=(jr.Literal(0.), jr.Literal(1.)), i=1))
Beispiel #10
0
 def test_can_match_literal_value(self):
     pattern = jr.Literal(matcher.Var('x'))
     expr = jr.Literal(0.)
     self.assertDictEqual(matcher.match(pattern, expr), dict(x=0.))
Beispiel #11
0
 def test_can_match_input_to_primitive(self):
     pattern = Exp(matcher.Var('x'))
     expr = Exp(Exp(jr.Literal(0.)))
     self.assertListEqual(list(matcher.match_all(pattern, expr)),
                          [dict(x=Exp(jr.Literal(0.)))])