def testGradOfConstraint(self): # Make sure that we can compute grads through sharding constraints h = lambda x: jnp.sin(with_sharding_constraint(x, P('x'))).sum() f = pjit(lambda x: jax.grad(h)(x), in_axis_resources=None, out_axis_resources=None) x = jnp.arange(8, dtype=jnp.float32) self.assertAllClose(f(x), jnp.cos(x))
def testRankTooLowConstraint(self): x = jnp.arange(2) spec = P('x', 'y') error = (r"One of with_sharding_constraint arguments " + r"was given.*" + spec_regex(spec) + r", which implies " r"that it has a rank of at least 2, but it is 1") with self.assertRaisesRegex(ValueError, error): pjit(lambda x: with_sharding_constraint(x, spec), in_axis_resources=None, out_axis_resources=None)(x)
def testUndefinedResourcesConstraint(self, mesh, resources): x = jnp.ones((2, 2)) spec = P(resources,) with self.assertRaisesRegex(ValueError, r"One of with_sharding_constraint arguments" r".*" + spec_regex(spec) + r", but resource axis " r"x is undefined."): pjit(lambda x: with_sharding_constraint(x, spec), in_axis_resources=None, out_axis_resources=None)(x)
def testVMapShardingConstraint(self): f = pjit(lambda x: with_sharding_constraint(x, P('x')), in_axis_resources=P(), out_axis_resources=P('x')) x = jnp.arange(5*4).reshape((5, 4)) jaxpr = jax.make_jaxpr(jax.vmap(f))(x) pjit_eqn, = jaxpr.eqns constraint_eqn, = pjit_eqn.params['jaxpr'].eqns self.assertEqual(constraint_eqn.params['axis_resources'].partitions, ((), ('x',))) self.assertEqual(constraint_eqn.params['axis_resources'].sync, SpecSync.DIM_PERMUTE)
def testConstraintShardsXMapAxis(self): spec = P('x') f = xmap(lambda x: with_sharding_constraint(x, axis_resources=spec), in_axes=['i', ...], out_axes=['i', ...], axis_resources={'i': 'x'}) x = jnp.arange(4).reshape((2, 2)) error = (r"with_sharding_constraint input has an axis resources specification of " + spec_regex(spec) + r" that uses one or more mesh axes already used by " r"xmap to partition a named axis appearing in its named_shape \(both " r"use mesh axes `x`\)") with self.assertRaisesRegex(JAXTypeError, error): f(x)
def testNonDivisibleConstraint(self, mesh, resources): x = jnp.ones((3, 2)) spec = P(resources,) mesh_size = str(np.prod([dim[1] for dim in mesh], dtype=np.int64)) with self.assertRaisesRegex(ValueError, r"One of with_sharding_constraint arguments" r".*" + spec_regex(spec) + r".*implies that the size of " r"its dimension 0 should be divisible by " + mesh_size + r", but it is equal to 3"): pjit(lambda x: with_sharding_constraint(x, spec), in_axis_resources=None, out_axis_resources=None)(x)
def f(x): x = with_sharding_constraint(x, [P('x', 'y'), P('y', 'x')]) x = x.copy() x[0]["a"] *= 2 return x
def f(x): y = x + 1 y = with_sharding_constraint(y, P('x', 'y')) return y * 2
def jax_func(x): # x: f32[12, 8] y = jnp.tile(x, (2, 1)) # y: f32[24, 8] y = pjit.with_sharding_constraint(y, P("x", "y")) return y[0:y.shape[0] // 4] # res: f32[6, 8]
def maybe_shard(x, resource): try: return with_sharding_constraint(x, resource) except ValueError as e: print(e) return x