Example #1
0
    def testShardingInXMap(self):
        h = pjit(lambda x: x,
                 in_axis_resources=P('x'),
                 out_axis_resources=None)
        f = xmap(lambda x: h(x * 2),
                 in_axes=['i', ...],
                 out_axes=['i', ...],
                 axis_resources={'i': 'y'})
        x = jnp.arange(16).reshape((4, 4))
        self.assertIn(pjit_p, xla.call_translations)
        rule = xla.call_translations[pjit_p]
        test_rule_called = False

        def _test_rule(*args, **kwargs):
            nonlocal test_rule_called
            test_rule_called = True
            in_axis_resources = kwargs['in_axis_resources']
            self.assertEqual(len(in_axis_resources), 1)
            self.assertIn(('y', ), in_axis_resources[0].partitions)
            return rule(*args, **kwargs)

        try:
            xla.call_translations[pjit_p] = _test_rule
            f(x)
            self.assertTrue(test_rule_called)
        finally:
            xla.call_translations[pjit_p] = rule
Example #2
0
    def testNestedMap(self, xmap_in_axes, xmap_out_axes, vmap_in_axes,
                      vmap_out_axes, vmap_as_xmap):
        """Test various vmap(xmap) and xmap(xmap) combinations.

    The outer map always introduces a single dimension, the inner map introduces one or two.
    """
        (xin_x, xin_y) = xmap_in_axes
        (vin_x, vin_y) = vmap_in_axes
        vmap_size = 7
        xmap_sizes = {'x': 11, 'y': 13}

        xshape = [2, 3]
        yshape = [3, 5]
        zshape = [2, 5]
        xind = ['n', 'k']
        yind = ['k', 'm']
        zind = ['n', 'm']
        f = partial(jnp.einsum, 'nk,km->nm')

        for pos, name in sorted(xin_x.items()):
            xshape.insert(pos, xmap_sizes[name])
            xind.insert(pos, name)
        for pos, name in sorted(xin_y.items()):
            yshape.insert(pos, xmap_sizes[name])
            yind.insert(pos, name)
        for pos, name in sorted(xmap_out_axes.items()):
            zshape.insert(pos, xmap_sizes[name])
            zind.insert(pos, name)

        if vin_x is not None:
            xshape.insert(vin_x, vmap_size)
            xind.insert(vin_x, 'v')
        if vin_y is not None:
            yshape.insert(vin_y, vmap_size)
            yind.insert(vin_y, 'v')
        zshape.insert(vmap_out_axes, vmap_size)
        zind.insert(vmap_out_axes, 'v')

        if vmap_as_xmap:
            do_vmap = partial(xmap,
                              in_axes=({
                                  vin_x: 'v'
                              } if vin_x is not None else {}, {
                                  vin_y: 'v'
                              } if vin_y is not None else {}),
                              out_axes={vmap_out_axes: 'v'})
        else:
            do_vmap = partial(vmap,
                              in_axes=vmap_in_axes,
                              out_axes=vmap_out_axes)

        fm = do_vmap(xmap(f, in_axes=xmap_in_axes, out_axes=xmap_out_axes))
        fref = partial(jnp.einsum,
                       f"{''.join(xind)},{''.join(yind)}->{''.join(zind)}")

        rng = np.random.RandomState(0)
        x = rng.randn(*xshape)
        y = rng.randn(*yshape)
        self.assertAllClose(fm(x, y), fref(x, y))
Example #3
0
 def check(spec):
   out = xmap(partial(jnp.einsum, spec),
              in_axes=(['i', 'j'], ['j', 'k']),
              out_axes=['i', 'k'])(x, y)
   expected = np.einsum('ij,jk->ik', x, y)
   tol = 1e-1 if jtu.device_under_test() == "tpu" else None
   self.assertAllClose(out, expected, check_dtypes=True,
                       atol=tol, rtol=tol)
Example #4
0
 def testCollectivePermute2D(self):
   perm = np.array([3, 1, 2, 0])
   x = jnp.arange(4).reshape((2, 2))
   result = xmap(lambda x: lax.pshuffle(x, ('i', 'j'), perm),
                 in_axes=['i', 'j', ...],
                 out_axes=['i', 'j', ...],
                 axis_resources={'i': 'x', 'j': 'y'})(x).reshape((-1,))
   self.assertAllClose(result, perm)
Example #5
0
 def test_xeinsum_vector_dot(self):
   rng = np.random.RandomState(0)
   x = rng.randn(3)
   y = rng.randn(3)
   out = xmap(partial(jnp.einsum, '{i},{i}->'),
              in_axes=(['i'], ['i']), out_axes=[])(x, y)
   expected = np.einsum('i,i->', x, y)
   self.assertAllClose(out, expected, check_dtypes=False)
Example #6
0
 def test_xeinsum_outer_product(self):
   rng = np.random.RandomState(0)
   x = rng.randn(3)
   y = rng.randn(3)
   out = xmap(partial(jnp.einsum, '{i},{j}->{i,j}'),
              in_axes=(['i'], ['j']), out_axes=['i', 'j'])(x, y)
   expected = np.einsum('i,j->ij', x, y)
   self.assertAllClose(out, expected, check_dtypes=True)
Example #7
0
 def testReturnExtraMappedAxes(self):
   fm = xmap(lambda x, y: x + y,
             in_axes=(['a', ...], ['b', ...]), out_axes=['a', ...])
   x = np.arange(12).reshape((4, 3))
   y = np.arange(6).reshape((2, 3))
   error = (r"One of xmap results has an out_axes specification of \['a', ...\], but "
            r"is actually mapped along more axes defined by this xmap call: b")
   with self.assertRaisesRegex(TypeError, error):
     fm(x, y)
Example #8
0
  def test_xmap_inherits_effects(self):

    def f(x):
      effect_p.bind(effect='foo')
      effect_p.bind(effect='bar')
      return x
    f = maps.xmap(f, in_axes=['a'], out_axes=['a'])
    with self.assertRaisesRegex(NotImplementedError, 'Effects not supported'):
      jax.make_jaxpr(f)(jnp.arange(jax.local_device_count()))
Example #9
0
    def test_xmap_inherits_effects(self):
        def f(x):
            effect_p.bind(effect='foo')
            effect_p.bind(effect='bar')
            return x

        f = maps.xmap(f, in_axes=['a'], out_axes=['a'])
        jaxpr = jax.make_jaxpr(f)(jnp.arange(jax.local_device_count()))
        self.assertSetEqual(jaxpr.effects, {"foo", "bar"})
Example #10
0
 def run_test():
   f_mapped = xmap(f,
                   in_axes=(['i', ...], ['j', ...]),
                   out_axes=['i', 'j', ...],
                   axis_resources=dict(axis_resources))
   x = jnp.arange(30).reshape(2, 3, 5)
   expected = jnp.einsum('imk,jnk->ijmn', x, x)
   for i in range(10):
     self.assertAllClose(f_mapped(x, x), expected)
Example #11
0
 def test_xmap(self):
   with tempfile.TemporaryDirectory() as tmpdir:
     cc.initialize_cache(tmpdir)
     def f(x):
       return x * 2
     devices = np.array(jax.local_devices()[:2])
     if devices.size < 2:
       raise SkipTest("Test requires 2 devices")
     x = np.arange(8, dtype=np.int64).reshape((2, 2, 2))
     xmap(f, in_axes=['a', ...], out_axes=['a', ...],
        axis_resources={'a': 'x'})(x)
     files_in_directory = len(os.listdir(tmpdir))
     self.assertEqual(files_in_directory, 1)
     x = np.arange(8, dtype=np.float32).reshape((2, 2, 2))
     xmap(f, in_axes=['a', ...], out_axes=['a', ...],
        axis_resources={'a': 'x'})(x)
     files_in_directory = len(os.listdir(tmpdir))
     self.assertEqual(files_in_directory, 2)
Example #12
0
 def testResourceConflictArgsLoop(self):
   fm = xmap(lambda x: x,
             in_axes=['a', 'b'], out_axes=['a', 'b'],
             axis_resources={'a': 'l', 'b': 'l'})
   x = np.arange(16).reshape(4, 4)
   error = (r"Axes `a` and `b` are both mapped to the resource `l`, but they "
            r"coincide in the named_shape of an input to an xmapped function "
            r"<lambda>")
   with self.assertRaisesRegex(JAXTypeError, error):
     fm(x)
Example #13
0
 def testLoopCollectives(self):
   fm = xmap(lambda x: lax.psum(x, 'i'),
             in_axes=['i'], out_axes=[],
             axis_resources={'i': 'l'})
   x = np.arange(16)
   error = (r"Named axes with loop resources assigned to them cannot be "
            r"referenced inside the xmapped computation \(e.g. in "
            r"collectives\), but `i` violates that rule")
   with self.assertRaisesRegex(RuntimeError, error):
     fm(x)
Example #14
0
 def testNamedShape(self, mesh, axis_resources):
   x = np.arange(4,)
   y = 2
   f = xmap(lambda x, y: (x + y, y * lax.axis_index('i')),
            in_axes=(['i', ...], {}),
            out_axes=(['i', ...], ['i', ...]),
            axis_resources=dict(axis_resources))
   z, w = f(x, y)
   self.assertEqual(z.aval.named_shape, {})
   self.assertEqual(w.aval.named_shape, {})
Example #15
0
 def testOneLogicalTwoMeshAxesSharding(self):
   def f(v):
     return v * 4
   fxy = xmap(f, in_axes=['a', ...], out_axes={1: 'a'},
              axis_resources={'a': ('x', 'y')})
   fyx = xmap(f, in_axes=['a', ...], out_axes={1: 'a'},
              axis_resources={'a': ('y', 'x')})
   vshape = (4, 5)
   v = jnp.arange(np.prod(vshape)).reshape(vshape)
   zxy = fxy(v)
   self.assertEqual(
       zxy.sharding_spec,
       pxla.ShardingSpec((pxla.NoSharding(), pxla.Chunked((2, 2))),
                         (pxla.ShardedAxis(0), pxla.ShardedAxis(1))))
   zyx = fyx(v)
   self.assertEqual(
       zyx.sharding_spec,
       pxla.ShardingSpec((pxla.NoSharding(), pxla.Chunked((2, 2))),
                         (pxla.ShardedAxis(1), pxla.ShardedAxis(0))))
Example #16
0
 def sample(axis_resources):
     return xmap(lambda: distr_sample(jax.random.PRNGKey(0),
                                      shape=NamedShape(3, i=4, j=6)),
                 in_axes=(),
                 out_axes=['i', 'j', ...],
                 axis_sizes={
                     'i': 4,
                     'j': 6
                 },
                 axis_resources=axis_resources)()
Example #17
0
 def testOneLogicalTwoMeshAxesBasic(self):
   def f(v):
     return lax.psum(v * 2, 'a'), v * 4
   fm = xmap(f, in_axes=['a', ...], out_axes=[{}, {1: 'a'}],
             axis_resources={'a': ('x', 'y')})
   vshape = (4, 5)
   v = jnp.arange(np.prod(vshape)).reshape(vshape)
   ans, ans2 = fm(v)
   self.assertAllClose(ans, (v * 2).sum(0))
   self.assertAllClose(ans2, v.T * 4)
Example #18
0
 def testResourceConflictInner(self):
   fm = xmap(lambda x, y: x + y,
             in_axes=(['a', ...], ['b', ...]), out_axes=['a', 'b', ...],
             axis_resources={'a': 'x', 'b': 'x'})
   x = np.arange(12).reshape(4, 3)
   y = np.arange(6).reshape(2, 3)
   error = (r"Axes `a` and `b` are both mapped to the resource `x`, but they "
            r"coincide in the named_shape.*primitive add created at")
   with self.assertRaisesRegex(JAXTypeError, error):
     fm(x, y)
Example #19
0
 def testResourceConflictOut(self):
   fm = xmap(lambda x, y: x,
             in_axes=(['a', ...], ['b', ...]), out_axes=['a', 'b', ...],
             axis_resources={'a': 'x', 'b': 'x'})
   x = np.arange(12).reshape(4, 3)
   y = np.arange(6).reshape(2, 3)
   error = (r"One of xmapped function \(<lambda>\) outputs is broadcast along axis "
            r"`b` which is assigned to resources `x`, but the output is already "
            r"partitioned along `x`, because its named shape contains `a`")
   with self.assertRaisesRegex(JAXTypeError, error):
     fm(x, y)
Example #20
0
 def testGather(self, mesh, axis_resources):
   if axis_resources and not jax.experimental.maps.EXPERIMENTAL_SPMD_LOWERING:
     raise SkipTest("pgather over mesh axes without SPMD lowering not implemented")
   x = jnp.arange(12, dtype=np.float32).reshape((4, 3))
   y = jnp.arange(35).reshape((5, 7)) % 3
   f = xmap(lambda src, idx: pgather(src, idx, 'j'),
            in_axes=(['i', 'j'], ['k', 'm']),
            out_axes=['i', 'k', 'm'],
            axis_resources=dict(axis_resources))
   f_ref = lambda x, y: x[:, y.reshape((-1,))].reshape((4, 5, 7))
   self.assertAllClose(f(x, y), f_ref(x, y))
Example #21
0
    def testRepeatedAxisResource(self):
        def f(v):
            return v * 4

        with self.assertRaisesRegex(
                ValueError,
                r"distinct resources.*specified \('x', 'x'\) for axis a"):
            fxy = xmap(f,
                       in_axes=['a', ...],
                       out_axes=['a', ...],
                       axis_resources={'a': ('x', 'x')})
Example #22
0
 def testConstraintShardsXMapAxis(self):
   spec = P('x')
   f = xmap(lambda x: with_sharding_constraint(x, axis_resources=spec),
            in_axes=['i', ...], out_axes=['i', ...], axis_resources={'i': 'x'})
   x = jnp.arange(4).reshape((2, 2))
   error = (r"with_sharding_constraint input has an axis resources specification of " +
            spec_regex(spec) + r" that uses one or more mesh axes already used by "
            r"xmap to partition a named axis appearing in its named_shape \(both "
            r"use mesh axes `x`\)")
   with self.assertRaisesRegex(JAXTypeError, error):
     f(x)
  def test_unordered_print_with_xmap(self):

    def f(x):
      debug_print("{}", x, ordered=False)
    f = maps.xmap(f, in_axes=['a'], out_axes=None, backend='cpu',
                  axis_resources={'a': 'dev'})
    with maps.Mesh(np.array(jax.devices(backend='cpu')), ['dev']):
      with capture_stdout() as output:
        f(jnp.arange(40))
        jax.effects_barrier()
      lines = [f"{i}\n" for i in range(40)]
      self._assertLinesEqual(output(), "".join(lines))
Example #24
0
  def test_xeinsum_matmul(self):
    rng = np.random.RandomState(0)
    x = rng.randn(3, 4)
    y = rng.randn(4, 5)

    out = xmap(partial(jnp.einsum, '{i,j},{j,k}->{i,k}'),
               in_axes=(['i', 'j'], ['j', 'k']),
               out_axes=['i', 'k'])(x, y)
    expected = np.einsum('ij,jk->ik', x, y)
    tol = 1e-1 if jtu.device_under_test() == "tpu" else None
    self.assertAllClose(out, expected, check_dtypes=True,
                        atol=tol, rtol=tol)

    # order of named axes in the spec doesn't matter!
    out = xmap(partial(jnp.einsum, '{i,j},{k,j}->{k,i}'),
               in_axes=(['i', 'j'], ['j', 'k']),
               out_axes=['i', 'k'])(x, y)
    expected = np.einsum('ij,jk->ik', x, y)
    tol = 1e-1 if jtu.device_under_test() == "tpu" else None
    self.assertAllClose(out, expected, check_dtypes=True,
                        atol=tol, rtol=tol)
Example #25
0
 def testCompilationCache(self):
   def f(x):
     assert python_should_be_executing
     return x * 2
   fm = xmap(f,
             in_axes=['a', ...], out_axes=['a', ...],
             axis_resources={'a': 'x'})
   x = np.arange(8).reshape((2, 2, 2))
   python_should_be_executing = True
   fm(x)
   python_should_be_executing = False
   fm(x)
Example #26
0
 def testCollectiveReduce(self):
   fm = xmap(lambda a, b: (lax.psum(a * 2, 'a'), b * 4),
             in_axes=[['a', 'b', ...], {0: 'c'}],
             out_axes=[['b', ...], {0: 'c'}],
             axis_resources={'a': 'x', 'b': 'y', 'c': 'x'})
   ashape = (16, 8, 5)
   a = jnp.arange(np.prod(ashape)).reshape(ashape)
   bshape = (2, 7)
   b = jnp.arange(np.prod(bshape)).reshape(bshape)
   c, d = fm(a, b)
   self.assertAllClose(c, (a * 2).sum(0))
   self.assertAllClose(d, b * 4)
Example #27
0
 def testCatchesInnerXMapErrors(self):
     f = pjit(xmap(lambda x, y: x,
                   in_axes=(['i'], ['j']),
                   out_axes=['i', 'j'],
                   axis_resources={
                       'i': 'x',
                       'j': 'x'
                   }),
              in_axis_resources=None,
              out_axis_resources=None)
     x = jnp.arange(4)
     with self.assertRaises(JAXTypeError):
         f(x, x)
Example #28
0
  def testMultipleCalls(self, mesh, axis_resources):
    def f(x, y):
      assert x.shape == y.shape == (3, 5)
      return jnp.tensordot(x, y, axes=([1], [1]))

    f_mapped = xmap(f,
                    in_axes=(['i', ...], ['j', ...]),
                    out_axes=['i', 'j', ...],
                    axis_resources=dict(axis_resources))
    x = jnp.arange(30).reshape(2, 3, 5)
    expected = jnp.einsum('imk,jnk->ijmn', x, x)
    for i in range(10):
      self.assertAllClose(f_mapped(x, x), expected)
Example #29
0
    def testCompilationCache(self):
        def f(x):
            assert python_should_be_executing
            return x * 2

        fm = xmap(f,
                  in_axes=A({'a': 0}),
                  out_axes=A({'a': 0}),
                  schedule=[('a', 'x'), ('a', 'vectorize')])
        x = np.arange(8).reshape((2, 2, 2))
        python_should_be_executing = True
        fm(x)
        python_should_be_executing = False
        fm(x)
Example #30
0
  def testReductions(self, reduction, axes, mapped_axis):
    axes_t = axes if isinstance(axes, tuple) else (axes,)
    reduces_i = 'i' in axes_t
    ref_red = partial(reduction,
                      axis=tuple(mapped_axis if a == 'i' else a + (a >= mapped_axis)
                                 for a in axes_t))
    mapped_axis_after_red = mapped_axis - sum(axis < mapped_axis if axis != 'i' else 0
                                              for axis in axes_t)
    xmap_red = xmap(lambda x: reduction(x, axes),
                    in_axes={mapped_axis: 'i'},
                    out_axes=({} if 'i' in axes_t else {mapped_axis_after_red: 'i'}))

    rng = np.random.RandomState(0)
    x = rng.randn(2, 5, 6)
    self.assertAllClose(ref_red(x), xmap_red(x))