Ejemplo n.º 1
0
def _get_local_path(path: Optional[str]) -> Optional[str]:
    """Check if path is a local path. Otherwise return None."""
    if path is None or is_non_local_path_uri(path):
        return None
    if path.startswith("file://"):
        path = path[7:]
    if os.path.exists(path):
        return path
    return None
Ejemplo n.º 2
0
    def to_uri(self, uri: str) -> str:
        """Write checkpoint data to location URI (e.g. cloud storage).

        Args:
            uri (str): Target location URI to write data to.

        Returns:
            str: Cloud location containing checkpoint data.
        """
        if uri.startswith("file://"):
            local_path = uri[7:]
            return self.to_directory(local_path)

        if not is_non_local_path_uri(uri):
            raise RuntimeError(
                f"Cannot upload checkpoint to URI: Provided URI "
                f"does not belong to a registered storage provider: `{uri}`. "
                f"Hint: {fs_hint(uri)}")

        with self.as_directory() as local_path:
            upload_to_uri(local_path=local_path, uri=uri)

        return uri
Ejemplo n.º 3
0
Archivo: cloud.py Proyecto: alipay/ray
    def save(self, path: Optional[str] = None, force_download: bool = False):
        """Save trial checkpoint to directory or cloud storage.

        If the ``path`` is a local target and the checkpoint already exists
        on local storage, the local directory is copied. Else, the checkpoint
        is downloaded from cloud storage.

        If the ``path`` is a cloud target and the checkpoint does not already
        exist on local storage, it is downloaded from cloud storage before.
        That way checkpoints can be transferred across cloud storage providers.

        Args:
            path: Path to save checkpoint at. If empty,
                the default cloud storage path is saved to the default
                local directory.
            force_download: If ``True``, forces (re-)download of
                the checkpoint. Defaults to ``False``.
        """
        temp_dirs = set()
        # Per default, save cloud checkpoint
        if not path:
            if self.cloud_path and self.local_path:
                path = self.local_path
            elif not self.cloud_path:
                raise RuntimeError(
                    "Cannot save trial checkpoint: No cloud path "
                    "found. If the checkpoint is already on the node, "
                    "you can pass a `path` argument to save it at another "
                    "location."
                )
            else:
                # No self.local_path
                raise RuntimeError(
                    "Cannot save trial checkpoint: No target path "
                    "specified and no default local directory available. "
                    "Please pass a `path` argument to `save()`."
                )
        elif not self.local_path and not self.cloud_path:
            raise RuntimeError(
                f"Cannot save trial checkpoint to cloud target "
                f"`{path}`: No existing local or cloud path was "
                f"found. This indicates an error when loading "
                f"the checkpoints. Please report this issue."
            )

        if is_non_local_path_uri(path):
            # Storing on cloud
            if not self.local_path:
                # No local copy, yet. Download to temp dir
                local_path = tempfile.mkdtemp(prefix="tune_checkpoint_")
                temp_dirs.add(local_path)
            else:
                local_path = self.local_path

            if self.cloud_path:
                # Do not update local path as it might be a temp file
                local_path = self.download(
                    local_path=local_path, overwrite=force_download
                )

                # Remove pointer to a temporary directory
                if self.local_path in temp_dirs:
                    self.local_path = None

            # We should now have a checkpoint available locally
            if not os.path.exists(local_path) or len(os.listdir(local_path)) == 0:
                raise RuntimeError(
                    f"No checkpoint found in directory `{local_path}` after "
                    f"download - maybe the bucket is empty or downloading "
                    f"failed?"
                )

            # Only update cloud path if it wasn't set before
            cloud_path = self.upload(
                cloud_path=path, local_path=local_path, clean_before=True
            )

            # Clean up temporary directories
            for temp_dir in temp_dirs:
                shutil.rmtree(temp_dir)

            return cloud_path

        local_path_exists = (
            self.local_path
            and os.path.exists(self.local_path)
            and len(os.listdir(self.local_path)) > 0
        )

        # Else: path is a local target
        if self.local_path and local_path_exists and not force_download:
            # If we have a local copy, use it

            if path == self.local_path:
                # Nothing to do
                return self.local_path

            # Both local, just copy tree
            if os.path.exists(path):
                shutil.rmtree(path)

            shutil.copytree(self.local_path, path)
            return path

        # Else: Download
        try:
            return self.download(local_path=path, overwrite=force_download)
        except Exception as e:
            raise RuntimeError(
                "Cannot save trial checkpoint to local target as downloading "
                "from cloud failed. Did you pass the correct `SyncConfig`?"
            ) from e
Ejemplo n.º 4
0
def _get_external_path(path: Optional[str]) -> Optional[str]:
    """Check if path is an external path. Otherwise return None."""
    if not isinstance(path, str) or not is_non_local_path_uri(path):
        return None
    return path