Exemple #1
0
    def to_air_checkpoint(self) -> Optional[Checkpoint]:
        from ray.tune.trainable.util import TrainableUtil

        checkpoint_data = self.dir_or_data

        if not checkpoint_data:
            return None

        if isinstance(checkpoint_data, ray.ObjectRef):
            checkpoint_data = ray.get(checkpoint_data)

        if isinstance(checkpoint_data, str):
            checkpoint_dir = TrainableUtil.find_checkpoint_dir(checkpoint_data)
            checkpoint = Checkpoint.from_directory(checkpoint_dir)
        elif isinstance(checkpoint_data, bytes):
            with tempfile.TemporaryDirectory() as tmpdir:
                TrainableUtil.create_from_pickle(checkpoint_data, tmpdir)
                # Double wrap in checkpoint so we hold the data in memory and
                # can remove the temp directory
                checkpoint = Checkpoint.from_dict(
                    Checkpoint.from_directory(tmpdir).to_dict())
        elif isinstance(checkpoint_data, dict):
            checkpoint = Checkpoint.from_dict(checkpoint_data)
        else:
            raise RuntimeError(
                f"Unknown checkpoint data type: {type(checkpoint_data)}")

        return checkpoint
Exemple #2
0
    def save(self, checkpoint_dir: Optional[str] = None) -> str:
        """Saves the current model state to a checkpoint.

        Subclasses should override ``save_checkpoint()`` instead to save state.
        This method dumps additional metadata alongside the saved path.

        If a remote checkpoint dir is given, this will also sync up to remote
        storage.

        Args:
            checkpoint_dir: Optional dir to place the checkpoint.

        Returns:
            str: path that points to xxx.pkl file.

        Note the return path should match up with what is expected of
        `restore()`.
        """
        checkpoint_dir = TrainableUtil.make_checkpoint_dir(
            checkpoint_dir or self.logdir, index=self.iteration)
        checkpoint_dict_or_path = self.save_checkpoint(checkpoint_dir)
        trainable_state = self.get_state()
        checkpoint_path = TrainableUtil.process_checkpoint(
            checkpoint_dict_or_path,
            parent_dir=checkpoint_dir,
            trainable_state=trainable_state,
        )

        # Maybe sync to cloud
        self._maybe_save_to_cloud(checkpoint_dir)

        return checkpoint_path
Exemple #3
0
    def testFindCheckpointDir(self):
        checkpoint_path = os.path.join(self.checkpoint_dir, "0/my/nested/chkpt")
        os.makedirs(checkpoint_path)
        found_dir = TrainableUtil.find_checkpoint_dir(checkpoint_path)
        self.assertEqual(self.checkpoint_dir, found_dir)

        with self.assertRaises(FileNotFoundError):
            parent = os.path.dirname(found_dir)
            TrainableUtil.find_checkpoint_dir(parent)
Exemple #4
0
    def __call__(self, checkpoint: _TrackedCheckpoint):
        """Requests checkpoint deletion asynchronously.

        Args:
            checkpoint: Checkpoint to delete.
        """
        if not self.runner:
            return

        if (checkpoint.storage_mode == CheckpointStorage.PERSISTENT
                and checkpoint.dir_or_data):
            checkpoint_path = checkpoint.dir_or_data

            logger.debug("Trial %s: Deleting checkpoint %s", self.trial_id,
                         checkpoint_path)

            # TODO(ujvl): Batch remote deletes.
            # We first delete the remote checkpoint. If it is on the same
            # node as the driver, it will also remove the local copy.
            ray.get(self.runner.delete_checkpoint.remote(checkpoint_path))

            # Delete local copy, if any exists.
            if os.path.exists(checkpoint_path):
                try:
                    checkpoint_dir = TrainableUtil.find_checkpoint_dir(
                        checkpoint_path)
                    shutil.rmtree(checkpoint_dir)
                except FileNotFoundError:
                    logger.debug(
                        "Local checkpoint dir not found during deletion.")
Exemple #5
0
 def _create_checkpoint_dir(self,
                            checkpoint_dir: Optional[str] = None
                            ) -> Optional[str]:
     # Create checkpoint_xxxxx directory and drop checkpoint marker
     checkpoint_dir = TrainableUtil.make_checkpoint_dir(
         checkpoint_dir or self.logdir, index=self.iteration)
     return checkpoint_dir
Exemple #6
0
    def _maybe_load_from_cloud(self, checkpoint_path: str) -> bool:
        if not self.uses_cloud_checkpointing:
            return False

        rel_checkpoint_dir = TrainableUtil.find_rel_checkpoint_dir(
            self.logdir, checkpoint_path)
        external_uri = os.path.join(self.remote_checkpoint_dir,
                                    rel_checkpoint_dir)
        local_dir = os.path.join(self.logdir, rel_checkpoint_dir)

        if self.custom_syncer:
            # Only keep for backwards compatibility
            self.custom_syncer.sync_down(remote_dir=external_uri,
                                         local_dir=local_dir)
            self.custom_syncer.wait_or_retry()
            return True

        checkpoint = Checkpoint.from_uri(external_uri)
        retry_fn(
            lambda: checkpoint.to_directory(local_dir),
            subprocess.CalledProcessError,
            num_retries=3,
            sleep_time=1,
        )

        return True
Exemple #7
0
 def setUp(self):
     self.checkpoint_dir = os.path.join(
         ray._private.utils.get_user_temp_dir(), "tune", "MyTrainable123"
     )
     self.checkpoint_dir = TrainableUtil.make_checkpoint_dir(
         self.checkpoint_dir, "0"
     )
Exemple #8
0
            def save_checkpoint(self, tmp_checkpoint_dir: str = ""):
                checkpoint_path = super().save_checkpoint()
                parent_dir = TrainableUtil.find_checkpoint_dir(checkpoint_path)

                preprocessor = self._merged_config.get("preprocessor", None)
                if parent_dir and preprocessor:
                    save_preprocessor_to_dir(preprocessor, parent_dir)
                return checkpoint_path
Exemple #9
0
    def restore_from_object(self, obj):
        """Restores training state from a checkpoint object.

        These checkpoints are returned from calls to save_to_object().
        """
        tmpdir = tempfile.mkdtemp("restore_from_object", dir=self.logdir)
        checkpoint_path = TrainableUtil.create_from_pickle(obj, tmpdir)
        self.restore(checkpoint_path)
        shutil.rmtree(tmpdir)
Exemple #10
0
    def restore(self, trial: Trial) -> None:
        """Restores training state from a given model checkpoint.

        Args:
            trial: The trial to be restored.

        Raises:
            RuntimeError: This error is raised if no runner is found.
            AbortTrialExecution: This error is raised if the trial is
                ineligible for restoration, given the Tune input arguments.
        """
        checkpoint = trial.checkpoint
        if checkpoint.dir_or_data is None:
            return
        if trial.runner is None:
            raise RuntimeError(
                "Trial {}: Unable to restore - no runner found.".format(trial))
        checkpoint_dir = checkpoint.dir_or_data
        node_ip = checkpoint.node_ip
        if checkpoint.storage_mode == CheckpointStorage.MEMORY:
            logger.debug("Trial %s: Attempting restore from object", trial)
            # Note that we don't store the remote since in-memory checkpoints
            # don't guarantee fault tolerance and don't need to be waited on.
            with self._change_working_directory(trial):
                trial.runner.restore_from_object.remote(checkpoint_dir)
        else:
            logger.debug("Trial %s: Attempting restore from %s", trial,
                         checkpoint_dir)
            if (trial.uses_cloud_checkpointing or not trial.sync_on_checkpoint
                    or not os.path.exists(checkpoint_dir)):
                # If using cloud checkpointing, trial will get cp from cloud.
                # If not syncing to driver, assume it has access to the cp
                # on the local fs.
                with self._change_working_directory(trial):
                    remote = trial.runner.restore.remote(
                        checkpoint_dir, node_ip)
            elif trial.sync_on_checkpoint:
                # This provides FT backwards compatibility in the
                # case where no cloud checkpoints are provided.
                logger.debug("Trial %s: Reading checkpoint into memory", trial)
                checkpoint_path = TrainableUtil.find_checkpoint_dir(
                    checkpoint_dir)
                obj = Checkpoint.from_directory(checkpoint_path).to_bytes()
                with self._change_working_directory(trial):
                    remote = trial.runner.restore_from_object.remote(obj)
            else:
                raise _AbortTrialExecution(
                    "Pass in `sync_on_checkpoint=True` for driver-based trial"
                    "restoration. Pass in an `upload_dir` for remote "
                    "storage-based restoration")

            self._futures[remote] = (_ExecutorEventType.RESTORING_RESULT,
                                     trial)
            trial.restoring_from = checkpoint
Exemple #11
0
    def testConvertTempToPermanent(self):
        checkpoint_dir = FuncCheckpointUtil.mk_temp_checkpoint_dir(self.logdir)
        new_checkpoint_dir = FuncCheckpointUtil.create_perm_checkpoint(
            checkpoint_dir, self.logdir, step=4)
        assert new_checkpoint_dir == TrainableUtil.find_checkpoint_dir(
            new_checkpoint_dir)
        assert os.path.exists(new_checkpoint_dir)
        assert not FuncCheckpointUtil.is_temp_checkpoint_dir(
            new_checkpoint_dir)

        tmp_checkpoint_dir = FuncCheckpointUtil.mk_temp_checkpoint_dir(
            self.logdir)
        assert tmp_checkpoint_dir != new_checkpoint_dir
Exemple #12
0
    def save_to_object(self):
        """Saves the current model state to a Python object.

        It also saves to disk but does not return the checkpoint path.

        Returns:
            Object holding checkpoint data.
        """
        tmpdir = tempfile.mkdtemp("save_to_object", dir=self.logdir)
        checkpoint_path = self.save(tmpdir)
        # Save all files in subtree and delete the tmpdir.
        obj = TrainableUtil.checkpoint_to_object(checkpoint_path)
        shutil.rmtree(tmpdir)
        return obj
Exemple #13
0
        def write_checkpoint(trial: Trial, index: int):
            checkpoint_dir = TrainableUtil.make_checkpoint_dir(trial.logdir,
                                                               index=index)
            result = {"training_iteration": index}
            with open(os.path.join(checkpoint_dir, "cp.json"), "w") as f:
                json.dump(result, f)

            tune_cp = _TrackedCheckpoint(
                dir_or_data=checkpoint_dir,
                storage_mode=CheckpointStorage.PERSISTENT,
                metrics=result,
            )
            trial.saving_to = tune_cp

            return checkpoint_dir
Exemple #14
0
def create_checkpoint(preprocessor: Optional[Preprocessor] = None,
                      config: Optional[dict] = None) -> Checkpoint:
    rl_trainer = RLTrainer(
        algorithm=_DummyAlgo,
        config=config or {},
        preprocessor=preprocessor,
    )
    rl_trainable_cls = rl_trainer.as_trainable()
    rl_trainable = rl_trainable_cls()

    with tempfile.TemporaryDirectory() as checkpoint_dir:
        checkpoint_file = rl_trainable.save(checkpoint_dir)
        checkpoint_path = TrainableUtil.find_checkpoint_dir(checkpoint_file)
        checkpoint_data = Checkpoint.from_directory(checkpoint_path).to_dict()

    return Checkpoint.from_dict(checkpoint_data)
Exemple #15
0
    def testPickleCheckpoint(self):
        for i in range(5):
            path = os.path.join(self.checkpoint_dir, str(i))
            with open(path, "w") as f:
                f.write(str(i))

        checkpoint_path = os.path.join(self.checkpoint_dir, "0")

        data_dict = TrainableUtil.pickle_checkpoint(checkpoint_path)
        loaded = cloudpickle.loads(data_dict)

        checkpoint_name = os.path.basename(checkpoint_path)
        self.assertEqual(loaded["checkpoint_name"], checkpoint_name)

        for i in range(5):
            path = os.path.join(self.checkpoint_dir, str(i))
            self.assertEqual(loaded["data"][str(i)], open(path, "rb").read())
Exemple #16
0
    def delete_checkpoint(self, checkpoint_path: Union[str, Checkpoint]):
        """Deletes local copy of checkpoint.

        Args:
            checkpoint_path: Path to checkpoint.
        """
        # Ensure Checkpoints are converted
        if isinstance(checkpoint_path,
                      Checkpoint) and checkpoint_path._local_path:
            checkpoint_path = checkpoint_path._local_path

        try:
            checkpoint_dir = TrainableUtil.find_checkpoint_dir(checkpoint_path)
        except FileNotFoundError:
            # The checkpoint won't exist locally if the
            # trial was rescheduled to another worker.
            logger.debug(
                f"Local checkpoint not found during garbage collection: "
                f"{self.trial_id} - {checkpoint_path}")
            return
        else:
            if self.uses_cloud_checkpointing:
                if self.custom_syncer:
                    # Keep for backwards compatibility
                    self.custom_syncer.delete(
                        self._storage_path(checkpoint_dir))
                    self.custom_syncer.wait_or_retry()
                else:
                    checkpoint_uri = self._storage_path(checkpoint_dir)
                    retry_fn(
                        lambda: delete_external_checkpoint(checkpoint_uri),
                        subprocess.CalledProcessError,
                        num_retries=3,
                        sleep_time=1,
                    )

        if os.path.exists(checkpoint_dir):
            shutil.rmtree(checkpoint_dir)
Exemple #17
0
    def to_air_checkpoint(self) -> Optional[Checkpoint]:
        from ray.tune.trainable.util import TrainableUtil

        checkpoint_data = self.dir_or_data

        if not checkpoint_data:
            return None

        if isinstance(checkpoint_data, ray.ObjectRef):
            checkpoint_data = ray.get(checkpoint_data)

        if isinstance(checkpoint_data, str):
            try:
                checkpoint_dir = TrainableUtil.find_checkpoint_dir(
                    checkpoint_data)
            except FileNotFoundError:
                if log_once("checkpoint_not_available"):
                    logger.error(
                        f"The requested checkpoint is not available on this node, "
                        f"most likely because you are using Ray client or disabled "
                        f"checkpoint synchronization. To avoid this, enable checkpoint "
                        f"synchronization to cloud storage by specifying a "
                        f"`SyncConfig`. The checkpoint may be available on a different "
                        f"node - please check this location on worker nodes: "
                        f"{checkpoint_data}")
                return None
            checkpoint = Checkpoint.from_directory(checkpoint_dir)
        elif isinstance(checkpoint_data, bytes):
            checkpoint = Checkpoint.from_bytes(checkpoint_data)
        elif isinstance(checkpoint_data, dict):
            checkpoint = Checkpoint.from_dict(checkpoint_data)
        else:
            raise RuntimeError(
                f"Unknown checkpoint data type: {type(checkpoint_data)}")

        return checkpoint
Exemple #18
0
    def get_trial_checkpoints_paths(
            self,
            trial: Trial,
            metric: Optional[str] = None) -> List[Tuple[str, Number]]:
        """Gets paths and metrics of all persistent checkpoints of a trial.

        Args:
            trial: The log directory of a trial, or a trial instance.
            metric: key for trial info to return, e.g. "mean_accuracy".
                "training_iteration" is used by default if no value was
                passed to ``self.default_metric``.

        Returns:
            List of [path, metric] for all persistent checkpoints of the trial.
        """
        metric = metric or self.default_metric or TRAINING_ITERATION

        if isinstance(trial, str):
            trial_dir = os.path.expanduser(trial)
            # Get checkpoints from logdir.
            chkpt_df = TrainableUtil.get_checkpoints_paths(trial_dir)

            # Join with trial dataframe to get metrics.
            trial_df = self.trial_dataframes[trial_dir]
            path_metric_df = chkpt_df.merge(trial_df,
                                            on="training_iteration",
                                            how="inner")
            return path_metric_df[["chkpt_path", metric]].values.tolist()
        elif isinstance(trial, Trial):
            checkpoints = trial.get_trial_checkpoints()
            # Support metrics given as paths, e.g.
            # "info/learner/default_policy/policy_loss".
            return [(c.dir_or_data, unflattened_lookup(metric, c.metrics))
                    for c in checkpoints]
        else:
            raise ValueError("trial should be a string or a Trial instance.")
Exemple #19
0
    def save(self, checkpoint_dir: Optional[str] = None) -> str:
        """Saves the current model state to a checkpoint.

        Subclasses should override ``save_checkpoint()`` instead to save state.
        This method dumps additional metadata alongside the saved path.

        If a remote checkpoint dir is given, this will also sync up to remote
        storage.

        Args:
            checkpoint_dir: Optional dir to place the checkpoint.

        Returns:
            The given or created checkpoint directory.

        Note the return path should match up with what is expected of
        `restore()`.
        """
        checkpoint_dir = self._create_checkpoint_dir(
            checkpoint_dir=checkpoint_dir)

        # User saves checkpoint
        checkpoint_dict_or_path = self.save_checkpoint(checkpoint_dir)

        if checkpoint_dict_or_path is None:
            # checkpoint_dict_or_path can only be None in class trainables.
            # In that case the default is to use the root checkpoint directory.
            assert checkpoint_dir
            checkpoint_dict_or_path = checkpoint_dir
        elif checkpoint_dir is None:
            # checkpoint_dir is only None in function trainables. In that case,
            # checkpoint_dict_or_path points to the already saved checkpoint dir.
            # This will be considered the root dir.
            assert isinstance(checkpoint_dict_or_path, str)
            checkpoint_dir = checkpoint_dict_or_path

        # Get trainable metadata
        metadata = self.get_state()

        if isinstance(checkpoint_dict_or_path, dict):
            metadata["relative_checkpoint_path"] = ""
            metadata["saved_as_dict"] = True
            Checkpoint.from_dict(checkpoint_dict_or_path).to_directory(
                checkpoint_dir)
            # Re-drop marker
            TrainableUtil.mark_as_checkpoint_dir(checkpoint_dir)
        else:
            # Make sure the checkpoint dir is contained
            if not checkpoint_dict_or_path.startswith(checkpoint_dir):
                raise ValueError(
                    f"The returned checkpoint path must be within the given "
                    f"checkpoint dir ({checkpoint_dir}): {checkpoint_dict_or_path}"
                )

            # Get relative path to returned checkpoint
            relative_checkpoint_path = os.path.relpath(checkpoint_dict_or_path,
                                                       checkpoint_dir)
            metadata["relative_checkpoint_path"] = relative_checkpoint_path
            metadata["saved_as_dict"] = False

        TrainableUtil.write_metadata(checkpoint_dir, metadata)

        # Maybe sync to cloud
        self._maybe_save_to_cloud(checkpoint_dir)

        return checkpoint_dir
Exemple #20
0
    def restore(
        self,
        checkpoint_path: Union[str, Checkpoint],
        checkpoint_node_ip: Optional[str] = None,
    ):
        """Restores training state from a given model checkpoint.

        These checkpoints are returned from calls to save().

        Subclasses should override ``load_checkpoint()`` instead to
        restore state.
        This method restores additional metadata saved with the checkpoint.

        `checkpoint_path` should match with the return from ``save()``.

        `checkpoint_path` can be
        `~/ray_results/exp/MyTrainable_abc/
        checkpoint_00000/checkpoint`. Or,
        `~/ray_results/exp/MyTrainable_abc/checkpoint_00000`.

        `self.logdir` should generally be corresponding to `checkpoint_path`,
        for example, `~/ray_results/exp/MyTrainable_abc`.

        `self.remote_checkpoint_dir` in this case, is something like,
        `REMOTE_CHECKPOINT_BUCKET/exp/MyTrainable_abc`

        Args:
            checkpoint_path: Path to restore checkpoint from. If this
                path does not exist on the local node, it will be fetched
                from external (cloud) storage if available, or restored
                from a remote node.
            checkpoint_node_ip: If given, try to restore
                checkpoint from this node if it doesn't exist locally or
                on cloud storage.

        """
        # Ensure Checkpoints are converted
        if isinstance(checkpoint_path, Checkpoint):
            return self._restore_from_checkpoint_obj(checkpoint_path)

        if not self._maybe_load_from_cloud(checkpoint_path) and (
                # If a checkpoint source IP is given
                checkpoint_node_ip
                # And the checkpoint does not currently exist on the local node
                and not os.path.exists(checkpoint_node_ip)
                # And the source IP is different to the current IP
                and checkpoint_node_ip != ray.util.get_node_ip_address()):
            checkpoint = get_checkpoint_from_remote_node(
                checkpoint_path, checkpoint_node_ip)
            if checkpoint:
                checkpoint.to_directory(checkpoint_path)

        if not os.path.exists(checkpoint_path):
            raise ValueError(
                f"Could not recover from checkpoint as it does not exist on local "
                f"disk and was not available on cloud storage or another Ray node. "
                f"Got checkpoint path: {checkpoint_path} and IP {checkpoint_node_ip}"
            )

        checkpoint_dir = TrainableUtil.find_checkpoint_dir(checkpoint_path)
        metadata = TrainableUtil.load_metadata(checkpoint_dir)

        if metadata["saved_as_dict"]:
            # If data was saved as a dict (e.g. from a class trainable),
            # also pass the dict to `load_checkpoint()`.
            checkpoint_dict = Checkpoint.from_directory(
                checkpoint_dir).to_dict()
            # If other files were added to the directory after converting from the
            # original dict (e.g. marker files), clean these up
            checkpoint_dict.pop(_DICT_CHECKPOINT_ADDITIONAL_FILE_KEY, None)
            to_load = checkpoint_dict
        else:
            # Otherwise, pass the relative checkpoint path
            relative_checkpoint_path = metadata["relative_checkpoint_path"]
            to_load = os.path.join(checkpoint_dir, relative_checkpoint_path)

        # Set metadata
        self._experiment_id = metadata["experiment_id"]
        self._iteration = metadata["iteration"]
        self._timesteps_total = metadata["timesteps_total"]
        self._time_total = metadata["time_total"]
        self._episodes_total = metadata["episodes_total"]

        # Actually load checkpoint
        self.load_checkpoint(to_load)

        self._time_since_restore = 0.0
        self._timesteps_since_restore = 0
        self._iterations_since_restore = 0
        self._restored = True

        logger.info("Restored on %s from checkpoint: %s",
                    self.get_current_ip(), checkpoint_dir)
        state = {
            "_iteration": self._iteration,
            "_timesteps_total": self._timesteps_total,
            "_time_total": self._time_total,
            "_episodes_total": self._episodes_total,
        }
        logger.info("Current state after restoring: %s", state)
Exemple #21
0
def test_find_rel_checkpoint_dir(checkpoint_path, logdir):
    assert (
        TrainableUtil.find_rel_checkpoint_dir(logdir, checkpoint_path) == "checkpoint0"
    )