示例#1
0
    def to_directory(self, path: Optional[str] = None) -> str:
        """Write checkpoint data to directory.

        Args:
            path (str): Target directory to restore data in.

        Returns:
            str: Directory containing checkpoint data.
        """
        path = path if path is not None else _temporary_checkpoint_dir()

        os.makedirs(path, exist_ok=True)
        # Drop marker
        open(os.path.join(path, ".is_checkpoint"), "a").close()

        if self._data_dict or self._obj_ref:
            # This is a object ref or dict
            data_dict = self.to_dict()

            if _FS_CHECKPOINT_KEY in data_dict:
                # This used to be a true fs checkpoint, so restore
                _unpack(data_dict[_FS_CHECKPOINT_KEY], path)
            else:
                # This is a dict checkpoint. Dump data into checkpoint.pkl
                checkpoint_data_path = os.path.join(
                    path, _DICT_CHECKPOINT_FILE_NAME)
                with open(checkpoint_data_path, "wb") as f:
                    pickle.dump(data_dict, f)
        else:
            # This is either a local fs, remote node fs, or external fs
            local_path = self._local_path
            external_path = _get_external_path(self._uri)
            if local_path:
                if local_path != path:
                    # If this exists on the local path, just copy over
                    if path and os.path.exists(path):
                        shutil.rmtree(path)
                    shutil.copytree(local_path, path)
            elif external_path:
                # If this exists on external storage (e.g. cloud), download
                download_from_bucket(bucket=external_path, local_path=path)
            else:
                raise RuntimeError(
                    f"No valid location found for checkpoint {self}: {self._uri}"
                )

        return path
示例#2
0
    def download(
        self,
        cloud_path: Optional[str] = None,
        local_path: Optional[str] = None,
        overwrite: bool = False,
    ) -> str:
        """Download checkpoint from cloud.

        This will fetch the checkpoint directory from cloud storage
        and save it to ``local_path``.

        If a ``local_path`` argument is provided and ``self.local_path``
        is unset, it will be set to ``local_path``.

        Args:
            cloud_path: Cloud path to load checkpoint from.
                Defaults to ``self.cloud_path``.
            local_path: Local path to save checkpoint at.
                Defaults to ``self.local_path``.
            overwrite: If True, overwrites potential existing local
                checkpoint. If False, exits if ``self.local_dir`` already
                exists and has files in it.

        """
        cloud_path = cloud_path or self.cloud_path
        if not cloud_path:
            raise RuntimeError(
                "Could not download trial checkpoint: No cloud "
                "path is set. Fix this by either passing a "
                "`cloud_path` to your call to `download()` or by "
                "passing a `cloud_path` into the constructor. The latter "
                "should automatically be done if you pass the correct "
                "`tune.SyncConfig`.")

        local_path = local_path or self.local_path

        if not local_path:
            raise RuntimeError(
                "Could not download trial checkpoint: No local "
                "path is set. Fix this by either passing a "
                "`local_path` to your call to `download()` or by "
                "passing a `local_path` into the constructor.")

        # Only update local path if unset
        if not self.local_path:
            self.local_path = local_path

        if (not overwrite and os.path.exists(local_path)
                and len(os.listdir(local_path)) > 0):
            # Local path already exists and we should not overwrite,
            # so return.
            return local_path

        # Else: Actually download

        # Delete existing dir
        shutil.rmtree(local_path, ignore_errors=True)
        # Re-create
        os.makedirs(local_path, 0o755, exist_ok=True)

        # Here we trigger the actual download
        download_from_bucket(cloud_path, local_path)

        # Local dir exists and is not empty
        return local_path