示例#1
0
def run_app(app, port, loop):
    runner = AppRunner(app, access_log=None)
    loop.run_until_complete(runner.setup())

    site = TCPSite(runner, HOST, port, shutdown_timeout=0.01)
    loop.run_until_complete(site.start())

    try:
        loop.run_forever()
    except KeyboardInterrupt:  # pragma: no branch
        pass
    finally:
        logger.info('shutting down server...')
        start = loop.time()
        with contextlib.suppress(asyncio.TimeoutError, KeyboardInterrupt):
            loop.run_until_complete(runner.cleanup())
        logger.debug('shutdown took %0.2fs', loop.time() - start)
示例#2
0
class WebIF(object):
    def __init__(self, loop, config):
        self.loop = loop
        self.config = config
        if config.has_option("global", "key"):
            self.fernet = Fernet(config.get("global", "key"))
        else:
            log(__name__).error(
                "You need to add a key to section [global], e.g.\n"
                "key: %s",
                Fernet.generate_key().decode("ascii"))
            raise NoOptionError("key", "global")

        # create data dirs, if needed
        basedir = config.get("global", "basedir")
        for section in config.sections():
            if section.startswith("dir:"):
                dir_name = os.path.join(basedir, section[4:])
                os.makedirs(dir_name, exist_ok=True)

        self.app = web.Application()
        router = self.app.router
        router.add_route('GET', "/favicon.ico", self.favicon)
        router.add_route('POST', "/login", self.login)
        router.add_route('GET', "/.well-known/acme-challenge/{token}",
                         self.serve_letsencrypt)
        router.add_route('GET',
                         "/.well-known/acme-challenge/upload/{token}/{thumb}",
                         self.upload_letsencrypt)
        router.add_route('GET', "/privacy/", self.privacy)
        router.add_route('GET', "/{token}/auth/privacy/{path:.*}",
                         self.privacy)
        router.add_route('GET', "/{token}/auth/{path:.*}", self.handle)
        router.add_route('POST', "/{token}/auth/{path:.*}", self.handle)
        router.add_route('GET', "/{path:.*}", self.handle, name='handle-get')
        router.add_route('POST', "/{path:.*}", self.handle, name='handle-post')

        self.runner = AppRunner(self.app,
                                access_log_format=config.get(
                                    "logging", "access_log_format"))
        self.loop.run_until_complete(self.runner.setup())
        self.cert_watcher = None

        loop.run_until_complete(self.init_site("http_site"))
        loop.run_until_complete(self.init_site("https_site"))

        self.letsencrypt_data = dict()

    async def init_site(self, site_name):
        if getattr(self, site_name, None):
            await getattr(self, site_name).stop()

        if site_name == "https_site":
            if not self.config.has_option("global", "https_port"):
                return
            port = self.config.getint("global", "https_port")
            ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLSv1_2)
            ssl_context.load_cert_chain(
                certfile=config.get("global", "certfile"),
                keyfile=config.get("global", "keyfile"))
            if self.cert_watcher is None:
                self.cert_watcher = loop.create_task(self.watch_cert())
        else:
            if not self.config.has_option("global", "http_port"):
                return
            port = self.config.getint("global", "http_port")
            ssl_context = None

        log(__name__).info("Starting fresh site %s on port %d", site_name,
                           port)

        site = TCPSite(self.runner,
                       config.get("global", "host"),
                       port,
                       ssl_context=ssl_context,
                       reuse_address=True)

        setattr(self, site_name, site)
        await site.start()

    def close(self):
        self.loop.run_until_complete(self.runner.cleanup())
        if self.cert_watcher:
            self.cert_watcher.cancel()

    async def watch_cert(self):
        def get_mtime(fname, old_mtime):
            try:
                return os.stat(fname).st_mtime
            except OSError:
                log(__name__).warning("Error in stat", exc_info=True)
                return old_mtime

        fname = self.config.get("global", "certfile")
        ts = get_mtime(fname, None)
        while True:
            await asyncio.sleep(5)
            new_ts = get_mtime(fname, ts)
            if new_ts != ts:
                ts = new_ts
                await self.init_site("https_site")
                log(__name__).info("Updated ssl context")

    def peername(self, request):
        peername = request.transport.get_extra_info('peername')
        return peername if peername else ("???", 0)

    def redirect(self, request, user, path, msg):
        router = request.app.router['handle-get']
        if path is None:
            path = ""
        path = (self.build_prefix(user, path) + path).lstrip("/")
        if msg is None:
            raise web.HTTPFound(router.url_for(path=path))
        else:
            raise web.HTTPFound(
                router.url_for(path=path).with_query({"msg": msg}))

    async def favicon(self, request):
        favicon = self.config.get("global", "favicon")
        return web.Response(body=base64.b64decode(favicon),
                            content_type="image/vnd.microsoft.icon")

    async def serve_letsencrypt(self, request):
        token = request.match_info["token"]
        log(__name__).info("%s - - cert verification request %s",
                           self.peername(request), token)
        if token in self.letsencrypt_data:
            return web.Response(text=self.letsencrypt_data[token], status=200)
        else:
            return web.Response(text="permission denied",
                                status=403,
                                reason="permission denied")

    async def upload_letsencrypt(self, request):
        token = request.match_info["token"]
        thumb = request.match_info["thumb"]
        log(__name__).info("%s - - cert verification file %s.%s",
                           self.peername(request), token, thumb)
        self.letsencrypt_data[token] = "%s.%s" % (token, thumb)
        return web.Response(text="ok", status=200)

    async def get_user(self, request):
        """Get username either from token, basic auth header, or form keys"""

        user, password = await self.get_user_from_request(request)

        if self.config.user_perms.get(user, dict()).get("creds",
                                                        None) == password:
            log(__name__).info("%s - - User %s active", self.peername(request),
                               user)
            cred = ("%s:%s" % (user, password)).encode("ascii")
            cred = base64.b64encode(cred).decode("ascii")
            return user, password, cred

        log(__name__).warning("%s - - Anonymous active %s:%s",
                              self.peername(request), user, password)
        return UNAUTH, None, ""

    async def get_user_from_request(self, request):
        try:
            if request.content_type == "application/x-www-form-urlencoded":
                formdata = await request.post()
                request.formdata = formdata
            else:
                request.formdata = dict()

            # form keys first
            user = request.formdata.get("user", None)
            password = request.formdata.get("password", None)
            if user is not None and password is not None:
                return user, password

            if "token" in request.match_info:
                token = request.match_info["token"].encode("ascii")
                user, path = self.fernet.decrypt(token).decode("utf-8").split(
                    ":", 1)
                if path == request.match_info["path"]:
                    return user, self.config.user_perms.get(user, dict()).get(
                        "creds", None)

            # basic-auth second
            auth = request.headers.get("authorization", "")
            if auth.lower().startswith("basic "):
                return base64.b64decode(auth[6:]).decode("utf8").split(":", 1)

        except Exception:
            log(__name__).info("Error", exc_info=True)
        return UNAUTH, None

    def build_prefix(self, user, path):
        if user == UNAUTH:
            return "/"
        else:
            token = self.fernet.encrypt(
                ("%s:%s" % (user, path)).encode("utf-8"))
            return "/%s/auth/" % token.decode("ascii")

    @redirect_on_exception
    async def login(self, request):
        user, password, cred = await self.get_user(request)
        if user == UNAUTH:
            msg = "signed out"
        else:
            msg = "%s signed in" % user
        self.redirect(request, user, "", msg)

    @redirect_on_exception
    async def handle(self, request):
        path = request.match_info.get("path", "")
        fullname, dirname, fname = self.normalize_path(path)

        action = request.query.get("action", None)
        user, password, cred = await self.get_user(request)

        if fname is None and action in [None, "list"]:
            if not (has_access(user, dirname, Access.LIST)
                    or has_access(user, dirname, Access.UPLOAD)
                    or has_access(user, dirname, Access.MKDIR)):
                raise AccessError(HTTPStatus.UNAUTHORIZED, "forbidden")

            subdirs, files = list_dir(fullname)
            filter_file_list(user, dirname, subdirs, files)
            msg = request.query.get("msg", None)
            template = jinja2.Template(TEMPLATE)
            resp = web.Response(
                status=200,
                content_type='text/html',
                text=template.render(
                    dirname=dirname,
                    parentdir=None
                    if not dirname else os.path.split(dirname)[0],
                    files=[{
                        "name": f,
                        "path": os.path.join(dirname, f)
                    } for f in files],
                    subdirs=[{
                        "name": s,
                        "path": os.path.join(dirname, s)
                    } for s in subdirs],
                    statusmessage=msg,
                    prefix="/",
                    allow_upload=has_access(user, dirname, Access.UPLOAD),
                    allow_delete=has_access(user, dirname, Access.DELETE),
                    allow_fetch=has_access(user, dirname, Access.FETCH),
                    allow_mkdir=has_access(user, dirname, Access.MKDIR),
                    debug=config.getboolean("global", "debug"),
                    user=user,
                    bp=self.build_prefix))

        elif fname is None and action == "upload":
            if has_access(user, dirname, Access.UPLOAD):
                resp = await self.upload_file(request, user, dirname)
            else:
                raise AccessError(401, "forbidden")

        elif fname and action == "delete":
            if has_access(user, dirname, Access.DELETE):
                resp = await self.delete_file(request, user, dirname, fname)
            else:
                raise AccessError(401, "forbidden")

        elif fname is None and action == "mkdir":
            if has_access(user, dirname, Access.MKDIR):
                resp = await self.make_dir(request, user, dirname)
            else:
                raise AccessError(401, "forbidden")

        elif fname and action is None:
            if has_access(user, dirname, Access.FETCH):
                log(__name__).info("%s - - Download %s by %s",
                                   self.peername(request), fullname, user)
                resp = await self.streamfile(request, fullname)
            else:
                raise AccessError(401, "forbidden")

        resp.headers["Accept-Ranges"] = "bytes"
        return resp

    async def make_dir(self, request, user, dirname):
        filename = text.get_valid_filename(request.formdata["dirname"])
        name = os.path.join(dirname, filename)
        fullname = os.path.join(config.get("global", "basedir"), name)
        os.makedirs(fullname, exist_ok=True)
        log(__name__).info("%s - - Directory %s created by %s",
                           self.peername(request), fullname, user)
        self.redirect(request, user, dirname, "Directory %s created" % name)

    async def upload_file(self, request, user, directory):
        if request.method != "POST" or request.content_type != "multipart/form-data":
            AccessError(400, "invalid")

        reader = await request.multipart()
        field = await reader.next()
        if field.name != "file":
            AccessError(400, "invalid")
        fname = text.get_valid_filename(os.path.basename(field.filename))
        name = os.path.join(directory, fname)
        fullname = os.path.join(config.get("global", "basedir"), name)
        if not fname:
            self.redirect(request, user, directory, "invalid filename")
        if os.access(fullname, os.R_OK):
            raise AccessError(403, "duplicate")
        with open(fullname, 'wb') as f:
            while True:
                chunk = await field.read_chunk()  # 8192 bytes by default.
                if not chunk:
                    break
                f.write(chunk)
        log(__name__).info("%s - - File %s uploaded by %s",
                           self.peername(request), name, user)
        size = os.lstat(fullname).st_size
        self.redirect(request, user, directory,
                      "%d bytes saved as %s" % (size, fname))

    async def delete_file(self, request, user, dirname, filename):
        name = os.path.join(dirname, filename)
        fullname = os.path.join(config.get("global", "basedir"), name)
        os.remove(fullname)
        log(__name__).info("%s - - File %s deleted by %s",
                           self.peername(request), fullname, user)
        self.redirect(request, user, dirname, "File %s deleted" % name)
        raise AccessError(401, "permission denied")

    async def streamfile(self, request, fullname):
        rng = request.http_range

        resp = web.StreamResponse(status=200 if rng.start is None else 206,
                                  reason="OK",
                                  headers={
                                      'Content-Type': get_mime_type(fullname),
                                      'Accept-Ranges': 'bytes'
                                  })

        size = length = os.path.getsize(fullname)

        try:

            with open(fullname, 'rb') as f:
                if rng.start is not None:
                    length -= rng.start
                    if rng.stop is not None:
                        length = min(rng.stop - rng.start + 1, length)
                    cr = "bytes {0}-{1}/{2}".format(rng.start,
                                                    rng.start + length - 1,
                                                    size)
                    resp.headers["Content-Range"] = cr
                    f.seek(rng.start)
                await resp.prepare(request)
                while length > 0:
                    buf = f.read(min(8192, length))
                    if not buf:
                        break
                    length -= len(buf)
                    await resp.write(buf)

                await resp.write_eof()
        except ConnectionResetError:
            # client went away
            pass

        return resp

    async def privacy(self, request):
        path = request.match_info.get("path", "")
        print("Path=", path)
        user, password, cred = await self.get_user(request)
        self.redirect(request, user, path,
                      self.config.get("global", "gdprmsg"))

    def normalize_path(self, path):
        basedir = self.config.get("global", "basedir")
        name = os.path.abspath(os.path.join(basedir, path))
        path = name[len(basedir):].strip(os.path.sep)

        if name != basedir and not name.startswith(basedir + os.path.sep):
            log(__name__).warning("Outside base: %s", name)
            raise AccessError(403, "forbidden")
        try:
            sr = os.stat(name)
        except FileNotFoundError:
            log(__name__).warning("File not found: %s", name)
            raise AccessError(404, "not found")
        except Exception:
            log(__name__).error("Stat error on %s", name)
            raise AccessError(401, "forbidden")

        if sr.st_uid != os.geteuid():
            log(__name__).warning("Wrong owner found for %s", name)

        if stat.S_ISDIR(sr.st_mode):
            return name, path, None

        if stat.S_ISREG(sr.st_mode):
            dirname, fname = os.path.split(path)
            return name, dirname, fname

        raise AccessError(401, "forbidden")