def test_checkpointing(self): global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y')) global_input_shape = (8, 2) mesh_axes = P('x', 'y') num = util.prod(global_input_shape) # First GDA global_input_data1 = np.arange(num).reshape(global_input_shape) def cb1(index): return global_input_data1[index] gda1 = GlobalDeviceArray.from_callback(global_input_shape, global_mesh, mesh_axes, cb1) ckpt_dir1 = pathlib.Path(self.create_tempdir('first').full_path) # Second GDA global_input_data2 = np.arange(num, num + num).reshape(global_input_shape) def cb2(index): return global_input_data2[index] gda2 = GlobalDeviceArray.from_callback(global_input_shape, global_mesh, mesh_axes, cb2) ckpt_dir2 = pathlib.Path(self.create_tempdir('second').full_path) # Third GDA def cb3(index): return np.array([]) global_mesh1d = jtu.create_global_mesh((8,), ('x',)) gda3 = GlobalDeviceArray.from_callback((0,), global_mesh1d, P(None), cb3) ckpt_dir3 = pathlib.Path(self.create_tempdir('third').full_path) ckpt_paths = [str(ckpt_dir1), str(ckpt_dir2), str(ckpt_dir3)] tspecs = jax.tree_map(serialization.get_tensorstore_spec, ckpt_paths) serialization.run_serialization([gda1, gda2, gda3], tspecs) m1, m2, m3 = serialization.run_deserialization( [global_mesh, global_mesh, global_mesh1d], [mesh_axes, P('x'), P(None)], tspecs) self.assertArraysEqual(m1.local_shards[0].data.to_py(), np.array([[0], [2]])) self.assertArraysEqual(m1.local_shards[1].data.to_py(), np.array([[1], [3]])) self.assertEqual(m1.local_shards[0].data.shape, (2, 1)) self.assertEqual(m1.dtype, np.int32) self.assertArraysEqual(m2.local_shards[0].data.to_py(), np.array([[16, 17], [18, 19]])) self.assertArraysEqual(m2.local_shards[1].data.to_py(), np.array([[16, 17], [18, 19]])) self.assertEqual(m2.local_shards[0].data.shape, (2, 2)) self.assertEqual(m2.dtype, np.int32) for i, s in enumerate(m3.local_shards): self.assertEqual(s.index, (slice(None),)) self.assertEqual(s.replica_id, i) self.assertArraysEqual(s.data.to_py(), np.array([])) self.assertEqual(m3.dtype, np.float32)
def test_checkpointing_with_bigger_shape(self): global_mesh = create_global_mesh((2, 2), ('x', 'y')) global_input_shape = (8, 2) num = util.prod(global_input_shape) # First GDA global_input_data1 = np.arange(num).reshape(global_input_shape) def cb1(index): return global_input_data1[index] gda1 = GlobalDeviceArray.from_callback(global_input_shape, global_mesh, P('x', 'y'), cb1) ckpt_dir1 = pathlib.Path(self.create_tempdir('first').full_path) ckpt_paths = [str(ckpt_dir1)] tspecs = jax.tree_map(serialization.get_tensorstore_spec, ckpt_paths) serialization.run_serialization([gda1], tspecs) m1, = serialization.run_deserialization( [create_global_mesh((4, 2), ('x', 'y'))], [P('x', 'y')], tspecs, [(12, 2)], ) expected_data = { 0: np.array([[0], [2], [4]]), 1: np.array([[1], [3], [5]]), 2: np.array([[6], [8], [10]]), 3: np.array([[7], [9], [11]]), 4: np.array([[12], [14], [0]]), 5: np.array([[13], [15], [0]]), 6: np.array([[0], [0], [0]]), 7: np.array([[0], [0], [0]]), } for l in m1.local_shards: self.assertArraysEqual(l.data.to_py(), expected_data[l.device.id])
def _restore_checkpoint_gda( train_state: Optional[train_states.TrainState], checkpoint_dir: str, global_mesh: Optional[maps.Mesh], state_specs: Optional[train_states.TrainState], step: Optional[int] = None) -> train_states.TrainState: """Restores a checkpoint using JAX GDA deserialization mechanism.""" if not tf.io.gfile.exists(checkpoint_dir) or not tf.io.gfile.listdir( checkpoint_dir): if train_state is not None and step is None: logging.info( 'GDA checkpoint restore did not find checkpoint_dir %s; ' 'Return train_state passed in', checkpoint_dir) return train_state raise FileNotFoundError( f'No checkpoint found for restore in {checkpoint_dir}') if step is None: checkpoint_dirnames = tf.io.gfile.listdir(checkpoint_dir) tmp_checkpoint_dirnames = [ x for x in checkpoint_dirnames if _is_tmp_checkpoint_asset(x) ] if tmp_checkpoint_dirnames: logging.warn('Found incompletely saved checkpoints %s; skipping them', tmp_checkpoint_dirnames) sorted_dirnames = sorted( [x for x in checkpoint_dirnames if _is_checkpoint_asset(x)]) if not sorted_dirnames: raise FileNotFoundError( f'No checkpoint found for restore in {checkpoint_dir}') latest_checkpoint_dirname = sorted_dirnames[-1] step = get_step_from_checkpoint_asset(latest_checkpoint_dirname) checkpoint_step_dir = _make_checkpoint_step_dir(checkpoint_dir, step) logging.info('Found latest checkpoint: %s', checkpoint_step_dir) else: checkpoint_step_dir = _make_checkpoint_step_dir(checkpoint_dir, step) if not tf.io.gfile.exists(checkpoint_step_dir) or not tf.io.gfile.listdir( checkpoint_step_dir): raise FileNotFoundError( f'No checkpoint found for restore in {checkpoint_step_dir}') logging.info('GDA checkpoint restore started...') if train_state is not None: leaves, treedef = jax.tree_util.tree_flatten(train_state) partition_spec_leaves, _ = jax.tree_util.tree_flatten(state_specs) nested_names = _extract_nested_prefix_names(train_state) global_shapes = jax.tree_map(lambda x: x.shape, leaves) else: partition_spec_leaves, treedef = jax.tree_util.tree_flatten(state_specs) nested_names = _extract_nested_prefix_names(state_specs) global_shapes = None flattened_nested_names, _ = jax.tree_util.tree_flatten(nested_names) ckpt_paths = [ os.path.join(checkpoint_step_dir, x).rstrip('/') for x in flattened_nested_names ] tspecs = jax.tree_map(gda_serialization.get_tensorstore_spec, ckpt_paths) train_state_gda = gda_serialization.run_deserialization( [global_mesh] * len(tspecs), partition_spec_leaves, tspecs, global_shapes=global_shapes) restored_train_state = jax.tree_util.tree_unflatten(treedef, train_state_gda) # Barrier across all processes to ensure all restore finish. py_utils.sync_global_devices('Wait for checkpoint restore from ' f'{checkpoint_step_dir} to finish.') logging.info('Successfully restored GDA checkpoint at %s!', checkpoint_step_dir) return restored_train_state