Пример #1
0
    def to_uri(self, uri: str) -> str:
        """Write checkpoint data to location URI (e.g. cloud storage).

        ARgs:
            uri (str): Target location URI to write data to.

        Returns:
            str: Cloud location containing checkpoint data.
        """
        if uri.startswith("file://"):
            local_path = uri[7:]
            return self.to_directory(local_path)

        assert is_cloud_target(uri)

        cleanup = False

        local_path = self._local_path
        if not local_path:
            cleanup = True
            local_path = self.to_directory()

        upload_to_bucket(bucket=uri, local_path=local_path)

        if cleanup:
            shutil.rmtree(local_path)

        return uri
Пример #2
0
def _get_local_path(path: Optional[str]) -> Optional[str]:
    """Check if path is a local path. Otherwise return None."""
    if path is None or is_cloud_target(path):
        return None
    if path.startswith("file://"):
        path = path[7:]
    if os.path.exists(path):
        return path
    return None
Пример #3
0
def _get_external_path(path: Optional[str]) -> Optional[str]:
    """Check if path is an external path. Otherwise return None."""
    if not isinstance(path, str) or not is_cloud_target(path):
        return None
    return path
Пример #4
0
    def save(self, path: Optional[str] = None, force_download: bool = False):
        """Save trial checkpoint to directory or cloud storage.

        If the ``path`` is a local target and the checkpoint already exists
        on local storage, the local directory is copied. Else, the checkpoint
        is downloaded from cloud storage.

        If the ``path`` is a cloud target and the checkpoint does not already
        exist on local storage, it is downloaded from cloud storage before.
        That way checkpoints can be transferred across cloud storage providers.

        Args:
            path: Path to save checkpoint at. If empty,
                the default cloud storage path is saved to the default
                local directory.
            force_download: If ``True``, forces (re-)download of
                the checkpoint. Defaults to ``False``.
        """
        temp_dirs = set()
        # Per default, save cloud checkpoint
        if not path:
            if self.cloud_path and self.local_path:
                path = self.local_path
            elif not self.cloud_path:
                raise RuntimeError(
                    "Cannot save trial checkpoint: No cloud path "
                    "found. If the checkpoint is already on the node, "
                    "you can pass a `path` argument to save it at another "
                    "location.")
            else:
                # No self.local_path
                raise RuntimeError(
                    "Cannot save trial checkpoint: No target path "
                    "specified and no default local directory available. "
                    "Please pass a `path` argument to `save()`.")
        elif not self.local_path and not self.cloud_path:
            raise RuntimeError(
                f"Cannot save trial checkpoint to cloud target "
                f"`{path}`: No existing local or cloud path was "
                f"found. This indicates an error when loading "
                f"the checkpoints. Please report this issue.")

        if is_cloud_target(path):
            # Storing on cloud
            if not self.local_path:
                # No local copy, yet. Download to temp dir
                local_path = tempfile.mkdtemp(prefix="tune_checkpoint_")
                temp_dirs.add(local_path)
            else:
                local_path = self.local_path

            if self.cloud_path:
                # Do not update local path as it might be a temp file
                local_path = self.download(local_path=local_path,
                                           overwrite=force_download)

                # Remove pointer to a temporary directory
                if self.local_path in temp_dirs:
                    self.local_path = None

            # We should now have a checkpoint available locally
            if not os.path.exists(local_path) or len(
                    os.listdir(local_path)) == 0:
                raise RuntimeError(
                    f"No checkpoint found in directory `{local_path}` after "
                    f"download - maybe the bucket is empty or downloading "
                    f"failed?")

            # Only update cloud path if it wasn't set before
            cloud_path = self.upload(cloud_path=path,
                                     local_path=local_path,
                                     clean_before=True)

            # Clean up temporary directories
            for temp_dir in temp_dirs:
                shutil.rmtree(temp_dir)

            return cloud_path

        local_path_exists = (self.local_path
                             and os.path.exists(self.local_path)
                             and len(os.listdir(self.local_path)) > 0)

        # Else: path is a local target
        if self.local_path and local_path_exists and not force_download:
            # If we have a local copy, use it

            if path == self.local_path:
                # Nothing to do
                return self.local_path

            # Both local, just copy tree
            if os.path.exists(path):
                shutil.rmtree(path)

            shutil.copytree(self.local_path, path)
            return path

        # Else: Download
        try:
            return self.download(local_path=path, overwrite=force_download)
        except Exception as e:
            raise RuntimeError(
                "Cannot save trial checkpoint to local target as downloading "
                "from cloud failed. Did you pass the correct `SyncConfig`?"
            ) from e