예제 #1
0
    def testFindCheckpointDir(self):
        checkpoint_path = os.path.join(self.checkpoint_dir, "my/nested/chkpt")
        os.makedirs(checkpoint_path)
        found_dir = TrainableUtil.find_checkpoint_dir(checkpoint_path)
        self.assertEquals(self.checkpoint_dir, found_dir)

        with self.assertRaises(FileNotFoundError):
            parent = os.path.dirname(found_dir)
            TrainableUtil.find_checkpoint_dir(parent)
예제 #2
0
 def save_checkpoint(self, checkpoint):
     if isinstance(checkpoint, str):
         try:
             TrainableUtil.find_checkpoint_dir(checkpoint)
         except FileNotFoundError:
             logger.error("Checkpoint must be created with path given from "
                          "make_checkpoint_dir.")
             raise
     self._last_checkpoint = checkpoint
     self._fresh_checkpoint = True
예제 #3
0
def save_all_states_remote(self, trial_state):
    """ Save all of AdaptDL's job state and return it as an in-memory
    object."""
    checkpoint = save_all_states()
    parent_dir = TrainableUtil.find_checkpoint_dir(checkpoint)
    checkpoint_path = TrainableUtil.process_checkpoint(checkpoint,
                                                       parent_dir,
                                                       trial_state)
    checkpoint_obj = TrainableUtil.checkpoint_to_object(checkpoint_path)
    # Done with the directory, remove
    shutil.rmtree(checkpoint_path)
    return checkpoint_obj
예제 #4
0
    def get_trial_checkpoints_paths(self, trial, metric=TRAINING_ITERATION):
        """Gets paths and metrics of all persistent checkpoints of a trial.

        Args:
            trial (Trial): The log directory of a trial, or a trial instance.
            metric (str): key for trial info to return, e.g. "mean_accuracy".
                "training_iteration" is used by default.

        Returns:
            List of [path, metric] for all persistent checkpoints of the trial.
        """
        if isinstance(trial, str):
            trial_dir = os.path.expanduser(trial)
            # Get checkpoints from logdir.
            chkpt_df = TrainableUtil.get_checkpoints_paths(trial_dir)

            # Join with trial dataframe to get metrics.
            trial_df = self.trial_dataframes[trial_dir]
            path_metric_df = chkpt_df.merge(
                trial_df, on="training_iteration", how="inner")
            return path_metric_df[["chkpt_path", metric]].values.tolist()
        elif isinstance(trial, Trial):
            checkpoints = trial.checkpoint_manager.best_checkpoints()
            return [[c.value, c.result[metric]] for c in checkpoints]
        else:
            raise ValueError("trial should be a string or a Trial instance.")
    def get_trial_checkpoints_paths(self, trial, metric=TRAINING_ITERATION):
        """Returns a list of [path, metric] lists for all disk checkpoints of
         a trial.

        Arguments:
            trial(Trial): The log directory of a trial, or a trial instance.
            metric (str): key for trial info to return, e.g. "mean_accuracy".
                "training_iteration" is used by default.
        """

        if isinstance(trial, str):
            trial_dir = os.path.expanduser(trial)

            # get checkpoints from logdir
            chkpt_df = TrainableUtil.get_checkpoints_paths(trial_dir)

            # join with trial dataframe to get metrics
            trial_df = self.trial_dataframes[trial_dir]
            path_metric_df = chkpt_df.merge(
                trial_df, on="training_iteration", how="inner")
            return path_metric_df[["chkpt_path", metric]].values.tolist()
        elif isinstance(trial, Trial):
            checkpoints = trial.checkpoint_manager.best_checkpoints()
            # TODO(ujvl): Remove condition once the checkpoint manager is
            #  modified to only track PERSISTENT checkpoints.
            return [[c.value, c.result[metric]] for c in checkpoints
                    if c.storage == Checkpoint.PERSISTENT]
        else:
            raise ValueError("trial should be a string or a Trial instance.")
예제 #6
0
    def save_checkpoint(self, checkpoint_dir: str = ""):
        if checkpoint_dir:
            raise ValueError(
                "Checkpoint dir should not be used with function API.")

        checkpoint = self._status_reporter.get_checkpoint()

        if not checkpoint:
            # We drop a marker here to indicate that the checkpoint is empty
            checkpoint = FuncCheckpointUtil.mk_null_checkpoint_dir(self.logdir)
            parent_dir = checkpoint
        elif isinstance(checkpoint, dict):
            return checkpoint
        elif isinstance(checkpoint, str):
            parent_dir = TrainableUtil.find_checkpoint_dir(checkpoint)
            # When the trainable is restored, a temporary checkpoint
            # is created. However, when saved, it should become permanent.
            # Ideally, there are no save calls upon a temporary
            # checkpoint, but certain schedulers might.
            if FuncCheckpointUtil.is_temp_checkpoint_dir(parent_dir):
                parent_dir = FuncCheckpointUtil.create_perm_checkpoint(
                    checkpoint_dir=parent_dir,
                    logdir=self.logdir,
                    step=self.training_iteration,
                )
        else:
            raise ValueError("Provided checkpoint was expected to have "
                             "type (str, dict). Got {}.".format(
                                 type(checkpoint)))

        return parent_dir
예제 #7
0
def save_all_states():
    """
    Invokes `save_state` on all `State` objects for which `State.skip` is True.
    This function can be used to trigger a global checkpoint and save every
    `State` in the current job.
    """
    if from_ray():
        from ray.tune.trainable import TrainableUtil
        checkpoint_dir = TrainableUtil.make_checkpoint_dir("/tmp",
                                                           index=None,
                                                           override=True)
    else:
        checkpoint_dir = checkpoint_path()
    for state in _STATES_TO_NAMES:
        save_state(state, checkpoint_dir)

    # Prevent corrupting original state files in case the process got killed
    # during state file writing.
    if replica_rank() == 0 and checkpoint_dir is not None:
        tmp_ckpt_dir = _get_tmp_ckpt_dir(checkpoint_dir)
        ckpt_dir = os.path.join(checkpoint_dir,
                                f"{CKPT_DIR_PREFIX}{num_restarts()}")
        os.rename(tmp_ckpt_dir, ckpt_dir)  # atomic, rename(src, dst)
        for dir_name in os.listdir(checkpoint_dir):
            dir_path = os.path.join(checkpoint_dir, dir_name)
            if dir_name.startswith(CKPT_DIR_PREFIX) and dir_path != ckpt_dir:
                shutil.rmtree(dir_path)
        return checkpoint_dir
예제 #8
0
 def mk_temp_checkpoint_dir(logdir):
     """Indicate that the checkpoint is only for restoration."""
     temporary_checkpoint_dir = TrainableUtil.make_checkpoint_dir(
         logdir, index="tmp" + uuid.uuid4().hex[:6], override=True
     )
     open(os.path.join(temporary_checkpoint_dir, TEMP_MARKER), "a").close()
     return temporary_checkpoint_dir
예제 #9
0
    def pause(self, trial_runner):
        """ Pause the AdaptDLTrial with a checkpoint. We try to remove the PG
        attached to this trial"""
        assert self.runner is not None
        checkpoint_obj = ray.get(
            self.runner.save_all_states.remote(self.runner.get_state.remote()))
        # Serialize to disk
        temp_checkpoint_dir = (FuncCheckpointUtil.mk_temp_checkpoint_dir(
            self.logdir))
        checkpoint_path = TrainableUtil.create_from_pickle(
            checkpoint_obj, temp_checkpoint_dir)

        # Trial will be restored from the checkpoint_path when it's resumed
        self.restore_path = checkpoint_path

        # Clear the allocation. This is a hack to clear the PG associated with
        # the trial. We assign a temporary PG which will get replaced with a
        # real PG once we resume the trial. This is needed because Tune likes
        # to keep the PGs around even for PAUSED trials.
        self.placement_group_factory = PlacementGroupFactory([{"CPU": 0.001}])
        # This forces Tune to garbage-collect uneeded PGs which can then be
        # reused
        trial_runner.trial_executor._pg_manager.\
            reconcile_placement_groups([self])
        logger.debug(f"PAUSING {self} w/ checkpoint at {checkpoint_path}")
예제 #10
0
    def set_checkpoint(self, checkpoint, is_new=True):
        """Sets the checkpoint to be returned upon get_checkpoint.

        If this is a "new" checkpoint, it will notify Tune
        (via has_new_checkpoint). Otherwise, it will NOT notify Tune.
        """
        if isinstance(checkpoint, str):
            try:
                TrainableUtil.find_checkpoint_dir(checkpoint)
            except FileNotFoundError:
                logger.error("Checkpoint must be created with path given from "
                             "make_checkpoint_dir.")
                raise
        self._last_checkpoint = checkpoint
        if is_new:
            self._fresh_checkpoint = True
예제 #11
0
 def mk_null_checkpoint_dir(logdir):
     """Indicate that the given checkpoint doesn't have state."""
     checkpoint_dir = TrainableUtil.make_checkpoint_dir(logdir,
                                                        index=-1,
                                                        override=True)
     open(os.path.join(checkpoint_dir, NULL_MARKER), "a").close()
     return checkpoint_dir
예제 #12
0
    def _restore(self,
                 trial,
                 checkpoint=None,
                 block=False) -> Optional[RunningJob]:
        """Restores training state from a given model checkpoint.

        Args:
            trial (Trial): The trial to be restored.
            checkpoint (Checkpoint): The checkpoint to restore from. If None,
                the most recent PERSISTENT checkpoint is used. Defaults to
                None.
            block (bool): Whether or not to block on restore before returning.

        Raises:
            RuntimeError: This error is raised if no runner is found.
            AbortTrialExecution: This error is raised if the trial is
                ineligible for restoration, given the Tune input arguments.
        """
        if checkpoint is None or checkpoint.value is None:
            checkpoint = trial.checkpoint
        if checkpoint.value is None:
            return
        if trial.runner is None:
            raise RuntimeError(
                "Trial {}: Unable to restore - no runner found.".format(trial))
        value = checkpoint.value
        if checkpoint.storage == Checkpoint.MEMORY:
            logger.debug("Trial %s: Attempting restore from object", trial)
            # Note that we don't store the remote since in-memory checkpoints
            # don't guarantee fault tolerance and don't need to be waited on.
            with _change_working_directory(trial):
                trial.runner.restore_from_object.remote(value)
        else:
            logger.debug("Trial %s: Attempting restore from %s", trial, value)
            if issubclass(trial.get_trainable_cls(), DurableTrainable):
                with _change_working_directory(trial):
                    remote = trial.runner.restore.remote(value)
            elif trial.sync_on_checkpoint:
                # This provides FT backwards compatibility in the
                # case where a DurableTrainable is not provided.
                logger.warning("Trial %s: Reading checkpoint into memory.",
                               trial)
                data_dict = TrainableUtil.pickle_checkpoint(value)
                with _change_working_directory(trial):
                    remote = trial.runner.restore_from_object.remote(data_dict)
            else:
                raise AbortTrialExecution(
                    "Pass in `sync_on_checkpoint=True` for driver-based trial"
                    "restoration. Pass in an `upload_dir` and a Trainable "
                    "extending `DurableTrainable` for remote storage-based "
                    "restoration")

            if block:
                ray.get(remote)
            else:
                trial.restoring_from = checkpoint
                running_job = RunningJob(trial, remote)
                self.jobs_running[remote] = running_job
                return running_job
예제 #13
0
 def save_to_object(self):
     checkpoint_path = self.save()
     data_dict = TrainableUtil.pickle_checkpoint(checkpoint_path)
     out = io.BytesIO()
     if len(data_dict) > 10e6:  # getting pretty large
         logger.info("Checkpoint size is {} bytes".format(len(data_dict)))
     out.write(data_dict)
     return out.getvalue()
예제 #14
0
 def restore_from_object(self, obj):
     self.temp_checkpoint_dir = FuncCheckpointUtil.mk_temp_checkpoint_dir(
         self.logdir
     )
     checkpoint_path = TrainableUtil.create_from_pickle(
         obj, self.temp_checkpoint_dir
     )
     self.restore(checkpoint_path)
예제 #15
0
    def save(self, checkpoint_path=None) -> str:
        if checkpoint_path:
            raise ValueError("Checkpoint path should not be used with function API.")

        checkpoint = self._status_reporter.get_checkpoint()
        state = self.get_state()

        if not checkpoint:
            state.update(iteration=0, timesteps_total=0, episodes_total=0)
            # We drop a marker here to indicate that the checkpoint is empty
            checkpoint = FuncCheckpointUtil.mk_null_checkpoint_dir(self.logdir)
            parent_dir = checkpoint
        elif isinstance(checkpoint, dict):
            parent_dir = TrainableUtil.make_checkpoint_dir(
                self.logdir, index=self.training_iteration
            )
        elif isinstance(checkpoint, str):
            parent_dir = TrainableUtil.find_checkpoint_dir(checkpoint)
            # When the trainable is restored, a temporary checkpoint
            # is created. However, when saved, it should become permanent.
            # Ideally, there are no save calls upon a temporary
            # checkpoint, but certain schedulers might.
            if FuncCheckpointUtil.is_temp_checkpoint_dir(parent_dir):
                relative_path = os.path.relpath(checkpoint, parent_dir)
                parent_dir = FuncCheckpointUtil.create_perm_checkpoint(
                    checkpoint_dir=parent_dir,
                    logdir=self.logdir,
                    step=self.training_iteration,
                )
                checkpoint = os.path.abspath(os.path.join(parent_dir, relative_path))
        else:
            raise ValueError(
                "Provided checkpoint was expected to have "
                "type (str, dict). Got {}.".format(type(checkpoint))
            )

        checkpoint_path = TrainableUtil.process_checkpoint(
            checkpoint, parent_dir, state
        )

        self._postprocess_checkpoint(checkpoint_path)

        self._maybe_save_to_cloud(parent_dir)

        return checkpoint_path
예제 #16
0
    def delete_checkpoint(self, checkpoint_path):
        """Deletes checkpoint from both local and remote storage.

        Args:
            checkpoint_path (str): Local path to checkpoint.
        """
        super(DurableTrainable, self).delete_checkpoint(checkpoint_path)
        local_dirpath = TrainableUtil.find_checkpoint_dir(checkpoint_path)
        self.storage_client.delete(self._storage_path(local_dirpath))
예제 #17
0
    def save(self, checkpoint_path=None) -> str:
        if checkpoint_path:
            raise ValueError("Checkpoint path should not be used with function API.")

        checkpoint_path = self.save_checkpoint()

        parent_dir = TrainableUtil.find_checkpoint_dir(checkpoint_path)
        self._maybe_save_to_cloud(parent_dir)

        return checkpoint_path
예제 #18
0
    def save(self, checkpoint_path=None):
        if checkpoint_path:
            raise ValueError(
                "Checkpoint path should not be used with function API.")

        checkpoint = self._status_reporter.get_checkpoint()
        state = self.get_state()

        if not checkpoint:
            state.update(iteration=0, timesteps_total=0, episodes_total=0)
            parent_dir = self.create_default_checkpoint_dir()
        elif isinstance(checkpoint, dict):
            parent_dir = TrainableUtil.make_checkpoint_dir(
                self.logdir, index=self.training_iteration)
        else:
            parent_dir = TrainableUtil.find_checkpoint_dir(checkpoint)
        checkpoint_path = TrainableUtil.process_checkpoint(
            checkpoint, parent_dir, state)
        return checkpoint_path
예제 #19
0
    def restore_from_object(self, obj):
        if self.default_checkpoint_dir is not None and os.exists(
                self.default_checkpoint_dir):
            shutil.rmtree(self.default_checkpoint_dir)
            logger.debug("Clearing default checkpoint: %s",
                         self.default_checkpoint_dir)

        checkpoint_dir = self.create_default_checkpoint_dir()
        checkpoint_path = TrainableUtil.create_from_pickle(obj, checkpoint_dir)
        self.restore(checkpoint_path)
예제 #20
0
    def testConvertTempToPermanent(self):
        checkpoint_dir = FuncCheckpointUtil.mk_temp_checkpoint_dir(self.logdir)
        new_checkpoint_dir = FuncCheckpointUtil.create_perm_checkpoint(
            checkpoint_dir, self.logdir, step=4)
        assert new_checkpoint_dir == TrainableUtil.find_checkpoint_dir(
            new_checkpoint_dir)
        assert os.path.exists(new_checkpoint_dir)
        assert not FuncCheckpointUtil.is_temp_checkpoint_dir(
            new_checkpoint_dir)

        tmp_checkpoint_dir = FuncCheckpointUtil.mk_temp_checkpoint_dir(
            self.logdir)
        assert tmp_checkpoint_dir != new_checkpoint_dir
    def create_perm_checkpoint(checkpoint_dir, logdir, step):
        """Copies temporary checkpoint to a permanent checkpoint directory."""
        checkpoint_dir = os.path.abspath(checkpoint_dir)
        temporary_marker = os.path.join(checkpoint_dir, TEMP_MARKER)
        assert os.path.exists(temporary_marker), (
            "Should not be calling this method on a permanent checkpoint.")
        os.remove(temporary_marker)
        perm_checkpoint_dir = TrainableUtil.make_checkpoint_dir(
            logdir, index=step, override=True)
        shutil.rmtree(perm_checkpoint_dir)

        shutil.copytree(checkpoint_dir, perm_checkpoint_dir)
        assert not os.path.exists(
            os.path.join(perm_checkpoint_dir, TEMP_MARKER))
        return perm_checkpoint_dir
예제 #22
0
    def delete_checkpoint(self, checkpoint_path):
        """Deletes checkpoint from both local and remote storage.

        Args:
            checkpoint_path (str): Local path to checkpoint.
        """
        try:
            local_dirpath = TrainableUtil.find_checkpoint_dir(checkpoint_path)
        except FileNotFoundError:
            logger.warning(
                "Trial %s: checkpoint path not found during "
                "garbage collection. See issue #6697.", self.trial_id)
        else:
            self.storage_client.delete(self._storage_path(local_dirpath))
        super(DurableTrainable, self).delete_checkpoint(checkpoint_path)
예제 #23
0
    def testPickleCheckpoint(self):
        for i in range(5):
            path = os.path.join(self.checkpoint_dir, str(i))
            with open(path, "w") as f:
                f.write(str(i))

        checkpoint_path = os.path.join(self.checkpoint_dir, "0")

        data_dict = TrainableUtil.pickle_checkpoint(checkpoint_path)
        loaded = pickle.loads(data_dict)

        checkpoint_name = os.path.basename(checkpoint_path)
        self.assertEqual(loaded["checkpoint_name"], checkpoint_name)

        for i in range(5):
            path = os.path.join(self.checkpoint_dir, str(i))
            self.assertEquals(loaded["data"][str(i)], open(path, "rb").read())
예제 #24
0
파일: trial.py 프로젝트: zommiommy/ray
    def delete(checkpoint):
        """Requests checkpoint deletion asynchronously.

        Args:
            checkpoint (Checkpoint): Checkpoint to delete.
        """
        if checkpoint.storage == Checkpoint.PERSISTENT and checkpoint.value:
            logger.debug("Trial %s: Deleting checkpoint %s", trial_id,
                         checkpoint.value)
            checkpoint_path = checkpoint.value
            # Delete local copy, if any exists.
            if os.path.exists(checkpoint_path):
                try:
                    checkpoint_dir = TrainableUtil.find_checkpoint_dir(
                        checkpoint_path)
                    shutil.rmtree(checkpoint_dir)
                except FileNotFoundError:
                    logger.warning("Checkpoint dir not found during deletion.")

            # TODO(ujvl): Batch remote deletes.
            runner.delete_checkpoint.remote(checkpoint.value)
예제 #25
0
    def restore(self, trial, checkpoint=None):
        """Restores training state from a given model checkpoint.

        Raises:
            RuntimeError: This error is raised if no runner is found.
            AbortTrialExecution: This error is raised if the trial is
                ineligible for restoration, given the Tune input arguments.
        """
        if checkpoint is None or checkpoint.value is None:
            checkpoint = trial.checkpoint
        if checkpoint.value is None:
            return
        if trial.runner is None:
            raise RuntimeError(
                "Trial {}: Unable to restore - no runner found.".format(trial))
        value = checkpoint.value
        if checkpoint.storage == Checkpoint.MEMORY:
            logger.debug("Trial %s: Attempting restore from object", trial)
            # Note that we don't store the remote since in-memory checkpoints
            # don't guarantee fault tolerance and don't need to be waited on.
            trial.runner.restore_from_object.remote(value)
        else:
            logger.debug("Trial %s: Attempting restore from %s", trial, value)
            if issubclass(trial.get_trainable_cls(), DurableTrainable):
                remote = trial.runner.restore.remote(value)
            elif trial.sync_on_checkpoint:
                # This provides FT backwards compatibility in the
                # case where a DurableTrainable is not provided.
                logger.warning("Trial %s: Reading checkpoint into memory.",
                               trial)
                data_dict = TrainableUtil.pickle_checkpoint(value)
                remote = trial.runner.restore_from_object.remote(data_dict)
            else:
                raise AbortTrialExecution(
                    "Pass in `sync_on_checkpoint=True` for driver-based trial"
                    "restoration. Pass in an `upload_dir` and a Trainable "
                    "extending `DurableTrainable` for remote storage-based "
                    "restoration")
            self._running[remote] = trial
            trial.restoring_from = checkpoint
예제 #26
0
 def save_to_object(self):
     checkpoint_path = self.save()
     obj = TrainableUtil.checkpoint_to_object(checkpoint_path)
     return obj
예제 #27
0
 def make_checkpoint_dir(self, step):
     checkpoint_dir = TrainableUtil.make_checkpoint_dir(self.logdir,
                                                        index=step)
     logger.debug("Making checkpoint dir at %s", checkpoint_dir)
     return checkpoint_dir
예제 #28
0
파일: horovod.py 프로젝트: zzmcdc/ray
 def load_checkpoint(self, checkpoint_dir: str):
     checkpoint_obj = TrainableUtil.checkpoint_to_object(checkpoint_dir)
     x_id = ray.put(checkpoint_obj)
     return self.executor.execute(lambda w: w.restore_from_object(x_id))
예제 #29
0
파일: horovod.py 프로젝트: zzmcdc/ray
 def save_checkpoint(self, checkpoint_dir: str) -> str:
     # TODO: optimize if colocated
     save_obj = self.executor.execute_single(lambda w: w.save_to_object())
     checkpoint_path = TrainableUtil.create_from_pickle(
         save_obj, checkpoint_dir)
     return checkpoint_path
예제 #30
0
파일: torch.py 프로젝트: wangziyuruc/ray
 def load_checkpoint(self, checkpoint_dir):
     checkpoint_obj = TrainableUtil.checkpoint_to_object(checkpoint_dir)
     return ray.get(
         w.restore_from_object.remote(checkpoint_obj) for w in self.workers)