Esempio n. 1
0
    def test_staging_primitive_applications(self):
        n = core.DShapedArray((), jnp.dtype('int32'), weak_type=False)
        a = core.DShapedArray((pe.DBIdx(0), ),
                              jnp.dtype('float32'),
                              weak_type=False)
        b = core.DShapedArray((pe.DBIdx(0), ),
                              jnp.dtype('float32'),
                              weak_type=False)

        @lu.wrap_init
        def f(x, y):
            z = lax.mul(x, y)
            w = lax.sin(z)
            u = lax_internal._reduce_sum(w, [0])
            return (u, )

        jaxpr, _, _ = pe.trace_to_jaxpr_dynamic(
            f, [n, a, b], keep_inputs=[False, True, True])

        self.assertLen(jaxpr.invars,
                       1 + 2)  # one axis size var, two other inputs
        self.assertLen(jaxpr.eqns, 3)
        self.assertLen(jaxpr.eqns[0].outvars, 1)
        self.assertEqual(jaxpr.eqns[0].outvars[0].aval.shape,
                         jaxpr.invars[1].aval.shape)

        self.assertLen(jaxpr.outvars, 1)
        self.assertEqual(jaxpr.outvars[0].aval.shape, ())
Esempio n. 2
0
    def test_staging_nested_including_shape_arg(self):
        n = core.DShapedArray((), jnp.dtype('int32'), weak_type=False)
        a = core.DShapedArray((pe.DBIdx(0), ),
                              jnp.dtype('float32'),
                              weak_type=False)
        b = core.DShapedArray((pe.DBIdx(0), ),
                              jnp.dtype('float32'),
                              weak_type=False)

        @lu.wrap_init
        def f(x, y):
            @jax.jit
            def g(_, x, y, z, w):
                return (x, w)

            return g(x.shape[0], x, y, x, y)

        jaxpr, _, _ = pe.trace_to_jaxpr_dynamic(
            f, [n, a, b], keep_inputs=[False, True, True])
        print(jaxpr)

        # { lambda ; a:i32[] b:f32[a] c:f32[a]. let
        #     d:f32[a] e:f32[a] = xla_call[
        #       call_jaxpr={ lambda ; f:i32[] g:i32[] h:f32[f] i:f32[f] j:f32[f] k:f32[f]. let
        #
        #         in (h, k) }
        #       name=g
        #     ] a a b c b c
        #   in (d, e) }

        self.assertLen(jaxpr.eqns, 1)
        eqn = jaxpr.eqns[0]
        self.assertIsInstance(eqn.primitive, core.CallPrimitive)
        inner_jaxpr = eqn.params['call_jaxpr']
        self.assertIsInstance(inner_jaxpr, core.Jaxpr)

        self.assertLen(inner_jaxpr.invars, 1 + 4)  # one axis size var
        self.assertEqual((inner_jaxpr.invars[0], ),
                         inner_jaxpr.invars[1].aval.shape)
        self.assertEqual((inner_jaxpr.invars[0], ),
                         inner_jaxpr.invars[2].aval.shape)
        self.assertEqual((inner_jaxpr.invars[0], ),
                         inner_jaxpr.invars[3].aval.shape)
        self.assertEqual((inner_jaxpr.invars[0], ),
                         inner_jaxpr.invars[4].aval.shape)
Esempio n. 3
0
    def test_staging_nested(self):
        n = core.DShapedArray((), jnp.dtype('int32'), weak_type=False)
        a = core.DShapedArray((pe.DBIdx(0), ),
                              jnp.dtype('float32'),
                              weak_type=False)
        b = core.DShapedArray((pe.DBIdx(0), ),
                              jnp.dtype('float32'),
                              weak_type=False)

        @lu.wrap_init
        def f(x, y):
            @jax.jit
            def g(x, y, z, w):
                return (x, w)

            return g(x, y, x, y)

        jaxpr, _, _ = pe.trace_to_jaxpr_dynamic(
            f, [n, a, b], keep_inputs=[False, True, True])

        self.assertLen(jaxpr.invars,
                       1 + 2)  # one axis size var, two other inputs
        self.assertEqual((jaxpr.invars[0], ), jaxpr.invars[1].aval.shape)
        self.assertEqual((jaxpr.invars[0], ), jaxpr.invars[2].aval.shape)

        self.assertLen(jaxpr.outvars, 2)
        self.assertEqual((jaxpr.invars[0], ), jaxpr.outvars[0].aval.shape)
        self.assertEqual((jaxpr.invars[0], ), jaxpr.outvars[1].aval.shape)

        self.assertLen(jaxpr.eqns, 1)
        eqn = jaxpr.eqns[0]
        self.assertIsInstance(eqn.primitive, core.CallPrimitive)
        inner_jaxpr = eqn.params['call_jaxpr']
        self.assertIsInstance(inner_jaxpr, core.Jaxpr)

        self.assertLen(inner_jaxpr.invars, 1 + 4)  # one axis size var
        self.assertEqual((inner_jaxpr.invars[0], ),
                         inner_jaxpr.invars[1].aval.shape)
        self.assertEqual((inner_jaxpr.invars[0], ),
                         inner_jaxpr.invars[2].aval.shape)
        self.assertEqual((inner_jaxpr.invars[0], ),
                         inner_jaxpr.invars[3].aval.shape)
        self.assertEqual((inner_jaxpr.invars[0], ),
                         inner_jaxpr.invars[4].aval.shape)
Esempio n. 4
0
    def test_staging_basic(self):
        n = core.ShapedArray((), jnp.dtype('int32'), weak_type=False)
        a = core.DShapedArray((pe.DBIdx(0), ),
                              jnp.dtype('float32'),
                              weak_type=False)
        b = core.DShapedArray((pe.DBIdx(0), ),
                              jnp.dtype('float32'),
                              weak_type=False)

        @lu.wrap_init
        def f(x, y):
            return x, y

        jaxpr, _, _ = pe.trace_to_jaxpr_dynamic(
            f, [n, a, b], keep_inputs=[False, True, True])

        self.assertLen(jaxpr.invars, 3)
        self.assertEqual((jaxpr.invars[0], ), jaxpr.invars[1].aval.shape)
        self.assertEqual((jaxpr.invars[0], ), jaxpr.invars[2].aval.shape)

        self.assertLen(jaxpr.outvars, 2)
        self.assertEqual((jaxpr.invars[0], ), jaxpr.outvars[0].aval.shape)
        self.assertEqual((jaxpr.invars[0], ), jaxpr.outvars[1].aval.shape)