Ejemplo n.º 1
0
class HTTPStreamFile(AbstractBufferedFile):
    def __init__(self, fs, url, mode="rb", loop=None, session=None, **kwargs):
        self.asynchronous = kwargs.pop("asynchronous", False)
        self.url = url
        self.loop = loop
        self.session = session
        if mode != "rb":
            raise ValueError
        self.details = {"name": url, "size": None}
        super().__init__(fs=fs,
                         path=url,
                         mode=mode,
                         cache_type="none",
                         **kwargs)
        self.r = sync(self.loop, self.session.get, url, **kwargs)

    def seek(self, *args, **kwargs):
        raise ValueError("Cannot seek streaming HTTP file")

    async def _read(self, num=-1):
        out = await self.r.content.read(num)
        self.loc += len(out)
        return out

    read = sync_wrapper(_read)

    async def _close(self):
        self.r.close()

    def close(self):
        asyncio.run_coroutine_threadsafe(self._close(), self.loop)

    def __reduce__(self):
        return reopen, (self.fs, self.url, self.mode, self.blocksize,
                        self.cache.name)
Ejemplo n.º 2
0
    else:
        return out


async def _file_size(url, session=None, size_policy="head", **kwargs):
    """Call HEAD on the server to get file size

    Default operation is to explicitly allow redirects and use encoding
    'identity' (no compression) to get the true size of the target.
    """
    kwargs = kwargs.copy()
    ar = kwargs.pop("allow_redirects", True)
    head = kwargs.get("headers", {}).copy()
    head["Accept-Encoding"] = "identity"
    session = session or await get_client()
    if size_policy == "head":
        r = await session.head(url, allow_redirects=ar, **kwargs)
    elif size_policy == "get":
        r = await session.get(url, allow_redirects=ar, **kwargs)
    else:
        raise TypeError('size_policy must be "head" or "get", got %s'
                        "" % size_policy)
    async with r:
        if "Content-Length" in r.headers:
            return int(r.headers["Content-Length"])
        elif "Content-Range" in r.headers:
            return int(r.headers["Content-Range"].split("/")[1])


file_size = sync_wrapper(_file_size)
Ejemplo n.º 3
0
class HTTPFile(AbstractBufferedFile):
    """
    A file-like object pointing to a remove HTTP(S) resource

    Supports only reading, with read-ahead of a predermined block-size.

    In the case that the server does not supply the filesize, only reading of
    the complete file in one go is supported.

    Parameters
    ----------
    url: str
        Full URL of the remote resource, including the protocol
    session: requests.Session or None
        All calls will be made within this session, to avoid restarting
        connections where the server allows this
    block_size: int or None
        The amount of read-ahead to do, in bytes. Default is 5MB, or the value
        configured for the FileSystem creating this file
    size: None or int
        If given, this is the size of the file in bytes, and we don't attempt
        to call the server to find the value.
    kwargs: all other key-values are passed to requests calls.
    """
    def __init__(self,
                 fs,
                 url,
                 session=None,
                 block_size=None,
                 mode="rb",
                 cache_type="bytes",
                 cache_options=None,
                 size=None,
                 loop=None,
                 asynchronous=False,
                 **kwargs):
        if mode != "rb":
            raise NotImplementedError("File mode not supported")
        self.asynchronous = asynchronous
        self.url = url
        self.session = session
        self.details = {"name": url, "size": size, "type": "file"}
        super().__init__(fs=fs,
                         path=url,
                         mode=mode,
                         block_size=block_size,
                         cache_type=cache_type,
                         cache_options=cache_options,
                         **kwargs)
        self.loop = loop

    def read(self, length=-1):
        """Read bytes from file

        Parameters
        ----------
        length: int
            Read up to this many bytes. If negative, read all content to end of
            file. If the server has not supplied the filesize, attempting to
            read only part of the data will raise a ValueError.
        """
        if ((length < 0 and self.loc == 0)
                or (length > (self.size or length))  # explicit read all
                or (  # read more than there is
                    self.size and self.size < self.blocksize
                )  # all fits in one block anyway
            ):
            self._fetch_all()
        if self.size is None:
            if length < 0:
                self._fetch_all()
        else:
            length = min(self.size - self.loc, length)
        return super().read(length)

    async def async_fetch_all(self):
        """Read whole file in one shot, without caching

        This is only called when position is still at zero,
        and read() is called without a byte-count.
        """
        if not isinstance(self.cache, AllBytes):
            r = await self.session.get(self.url, **self.kwargs)
            async with r:
                r.raise_for_status()
                out = await r.read()
                self.cache = AllBytes(size=len(out),
                                      fetcher=None,
                                      blocksize=None,
                                      data=out)
                self.size = len(out)

    _fetch_all = sync_wrapper(async_fetch_all)

    async def async_fetch_range(self, start, end):
        """Download a block of data

        The expectation is that the server returns only the requested bytes,
        with HTTP code 206. If this is not the case, we first check the headers,
        and then stream the output - if the data size is bigger than we
        requested, an exception is raised.
        """
        kwargs = self.kwargs.copy()
        headers = kwargs.pop("headers", {}).copy()
        headers["Range"] = "bytes=%i-%i" % (start, end - 1)
        logger.debug(self.url + " : " + headers["Range"])
        r = await self.session.get(self.url, headers=headers, **kwargs)
        async with r:
            if r.status == 416:
                # range request outside file
                return b""
            r.raise_for_status()
            if r.status == 206:
                # partial content, as expected
                out = await r.read()
            elif "Content-Length" in r.headers:
                cl = int(r.headers["Content-Length"])
                if cl <= end - start:
                    # data size OK
                    out = await r.read()
                else:
                    raise ValueError(
                        "Got more bytes (%i) than requested (%i)" %
                        (cl, end - start))
            else:
                cl = 0
                out = []
                while True:
                    chunk = await r.content.read(2**20)
                    # data size unknown, let's see if it goes too big
                    if chunk:
                        out.append(chunk)
                        cl += len(chunk)
                        if cl > end - start:
                            raise ValueError(
                                "Got more bytes so far (>%i) than requested (%i)"
                                % (cl, end - start))
                    else:
                        break
                out = b"".join(out)
            return out

    _fetch_range = sync_wrapper(async_fetch_range)

    def close(self):
        pass

    def __reduce__(self):
        return reopen, (
            self.fs,
            self.url,
            self.mode,
            self.blocksize,
            self.cache.name,
            self.size,
        )
class dCacheFileSystem(AsyncFileSystem):
    """

    """

    def __init__(
        self,
        api_url=None,
        webdav_url=None,
        username=None,
        password=None,
        token=None,
        block_size=None,
        asynchronous=False,
        loop=None,
        client_kwargs=None,
        **storage_options
    ):
        """
        NB: if this is called async, you must await set_client

        Parameters
        ----------
        block_size: int
            Blocks to read bytes; if 0, will default to raw requests file-like
            objects instead of HTTPFile instances
        client_kwargs: dict
            Passed to aiohttp.ClientSession, see
            https://docs.aiohttp.org/en/stable/client_reference.html
            For example, ``{'auth': aiohttp.BasicAuth('user', 'pass')}``
        storage_options: key-value
            Any other parameters passed on to requests
        """
        super().__init__(
            self,
            asynchronous=asynchronous,
            loop=loop,
            **storage_options
        )
        self.api_url = api_url
        self.webdav_url = webdav_url
        self.client_kwargs = client_kwargs or {}
        if (username is None) ^ (password is None):
            raise ValueError('Username or password not provided')
        if (username is not None) and (password is not None):
            self.client_kwargs.update(
                auth=aiohttp.BasicAuth(username, password)
            )
        if token is not None:
            if password is not None:
                raise ValueError('Provide either token or username/password')
            headers = self.client_kwargs.get('headers', {})
            headers.update(Authorization=f'Bearer {token}')
            self.client_kwargs.update(headers=headers)
        block_size = DEFAULT_BLOCK_SIZE if block_size is None else block_size
        self.block_size = block_size
        self.kwargs = storage_options
        if not asynchronous:
            self._session = sync(self.loop, get_client, **self.client_kwargs)
            weakref.finalize(self, sync, self.loop, self.session.close)
        else:
            self._session = None

    @property
    def session(self):
        if self._session is None:
            raise RuntimeError(
                "please await ``.set_session`` before anything else"
            )
        return self._session

    @property
    def api_url(self):
        if self._api_url is None:
            raise ValueError('dCache API URL not set!')
        return self._api_url

    @api_url.setter
    def api_url(self, api_url):
        self._api_url = api_url

    @property
    def webdav_url(self):
        if self._webdav_url is None:
            raise ValueError('WebDAV door not set!')
        return self._webdav_url

    @webdav_url.setter
    def webdav_url(self, webdav_url):
        self._webdav_url = webdav_url

    async def set_session(self):
        self._session = await get_client(**self.client_kwargs)

    @classmethod
    def _strip_protocol(cls, path):
        """
        Turn path from fully-qualified to file-system-specific

        :param path: (str or list)
        :return (str)
        """
        if isinstance(path, list):
            return [cls._strip_protocol(p) for p in path]
        return URL(path).path

    @classmethod
    def _get_kwargs_from_urls(cls, path):
        """
        Extract kwargs encoded in the path
        :param path: (str)
        :return (dict)
        """
        return {'webdav_url': cls._get_webdav_url(path)}

    @classmethod
    def _get_webdav_url(cls, path):
        """
        Extract kwargs encoded in the path(s)

        :param path: (str or list) if list, extract URL from the first element
        :return (dict)
        """
        if isinstance(path, list):
            return cls._get_webdav_url(path[0])
        return URL(path).drive or None

    async def _get_info(self, path, children=False, limit=None, **kwargs):
        """
        Request file or directory metadata to the API

        :param path: (str)
        :param children: (bool) if True, return metadata of the children paths
            as well
        :param limit: (int) if provided and children is True, set limit to the
            number of children returned
        :param kwargs: (dict) optional arguments passed on to requests
        :return (dict) path metadata
        """
        url = URL(self.api_url) / 'namespace' / _encode(path)
        url = url.with_query(children=children)
        if limit is not None and children:
            url = url.add_query(limit=f'{limit}')
        url = url.as_uri()
        kw = self.kwargs.copy()
        kw.update(kwargs)
        async with self.session.get(url, **kw) as r:
            if r.status == 404:
                raise FileNotFoundError(url)
            r.raise_for_status()
            return await r.json()

    async def _ls(self, path, detail=True, limit=None, **kwargs):
        """
        List path content.

        :param path: (str)
        :param detail: (bool) if True, return a list of dictionaries with the
            (children) path(s) info. If False, return a list of paths
        :param limit: (int) set the maximum number of children paths returned
            to this value
        :param kwargs: (dict) optional arguments passed on to requests
        :return list of dictionaries or list of str
        """
        path = self._strip_protocol(path)

        info = await self._get_info(path, children=True, limit=limit,
                                    **kwargs)
        details = _get_details(path, info)
        if details['type'] == 'directory':
            elements = info.get('children') or []
            details = [_get_details(path, el) for el in elements]
        else:
            details = [details]

        if detail:
            return details
        else:
            return [d.get('name') for d in details]

    ls = sync_wrapper(_ls)

    async def _cat_file(self, url, start=None, end=None, **kwargs):
        self.webdav_url = self._get_webdav_url(url) or self.webdav_url

        path = self._strip_protocol(url)
        url = URL(self.webdav_url) / path
        url = url.as_uri()
        kw = self.kwargs.copy()
        kw.update(kwargs)
        if (start is None) ^ (end is None):
            raise ValueError("Give start and end or neither")
        if start is not None:
            headers = kw.pop("headers", {}).copy()
            headers["Range"] = "bytes=%i-%i" % (start, end - 1)
            kw["headers"] = headers
        async with self.session.get(url, **kw) as r:
            if r.status == 404:
                raise FileNotFoundError(url)
            r.raise_for_status()
            out = await r.read()
        return out

    async def _get_file(self, rpath, lpath, chunk_size=5 * 2 ** 20, **kwargs):
        self.webdav_url = self._get_webdav_url(rpath) or self.webdav_url

        path = self._strip_protocol(rpath)
        url = URL(self.webdav_url) / path
        url = url.as_uri()
        kw = self.kwargs.copy()
        kw.update(kwargs)
        async with self.session.get(url, **self.kwargs) as r:
            if r.status == 404:
                raise FileNotFoundError(rpath)
            r.raise_for_status()
            with open(lpath, "wb") as fd:
                chunk = True
                while chunk:
                    chunk = await r.content.read(chunk_size)
                    fd.write(chunk)

    async def _put_file(self, lpath, rpath, **kwargs):
        self.webdav_url = self._get_webdav_url(rpath) or self.webdav_url

        path = self._strip_protocol(rpath)
        url = URL(self.webdav_url) / path
        url = url.as_uri()
        kw = self.kwargs.copy()
        kw.update(kwargs)
        with open(lpath, "rb") as fd:
            r = await self.session.put(url, data=fd, **self.kwargs)
            r.raise_for_status()

    async def _cp_file(self, path1, path2, **kwargs):
        raise NotImplementedError

    async def _pipe_file(self, path1, path2, **kwargs):
        raise NotImplementedError

    async def _mv(self, path1, path2, **kwargs):
        """
        Rename path1 to path2

        :param path1: (str) source path
        :param path2: (str) destination path
        :param kwargs: (dict) optional arguments passed on to requests
        """
        path1 = self._strip_protocol(path1)
        path2 = self._strip_protocol(path2)

        url = URL(self.api_url) / 'namespace' / _encode(path1)
        url = url.as_uri()
        data = dict(action='mv', destination=path2)
        kw = self.kwargs.copy()
        kw.update(kwargs)
        async with self.session.post(url, json=data, **kw) as r:
            if r.status == 404:
                raise FileNotFoundError(url)
            r.raise_for_status()
            return await r.json()

    mv = sync_wrapper(_mv)

    async def _rm_file(self, path, **kwargs):
        """
        Remove file or directory (must be empty)

        :param path: (str)
        """
        url = URL(self.api_url) / 'namespace' / _encode(path)
        url = url.as_uri()
        kw = self.kwargs.copy()
        kw.update(kwargs)
        async with self.session.delete(url, **kw) as r:
            if r.status == 404:
                raise FileNotFoundError(url)
            r.raise_for_status()

    async def _rm(self, path, recursive=False, **kwargs):
        """
        Asynchronous remove method. Need to delete elements from branches
        towards root, awaiting tasks to be completed.
        """
        path = await self._expand_path(path, recursive=recursive)
        for p in reversed(path):
            await asyncio.gather(self._rm_file(p, **kwargs))

    rm = sync_wrapper(_rm)

    async def _info(self, path, **kwargs):
        """
        Give details about a file or a directory

        :param path: (str)
        :param kwargs: (dict) optional arguments passed on to requests
        :return (dict)
        """
        path = self._strip_protocol(path)
        info = await self._get_info(path, **kwargs)
        return _get_details(path, info)

    info = sync_wrapper(_info)

    def created(self, path):
        """
        Date and time in which the path was created

        :param path: (str)
        :return (datetime.datetime object)
        """
        return self.info(path).get('created')

    def modified(self, path):
        """
        Date and time in which the path was last modified

        :param path: (str)
        :return (datetime.datetime object)
        """
        return self.info(path).get('modified')

    def _open(
        self,
        path,
        mode="rb",
        block_size=None,
        cache_type="readahead",
        cache_options=None,
        **kwargs
    ):
        """Make a file-like object

        Parameters
        ----------
        path: str
            Full URL with protocol
        mode: string
            must be "rb"
        block_size: int or None
            Bytes to download in one request; use instance value if None. If
            zero, will return a streaming Requests file-like instance.
        kwargs: key-value
            Any other parameters, passed to requests calls
        """
        if mode not in {"rb", "wb"}:
            raise NotImplementedError
        kw = self.kwargs.copy()
        kw.update(kwargs)
        if block_size:
            return dCacheFile(
                self,
                path,
                mode=mode,
                block_size=block_size,
                cache_type=cache_type,
                cache_options=cache_options,
                asynchronous=self.asynchronous,
                session=self.session,
                loop=self.loop,
                **kw
            )
        else:
            return dCacheStreamFile(
                self,
                path,
                mode=mode,
                asynchronous=self.asynchronous,
                session=self.session,
                loop=self.loop,
                **kw
            )

    def open(
        self,
        path,
        mode="rb",
        block_size=None,
        cache_options=None,
        **kwargs
    ):
        """

        """
        self.webdav_url = self._get_webdav_url(path) or self.webdav_url
        block_size = self.block_size if block_size is None else block_size
        return super().open(
            path=path,
            mode=mode,
            block_size=block_size,
            cache_options=cache_options,
            **kwargs
        )
class dCacheFile(HTTPFile):
    """
    A file-like object pointing to a remove HTTP(S) resource

    Supports only reading, with read-ahead of a predermined block-size.

    In the case that the server does not supply the filesize, only reading of
    the complete file in one go is supported.

    Parameters
    ----------
    url: str
        Full URL of the remote resource, including the protocol
    session: requests.Session or None
        All calls will be made within this session, to avoid restarting
        connections where the server allows this
    block_size: int or None
        The amount of read-ahead to do, in bytes. Default is 5MB, or the value
        configured for the FileSystem creating this file
    size: None or int
        If given, this is the size of the file in bytes, and we don't attempt
        to call the server to find the value.
    kwargs: all other key-values are passed to requests calls.
    """

    def __init__(
        self,
        fs,
        url,
        mode="rb",
        block_size=None,
        cache_type="bytes",
        cache_options=None,
        asynchronous=False,
        session=None,
        loop=None,
        **kwargs
    ):
        path = fs._strip_protocol(url)
        url = URL(fs.webdav_url) / path
        self.url = url.as_uri()
        self.asynchronous = asynchronous
        self.session = session
        self.loop = loop
        if mode not in {"rb", "wb"}:
            raise ValueError
        super(HTTPFile, self).__init__(
            fs=fs,
            path=path,
            mode=mode,
            block_size=block_size,
            cache_type=cache_type,
            cache_options=cache_options,
            **kwargs
        )

    def flush(self, force=False):
        if self.closed:
            raise ValueError("Flush on closed file")
        if force and self.forced:
            raise ValueError("Force flush cannot be called more than once")
        if force:
            self.write_chunked()
            self.forced = True

    async def _write_chunked(self):
        self.buffer.seek(0)
        r = await self.session.put(self.url, data=self.buffer, **self.kwargs)
        r.raise_for_status()
        return False

    write_chunked = sync_wrapper(_write_chunked)

    def close(self):
        super(HTTPFile, self).close()
Ejemplo n.º 6
0
class HTTPFileSystem(AsyncFileSystem):
    """
    Simple File-System for fetching data via HTTP(S)

    ``ls()`` is implemented by loading the parent page and doing a regex
    match on the result. If simple_link=True, anything of the form
    "http(s)://server.com/stuff?thing=other"; otherwise only links within
    HTML href tags will be used.
    """

    sep = "/"

    def __init__(
        self,
        simple_links=True,
        block_size=None,
        same_scheme=True,
        size_policy=None,
        cache_type="bytes",
        cache_options=None,
        asynchronous=False,
        loop=None,
        client_kwargs=None,
        **storage_options,
    ):
        """
        NB: if this is called async, you must await set_client

        Parameters
        ----------
        block_size: int
            Blocks to read bytes; if 0, will default to raw requests file-like
            objects instead of HTTPFile instances
        simple_links: bool
            If True, will consider both HTML <a> tags and anything that looks
            like a URL; if False, will consider only the former.
        same_scheme: True
            When doing ls/glob, if this is True, only consider paths that have
            http/https matching the input URLs.
        size_policy: this argument is deprecated
        client_kwargs: dict
            Passed to aiohttp.ClientSession, see
            https://docs.aiohttp.org/en/stable/client_reference.html
            For example, ``{'auth': aiohttp.BasicAuth('user', 'pass')}``
        storage_options: key-value
            Any other parameters passed on to requests
        cache_type, cache_options: defaults used in open
        """
        super().__init__(self,
                         asynchronous=asynchronous,
                         loop=loop,
                         **storage_options)
        self.block_size = block_size if block_size is not None else DEFAULT_BLOCK_SIZE
        self.simple_links = simple_links
        self.same_schema = same_scheme
        self.cache_type = cache_type
        self.cache_options = cache_options
        self.client_kwargs = client_kwargs or {}
        self.kwargs = storage_options
        self._session = None
        if not asynchronous:
            sync(self.loop, self.set_session)

    @staticmethod
    def close_session(loop, session):
        if loop is not None and loop.is_running():
            sync(loop, session.close)
        elif session._connector is not None:
            # close after loop is dead
            session._connector._close()

    async def set_session(self):
        if self._session is None:
            self._session = await get_client(loop=self.loop,
                                             **self.client_kwargs)
            if not self.asynchronous:
                weakref.finalize(self, self.close_session, self.loop,
                                 self._session)
        return self._session

    @classmethod
    def _strip_protocol(cls, path):
        """For HTTP, we always want to keep the full URL"""
        return path

    @classmethod
    def _parent(cls, path):
        # override, since _strip_protocol is different for URLs
        par = super()._parent(path)
        if len(par) > 7:  # "http://..."
            return par
        return ""

    async def _ls(self, url, detail=True, **kwargs):
        # ignoring URL-encoded arguments
        kw = self.kwargs.copy()
        kw.update(kwargs)
        logger.debug(url)
        session = await self.set_session()
        async with session.get(url, **self.kwargs) as r:
            r.raise_for_status()
            text = await r.text()
        if self.simple_links:
            links = ex2.findall(text) + [u[2] for u in ex.findall(text)]
        else:
            links = [u[2] for u in ex.findall(text)]
        out = set()
        parts = urlparse(url)
        for l in links:
            if isinstance(l, tuple):
                l = l[1]
            if l.startswith("/") and len(l) > 1:
                # absolute URL on this server
                l = parts.scheme + "://" + parts.netloc + l
            if l.startswith("http"):
                if self.same_schema and l.startswith(url.rstrip("/") + "/"):
                    out.add(l)
                elif l.replace("https", "http").startswith(
                        url.replace("https", "http").rstrip("/") + "/"):
                    # allowed to cross http <-> https
                    out.add(l)
            else:
                if l not in ["..", "../"]:
                    # Ignore FTP-like "parent"
                    out.add("/".join([url.rstrip("/"), l.lstrip("/")]))
        if not out and url.endswith("/"):
            out = await self._ls(url.rstrip("/"), detail=False)
        if detail:
            return [{
                "name": u,
                "size": None,
                "type": "directory" if u.endswith("/") else "file",
            } for u in out]
        else:
            return list(sorted(out))

    ls = sync_wrapper(_ls)

    async def _cat_file(self, url, start=None, end=None, **kwargs):
        kw = self.kwargs.copy()
        kw.update(kwargs)
        logger.debug(url)
        if (start is None) ^ (end is None):
            raise ValueError("Give start and end or neither")
        if start is not None:
            headers = kw.pop("headers", {}).copy()
            headers["Range"] = "bytes=%i-%i" % (start, end - 1)
            kw["headers"] = headers
        session = await self.set_session()
        async with session.get(url, **kw) as r:
            if r.status == 404:
                raise FileNotFoundError(url)
            r.raise_for_status()
            out = await r.read()
        return out

    async def _get_file(self, rpath, lpath, chunk_size=5 * 2**20, **kwargs):
        kw = self.kwargs.copy()
        kw.update(kwargs)
        logger.debug(rpath)
        session = await self.set_session()
        async with session.get(rpath, **self.kwargs) as r:
            if r.status == 404:
                raise FileNotFoundError(rpath)
            r.raise_for_status()
            with open(lpath, "wb") as fd:
                chunk = True
                while chunk:
                    chunk = await r.content.read(chunk_size)
                    fd.write(chunk)

    async def _exists(self, path, **kwargs):
        kw = self.kwargs.copy()
        kw.update(kwargs)
        try:
            logger.debug(path)
            session = await self.set_session()
            r = await session.get(path, **kw)
            async with r:
                return r.status < 400
        except (requests.HTTPError, aiohttp.ClientError):
            return False

    async def _isfile(self, path, **kwargs):
        return await self._exists(path, **kwargs)

    def _open(
        self,
        path,
        mode="rb",
        block_size=None,
        autocommit=None,  # XXX: This differs from the base class.
        cache_type=None,
        cache_options=None,
        size=None,
        **kwargs,
    ):
        """Make a file-like object

        Parameters
        ----------
        path: str
            Full URL with protocol
        mode: string
            must be "rb"
        block_size: int or None
            Bytes to download in one request; use instance value if None. If
            zero, will return a streaming Requests file-like instance.
        kwargs: key-value
            Any other parameters, passed to requests calls
        """
        if mode != "rb":
            raise NotImplementedError
        block_size = block_size if block_size is not None else self.block_size
        kw = self.kwargs.copy()
        kw["asynchronous"] = self.asynchronous
        kw.update(kwargs)
        size = size or self.size(path)
        session = sync(self.loop, self.set_session)
        if block_size and size:
            return HTTPFile(
                self,
                path,
                session=session,
                block_size=block_size,
                mode=mode,
                size=size,
                cache_type=cache_type or self.cache_type,
                cache_options=cache_options or self.cache_options,
                loop=self.loop,
                **kw,
            )
        else:
            return HTTPStreamFile(self,
                                  path,
                                  mode=mode,
                                  loop=self.loop,
                                  session=session,
                                  **kw)

    def ukey(self, url):
        """Unique identifier; assume HTTP files are static, unchanging"""
        return tokenize(url, self.kwargs, self.protocol)

    async def _info(self, url, **kwargs):
        """Get info of URL

        Tries to access location via HEAD, and then GET methods, but does
        not fetch the data.

        It is possible that the server does not supply any size information, in
        which case size will be given as None (and certain operations on the
        corresponding file will not work).
        """
        size = False
        for policy in ["head", "get"]:
            try:
                session = await self.set_session()
                size = await _file_size(url,
                                        size_policy=policy,
                                        session=session,
                                        **self.kwargs)
                if size:
                    break
            except Exception:
                pass
        else:
            # get failed, so conclude URL does not exist
            if size is False:
                raise FileNotFoundError(url)
        return {"name": url, "size": size or None, "type": "file"}

    async def _glob(self, path, **kwargs):
        """
        Find files by glob-matching.

        This implementation is idntical to the one in AbstractFileSystem,
        but "?" is not considered as a character for globbing, because it is
        so common in URLs, often identifying the "query" part.
        """
        import re

        ends = path.endswith("/")
        path = self._strip_protocol(path)
        indstar = path.find("*") if path.find("*") >= 0 else len(path)
        indbrace = path.find("[") if path.find("[") >= 0 else len(path)

        ind = min(indstar, indbrace)

        detail = kwargs.pop("detail", False)

        if not has_magic(path):
            root = path
            depth = 1
            if ends:
                path += "/*"
            elif await self._exists(path):
                if not detail:
                    return [path]
                else:
                    return {path: await self._info(path)}
            else:
                if not detail:
                    return []  # glob of non-existent returns empty
                else:
                    return {}
        elif "/" in path[:ind]:
            ind2 = path[:ind].rindex("/")
            root = path[:ind2 + 1]
            depth = None if "**" in path else path[ind2 + 1:].count("/") + 1
        else:
            root = ""
            depth = None if "**" in path else path[ind + 1:].count("/") + 1

        allpaths = await self._find(root,
                                    maxdepth=depth,
                                    withdirs=True,
                                    detail=True,
                                    **kwargs)
        # Escape characters special to python regex, leaving our supported
        # special characters in place.
        # See https://www.gnu.org/software/bash/manual/html_node/Pattern-Matching.html
        # for shell globbing details.
        pattern = ("^" + (path.replace("\\", r"\\").replace(
            ".", r"\.").replace("+", r"\+").replace("//", "/").replace(
                "(", r"\(").replace(")", r"\)").replace("|", r"\|").replace(
                    "^", r"\^").replace("$", r"\$").replace(
                        "{", r"\{").replace("}", r"\}").rstrip("/")) + "$")
        pattern = re.sub("[*]{2}", "=PLACEHOLDER=", pattern)
        pattern = re.sub("[*]", "[^/]*", pattern)
        pattern = re.compile(pattern.replace("=PLACEHOLDER=", ".*"))
        out = {
            p: allpaths[p]
            for p in sorted(allpaths)
            if pattern.match(p.replace("//", "/").rstrip("/"))
        }
        if detail:
            return out
        else:
            return list(out)

    async def _isdir(self, path):
        # override, since all URLs are (also) files
        return bool(await self._ls(path))
Ejemplo n.º 7
0
class GCSFileSystem(AsyncFileSystem):
    r"""
    Connect to Google Cloud Storage.

    The following modes of authentication are supported:

    - ``token=None``, GCSFS will attempt to guess your credentials in the
      following order: gcloud CLI default, gcsfs cached token, google compute
      metadata service, anonymous.
    - ``token='google_default'``, your default gcloud credentials will be used,
      which are typically established by doing ``gcloud login`` in a terminal.
    - ``token=='cache'``, credentials from previously successful gcsfs
      authentication will be used (use this after "browser" auth succeeded)
    - ``token='anon'``, no authentication is preformed, and you can only
      access data which is accessible to allUsers (in this case, the project and
      access level parameters are meaningless)
    - ``token='browser'``, you get an access code with which you can
      authenticate via a specially provided URL
    - if ``token='cloud'``, we assume we are running within google compute
      or google container engine, and query the internal metadata directly for
      a token.
    - you may supply a token generated by the
      [gcloud](https://cloud.google.com/sdk/docs/)
      utility; this is either a python dictionary, the name of a file
      containing the JSON returned by logging in with the gcloud CLI tool,
      or a Credentials object. gcloud typically stores its tokens in locations
      such as
      ``~/.config/gcloud/application_default_credentials.json``,
      `` ~/.config/gcloud/credentials``, or
      ``~\AppData\Roaming\gcloud\credentials``, etc.

    Specific methods, (eg. `ls`, `info`, ...) may return object details from GCS.
    These detailed listings include the
    [object resource](https://cloud.google.com/storage/docs/json_api/v1/objects#resource)

    GCS *does not* include  "directory" objects but instead generates
    directories by splitting
    [object names](https://cloud.google.com/storage/docs/key-terms).
    This means that, for example,
    a directory does not need to exist for an object to be created within it.
    Creating an object implicitly creates it's parent directories, and removing
    all objects from a directory implicitly deletes the empty directory.

    `GCSFileSystem` generates listing entries for these implied directories in
    listing apis with the  object properies:

        - "name" : string
            The "{bucket}/{name}" path of the dir, used in calls to
            GCSFileSystem or GCSFile.
        - "bucket" : string
            The name of the bucket containing this object.
        - "kind" : 'storage#object'
        - "size" : 0
        - "storageClass" : 'DIRECTORY'
        - type: 'directory' (fsspec compat)

    GCSFileSystem maintains a per-implied-directory cache of object listings and
    fulfills all object information and listing requests from cache. This implied, for example, that objects
    created via other processes *will not* be visible to the GCSFileSystem until the cache
    refreshed. Calls to GCSFileSystem.open and calls to GCSFile are not effected by this cache.

    In the default case the cache is never expired. This may be controlled via the `cache_timeout`
    GCSFileSystem parameter or via explicit calls to `GCSFileSystem.invalidate_cache`.

    Parameters
    ----------
    project : string
        project_id to work under. Note that this is not the same as, but often
        very similar to, the project name.
        This is required in order
        to list all the buckets you have access to within a project and to
        create/delete buckets, or update their access policies.
        If ``token='google_default'``, the value is overriden by the default,
        if ``token='anon'``, the value is ignored.
    access : one of {'read_only', 'read_write', 'full_control'}
        Full control implies read/write as well as modifying metadata,
        e.g., access control.
    token: None, dict or string
        (see description of authentication methods, above)
    consistency: 'none', 'size', 'md5'
        Check method when writing files. Can be overridden in open().
    cache_timeout: float, seconds
        Cache expiration time in seconds for object metadata cache.
        Set cache_timeout <= 0 for no caching, None for no cache expiration.
    secure_serialize: bool (deprecated)
    check_connection: bool
        When token=None, gcsfs will attempt various methods of establishing
        credentials, falling back to anon. It is possible for a method to
        find credentials in the system that turn out not to be valid. Setting
        this parameter to True will ensure that an actual operation is
        attempted before deciding that credentials are valid.
    requester_pays : bool, or str default False
        Whether to use requester-pays requests. This will include your
        project ID `project` in requests as the `userPorject`, and you'll be
        billed for accessing data from requester-pays buckets. Optionally,
        pass a project-id here as a string to use that as the `userProject`.
    session_kwargs: dict
        passed on to aiohttp.ClientSession; can contain, for example,
        proxy settings.
    """

    scopes = {"read_only", "read_write", "full_control"}
    retries = 6  # number of retries on http failure
    base = "https://storage.googleapis.com/storage/v1/"
    default_block_size = DEFAULT_BLOCK_SIZE
    protocol = "gcs", "gs"
    async_impl = True

    def __init__(
        self,
        project=DEFAULT_PROJECT,
        access="full_control",
        token=None,
        block_size=None,
        consistency="none",
        cache_timeout=None,
        secure_serialize=True,
        check_connection=False,
        requests_timeout=None,
        requester_pays=False,
        asynchronous=False,
        session_kwargs=None,
        loop=None,
        timeout=None,
        **kwargs,
    ):
        super().__init__(
            self,
            listings_expiry_time=cache_timeout,
            asynchronous=asynchronous,
            loop=loop,
            **kwargs,
        )
        if access not in self.scopes:
            raise ValueError("access must be one of {}", self.scopes)
        if project is None:
            warnings.warn(
                "GCS project not set - cannot list or create buckets")
        if block_size is not None:
            self.default_block_size = block_size
        self.requester_pays = requester_pays
        self.consistency = consistency
        self.cache_timeout = cache_timeout or kwargs.pop(
            "listings_expiry_time", None)
        self.requests_timeout = requests_timeout
        self.timeout = timeout
        self._session = None
        self.session_kwargs = session_kwargs or {}

        self.credentials = GoogleCredentials(project, access, token,
                                             check_connection)

        if not self.asynchronous:
            self._session = sync(self.loop,
                                 get_client,
                                 timeout=self.timeout,
                                 **self.session_kwargs)
            weakref.finalize(self, self.close_session, self.loop,
                             self._session)

    @property
    def project(self):
        return self.credentials.project

    @staticmethod
    def close_session(loop, session):
        if loop is not None and session is not None:
            if loop.is_running():
                try:
                    sync(loop, session.close, timeout=0.1)
                except fsspec.FSTimeoutError:
                    pass
            else:
                pass

    async def _set_session(self):
        if self._session is None:
            self._session = await get_client(**self.session_kwargs)
        return self._session

    @property
    def session(self):
        if self.asynchronous and self._session is None:
            raise RuntimeError("Please await _connect* before anything else")
        return self._session

    @classmethod
    def _strip_protocol(cls, path):
        if isinstance(path, list):
            return [cls._strip_protocol(p) for p in path]
        path = stringify_path(path)
        protos = (cls.protocol, ) if isinstance(cls.protocol,
                                                str) else cls.protocol
        for protocol in protos:
            if path.startswith(protocol + "://"):
                path = path[len(protocol) + 3:]
            elif path.startswith(protocol + "::"):
                path = path[len(protocol) + 2:]
        # use of root_marker to make minimum required path, e.g., "/"
        return path or cls.root_marker

    def _get_params(self, kwargs):
        params = {k: v for k, v in kwargs.items() if v is not None}
        # needed for requester pays buckets
        if self.requester_pays:
            if isinstance(self.requester_pays, str):
                user_project = self.requester_pays
            else:
                user_project = self.project
            params["userProject"] = user_project
        return params

    def _get_headers(self, headers):
        out = {}
        if headers is not None:
            out.update(headers)
        if "User-Agent" not in out:
            out["User-Agent"] = "python-gcsfs/" + version
        self.credentials.apply(out)
        return out

    def _format_path(self, path, args):
        if not path.startswith("http"):
            path = self.base + path

        if args:
            path = path.format(*[quote_plus(p) for p in args])
        return path

    @retry_request(retries=retries)
    async def _request(self,
                       method,
                       path,
                       *args,
                       headers=None,
                       json=None,
                       data=None,
                       **kwargs):
        await self._set_session()
        async with self.session.request(
                method=method,
                url=self._format_path(path, args),
                params=self._get_params(kwargs),
                json=json,
                headers=self._get_headers(headers),
                data=data,
                timeout=self.requests_timeout,
        ) as r:

            status = r.status
            headers = r.headers
            info = r.request_info  # for debug only
            contents = await r.read()

            validate_response(status, contents, path)
            return status, headers, info, contents

    async def _call(self,
                    method,
                    path,
                    *args,
                    json_out=False,
                    info_out=False,
                    **kwargs):
        logger.debug(
            f"{method.upper()}: {path}, {args}, {kwargs.get('headers')}")

        status, headers, info, contents = await self._request(
            method, path, *args, **kwargs)
        if json_out:
            return json.loads(contents)
        elif info_out:
            return info
        else:
            return headers, contents

    call = sync_wrapper(_call)

    @property
    def buckets(self):
        """Return list of available project buckets."""
        return [
            b["name"]
            for b in sync(self.loop, self._list_buckets, timeout=self.timeout)
        ]

    @staticmethod
    def _process_object(bucket, object_metadata):
        """Process object resource into gcsfs object information format.

        Process GCS object resource via type casting and attribute updates to
        the cache-able gcsfs object information format. Returns an updated copy
        of the object resource.

        (See https://cloud.google.com/storage/docs/json_api/v1/objects#resource)
        """
        result = dict(object_metadata)
        result["size"] = int(object_metadata.get("size", 0))
        result["name"] = posixpath.join(bucket, object_metadata["name"])
        result["type"] = "file"

        return result

    async def _get_object(self, path):
        """Return object information at the given path."""
        bucket, key = self.split_path(path)

        # Check if parent dir is in listing cache
        listing = self._ls_from_cache(path)
        if listing:
            for file_details in listing:
                if file_details["type"] == "file" and file_details[
                        "name"] == path:
                    return file_details
            else:
                raise FileNotFoundError(path)

        if not key:
            # Attempt to "get" the bucket root, return error instead of
            # listing.
            raise FileNotFoundError(path)

        res = None
        # Work around various permission settings. Prefer an object get (storage.objects.get), but
        # fall back to a bucket list + filter to object name (storage.objects.list).
        try:
            res = await self._call("GET",
                                   "b/{}/o/{}",
                                   bucket,
                                   key,
                                   json_out=True)
        except OSError as e:
            if not str(e).startswith("Forbidden"):
                raise
            resp = await self._call("GET",
                                    "b/{}/o/",
                                    bucket,
                                    json_out=True,
                                    prefix=key,
                                    maxResults=1)
            for item in resp.get("items", []):
                if item["name"] == key:
                    res = item
                    break
            if res is None:
                raise FileNotFoundError(path)
        return self._process_object(bucket, res)

    async def _list_objects(self, path, prefix=""):
        bucket, key = self.split_path(path)
        path = path.rstrip("/")

        try:
            clisting = self._ls_from_cache(path)
            hassubdirs = clisting and any(
                c["name"].rstrip("/") == path and c["type"] == "directory"
                for c in clisting)
            if clisting and not hassubdirs:
                return clisting
        except FileNotFoundError:
            # not finding a bucket in list of "my" buckets is OK
            if key:
                raise

        items, prefixes = await self._do_list_objects(path, prefix=prefix)

        pseudodirs = [{
            "bucket": bucket,
            "name": bucket + "/" + prefix.strip("/"),
            "size": 0,
            "storageClass": "DIRECTORY",
            "type": "directory",
        } for prefix in prefixes]
        if not (items + pseudodirs):
            if key:
                return [await self._get_object(path)]
            else:
                return []
        out = items + pseudodirs
        # Don't cache prefixed/partial listings
        if not prefix:
            self.dircache[path] = out
        return out

    async def _do_list_objects(self,
                               path,
                               max_results=None,
                               delimiter="/",
                               prefix=""):
        """Object listing for the given {bucket}/{prefix}/ path."""
        bucket, _path = self.split_path(path)
        _path = "" if not _path else _path.rstrip("/") + "/"
        prefix = f"{_path}{prefix}" or None

        prefixes = []
        items = []
        page = await self._call(
            "GET",
            "b/{}/o/",
            bucket,
            delimiter=delimiter,
            prefix=prefix,
            maxResults=max_results,
            json_out=True,
        )

        prefixes.extend(page.get("prefixes", []))
        items.extend(page.get("items", []))
        next_page_token = page.get("nextPageToken", None)

        while next_page_token is not None:
            page = await self._call(
                "GET",
                "b/{}/o/",
                bucket,
                delimiter=delimiter,
                prefix=prefix,
                maxResults=max_results,
                pageToken=next_page_token,
                json_out=True,
            )

            assert page["kind"] == "storage#objects"
            prefixes.extend(page.get("prefixes", []))
            items.extend(page.get("items", []))
            next_page_token = page.get("nextPageToken", None)

        items = [self._process_object(bucket, i) for i in items]
        return items, prefixes

    async def _list_buckets(self):
        """Return list of all buckets under the current project."""
        if "" not in self.dircache:
            items = []
            page = await self._call("GET",
                                    "b/",
                                    project=self.project,
                                    json_out=True)

            assert page["kind"] == "storage#buckets"
            items.extend(page.get("items", []))
            next_page_token = page.get("nextPageToken", None)

            while next_page_token is not None:
                page = await self._call(
                    "GET",
                    "b/",
                    project=self.project,
                    pageToken=next_page_token,
                    json_out=True,
                )

                assert page["kind"] == "storage#buckets"
                items.extend(page.get("items", []))
                next_page_token = page.get("nextPageToken", None)

            self.dircache[""] = [{
                "name": i["name"] + "/",
                "size": 0,
                "type": "directory"
            } for i in items]
        return self.dircache[""]

    def invalidate_cache(self, path=None):
        """
        Invalidate listing cache for given path, it is reloaded on next use.

        Parameters
        ----------
        path: string or None
            If None, clear all listings cached else listings at or under given
            path.
        """
        if path is None:
            logger.debug("invalidate_cache clearing cache")
            self.dircache.clear()
        else:
            path = self._strip_protocol(path).rstrip("/")

            while path:
                self.dircache.pop(path, None)
                path = self._parent(path)

    async def _mkdir(self,
                     bucket,
                     acl="projectPrivate",
                     default_acl="bucketOwnerFullControl"):
        """
        New bucket

        Parameters
        ----------
        bucket: str
            bucket name. If contains '/' (i.e., looks like subdir), will
            have no effect because GCS doesn't have real directories.
        acl: string, one of bACLs
            access for the bucket itself
        default_acl: str, one of ACLs
            default ACL for objects created in this bucket
        """
        if bucket in ["", "/"]:
            raise ValueError("Cannot create root bucket")
        if "/" in bucket:
            return
        await self._call(
            method="POST",
            path="b/",
            predefinedAcl=acl,
            project=self.project,
            predefinedDefaultObjectAcl=default_acl,
            json={"name": bucket},
            json_out=True,
        )
        self.invalidate_cache(bucket)

    mkdir = sync_wrapper(_mkdir)

    async def _rmdir(self, bucket):
        """Delete an empty bucket

        Parameters
        ----------
        bucket: str
            bucket name. If contains '/' (i.e., looks like subdir), will
            have no effect because GCS doesn't have real directories.
        """
        bucket = bucket.rstrip("/")
        if "/" in bucket:
            return
        await self._call("DELETE", "b/" + bucket, json_out=False)
        self.invalidate_cache(bucket)
        self.invalidate_cache("")

    rmdir = sync_wrapper(_rmdir)

    async def _info(self, path, **kwargs):
        """File information about this path."""
        path = self._strip_protocol(path).rstrip("/")
        # Check directory cache for parent dir
        parent_path = self._parent(path)
        parent_cache = self._ls_from_cache(parent_path)
        bucket, key = self.split_path(path)
        if parent_cache:
            for o in parent_cache:
                if o["name"].rstrip("/") == path:
                    return o
        if self._ls_from_cache(path):
            # this is a directory
            return {
                "bucket": bucket,
                "name": path.rstrip("/"),
                "size": 0,
                "storageClass": "DIRECTORY",
                "type": "directory",
            }
        # Check exact file path
        try:
            return await self._get_object(path)
        except FileNotFoundError:
            pass
        kwargs["detail"] = True  # Force to true for info
        out = await self._ls(path, **kwargs)
        out0 = [o for o in out if o["name"].rstrip("/") == path]
        if out0:
            # exact hit
            return out0[0]
        elif out:
            # other stuff - must be a directory
            return {
                "bucket": bucket,
                "name": path.rstrip("/"),
                "size": 0,
                "storageClass": "DIRECTORY",
                "type": "directory",
            }
        else:
            raise FileNotFoundError(path)

    async def _glob(self, path, prefix="", **kwargs):
        if not prefix:
            # Identify pattern prefixes. Ripped from fsspec.spec.AbstractFileSystem.glob and matches
            # the glob.has_magic patterns.
            indstar = path.find("*") if path.find("*") >= 0 else len(path)
            indques = path.find("?") if path.find("?") >= 0 else len(path)
            indbrace = path.find("[") if path.find("[") >= 0 else len(path)

            ind = min(indstar, indques, indbrace)
            prefix = path[:ind].split("/")[-1]
        return await super()._glob(path, prefix=prefix, **kwargs)

    async def _ls(self, path, detail=False, prefix="", **kwargs):
        """List objects under the given '/{bucket}/{prefix} path."""
        path = self._strip_protocol(path).rstrip("/")

        if path in ["/", ""]:
            out = await self._list_buckets()
        else:
            out = await self._list_objects(path, prefix=prefix)

        if detail:
            return out
        else:
            return sorted([o["name"] for o in out])

    @classmethod
    def url(cls, path):
        """ Get HTTP URL of the given path """
        u = "https://storage.googleapis.com/download/storage/v1/b/{}/o/{}?alt=media"
        bucket, object = cls.split_path(path)
        object = quote_plus(object)
        return u.format(bucket, object)

    async def _cat_file(self, path, start=None, end=None):
        """ Simple one-shot get of file data """
        u2 = self.url(path)
        if start or end:
            head = {"Range": await self._process_limits(path, start, end)}
        else:
            head = {}
        headers, out = await self._call("GET", u2, headers=head)
        return out

    async def _getxattr(self, path, attr):
        """Get user-defined metadata attribute"""
        meta = (await self._info(path)).get("metadata", {})
        return meta[attr]

    getxattr = sync_wrapper(_getxattr)

    async def _setxattrs(self,
                         path,
                         content_type=None,
                         content_encoding=None,
                         **kwargs):
        """Set/delete/add writable metadata attributes

        Parameters
        ---------
        content_type: str
            If not None, set the content-type to this value
        content_encoding: str
            If not None, set the content-encoding.
            See https://cloud.google.com/storage/docs/transcoding
        kw_args: key-value pairs like field="value" or field=None
            value must be string to add or modify, or None to delete

        Returns
        -------
        Entire metadata after update (even if only path is passed)
        """
        i_json = {"metadata": kwargs}
        if content_type is not None:
            i_json["contentType"] = content_type
        if content_encoding is not None:
            i_json["contentEncoding"] = content_encoding

        bucket, key = self.split_path(path)
        o_json = await self._call(
            "PATCH",
            "b/{}/o/{}",
            bucket,
            key,
            fields="metadata",
            json=i_json,
            json_out=True,
        )
        (await self._info(path))["metadata"] = o_json.get("metadata", {})
        return o_json.get("metadata", {})

    setxattrs = sync_wrapper(_setxattrs)

    async def _merge(self, path, paths, acl=None):
        """Concatenate objects within a single bucket"""
        bucket, key = self.split_path(path)
        source = [{"name": self.split_path(p)[1]} for p in paths]
        await self._call(
            "POST",
            "b/{}/o/{}/compose",
            bucket,
            key,
            destinationPredefinedAcl=acl,
            headers={"Content-Type": "application/json"},
            json={
                "sourceObjects": source,
                "kind": "storage#composeRequest",
                "destination": {
                    "name": key,
                    "bucket": bucket
                },
            },
        )

    merge = sync_wrapper(_merge)

    async def _cp_file(self, path1, path2, acl=None, **kwargs):
        """Duplicate remote file"""
        b1, k1 = self.split_path(path1)
        b2, k2 = self.split_path(path2)
        out = await self._call(
            "POST",
            "b/{}/o/{}/rewriteTo/b/{}/o/{}",
            b1,
            k1,
            b2,
            k2,
            headers={"Content-Type": "application/json"},
            destinationPredefinedAcl=acl,
            json_out=True,
        )
        while out["done"] is not True:
            out = await self._call(
                "POST",
                "b/{}/o/{}/rewriteTo/b/{}/o/{}",
                b1,
                k1,
                b2,
                k2,
                headers={"Content-Type": "application/json"},
                rewriteToken=out["rewriteToken"],
                destinationPredefinedAcl=acl,
                json_out=True,
            )

    async def _rm_file(self, path):
        bucket, key = self.split_path(path)
        if key:
            await self._call("DELETE", "b/{}/o/{}", bucket, key)
            self.invalidate_cache(posixpath.dirname(
                self._strip_protocol(path)))
            return True
        else:
            await self._rmdir(path)

    async def _rm_files(self, paths):
        template = ("\n--===============7330845974216740156==\n"
                    "Content-Type: application/http\n"
                    "Content-Transfer-Encoding: binary\n"
                    "Content-ID: <b29c5de2-0db4-490b-b421-6a51b598bd11+{i}>"
                    "\n\nDELETE /storage/v1/b/{bucket}/o/{key} HTTP/1.1\n"
                    "Content-Type: application/json\n"
                    "accept: application/json\ncontent-length: 0\n")
        body = "".join([
            template.format(
                i=i + 1,
                bucket=p.split("/", 1)[0],
                key=quote_plus(p.split("/", 1)[1]),
            ) for i, p in enumerate(paths)
        ])
        headers, content = await self._call(
            "POST",
            "https://storage.googleapis.com/batch/storage/v1",
            headers={
                "Content-Type":
                'multipart/mixed; boundary="=========='
                '=====7330845974216740156=="'
            },
            data=body + "\n--===============7330845974216740156==--",
        )

        boundary = headers["Content-Type"].split("=", 1)[1]
        parents = [self._parent(p) for p in paths]
        [self.invalidate_cache(parent) for parent in parents + list(paths)]
        txt = content.decode()
        if any(not ("200 OK" in c or "204 No Content" in c)
               for c in txt.split(boundary)[1:-1]):
            pattern = '"message": "([^"]+)"'
            out = set(re.findall(pattern, txt))
            raise OSError(out)

    async def _rm(self, path, recursive=False, maxdepth=None, batchsize=20):
        paths = await self._expand_path(path,
                                        recursive=recursive,
                                        maxdepth=maxdepth)
        files = [p for p in paths if self.split_path(p)[1]]
        dirs = [p for p in paths if not self.split_path(p)[1]]
        exs = await asyncio.gather(
            *([
                self._rm_files(files[i:i + batchsize])
                for i in range(0, len(files), batchsize)
            ]),
            return_exceptions=True,
        )
        exs = [
            ex for ex in exs
            if ex is not None and "No such object" not in str(ex)
        ]
        if exs:
            raise exs[0]
        await asyncio.gather(*[self._rmdir(d) for d in dirs])

    rm = sync_wrapper(_rm)

    async def _pipe_file(
        self,
        path,
        data,
        metadata=None,
        consistency=None,
        content_type="application/octet-stream",
        chunksize=50 * 2**20,
    ):
        # enforce blocksize should be a multiple of 2**18
        consistency = consistency or self.consistency
        bucket, key = self.split_path(path)
        size = len(data)
        out = None
        if size < 5 * 2**20:
            return await simple_upload(self, bucket, key, data, metadata,
                                       consistency, content_type)
        else:
            location = await initiate_upload(self, bucket, key, content_type,
                                             metadata)
            for offset in range(0, len(data), chunksize):
                bit = data[offset:offset + chunksize]
                out = await upload_chunk(self, location, bit, offset, size,
                                         content_type)

        checker = get_consistency_checker(consistency)
        checker.update(data)
        checker.validate_json_response(out)
        self.invalidate_cache(self._parent(path))

    async def _put_file(
        self,
        lpath,
        rpath,
        metadata=None,
        consistency=None,
        content_type="application/octet-stream",
        chunksize=50 * 2**20,
        **kwargs,
    ):
        # enforce blocksize should be a multiple of 2**18
        if os.path.isdir(lpath):
            return
        consistency = consistency or self.consistency
        checker = get_consistency_checker(consistency)
        bucket, key = self.split_path(rpath)
        with open(lpath, "rb") as f0:
            size = f0.seek(0, 2)
            f0.seek(0)
            if size < 5 * 2**20:
                return await simple_upload(
                    self,
                    bucket,
                    key,
                    f0.read(),
                    consistency=consistency,
                    metadatain=metadata,
                    content_type=content_type,
                )
            else:
                location = await initiate_upload(self, bucket, key,
                                                 content_type, metadata)
                offset = 0
                while True:
                    bit = f0.read(chunksize)
                    if not bit:
                        break
                    out = await upload_chunk(self, location, bit, offset, size,
                                             content_type)
                    offset += len(bit)
                    checker.update(bit)

            checker.validate_json_response(out)
            self.invalidate_cache(self._parent(rpath))

    async def _isdir(self, path):
        try:
            return (await self._info(path))["type"] == "directory"
        except IOError:
            return False

    async def _find(self,
                    path,
                    withdirs=False,
                    detail=False,
                    prefix="",
                    **kwargs):
        path = self._strip_protocol(path)
        bucket, key = self.split_path(path)
        out, _ = await self._do_list_objects(
            path,
            delimiter=None,
            prefix=prefix,
        )
        if not prefix and not out and key:
            try:
                out = [await self._get_object(path, )]
            except FileNotFoundError:
                out = []
        dirs = []
        sdirs = set()
        cache_entries = {}
        for o in out:
            par = o["name"]
            while par:
                par = self._parent(par)
                if par not in sdirs:
                    if len(par) < len(path):
                        break
                    sdirs.add(par)
                    dirs.append({
                        "Key": self.split_path(par)[1],
                        "Size": 0,
                        "name": par,
                        "StorageClass": "DIRECTORY",
                        "type": "directory",
                        "size": 0,
                    })
                # Don't cache "folder-like" objects (ex: "Create Folder" in GCS console) to prevent
                # masking subfiles in subsequent requests.
                if not o["name"].endswith("/"):
                    cache_entries.setdefault(par, []).append(o)
        self.dircache.update(cache_entries)

        if withdirs:
            out = sorted(out + dirs, key=lambda x: x["name"])

        if detail:
            return {o["name"]: o for o in out}
        return [o["name"] for o in out]

    @retry_request(retries=retries)
    async def _get_file_request(self,
                                rpath,
                                lpath,
                                *args,
                                headers=None,
                                **kwargs):
        consistency = kwargs.pop("consistency", self.consistency)

        async with self.session.get(
                url=rpath,
                params=self._get_params(kwargs),
                headers=self._get_headers(headers),
                timeout=self.requests_timeout,
        ) as r:
            r.raise_for_status()
            checker = get_consistency_checker(consistency)

            os.makedirs(os.path.dirname(lpath), exist_ok=True)
            with open(lpath, "wb") as f2:
                while True:
                    data = await r.content.read(4096 * 32)
                    if not data:
                        break
                    f2.write(data)
                    checker.update(data)

            validate_response(r.status, data, rpath)  # validate http request
            checker.validate_http_response(r)  # validate file consistency
            return r.status, r.headers, r.request_info, data

    async def _get_file(self, rpath, lpath, **kwargs):
        if await self._isdir(rpath):
            return
        u2 = self.url(rpath)
        await self._get_file_request(u2, lpath, **kwargs)

    def _open(
        self,
        path,
        mode="rb",
        block_size=None,
        cache_options=None,
        acl=None,
        consistency=None,
        metadata=None,
        autocommit=True,
        **kwargs,
    ):
        """
        See ``GCSFile``.

        consistency: None or str
            If None, use default for this instance
        """
        if block_size is None:
            block_size = self.default_block_size
        const = consistency or self.consistency
        return GCSFile(
            self,
            path,
            mode,
            block_size,
            cache_options=cache_options,
            consistency=const,
            metadata=metadata,
            acl=acl,
            autocommit=autocommit,
            **kwargs,
        )

    @classmethod
    def split_path(cls, path):
        """
        Normalise GCS path string into bucket and key.

        Parameters
        ----------
        path : string
            Input path, like `gcs://mybucket/path/to/file`.
            Path is of the form: '[gs|gcs://]bucket[/key]'

        Returns
        -------
            (bucket, key) tuple
        """
        path = cls._strip_protocol(path).lstrip("/")
        if "/" not in path:
            return path, ""
        else:
            return path.split("/", 1)
Ejemplo n.º 8
0
class HTTPFileSystem(AsyncFileSystem):
    """
    Simple File-System for fetching data via HTTP(S)

    ``ls()`` is implemented by loading the parent page and doing a regex
    match on the result. If simple_link=True, anything of the form
    "http(s)://server.com/stuff?thing=other"; otherwise only links within
    HTML href tags will be used.
    """

    sep = "/"

    def __init__(
        self,
        simple_links=True,
        block_size=None,
        same_scheme=True,
        size_policy=None,
        cache_type="bytes",
        cache_options=None,
        asynchronous=False,
        loop=None,
        client_kwargs=None,
        get_client=get_client,
        **storage_options,
    ):
        """
        NB: if this is called async, you must await set_client

        Parameters
        ----------
        block_size: int
            Blocks to read bytes; if 0, will default to raw requests file-like
            objects instead of HTTPFile instances
        simple_links: bool
            If True, will consider both HTML <a> tags and anything that looks
            like a URL; if False, will consider only the former.
        same_scheme: True
            When doing ls/glob, if this is True, only consider paths that have
            http/https matching the input URLs.
        size_policy: this argument is deprecated
        client_kwargs: dict
            Passed to aiohttp.ClientSession, see
            https://docs.aiohttp.org/en/stable/client_reference.html
            For example, ``{'auth': aiohttp.BasicAuth('user', 'pass')}``
        get_client: Callable[..., aiohttp.ClientSession]
            A callable which takes keyword arguments and constructs
            an aiohttp.ClientSession. It's state will be managed by
            the HTTPFileSystem class.
        storage_options: key-value
            Any other parameters passed on to requests
        cache_type, cache_options: defaults used in open
        """
        super().__init__(self,
                         asynchronous=asynchronous,
                         loop=loop,
                         **storage_options)
        self.block_size = block_size if block_size is not None else DEFAULT_BLOCK_SIZE
        self.simple_links = simple_links
        self.same_schema = same_scheme
        self.cache_type = cache_type
        self.cache_options = cache_options
        self.client_kwargs = client_kwargs or {}
        self.get_client = get_client
        self.kwargs = storage_options
        self._session = None

        # Clean caching-related parameters from `storage_options`
        # before propagating them as `request_options` through `self.kwargs`.
        # TODO: Maybe rename `self.kwargs` to `self.request_options` to make
        #       it clearer.
        request_options = copy(storage_options)
        self.use_listings_cache = request_options.pop("use_listings_cache",
                                                      False)
        request_options.pop("listings_expiry_time", None)
        request_options.pop("max_paths", None)
        request_options.pop("skip_instance_cache", None)
        self.kwargs = request_options

        if not asynchronous:
            sync(self.loop, self.set_session)

    @staticmethod
    def close_session(loop, session):
        if loop is not None and loop.is_running():
            try:
                sync(loop, session.close, timeout=0.1)
                return
            except (TimeoutError, FSTimeoutError):
                pass
        connector = getattr(session, "_connector", None)
        if connector is not None:
            # close after loop is dead
            connector._close()

    async def set_session(self):
        if self._session is None:
            self._session = await self.get_client(loop=self.loop,
                                                  **self.client_kwargs)
            if not self.asynchronous:
                weakref.finalize(self, self.close_session, self.loop,
                                 self._session)
        return self._session

    @classmethod
    def _strip_protocol(cls, path):
        """For HTTP, we always want to keep the full URL"""
        return path

    @classmethod
    def _parent(cls, path):
        # override, since _strip_protocol is different for URLs
        par = super()._parent(path)
        if len(par) > 7:  # "http://..."
            return par
        return ""

    async def _ls_real(self, url, detail=True, **kwargs):
        # ignoring URL-encoded arguments
        kw = self.kwargs.copy()
        kw.update(kwargs)
        logger.debug(url)
        session = await self.set_session()
        async with session.get(url, **self.kwargs) as r:
            self._raise_not_found_for_status(r, url)
            text = await r.text()
        if self.simple_links:
            links = ex2.findall(text) + [u[2] for u in ex.findall(text)]
        else:
            links = [u[2] for u in ex.findall(text)]
        out = set()
        parts = urlparse(url)
        for l in links:
            if isinstance(l, tuple):
                l = l[1]
            if l.startswith("/") and len(l) > 1:
                # absolute URL on this server
                l = parts.scheme + "://" + parts.netloc + l
            if l.startswith("http"):
                if self.same_schema and l.startswith(url.rstrip("/") + "/"):
                    out.add(l)
                elif l.replace("https", "http").startswith(
                        url.replace("https", "http").rstrip("/") + "/"):
                    # allowed to cross http <-> https
                    out.add(l)
            else:
                if l not in ["..", "../"]:
                    # Ignore FTP-like "parent"
                    out.add("/".join([url.rstrip("/"), l.lstrip("/")]))
        if not out and url.endswith("/"):
            out = await self._ls_real(url.rstrip("/"), detail=False)
        if detail:
            return [{
                "name": u,
                "size": None,
                "type": "directory" if u.endswith("/") else "file",
            } for u in out]
        else:
            return list(sorted(out))
        return out

    async def _ls(self, url, detail=True, **kwargs):

        if self.use_listings_cache and url in self.dircache:
            out = self.dircache[url]
        else:
            out = await self._ls_real(url, detail=detail, **kwargs)
            self.dircache[url] = out
        return out

    ls = sync_wrapper(_ls)

    def _raise_not_found_for_status(self, response, url):
        """
        Raises FileNotFoundError for 404s, otherwise uses raise_for_status.
        """
        if response.status == 404:
            raise FileNotFoundError(url)
        response.raise_for_status()

    async def _cat_file(self, url, start=None, end=None, **kwargs):
        kw = self.kwargs.copy()
        kw.update(kwargs)
        logger.debug(url)

        # TODO: extract into testable utility function?
        if start is not None or end is not None:
            if start == end:
                return b""
            headers = kw.pop("headers", {}).copy()

            headers["Range"] = await self._process_limits(url, start, end)
            kw["headers"] = headers
        session = await self.set_session()
        async with session.get(url, **kw) as r:
            out = await r.read()
            self._raise_not_found_for_status(r, url)
        return out

    async def _get_file(self,
                        rpath,
                        lpath,
                        chunk_size=5 * 2**20,
                        callback=_DEFAULT_CALLBACK,
                        **kwargs):
        kw = self.kwargs.copy()
        kw.update(kwargs)
        logger.debug(rpath)
        session = await self.set_session()
        async with session.get(rpath, **self.kwargs) as r:
            try:
                size = int(r.headers["content-length"])
            except (ValueError, KeyError):
                size = None

            callback.set_size(size)
            self._raise_not_found_for_status(r, rpath)
            with open(lpath, "wb") as fd:
                chunk = True
                while chunk:
                    chunk = await r.content.read(chunk_size)
                    fd.write(chunk)
                    callback.relative_update(len(chunk))

    async def _put_file(
        self,
        lpath,
        rpath,
        chunk_size=5 * 2**20,
        callback=_DEFAULT_CALLBACK,
        method="post",
        **kwargs,
    ):
        async def gen_chunks():
            # Support passing arbitrary file-like objects
            # and use them instead of streams.
            if isinstance(lpath, io.IOBase):
                context = nullcontext(lpath)
                use_seek = False  # might not support seeking
            else:
                context = open(lpath, "rb")
                use_seek = True

            with context as f:
                if use_seek:
                    callback.set_size(f.seek(0, 2))
                    f.seek(0)
                else:
                    callback.set_size(getattr(f, "size", None))

                chunk = f.read(64 * 1024)
                while chunk:
                    yield chunk
                    callback.relative_update(len(chunk))
                    chunk = f.read(64 * 1024)

        kw = self.kwargs.copy()
        kw.update(kwargs)
        session = await self.set_session()

        method = method.lower()
        if method not in ("post", "put"):
            raise ValueError(
                f"method has to be either 'post' or 'put', not: {method!r}")

        meth = getattr(session, method)
        async with meth(rpath, data=gen_chunks(), **kw) as resp:
            self._raise_not_found_for_status(resp, rpath)

    async def _exists(self, path, **kwargs):
        kw = self.kwargs.copy()
        kw.update(kwargs)
        try:
            logger.debug(path)
            session = await self.set_session()
            r = await session.get(path, **kw)
            async with r:
                return r.status < 400
        except (requests.HTTPError, aiohttp.ClientError):
            return False

    async def _isfile(self, path, **kwargs):
        return await self._exists(path, **kwargs)

    def _open(
        self,
        path,
        mode="rb",
        block_size=None,
        autocommit=None,  # XXX: This differs from the base class.
        cache_type=None,
        cache_options=None,
        size=None,
        **kwargs,
    ):
        """Make a file-like object

        Parameters
        ----------
        path: str
            Full URL with protocol
        mode: string
            must be "rb"
        block_size: int or None
            Bytes to download in one request; use instance value if None. If
            zero, will return a streaming Requests file-like instance.
        kwargs: key-value
            Any other parameters, passed to requests calls
        """
        if mode != "rb":
            raise NotImplementedError
        block_size = block_size if block_size is not None else self.block_size
        kw = self.kwargs.copy()
        kw["asynchronous"] = self.asynchronous
        kw.update(kwargs)
        size = size or self.info(path, **kwargs)["size"]
        session = sync(self.loop, self.set_session)
        if block_size and size:
            return HTTPFile(
                self,
                path,
                session=session,
                block_size=block_size,
                mode=mode,
                size=size,
                cache_type=cache_type or self.cache_type,
                cache_options=cache_options or self.cache_options,
                loop=self.loop,
                **kw,
            )
        else:
            return HTTPStreamFile(self,
                                  path,
                                  mode=mode,
                                  loop=self.loop,
                                  session=session,
                                  **kw)

    def ukey(self, url):
        """Unique identifier; assume HTTP files are static, unchanging"""
        return tokenize(url, self.kwargs, self.protocol)

    async def _info(self, url, **kwargs):
        """Get info of URL

        Tries to access location via HEAD, and then GET methods, but does
        not fetch the data.

        It is possible that the server does not supply any size information, in
        which case size will be given as None (and certain operations on the
        corresponding file will not work).
        """
        info = {}
        session = await self.set_session()

        for policy in ["head", "get"]:
            try:
                info.update(await _file_info(
                    url,
                    size_policy=policy,
                    session=session,
                    **self.kwargs,
                    **kwargs,
                ))
                if info.get("size") is not None:
                    break
            except Exception as exc:
                if policy == "get":
                    # If get failed, then raise a FileNotFoundError
                    raise FileNotFoundError(url) from exc
                logger.debug(str(exc))

        return {"name": url, "size": None, **info, "type": "file"}

    async def _glob(self, path, **kwargs):
        """
        Find files by glob-matching.

        This implementation is idntical to the one in AbstractFileSystem,
        but "?" is not considered as a character for globbing, because it is
        so common in URLs, often identifying the "query" part.
        """
        import re

        ends = path.endswith("/")
        path = self._strip_protocol(path)
        indstar = path.find("*") if path.find("*") >= 0 else len(path)
        indbrace = path.find("[") if path.find("[") >= 0 else len(path)

        ind = min(indstar, indbrace)

        detail = kwargs.pop("detail", False)

        if not has_magic(path):
            root = path
            depth = 1
            if ends:
                path += "/*"
            elif await self._exists(path):
                if not detail:
                    return [path]
                else:
                    return {path: await self._info(path)}
            else:
                if not detail:
                    return []  # glob of non-existent returns empty
                else:
                    return {}
        elif "/" in path[:ind]:
            ind2 = path[:ind].rindex("/")
            root = path[:ind2 + 1]
            depth = None if "**" in path else path[ind2 + 1:].count("/") + 1
        else:
            root = ""
            depth = None if "**" in path else path[ind + 1:].count("/") + 1

        allpaths = await self._find(root,
                                    maxdepth=depth,
                                    withdirs=True,
                                    detail=True,
                                    **kwargs)
        # Escape characters special to python regex, leaving our supported
        # special characters in place.
        # See https://www.gnu.org/software/bash/manual/html_node/Pattern-Matching.html
        # for shell globbing details.
        pattern = ("^" + (path.replace("\\", r"\\").replace(
            ".", r"\.").replace("+", r"\+").replace("//", "/").replace(
                "(", r"\(").replace(")", r"\)").replace("|", r"\|").replace(
                    "^", r"\^").replace("$", r"\$").replace(
                        "{", r"\{").replace("}", r"\}").rstrip("/")) + "$")
        pattern = re.sub("[*]{2}", "=PLACEHOLDER=", pattern)
        pattern = re.sub("[*]", "[^/]*", pattern)
        pattern = re.compile(pattern.replace("=PLACEHOLDER=", ".*"))
        out = {
            p: allpaths[p]
            for p in sorted(allpaths)
            if pattern.match(p.replace("//", "/").rstrip("/"))
        }
        if detail:
            return out
        else:
            return list(out)

    async def _isdir(self, path):
        # override, since all URLs are (also) files
        try:
            return bool(await self._ls(path))
        except (FileNotFoundError, ValueError):
            return False