Example #1
0
    def testNestedXMapDifferentResources(self):
        @partial(xmap,
                 in_axes=A({'a': 0}),
                 out_axes=A({'a': 0}),
                 schedule=[('a', 'x')])
        def f(x):
            with mesh(np.empty((), dtype=np.object), ()):

                @partial(xmap,
                         in_axes=A({'b': 0}),
                         out_axes=A({'b': 0}),
                         schedule=[('b', 'vectorize')])
                def h(x):
                    return x

                return h(x)

        xshape = (2, 5, 6)
        x = jnp.arange(np.prod(xshape)).reshape(xshape)
        with self.assertRaisesRegex(RuntimeError,
                                    "Changing the resource environment.*"):
            f(x)
Example #2
0
    def testPdotBatching(self):
        def f(x, y):
            return lax.pdot(x, y, 'i')

        rng = np.random.RandomState(0)
        x = rng.randn(2, 3, 8)
        y = rng.randn(2, 8, 5)

        f_mapped = xmap(f,
                        in_axes=[A({
                            'i': 2,
                            'j': 0
                        }), A({
                            'i': 1,
                            'j': 0
                        })],
                        out_axes=A({'j': 0}),
                        schedule=[('j', 'vectorize'), ('i', 'r1'),
                                  ('i', 'vectorize')])

        z = f_mapped(x, y)

        self.assertAllClose(z, jnp.einsum('nij,njk->nik', x, y))
Example #3
0
    def testXMapCollectives(self):
        def f(a, b):
            return lax.psum(a + 2, 'x'), b * 4

        fm = xmap(f,
                  in_axes=[A({
                      'x': 0,
                      'z': 1
                  }), A({'y': 1})],
                  out_axes=[A({'z': 0}), A({'y': 0})],
                  schedule=[
                      ('x', 'r1'),
                      ('x', 'r2'),
                      ('y', 'r1'),
                      ('z', 'r3'),
                      ('x', 'vectorize'),
                      ('y', 'vectorize'),
                  ])
        a = jnp.arange(16 * 5 * 2).reshape((16, 5, 2))
        b = jnp.arange(6 * 16).reshape((6, 16))
        c, d = fm(a, b)
        self.assertAllClose(c, (a + 2).sum(0))
        self.assertAllClose(d, (b * 4).T)