Exemple #1
0
def restore_checkpoint(ckpt_dir, target, step=None, prefix='checkpoint_'):
    """Restore last/best checkpoint from checkpoints in path.

  Sorts the checkpoint files naturally, returning the highest-valued
  file, e.g.:
    ckpt_1, ckpt_2, ckpt_3 --> ckpt_3
    ckpt_0.01, ckpt_0.1, ckpt_0.001 --> ckpt_0.1
    ckpt_-1.0, ckpt_1.0, ckpt_1e5 --> ckpt_1e5

  Args:
    ckpt_dir: str: directory of checkpoints to restore from.
    target: matching object to rebuild via deserialized state-dict.
    step: int: step number to load or None to load latest.
    prefix: str: name prefix of checkpoint files.

  Returns:
    Restored `target` updated from checkpoint file, or if no step specified and
    no checkpoint files present, returns the passed-in `target` unchanged.
  """
    if step:
        ckpt_path = _checkpoint_path(ckpt_dir, step, prefix)
        if not gfile.exists(ckpt_path):
            raise ValueError(f'Matching checkpoint not found: {ckpt_path}')
    else:
        glob_path = os.path.join(ckpt_dir, f'{prefix}*')
        checkpoint_files = natural_sort(gfile.glob(glob_path))
        ckpt_tmp_path = _checkpoint_path(ckpt_dir, 'tmp', prefix)
        checkpoint_files = [f for f in checkpoint_files if f != ckpt_tmp_path]
        if not checkpoint_files:
            return target
        ckpt_path = checkpoint_files[-1]

    logging.info('Restoring checkpoint from %s', ckpt_path)
    with gfile.GFile(ckpt_path, 'rb') as fp:
        return serialization.from_bytes(target, fp.read())
Exemple #2
0
 def test_save_restore_checkpoints_w_float_steps(self):
     tmp_dir = self.create_tempdir().full_path
     test_object0 = {
         'a': np.array([0, 0, 0], np.int32),
         'b': np.array([0, 0, 0], np.int32)
     }
     test_object1 = {
         'a': np.array([1, 2, 3], np.int32),
         'b': np.array([1, 1, 1], np.int32)
     }
     test_object2 = {
         'a': np.array([4, 5, 6], np.int32),
         'b': np.array([2, 2, 2], np.int32)
     }
     # Create leftover temporary checkpoint, which should be ignored.
     gfile.GFile(os.path.join(tmp_dir, 'test_tmp'), 'w')
     checkpoints.save_checkpoint(tmp_dir,
                                 test_object1,
                                 0.0,
                                 prefix='test_',
                                 keep=1)
     self.assertIn('test_0.0', os.listdir(tmp_dir))
     new_object = checkpoints.restore_checkpoint(tmp_dir,
                                                 test_object0,
                                                 prefix='test_')
     jtu.check_eq(new_object, test_object1)
     checkpoints.save_checkpoint(tmp_dir,
                                 test_object1,
                                 2.0,
                                 prefix='test_',
                                 keep=1)
     checkpoints.save_checkpoint(tmp_dir,
                                 test_object2,
                                 1.0,
                                 prefix='test_',
                                 keep=1)
     new_object = checkpoints.restore_checkpoint(tmp_dir,
                                                 test_object0,
                                                 prefix='test_')
     jtu.check_eq(new_object, test_object1)
     checkpoints.save_checkpoint(tmp_dir,
                                 test_object2,
                                 3.0,
                                 prefix='test_',
                                 keep=2)
     checkpoints.save_checkpoint(tmp_dir,
                                 test_object1,
                                 -1.0,
                                 prefix='test_',
                                 keep=2)
     new_object = checkpoints.restore_checkpoint(tmp_dir,
                                                 test_object0,
                                                 prefix='test_')
     self.assertIn('test_3.0', os.listdir(tmp_dir))
     self.assertIn('test_2.0', os.listdir(tmp_dir))
     jtu.check_eq(new_object, test_object2)
Exemple #3
0
def save_checkpoint(ckpt_dir,
                    target,
                    step,
                    prefix='checkpoint_',
                    keep=1,
                    overwrite=False):
    """Save a checkpoint of the model.

  Attempts to be pre-emption safe by writing to temporary before
  a final rename and cleanup of past files.

  Args:
    ckpt_dir: str: path to store checkpoint files in.
    target: serializable flax object, usually a flax optimizer.
    step: int or float: training step number or other metric number.
    prefix: str: checkpoint file name prefix.
    keep: number of past checkpoint files to keep.
    overwrite: bool: allow overwriting when writing a checkpoint.

  Returns:
    Filename of saved checkpoint.
  """
    # Write temporary checkpoint file.
    logging.info('Saving checkpoint at step: %s', step)
    ckpt_tmp_path = _checkpoint_path(ckpt_dir, 'tmp', prefix)
    ckpt_path = _checkpoint_path(ckpt_dir, step, prefix)
    gfile.makedirs(os.path.dirname(ckpt_path))

    logging.info('Writing to temporary checkpoint location: %s', ckpt_tmp_path)
    with gfile.GFile(ckpt_tmp_path, 'wb') as fp:
        fp.write(serialization.to_bytes(target))

    # Rename once serialization and writing finished.
    gfile.rename(ckpt_tmp_path, ckpt_path, overwrite=overwrite)
    logging.info('Saved checkpoint at %s', ckpt_path)

    # Remove old checkpoint files.
    base_path = os.path.join(ckpt_dir, f'{prefix}')
    checkpoint_files = natural_sort(gfile.glob(base_path + '*'))
    if len(checkpoint_files) > keep:
        old_ckpts = checkpoint_files[:-keep]
        for path in old_ckpts:
            logging.info('Removing checkpoint at %s', path)
            gfile.remove(path)

    return ckpt_path
Exemple #4
0
 def test_save_restore_checkpoints(self):
     tmp_dir = self.create_tempdir().full_path
     test_object0 = {
         'a': np.array([0, 0, 0], np.int32),
         'b': np.array([0, 0, 0], np.int32)
     }
     test_object1 = {
         'a': np.array([1, 2, 3], np.int32),
         'b': np.array([1, 1, 1], np.int32)
     }
     test_object2 = {
         'a': np.array([4, 5, 6], np.int32),
         'b': np.array([2, 2, 2], np.int32)
     }
     new_object = checkpoints.restore_checkpoint(tmp_dir,
                                                 test_object0,
                                                 prefix='test_')
     jtu.check_eq(new_object, test_object0)
     # Create leftover temporary checkpoint, which should be ignored.
     gfile.GFile(os.path.join(tmp_dir, 'test_tmp'), 'w')
     checkpoints.save_checkpoint(tmp_dir,
                                 test_object1,
                                 0,
                                 prefix='test_',
                                 keep=1)
     self.assertIn('test_0', os.listdir(tmp_dir))
     new_object = checkpoints.restore_checkpoint(tmp_dir,
                                                 test_object0,
                                                 prefix='test_')
     jtu.check_eq(new_object, test_object1)
     checkpoints.save_checkpoint(tmp_dir,
                                 test_object1,
                                 1,
                                 prefix='test_',
                                 keep=1)
     checkpoints.save_checkpoint(tmp_dir,
                                 test_object2,
                                 2,
                                 prefix='test_',
                                 keep=1)
     new_object = checkpoints.restore_checkpoint(tmp_dir,
                                                 test_object0,
                                                 prefix='test_')
     jtu.check_eq(new_object, test_object2)
     checkpoints.save_checkpoint(tmp_dir,
                                 test_object2,
                                 3,
                                 prefix='test_',
                                 keep=2)
     checkpoints.save_checkpoint(tmp_dir,
                                 test_object1,
                                 4,
                                 prefix='test_',
                                 keep=2)
     new_object = checkpoints.restore_checkpoint(tmp_dir,
                                                 test_object0,
                                                 prefix='test_')
     jtu.check_eq(new_object, test_object1)
     new_object = checkpoints.restore_checkpoint(tmp_dir,
                                                 test_object0,
                                                 step=3,
                                                 prefix='test_')
     jtu.check_eq(new_object, test_object2)
     with self.assertRaises(ValueError):
         checkpoints.restore_checkpoint(tmp_dir,
                                        test_object0,
                                        step=5,
                                        prefix='test_')
Exemple #5
0
def restore_from_path(ckpt_path, target):
    ckpt_path = check_and_convert_gcs_filepath(ckpt_path)
    logging.info('Restoring checkpoint from %s', ckpt_path)
    with gfile.GFile(ckpt_path, 'rb') as fp:
        return serialization.from_bytes(target, fp.read())
Exemple #6
0
def save_model(filename: str, model: nn.Module) -> None:
    gfile.makedirs(os.path.dirname(filename))
    with gfile.GFile(filename, "wb") as fp:
        fp.write(serialization.to_bytes(model))
Exemple #7
0
def load_model(filename: str, model: nn.Module) -> nn.Module:
    with gfile.GFile(filename, "rb") as fp:
        return serialization.from_bytes(model, fp.read())