Beispiel #1
0
 def testJVP(self):
   # Add a constant captured by the nested pjit to make things more complicated
   h = jnp.arange(4)
   f = pjit(lambda x: x.sum() + h.sum(), in_axis_resources=P('x', 'y'), out_axis_resources=None)
   g = pjit(lambda x: f(x + 2), in_axis_resources=P('x', None), out_axis_resources=None)
   jtu.check_grads(g, (jnp.arange(16, dtype=jnp.float32).reshape((4, 4)),),
                   order=2, modes=["fwd"], eps=1)
Beispiel #2
0
 def testUndefinedResourcesOuts(self, mesh, resources):
   x = jnp.ones((2, 2))
   spec = P(resources,)
   with self.assertRaisesRegex(ValueError,
                               r"One of pjit outputs.*" + spec_regex(spec) + r", "
                               r"but resource axis x is undefined."):
     pjit(lambda x: x, in_axis_resources=None, out_axis_resources=spec)(x)
Beispiel #3
0
 def testRankTooLowOuts(self):
   x = jnp.arange(2)
   spec = P('x', 'y')
   error = (r"One of pjit outputs.*" + spec_regex(spec) + r", which implies "
            r"that it has a rank of at least 2, but it is 0")
   with self.assertRaisesRegex(ValueError, error):
     pjit(lambda x: x.sum(), in_axis_resources=None, out_axis_resources=spec)(x)
Beispiel #4
0
 def testEmptyMesh(self):
     error = (
         r"pjit requires a non-empty mesh! Are you sure that it's defined "
         r"at the call site?")
     with self.assertRaisesRegex(RuntimeError, error):
         pjit(lambda x: x, in_axis_resources=None,
              out_axis_resources=None)(jnp.arange(4))
Beispiel #5
0
 def testWithCustomPRNGKey(self):
     if not config.jax_enable_custom_prng:
         raise unittest.SkipTest("test requires jax_enable_custom_prng")
     key = jax.prng.seed_with_impl(jax.prng.rbg_prng_impl, 87)
     # Make sure this doesn't crash
     pjit(lambda x: x, in_axis_resources=(None),
          out_axis_resources=(None))(key)
Beispiel #6
0
 def testRepeatedOutResources(self):
   x = jnp.arange(2)
   for spec in [P('x', 'x'), P('x', ('y', 'x'))]:
     error = (r"A single out_axis_resources specification can map every mesh "
              r"axis to at most one positional dimension, but " +
              spec_regex(spec) + " has duplicate entries for `x`")
     with self.assertRaisesRegex(ValueError, error):
       pjit(lambda x: x, in_axis_resources=None, out_axis_resources=spec)(x)
Beispiel #7
0
 def testUndefinedResourcesConstraint(self, mesh, resources):
   x = jnp.ones((2, 2))
   spec = P(resources,)
   with self.assertRaisesRegex(ValueError,
                               r"One of with_sharding_constraint arguments"
                               r".*" + spec_regex(spec) + r", but resource axis "
                               r"x is undefined."):
     pjit(lambda x: with_sharding_constraint(x, spec),
          in_axis_resources=None, out_axis_resources=None)(x)
Beispiel #8
0
 def testNested(self):
   # Add a constant captured by the nested pjit to make things more complicated
   h = jnp.arange(4)
   f = pjit(lambda x: x.sum() + h.sum(), in_axis_resources=P('x', 'y'), out_axis_resources=None)
   g = pjit(lambda x: f(jnp.sin(x)), in_axis_resources=P('x', None), out_axis_resources=None)
   x = jnp.arange(16).reshape((4, 4))
   y = g(x)
   self.assertAllClose(y, jnp.sin(x).sum() + h.sum())
   self.assertTrue(hasattr(y, "sharding_spec"))
Beispiel #9
0
 def testRankTooLowConstraint(self):
   x = jnp.arange(2)
   spec = P('x', 'y')
   error = (r"One of with_sharding_constraint arguments " +
            r"was given.*" + spec_regex(spec) + r", which implies "
            r"that it has a rank of at least 2, but it is 1")
   with self.assertRaisesRegex(ValueError, error):
     pjit(lambda x: with_sharding_constraint(x, spec),
          in_axis_resources=None, out_axis_resources=None)(x)
Beispiel #10
0
 def testNonDivisibleOuts(self, mesh, resources):
   x = jnp.ones((3, 2))
   spec = P(resources, None)
   mesh_size = str(np.prod([dim[1] for dim in mesh], dtype=np.int64))
   with self.assertRaisesRegex(ValueError,
                               r"One of pjit outputs.*" + spec_regex(spec) + r".*"
                               r"implies that the size of its dimension 0 should be "
                               r"divisible by " + mesh_size + r", but it is equal to 3"):
     pjit(lambda x: x, in_axis_resources=None, out_axis_resources=P(resources, None))(x)
Beispiel #11
0
 def testAutodiff(self, mesh, resources):
   if len(mesh) != 2: return
   assert resources == ('x', 'y')
   # Add a constant captured by the nested pjit to make things more complicated
   h = jnp.arange(4)
   f = pjit(lambda x: x.sum(1) * h.sum(),
            in_axis_resources=P('x', 'y'), out_axis_resources=P(('x', 'y')))
   g = pjit(lambda x: f(jnp.sin(x * 4 + 2)),
            in_axis_resources=P('x', None), out_axis_resources=P(('x', 'y')))
   jtu.check_grads(g, (jnp.arange(16, dtype=jnp.float32).reshape((4, 4)) / 100,),
                   order=2)
Beispiel #12
0
 def testNonDivisibleConstraint(self, mesh, resources):
   x = jnp.ones((3, 2))
   spec = P(resources,)
   mesh_size = str(np.prod([dim[1] for dim in mesh], dtype=np.int64))
   with self.assertRaisesRegex(ValueError,
                               r"One of with_sharding_constraint arguments"
                               r".*" + spec_regex(spec) + r".*implies that the size of "
                               r"its dimension 0 should be divisible by " + mesh_size +
                               r", but it is equal to 3"):
     pjit(lambda x: with_sharding_constraint(x, spec),
          in_axis_resources=None, out_axis_resources=None)(x)
Beispiel #13
0
  def test_from_gda_duplicates(self):
    global_mesh = create_global_mesh((1, 2), ('x', 'y'))
    global_input_shape = (8, 2)
    mesh_axes = ['x', 'y']
    input_gda = create_gda(global_input_shape, global_mesh, mesh_axes)

    # It's occasionally possible to end up with two FROM_GDA singletons (e.g. if
    # pickling in_axis_resources and sending to other processes). Make sure this
    # this doesn't cause an error to avoid user confusion.
    from_gda_dup = pjit_lib._FromGsdaSingleton()
    with mesh(global_mesh.devices, global_mesh.axis_names):
      pjit(lambda x: x, in_axis_resources=from_gda_dup, out_axis_resources=None)(
          input_gda)
Beispiel #14
0
 def testGradOfConstraint(self):
   # Make sure that we can compute grads through sharding constraints
   h = lambda x: jnp.sin(with_sharding_constraint(x, P('x'))).sum()
   f = pjit(lambda x: jax.grad(h)(x),
            in_axis_resources=None, out_axis_resources=None)
   x = jnp.arange(8, dtype=jnp.float32)
   self.assertAllClose(f(x), jnp.cos(x))
Beispiel #15
0
    def testShardingInXMap(self):
        h = pjit(lambda x: x,
                 in_axis_resources=P('x'),
                 out_axis_resources=None)
        f = xmap(lambda x: h(x * 2),
                 in_axes=['i', ...],
                 out_axes=['i', ...],
                 axis_resources={'i': 'y'})
        x = jnp.arange(16).reshape((4, 4))
        self.assertIn(pjit_p, xla.call_translations)
        rule = xla.call_translations[pjit_p]
        test_rule_called = False

        def _test_rule(*args, **kwargs):
            nonlocal test_rule_called
            test_rule_called = True
            in_axis_resources = kwargs['in_axis_resources']
            self.assertEqual(len(in_axis_resources), 1)
            self.assertIn(('y', ), in_axis_resources[0].partitions)
            return rule(*args, **kwargs)

        try:
            xla.call_translations[pjit_p] = _test_rule
            f(x)
            self.assertTrue(test_rule_called)
        finally:
            xla.call_translations[pjit_p] = rule
Beispiel #16
0
 def testCatchesInnerXMapErrors(self):
   f = pjit(xmap(lambda x, y: x, in_axes=(['i'], ['j']), out_axes=['i', 'j'],
                 axis_resources={'i': 'x', 'j': 'x'}),
            in_axis_resources=None, out_axis_resources=None)
   x = jnp.arange(4)
   with self.assertRaises(JAXTypeError):
     f(x, x)
Beispiel #17
0
 def testNoopPartitionSpecs(self):
     noops = [P(), P(None), P(()), P((), None), P(None, None, ())]
     x = jnp.arange(8).reshape((2, 2, 2))
     for spec in noops:
         y = pjit(lambda x: x * 2,
                  in_axis_resources=spec,
                  out_axis_resources=spec)(x)
         self.assertAllClose(y, x * 2)
Beispiel #18
0
 def testVMap(self):
   f = pjit(lambda x, y: (x + y, x), in_axis_resources=P('x'), out_axis_resources=P('x'))
   x = jnp.arange(4)
   y = jnp.arange(5*4).reshape((5, 4))
   z, w = jax.vmap(f, in_axes=(None, 0), out_axes=(0, None))(x, y)
   self.assertAllClose(z, x + y)
   self.assertAllClose(w, x)
   self.assertEqual(z.sharding_spec.sharding, (pxla.NoSharding(), pxla.Chunked([2])))
   self.assertEqual(w.sharding_spec.sharding, (pxla.Chunked([2]),))
 def _matrix_inverse_pth_root_pjit(xs, ps):
   mesh_axis_names_tuple = tuple(mesh_axis_names)
   # Partition the concatenated statistics matrix across all cores.
   partitioned_xs, partitioned_ps = pjit.pjit(
       lambda x, y: (x, y),
       in_axis_resources=None,
       out_axis_resources=pjit.PartitionSpec(mesh_axis_names_tuple,))(xs, ps)
   # Run matrix inverse pth root on each shard.
   partitioned_preconditioners, partitioned_errors = _matrix_inverse_pth_root_vmap(
       partitioned_xs, partitioned_ps)
   # Recombine the outputs at each core.
   preconditioners, errors = pjit.pjit(
       lambda x, y: (x, y),
       in_axis_resources=(pjit.PartitionSpec(mesh_axis_names_tuple,),
                          pjit.PartitionSpec(mesh_axis_names_tuple,)),
       out_axis_resources=(None, None))(partitioned_preconditioners,
                                        partitioned_errors)
   return preconditioners, errors
Beispiel #20
0
 def testVMapShardingConstraint(self):
   f = pjit(lambda x: with_sharding_constraint(x, P('x')),
            in_axis_resources=P(), out_axis_resources=P('x'))
   x = jnp.arange(5*4).reshape((5, 4))
   jaxpr = jax.make_jaxpr(jax.vmap(f))(x)
   pjit_eqn, = jaxpr.eqns
   constraint_eqn, = pjit_eqn.params['jaxpr'].eqns
   self.assertEqual(constraint_eqn.params['axis_resources'].partitions, ((), ('x',)))
   self.assertEqual(constraint_eqn.params['axis_resources'].sync, SpecSync.DIM_PERMUTE)
Beispiel #21
0
 def testEvalJaxpr(self):
     x, y = jnp.arange(4), jnp.arange(5)
     f = pjit(lambda x, y: x.sum() + jnp.sin(y),
              in_axis_resources=(P('x'), P('y')),
              out_axis_resources=P('y'))
     f_jaxpr = jax.make_jaxpr(f)(x, y)
     f_eval = jax.core.jaxpr_as_fun(f_jaxpr)
     r, = f_eval(x, y)
     self.assertAllClose(r, x.sum() + jnp.sin(y))
Beispiel #22
0
 def testLowerDonateArgnumsAvailable(self):
   x = jax.ShapeDtypeStruct((2, 2), jnp.float32)
   def f(*args):
     x, *_ = args
     return x
   f_low = pjit(f, donate_argnums=(0,),
                in_axis_resources=P('x'), out_axis_resources=P('x')).lower(x)
   f_com = f_low.compile()
   f_low.donate_argnums == f_com.donate_argnums == (0,)
Beispiel #23
0
 def testNonHashableAxisResources(self):
     x = jnp.arange(4)
     y = pjit(lambda x: {'b': x['a'] + 2},
              in_axis_resources=({
                  'a': P('x')
              }, ),
              out_axis_resources={'b': P('x')})({
                  'a': x
              })
     self.assertAllClose(y, {'b': x + 2})
Beispiel #24
0
 def testOutputShardsXMapAxis(self):
   spec = P('x')
   f = xmap(pjit(lambda x: x + 2, in_axis_resources=None, out_axis_resources=spec),
            in_axes=['i', ...], out_axes=['i', ...], axis_resources={'i': 'x'})
   x = jnp.arange(4).reshape((2, 2))
   error = (r"pjit output has an axis resources specification of " +
            spec_regex(spec) + r" that uses one or more mesh axes already used by "
            r"xmap to partition a named axis appearing in its named_shape \(both "
            r"use mesh axes `x`\)")
   with self.assertRaisesRegex(JAXTypeError, error):
     f(x)
Beispiel #25
0
 def test_pjit_inherits_effects(self):
   if jax.default_backend() not in {'gpu', 'tpu'}:
     raise unittest.SkipTest("pjit only supports GPU and TPU backends")
   def f(x):
     effect_p.bind(effect='foo')
     effect_p.bind(effect='bar')
     return x
   f = pjit.pjit(f, in_axis_resources=pjit.PartitionSpec('x'),
       out_axis_resources=pjit.PartitionSpec('x'))
   with self.assertRaisesRegex(NotImplementedError, 'Effects not supported'):
     with maps.Mesh(np.array(jax.devices()), ['x']):
       jax.make_jaxpr(f)(jnp.arange(jax.local_device_count()))
Beispiel #26
0
 def testAxisResourcesMismatch(self):
   x = jnp.ones([])
   p = [None, None, None]
   pjit(lambda x: x, (p,), p)([x, x, x])  # OK
   error = re.escape(
       r"pjit in_axis_resources specification must be a tree prefix of the "
       r"corresponding value, got specification (None, None, None) for value "
       r"tree PyTreeDef((*, *)). Note that pjit in_axis_resources that are "
       r"non-trivial pytrees should always be wrapped in a tuple representing "
       r"the argument list.")
   with self.assertRaisesRegex(ValueError, error):
     pjit(lambda x, y: x, p, p)(x, x)  # Error, but make sure we hint at tupling
   # TODO(apaszke): Disable implicit list casts and enable this
   # error = re.escape(
   # r"pjit in_axis_resources specification must be a tree prefix of the "
   # r"corresponding value, got specification (None, None, None) for value "
   # r"tree PyTreeDef(([*, *, *],)). Note that pjit in_axis_resources that "
   # r"are non-trivial pytrees should always be wrapped in a tuple representing "
   # r"the argument list. In particular, you're passing in a single argument "
   # r"which means that pjit in_axis_resources might need to be wrapped in a "
   # r"singleton tuple.")
   # with self.assertRaisesRegex(ValueError, error):
   # pjit(lambda x: x, p, p)([x, x, x])  # Error, but make sure we hint at singleton tuple
   error = re.escape(
       r"pjit out_axis_resources specification must be a tree prefix of the "
       r"corresponding value, got specification [[None, None, None], None] for "
       r"value tree PyTreeDef([*, *, *]).")
   with self.assertRaisesRegex(ValueError, error):
     pjit(lambda x: x, (p,), [p, None])([x, x, x])  # Error, we raise a generic tree mismatch message
Beispiel #27
0
    def test_pjit_inherits_effects(self):
        def f(x):
            effect_p.bind(effect='foo')
            effect_p.bind(effect='bar')
            return x

        f = pjit.pjit(f,
                      in_axis_resources=pjit.PartitionSpec('x'),
                      out_axis_resources=pjit.PartitionSpec('x'))
        with self.assertRaisesRegex(NotImplementedError,
                                    'Effects not supported'):
            with maps.Mesh(np.array(jax.devices()), ['x']):
                jax.make_jaxpr(f)(jnp.arange(jax.local_device_count()))
Beispiel #28
0
  def testPjit(self):
    if jax.device_count() < 2:
      raise SkipTest("test requires >=2 devices")

    p = jax.experimental.PartitionSpec('x')
    f = pjit.pjit(lambda x: 0. / x,
                  in_axis_resources=p,
                  out_axis_resources=p)

    with jax.experimental.maps.mesh(np.array(jax.local_devices()[:2]), ('x',)):
      with self.assertRaises(FloatingPointError):
        ans = f(jnp.array([0., 1.]))
        ans.block_until_ready()
Beispiel #29
0
 def testVmapModifiesAxisResources(self):
   h = pjit(lambda x, y: (x + y, x, y), in_axis_resources=P('x'), out_axis_resources=None)
   x = jnp.arange(4)
   y = jnp.arange(5*4).reshape((5, 4))
   jaxpr = jax.make_jaxpr(jax.vmap(h, in_axes=(None, 0)))(x, y).jaxpr
   eqn = jaxpr.eqns[0]
   self.assertIs(eqn.primitive, pjit_p)
   x_sync, y_sync = (spec.sync for spec in eqn.params['in_axis_resources'])
   self.assertEqual(x_sync, SpecSync.IN_SYNC)
   self.assertEqual(y_sync, SpecSync.DIM_PERMUTE)
   x_sync, y_sync, z_sync = (spec.sync for spec in eqn.params['out_axis_resources'])
   self.assertEqual(x_sync, SpecSync.DIM_PERMUTE)
   self.assertEqual(y_sync, SpecSync.IN_SYNC)
   self.assertEqual(z_sync, SpecSync.DIM_PERMUTE)
Beispiel #30
0
  def testBufferDonation(self):
    @partial(pjit,
             in_axis_resources=P('x'),
             out_axis_resources=P('x'),
             donate_argnums=0)
    def f(x, y):
      return x + y

    shard = pjit(lambda x: x, in_axis_resources=P('x'),
                 out_axis_resources=P('x'))
    x = shard(jnp.ones((2, 5)) * 4)
    y = shard(jnp.ones((2, 5)) * 2)
    expected = x + y
    self.assertAllClose(f(x, y), expected)
    self.assertNotDeleted(y)
    self.assertDeleted(x)