예제 #1
0
파일: gmap_test.py 프로젝트: Henrys-Lab/jax
    def testXMapMeshCollectives(self):
        local_devices = list(jax.local_devices())
        if len(local_devices) < 4:
            raise SkipTest("Test requires at least 4 local devices")

        def f(a, b):
            return lax.psum(a * 2, 'a'), b * 4

        devices = np.array(local_devices[:4]).reshape((2, 2))
        with mesh(devices, ('x', 'y')):
            fm = xmap(f,
                      in_axes=[A({
                          'a': 0,
                          'b': 1
                      }), A({'c': 0})],
                      out_axes=[A({'b': 0}), A({'c': 0})],
                      schedule=[
                          ('a', 'x'),
                          ('b', 'y'),
                          ('c', 'x'),
                          ('a', 'vectorize'),
                          ('b', 'vectorize'),
                      ])
            ashape = (16, 8, 5)
            a = jnp.arange(np.prod(ashape)).reshape(ashape)
            bshape = (2, 7)
            b = jnp.arange(np.prod(bshape)).reshape(bshape)
            c, d = fm(a, b)
            self.assertAllClose(c, (a * 2).sum(0))
            self.assertAllClose(d, b * 4)
예제 #2
0
파일: gmap_test.py 프로젝트: Henrys-Lab/jax
    def testXMap(self):
        def f(a, b):
            return a + 2, b * 4

        fm = xmap(f,
                  in_axes=[A({
                      'x': 0,
                      'z': 1
                  }), A({'y': 1})],
                  out_axes=[A({
                      'x': 1,
                      'z': 0
                  }), A({'y': 0})],
                  schedule=[
                      ('x', 'r1'),
                      ('x', 'r2'),
                      ('y', 'r1'),
                      ('z', 'r3'),
                      ('x', 'vectorize'),
                      ('y', 'vectorize'),
                  ])
        a = jnp.arange(16 * 5 * 2).reshape((16, 5, 2))
        b = jnp.arange(6 * 16).reshape((6, 16))
        c, d = fm(a, b)
        self.assertAllClose(c, (a + 2).transpose((1, 0, 2)))
        self.assertAllClose(d, (b * 4).T)
예제 #3
0
 def testXMapCompilationCache(self):
   def f(x):
     assert python_should_be_executing
     return x * 2
   fm = xmap(f,
             in_axes=A({'a': 0}),
             out_axes=A({'a': 0}),
             schedule=[('a', 'x'), ('a', 'vectorize')])
   x = np.arange(8).reshape((2, 2, 2))
   python_should_be_executing = True
   fm(x)
   python_should_be_executing = False
   fm(x)
예제 #4
0
  def testPdotBasic(self):
    def f(x, y):
      return lax.pdot(x, y, 'i')

    f_mapped = xmap(f, in_axes=[A({'i': 1}), A({'i': 0})], out_axes=A(),
                    schedule=[('i', 'r1'), ('i', 'vectorize')])

    rng = np.random.RandomState(0)
    x = rng.randn(3, 8)
    y = rng.randn(8, 5)

    z = f_mapped(x, y)

    self.assertAllClose(z, jnp.dot(x, y))
예제 #5
0
  def testPdotBatching(self):
    def f(x, y):
      return lax.pdot(x, y, 'i')

    rng = np.random.RandomState(0)
    x = rng.randn(2, 3, 8)
    y = rng.randn(2, 8, 5)

    f_mapped = xmap(f,
                    in_axes=[A({'i': 2, 'j': 0}), A({'i': 1, 'j': 0})],
                    out_axes=A({'j': 0}),
                    schedule=[('j', 'vectorize'), ('i', 'r1'), ('i', 'vectorize')])

    z = f_mapped(x, y)

    self.assertAllClose(z, jnp.einsum('nij,njk->nik', x, y))
예제 #6
0
파일: gmap_test.py 프로젝트: xeransis/jax
    def testXMapCollectives(self):
        def f(a, b):
            return lax.psum(a + 2, 'x'), b * 4

        with fake_resources(r1=4, r2=2, r3=5):
            fm = xmap(f,
                      in_axes=[A({
                          'x': 0,
                          'z': 1
                      }), A({'y': 1})],
                      out_axes=[A({'z': 0}), A({'y': 0})],
                      schedule=[
                          ('x', 'r1'),
                          ('x', 'r2'),
                          ('y', 'r1'),
                          ('z', 'r3'),
                          ('x', 'vectorize'),
                          ('y', 'vectorize'),
                      ])
            a = jnp.arange(16 * 5 * 2).reshape((16, 5, 2))
            b = jnp.arange(6 * 16).reshape((6, 16))
            c, d = fm(a, b)
            self.assertAllClose(c, (a + 2).sum(0))
            self.assertAllClose(d, (b * 4).T)