def append_pytree(self, pytree): """Append and record a serializable pytree to disk. The pytree will be saved to disk as a list of pytree objects. Everytime this function is called, it will load the previous saved state, append the next pytree to the list, then save the appended list. Args: pytree: Any serializable pytree. """ # Read the latest (and only) checkpoint, then append the new state to it # before saving back to disk. old_state = flax_checkpoints.restore_checkpoint(self._pytree_path, target=None) # Because we pass target=None, flax checkpointing will return the raw state # dict, where 'pytree' will be a dict with keys ['0', '1', ...] instead of a # list. if old_state: state_list = old_state['pytree'] state_list = [state_list[str(i)] for i in range(len(state_list))] else: state_list = [] state_list.append(pytree) state = checkpoint.CheckpointState(state_list) checkpoint.save_checkpoint( self._pytree_path, 'training_metrics', state, max_to_keep=None, use_deprecated_checkpointing=self._use_deprecated_checkpointing)
def test_save_checkpoint_background_reraises_error(self): """Test than an error while saving a checkpoint is re-raised later.""" # Checkpoint error is not raised when it actually happens, but when we next # write a checkpoint. baz = ['a', 'b', 'ccc'] state = checkpoint.CheckpointState(self.flax_module.params, global_step=5, completed_epochs=4, baz=baz) checkpoint.save_checkpoint_background('/forbidden_directory/', 'checkpoint', state) with self.assertRaisesRegex(BaseException, r'Permission\sdenied'): checkpoint.save_checkpoint_background(self.test_dir, 'checkpoint', state)
def test_save_load_roundtrip(self, use_deprecated_checkpointing): """Test that saving and loading produces the original state.""" baz = ['a', 'b', 'ccc'] state = checkpoint.CheckpointState(self.flax_module.params, global_step=5, completed_epochs=4, baz=baz) checkpoint.save_checkpoint( self.test_dir, 'checkpoint', state, use_deprecated_checkpointing=use_deprecated_checkpointing) latest = checkpoint.load_latest_checkpoint( self.test_dir, target=state, use_deprecated_checkpointing=use_deprecated_checkpointing) self.assertEqual(latest.pystate['baz'], baz) assert pytree_equal(latest.pytree, self.flax_module.params) self.assertEqual(latest.pystate['global_step'], 5) self.assertEqual(latest.pystate['completed_epochs'], 4)
def _maybe_restore_latest_checkpoint(unreplicated_optimizer, unreplicated_batch_stats, unreplicated_training_metrics_grabber, train_dir, use_deprecated_checkpointing): """Restore from the latest checkpoint, if it exists.""" unreplicated_checkpoint_state = checkpoint.CheckpointState( { 'optimizer': unreplicated_optimizer, 'batch_stats': unreplicated_batch_stats, 'training_metrics_grabber': unreplicated_training_metrics_grabber, }, global_step=0, preemption_count=0, sum_train_cost=0.0) latest = checkpoint.load_latest_checkpoint( train_dir, target=unreplicated_checkpoint_state, recents_filename='latest', use_deprecated_checkpointing=use_deprecated_checkpointing) optimizer = jax_utils.replicate(unreplicated_optimizer) batch_stats = jax_utils.replicate(unreplicated_batch_stats) training_metrics_grabber = jax_utils.replicate( unreplicated_training_metrics_grabber) if latest is None: return optimizer, batch_stats, training_metrics_grabber, 0, 0.0, 0, False pytree_dict, extra_state = restore_checkpoint( latest, replicated_pytree={ 'optimizer': optimizer, 'batch_stats': batch_stats, 'training_metrics_grabber': training_metrics_grabber, }, use_deprecated_checkpointing=use_deprecated_checkpointing) return (pytree_dict['optimizer'], pytree_dict['batch_stats'], pytree_dict['training_metrics_grabber'], extra_state['global_step'], extra_state['sum_train_cost'], extra_state['preemption_count'], True)
def save_checkpoint(train_dir, pytree, global_step, preemption_count, sum_train_cost, max_to_keep=1, use_deprecated_checkpointing=True): """Saves the pytree to train_dir.""" checkpoint_name = 'ckpt_{}'.format(global_step) logging.info('Saving checkpoint to %s', checkpoint_name) unstructured_state = jax.device_get( [x[0] for x in jax.tree_leaves(pytree)]) state = checkpoint.CheckpointState(pytree=unstructured_state, global_step=global_step, preemption_count=preemption_count, sum_train_cost=sum_train_cost) checkpoint.save_checkpoint_background( train_dir, checkpoint_name, state, max_to_keep=max_to_keep, use_deprecated_checkpointing=use_deprecated_checkpointing) logging.info('Done saving checkpoint.')