예제 #1
0
  def test_checkpointing(self):
    global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
    global_input_shape = (8, 2)
    mesh_axes = P('x', 'y')
    num = util.prod(global_input_shape)

    # First GDA
    global_input_data1 = np.arange(num).reshape(global_input_shape)
    def cb1(index):
      return global_input_data1[index]
    gda1 = GlobalDeviceArray.from_callback(global_input_shape, global_mesh,
                                           mesh_axes, cb1)
    ckpt_dir1 = pathlib.Path(self.create_tempdir('first').full_path)

    # Second GDA
    global_input_data2 = np.arange(num, num + num).reshape(global_input_shape)
    def cb2(index):
      return global_input_data2[index]
    gda2 = GlobalDeviceArray.from_callback(global_input_shape, global_mesh,
                                           mesh_axes, cb2)
    ckpt_dir2 = pathlib.Path(self.create_tempdir('second').full_path)

    # Third GDA
    def cb3(index):
      return np.array([])
    global_mesh1d = jtu.create_global_mesh((8,), ('x',))
    gda3 = GlobalDeviceArray.from_callback((0,), global_mesh1d, P(None), cb3)
    ckpt_dir3 = pathlib.Path(self.create_tempdir('third').full_path)

    ckpt_paths = [str(ckpt_dir1), str(ckpt_dir2), str(ckpt_dir3)]
    tspecs = jax.tree_map(serialization.get_tensorstore_spec, ckpt_paths)

    serialization.run_serialization([gda1, gda2, gda3], tspecs)

    m1, m2, m3 = serialization.run_deserialization(
        [global_mesh, global_mesh, global_mesh1d],
        [mesh_axes, P('x'), P(None)],
        tspecs)

    self.assertArraysEqual(m1.local_shards[0].data.to_py(),
                           np.array([[0], [2]]))
    self.assertArraysEqual(m1.local_shards[1].data.to_py(),
                           np.array([[1], [3]]))
    self.assertEqual(m1.local_shards[0].data.shape, (2, 1))
    self.assertEqual(m1.dtype, np.int32)

    self.assertArraysEqual(m2.local_shards[0].data.to_py(),
                           np.array([[16, 17], [18, 19]]))
    self.assertArraysEqual(m2.local_shards[1].data.to_py(),
                           np.array([[16, 17], [18, 19]]))
    self.assertEqual(m2.local_shards[0].data.shape, (2, 2))
    self.assertEqual(m2.dtype, np.int32)

    for i, s in enumerate(m3.local_shards):
      self.assertEqual(s.index, (slice(None),))
      self.assertEqual(s.replica_id, i)
      self.assertArraysEqual(s.data.to_py(), np.array([]))
    self.assertEqual(m3.dtype, np.float32)
예제 #2
0
    def test_checkpointing_with_bigger_shape(self):
        global_mesh = create_global_mesh((2, 2), ('x', 'y'))
        global_input_shape = (8, 2)
        num = util.prod(global_input_shape)

        # First GDA
        global_input_data1 = np.arange(num).reshape(global_input_shape)

        def cb1(index):
            return global_input_data1[index]

        gda1 = GlobalDeviceArray.from_callback(global_input_shape, global_mesh,
                                               P('x', 'y'), cb1)
        ckpt_dir1 = pathlib.Path(self.create_tempdir('first').full_path)

        ckpt_paths = [str(ckpt_dir1)]
        tspecs = jax.tree_map(serialization.get_tensorstore_spec, ckpt_paths)

        serialization.run_serialization([gda1], tspecs)

        m1, = serialization.run_deserialization(
            [create_global_mesh((4, 2), ('x', 'y'))],
            [P('x', 'y')],
            tspecs,
            [(12, 2)],
        )

        expected_data = {
            0: np.array([[0], [2], [4]]),
            1: np.array([[1], [3], [5]]),
            2: np.array([[6], [8], [10]]),
            3: np.array([[7], [9], [11]]),
            4: np.array([[12], [14], [0]]),
            5: np.array([[13], [15], [0]]),
            6: np.array([[0], [0], [0]]),
            7: np.array([[0], [0], [0]]),
        }

        for l in m1.local_shards:
            self.assertArraysEqual(l.data.to_py(), expected_data[l.device.id])
예제 #3
0
def _restore_checkpoint_gda(
    train_state: Optional[train_states.TrainState],
    checkpoint_dir: str,
    global_mesh: Optional[maps.Mesh],
    state_specs: Optional[train_states.TrainState],
    step: Optional[int] = None) -> train_states.TrainState:
  """Restores a checkpoint using JAX GDA deserialization mechanism."""
  if not tf.io.gfile.exists(checkpoint_dir) or not tf.io.gfile.listdir(
      checkpoint_dir):
    if train_state is not None and step is None:
      logging.info(
          'GDA checkpoint restore did not find checkpoint_dir %s; '
          'Return train_state passed in', checkpoint_dir)
      return train_state
    raise FileNotFoundError(
        f'No checkpoint found for restore in {checkpoint_dir}')

  if step is None:
    checkpoint_dirnames = tf.io.gfile.listdir(checkpoint_dir)
    tmp_checkpoint_dirnames = [
        x for x in checkpoint_dirnames if _is_tmp_checkpoint_asset(x)
    ]
    if tmp_checkpoint_dirnames:
      logging.warn('Found incompletely saved checkpoints %s; skipping them',
                   tmp_checkpoint_dirnames)
    sorted_dirnames = sorted(
        [x for x in checkpoint_dirnames if _is_checkpoint_asset(x)])
    if not sorted_dirnames:
      raise FileNotFoundError(
          f'No checkpoint found for restore in {checkpoint_dir}')
    latest_checkpoint_dirname = sorted_dirnames[-1]
    step = get_step_from_checkpoint_asset(latest_checkpoint_dirname)
    checkpoint_step_dir = _make_checkpoint_step_dir(checkpoint_dir, step)
    logging.info('Found latest checkpoint: %s', checkpoint_step_dir)
  else:
    checkpoint_step_dir = _make_checkpoint_step_dir(checkpoint_dir, step)
    if not tf.io.gfile.exists(checkpoint_step_dir) or not tf.io.gfile.listdir(
        checkpoint_step_dir):
      raise FileNotFoundError(
          f'No checkpoint found for restore in {checkpoint_step_dir}')

  logging.info('GDA checkpoint restore started...')
  if train_state is not None:
    leaves, treedef = jax.tree_util.tree_flatten(train_state)
    partition_spec_leaves, _ = jax.tree_util.tree_flatten(state_specs)
    nested_names = _extract_nested_prefix_names(train_state)
    global_shapes = jax.tree_map(lambda x: x.shape, leaves)
  else:
    partition_spec_leaves, treedef = jax.tree_util.tree_flatten(state_specs)
    nested_names = _extract_nested_prefix_names(state_specs)
    global_shapes = None

  flattened_nested_names, _ = jax.tree_util.tree_flatten(nested_names)

  ckpt_paths = [
      os.path.join(checkpoint_step_dir, x).rstrip('/')
      for x in flattened_nested_names
  ]
  tspecs = jax.tree_map(gda_serialization.get_tensorstore_spec, ckpt_paths)

  train_state_gda = gda_serialization.run_deserialization(
      [global_mesh] * len(tspecs),
      partition_spec_leaves,
      tspecs,
      global_shapes=global_shapes)

  restored_train_state = jax.tree_util.tree_unflatten(treedef, train_state_gda)
  # Barrier across all processes to ensure all restore finish.
  py_utils.sync_global_devices('Wait for checkpoint restore from '
                               f'{checkpoint_step_dir} to finish.')
  logging.info('Successfully restored GDA checkpoint at %s!',
               checkpoint_step_dir)
  return restored_train_state