def testJVP(self): # Add a constant captured by the nested pjit to make things more complicated h = jnp.arange(4) f = pjit(lambda x: x.sum() + h.sum(), in_axis_resources=P('x', 'y'), out_axis_resources=None) g = pjit(lambda x: f(x + 2), in_axis_resources=P('x', None), out_axis_resources=None) jtu.check_grads(g, (jnp.arange(16, dtype=jnp.float32).reshape((4, 4)),), order=2, modes=["fwd"], eps=1)
def testUndefinedResourcesOuts(self, mesh, resources): x = jnp.ones((2, 2)) spec = P(resources,) with self.assertRaisesRegex(ValueError, r"One of pjit outputs.*" + spec_regex(spec) + r", " r"but resource axis x is undefined."): pjit(lambda x: x, in_axis_resources=None, out_axis_resources=spec)(x)
def testRankTooLowOuts(self): x = jnp.arange(2) spec = P('x', 'y') error = (r"One of pjit outputs.*" + spec_regex(spec) + r", which implies " r"that it has a rank of at least 2, but it is 0") with self.assertRaisesRegex(ValueError, error): pjit(lambda x: x.sum(), in_axis_resources=None, out_axis_resources=spec)(x)
def testEmptyMesh(self): error = ( r"pjit requires a non-empty mesh! Are you sure that it's defined " r"at the call site?") with self.assertRaisesRegex(RuntimeError, error): pjit(lambda x: x, in_axis_resources=None, out_axis_resources=None)(jnp.arange(4))
def testWithCustomPRNGKey(self): if not config.jax_enable_custom_prng: raise unittest.SkipTest("test requires jax_enable_custom_prng") key = jax.prng.seed_with_impl(jax.prng.rbg_prng_impl, 87) # Make sure this doesn't crash pjit(lambda x: x, in_axis_resources=(None), out_axis_resources=(None))(key)
def testRepeatedOutResources(self): x = jnp.arange(2) for spec in [P('x', 'x'), P('x', ('y', 'x'))]: error = (r"A single out_axis_resources specification can map every mesh " r"axis to at most one positional dimension, but " + spec_regex(spec) + " has duplicate entries for `x`") with self.assertRaisesRegex(ValueError, error): pjit(lambda x: x, in_axis_resources=None, out_axis_resources=spec)(x)
def testUndefinedResourcesConstraint(self, mesh, resources): x = jnp.ones((2, 2)) spec = P(resources,) with self.assertRaisesRegex(ValueError, r"One of with_sharding_constraint arguments" r".*" + spec_regex(spec) + r", but resource axis " r"x is undefined."): pjit(lambda x: with_sharding_constraint(x, spec), in_axis_resources=None, out_axis_resources=None)(x)
def testNested(self): # Add a constant captured by the nested pjit to make things more complicated h = jnp.arange(4) f = pjit(lambda x: x.sum() + h.sum(), in_axis_resources=P('x', 'y'), out_axis_resources=None) g = pjit(lambda x: f(jnp.sin(x)), in_axis_resources=P('x', None), out_axis_resources=None) x = jnp.arange(16).reshape((4, 4)) y = g(x) self.assertAllClose(y, jnp.sin(x).sum() + h.sum()) self.assertTrue(hasattr(y, "sharding_spec"))
def testRankTooLowConstraint(self): x = jnp.arange(2) spec = P('x', 'y') error = (r"One of with_sharding_constraint arguments " + r"was given.*" + spec_regex(spec) + r", which implies " r"that it has a rank of at least 2, but it is 1") with self.assertRaisesRegex(ValueError, error): pjit(lambda x: with_sharding_constraint(x, spec), in_axis_resources=None, out_axis_resources=None)(x)
def testNonDivisibleOuts(self, mesh, resources): x = jnp.ones((3, 2)) spec = P(resources, None) mesh_size = str(np.prod([dim[1] for dim in mesh], dtype=np.int64)) with self.assertRaisesRegex(ValueError, r"One of pjit outputs.*" + spec_regex(spec) + r".*" r"implies that the size of its dimension 0 should be " r"divisible by " + mesh_size + r", but it is equal to 3"): pjit(lambda x: x, in_axis_resources=None, out_axis_resources=P(resources, None))(x)
def testAutodiff(self, mesh, resources): if len(mesh) != 2: return assert resources == ('x', 'y') # Add a constant captured by the nested pjit to make things more complicated h = jnp.arange(4) f = pjit(lambda x: x.sum(1) * h.sum(), in_axis_resources=P('x', 'y'), out_axis_resources=P(('x', 'y'))) g = pjit(lambda x: f(jnp.sin(x * 4 + 2)), in_axis_resources=P('x', None), out_axis_resources=P(('x', 'y'))) jtu.check_grads(g, (jnp.arange(16, dtype=jnp.float32).reshape((4, 4)) / 100,), order=2)
def testNonDivisibleConstraint(self, mesh, resources): x = jnp.ones((3, 2)) spec = P(resources,) mesh_size = str(np.prod([dim[1] for dim in mesh], dtype=np.int64)) with self.assertRaisesRegex(ValueError, r"One of with_sharding_constraint arguments" r".*" + spec_regex(spec) + r".*implies that the size of " r"its dimension 0 should be divisible by " + mesh_size + r", but it is equal to 3"): pjit(lambda x: with_sharding_constraint(x, spec), in_axis_resources=None, out_axis_resources=None)(x)
def test_from_gda_duplicates(self): global_mesh = create_global_mesh((1, 2), ('x', 'y')) global_input_shape = (8, 2) mesh_axes = ['x', 'y'] input_gda = create_gda(global_input_shape, global_mesh, mesh_axes) # It's occasionally possible to end up with two FROM_GDA singletons (e.g. if # pickling in_axis_resources and sending to other processes). Make sure this # this doesn't cause an error to avoid user confusion. from_gda_dup = pjit_lib._FromGsdaSingleton() with mesh(global_mesh.devices, global_mesh.axis_names): pjit(lambda x: x, in_axis_resources=from_gda_dup, out_axis_resources=None)( input_gda)
def testGradOfConstraint(self): # Make sure that we can compute grads through sharding constraints h = lambda x: jnp.sin(with_sharding_constraint(x, P('x'))).sum() f = pjit(lambda x: jax.grad(h)(x), in_axis_resources=None, out_axis_resources=None) x = jnp.arange(8, dtype=jnp.float32) self.assertAllClose(f(x), jnp.cos(x))
def testShardingInXMap(self): h = pjit(lambda x: x, in_axis_resources=P('x'), out_axis_resources=None) f = xmap(lambda x: h(x * 2), in_axes=['i', ...], out_axes=['i', ...], axis_resources={'i': 'y'}) x = jnp.arange(16).reshape((4, 4)) self.assertIn(pjit_p, xla.call_translations) rule = xla.call_translations[pjit_p] test_rule_called = False def _test_rule(*args, **kwargs): nonlocal test_rule_called test_rule_called = True in_axis_resources = kwargs['in_axis_resources'] self.assertEqual(len(in_axis_resources), 1) self.assertIn(('y', ), in_axis_resources[0].partitions) return rule(*args, **kwargs) try: xla.call_translations[pjit_p] = _test_rule f(x) self.assertTrue(test_rule_called) finally: xla.call_translations[pjit_p] = rule
def testCatchesInnerXMapErrors(self): f = pjit(xmap(lambda x, y: x, in_axes=(['i'], ['j']), out_axes=['i', 'j'], axis_resources={'i': 'x', 'j': 'x'}), in_axis_resources=None, out_axis_resources=None) x = jnp.arange(4) with self.assertRaises(JAXTypeError): f(x, x)
def testNoopPartitionSpecs(self): noops = [P(), P(None), P(()), P((), None), P(None, None, ())] x = jnp.arange(8).reshape((2, 2, 2)) for spec in noops: y = pjit(lambda x: x * 2, in_axis_resources=spec, out_axis_resources=spec)(x) self.assertAllClose(y, x * 2)
def testVMap(self): f = pjit(lambda x, y: (x + y, x), in_axis_resources=P('x'), out_axis_resources=P('x')) x = jnp.arange(4) y = jnp.arange(5*4).reshape((5, 4)) z, w = jax.vmap(f, in_axes=(None, 0), out_axes=(0, None))(x, y) self.assertAllClose(z, x + y) self.assertAllClose(w, x) self.assertEqual(z.sharding_spec.sharding, (pxla.NoSharding(), pxla.Chunked([2]))) self.assertEqual(w.sharding_spec.sharding, (pxla.Chunked([2]),))
def _matrix_inverse_pth_root_pjit(xs, ps): mesh_axis_names_tuple = tuple(mesh_axis_names) # Partition the concatenated statistics matrix across all cores. partitioned_xs, partitioned_ps = pjit.pjit( lambda x, y: (x, y), in_axis_resources=None, out_axis_resources=pjit.PartitionSpec(mesh_axis_names_tuple,))(xs, ps) # Run matrix inverse pth root on each shard. partitioned_preconditioners, partitioned_errors = _matrix_inverse_pth_root_vmap( partitioned_xs, partitioned_ps) # Recombine the outputs at each core. preconditioners, errors = pjit.pjit( lambda x, y: (x, y), in_axis_resources=(pjit.PartitionSpec(mesh_axis_names_tuple,), pjit.PartitionSpec(mesh_axis_names_tuple,)), out_axis_resources=(None, None))(partitioned_preconditioners, partitioned_errors) return preconditioners, errors
def testVMapShardingConstraint(self): f = pjit(lambda x: with_sharding_constraint(x, P('x')), in_axis_resources=P(), out_axis_resources=P('x')) x = jnp.arange(5*4).reshape((5, 4)) jaxpr = jax.make_jaxpr(jax.vmap(f))(x) pjit_eqn, = jaxpr.eqns constraint_eqn, = pjit_eqn.params['jaxpr'].eqns self.assertEqual(constraint_eqn.params['axis_resources'].partitions, ((), ('x',))) self.assertEqual(constraint_eqn.params['axis_resources'].sync, SpecSync.DIM_PERMUTE)
def testEvalJaxpr(self): x, y = jnp.arange(4), jnp.arange(5) f = pjit(lambda x, y: x.sum() + jnp.sin(y), in_axis_resources=(P('x'), P('y')), out_axis_resources=P('y')) f_jaxpr = jax.make_jaxpr(f)(x, y) f_eval = jax.core.jaxpr_as_fun(f_jaxpr) r, = f_eval(x, y) self.assertAllClose(r, x.sum() + jnp.sin(y))
def testLowerDonateArgnumsAvailable(self): x = jax.ShapeDtypeStruct((2, 2), jnp.float32) def f(*args): x, *_ = args return x f_low = pjit(f, donate_argnums=(0,), in_axis_resources=P('x'), out_axis_resources=P('x')).lower(x) f_com = f_low.compile() f_low.donate_argnums == f_com.donate_argnums == (0,)
def testNonHashableAxisResources(self): x = jnp.arange(4) y = pjit(lambda x: {'b': x['a'] + 2}, in_axis_resources=({ 'a': P('x') }, ), out_axis_resources={'b': P('x')})({ 'a': x }) self.assertAllClose(y, {'b': x + 2})
def testOutputShardsXMapAxis(self): spec = P('x') f = xmap(pjit(lambda x: x + 2, in_axis_resources=None, out_axis_resources=spec), in_axes=['i', ...], out_axes=['i', ...], axis_resources={'i': 'x'}) x = jnp.arange(4).reshape((2, 2)) error = (r"pjit output has an axis resources specification of " + spec_regex(spec) + r" that uses one or more mesh axes already used by " r"xmap to partition a named axis appearing in its named_shape \(both " r"use mesh axes `x`\)") with self.assertRaisesRegex(JAXTypeError, error): f(x)
def test_pjit_inherits_effects(self): if jax.default_backend() not in {'gpu', 'tpu'}: raise unittest.SkipTest("pjit only supports GPU and TPU backends") def f(x): effect_p.bind(effect='foo') effect_p.bind(effect='bar') return x f = pjit.pjit(f, in_axis_resources=pjit.PartitionSpec('x'), out_axis_resources=pjit.PartitionSpec('x')) with self.assertRaisesRegex(NotImplementedError, 'Effects not supported'): with maps.Mesh(np.array(jax.devices()), ['x']): jax.make_jaxpr(f)(jnp.arange(jax.local_device_count()))
def testAxisResourcesMismatch(self): x = jnp.ones([]) p = [None, None, None] pjit(lambda x: x, (p,), p)([x, x, x]) # OK error = re.escape( r"pjit in_axis_resources specification must be a tree prefix of the " r"corresponding value, got specification (None, None, None) for value " r"tree PyTreeDef((*, *)). Note that pjit in_axis_resources that are " r"non-trivial pytrees should always be wrapped in a tuple representing " r"the argument list.") with self.assertRaisesRegex(ValueError, error): pjit(lambda x, y: x, p, p)(x, x) # Error, but make sure we hint at tupling # TODO(apaszke): Disable implicit list casts and enable this # error = re.escape( # r"pjit in_axis_resources specification must be a tree prefix of the " # r"corresponding value, got specification (None, None, None) for value " # r"tree PyTreeDef(([*, *, *],)). Note that pjit in_axis_resources that " # r"are non-trivial pytrees should always be wrapped in a tuple representing " # r"the argument list. In particular, you're passing in a single argument " # r"which means that pjit in_axis_resources might need to be wrapped in a " # r"singleton tuple.") # with self.assertRaisesRegex(ValueError, error): # pjit(lambda x: x, p, p)([x, x, x]) # Error, but make sure we hint at singleton tuple error = re.escape( r"pjit out_axis_resources specification must be a tree prefix of the " r"corresponding value, got specification [[None, None, None], None] for " r"value tree PyTreeDef([*, *, *]).") with self.assertRaisesRegex(ValueError, error): pjit(lambda x: x, (p,), [p, None])([x, x, x]) # Error, we raise a generic tree mismatch message
def test_pjit_inherits_effects(self): def f(x): effect_p.bind(effect='foo') effect_p.bind(effect='bar') return x f = pjit.pjit(f, in_axis_resources=pjit.PartitionSpec('x'), out_axis_resources=pjit.PartitionSpec('x')) with self.assertRaisesRegex(NotImplementedError, 'Effects not supported'): with maps.Mesh(np.array(jax.devices()), ['x']): jax.make_jaxpr(f)(jnp.arange(jax.local_device_count()))
def testPjit(self): if jax.device_count() < 2: raise SkipTest("test requires >=2 devices") p = jax.experimental.PartitionSpec('x') f = pjit.pjit(lambda x: 0. / x, in_axis_resources=p, out_axis_resources=p) with jax.experimental.maps.mesh(np.array(jax.local_devices()[:2]), ('x',)): with self.assertRaises(FloatingPointError): ans = f(jnp.array([0., 1.])) ans.block_until_ready()
def testVmapModifiesAxisResources(self): h = pjit(lambda x, y: (x + y, x, y), in_axis_resources=P('x'), out_axis_resources=None) x = jnp.arange(4) y = jnp.arange(5*4).reshape((5, 4)) jaxpr = jax.make_jaxpr(jax.vmap(h, in_axes=(None, 0)))(x, y).jaxpr eqn = jaxpr.eqns[0] self.assertIs(eqn.primitive, pjit_p) x_sync, y_sync = (spec.sync for spec in eqn.params['in_axis_resources']) self.assertEqual(x_sync, SpecSync.IN_SYNC) self.assertEqual(y_sync, SpecSync.DIM_PERMUTE) x_sync, y_sync, z_sync = (spec.sync for spec in eqn.params['out_axis_resources']) self.assertEqual(x_sync, SpecSync.DIM_PERMUTE) self.assertEqual(y_sync, SpecSync.IN_SYNC) self.assertEqual(z_sync, SpecSync.DIM_PERMUTE)
def testBufferDonation(self): @partial(pjit, in_axis_resources=P('x'), out_axis_resources=P('x'), donate_argnums=0) def f(x, y): return x + y shard = pjit(lambda x: x, in_axis_resources=P('x'), out_axis_resources=P('x')) x = shard(jnp.ones((2, 5)) * 4) y = shard(jnp.ones((2, 5)) * 2) expected = x + y self.assertAllClose(f(x, y), expected) self.assertNotDeleted(y) self.assertDeleted(x)