def testShardingInXMap(self): h = pjit(lambda x: x, in_axis_resources=P('x'), out_axis_resources=None) f = xmap(lambda x: h(x * 2), in_axes=['i', ...], out_axes=['i', ...], axis_resources={'i': 'y'}) x = jnp.arange(16).reshape((4, 4)) self.assertIn(pjit_p, xla.call_translations) rule = xla.call_translations[pjit_p] test_rule_called = False def _test_rule(*args, **kwargs): nonlocal test_rule_called test_rule_called = True in_axis_resources = kwargs['in_axis_resources'] self.assertEqual(len(in_axis_resources), 1) self.assertIn(('y', ), in_axis_resources[0].partitions) return rule(*args, **kwargs) try: xla.call_translations[pjit_p] = _test_rule f(x) self.assertTrue(test_rule_called) finally: xla.call_translations[pjit_p] = rule
def testNestedMap(self, xmap_in_axes, xmap_out_axes, vmap_in_axes, vmap_out_axes, vmap_as_xmap): """Test various vmap(xmap) and xmap(xmap) combinations. The outer map always introduces a single dimension, the inner map introduces one or two. """ (xin_x, xin_y) = xmap_in_axes (vin_x, vin_y) = vmap_in_axes vmap_size = 7 xmap_sizes = {'x': 11, 'y': 13} xshape = [2, 3] yshape = [3, 5] zshape = [2, 5] xind = ['n', 'k'] yind = ['k', 'm'] zind = ['n', 'm'] f = partial(jnp.einsum, 'nk,km->nm') for pos, name in sorted(xin_x.items()): xshape.insert(pos, xmap_sizes[name]) xind.insert(pos, name) for pos, name in sorted(xin_y.items()): yshape.insert(pos, xmap_sizes[name]) yind.insert(pos, name) for pos, name in sorted(xmap_out_axes.items()): zshape.insert(pos, xmap_sizes[name]) zind.insert(pos, name) if vin_x is not None: xshape.insert(vin_x, vmap_size) xind.insert(vin_x, 'v') if vin_y is not None: yshape.insert(vin_y, vmap_size) yind.insert(vin_y, 'v') zshape.insert(vmap_out_axes, vmap_size) zind.insert(vmap_out_axes, 'v') if vmap_as_xmap: do_vmap = partial(xmap, in_axes=({ vin_x: 'v' } if vin_x is not None else {}, { vin_y: 'v' } if vin_y is not None else {}), out_axes={vmap_out_axes: 'v'}) else: do_vmap = partial(vmap, in_axes=vmap_in_axes, out_axes=vmap_out_axes) fm = do_vmap(xmap(f, in_axes=xmap_in_axes, out_axes=xmap_out_axes)) fref = partial(jnp.einsum, f"{''.join(xind)},{''.join(yind)}->{''.join(zind)}") rng = np.random.RandomState(0) x = rng.randn(*xshape) y = rng.randn(*yshape) self.assertAllClose(fm(x, y), fref(x, y))
def check(spec): out = xmap(partial(jnp.einsum, spec), in_axes=(['i', 'j'], ['j', 'k']), out_axes=['i', 'k'])(x, y) expected = np.einsum('ij,jk->ik', x, y) tol = 1e-1 if jtu.device_under_test() == "tpu" else None self.assertAllClose(out, expected, check_dtypes=True, atol=tol, rtol=tol)
def testCollectivePermute2D(self): perm = np.array([3, 1, 2, 0]) x = jnp.arange(4).reshape((2, 2)) result = xmap(lambda x: lax.pshuffle(x, ('i', 'j'), perm), in_axes=['i', 'j', ...], out_axes=['i', 'j', ...], axis_resources={'i': 'x', 'j': 'y'})(x).reshape((-1,)) self.assertAllClose(result, perm)
def test_xeinsum_vector_dot(self): rng = np.random.RandomState(0) x = rng.randn(3) y = rng.randn(3) out = xmap(partial(jnp.einsum, '{i},{i}->'), in_axes=(['i'], ['i']), out_axes=[])(x, y) expected = np.einsum('i,i->', x, y) self.assertAllClose(out, expected, check_dtypes=False)
def test_xeinsum_outer_product(self): rng = np.random.RandomState(0) x = rng.randn(3) y = rng.randn(3) out = xmap(partial(jnp.einsum, '{i},{j}->{i,j}'), in_axes=(['i'], ['j']), out_axes=['i', 'j'])(x, y) expected = np.einsum('i,j->ij', x, y) self.assertAllClose(out, expected, check_dtypes=True)
def testReturnExtraMappedAxes(self): fm = xmap(lambda x, y: x + y, in_axes=(['a', ...], ['b', ...]), out_axes=['a', ...]) x = np.arange(12).reshape((4, 3)) y = np.arange(6).reshape((2, 3)) error = (r"One of xmap results has an out_axes specification of \['a', ...\], but " r"is actually mapped along more axes defined by this xmap call: b") with self.assertRaisesRegex(TypeError, error): fm(x, y)
def test_xmap_inherits_effects(self): def f(x): effect_p.bind(effect='foo') effect_p.bind(effect='bar') return x f = maps.xmap(f, in_axes=['a'], out_axes=['a']) with self.assertRaisesRegex(NotImplementedError, 'Effects not supported'): jax.make_jaxpr(f)(jnp.arange(jax.local_device_count()))
def test_xmap_inherits_effects(self): def f(x): effect_p.bind(effect='foo') effect_p.bind(effect='bar') return x f = maps.xmap(f, in_axes=['a'], out_axes=['a']) jaxpr = jax.make_jaxpr(f)(jnp.arange(jax.local_device_count())) self.assertSetEqual(jaxpr.effects, {"foo", "bar"})
def run_test(): f_mapped = xmap(f, in_axes=(['i', ...], ['j', ...]), out_axes=['i', 'j', ...], axis_resources=dict(axis_resources)) x = jnp.arange(30).reshape(2, 3, 5) expected = jnp.einsum('imk,jnk->ijmn', x, x) for i in range(10): self.assertAllClose(f_mapped(x, x), expected)
def test_xmap(self): with tempfile.TemporaryDirectory() as tmpdir: cc.initialize_cache(tmpdir) def f(x): return x * 2 devices = np.array(jax.local_devices()[:2]) if devices.size < 2: raise SkipTest("Test requires 2 devices") x = np.arange(8, dtype=np.int64).reshape((2, 2, 2)) xmap(f, in_axes=['a', ...], out_axes=['a', ...], axis_resources={'a': 'x'})(x) files_in_directory = len(os.listdir(tmpdir)) self.assertEqual(files_in_directory, 1) x = np.arange(8, dtype=np.float32).reshape((2, 2, 2)) xmap(f, in_axes=['a', ...], out_axes=['a', ...], axis_resources={'a': 'x'})(x) files_in_directory = len(os.listdir(tmpdir)) self.assertEqual(files_in_directory, 2)
def testResourceConflictArgsLoop(self): fm = xmap(lambda x: x, in_axes=['a', 'b'], out_axes=['a', 'b'], axis_resources={'a': 'l', 'b': 'l'}) x = np.arange(16).reshape(4, 4) error = (r"Axes `a` and `b` are both mapped to the resource `l`, but they " r"coincide in the named_shape of an input to an xmapped function " r"<lambda>") with self.assertRaisesRegex(JAXTypeError, error): fm(x)
def testLoopCollectives(self): fm = xmap(lambda x: lax.psum(x, 'i'), in_axes=['i'], out_axes=[], axis_resources={'i': 'l'}) x = np.arange(16) error = (r"Named axes with loop resources assigned to them cannot be " r"referenced inside the xmapped computation \(e.g. in " r"collectives\), but `i` violates that rule") with self.assertRaisesRegex(RuntimeError, error): fm(x)
def testNamedShape(self, mesh, axis_resources): x = np.arange(4,) y = 2 f = xmap(lambda x, y: (x + y, y * lax.axis_index('i')), in_axes=(['i', ...], {}), out_axes=(['i', ...], ['i', ...]), axis_resources=dict(axis_resources)) z, w = f(x, y) self.assertEqual(z.aval.named_shape, {}) self.assertEqual(w.aval.named_shape, {})
def testOneLogicalTwoMeshAxesSharding(self): def f(v): return v * 4 fxy = xmap(f, in_axes=['a', ...], out_axes={1: 'a'}, axis_resources={'a': ('x', 'y')}) fyx = xmap(f, in_axes=['a', ...], out_axes={1: 'a'}, axis_resources={'a': ('y', 'x')}) vshape = (4, 5) v = jnp.arange(np.prod(vshape)).reshape(vshape) zxy = fxy(v) self.assertEqual( zxy.sharding_spec, pxla.ShardingSpec((pxla.NoSharding(), pxla.Chunked((2, 2))), (pxla.ShardedAxis(0), pxla.ShardedAxis(1)))) zyx = fyx(v) self.assertEqual( zyx.sharding_spec, pxla.ShardingSpec((pxla.NoSharding(), pxla.Chunked((2, 2))), (pxla.ShardedAxis(1), pxla.ShardedAxis(0))))
def sample(axis_resources): return xmap(lambda: distr_sample(jax.random.PRNGKey(0), shape=NamedShape(3, i=4, j=6)), in_axes=(), out_axes=['i', 'j', ...], axis_sizes={ 'i': 4, 'j': 6 }, axis_resources=axis_resources)()
def testOneLogicalTwoMeshAxesBasic(self): def f(v): return lax.psum(v * 2, 'a'), v * 4 fm = xmap(f, in_axes=['a', ...], out_axes=[{}, {1: 'a'}], axis_resources={'a': ('x', 'y')}) vshape = (4, 5) v = jnp.arange(np.prod(vshape)).reshape(vshape) ans, ans2 = fm(v) self.assertAllClose(ans, (v * 2).sum(0)) self.assertAllClose(ans2, v.T * 4)
def testResourceConflictInner(self): fm = xmap(lambda x, y: x + y, in_axes=(['a', ...], ['b', ...]), out_axes=['a', 'b', ...], axis_resources={'a': 'x', 'b': 'x'}) x = np.arange(12).reshape(4, 3) y = np.arange(6).reshape(2, 3) error = (r"Axes `a` and `b` are both mapped to the resource `x`, but they " r"coincide in the named_shape.*primitive add created at") with self.assertRaisesRegex(JAXTypeError, error): fm(x, y)
def testResourceConflictOut(self): fm = xmap(lambda x, y: x, in_axes=(['a', ...], ['b', ...]), out_axes=['a', 'b', ...], axis_resources={'a': 'x', 'b': 'x'}) x = np.arange(12).reshape(4, 3) y = np.arange(6).reshape(2, 3) error = (r"One of xmapped function \(<lambda>\) outputs is broadcast along axis " r"`b` which is assigned to resources `x`, but the output is already " r"partitioned along `x`, because its named shape contains `a`") with self.assertRaisesRegex(JAXTypeError, error): fm(x, y)
def testGather(self, mesh, axis_resources): if axis_resources and not jax.experimental.maps.EXPERIMENTAL_SPMD_LOWERING: raise SkipTest("pgather over mesh axes without SPMD lowering not implemented") x = jnp.arange(12, dtype=np.float32).reshape((4, 3)) y = jnp.arange(35).reshape((5, 7)) % 3 f = xmap(lambda src, idx: pgather(src, idx, 'j'), in_axes=(['i', 'j'], ['k', 'm']), out_axes=['i', 'k', 'm'], axis_resources=dict(axis_resources)) f_ref = lambda x, y: x[:, y.reshape((-1,))].reshape((4, 5, 7)) self.assertAllClose(f(x, y), f_ref(x, y))
def testRepeatedAxisResource(self): def f(v): return v * 4 with self.assertRaisesRegex( ValueError, r"distinct resources.*specified \('x', 'x'\) for axis a"): fxy = xmap(f, in_axes=['a', ...], out_axes=['a', ...], axis_resources={'a': ('x', 'x')})
def testConstraintShardsXMapAxis(self): spec = P('x') f = xmap(lambda x: with_sharding_constraint(x, axis_resources=spec), in_axes=['i', ...], out_axes=['i', ...], axis_resources={'i': 'x'}) x = jnp.arange(4).reshape((2, 2)) error = (r"with_sharding_constraint input has an axis resources specification of " + spec_regex(spec) + r" that uses one or more mesh axes already used by " r"xmap to partition a named axis appearing in its named_shape \(both " r"use mesh axes `x`\)") with self.assertRaisesRegex(JAXTypeError, error): f(x)
def test_unordered_print_with_xmap(self): def f(x): debug_print("{}", x, ordered=False) f = maps.xmap(f, in_axes=['a'], out_axes=None, backend='cpu', axis_resources={'a': 'dev'}) with maps.Mesh(np.array(jax.devices(backend='cpu')), ['dev']): with capture_stdout() as output: f(jnp.arange(40)) jax.effects_barrier() lines = [f"{i}\n" for i in range(40)] self._assertLinesEqual(output(), "".join(lines))
def test_xeinsum_matmul(self): rng = np.random.RandomState(0) x = rng.randn(3, 4) y = rng.randn(4, 5) out = xmap(partial(jnp.einsum, '{i,j},{j,k}->{i,k}'), in_axes=(['i', 'j'], ['j', 'k']), out_axes=['i', 'k'])(x, y) expected = np.einsum('ij,jk->ik', x, y) tol = 1e-1 if jtu.device_under_test() == "tpu" else None self.assertAllClose(out, expected, check_dtypes=True, atol=tol, rtol=tol) # order of named axes in the spec doesn't matter! out = xmap(partial(jnp.einsum, '{i,j},{k,j}->{k,i}'), in_axes=(['i', 'j'], ['j', 'k']), out_axes=['i', 'k'])(x, y) expected = np.einsum('ij,jk->ik', x, y) tol = 1e-1 if jtu.device_under_test() == "tpu" else None self.assertAllClose(out, expected, check_dtypes=True, atol=tol, rtol=tol)
def testCompilationCache(self): def f(x): assert python_should_be_executing return x * 2 fm = xmap(f, in_axes=['a', ...], out_axes=['a', ...], axis_resources={'a': 'x'}) x = np.arange(8).reshape((2, 2, 2)) python_should_be_executing = True fm(x) python_should_be_executing = False fm(x)
def testCollectiveReduce(self): fm = xmap(lambda a, b: (lax.psum(a * 2, 'a'), b * 4), in_axes=[['a', 'b', ...], {0: 'c'}], out_axes=[['b', ...], {0: 'c'}], axis_resources={'a': 'x', 'b': 'y', 'c': 'x'}) ashape = (16, 8, 5) a = jnp.arange(np.prod(ashape)).reshape(ashape) bshape = (2, 7) b = jnp.arange(np.prod(bshape)).reshape(bshape) c, d = fm(a, b) self.assertAllClose(c, (a * 2).sum(0)) self.assertAllClose(d, b * 4)
def testCatchesInnerXMapErrors(self): f = pjit(xmap(lambda x, y: x, in_axes=(['i'], ['j']), out_axes=['i', 'j'], axis_resources={ 'i': 'x', 'j': 'x' }), in_axis_resources=None, out_axis_resources=None) x = jnp.arange(4) with self.assertRaises(JAXTypeError): f(x, x)
def testMultipleCalls(self, mesh, axis_resources): def f(x, y): assert x.shape == y.shape == (3, 5) return jnp.tensordot(x, y, axes=([1], [1])) f_mapped = xmap(f, in_axes=(['i', ...], ['j', ...]), out_axes=['i', 'j', ...], axis_resources=dict(axis_resources)) x = jnp.arange(30).reshape(2, 3, 5) expected = jnp.einsum('imk,jnk->ijmn', x, x) for i in range(10): self.assertAllClose(f_mapped(x, x), expected)
def testCompilationCache(self): def f(x): assert python_should_be_executing return x * 2 fm = xmap(f, in_axes=A({'a': 0}), out_axes=A({'a': 0}), schedule=[('a', 'x'), ('a', 'vectorize')]) x = np.arange(8).reshape((2, 2, 2)) python_should_be_executing = True fm(x) python_should_be_executing = False fm(x)
def testReductions(self, reduction, axes, mapped_axis): axes_t = axes if isinstance(axes, tuple) else (axes,) reduces_i = 'i' in axes_t ref_red = partial(reduction, axis=tuple(mapped_axis if a == 'i' else a + (a >= mapped_axis) for a in axes_t)) mapped_axis_after_red = mapped_axis - sum(axis < mapped_axis if axis != 'i' else 0 for axis in axes_t) xmap_red = xmap(lambda x: reduction(x, axes), in_axes={mapped_axis: 'i'}, out_axes=({} if 'i' in axes_t else {mapped_axis_after_red: 'i'})) rng = np.random.RandomState(0) x = rng.randn(2, 5, 6) self.assertAllClose(ref_red(x), xmap_red(x))