コード例 #1
0
    def save(self, checkpoint_dir=None):
        """Saves the current model state to a checkpoint.

        Subclasses should override ``_save()`` instead to save state.
        This method dumps additional metadata alongside the saved path.

        Args:
            checkpoint_dir (str): Optional dir to place the checkpoint.

        Returns:
            str: Checkpoint path or prefix that may be passed to restore().
        """
        checkpoint_dir = TrainableUtil.make_checkpoint_dir(
            checkpoint_dir or self.logdir, index=self.iteration)
        checkpoint = self.save_checkpoint(checkpoint_dir)
        trainable_state = self.get_state()
        checkpoint_path = TrainableUtil.process_checkpoint(
            checkpoint,
            parent_dir=checkpoint_dir,
            trainable_state=trainable_state)
        return checkpoint_path
コード例 #2
0
    def restore(self, trial, checkpoint=None, block=False) -> None:
        """Restores training state from a given model checkpoint.

        Args:
            trial (Trial): The trial to be restored.
            checkpoint (Checkpoint): The checkpoint to restore from. If None,
                the most recent PERSISTENT checkpoint is used. Defaults to
                None.
            block (bool): Whether or not to block on restore before returning.

        Raises:
            RuntimeError: This error is raised if no runner is found.
            AbortTrialExecution: This error is raised if the trial is
                ineligible for restoration, given the Tune input arguments.
        """
        if checkpoint is None or checkpoint.value is None:
            checkpoint = trial.checkpoint
        if checkpoint.value is None:
            return
        if trial.runner is None:
            raise RuntimeError(
                "Trial {}: Unable to restore - no runner found.".format(trial))
        value = checkpoint.value
        if checkpoint.storage == Checkpoint.MEMORY:
            logger.debug("Trial %s: Attempting restore from object", trial)
            # Note that we don't store the remote since in-memory checkpoints
            # don't guarantee fault tolerance and don't need to be waited on.
            with self._change_working_directory(trial):
                trial.runner.restore_from_object.remote(value)
        else:
            logger.debug("Trial %s: Attempting restore from %s", trial, value)
            if trial.uses_cloud_checkpointing or not trial.sync_on_checkpoint:
                with self._change_working_directory(trial):
                    remote = trial.runner.restore.remote(value)
            elif trial.sync_on_checkpoint:
                # This provides FT backwards compatibility in the
                # case where no cloud checkpoints are provided.
                logger.debug("Trial %s: Reading checkpoint into memory", trial)
                obj = TrainableUtil.checkpoint_to_object(value)
                with self._change_working_directory(trial):
                    remote = trial.runner.restore_from_object.remote(obj)
            else:
                raise AbortTrialExecution(
                    "Pass in `sync_on_checkpoint=True` for driver-based trial"
                    "restoration. Pass in an `upload_dir` for remote "
                    "storage-based restoration")

            if block:
                ray.get(remote)
            else:
                self._running[remote] = trial
                trial.restoring_from = checkpoint
コード例 #3
0
    def restore(self, trial: Trial) -> None:
        """Restores training state from a given model checkpoint.

        Args:
            trial: The trial to be restored.

        Raises:
            RuntimeError: This error is raised if no runner is found.
            AbortTrialExecution: This error is raised if the trial is
                ineligible for restoration, given the Tune input arguments.
        """
        checkpoint = trial.checkpoint
        if checkpoint.dir_or_data is None:
            return
        if trial.runner is None:
            raise RuntimeError(
                "Trial {}: Unable to restore - no runner found.".format(trial))
        checkpoint_dir = checkpoint.dir_or_data
        node_ip = checkpoint.node_ip
        if checkpoint.storage_mode == CheckpointStorage.MEMORY:
            logger.debug("Trial %s: Attempting restore from object", trial)
            # Note that we don't store the remote since in-memory checkpoints
            # don't guarantee fault tolerance and don't need to be waited on.
            with self._change_working_directory(trial):
                trial.runner.restore_from_object.remote(checkpoint_dir)
        else:
            logger.debug("Trial %s: Attempting restore from %s", trial,
                         checkpoint_dir)
            if (trial.uses_cloud_checkpointing or not trial.sync_on_checkpoint
                    or not os.path.exists(checkpoint_dir)):
                # If using cloud checkpointing, trial will get cp from cloud.
                # If not syncing to driver, assume it has access to the cp
                # on the local fs.
                with self._change_working_directory(trial):
                    remote = trial.runner.restore.remote(
                        checkpoint_dir, node_ip)
            elif trial.sync_on_checkpoint:
                # This provides FT backwards compatibility in the
                # case where no cloud checkpoints are provided.
                logger.debug("Trial %s: Reading checkpoint into memory", trial)
                obj = TrainableUtil.checkpoint_to_object(checkpoint_dir)
                with self._change_working_directory(trial):
                    remote = trial.runner.restore_from_object.remote(obj)
            else:
                raise _AbortTrialExecution(
                    "Pass in `sync_on_checkpoint=True` for driver-based trial"
                    "restoration. Pass in an `upload_dir` for remote "
                    "storage-based restoration")

            self._futures[remote] = (_ExecutorEventType.RESTORING_RESULT,
                                     trial)
            trial.restoring_from = checkpoint
コード例 #4
0
    def _trial_to_result(self, trial: Trial) -> Result:
        if trial.checkpoint.value:
            checkpoint_dir = TrainableUtil.find_checkpoint_dir(trial.checkpoint.value)
            checkpoint = Checkpoint.from_directory(checkpoint_dir)
        else:
            checkpoint = None

        result = Result(
            checkpoint=checkpoint,
            metrics=trial.last_result.copy(),
            error=self._populate_exception(trial),
        )
        return result
コード例 #5
0
        def write_checkpoint(trial: Trial, index: int):
            checkpoint_dir = TrainableUtil.make_checkpoint_dir(trial.logdir,
                                                               index=index)
            result = {"training_iteration": index}
            with open(os.path.join(checkpoint_dir, "cp.json"), "w") as f:
                json.dump(result, f)

            tune_cp = _TuneCheckpoint(_TuneCheckpoint.PERSISTENT,
                                      checkpoint_dir, result)
            trial.saving_to = tune_cp
            trial.on_checkpoint(tune_cp)

            return checkpoint_dir
コード例 #6
0
    def testConvertTempToPermanent(self):
        checkpoint_dir = FuncCheckpointUtil.mk_temp_checkpoint_dir(self.logdir)
        new_checkpoint_dir = FuncCheckpointUtil.create_perm_checkpoint(
            checkpoint_dir, self.logdir, step=4)
        assert new_checkpoint_dir == TrainableUtil.find_checkpoint_dir(
            new_checkpoint_dir)
        assert os.path.exists(new_checkpoint_dir)
        assert not FuncCheckpointUtil.is_temp_checkpoint_dir(
            new_checkpoint_dir)

        tmp_checkpoint_dir = FuncCheckpointUtil.mk_temp_checkpoint_dir(
            self.logdir)
        assert tmp_checkpoint_dir != new_checkpoint_dir
コード例 #7
0
    def save_to_object(self):
        """Saves the current model state to a Python object.

        It also saves to disk but does not return the checkpoint path.

        Returns:
            Object holding checkpoint data.
        """
        tmpdir = tempfile.mkdtemp("save_to_object", dir=self.logdir)
        checkpoint_path = self.save(tmpdir)
        # Save all files in subtree and delete the tmpdir.
        obj = TrainableUtil.checkpoint_to_object(checkpoint_path)
        shutil.rmtree(tmpdir)
        return obj
コード例 #8
0
    def delete_checkpoint(self, checkpoint_path):
        """Deletes local copy of checkpoint.

        Args:
            checkpoint_path (str): Path to checkpoint.
        """
        try:
            checkpoint_dir = TrainableUtil.find_checkpoint_dir(checkpoint_path)
        except FileNotFoundError:
            # The checkpoint won't exist locally if the
            # trial was rescheduled to another worker.
            logger.debug("Checkpoint not found during garbage collection.")
            return
        if os.path.exists(checkpoint_dir):
            shutil.rmtree(checkpoint_dir)
コード例 #9
0
def create_checkpoint(preprocessor: Optional[Preprocessor] = None,
                      config: Optional[dict] = None) -> Checkpoint:
    rl_trainer = RLTrainer(
        algorithm=_DummyAlgo,
        config=config or {},
        preprocessor=preprocessor,
    )
    rl_trainable_cls = rl_trainer.as_trainable()
    rl_trainable = rl_trainable_cls()

    with tempfile.TemporaryDirectory() as checkpoint_dir:
        checkpoint_file = rl_trainable.save(checkpoint_dir)
        checkpoint_path = TrainableUtil.find_checkpoint_dir(checkpoint_file)
        checkpoint_data = Checkpoint.from_directory(checkpoint_path).to_dict()

    return Checkpoint.from_dict(checkpoint_data)
コード例 #10
0
        def write_checkpoint(trial: Trial, index: int):
            checkpoint_dir = TrainableUtil.make_checkpoint_dir(
                trial.logdir, index=index
            )
            result = {"training_iteration": index}
            with open(os.path.join(checkpoint_dir, "cp.json"), "w") as f:
                json.dump(result, f)

            tune_cp = _TrackedCheckpoint(
                dir_or_data=checkpoint_dir,
                storage_mode=CheckpointStorage.PERSISTENT,
                metrics=result,
            )
            trial.saving_to = tune_cp

            return checkpoint_dir
コード例 #11
0
ファイル: test_trainable_util.py プロジェクト: qyou/ray
    def testPickleCheckpoint(self):
        for i in range(5):
            path = os.path.join(self.checkpoint_dir, str(i))
            with open(path, "w") as f:
                f.write(str(i))

        checkpoint_path = os.path.join(self.checkpoint_dir, "0")

        data_dict = TrainableUtil.pickle_checkpoint(checkpoint_path)
        loaded = cloudpickle.loads(data_dict)

        checkpoint_name = os.path.basename(checkpoint_path)
        self.assertEqual(loaded["checkpoint_name"], checkpoint_name)

        for i in range(5):
            path = os.path.join(self.checkpoint_dir, str(i))
            self.assertEquals(loaded["data"][str(i)], open(path, "rb").read())
コード例 #12
0
    def delete(checkpoint):
        """Requests checkpoint deletion asynchronously.

        Args:
            checkpoint (Checkpoint): Checkpoint to delete.
        """
        if checkpoint.storage == Checkpoint.PERSISTENT and checkpoint.value:
            logger.debug("Trial %s: Deleting checkpoint %s", trial_id,
                         checkpoint.value)
            checkpoint_path = checkpoint.value
            # Delete local copy, if any exists.
            if os.path.exists(checkpoint_path):
                try:
                    checkpoint_dir = TrainableUtil.find_checkpoint_dir(
                        checkpoint_path)
                    shutil.rmtree(checkpoint_dir)
                except FileNotFoundError:
                    logger.warning("Checkpoint dir not found during deletion.")

            # TODO(ujvl): Batch remote deletes.
            runner.delete_checkpoint.remote(checkpoint.value)
コード例 #13
0
    def delete_checkpoint(self, checkpoint_path):
        """Deletes local copy of checkpoint.

        Args:
            checkpoint_path (str): Path to checkpoint.
        """
        try:
            checkpoint_dir = TrainableUtil.find_checkpoint_dir(checkpoint_path)
        except FileNotFoundError:
            # The checkpoint won't exist locally if the
            # trial was rescheduled to another worker.
            logger.debug(
                f"Local checkpoint not found during garbage collection: "
                f"{self.trial_id} - {checkpoint_path}")
            return
        else:
            if self.uses_cloud_checkpointing:
                self.storage_client.delete(self._storage_path(checkpoint_dir))

        if os.path.exists(checkpoint_dir):
            shutil.rmtree(checkpoint_dir)
コード例 #14
0
    def delete_checkpoint(self, checkpoint_path: str):
        """Deletes local copy of checkpoint.

        Args:
            checkpoint_path: Path to checkpoint.
        """
        # Ensure TrialCheckpoints are converted
        if isinstance(checkpoint_path, TrialCheckpoint):
            checkpoint_path = checkpoint_path.local_path

        try:
            checkpoint_dir = TrainableUtil.find_checkpoint_dir(checkpoint_path)
        except FileNotFoundError:
            # The checkpoint won't exist locally if the
            # trial was rescheduled to another worker.
            logger.debug(
                f"Local checkpoint not found during garbage collection: "
                f"{self.trial_id} - {checkpoint_path}"
            )
            return
        else:
            if self.uses_cloud_checkpointing:
                if self.custom_syncer:
                    # Keep for backwards compatibility
                    self.custom_syncer.delete(self._storage_path(checkpoint_dir))
                    self.custom_syncer.wait_or_retry()
                else:
                    checkpoint_uri = self._storage_path(checkpoint_dir)
                    retry_fn(
                        lambda: delete_external_checkpoint(checkpoint_uri),
                        subprocess.CalledProcessError,
                        num_retries=3,
                        sleep_time=1,
                    )

        if os.path.exists(checkpoint_dir):
            shutil.rmtree(checkpoint_dir)
コード例 #15
0
ファイル: experiment_analysis.py プロジェクト: alipay/ray
    def get_trial_checkpoints_paths(
            self,
            trial: Trial,
            metric: Optional[str] = None) -> List[Tuple[str, Number]]:
        """Gets paths and metrics of all persistent checkpoints of a trial.

        Args:
            trial: The log directory of a trial, or a trial instance.
            metric: key for trial info to return, e.g. "mean_accuracy".
                "training_iteration" is used by default if no value was
                passed to ``self.default_metric``.

        Returns:
            List of [path, metric] for all persistent checkpoints of the trial.
        """
        metric = metric or self.default_metric or TRAINING_ITERATION

        if isinstance(trial, str):
            trial_dir = os.path.expanduser(trial)
            # Get checkpoints from logdir.
            chkpt_df = TrainableUtil.get_checkpoints_paths(trial_dir)

            # Join with trial dataframe to get metrics.
            trial_df = self.trial_dataframes[trial_dir]
            path_metric_df = chkpt_df.merge(trial_df,
                                            on="training_iteration",
                                            how="inner")
            return path_metric_df[["chkpt_path", metric]].values.tolist()
        elif isinstance(trial, Trial):
            checkpoints = trial.checkpoint_manager.best_checkpoints()
            # Support metrics given as paths, e.g.
            # "info/learner/default_policy/policy_loss".
            return [(c.value, unflattened_lookup(metric, c.result))
                    for c in checkpoints]
        else:
            raise ValueError("trial should be a string or a Trial instance.")
コード例 #16
0
ファイル: horovod.py プロジェクト: patrickstuedi/ray
 def save_checkpoint(self, checkpoint_dir: str) -> str:
     # TODO: optimize if colocated
     save_obj = self.executor.execute_single(lambda w: w.save_to_object())
     checkpoint_path = TrainableUtil.create_from_pickle(
         save_obj, checkpoint_dir)
     return checkpoint_path
コード例 #17
0
 def load_checkpoint(self, checkpoint_dir: str):
     checkpoint_obj = TrainableUtil.checkpoint_to_object(checkpoint_dir)
     return ray.get(
         w.restore_from_object.remote(checkpoint_obj) for w in self.workers)
コード例 #18
0
ファイル: trainable.py プロジェクト: tchordia/ray
    def restore(self, checkpoint_path: str, checkpoint_node_ip: Optional[str] = None):
        """Restores training state from a given model checkpoint.

        These checkpoints are returned from calls to save().

        Subclasses should override ``load_checkpoint()`` instead to
        restore state.
        This method restores additional metadata saved with the checkpoint.

        `checkpoint_path` should match with the return from ``save()``.

        `checkpoint_path` can be
        `~/ray_results/exp/MyTrainable_abc/
        checkpoint_00000/checkpoint`. Or,
        `~/ray_results/exp/MyTrainable_abc/checkpoint_00000`.

        `self.logdir` should generally be corresponding to `checkpoint_path`,
        for example, `~/ray_results/exp/MyTrainable_abc`.

        `self.remote_checkpoint_dir` in this case, is something like,
        `REMOTE_CHECKPOINT_BUCKET/exp/MyTrainable_abc`

        Args:
            checkpoint_path: Path to restore checkpoint from. If this
                path does not exist on the local node, it will be fetched
                from external (cloud) storage if available, or restored
                from a remote node.
            checkpoint_node_ip: If given, try to restore
                checkpoint from this node if it doesn't exist locally or
                on cloud storage.

        """
        # Ensure TrialCheckpoints are converted
        if isinstance(checkpoint_path, TrialCheckpoint):
            checkpoint_path = checkpoint_path.local_path

        if self.uses_cloud_checkpointing:
            rel_checkpoint_dir = TrainableUtil.find_rel_checkpoint_dir(
                self.logdir, checkpoint_path
            )
            external_uri = os.path.join(self.remote_checkpoint_dir, rel_checkpoint_dir)
            local_dir = os.path.join(self.logdir, rel_checkpoint_dir)

            if self.storage_client:
                # Only keep for backwards compatibility
                self.storage_client.sync_down(external_uri, local_dir)
                self.storage_client.wait_or_retry()
            else:
                checkpoint = Checkpoint.from_uri(external_uri)
                retry_fn(
                    lambda: checkpoint.to_directory(local_dir),
                    subprocess.CalledProcessError,
                    num_retries=3,
                    sleep_time=1,
                )
        elif (
            # If a checkpoint source IP is given
            checkpoint_node_ip
            # And the checkpoint does not currently exist on the local node
            and not os.path.exists(checkpoint_node_ip)
            # And the source IP is different to the current IP
            and checkpoint_node_ip != ray.util.get_node_ip_address()
        ):
            checkpoint = get_checkpoint_from_remote_node(
                checkpoint_path, checkpoint_node_ip
            )
            if checkpoint:
                checkpoint.to_directory(checkpoint_path)

        with open(checkpoint_path + ".tune_metadata", "rb") as f:
            metadata = pickle.load(f)
        self._experiment_id = metadata["experiment_id"]
        self._iteration = metadata["iteration"]
        self._timesteps_total = metadata["timesteps_total"]
        self._time_total = metadata["time_total"]
        self._episodes_total = metadata["episodes_total"]
        saved_as_dict = metadata["saved_as_dict"]
        if saved_as_dict:
            with open(checkpoint_path, "rb") as loaded_state:
                checkpoint_dict = pickle.load(loaded_state)
            checkpoint_dict.update(tune_checkpoint_path=checkpoint_path)
            self.load_checkpoint(checkpoint_dict)
        else:
            self.load_checkpoint(checkpoint_path)
        self._time_since_restore = 0.0
        self._timesteps_since_restore = 0
        self._iterations_since_restore = 0
        self._restored = True
        logger.info(
            "Restored on %s from checkpoint: %s", self.get_current_ip(), checkpoint_path
        )
        state = {
            "_iteration": self._iteration,
            "_timesteps_total": self._timesteps_total,
            "_time_total": self._time_total,
            "_episodes_total": self._episodes_total,
        }
        logger.info("Current state after restoring: %s", state)
コード例 #19
0
 def save_checkpoint(self, checkpoint_dir: str) -> str:
     # TODO: optimize if colocated
     save_obj = ray.get(self.workers[0].save_to_object.remote())
     checkpoint_path = TrainableUtil.create_from_pickle(
         save_obj, checkpoint_dir)
     return checkpoint_path
コード例 #20
0
def test_find_rel_checkpoint_dir(checkpoint_path, logdir):
    assert (TrainableUtil.find_rel_checkpoint_dir(
        logdir, checkpoint_path) == "checkpoint0/")
コード例 #21
0
ファイル: horovod.py プロジェクト: patrickstuedi/ray
 def load_checkpoint(self, checkpoint_dir: str):
     checkpoint_obj = TrainableUtil.checkpoint_to_object(checkpoint_dir)
     x_id = ray.put(checkpoint_obj)
     return self.executor.execute(
         lambda w: w.restore_from_object(ray.get(x_id)))
コード例 #22
0
ファイル: trainable.py プロジェクト: jon-chuang/ray
    def restore(self, checkpoint_path):
        """Restores training state from a given model checkpoint.

        These checkpoints are returned from calls to save().

        Subclasses should override ``load_checkpoint()`` instead to
        restore state.
        This method restores additional metadata saved with the checkpoint.

        `checkpoint_path` should match with the return from ``save()``.

        `checkpoint_path` can be
        `~/ray_results/exp/MyTrainable_abc/
        checkpoint_00000/checkpoint`. Or,
        `~/ray_results/exp/MyTrainable_abc/checkpoint_00000`.

        `self.logdir` should generally be corresponding to `checkpoint_path`,
        for example, `~/ray_results/exp/MyTrainable_abc`.

        `self.remote_checkpoint_dir` in this case, is something like,
        `REMOTE_CHECKPOINT_BUCKET/exp/MyTrainable_abc`
        """
        if self.uses_cloud_checkpointing:
            rel_checkpoint_dir = TrainableUtil.find_rel_checkpoint_dir(
                self.logdir, checkpoint_path)
            self.storage_client.sync_down(
                os.path.join(self.remote_checkpoint_dir, rel_checkpoint_dir),
                os.path.join(self.logdir, rel_checkpoint_dir),
            )
            self.storage_client.wait_or_retry()

        # Ensure TrialCheckpoints are converted
        if isinstance(checkpoint_path, TrialCheckpoint):
            checkpoint_path = checkpoint_path.local_path

        with open(checkpoint_path + ".tune_metadata", "rb") as f:
            metadata = pickle.load(f)
        self._experiment_id = metadata["experiment_id"]
        self._iteration = metadata["iteration"]
        self._timesteps_total = metadata["timesteps_total"]
        self._time_total = metadata["time_total"]
        self._episodes_total = metadata["episodes_total"]
        saved_as_dict = metadata["saved_as_dict"]
        if saved_as_dict:
            with open(checkpoint_path, "rb") as loaded_state:
                checkpoint_dict = pickle.load(loaded_state)
            checkpoint_dict.update(tune_checkpoint_path=checkpoint_path)
            self.load_checkpoint(checkpoint_dict)
        else:
            self.load_checkpoint(checkpoint_path)
        self._time_since_restore = 0.0
        self._timesteps_since_restore = 0
        self._iterations_since_restore = 0
        self._restored = True
        logger.info("Restored on %s from checkpoint: %s",
                    self.get_current_ip(), checkpoint_path)
        state = {
            "_iteration": self._iteration,
            "_timesteps_total": self._timesteps_total,
            "_time_total": self._time_total,
            "_episodes_total": self._episodes_total,
        }
        logger.info("Current state after restoring: %s", state)
コード例 #23
0
ファイル: test_trainable_util.py プロジェクト: qyou/ray
 def setUp(self):
     self.checkpoint_dir = os.path.join(
         ray._private.utils.get_user_temp_dir(), "tune", "MyTrainable123")
     self.checkpoint_dir = TrainableUtil.make_checkpoint_dir(
         self.checkpoint_dir, "0")