Ejemplo n.º 1
0
    def append_pytree(self, pytree):
        """Append and record a serializable pytree to disk.

    The pytree will be saved to disk as a list of pytree objects. Everytime
    this function is called, it will load the previous saved state, append the
    next pytree to the list, then save the appended list.

    Args:
      pytree: Any serializable pytree.
    """
        # Read the latest (and only) checkpoint, then append the new state to it
        # before saving back to disk.
        old_state = flax_checkpoints.restore_checkpoint(self._pytree_path,
                                                        target=None)
        # Because we pass target=None, flax checkpointing will return the raw state
        # dict, where 'pytree' will be a dict with keys ['0', '1', ...] instead of a
        # list.
        if old_state:
            state_list = old_state['pytree']
            state_list = [state_list[str(i)] for i in range(len(state_list))]
        else:
            state_list = []
        state_list.append(pytree)
        state = checkpoint.CheckpointState(state_list)
        checkpoint.save_checkpoint(
            self._pytree_path,
            'training_metrics',
            state,
            max_to_keep=None,
            use_deprecated_checkpointing=self._use_deprecated_checkpointing)
Ejemplo n.º 2
0
 def test_save_checkpoint_background_reraises_error(self):
     """Test than an error while saving a checkpoint is re-raised later."""
     # Checkpoint error is not raised when it actually happens, but when we next
     # write a checkpoint.
     baz = ['a', 'b', 'ccc']
     state = checkpoint.CheckpointState(self.flax_module.params,
                                        global_step=5,
                                        completed_epochs=4,
                                        baz=baz)
     checkpoint.save_checkpoint_background('/forbidden_directory/',
                                           'checkpoint', state)
     with self.assertRaisesRegex(BaseException, r'Permission\sdenied'):
         checkpoint.save_checkpoint_background(self.test_dir, 'checkpoint',
                                               state)
Ejemplo n.º 3
0
    def test_save_load_roundtrip(self, use_deprecated_checkpointing):
        """Test that saving and loading produces the original state."""
        baz = ['a', 'b', 'ccc']
        state = checkpoint.CheckpointState(self.flax_module.params,
                                           global_step=5,
                                           completed_epochs=4,
                                           baz=baz)
        checkpoint.save_checkpoint(
            self.test_dir,
            'checkpoint',
            state,
            use_deprecated_checkpointing=use_deprecated_checkpointing)
        latest = checkpoint.load_latest_checkpoint(
            self.test_dir,
            target=state,
            use_deprecated_checkpointing=use_deprecated_checkpointing)

        self.assertEqual(latest.pystate['baz'], baz)
        assert pytree_equal(latest.pytree, self.flax_module.params)
        self.assertEqual(latest.pystate['global_step'], 5)
        self.assertEqual(latest.pystate['completed_epochs'], 4)
Ejemplo n.º 4
0
def _maybe_restore_latest_checkpoint(unreplicated_optimizer,
                                     unreplicated_batch_stats,
                                     unreplicated_training_metrics_grabber,
                                     train_dir, use_deprecated_checkpointing):
    """Restore from the latest checkpoint, if it exists."""
    unreplicated_checkpoint_state = checkpoint.CheckpointState(
        {
            'optimizer': unreplicated_optimizer,
            'batch_stats': unreplicated_batch_stats,
            'training_metrics_grabber': unreplicated_training_metrics_grabber,
        },
        global_step=0,
        preemption_count=0,
        sum_train_cost=0.0)
    latest = checkpoint.load_latest_checkpoint(
        train_dir,
        target=unreplicated_checkpoint_state,
        recents_filename='latest',
        use_deprecated_checkpointing=use_deprecated_checkpointing)

    optimizer = jax_utils.replicate(unreplicated_optimizer)
    batch_stats = jax_utils.replicate(unreplicated_batch_stats)
    training_metrics_grabber = jax_utils.replicate(
        unreplicated_training_metrics_grabber)

    if latest is None:
        return optimizer, batch_stats, training_metrics_grabber, 0, 0.0, 0, False

    pytree_dict, extra_state = restore_checkpoint(
        latest,
        replicated_pytree={
            'optimizer': optimizer,
            'batch_stats': batch_stats,
            'training_metrics_grabber': training_metrics_grabber,
        },
        use_deprecated_checkpointing=use_deprecated_checkpointing)
    return (pytree_dict['optimizer'], pytree_dict['batch_stats'],
            pytree_dict['training_metrics_grabber'],
            extra_state['global_step'], extra_state['sum_train_cost'],
            extra_state['preemption_count'], True)
Ejemplo n.º 5
0
def save_checkpoint(train_dir,
                    pytree,
                    global_step,
                    preemption_count,
                    sum_train_cost,
                    max_to_keep=1,
                    use_deprecated_checkpointing=True):
    """Saves the pytree to train_dir."""
    checkpoint_name = 'ckpt_{}'.format(global_step)
    logging.info('Saving checkpoint to %s', checkpoint_name)
    unstructured_state = jax.device_get(
        [x[0] for x in jax.tree_leaves(pytree)])
    state = checkpoint.CheckpointState(pytree=unstructured_state,
                                       global_step=global_step,
                                       preemption_count=preemption_count,
                                       sum_train_cost=sum_train_cost)
    checkpoint.save_checkpoint_background(
        train_dir,
        checkpoint_name,
        state,
        max_to_keep=max_to_keep,
        use_deprecated_checkpointing=use_deprecated_checkpointing)
    logging.info('Done saving checkpoint.')