Example #1
0
 def testGradOfConstraint(self):
   # Make sure that we can compute grads through sharding constraints
   h = lambda x: jnp.sin(with_sharding_constraint(x, P('x'))).sum()
   f = pjit(lambda x: jax.grad(h)(x),
            in_axis_resources=None, out_axis_resources=None)
   x = jnp.arange(8, dtype=jnp.float32)
   self.assertAllClose(f(x), jnp.cos(x))
Example #2
0
 def testRankTooLowConstraint(self):
   x = jnp.arange(2)
   spec = P('x', 'y')
   error = (r"One of with_sharding_constraint arguments " +
            r"was given.*" + spec_regex(spec) + r", which implies "
            r"that it has a rank of at least 2, but it is 1")
   with self.assertRaisesRegex(ValueError, error):
     pjit(lambda x: with_sharding_constraint(x, spec),
          in_axis_resources=None, out_axis_resources=None)(x)
Example #3
0
 def testUndefinedResourcesConstraint(self, mesh, resources):
   x = jnp.ones((2, 2))
   spec = P(resources,)
   with self.assertRaisesRegex(ValueError,
                               r"One of with_sharding_constraint arguments"
                               r".*" + spec_regex(spec) + r", but resource axis "
                               r"x is undefined."):
     pjit(lambda x: with_sharding_constraint(x, spec),
          in_axis_resources=None, out_axis_resources=None)(x)
Example #4
0
 def testVMapShardingConstraint(self):
   f = pjit(lambda x: with_sharding_constraint(x, P('x')),
            in_axis_resources=P(), out_axis_resources=P('x'))
   x = jnp.arange(5*4).reshape((5, 4))
   jaxpr = jax.make_jaxpr(jax.vmap(f))(x)
   pjit_eqn, = jaxpr.eqns
   constraint_eqn, = pjit_eqn.params['jaxpr'].eqns
   self.assertEqual(constraint_eqn.params['axis_resources'].partitions, ((), ('x',)))
   self.assertEqual(constraint_eqn.params['axis_resources'].sync, SpecSync.DIM_PERMUTE)
Example #5
0
 def testConstraintShardsXMapAxis(self):
   spec = P('x')
   f = xmap(lambda x: with_sharding_constraint(x, axis_resources=spec),
            in_axes=['i', ...], out_axes=['i', ...], axis_resources={'i': 'x'})
   x = jnp.arange(4).reshape((2, 2))
   error = (r"with_sharding_constraint input has an axis resources specification of " +
            spec_regex(spec) + r" that uses one or more mesh axes already used by "
            r"xmap to partition a named axis appearing in its named_shape \(both "
            r"use mesh axes `x`\)")
   with self.assertRaisesRegex(JAXTypeError, error):
     f(x)
Example #6
0
 def testNonDivisibleConstraint(self, mesh, resources):
   x = jnp.ones((3, 2))
   spec = P(resources,)
   mesh_size = str(np.prod([dim[1] for dim in mesh], dtype=np.int64))
   with self.assertRaisesRegex(ValueError,
                               r"One of with_sharding_constraint arguments"
                               r".*" + spec_regex(spec) + r".*implies that the size of "
                               r"its dimension 0 should be divisible by " + mesh_size +
                               r", but it is equal to 3"):
     pjit(lambda x: with_sharding_constraint(x, spec),
          in_axis_resources=None, out_axis_resources=None)(x)
Example #7
0
 def f(x):
     x = with_sharding_constraint(x, [P('x', 'y'), P('y', 'x')])
     x = x.copy()
     x[0]["a"] *= 2
     return x
Example #8
0
 def f(x):
     y = x + 1
     y = with_sharding_constraint(y, P('x', 'y'))
     return y * 2
Example #9
0
 def jax_func(x):  # x: f32[12, 8]
     y = jnp.tile(x, (2, 1))  # y: f32[24, 8]
     y = pjit.with_sharding_constraint(y, P("x", "y"))
     return y[0:y.shape[0] // 4]  # res: f32[6, 8]
Example #10
0
def maybe_shard(x, resource):
    try:
        return with_sharding_constraint(x, resource)
    except ValueError as e:
        print(e)
        return x