Beispiel #1
0
def save_all_states():
    """
    Invokes `save_state` on all `State` objects for which `State.skip` is True.
    This function can be used to trigger a global checkpoint and save every
    `State` in the current job.
    """
    if from_ray():
        from ray.tune.trainable import TrainableUtil
        checkpoint_dir = TrainableUtil.make_checkpoint_dir("/tmp",
                                                           index=None,
                                                           override=True)
    else:
        checkpoint_dir = checkpoint_path()
    for state in _STATES_TO_NAMES:
        save_state(state, checkpoint_dir)

    # Prevent corrupting original state files in case the process got killed
    # during state file writing.
    if replica_rank() == 0 and checkpoint_dir is not None:
        tmp_ckpt_dir = _get_tmp_ckpt_dir(checkpoint_dir)
        ckpt_dir = os.path.join(checkpoint_dir,
                                f"{CKPT_DIR_PREFIX}{num_restarts()}")
        os.rename(tmp_ckpt_dir, ckpt_dir)  # atomic, rename(src, dst)
        for dir_name in os.listdir(checkpoint_dir):
            dir_path = os.path.join(checkpoint_dir, dir_name)
            if dir_name.startswith(CKPT_DIR_PREFIX) and dir_path != ckpt_dir:
                shutil.rmtree(dir_path)
        return checkpoint_dir
Beispiel #2
0
 def mk_null_checkpoint_dir(logdir):
     """Indicate that the given checkpoint doesn't have state."""
     checkpoint_dir = TrainableUtil.make_checkpoint_dir(logdir,
                                                        index=-1,
                                                        override=True)
     open(os.path.join(checkpoint_dir, NULL_MARKER), "a").close()
     return checkpoint_dir
Beispiel #3
0
 def mk_temp_checkpoint_dir(logdir):
     """Indicate that the checkpoint is only for restoration."""
     temporary_checkpoint_dir = TrainableUtil.make_checkpoint_dir(
         logdir, index="tmp" + uuid.uuid4().hex[:6], override=True
     )
     open(os.path.join(temporary_checkpoint_dir, TEMP_MARKER), "a").close()
     return temporary_checkpoint_dir
    def create_perm_checkpoint(checkpoint_dir, logdir, step):
        """Copies temporary checkpoint to a permanent checkpoint directory."""
        checkpoint_dir = os.path.abspath(checkpoint_dir)
        temporary_marker = os.path.join(checkpoint_dir, TEMP_MARKER)
        assert os.path.exists(temporary_marker), (
            "Should not be calling this method on a permanent checkpoint.")
        os.remove(temporary_marker)
        perm_checkpoint_dir = TrainableUtil.make_checkpoint_dir(
            logdir, index=step, override=True)
        shutil.rmtree(perm_checkpoint_dir)

        shutil.copytree(checkpoint_dir, perm_checkpoint_dir)
        assert not os.path.exists(
            os.path.join(perm_checkpoint_dir, TEMP_MARKER))
        return perm_checkpoint_dir
Beispiel #5
0
    def save(self, checkpoint_path=None) -> str:
        if checkpoint_path:
            raise ValueError("Checkpoint path should not be used with function API.")

        checkpoint = self._status_reporter.get_checkpoint()
        state = self.get_state()

        if not checkpoint:
            state.update(iteration=0, timesteps_total=0, episodes_total=0)
            # We drop a marker here to indicate that the checkpoint is empty
            checkpoint = FuncCheckpointUtil.mk_null_checkpoint_dir(self.logdir)
            parent_dir = checkpoint
        elif isinstance(checkpoint, dict):
            parent_dir = TrainableUtil.make_checkpoint_dir(
                self.logdir, index=self.training_iteration
            )
        elif isinstance(checkpoint, str):
            parent_dir = TrainableUtil.find_checkpoint_dir(checkpoint)
            # When the trainable is restored, a temporary checkpoint
            # is created. However, when saved, it should become permanent.
            # Ideally, there are no save calls upon a temporary
            # checkpoint, but certain schedulers might.
            if FuncCheckpointUtil.is_temp_checkpoint_dir(parent_dir):
                relative_path = os.path.relpath(checkpoint, parent_dir)
                parent_dir = FuncCheckpointUtil.create_perm_checkpoint(
                    checkpoint_dir=parent_dir,
                    logdir=self.logdir,
                    step=self.training_iteration,
                )
                checkpoint = os.path.abspath(os.path.join(parent_dir, relative_path))
        else:
            raise ValueError(
                "Provided checkpoint was expected to have "
                "type (str, dict). Got {}.".format(type(checkpoint))
            )

        checkpoint_path = TrainableUtil.process_checkpoint(
            checkpoint, parent_dir, state
        )

        self._postprocess_checkpoint(checkpoint_path)

        self._maybe_save_to_cloud(parent_dir)

        return checkpoint_path
Beispiel #6
0
    def save(self, checkpoint_path=None):
        if checkpoint_path:
            raise ValueError(
                "Checkpoint path should not be used with function API.")

        checkpoint = self._status_reporter.get_checkpoint()
        state = self.get_state()

        if not checkpoint:
            state.update(iteration=0, timesteps_total=0, episodes_total=0)
            parent_dir = self.create_default_checkpoint_dir()
        elif isinstance(checkpoint, dict):
            parent_dir = TrainableUtil.make_checkpoint_dir(
                self.logdir, index=self.training_iteration)
        else:
            parent_dir = TrainableUtil.find_checkpoint_dir(checkpoint)
        checkpoint_path = TrainableUtil.process_checkpoint(
            checkpoint, parent_dir, state)
        return checkpoint_path
Beispiel #7
0
 def make_checkpoint_dir(self, step):
     checkpoint_dir = TrainableUtil.make_checkpoint_dir(self.logdir,
                                                        index=step)
     logger.debug("Making checkpoint dir at %s", checkpoint_dir)
     return checkpoint_dir
Beispiel #8
0
 def make_checkpoint_dir(self, step=None):
     checkpoint_dir = TrainableUtil.make_checkpoint_dir(self.logdir,
                                                        index=step)
     return checkpoint_dir
Beispiel #9
0
 def create_default_checkpoint_dir(self):
     self.default_checkpoint_dir = TrainableUtil.make_checkpoint_dir(
         self.logdir, index="default")
     return self.default_checkpoint_dir
Beispiel #10
0
 def setUp(self):
     self.checkpoint_dir = "/tmp/tune/MyTrainable123"
     TrainableUtil.make_checkpoint_dir(self.checkpoint_dir)
Beispiel #11
0
 def setUp(self):
     self.checkpoint_dir = os.path.join(ray.utils.get_user_temp_dir(),
                                        "tune", "MyTrainable123")
     TrainableUtil.make_checkpoint_dir(self.checkpoint_dir)