Пример #1
0
    def testInfeed(self):
        devices = np.array(jax.local_devices())
        nr_devices = len(devices)
        shape = (nr_devices * 3, nr_devices * 5)

        def f_for_jit(x):
            token = lax.create_token(x)
            (y, ), token = lax.infeed(token,
                                      shape=(jax.ShapedArray(
                                          x.shape, np.float32), ))
            (z, ), token = lax.infeed(token,
                                      shape=(jax.ShapedArray(
                                          x.shape, np.float32), ))
            (w, ), token = lax.infeed(token,
                                      shape=(jax.ShapedArray(
                                          x.shape, np.float32), ))

            return x + y + z + w

        x = np.arange(np.prod(shape), dtype=np.float32).reshape(shape)
        y = x * 2.
        z = x * 3.
        w = x * 4.

        # Transfer data to infeed before executing the function. For GPUs, the
        # execution of the compiled function is blocking, so transferring data
        # to infeed before executing ensures that the execution does not deadlock
        # waiting for the infeed data.
        logging.info('Transfering to infeed for the jit call')
        d = devices[0]
        d.transfer_to_infeed((y, ))
        d.transfer_to_infeed((z, ))
        d.transfer_to_infeed((w, ))

        # JIT
        logging.info('Making jit call')
        res0 = jax.jit(f_for_jit)(x)
        self.assertAllClose(res0, x + y + z + w, check_dtypes=True)

        # PJIT
        def f_for_pjit(x):
            token = lax.create_token(x)
            # A replicated infeed
            (y, ), token = lax.infeed(token,
                                      shape=(jax.ShapedArray(
                                          x.shape, np.float32), ),
                                      partitions=(None, ))
            # An infeed sharded on first axis
            (z, ), token = lax.infeed(token,
                                      shape=(jax.ShapedArray(
                                          x.shape, np.float32), ),
                                      partitions=(P(nr_devices, 1), ))
            # An infeed sharded on second axis
            (w, ), token = lax.infeed(token,
                                      shape=(jax.ShapedArray(
                                          x.shape, np.float32), ),
                                      partitions=(P(1, nr_devices), ))
            return x + y + z + w

        logging.info('Transfering to infeed for the pjit call')
        for didx, d in enumerate(devices):
            # Transfer the whole array to all devices for replicated.
            d.transfer_to_infeed((y, ))
            # For sharded infeed, transfer only the needed slices to each device.
            d.transfer_to_infeed((z[3 * didx:3 * didx + 3, :]))
            d.transfer_to_infeed((w[:, 5 * didx:5 * didx + 5], ))

        with mesh(devices, ['d']):
            logging.info('Making pjit call')
            res = pjit(f_for_pjit,
                       in_axis_resources=(P('d'), ),
                       out_axis_resources=P('d'))(x)

        self.assertAllClose(res0, res, check_dtypes=True)
Пример #2
0
class PmapOfShardedJitTest(jtu.JaxTestCase):

  def setUp(self):
    super().setUp()
    if jtu.device_under_test() == "gpu":
      os.environ["NCCL_LAUNCH_MODE"] = "PARALLEL"

  # TODO(skye): make a similar version for ShardedJitTest and run the same tests
  def _runTest(self, f, in_partitions, out_partitions, dtype=np.float32):
    """Compares pmap(sharded_jit(f, ...)) to pmap(f)"""
    shape = (2, 4, 4)
    num_shards = shape[0] * np.prod(in_partitions[0])
    if num_shards > jax.local_device_count():
      raise SkipTest("requires %d devices" % num_shards)

    x = np.arange(prod(shape)).reshape(shape)
    y = x + 1
    result = pmap(
        sharded_jit(f, in_parts=in_partitions, out_parts=out_partitions))(x, y)
    expected = pmap(f)(x, y)
    self.assertAllClose(result, expected, check_dtypes=False)

    flat_result = tree_util.tree_flatten(result)[0]
    for r in flat_result:
      self.assertTrue(isinstance(r, pxla.ShardedDeviceArray))
      self.assertEqual(len(r.device_buffers), num_shards)


  @parameterized.named_parameters({
      "testcase_name":
          "_in_parts={}_out_parts={}".format(in_partitions,
                                             out_partitions).replace(" ", ""),
      "in_partitions":
          in_partitions,
      "out_partitions":
          out_partitions
  } for in_partitions in [
      (P(2, 1), P(2, 1)),
      (P(2, 1), P(1, 2)),
      (P(2, 2), P(2, 2)),
      (P(4, 1), P(2, 2)),
  ] for out_partitions in [in_partitions[0], None])
  def testBasic(self, in_partitions, out_partitions):

    def f(x, y):
      return lax.dot(x, y)

    self._runTest(f, in_partitions, out_partitions)

  @parameterized.named_parameters({
      "testcase_name":
          "_in_parts={}_out_parts={}".format(in_partitions,
                                             out_partitions).replace(" ", ""),
      "in_partitions":
          in_partitions,
      "out_partitions":
          out_partitions
  } for in_partitions in [
      (P(2, 1), P(2, 1)),
      (P(2, 1), P(1, 2)),
      (P(4, 1), P(2, 2))
  ] for out_partitions in [(in_partitions[1], in_partitions[0], None),
                           (None, None, None)])
  def testMultipleOutputs(self, in_partitions, out_partitions):

    def f(x, y):
      a = lax.dot(x, y)
      # TODO(skye): use these more interesting outputs once returning constants
      # works
      # return a, a + 1, 3
      return a, a + x, x + y

    self._runTest(f, in_partitions, out_partitions)

  @parameterized.named_parameters({
      "testcase_name":
          "_in_parts={}_out_parts={}".format(in_partitions,
                                             out_partitions).replace(" ", ""),
      "in_partitions":
          in_partitions,
      "out_partitions":
          out_partitions
  } for in_partitions in [
      (P(2, 1), P(2, 1)),
      (P(2, 1), P(1, 2)),
      (P(4, 1), P(2, 2))
  ] for out_partitions in [in_partitions[0], None])
  def testArrayConstants(self, in_partitions, out_partitions):

    def f(x, y):
      a = lax.dot(x, y)
      b = a + jnp.ones(a.shape)
      c = b + jnp.ones(a.shape[0])[jnp.newaxis]
      return c

    self._runTest(f, in_partitions, out_partitions)

  def testPyTreeArgs(self):
    if jax.local_device_count() < 4:
      raise SkipTest("requires 4 devices")

    def f(a, b, c):
      a1, a2 = a
      c1, (c2, c3) = c
      return a1 + a2 + b + c1 + c2 + c3

    def _make_arg(*shape):
      return np.arange(prod(shape)).reshape(shape)

    a = (_make_arg(2, 4, 4), _make_arg(2))
    b = _make_arg(2, 4, 4)
    c = (_make_arg(2), (_make_arg(2, 4, 4), _make_arg(2, 4, 4)))

    in_parts = (None, P(2, 1), (None, P(2, 1)))
    out_parts = P(2, 1)

    result = pmap(sharded_jit(f, in_parts=in_parts, out_parts=out_parts))(a, b, c)
    expected = pmap(f)(a, b, c)

    self.assertAllClose(result, expected, check_dtypes=False)
    self.assertTrue(isinstance(result, pxla.ShardedDeviceArray))
    self.assertEqual(len(result.device_buffers), 4)

  def testPyTreeOutputs(self):
    if jax.local_device_count() < 4:
      raise SkipTest("requires 4 devices")

    def f(x):
      return x + 1, ((x + 2, x + 3), x + 4)

    shape = (2, 4, 4)
    x = np.arange(prod(shape)).reshape(shape)
    in_parts = (P(2, 1),)
    out_parts = (P(2, 1), ((P(1, 2), None), P(2, 1)))

    result = pmap(sharded_jit(f, in_parts=in_parts, out_parts=out_parts))(x)
    expected = pmap(f)(x)

    self.assertAllClose(result, expected, check_dtypes=False)

  def testManyArgs(self):
    if jax.local_device_count() < 4:
      raise SkipTest("requires 4 devices")

    num_args = 200

    def f(*args):
      return jnp.asarray(args).sum()

    shape = (2, 4, 4)
    args = [np.arange(prod(shape)).reshape(shape)] * num_args
    in_partitions = (P(2, 1),) * num_args
    out_partitions = None
    result = pmap(sharded_jit(
        f, in_parts=in_partitions, out_parts=out_partitions))(*args)
    expected = pmap(f)(*args)

    self.assertAllClose(result, expected, check_dtypes=False)
    self.assertTrue(isinstance(result, pxla.ShardedDeviceArray))
    self.assertEqual(len(result.device_buffers), 4)

  def testShardingConstraint(self):
    if jax.local_device_count() < 4:
      raise SkipTest("requires 4 devices")

    @partial(sharded_jit, in_parts=None, out_parts=None)
    def f(x):
      y = jnp.dot(x, x)
      y = with_sharding_constraint(y, P(2,1))
      return y * 2

    def expected_f(x):
      return jnp.dot(x, x) * 2

    shape = (2, 8, 8)
    x = np.arange(prod(shape)).reshape(shape)
    result = pmap(f)(x)
    expected = pmap(expected_f)(x)

    self.assertAllClose(result, expected, check_dtypes=False)
    self.assertIsInstance(result, pxla.ShardedDeviceArray)
    self.assertLen(result.device_buffers, 4)

  def testInAxesNone(self):
    shape = (4, 4)
    replicas = 2
    in_partitions = (P(2, 1), None, None)
    out_partitions = P(2, 1)
    in_axes = (None, None, 0)
    x = y = np.arange(prod(shape), dtype=np.float32).reshape(shape)
    dummy = np.arange(replicas, dtype=np.float32) + 1
    num_shards = replicas * np.prod(in_partitions[0])
    if num_shards > jax.local_device_count():
      raise SkipTest("requires %d devices" % num_shards)

    def f(x, y, _):
      return x @ y

    result = pmap(
        sharded_jit(f, in_parts=in_partitions, out_parts=out_partitions),
        in_axes=in_axes)(x, y, dummy)
    expected = pmap(f, in_axes=in_axes)(x, y, dummy)
    self.assertAllClose(result, expected, check_dtypes=True)
Пример #3
0
 def f(x):
   y = jnp.dot(x, x)
   y = with_sharding_constraint(y, P(2,1))
   return y * 2
 def embedding(x):
     x = maybe_shard(x, P("dp", None))
     return EmbeddingShardV2(config)(x)
Пример #5
0
 def f(x):
   y = x + 1
   y = with_sharding_constraint(y, P(2,1))
   return y * 2
Пример #6
0
class GDATest(jtu.JaxTestCase):

  @parameterized.named_parameters(
      ("mesh_x_y", P("x", "y"),
       # There are more slices but for convienient purposes, checking for only
       # 2. The indices + shard_shape + replica_id should be unique enough.
       ((slice(0, 2), slice(0, 1)), (slice(0, 2), slice(1, 2))),
       (2, 1),
       [0, 0, 0, 0, 0, 0, 0, 0], False),
      ("mesh_x", P("x"),
       ((slice(0, 2), slice(None)), (slice(0, 2), slice(None))),
       (2, 2),
       [0, 1, 0, 1, 0, 1, 0, 1], False),
      ("mesh_y", P("y"),
       ((slice(0, 4), slice(None)), (slice(4, 8), slice(None))),
       (4, 2),
       [0, 0, 1, 1, 2, 2, 3, 3], False),
      ("mesh_none_y", P(None, "y"),
       ((slice(None), slice(0, 1)), (slice(None), slice(1, 2))),
       (8, 1),
       [0, 0, 1, 1, 2, 2, 3, 3], False),
      ("mesh_xy", P(("x", "y")),
       ((slice(0, 1), slice(None)), (slice(1, 2), slice(None))),
       (1, 2),
       [0, 0, 0, 0, 0, 0, 0, 0], False),
      ("mesh_fully_replicated", P(),
       ((slice(None), slice(None)), (slice(None), slice(None))),
       (8, 2),
       [0, 1, 2, 3, 4, 5, 6, 7], True),
  )
  def test_gda_2d_shard(self, mesh_axes, expected_index, expected_shard_shape,
                        expected_replica_ids, expected_is_fully_replicated):
    global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
    global_input_shape = (8, 2)
    global_input_data = np.arange(
        prod(global_input_shape)).reshape(global_input_shape)
    def cb(index):
      return global_input_data[index]

    gda = GlobalDeviceArray.from_callback(global_input_shape, global_mesh,
                                          mesh_axes, cb)
    self.assertEqual(gda.ndim, 2)
    self.assertEqual(gda.size, 16)
    self.assertEqual(gda.mesh_axes, mesh_axes)
    self.assertEqual(gda.local_shards[0].index, expected_index[0])
    self.assertArraysEqual(gda.local_data(0),
                           global_input_data[expected_index[0]])
    self.assertEqual(gda.local_shards[1].index, expected_index[1])
    self.assertArraysEqual(gda.local_data(1),
                           global_input_data[expected_index[1]])
    self.assertEqual(gda.local_data(0).shape, expected_shard_shape)
    replica_ids = [i.replica_id for i in gda.local_shards]
    self.assertListEqual(replica_ids, expected_replica_ids)
    self.assertListEqual([i.device.id for i in gda.local_shards],
                         [0, 1, 2, 3, 4, 5, 6, 7])
    self.assertEqual(gda.is_fully_replicated, expected_is_fully_replicated)
    for s in gda.local_shards:
      self.assertEqual(s.data.aval,
                       core.ShapedArray(expected_shard_shape, s.data.dtype))
    for g, l in safe_zip(gda.global_shards, gda.local_shards):
      self.assertEqual(g.device, l.device)
      self.assertEqual(g.index, l.index)
      self.assertEqual(g.replica_id, l.replica_id)
      self.assertEqual(g.data.aval, l.data.aval)
      self.assertArraysEqual(g.data, l.data)


  @parameterized.named_parameters(
      ("mesh_x_y_z", P("x", "y", "z"),
       # There are more slices but for convienient purposes, checking for only
       # 2. The indices + shard_shape + replica_id should be unique enough.
       ((slice(0, 4), slice(0, 2), slice(0, 1)), (slice(0, 4), slice(0, 2), slice(1, 2))),
       (4, 2, 1),
       [0, 0, 0, 0, 0, 0, 0, 0]),
      ("mesh_xy_z", P(("x", "y"), "z"),
       ((slice(0, 2), slice(0, 2), slice(None)), (slice(0, 2), slice(2, 4), slice(None))),
       (2, 2, 2),
       [0, 0, 0, 0, 0, 0, 0, 0]),
      ("mesh_z", P("z"),
       ((slice(0, 4), slice(None), slice(None)), (slice(4, 8), slice(None), slice(None))),
       (4, 4, 2),
       [0, 0, 1, 1, 2, 2, 3, 3]),
  )
  def test_gda_3d_shard(self, mesh_axes, expected_index, expected_shard_shape,
                         expected_replica_ids):
    global_mesh = jtu.create_global_mesh((2, 2, 2), ('x', 'y', 'z'))
    global_input_shape = (8, 4, 2)
    global_input_data = np.arange(
        prod(global_input_shape)).reshape(global_input_shape)
    def cb(index):
      return global_input_data[index]

    gda = GlobalDeviceArray.from_callback(global_input_shape, global_mesh,
                                          mesh_axes, cb)
    self.assertEqual(gda.ndim, 3)
    self.assertEqual(gda.size, 64)
    self.assertEqual(gda.local_shards[0].index, expected_index[0])
    self.assertArraysEqual(gda.local_data(0),
                           global_input_data[expected_index[0]])
    self.assertEqual(gda.local_shards[1].index, expected_index[1])
    self.assertArraysEqual(gda.local_data(1),
                           global_input_data[expected_index[1]])
    self.assertEqual(gda.local_data(0).shape, expected_shard_shape)

    replica_ids = [i.replica_id for i in gda.local_shards]
    self.assertListEqual(replica_ids, expected_replica_ids)

  @parameterized.named_parameters(
      ("mesh_x", P("x"),
       # There are more slices but for convienient purposes, checking for only
       # 2. The indices + shard_shape + replica_id should be unique enough.
       ((slice(0, 2),), (slice(2, 4),)),
       (2,),
       [0, 0, 0, 0, 0, 0, 0, 0]),
      ("mesh_none", P(),
       ((slice(None),), (slice(None),)),
       (16,),
       [0, 1, 2, 3, 4, 5, 6, 7]),
  )
  def test_gda_1d_shard(self, mesh_axes, expected_index, expected_shard_shape,
                         expected_replica_ids):
    global_mesh = jtu.create_global_mesh((8,), ('x'))
    global_input_shape = (16,)
    global_input_data = np.arange(prod(global_input_shape)).reshape(-1)
    def cb(index):
      return global_input_data[index]

    gda = GlobalDeviceArray.from_callback(global_input_shape, global_mesh,
                                          mesh_axes, cb)
    self.assertEqual(gda.ndim, 1)
    self.assertEqual(gda.size, 16)
    self.assertEqual(gda.local_shards[0].index, expected_index[0])
    self.assertArraysEqual(gda.local_data(0),
                           global_input_data[expected_index[0]])
    self.assertEqual(gda.local_shards[1].index, expected_index[1])
    self.assertArraysEqual(gda.local_data(1),
                           global_input_data[expected_index[1]])
    self.assertEqual(gda.local_data(0).shape, expected_shard_shape)
    replica_ids = [i.replica_id for i in gda.local_shards]
    self.assertListEqual(replica_ids, expected_replica_ids)

  def test_gda_shape_0_1d_mesh(self):
    global_mesh = jtu.create_global_mesh((8,), ('x'))
    global_input_shape = (0,)
    mesh_axes = P(None)
    def cb(index):
      return np.array([])
    gda = GlobalDeviceArray.from_callback(global_input_shape, global_mesh,
                                          mesh_axes, cb)
    self.assertEqual(gda.ndim, 1)
    self.assertEqual(gda.size, 0)
    for i, s in enumerate(gda.local_shards):
      self.assertEqual(s.index, (slice(None),))
      self.assertEqual(s.replica_id, i)
      self.assertArraysEqual(s.data.to_py(), np.array([]))
    self.assertEqual(gda.dtype, np.float32)
    self.assertEqual(
        gda_lib.get_shard_shape(global_input_shape, global_mesh, mesh_axes),
        (0,))


  @parameterized.named_parameters(
      ("mesh_x_y", P("x", "y"),
       # There are more slices but for convienient purposes, checking for only
       # 2. The indices + shard_shape + replica_id should be unique enough.
       ((slice(0, 4), slice(0, 1)), (slice(0, 4), slice(1, 2))),
       (4, 1),
       [0, 0, 0, 0]),
  )
  def test_gda_subset_devices(self, mesh_axes, expected_index,
                               expected_shard_shape, expected_replica_ids):
    global_mesh = jtu.create_global_mesh((2, 2), ('x', 'y'))
    global_input_shape = (8, 2)
    global_input_data = np.arange(
        prod(global_input_shape)).reshape(global_input_shape)
    def cb(index):
      return global_input_data[index]

    gda = GlobalDeviceArray.from_callback(global_input_shape, global_mesh,
                                          mesh_axes, cb)
    self.assertEqual(gda.local_shards[0].index, expected_index[0])
    self.assertArraysEqual(gda.local_data(0),
                           global_input_data[expected_index[0]])
    self.assertEqual(gda.local_shards[1].index, expected_index[1])
    self.assertArraysEqual(gda.local_data(1),
                           global_input_data[expected_index[1]])
    self.assertEqual(gda.local_data(0).shape, expected_shard_shape)
    replica_ids = [i.replica_id for i in gda.local_shards]
    self.assertListEqual(replica_ids, expected_replica_ids)
    for g, l in safe_zip(gda.global_shards, gda.local_shards):
      self.assertEqual(g.device, l.device)
      self.assertEqual(g.index, l.index)
      self.assertEqual(g.replica_id, l.replica_id)
      self.assertArraysEqual(g.data, l.data)

  def test_gda_batched_callback(self):
    global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
    global_input_shape = (8, 2)
    mesh_axes = P(('x', 'y'))
    global_input_data = np.arange(
        prod(global_input_shape)).reshape(global_input_shape)

    def cb(indices):
      self.assertEqual(len(indices), len(global_mesh.local_devices))
      return [global_input_data[index] for index in indices]

    gda = GlobalDeviceArray.from_batched_callback(
        global_input_shape, global_mesh, mesh_axes, cb)
    expected_first_shard_value = np.array([[0, 1]])
    self.assertArraysEqual(gda.local_data(0).to_py(),
                           expected_first_shard_value)
    expected_second_shard_value = np.array([[2, 3]])
    self.assertArraysEqual(gda.local_data(1).to_py(),
                           expected_second_shard_value)

  def test_gda_batched_callback_with_devices(self):
    global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
    global_input_shape = (8, 2)
    mesh_axes = P('x')
    global_input_data = np.arange(
        prod(global_input_shape), dtype=np.float32).reshape(global_input_shape)

    def cb(cb_inp):
      self.assertLen(cb_inp, 4)
      dbs = []
      for inp in cb_inp:
        index, devices = inp
        self.assertLen(devices, 2)
        array = global_input_data[index]
        dbs.extend([jax.device_put(array, device) for device in devices])
      return dbs

    gda = GlobalDeviceArray.from_batched_callback_with_devices(
        global_input_shape, global_mesh, mesh_axes, cb)
    expected_first_shard_value = np.array([[0, 1], [2, 3]], dtype=np.float32)
    self.assertArraysEqual(gda.local_data(0).to_py(),
                           expected_first_shard_value)
    expected_second_shard_value = np.array([[0, 1], [2, 3]], dtype=np.float32)
    self.assertArraysEqual(gda.local_data(1).to_py(),
                           expected_second_shard_value)

  def test_gda_str_repr(self):
    global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
    global_input_shape = (8, 2)
    mesh_axes = P(('x', 'y'))
    global_input_data = np.arange(
        prod(global_input_shape)).reshape(global_input_shape)
    def cb(index):
      return global_input_data[index]
    gda = GlobalDeviceArray.from_callback(
        global_input_shape, global_mesh, mesh_axes, cb)
    self.assertEqual(str(gda),
                     'GlobalDeviceArray(shape=(8, 2), dtype=int32)')
    self.assertEqual(
        repr(gda), ('GlobalDeviceArray(shape=(8, 2), dtype=int32, '
                    "global_mesh_shape={'x': 4, 'y': 2}, "
                    "mesh_axes=PartitionSpec(('x', 'y'),))"))

  def test_gda_equality_raises_not_implemented(self):
    global_mesh = jtu.create_global_mesh((1, 2), ('x', 'y'))
    global_input_shape = (8, 2)
    mesh_axes = P(None,)
    global_input_data = np.arange(
        prod(global_input_shape)).reshape(global_input_shape)
    def cb(index):
      return global_input_data[index]
    input_gda = GlobalDeviceArray.from_callback(
        global_input_shape, global_mesh, mesh_axes, cb)
    same_input_gda = GlobalDeviceArray.from_callback(
        global_input_shape, global_mesh, mesh_axes, cb)
    with self.assertRaisesRegex(NotImplementedError,
        'GlobalDeviceArray equality is intentionally unimplemented.'):
      input_gda == same_input_gda

  def test_mesh_hash(self):
    global_mesh1 = jtu.create_global_mesh((4, 2), ('x', 'y'))

    global_mesh2 = jtu.create_global_mesh((2, 4), ('x', 'y'))

    global_mesh3 = jtu.create_global_mesh((4, 2), ('x', 'y'))

    self.assertNotEqual(hash(global_mesh1), hash(global_mesh2))
    self.assertEqual(hash(global_mesh1), hash(global_mesh3))

  def test_device_mismatch(self):
    devices = jax.devices()
    if len(devices) < 8:
      raise unittest.SkipTest("Test requires 8 global devices.")
    mesh_devices = np.array([[devices[0], devices[2]],
                             [devices[3], devices[1]],
                             [devices[4], devices[6]],
                             [devices[7], devices[5]]])
    global_mesh = Mesh(mesh_devices, ('x', 'y'))
    global_input_shape = (8, 2)
    mesh_axes = P('x', 'y')
    global_input_data = np.arange(
        prod(global_input_shape)).reshape(global_input_shape)
    indices = get_shard_indices(global_input_shape, global_mesh, mesh_axes)

    dbs = [
        jax.device_put(global_input_data[indices[d]], d)
        for d in jax.local_devices()
    ]

    with self.assertRaisesRegex(
        ValueError,
        'The `global_mesh.local_devices` and `device_buffers` device order'):
      GlobalDeviceArray(global_input_shape, global_mesh, mesh_axes, dbs)

  def test_gda_block_until_ready(self):
    global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
    global_input_shape = (8, 2)
    mesh_axes = P(('x', 'y'))
    global_input_data = np.arange(
        prod(global_input_shape)).reshape(global_input_shape)

    def cb(index):
      return global_input_data[index]

    gda = GlobalDeviceArray.from_callback(
        global_input_shape, global_mesh, mesh_axes, cb)

    self.assertTrue(gda.block_until_ready() is gda)
Пример #7
0
 def f(x):
   x = with_sharding_constraint(x, [P('x', 'y'), P('y', 'x')])
   x = x.copy()
   x[0]["a"] *= 2
   return x
Пример #8
0
 def dispatch():
   with mesh(devices, ['d']):
     logging.info('Making pjit call')
     pjit(f, in_axis_resources=(P('d'),), out_axis_resources=P('d'))(x)
Пример #9
0
  def test_pjit_gda_multi_input_multi_output(self):
    global_mesh = create_global_mesh((4, 2), ('x', 'y'))
    global_input_shape = (8, 2)
    input_data = np.arange(
        prod(global_input_shape)).reshape(global_input_shape)
    def cb(index):
      return input_data[index]

    mesh_axes1 = P('x', 'y')
    gda1 = global_device_array.GlobalDeviceArray.from_callback(
        global_input_shape, global_mesh, mesh_axes1, cb)
    mesh_axes2 = P('x')
    gda2 = global_device_array.GlobalDeviceArray.from_callback(
        global_input_shape, global_mesh, mesh_axes2, cb)
    mesh_axes3 = P(('x', 'y'))
    gda3 = global_device_array.GlobalDeviceArray.from_callback(
        global_input_shape, global_mesh, mesh_axes3, cb)
    mesh_axes4 = P(None)
    gda4 = global_device_array.GlobalDeviceArray.from_callback(
        global_input_shape, global_mesh, mesh_axes4, cb)

    with jax._src.config.parallel_functions_output_gda(True):
      @partial(
          pjit,
          # `FROM_GDA` will be replicated for all the inputs.
          in_axis_resources=FROM_GDA,
          out_axis_resources=(mesh_axes1, mesh_axes4, mesh_axes2, mesh_axes3))
      def f(x, y, z, a):
        return x @ x.T, y, z, a
      out1, out2, out3, out4 = f(gda1, gda2, gda3, gda4)

      self.assertIsInstance(out1, global_device_array.GlobalDeviceArray)
      self.assertEqual(out1.shape, (8, 8))
      self.assertEqual(out1.local_shards[0].data.shape, (2, 4))
      self.assertEqual(out1.local_shards[0].index, (slice(0, 2), slice(0, 4)))
      self.assertEqual(out1.local_shards[1].index, (slice(0, 2), slice(4, 8)))
      self.assertListEqual([s.replica_id for s in out1.local_shards],
                           [0, 0, 0, 0, 0, 0, 0, 0])
      expected_matrix_mul = input_data @ input_data.T
      for s in out1.local_shards:
        self.assertArraysEqual(s.data, expected_matrix_mul[s.index])

      self.assertIsInstance(out2, global_device_array.GlobalDeviceArray)
      self.assertEqual(out2.shape, (8, 2))
      self.assertEqual(out2.local_shards[0].data.shape, (8, 2))
      self.assertEqual(out2.local_shards[0].index, (slice(None), slice(None)))
      self.assertEqual(out2.local_shards[1].index, (slice(None), slice(None)))
      self.assertListEqual([s.replica_id for s in out2.local_shards],
                           [0, 1, 2, 3, 4, 5, 6, 7])
      for s in out2.local_shards:
        self.assertArraysEqual(s.data, input_data)

      self.assertIsInstance(out3, global_device_array.GlobalDeviceArray)
      self.assertEqual(out3.shape, (8, 2))
      self.assertEqual(out3.local_shards[0].data.shape, (2, 2))
      self.assertEqual(out3.local_shards[0].index, (slice(0, 2), slice(None)))
      self.assertEqual(out3.local_shards[1].index, (slice(0, 2), slice(None)))
      self.assertListEqual([s.replica_id for s in out3.local_shards],
                           [0, 1, 0, 1, 0, 1, 0, 1])
      for s in out3.local_shards:
        self.assertArraysEqual(s.data, input_data[s.index])

      self.assertIsInstance(out4, global_device_array.GlobalDeviceArray)
      self.assertEqual(out4.shape, (8, 2))
      self.assertEqual(out4.local_shards[0].data.shape, (1, 2))
      self.assertEqual(out4.local_shards[0].index, (slice(0, 1), slice(None)))
      self.assertEqual(out4.local_shards[1].index, (slice(1, 2), slice(None)))
      self.assertListEqual([s.replica_id for s in out4.local_shards],
                           [0, 0, 0, 0, 0, 0, 0, 0])
      for s in out4.local_shards:
        self.assertArraysEqual(s.data, input_data[s.index])
Пример #10
0
 def testLowerWithDuckTyping(self):
   x = jax.ShapeDtypeStruct((2, 2), jnp.float32)
   # Make sure this doesn't crash
   pjit(lambda x: x + 4,
        in_axis_resources=P('x'), out_axis_resources=P('x')).lower(x)
Пример #11
0
 def f(x):
   token = lax.create_token(x)
   token = lax.outfeed(token, x, partitions=(None,))
   token = lax.outfeed(token, x, partitions=(P(nr_devices, 1),))
   token = lax.outfeed(token, x, partitions=(P(1, nr_devices),))
   return x
Пример #12
0
 def f(x):
   with mesh(np.array([jax.local_devices()[0]]), ('x')):
     @partial(pjit, in_axis_resources=P('x'), out_axis_resources=None)
     def h(x):
       return x
     return h(x)
Пример #13
0
class JaxArrayTest(jtu.JaxTestCase):
    @parameterized.named_parameters(
        ("mesh_x_y", P("x", "y")),
        ("mesh_x", P("x")),
        ("mesh_y", P("y")),
        ("mesh_none_y", P(None, "y")),
        ("mesh_xy", P(("x", "y"))),
        ("mesh_fully_replicated", P()),
    )
    def test_jax_array_value(self, mesh_axes):
        with jax._src.config.jax_array(True):
            global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
            input_shape = (8, 2)
            arr, global_data = create_array(
                input_shape,
                sharding.MeshPspecSharding(global_mesh, mesh_axes))
            for s in arr.addressable_shards:
                self.assertLen(s.data._arrays, 1)
                self.assertArraysEqual(s.data._arrays[0], global_data[s.index])
            self.assertArraysEqual(arr._value, global_data)
            self.assertArraysEqual(arr._npy_value, global_data)

    def test_array_delete(self):
        with jax._src.config.jax_array(True):
            global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
            input_shape = (8, 2)
            arr, _ = create_array(
                input_shape,
                sharding.MeshPspecSharding(global_mesh, P('x', 'y')))
            arr.delete()
            with self.assertRaisesRegex(ValueError, 'Array has been deleted.'):
                arr._check_if_deleted()
            self.assertIsNone(arr._npy_value)
            self.assertIsNone(arr._arrays)

    def test_device_put(self):
        with jax._src.config.jax_array(True):
            numpy_array = np.array([1, 2, 3])
            arr = jax.device_put(numpy_array, jax.devices()[0])
            self.assertIsInstance(arr.sharding, sharding.SingleDeviceSharding)
            self.assertArraysEqual(arr, numpy_array)
            self.assertEqual(arr._committed, True)
            for i in arr.addressable_shards:
                self.assertArraysEqual(i.data, numpy_array)
                self.assertEqual(i.device, jax.devices()[0])
                self.assertEqual(i.index, (slice(None), ))

    def test_device_put_array_delete(self):
        with jax._src.config.jax_array(True):
            arr = jax.device_put(np.array([1, 2, 3]), jax.devices()[0])
            arr.delete()
            with self.assertRaisesRegex(ValueError, 'Array has been deleted.'):
                arr._check_if_deleted()
            self.assertIsNone(arr._npy_value)
            self.assertIsNone(arr._arrays)

    def test_array_device_get(self):
        with jax._src.config.jax_array(True):
            global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
            input_shape = (8, 2)
            arr, input_data = create_array(
                input_shape,
                sharding.MeshPspecSharding(global_mesh, P('x', 'y')))
            self.assertArraysEqual(jax.device_get(arr), input_data)

    def test_repr(self):
        with jax._src.config.jax_array(True):
            global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
            input_shape = (8, 2)
            arr, _ = create_array(
                input_shape,
                sharding.MeshPspecSharding(global_mesh, P('x', 'y')))
            repr(arr)  # doesn't crash

    def test_jnp_array(self):
        with jax._src.config.jax_array(True):
            arr = jnp.array([1, 2, 3])
            self.assertIsInstance(arr, array.Array)
            self.assertIsInstance(arr.sharding, sharding.SingleDeviceSharding)
            self.assertEqual(arr._committed, False)

    def test_jnp_array_jit_add(self):
        with jax._src.config.jax_array(True):
            a = jnp.array([1, 2, 3])
            b = jnp.array([4, 5, 6])
            arr = jax.jit(lambda x, y: x + y)(a, b)
            self.assertIsInstance(arr, array.Array)
            self.assertArraysEqual(arr, np.array([5, 7, 9]))
            self.assertIsInstance(arr.sharding, sharding.SingleDeviceSharding)

    def test_jnp_array_jnp_add(self):
        with jax._src.config.jax_array(True):
            arr = jnp.add(jnp.array([1, 2, 3]), jnp.array([4, 5, 6]))
            self.assertIsInstance(arr, array.Array)
            self.assertArraysEqual(arr, np.array([5, 7, 9]))
            self.assertIsInstance(arr.sharding, sharding.SingleDeviceSharding)

    def test_jnp_array_normal_add(self):
        with jax._src.config.jax_array(True):
            a = jnp.array([1, 2, 3])
            b = jnp.array([4, 5, 6])
            arr = a + b
            self.assertIsInstance(arr, array.Array)
            self.assertArraysEqual(arr, np.array([5, 7, 9]))
            self.assertIsInstance(arr.sharding, sharding.SingleDeviceSharding)

    def test_array_sharded_astype(self):
        with jax._src.config.jax_array(True):
            global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
            input_shape = (8, 2)
            arr, input_data = create_array(
                input_shape,
                sharding.MeshPspecSharding(global_mesh, P('x', 'y')))
            arr_float32 = arr.astype(jnp.float32)
            self.assertEqual(arr_float32.dtype, np.float32)
            self.assertArraysEqual(arr_float32, input_data.astype(np.float32))

    def test_jnp_array_astype(self):
        with jax._src.config.jax_array(True):
            arr = jnp.array([1, 2, 3])
            arr_float32 = arr.astype(jnp.float32)
            self.assertEqual(arr_float32.dtype, np.float32)
            self.assertArraysEqual(arr_float32, arr.astype(np.float32))
 def residual(x, mask):
     out = x + TransformerLayerShardV2(
         config, init_scale=2. / config["layers"])(x, mask)
     return maybe_shard(out, P("dp", None, "mp"))
Пример #15
0
class GDATest(jtu.JaxTestCase):

  @parameterized.named_parameters(
      ("mesh_x_y", ["x", "y"],
       # There are more slices but for convienient purposes, checking for only
       # 2. The indices + shard_shape + replica_id should be unique enough.
       ((slice(0, 2), slice(0, 1)), (slice(0, 2), slice(1, 2))),
       (2, 1),
       [0, 0, 0, 0, 0, 0, 0, 0], False),
      ("mesh_x_y_pspec", P("x", "y"),
       ((slice(0, 2), slice(0, 1)), (slice(0, 2), slice(1, 2))),
       (2, 1),
       [0, 0, 0, 0, 0, 0, 0, 0], False),
      ("mesh_x", ["x"],
       ((slice(0, 2), slice(None)), (slice(0, 2), slice(None))),
       (2, 2),
       [0, 1, 0, 1, 0, 1, 0, 1], False),
      ("mesh_y", ["y"],
       ((slice(0, 4), slice(None)), (slice(4, 8), slice(None))),
       (4, 2),
       [0, 0, 1, 1, 2, 2, 3, 3], False),
      ("mesh_none_y", [None, "y"],
       ((slice(None), slice(0, 1)), (slice(None), slice(1, 2))),
       (8, 1),
       [0, 0, 1, 1, 2, 2, 3, 3], False),
      ("mesh_xy", [("x", "y")],
       ((slice(0, 1), slice(None)), (slice(1, 2), slice(None))),
       (1, 2),
       [0, 0, 0, 0, 0, 0, 0, 0], False),
      ("mesh_fully_replicated", [],
       ((slice(None), slice(None)), (slice(None), slice(None))),
       (8, 2),
       [0, 1, 2, 3, 4, 5, 6, 7], True),
  )
  def test_gda_2d_shard(self, mesh_axes, expected_index, expected_shard_shape,
                         expected_replica_ids, expected_is_fully_replicated):
    global_mesh = create_global_mesh((4, 2), ('x', 'y'))
    global_input_shape = (8, 2)
    global_input_data = np.arange(
        prod(global_input_shape)).reshape(global_input_shape)
    def cb(index):
      return global_input_data[index]
    gda = GlobalDeviceArray.from_callback(global_input_shape,
                                                  global_mesh,
                                                  mesh_axes, cb)
    self.assertEqual(gda.local_shards[0].index, expected_index[0])
    self.assertArraysEqual(gda.local_data(0),
                           global_input_data[expected_index[0]])
    self.assertEqual(gda.local_shards[1].index, expected_index[1])
    self.assertArraysEqual(gda.local_data(1),
                           global_input_data[expected_index[1]])
    self.assertEqual(gda.local_data(0).shape, expected_shard_shape)
    replica_ids = [i.replica_id for i in gda.local_shards]
    self.assertListEqual(replica_ids, expected_replica_ids)
    self.assertListEqual([i.device.id for i in gda.local_shards],
                         [0, 1, 2, 3, 4, 5, 6, 7])
    self.assertEqual(gda.is_fully_replicated, expected_is_fully_replicated)
    for s in gda.local_shards:
      self.assertEqual(s.data.aval,
                       core.ShapedArray(expected_shard_shape, s.data.dtype))
    for g, l in safe_zip(gda.global_shards, gda.local_shards):
      self.assertEqual(g.device, l.device)
      self.assertEqual(g.index, l.index)
      self.assertEqual(g.replica_id, l.replica_id)
      self.assertEqual(g.data.aval, l.data.aval)
      self.assertArraysEqual(g.data, l.data)


  @parameterized.named_parameters(
      ("mesh_x_y_z", ["x", "y", "z"],
       # There are more slices but for convienient purposes, checking for only
       # 2. The indices + shard_shape + replica_id should be unique enough.
       ((slice(0, 4), slice(0, 2), slice(0, 1)), (slice(0, 4), slice(0, 2), slice(1, 2))),
       (4, 2, 1),
       [0, 0, 0, 0, 0, 0, 0, 0]),
      ("mesh_xy_z", [("x", "y"), "z"],
       ((slice(0, 2), slice(0, 2), slice(None)), (slice(0, 2), slice(2, 4), slice(None))),
       (2, 2, 2),
       [0, 0, 0, 0, 0, 0, 0, 0]),
      ("mesh_z", ["z"],
       ((slice(0, 4), slice(None), slice(None)), (slice(4, 8), slice(None), slice(None))),
       (4, 4, 2),
       [0, 0, 1, 1, 2, 2, 3, 3]),
  )
  def test_gda_3d_shard(self, mesh_axes, expected_index, expected_shard_shape,
                         expected_replica_ids):
    global_mesh = create_global_mesh((2, 2, 2), ('x', 'y', 'z'))
    global_input_shape = (8, 4, 2)
    global_input_data = np.arange(
        prod(global_input_shape)).reshape(global_input_shape)
    def cb(index):
      return global_input_data[index]
    gda = GlobalDeviceArray.from_callback(global_input_shape,
                                                  global_mesh,
                                                  mesh_axes, cb)
    self.assertEqual(gda.local_shards[0].index, expected_index[0])
    self.assertArraysEqual(gda.local_data(0),
                           global_input_data[expected_index[0]])
    self.assertEqual(gda.local_shards[1].index, expected_index[1])
    self.assertArraysEqual(gda.local_data(1),
                           global_input_data[expected_index[1]])
    self.assertEqual(gda.local_data(0).shape, expected_shard_shape)

    replica_ids = [i.replica_id for i in gda.local_shards]
    self.assertListEqual(replica_ids, expected_replica_ids)

  @parameterized.named_parameters(
      ("mesh_x", ["x"],
       # There are more slices but for convienient purposes, checking for only
       # 2. The indices + shard_shape + replica_id should be unique enough.
       ((slice(0, 2),), (slice(2, 4),)),
       (2,),
       [0, 0, 0, 0, 0, 0, 0, 0]),
      ("mesh_none", [],
       ((slice(None),), (slice(None),)),
       (16,),
       [0, 1, 2, 3, 4, 5, 6, 7]),
  )
  def test_gda_1d_shard(self, mesh_axes, expected_index, expected_shard_shape,
                         expected_replica_ids):
    global_mesh = create_global_mesh((8,), ('x'))
    global_input_shape = (16,)
    global_input_data = np.arange(prod(global_input_shape)).reshape(-1)
    def cb(index):
      return global_input_data[index]
    gda = GlobalDeviceArray.from_callback(global_input_shape,
                                                  global_mesh,
                                                  mesh_axes, cb)
    self.assertEqual(gda.local_shards[0].index, expected_index[0])
    self.assertArraysEqual(gda.local_data(0),
                           global_input_data[expected_index[0]])
    self.assertEqual(gda.local_shards[1].index, expected_index[1])
    self.assertArraysEqual(gda.local_data(1),
                           global_input_data[expected_index[1]])
    self.assertEqual(gda.local_data(0).shape, expected_shard_shape)
    replica_ids = [i.replica_id for i in gda.local_shards]
    self.assertListEqual(replica_ids, expected_replica_ids)

  @parameterized.named_parameters(
      ("mesh_x_y", ["x", "y"],
       # There are more slices but for convienient purposes, checking for only
       # 2. The indices + shard_shape + replica_id should be unique enough.
       ((slice(0, 4), slice(0, 1)), (slice(0, 4), slice(1, 2))),
       (4, 1),
       [0, 0, 0, 0]),
  )
  def test_gda_subset_devices(self, mesh_axes, expected_index,
                               expected_shard_shape, expected_replica_ids):
    global_mesh = create_global_mesh((2, 2), ('x', 'y'))
    global_input_shape = (8, 2)
    global_input_data = np.arange(
        prod(global_input_shape)).reshape(global_input_shape)
    def cb(index):
      return global_input_data[index]
    gda = GlobalDeviceArray.from_callback(global_input_shape,
                                                  global_mesh,
                                                  mesh_axes, cb)
    self.assertEqual(gda.local_shards[0].index, expected_index[0])
    self.assertArraysEqual(gda.local_data(0),
                           global_input_data[expected_index[0]])
    self.assertEqual(gda.local_shards[1].index, expected_index[1])
    self.assertArraysEqual(gda.local_data(1),
                           global_input_data[expected_index[1]])
    self.assertEqual(gda.local_data(0).shape, expected_shard_shape)
    replica_ids = [i.replica_id for i in gda.local_shards]
    self.assertListEqual(replica_ids, expected_replica_ids)
    for g, l in safe_zip(gda.global_shards, gda.local_shards):
      self.assertEqual(g.device, l.device)
      self.assertEqual(g.index, l.index)
      self.assertEqual(g.replica_id, l.replica_id)
      self.assertArraysEqual(g.data, l.data)

  def test_gda_batched_callback(self):
    global_mesh = create_global_mesh((4, 2), ('x', 'y'))
    global_input_shape = (8, 2)
    mesh_axes = [('x', 'y')]
    global_input_data = np.arange(
        prod(global_input_shape)).reshape(global_input_shape)

    def cb(indices):
      self.assertEqual(len(indices), len(global_mesh.local_devices))
      return [global_input_data[index] for index in indices]

    gda = GlobalDeviceArray.from_batched_callback(
        global_input_shape, global_mesh, mesh_axes, cb)
    expected_first_shard_value = np.array([[0, 1]])
    self.assertArraysEqual(gda.local_data(0).to_py(),
                           expected_first_shard_value)
    expected_second_shard_value = np.array([[2, 3]])
    self.assertArraysEqual(gda.local_data(1).to_py(),
                           expected_second_shard_value)

  def test_gda_batched_callback_with_devices(self):
    global_mesh = create_global_mesh((4, 2), ('x', 'y'))
    global_input_shape = (8, 2)
    mesh_axes = ['x']
    global_input_data = np.arange(
        prod(global_input_shape), dtype=np.float32).reshape(global_input_shape)

    def cb(cb_inp):
      self.assertLen(cb_inp, 4)
      dbs = []
      for inp in cb_inp:
        index, devices = inp
        self.assertLen(devices, 2)
        array = global_input_data[index]
        dbs.extend([jax.device_put(array, device) for device in devices])
      return dbs

    gda = GlobalDeviceArray.from_batched_callback_with_devices(
        global_input_shape, global_mesh, mesh_axes, cb)
    expected_first_shard_value = np.array([[0, 1], [2, 3]], dtype=np.float32)
    self.assertArraysEqual(gda.local_data(0).to_py(),
                           expected_first_shard_value)
    expected_second_shard_value = np.array([[0, 1], [2, 3]], dtype=np.float32)
    self.assertArraysEqual(gda.local_data(1).to_py(),
                           expected_second_shard_value)

  def test_gda_str_repr(self):
    global_mesh = create_global_mesh((4, 2), ('x', 'y'))
    global_input_shape = (8, 2)
    mesh_axes = [('x', 'y')]
    global_input_data = np.arange(
        prod(global_input_shape)).reshape(global_input_shape)
    def cb(index):
      return global_input_data[index]
    gda = GlobalDeviceArray.from_callback(
        global_input_shape, global_mesh, mesh_axes, cb)
    self.assertEqual(str(gda),
                     'GlobalDeviceArray(shape=(8, 2), dtype=int32)')
    self.assertEqual(
        repr(gda),
        ("GlobalDeviceArray(shape=(8, 2), dtype=int32, "
        "global_mesh_shape={'x': 4, 'y': 2}, mesh_axes=[('x', 'y')])"))
Пример #16
0
 def testNoopPartitionSpecs(self):
   noops = [P(), P(None), P(()), P((), None), P(None, None, ())]
   x = jnp.arange(8).reshape((2, 2, 2))
   for spec in noops:
     y = pjit(lambda x: x * 2, in_axis_resources=spec, out_axis_resources=spec)(x)
     self.assertAllClose(y, x * 2)
Пример #17
0
    def test_checkpointing(self):
        global_mesh = create_global_mesh((4, 2), ('x', 'y'))
        global_input_shape = (8, 2)
        mesh_axes = P('x', 'y')
        num = util.prod(global_input_shape)

        # First GDA
        global_input_data1 = np.arange(num).reshape(global_input_shape)

        def cb1(index):
            return global_input_data1[index]

        gda1 = GlobalDeviceArray.from_callback(global_input_shape, global_mesh,
                                               mesh_axes, cb1)
        ckpt_dir1 = pathlib.Path(self.create_tempdir('first').full_path)

        # Second GDA
        global_input_data2 = np.arange(num,
                                       num + num).reshape(global_input_shape)

        def cb2(index):
            return global_input_data2[index]

        gda2 = GlobalDeviceArray.from_callback(global_input_shape, global_mesh,
                                               mesh_axes, cb2)
        ckpt_dir2 = pathlib.Path(self.create_tempdir('second').full_path)

        # Third GDA
        def cb3(index):
            return np.array([])

        global_mesh1d = create_global_mesh((8, ), ('x', ))
        gda3 = GlobalDeviceArray.from_callback((0, ), global_mesh1d, P(None),
                                               cb3)
        ckpt_dir3 = pathlib.Path(self.create_tempdir('third').full_path)

        ckpt_paths = [str(ckpt_dir1), str(ckpt_dir2), str(ckpt_dir3)]
        tspecs = jax.tree_map(serialization.get_tensorstore_spec, ckpt_paths)

        serialization.run_serialization([gda1, gda2, gda3], tspecs)

        m1, m2, m3 = serialization.run_deserialization(
            [global_mesh, global_mesh, global_mesh1d],
            [mesh_axes, P('x'), P(None)], tspecs)

        self.assertArraysEqual(m1.local_shards[0].data.to_py(),
                               np.array([[0], [2]]))
        self.assertArraysEqual(m1.local_shards[1].data.to_py(),
                               np.array([[1], [3]]))
        self.assertEqual(m1.local_shards[0].data.shape, (2, 1))
        self.assertEqual(m1.dtype, np.int32)

        self.assertArraysEqual(m2.local_shards[0].data.to_py(),
                               np.array([[16, 17], [18, 19]]))
        self.assertArraysEqual(m2.local_shards[1].data.to_py(),
                               np.array([[16, 17], [18, 19]]))
        self.assertEqual(m2.local_shards[0].data.shape, (2, 2))
        self.assertEqual(m2.dtype, np.int32)

        for i, s in enumerate(m3.local_shards):
            self.assertEqual(s.index, (slice(None), ))
            self.assertEqual(s.replica_id, i)
            self.assertArraysEqual(s.data.to_py(), np.array([]))
        self.assertEqual(m3.dtype, np.float32)
Пример #18
0
 def f(x):
   return lax.while_loop(lambda i: i[0,0] < 10.,
                         lambda i: with_sharding_constraint(i + 1., P(2, 1)),
                         x)
Пример #19
0
 def f(x):
   y = x + 1
   y = with_sharding_constraint(y, P('x', 'y'))
   return y * 2
Пример #20
0
 def f(x):
   y = x + 1
   p, vjp_f = vjp(lambda z: jnp.sin(with_sharding_constraint(z, P(2,2))), y)
   return vjp_f(p)
Пример #21
0
 def testNonHashableAxisResources(self):
   x = jnp.arange(4)
   y = pjit(lambda x: {'b': x['a'] + 2},
            in_axis_resources=({'a': P('x')},),
            out_axis_resources={'b': P('x')})({'a': x})
   self.assertAllClose(y, {'b': x + 2})
    def __init__(self, config):
        self.config = config
        optimizer = config["optimizer"]

        bf16_optimizer = config.get("bf16_optimizer", False)
        early_cast = config.get("early_cast", False)
        early_collect = config.get("early_collect", True)

        def embedding(x):
            x = maybe_shard(x, P("dp", None))
            return EmbeddingShardV2(config)(x)

        def residual(x, mask):
            out = x + TransformerLayerShardV2(
                config, init_scale=2. / config["layers"])(x, mask)
            return maybe_shard(out, P("dp", None, "mp"))

        def transformer(x, mask):
            return hk.remat(residual)(x, mask)

        def projection(x):
            return Projection(config)(x)

        def init_fns():
            embed_init_fn = hk.transform(
                hk.experimental.optimize_rng_use(embedding)).init
            transformer_init_fn = hk.transform(
                hk.experimental.optimize_rng_use(transformer)).init
            projection_init_fn = hk.transform(
                hk.experimental.optimize_rng_use(projection)).init

            return embed_init_fn, transformer_init_fn, projection_init_fn

        def shard_strategy(shape_dtype, parallel):
            if shape_dtype.ndim <= 1:
                return P()
            # embedding/projection layers
            elif shape_dtype.shape == (config["n_vocab"], config["d_model"]):
                return P(parallel, None)
            elif shape_dtype.shape == (config["d_model"], config["n_vocab"]):
                return P(None, parallel)

            # a transformer layer
            elif shape_dtype.shape[0] == config["layers"]:
                if shape_dtype.ndim == 2:
                    # a channel wise variable (e.g. layernorm parameters)
                    # replicate it for speed
                    return P(None)
                elif shape_dtype.ndim == 3:
                    # a weight matrix
                    matrix_size = shape_dtype.shape[1:]

                    assert matrix_size[0] != matrix_size[
                        1]  # this case is ambiguous

                    if matrix_size[0] == config["d_model"]:
                        # shard along the axis which is _not_ the model dimension
                        return P(None, None, parallel)
                    elif matrix_size[1] == config["d_model"]:
                        return P(None, parallel, None)
                else:
                    raise NotImplementedError("borked")

            else:
                raise NotImplementedError("borked")

        def init(key, x):
            embed_init_fn, transformer_init_fn, projection_init_fn = init_fns()

            def init_scan_fn(key, x):
                new_key, key = jax.random.split(key)

                return new_key, transformer_init_fn(key, x, 0)

            e_key, t_key, p_key = jax.random.split(key, 3)

            input_shape = (config["layers"], ) + x.shape + (
                config["d_model"], )

            params = {
                "embed":
                embed_init_fn(e_key, x),
                "transformer":
                jax.lax.scan(init_scan_fn,
                             t_key,
                             xs=jax.random.uniform(t_key,
                                                   input_shape,
                                                   dtype=jnp.float32))[1],
                "proj":
                projection_init_fn(
                    p_key,
                    jax.random.uniform(t_key,
                                       input_shape[1:],
                                       dtype=jnp.float32)),
            }

            return {
                "params": (to_bf16 if early_cast else to_f32)(params),
                "step":
                np.array(0),
                "opt_state":
                optimizer.init((to_bf16 if bf16_optimizer else to_f32)(params))
            }

        assert thread_resources.env.shape['mp'] == config["cores_per_replica"]

        dp = thread_resources.env.shape['dp']
        mp = thread_resources.env.shape['mp']

        key = hk.PRNGSequence(42)
        x = jax.random.uniform(next(key), (mp * dp, 16), minval=0,
                               maxval=1).astype(jnp.uint32)  # batch, seq

        head_print("starting shape evaluation")

        param_shapes = jax.eval_shape(init, jax.random.PRNGKey(42), x)

        state_shard = {
            "step":
            P(),

            # zero level 1: shard optimizer states over both MP and DP
            "opt_state":
            jax.tree_map(partial(shard_strategy, parallel=["mp", "dp"]),
                         param_shapes["opt_state"]),

            # fp32 params are also sharded (so this is like a weird mix between zero-1 and zero-3...)
            "params":
            jax.tree_map(partial(shard_strategy, parallel=["mp", "dp"]),
                         param_shapes["params"]),
        }

        head_print("sharding strategy:")
        jax.tree_multimap(head_print, state_shard, param_shapes)

        self.init_pjit = pjit(init,
                              in_axis_resources=(None, P("dp")),
                              out_axis_resources=state_shard)

        def apply_fns():
            embed_apply_fn = hk.without_apply_rng(
                hk.transform(embedding)).apply
            transformer_apply_fn = hk.without_apply_rng(
                hk.transform(transformer)).apply

            return embed_apply_fn, transformer_apply_fn

        def train_apply_fn(params, x, y):
            embed_apply_fn, transformer_apply_fn = apply_fns()

            def train_loss(x, y):
                loss, _ = Projection(config).loss(x, y, z_loss=1.0)
                return loss.mean(), loss[:, -1].mean()

            projection_apply_fn = hk.without_apply_rng(
                hk.transform(train_loss)).apply

            x = embed_apply_fn(params["embed"], x)
            x = to_bf16(x)

            def apply_scan_fn(x, layer_state):
                return to_bf16(transformer_apply_fn(layer_state, x, 0)), None

            x = jax.lax.scan(apply_scan_fn, x, xs=params["transformer"])[0]

            return projection_apply_fn(params["proj"], x, y)

        mp_shard_strategy = jax.tree_map(
            partial(shard_strategy, parallel=["mp"]), param_shapes["params"])

        def train(state, ctx, tgt):
            if early_collect:
                bf16_params = maybe_shard(to_bf16(state["params"]),
                                          mp_shard_strategy)
            else:
                bf16_params = to_bf16(state["params"])

            def microbatch(old_grad, batch):
                ctx, tgt = batch

                val_grad_fn = jax.value_and_grad(train_apply_fn,
                                                 has_aux=True,
                                                 allow_int=True)
                (loss, last_loss), grad = val_grad_fn(bf16_params, ctx, tgt)

                new_grad = jax.tree_multimap(lambda a, b: a + b, old_grad,
                                             grad)
                return new_grad, (loss, last_loss)

            if ctx.shape[0] == 1:
                val_grad_fn = jax.value_and_grad(train_apply_fn,
                                                 has_aux=True,
                                                 allow_int=True)
                (loss, last_loss), grad = val_grad_fn(bf16_params, ctx[0],
                                                      tgt[0])
            else:
                grad, (loss, last_loss) = jax.lax.scan(
                    microbatch,
                    jax.tree_map(
                        lambda x: jnp.zeros_like(x).astype(jnp.bfloat16),
                        bf16_params), (ctx, tgt))

            updates, new_opt_state = optimizer.update(grad, state["opt_state"],
                                                      state["params"])

            return to_f32(loss), to_f32(last_loss), {
                "params": optax.apply_updates(state["params"],
                                              to_f32(updates)),
                "step": state["step"] + 1,
                "opt_state": new_opt_state,
            }

        self.train_pjit = pjit(train,
                               in_axis_resources=(state_shard, P(None, "dp"),
                                                  P(None, "dp")),
                               out_axis_resources=(None, None, state_shard),
                               donate_argnums=(0, ))

        def eval_apply_fn(params, x, y, mask):
            embed_apply_fn, transformer_apply_fn = apply_fns()

            if early_collect:
                bf16_params = maybe_shard(to_bf16(params), mp_shard_strategy)
            else:
                bf16_params = to_bf16(params)

            def eval_loss(x, y):
                loss, correct = Projection(config).loss(x, y)
                return {
                    "loss": loss.mean(axis=-1),
                    "last_loss": loss[:, -1],
                    "all_loss": loss,
                    "correct": correct
                }

            projection_apply_fn = hk.without_apply_rng(
                hk.transform(eval_loss)).apply

            x = embed_apply_fn(bf16_params["embed"], x)

            def apply_scan_fn(layer_in, layer_state):
                x, mask = layer_in
                return (to_bf16(transformer_apply_fn(layer_state, x,
                                                     mask)), mask), None

            x = jax.lax.scan(apply_scan_fn, (to_bf16(x), mask),
                             xs=bf16_params["transformer"])[0][0]

            return projection_apply_fn(bf16_params["proj"], x, y)

        def eval(params, ctx, tgt, ctx_length):
            mask = (jnp.arange(0, ctx.shape[1])[None, :] >
                    ctx_length[:, None]) * -1e10

            # head_print("mask.shape", mask.shape)
            # head_print("ctx.shape", ctx.shape)
            # head_print("ctx_length.shape", ctx_length.shape)

            return eval_apply_fn(params, ctx, tgt, mask[:, None, None, :])

        self.eval_pjit = pjit(
            eval,
            in_axis_resources=(mp_shard_strategy
                               if early_collect else state_shard["params"],
                               P("dp"), P("dp"), P("dp")),
            out_axis_resources=P("dp"))

        self.move_weights_pjit = pjit(
            lambda x: to_bf16(x),
            in_axis_resources=(state_shard["params"], ),
            out_axis_resources=mp_shard_strategy
            if early_collect else state_shard["params"])

        seq = config["seq"]
        vocab = config["n_vocab"]

        example_shape = (
            max(dp // jax.host_count(), 1),
            seq,
        )
        x = jax.random.uniform(next(key),
                               example_shape,
                               minval=0,
                               maxval=vocab).astype(jnp.uint32)  # batch, len

        head_print("in shape", x.shape)

        head_print("dp", dp)
        head_print("mp", mp)

        self.state = self.init_pjit(next(key), x)
        self.state_shard = state_shard
        self.eval_weights = None

        param_count = hk.data_structures.tree_size(self.state['params'])
        head_print(f"Total parameters: {param_count * dp}")