Example #1
0
  def test_checkpointing(self):
    global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
    global_input_shape = (8, 2)
    mesh_axes = P('x', 'y')
    num = util.prod(global_input_shape)

    # First GDA
    global_input_data1 = np.arange(num).reshape(global_input_shape)
    def cb1(index):
      return global_input_data1[index]
    gda1 = GlobalDeviceArray.from_callback(global_input_shape, global_mesh,
                                           mesh_axes, cb1)
    ckpt_dir1 = pathlib.Path(self.create_tempdir('first').full_path)

    # Second GDA
    global_input_data2 = np.arange(num, num + num).reshape(global_input_shape)
    def cb2(index):
      return global_input_data2[index]
    gda2 = GlobalDeviceArray.from_callback(global_input_shape, global_mesh,
                                           mesh_axes, cb2)
    ckpt_dir2 = pathlib.Path(self.create_tempdir('second').full_path)

    # Third GDA
    def cb3(index):
      return np.array([])
    global_mesh1d = jtu.create_global_mesh((8,), ('x',))
    gda3 = GlobalDeviceArray.from_callback((0,), global_mesh1d, P(None), cb3)
    ckpt_dir3 = pathlib.Path(self.create_tempdir('third').full_path)

    ckpt_paths = [str(ckpt_dir1), str(ckpt_dir2), str(ckpt_dir3)]
    tspecs = jax.tree_map(serialization.get_tensorstore_spec, ckpt_paths)

    serialization.run_serialization([gda1, gda2, gda3], tspecs)

    m1, m2, m3 = serialization.run_deserialization(
        [global_mesh, global_mesh, global_mesh1d],
        [mesh_axes, P('x'), P(None)],
        tspecs)

    self.assertArraysEqual(m1.local_shards[0].data.to_py(),
                           np.array([[0], [2]]))
    self.assertArraysEqual(m1.local_shards[1].data.to_py(),
                           np.array([[1], [3]]))
    self.assertEqual(m1.local_shards[0].data.shape, (2, 1))
    self.assertEqual(m1.dtype, np.int32)

    self.assertArraysEqual(m2.local_shards[0].data.to_py(),
                           np.array([[16, 17], [18, 19]]))
    self.assertArraysEqual(m2.local_shards[1].data.to_py(),
                           np.array([[16, 17], [18, 19]]))
    self.assertEqual(m2.local_shards[0].data.shape, (2, 2))
    self.assertEqual(m2.dtype, np.int32)

    for i, s in enumerate(m3.local_shards):
      self.assertEqual(s.index, (slice(None),))
      self.assertEqual(s.replica_id, i)
      self.assertArraysEqual(s.data.to_py(), np.array([]))
    self.assertEqual(m3.dtype, np.float32)
Example #2
0
    def test_checkpointing_with_bigger_shape(self):
        global_mesh = create_global_mesh((2, 2), ('x', 'y'))
        global_input_shape = (8, 2)
        num = util.prod(global_input_shape)

        # First GDA
        global_input_data1 = np.arange(num).reshape(global_input_shape)

        def cb1(index):
            return global_input_data1[index]

        gda1 = GlobalDeviceArray.from_callback(global_input_shape, global_mesh,
                                               P('x', 'y'), cb1)
        ckpt_dir1 = pathlib.Path(self.create_tempdir('first').full_path)

        ckpt_paths = [str(ckpt_dir1)]
        tspecs = jax.tree_map(serialization.get_tensorstore_spec, ckpt_paths)

        serialization.run_serialization([gda1], tspecs)

        m1, = serialization.run_deserialization(
            [create_global_mesh((4, 2), ('x', 'y'))],
            [P('x', 'y')],
            tspecs,
            [(12, 2)],
        )

        expected_data = {
            0: np.array([[0], [2], [4]]),
            1: np.array([[1], [3], [5]]),
            2: np.array([[6], [8], [10]]),
            3: np.array([[7], [9], [11]]),
            4: np.array([[12], [14], [0]]),
            5: np.array([[13], [15], [0]]),
            6: np.array([[0], [0], [0]]),
            7: np.array([[0], [0], [0]]),
        }

        for l in m1.local_shards:
            self.assertArraysEqual(l.data.to_py(), expected_data[l.device.id])
Example #3
0
def _save_checkpoint_gda(train_state: train_states.TrainState,
                         checkpoint_dir: str, overwrite: bool,
                         step: int) -> None:
  """Saves a checkpoint using JAX GDA serialization mechanism.

  Note that all JAX processes must call _save_checkpoint_gda in sync because
  each process may only have a slice of the global data.

  Args:
    train_state: A partitioned train_state that is a Pytree of
      GlobalDeviceArray.
    checkpoint_dir: Full path to parent checkpoint_dir.
    overwrite: Whether to allow overwriting an existing target directory.
    step: Step to save checkpoint for.
  """
  if not overwrite:
    # Does not contain directory path, only dirname is returned.
    checkpoint_dirnames = tf.io.gfile.listdir(checkpoint_dir)
    # Delete tmp directories if any.
    if jax.process_index() == 0:
      tmp_checkpoint_dirnames = [
          x for x in checkpoint_dirnames if _is_tmp_checkpoint_asset(x)
      ]
      if tmp_checkpoint_dirnames:
        logging.warn('Found incompletely saved checkpoints %s; deleting them',
                     tmp_checkpoint_dirnames)
        for x in tmp_checkpoint_dirnames:
          tf.io.gfile.rmtree(os.path.join(checkpoint_dir, x))
    # Note we must barrier across all processes after the tmp directory delete.
    py_utils.sync_global_devices('Wait for checkpoint tmp dir deletions to '
                                 'finish.')

    sorted_dirnames = sorted(
        [x for x in checkpoint_dirnames if _is_checkpoint_asset(x)])
    if sorted_dirnames:
      latest_checkpoint_dirname = sorted_dirnames[-1]
      previous_step = get_step_from_checkpoint_asset(latest_checkpoint_dirname)
      if previous_step >= step:
        logging.warning(
            'A more recent checkpoint `%d` has already been saved compared '
            'to the current timestep `%d`. Skip saving a checkpoint.',
            previous_step, step)
        return

  checkpoint_step_dir = _make_checkpoint_step_dir(checkpoint_dir, step)
  checkpoint_step_tmp_dir = _make_tmp_checkpoint_dir(
      checkpoint_dir, step, sync_timestamp=True)
  logging.info('Saving to a tmp checkpoint dir %s', checkpoint_step_tmp_dir)

  nested_names = _extract_nested_prefix_names(train_state)
  flattened_nested_names, _ = jax.tree_util.tree_flatten(nested_names)

  if jax.process_index() == 0:
    # Create the tmp parent dir.
    tf.io.gfile.makedirs(checkpoint_step_tmp_dir)

  with futures.ThreadPoolExecutor() as executor:
    ckpt_paths = list(
        executor.map(_mkdir_path, flattened_nested_names,
                     [checkpoint_step_tmp_dir] * len(flattened_nested_names)))
  py_utils.sync_global_devices('Wait for checkpoint tmp dir and subdirs '
                               f'creation {checkpoint_step_tmp_dir} to finish.')

  tspecs = jax.tree_map(gda_serialization.get_tensorstore_spec, ckpt_paths)
  leaves, _ = jax.tree_util.tree_flatten(train_state)

  gda_serialization.run_serialization(leaves, tspecs)

  # Note we must barrier across all processes before the directory rename.
  py_utils.sync_global_devices('Wait for checkpoint chunk writes to '
                               f'{checkpoint_step_tmp_dir} to finish.')

  if jax.process_index() == 0:
    # Rename temporary checkpoint directory to its final location.
    logging.info('Renaming %s to %s', checkpoint_step_tmp_dir,
                 checkpoint_step_dir)
    tf.io.gfile.rename(checkpoint_step_tmp_dir, checkpoint_step_dir)

  logging.info('Finished saving GDA checkpoint for step `%s` to `%s`.', step,
               checkpoint_step_dir)