Exemplo n.º 1
0
 def testVMap(self):
   f = pjit(lambda x, y: (x + y, x), in_axis_resources=P('x'), out_axis_resources=P('x'))
   x = jnp.arange(4)
   y = jnp.arange(5*4).reshape((5, 4))
   z, w = jax.vmap(f, in_axes=(None, 0), out_axes=(0, None))(x, y)
   self.assertAllClose(z, x + y)
   self.assertAllClose(w, x)
   self.assertEqual(z.sharding_spec.sharding, (pxla.NoSharding(), pxla.Chunked([2])))
   self.assertEqual(w.sharding_spec.sharding, (pxla.Chunked([2]),))
Exemplo n.º 2
0
    def testNestedMesh(self, mesh, axis_resources):
        @partial(xmap,
                 in_axes={1: 'a'},
                 out_axes=({
                     0: 'a'
                 }, {}),
                 axis_resources=dict([axis_resources[0]]))
        def f(x):
            y = x * 2

            @partial(xmap,
                     in_axes={0: 'b'},
                     out_axes=({
                         1: 'b'
                     }, {}),
                     axis_resources=dict([axis_resources[1]]))
            def h(y):
                return jnp.sin(y), lax.psum(y, ('a', 'b'))

            return h(y)

        xshape = (4, 2, 5)
        x = jnp.arange(np.prod(xshape)).reshape(xshape)
        y = f(x)
        self.assertAllClose(y, (jnp.sin(x * 2).transpose(
            (1, 2, 0)), (x * 2).sum((0, 1))))
        self.assertEqual(
            y[0].sharding_spec.sharding,
            (pxla.Chunked(2), pxla.NoSharding(), pxla.NoSharding()))
        self.assertEqual(y[0].sharding_spec.mesh_mapping,
                         (pxla.Replicated(2), pxla.ShardedAxis(0)) +
                         (pxla.Replicated(2), ) * (len(mesh) - 2))
Exemplo n.º 3
0
    def testNestedXMapMesh(self):
        @partial(xmap,
                 in_axes=A({'a': 1}),
                 out_axes=A({'a': 0}),
                 schedule=[('a', 'y')])
        def f(x):
            y = x * 2

            @partial(xmap,
                     in_axes=A({'b': 0}),
                     out_axes=A({'b': 1}),
                     schedule=[('b', 'x')])
            def h(y):
                return jnp.sin(y)

            return h(y)

        xshape = (2, 3, 5)
        x = jnp.arange(np.prod(xshape)).reshape(xshape)
        y = f(x)
        self.assertAllClose(y, jnp.sin(x * 2).transpose((1, 2, 0)))
        # Make sure the op really ran accros a 2D mesh.
        self.assertEqual(y.sharding_spec.sharding,
                         (pxla.Chunked(3), None, None))
        self.assertEqual(y.sharding_spec.mesh_mapping,
                         (pxla.Replicated(2), pxla.ShardedAxis(0)))
Exemplo n.º 4
0
    def testNestedMesh(self):
        @partial(xmap,
                 in_axes={1: 'a'},
                 out_axes={0: 'a'},
                 axis_resources={'a': 'y'})
        def f(x):
            y = x * 2

            @partial(xmap,
                     in_axes={0: 'b'},
                     out_axes={1: 'b'},
                     axis_resources={'b': 'x'})
            def h(y):
                return jnp.sin(y)

            return h(y)

        xshape = (2, 3, 5)
        x = jnp.arange(np.prod(xshape)).reshape(xshape)
        y = f(x)
        self.assertAllClose(y, jnp.sin(x * 2).transpose((1, 2, 0)))
        # Make sure the op really ran accros a 2D mesh.
        self.assertEqual(y.sharding_spec.sharding,
                         (pxla.Chunked(3), None, None))
        self.assertEqual(y.sharding_spec.mesh_mapping,
                         (pxla.Replicated(2), pxla.ShardedAxis(0)))
Exemplo n.º 5
0
 def testOneLogicalTwoMeshAxesSharding(self):
   def f(v):
     return v * 4
   fxy = xmap(f, in_axes=['a', ...], out_axes={1: 'a'},
              axis_resources={'a': ('x', 'y')})
   fyx = xmap(f, in_axes=['a', ...], out_axes={1: 'a'},
              axis_resources={'a': ('y', 'x')})
   vshape = (4, 5)
   v = jnp.arange(np.prod(vshape)).reshape(vshape)
   zxy = fxy(v)
   self.assertEqual(
       zxy.sharding_spec,
       pxla.ShardingSpec((pxla.NoSharding(), pxla.Chunked((2, 2))),
                         (pxla.ShardedAxis(0), pxla.ShardedAxis(1))))
   zyx = fyx(v)
   self.assertEqual(
       zyx.sharding_spec,
       pxla.ShardingSpec((pxla.NoSharding(), pxla.Chunked((2, 2))),
                         (pxla.ShardedAxis(1), pxla.ShardedAxis(0))))
Exemplo n.º 6
0
  def testNestedMesh(self, mesh, axis_resources):
    @partial(xmap, in_axes={1: 'a'}, out_axes=({0: 'a'}, {}),
              axis_resources=dict([axis_resources[0]]))
    def f(x):
      y = x * 2
      @partial(xmap, in_axes={0: 'b'}, out_axes=({1: 'b'}, {}),
               axis_resources=dict([axis_resources[1]]))
      def h(y):
        # Multiply by a constant array to better exercise the partial_eval rule
        return jnp.sin(y) * np.arange(y.size), lax.psum(y, ('a', 'b'))
      return h(y)

    xshape = (4, 2, 5)
    x = jnp.arange(np.prod(xshape)).reshape(xshape)
    y = f(x)
    self.assertAllClose(y, ((jnp.sin(x * 2) * np.arange(xshape[-1])).transpose((1, 2, 0)), (x * 2).sum((0, 1))))
    self.assertEqual(y[0].sharding_spec.sharding,
                     (pxla.Chunked([2]), pxla.NoSharding(), pxla.NoSharding()))
    self.assertEqual(y[0].sharding_spec.mesh_mapping,
                     (pxla.Replicated(2), pxla.ShardedAxis(0)) + (pxla.Replicated(2),) * (len(mesh) - 2))
    if maps.EXPERIMENTAL_SPMD_LOWERING:
      hlo = jax.xla_computation(f)(x).as_hlo_text()
      # Make sure that there are non-partial sharding specs in the HLO
      self.assertRegex(hlo, r"sharding={devices=\[[0-9,]+\][0-9,]+}")