def restore_checkpoint(ckpt_dir, target, step=None, prefix='checkpoint_'): """Restore last/best checkpoint from checkpoints in path. Sorts the checkpoint files naturally, returning the highest-valued file, e.g.: ckpt_1, ckpt_2, ckpt_3 --> ckpt_3 ckpt_0.01, ckpt_0.1, ckpt_0.001 --> ckpt_0.1 ckpt_-1.0, ckpt_1.0, ckpt_1e5 --> ckpt_1e5 Args: ckpt_dir: str: directory of checkpoints to restore from. target: matching object to rebuild via deserialized state-dict. step: int: step number to load or None to load latest. prefix: str: name prefix of checkpoint files. Returns: Restored `target` updated from checkpoint file, or if no step specified and no checkpoint files present, returns the passed-in `target` unchanged. """ if step: ckpt_path = _checkpoint_path(ckpt_dir, step, prefix) if not gfile.exists(ckpt_path): raise ValueError(f'Matching checkpoint not found: {ckpt_path}') else: glob_path = os.path.join(ckpt_dir, f'{prefix}*') checkpoint_files = natural_sort(gfile.glob(glob_path)) ckpt_tmp_path = _checkpoint_path(ckpt_dir, 'tmp', prefix) checkpoint_files = [f for f in checkpoint_files if f != ckpt_tmp_path] if not checkpoint_files: return target ckpt_path = checkpoint_files[-1] logging.info('Restoring checkpoint from %s', ckpt_path) with gfile.GFile(ckpt_path, 'rb') as fp: return serialization.from_bytes(target, fp.read())
def test_save_restore_checkpoints_w_float_steps(self): tmp_dir = self.create_tempdir().full_path test_object0 = { 'a': np.array([0, 0, 0], np.int32), 'b': np.array([0, 0, 0], np.int32) } test_object1 = { 'a': np.array([1, 2, 3], np.int32), 'b': np.array([1, 1, 1], np.int32) } test_object2 = { 'a': np.array([4, 5, 6], np.int32), 'b': np.array([2, 2, 2], np.int32) } # Create leftover temporary checkpoint, which should be ignored. gfile.GFile(os.path.join(tmp_dir, 'test_tmp'), 'w') checkpoints.save_checkpoint(tmp_dir, test_object1, 0.0, prefix='test_', keep=1) self.assertIn('test_0.0', os.listdir(tmp_dir)) new_object = checkpoints.restore_checkpoint(tmp_dir, test_object0, prefix='test_') jtu.check_eq(new_object, test_object1) checkpoints.save_checkpoint(tmp_dir, test_object1, 2.0, prefix='test_', keep=1) checkpoints.save_checkpoint(tmp_dir, test_object2, 1.0, prefix='test_', keep=1) new_object = checkpoints.restore_checkpoint(tmp_dir, test_object0, prefix='test_') jtu.check_eq(new_object, test_object1) checkpoints.save_checkpoint(tmp_dir, test_object2, 3.0, prefix='test_', keep=2) checkpoints.save_checkpoint(tmp_dir, test_object1, -1.0, prefix='test_', keep=2) new_object = checkpoints.restore_checkpoint(tmp_dir, test_object0, prefix='test_') self.assertIn('test_3.0', os.listdir(tmp_dir)) self.assertIn('test_2.0', os.listdir(tmp_dir)) jtu.check_eq(new_object, test_object2)
def save_checkpoint(ckpt_dir, target, step, prefix='checkpoint_', keep=1, overwrite=False): """Save a checkpoint of the model. Attempts to be pre-emption safe by writing to temporary before a final rename and cleanup of past files. Args: ckpt_dir: str: path to store checkpoint files in. target: serializable flax object, usually a flax optimizer. step: int or float: training step number or other metric number. prefix: str: checkpoint file name prefix. keep: number of past checkpoint files to keep. overwrite: bool: allow overwriting when writing a checkpoint. Returns: Filename of saved checkpoint. """ # Write temporary checkpoint file. logging.info('Saving checkpoint at step: %s', step) ckpt_tmp_path = _checkpoint_path(ckpt_dir, 'tmp', prefix) ckpt_path = _checkpoint_path(ckpt_dir, step, prefix) gfile.makedirs(os.path.dirname(ckpt_path)) logging.info('Writing to temporary checkpoint location: %s', ckpt_tmp_path) with gfile.GFile(ckpt_tmp_path, 'wb') as fp: fp.write(serialization.to_bytes(target)) # Rename once serialization and writing finished. gfile.rename(ckpt_tmp_path, ckpt_path, overwrite=overwrite) logging.info('Saved checkpoint at %s', ckpt_path) # Remove old checkpoint files. base_path = os.path.join(ckpt_dir, f'{prefix}') checkpoint_files = natural_sort(gfile.glob(base_path + '*')) if len(checkpoint_files) > keep: old_ckpts = checkpoint_files[:-keep] for path in old_ckpts: logging.info('Removing checkpoint at %s', path) gfile.remove(path) return ckpt_path
def test_save_restore_checkpoints(self): tmp_dir = self.create_tempdir().full_path test_object0 = { 'a': np.array([0, 0, 0], np.int32), 'b': np.array([0, 0, 0], np.int32) } test_object1 = { 'a': np.array([1, 2, 3], np.int32), 'b': np.array([1, 1, 1], np.int32) } test_object2 = { 'a': np.array([4, 5, 6], np.int32), 'b': np.array([2, 2, 2], np.int32) } new_object = checkpoints.restore_checkpoint(tmp_dir, test_object0, prefix='test_') jtu.check_eq(new_object, test_object0) # Create leftover temporary checkpoint, which should be ignored. gfile.GFile(os.path.join(tmp_dir, 'test_tmp'), 'w') checkpoints.save_checkpoint(tmp_dir, test_object1, 0, prefix='test_', keep=1) self.assertIn('test_0', os.listdir(tmp_dir)) new_object = checkpoints.restore_checkpoint(tmp_dir, test_object0, prefix='test_') jtu.check_eq(new_object, test_object1) checkpoints.save_checkpoint(tmp_dir, test_object1, 1, prefix='test_', keep=1) checkpoints.save_checkpoint(tmp_dir, test_object2, 2, prefix='test_', keep=1) new_object = checkpoints.restore_checkpoint(tmp_dir, test_object0, prefix='test_') jtu.check_eq(new_object, test_object2) checkpoints.save_checkpoint(tmp_dir, test_object2, 3, prefix='test_', keep=2) checkpoints.save_checkpoint(tmp_dir, test_object1, 4, prefix='test_', keep=2) new_object = checkpoints.restore_checkpoint(tmp_dir, test_object0, prefix='test_') jtu.check_eq(new_object, test_object1) new_object = checkpoints.restore_checkpoint(tmp_dir, test_object0, step=3, prefix='test_') jtu.check_eq(new_object, test_object2) with self.assertRaises(ValueError): checkpoints.restore_checkpoint(tmp_dir, test_object0, step=5, prefix='test_')
def restore_from_path(ckpt_path, target): ckpt_path = check_and_convert_gcs_filepath(ckpt_path) logging.info('Restoring checkpoint from %s', ckpt_path) with gfile.GFile(ckpt_path, 'rb') as fp: return serialization.from_bytes(target, fp.read())
def save_model(filename: str, model: nn.Module) -> None: gfile.makedirs(os.path.dirname(filename)) with gfile.GFile(filename, "wb") as fp: fp.write(serialization.to_bytes(model))
def load_model(filename: str, model: nn.Module) -> nn.Module: with gfile.GFile(filename, "rb") as fp: return serialization.from_bytes(model, fp.read())