Exemple #1
0
    def testFindCheckpointDir(self):
        checkpoint_path = os.path.join(self.checkpoint_dir, "0/my/nested/chkpt")
        os.makedirs(checkpoint_path)
        found_dir = TrainableUtil.find_checkpoint_dir(checkpoint_path)
        self.assertEqual(self.checkpoint_dir, found_dir)

        with self.assertRaises(FileNotFoundError):
            parent = os.path.dirname(found_dir)
            TrainableUtil.find_checkpoint_dir(parent)
Exemple #2
0
    def to_air_checkpoint(self) -> Optional[Checkpoint]:
        from ray.tune.trainable.util import TrainableUtil

        checkpoint_data = self.dir_or_data

        if not checkpoint_data:
            return None

        if isinstance(checkpoint_data, ray.ObjectRef):
            checkpoint_data = ray.get(checkpoint_data)

        if isinstance(checkpoint_data, str):
            checkpoint_dir = TrainableUtil.find_checkpoint_dir(checkpoint_data)
            checkpoint = Checkpoint.from_directory(checkpoint_dir)
        elif isinstance(checkpoint_data, bytes):
            with tempfile.TemporaryDirectory() as tmpdir:
                TrainableUtil.create_from_pickle(checkpoint_data, tmpdir)
                # Double wrap in checkpoint so we hold the data in memory and
                # can remove the temp directory
                checkpoint = Checkpoint.from_dict(
                    Checkpoint.from_directory(tmpdir).to_dict())
        elif isinstance(checkpoint_data, dict):
            checkpoint = Checkpoint.from_dict(checkpoint_data)
        else:
            raise RuntimeError(
                f"Unknown checkpoint data type: {type(checkpoint_data)}")

        return checkpoint
Exemple #3
0
    def __call__(self, checkpoint: _TrackedCheckpoint):
        """Requests checkpoint deletion asynchronously.

        Args:
            checkpoint: Checkpoint to delete.
        """
        if not self.runner:
            return

        if (checkpoint.storage_mode == CheckpointStorage.PERSISTENT
                and checkpoint.dir_or_data):
            checkpoint_path = checkpoint.dir_or_data

            logger.debug("Trial %s: Deleting checkpoint %s", self.trial_id,
                         checkpoint_path)

            # TODO(ujvl): Batch remote deletes.
            # We first delete the remote checkpoint. If it is on the same
            # node as the driver, it will also remove the local copy.
            ray.get(self.runner.delete_checkpoint.remote(checkpoint_path))

            # Delete local copy, if any exists.
            if os.path.exists(checkpoint_path):
                try:
                    checkpoint_dir = TrainableUtil.find_checkpoint_dir(
                        checkpoint_path)
                    shutil.rmtree(checkpoint_dir)
                except FileNotFoundError:
                    logger.debug(
                        "Local checkpoint dir not found during deletion.")
Exemple #4
0
            def save_checkpoint(self, tmp_checkpoint_dir: str = ""):
                checkpoint_path = super().save_checkpoint()
                parent_dir = TrainableUtil.find_checkpoint_dir(checkpoint_path)

                preprocessor = self._merged_config.get("preprocessor", None)
                if parent_dir and preprocessor:
                    save_preprocessor_to_dir(preprocessor, parent_dir)
                return checkpoint_path
Exemple #5
0
    def restore(self, trial: Trial) -> None:
        """Restores training state from a given model checkpoint.

        Args:
            trial: The trial to be restored.

        Raises:
            RuntimeError: This error is raised if no runner is found.
            AbortTrialExecution: This error is raised if the trial is
                ineligible for restoration, given the Tune input arguments.
        """
        checkpoint = trial.checkpoint
        if checkpoint.dir_or_data is None:
            return
        if trial.runner is None:
            raise RuntimeError(
                "Trial {}: Unable to restore - no runner found.".format(trial))
        checkpoint_dir = checkpoint.dir_or_data
        node_ip = checkpoint.node_ip
        if checkpoint.storage_mode == CheckpointStorage.MEMORY:
            logger.debug("Trial %s: Attempting restore from object", trial)
            # Note that we don't store the remote since in-memory checkpoints
            # don't guarantee fault tolerance and don't need to be waited on.
            with self._change_working_directory(trial):
                trial.runner.restore_from_object.remote(checkpoint_dir)
        else:
            logger.debug("Trial %s: Attempting restore from %s", trial,
                         checkpoint_dir)
            if (trial.uses_cloud_checkpointing or not trial.sync_on_checkpoint
                    or not os.path.exists(checkpoint_dir)):
                # If using cloud checkpointing, trial will get cp from cloud.
                # If not syncing to driver, assume it has access to the cp
                # on the local fs.
                with self._change_working_directory(trial):
                    remote = trial.runner.restore.remote(
                        checkpoint_dir, node_ip)
            elif trial.sync_on_checkpoint:
                # This provides FT backwards compatibility in the
                # case where no cloud checkpoints are provided.
                logger.debug("Trial %s: Reading checkpoint into memory", trial)
                checkpoint_path = TrainableUtil.find_checkpoint_dir(
                    checkpoint_dir)
                obj = Checkpoint.from_directory(checkpoint_path).to_bytes()
                with self._change_working_directory(trial):
                    remote = trial.runner.restore_from_object.remote(obj)
            else:
                raise _AbortTrialExecution(
                    "Pass in `sync_on_checkpoint=True` for driver-based trial"
                    "restoration. Pass in an `upload_dir` for remote "
                    "storage-based restoration")

            self._futures[remote] = (_ExecutorEventType.RESTORING_RESULT,
                                     trial)
            trial.restoring_from = checkpoint
Exemple #6
0
    def testConvertTempToPermanent(self):
        checkpoint_dir = FuncCheckpointUtil.mk_temp_checkpoint_dir(self.logdir)
        new_checkpoint_dir = FuncCheckpointUtil.create_perm_checkpoint(
            checkpoint_dir, self.logdir, step=4)
        assert new_checkpoint_dir == TrainableUtil.find_checkpoint_dir(
            new_checkpoint_dir)
        assert os.path.exists(new_checkpoint_dir)
        assert not FuncCheckpointUtil.is_temp_checkpoint_dir(
            new_checkpoint_dir)

        tmp_checkpoint_dir = FuncCheckpointUtil.mk_temp_checkpoint_dir(
            self.logdir)
        assert tmp_checkpoint_dir != new_checkpoint_dir
Exemple #7
0
def create_checkpoint(preprocessor: Optional[Preprocessor] = None,
                      config: Optional[dict] = None) -> Checkpoint:
    rl_trainer = RLTrainer(
        algorithm=_DummyAlgo,
        config=config or {},
        preprocessor=preprocessor,
    )
    rl_trainable_cls = rl_trainer.as_trainable()
    rl_trainable = rl_trainable_cls()

    with tempfile.TemporaryDirectory() as checkpoint_dir:
        checkpoint_file = rl_trainable.save(checkpoint_dir)
        checkpoint_path = TrainableUtil.find_checkpoint_dir(checkpoint_file)
        checkpoint_data = Checkpoint.from_directory(checkpoint_path).to_dict()

    return Checkpoint.from_dict(checkpoint_data)
Exemple #8
0
    def delete_checkpoint(self, checkpoint_path: Union[str, Checkpoint]):
        """Deletes local copy of checkpoint.

        Args:
            checkpoint_path: Path to checkpoint.
        """
        # Ensure Checkpoints are converted
        if isinstance(checkpoint_path,
                      Checkpoint) and checkpoint_path._local_path:
            checkpoint_path = checkpoint_path._local_path

        try:
            checkpoint_dir = TrainableUtil.find_checkpoint_dir(checkpoint_path)
        except FileNotFoundError:
            # The checkpoint won't exist locally if the
            # trial was rescheduled to another worker.
            logger.debug(
                f"Local checkpoint not found during garbage collection: "
                f"{self.trial_id} - {checkpoint_path}")
            return
        else:
            if self.uses_cloud_checkpointing:
                if self.custom_syncer:
                    # Keep for backwards compatibility
                    self.custom_syncer.delete(
                        self._storage_path(checkpoint_dir))
                    self.custom_syncer.wait_or_retry()
                else:
                    checkpoint_uri = self._storage_path(checkpoint_dir)
                    retry_fn(
                        lambda: delete_external_checkpoint(checkpoint_uri),
                        subprocess.CalledProcessError,
                        num_retries=3,
                        sleep_time=1,
                    )

        if os.path.exists(checkpoint_dir):
            shutil.rmtree(checkpoint_dir)
Exemple #9
0
    def to_air_checkpoint(self) -> Optional[Checkpoint]:
        from ray.tune.trainable.util import TrainableUtil

        checkpoint_data = self.dir_or_data

        if not checkpoint_data:
            return None

        if isinstance(checkpoint_data, ray.ObjectRef):
            checkpoint_data = ray.get(checkpoint_data)

        if isinstance(checkpoint_data, str):
            try:
                checkpoint_dir = TrainableUtil.find_checkpoint_dir(
                    checkpoint_data)
            except FileNotFoundError:
                if log_once("checkpoint_not_available"):
                    logger.error(
                        f"The requested checkpoint is not available on this node, "
                        f"most likely because you are using Ray client or disabled "
                        f"checkpoint synchronization. To avoid this, enable checkpoint "
                        f"synchronization to cloud storage by specifying a "
                        f"`SyncConfig`. The checkpoint may be available on a different "
                        f"node - please check this location on worker nodes: "
                        f"{checkpoint_data}")
                return None
            checkpoint = Checkpoint.from_directory(checkpoint_dir)
        elif isinstance(checkpoint_data, bytes):
            checkpoint = Checkpoint.from_bytes(checkpoint_data)
        elif isinstance(checkpoint_data, dict):
            checkpoint = Checkpoint.from_dict(checkpoint_data)
        else:
            raise RuntimeError(
                f"Unknown checkpoint data type: {type(checkpoint_data)}")

        return checkpoint
Exemple #10
0
    def restore(
        self,
        checkpoint_path: Union[str, Checkpoint],
        checkpoint_node_ip: Optional[str] = None,
    ):
        """Restores training state from a given model checkpoint.

        These checkpoints are returned from calls to save().

        Subclasses should override ``load_checkpoint()`` instead to
        restore state.
        This method restores additional metadata saved with the checkpoint.

        `checkpoint_path` should match with the return from ``save()``.

        `checkpoint_path` can be
        `~/ray_results/exp/MyTrainable_abc/
        checkpoint_00000/checkpoint`. Or,
        `~/ray_results/exp/MyTrainable_abc/checkpoint_00000`.

        `self.logdir` should generally be corresponding to `checkpoint_path`,
        for example, `~/ray_results/exp/MyTrainable_abc`.

        `self.remote_checkpoint_dir` in this case, is something like,
        `REMOTE_CHECKPOINT_BUCKET/exp/MyTrainable_abc`

        Args:
            checkpoint_path: Path to restore checkpoint from. If this
                path does not exist on the local node, it will be fetched
                from external (cloud) storage if available, or restored
                from a remote node.
            checkpoint_node_ip: If given, try to restore
                checkpoint from this node if it doesn't exist locally or
                on cloud storage.

        """
        # Ensure Checkpoints are converted
        if isinstance(checkpoint_path, Checkpoint):
            return self._restore_from_checkpoint_obj(checkpoint_path)

        if not self._maybe_load_from_cloud(checkpoint_path) and (
                # If a checkpoint source IP is given
                checkpoint_node_ip
                # And the checkpoint does not currently exist on the local node
                and not os.path.exists(checkpoint_node_ip)
                # And the source IP is different to the current IP
                and checkpoint_node_ip != ray.util.get_node_ip_address()):
            checkpoint = get_checkpoint_from_remote_node(
                checkpoint_path, checkpoint_node_ip)
            if checkpoint:
                checkpoint.to_directory(checkpoint_path)

        if not os.path.exists(checkpoint_path):
            raise ValueError(
                f"Could not recover from checkpoint as it does not exist on local "
                f"disk and was not available on cloud storage or another Ray node. "
                f"Got checkpoint path: {checkpoint_path} and IP {checkpoint_node_ip}"
            )

        checkpoint_dir = TrainableUtil.find_checkpoint_dir(checkpoint_path)
        metadata = TrainableUtil.load_metadata(checkpoint_dir)

        if metadata["saved_as_dict"]:
            # If data was saved as a dict (e.g. from a class trainable),
            # also pass the dict to `load_checkpoint()`.
            checkpoint_dict = Checkpoint.from_directory(
                checkpoint_dir).to_dict()
            # If other files were added to the directory after converting from the
            # original dict (e.g. marker files), clean these up
            checkpoint_dict.pop(_DICT_CHECKPOINT_ADDITIONAL_FILE_KEY, None)
            to_load = checkpoint_dict
        else:
            # Otherwise, pass the relative checkpoint path
            relative_checkpoint_path = metadata["relative_checkpoint_path"]
            to_load = os.path.join(checkpoint_dir, relative_checkpoint_path)

        # Set metadata
        self._experiment_id = metadata["experiment_id"]
        self._iteration = metadata["iteration"]
        self._timesteps_total = metadata["timesteps_total"]
        self._time_total = metadata["time_total"]
        self._episodes_total = metadata["episodes_total"]

        # Actually load checkpoint
        self.load_checkpoint(to_load)

        self._time_since_restore = 0.0
        self._timesteps_since_restore = 0
        self._iterations_since_restore = 0
        self._restored = True

        logger.info("Restored on %s from checkpoint: %s",
                    self.get_current_ip(), checkpoint_dir)
        state = {
            "_iteration": self._iteration,
            "_timesteps_total": self._timesteps_total,
            "_time_total": self._time_total,
            "_episodes_total": self._episodes_total,
        }
        logger.info("Current state after restoring: %s", state)