Esempio n. 1
0
    def __init__(
        self,
        cfg: ServerConfig,
        handler: ServerHandler,
        *,
        ssl_context: Optional[SSLContext] = None,
    ) -> None:
        handler._set_server(self)
        self.cfg = cfg
        self.handler = handler
        self.host = cfg.host
        self.port = cfg.port
        self.shutdown_timeout = cfg.shutdown_timeout
        self.ssl_context = ssl_context
        self.backlog = cfg.backlog
        self.reuse_address = cfg.reuse_address
        self.reuse_port = cfg.reuse_port

        self.sites: List[BaseSite] = []

        self.web_app = web.Application()
        self.runner = AppRunner(
            self.web_app,
            handle_signals=True,
            access_log_class=AccessLogger,
            access_log_format=AccessLogger.LOG_FORMAT,
            access_log=access_logger,
        )
        self.web_app.middlewares.append(self._req_wrapper)
Esempio n. 2
0
async def async_main(loop: asyncio.AbstractEventLoop) -> None:
    jira_oauth = JiraOAuth.from_file()

    jira_oauth.app = Application()
    jira_oauth.app.add_routes([web.get('/', jira_oauth.process_oauth_result)])
    runner = AppRunner(app=jira_oauth.app)
    await runner.setup()
    site = TCPSite(runner=runner)
    await site.start()

    jira_oauth.redirect_url = 'http://localhost:8080/'
    jira_oauth.loop = loop

    await jira_oauth.generate_request_token_and_auth_url()
    jira_oauth_console = JiraOAuthConsole(jira_oauth=jira_oauth)
    print(
        f"Request Token: oauth_token={jira_oauth.request_token['oauth_token']}, "
        f"oauth_token_secret={jira_oauth.request_token['oauth_token_secret']}")
    print()
    jira_oauth_console.print_url()
    await jira_oauth.generate_access_token()
    print()
    print(
        f"Access Token: oauth_token={jira_oauth.access_token['oauth_token']}, "
        f"oauth_token_secret={jira_oauth.access_token['oauth_token_secret']}")
    print(
        "You may now access protected resources using the access tokens above."
    )
    print()
    await jira_oauth_console.check_access_token()
Esempio n. 3
0
 async def run_server(self):
     app = web.Application()
     app.add_routes([web.get('/', self.serve_image)])
     runner = AppRunner(app)
     await runner.setup()
     site = TCPSite(runner, "0.0.0.0", 8080)
     await site.start()
Esempio n. 4
0
    async def start(self):
        app, *_ = serve_static(static_path=str(self.config.dist_dir), port=self.port)
        self.runner = AppRunner(app, access_log=None)
        await self.runner.setup()

        site = TCPSite(self.runner, HOST, self.port, shutdown_timeout=0.01)
        await site.start()
Esempio n. 5
0
    def __init__(self, loop, config):
        self.loop = loop
        self.config = config
        if config.has_option("global", "key"):
            self.fernet = Fernet(config.get("global", "key"))
        else:
            log(__name__).error(
                "You need to add a key to section [global], e.g.\n"
                "key: %s",
                Fernet.generate_key().decode("ascii"))
            raise NoOptionError("key", "global")

        # create data dirs, if needed
        basedir = config.get("global", "basedir")
        for section in config.sections():
            if section.startswith("dir:"):
                dir_name = os.path.join(basedir, section[4:])
                os.makedirs(dir_name, exist_ok=True)

        self.app = web.Application()
        router = self.app.router
        router.add_route('GET', "/favicon.ico", self.favicon)
        router.add_route('POST', "/login", self.login)
        router.add_route('GET', "/.well-known/acme-challenge/{token}",
                         self.serve_letsencrypt)
        router.add_route('GET',
                         "/.well-known/acme-challenge/upload/{token}/{thumb}",
                         self.upload_letsencrypt)
        router.add_route('GET', "/privacy/", self.privacy)
        router.add_route('GET', "/{token}/auth/privacy/{path:.*}",
                         self.privacy)
        router.add_route('GET', "/{token}/auth/{path:.*}", self.handle)
        router.add_route('POST', "/{token}/auth/{path:.*}", self.handle)
        router.add_route('GET', "/{path:.*}", self.handle, name='handle-get')
        router.add_route('POST', "/{path:.*}", self.handle, name='handle-post')

        self.runner = AppRunner(self.app,
                                access_log_format=config.get(
                                    "logging", "access_log_format"))
        self.loop.run_until_complete(self.runner.setup())
        self.cert_watcher = None

        loop.run_until_complete(self.init_site("http_site"))
        loop.run_until_complete(self.init_site("https_site"))

        self.letsencrypt_data = dict()
Esempio n. 6
0
def run_app(app, port, loop):
    runner = AppRunner(app, access_log=None)
    loop.run_until_complete(runner.setup())

    site = TCPSite(runner, HOST, port, shutdown_timeout=0.01)
    loop.run_until_complete(site.start())

    try:
        loop.run_forever()
    except KeyboardInterrupt:  # pragma: no branch
        pass
    finally:
        logger.info('shutting down server...')
        start = loop.time()
        with contextlib.suppress(asyncio.TimeoutError, KeyboardInterrupt):
            loop.run_until_complete(runner.cleanup())
        logger.debug('shutdown took %0.2fs', loop.time() - start)
Esempio n. 7
0
async def init(loop):
    app = web.Application(loop=loop)
    runner: web.AppRunner = AppRunner(app)
    await runner.setup()
    handler_svc: web.Server = runner.server
    serv_generator = loop.create_server(handler_svc, conf['SITE_HOST'],
                                        conf['SITE_PORT'])
    return serv_generator, handler_svc, app
Esempio n. 8
0
    def init(loop):
        app = web.Application(loop=loop,
                              middlewares=[logger_factory, response_factory])
        add_routes(app, 'handlers')

        srv = yield from loop.create_server(
            AppRunner(app)._make_server(), 'localhost', 8000)
        logging.info('server started at http://127.0.0.1:8000...')
        return srv
Esempio n. 9
0
async def websocket_server(loop, free_port):
    host = 'localhost'
    runner = AppRunner(websocket_application)
    await runner.setup()
    tcpsite = TCPSite(runner, host, free_port, shutdown_timeout=2)
    await tcpsite.start()

    yield tcpsite.name

    await runner.shutdown()
    await runner.cleanup()
Esempio n. 10
0
 def __init__(self, config, handler):
     self.loop = asyncio.get_running_loop()
     self.port = config.getint("local", "doh_port")
     self.path = config.get("local", "doh_path")
     self.ssl = ssl.SSLContext(ssl.PROTOCOL_TLSv1_2)
     self.ssl.load_cert_chain(config.get("local", "doh_cert"),
                              config.get("local", "doh_key"))
     self.protocol_handler = handler
     app = web.Application()
     app.router.add_get(self.path, self.handle_get)
     app.router.add_post(self.path, self.handle_post)
     self.runner = AppRunner(app)
     asyncio.create_task(self.setup())
Esempio n. 11
0
    async def _start_static_server(self):
        self._port = find_free_port()
        self._runner = AppRunner(
            self._app,
            handle_signals=True,
            access_log_class=AccessLogger,
            access_log_format=AccessLogger.LOG_FORMAT,
            access_log=logging.getLogger(__name__),
        )

        await self._runner.setup()

        self._site = TCPSite(
            self._runner,
            port=self._port,
        )

        await self._site.start()
Esempio n. 12
0
    async def run_app(self, app, host, port):
        self.runner = AppRunner(app)

        await self.runner.setup()

        sites = []
        log(__name__).info("Created HTTP endpoint %s:%d", host, port)
        sites.append(
            TCPSite(self.runner,
                    host,
                    port,
                    shutdown_timeout=SHUTDOWN_TIMEOUT,
                    backlog=BACKLOG,
                    reuse_address=True,
                    reuse_port=True))

        for site in sites:
            await site.start()
Esempio n. 13
0
async def start_server() -> Iterable[None]:
    handler = functools.partial(handle_rpc, serializer=JsonSerializer(), bstream=SimpleBlockStream())
    server = AIOHttpTransportServer(process_request=handler,
                                    ip=server_addr,
                                    port=PORT,
                                    ssl_cert=SSL_CERT,
                                    ssl_key=SSL_KEY,
                                    api_key_enc=API_KEY_ENC,
                                    settings={"serializer": "json", "bstream": "simple"})

    app = server.make_app()
    runner = AppRunner(app)
    await runner.setup()
    site = TCPSite(runner, server_addr, PORT, ssl_context=server.make_ssl_context())
    await site.start()
    try:
        yield
    finally:
        await runner.cleanup()
Esempio n. 14
0
    async def run_app(self,
                      app: Application,
                      host: str = "0.0.0.0",
                      port: int = 7070) -> None:
        """Runs application and blocks until stopped.

        Args:
            app: Application instance.
            host: TCP/IP hostname to serve on.
            port: TCP/IP port to serve on.
        """
        _LOG.info("Starting app", extra={"host": host, "port": port})
        runner = AppRunner(app=app)
        await runner.setup()
        site = TCPSite(runner=runner, host=host, port=port)
        await site.start()

        await self._stopped.wait()
        _LOG.debug("Stopping app", extra={"host": host, "port": port})
        await app["scheduler"].close()
        await runner.cleanup()
Esempio n. 15
0
async def run_app(
    context: BeanContext,
    app: Union[Application, Awaitable[Application]],
    *,
    host: Optional[str] = None,
    port: Optional[int] = None,
    path: Optional[str] = None,
    sock: Optional[socket.socket] = None,
    shutdown_timeout: float = 60.0,
    ssl_context: Optional[SSLContext] = None,
    print: Callable[..., None] = print,
    backlog: int = 128,
    access_log_class: Type[AbstractAccessLogger] = AccessLogger,
    access_log_format: str = AccessLogger.LOG_FORMAT,
    access_log: Optional[logging.Logger] = access_logger,
    handle_signals: bool = True,
    reuse_address: Optional[bool] = None,
    reuse_port: Optional[bool] = None
) -> None:

    loop = context.loop
    stop_event = asyncio.Event(loop=loop)

    def stop(**kwargs):
        stop_event.set()

    for sig in ("SIGINT", "SIGTERM"):
        await context.add_signal_handler(sig, stop)
    # A internal functio to actually do all dirty job for application running
    if asyncio.iscoroutine(app):
        app = await app  # type: ignore

    app = cast(Application, app)

    runner = AppRunner(
        app,
        handle_signals=handle_signals,
        access_log_class=access_log_class,
        access_log_format=access_log_format,
        access_log=access_log,
    )

    await runner.setup()

    sites = []  # type: List[BaseSite]

    try:
        if host is not None:
            if isinstance(host, (str, bytes, bytearray, memoryview)):
                sites.append(
                    TCPSite(
                        runner,
                        host,
                        port,
                        shutdown_timeout=shutdown_timeout,
                        ssl_context=ssl_context,
                        backlog=backlog,
                        reuse_address=reuse_address,
                        reuse_port=reuse_port,
                    )
                )
            else:
                for h in host:
                    sites.append(
                        TCPSite(
                            runner,
                            h,
                            port,
                            shutdown_timeout=shutdown_timeout,
                            ssl_context=ssl_context,
                            backlog=backlog,
                            reuse_address=reuse_address,
                            reuse_port=reuse_port,
                        )
                    )
        elif path is None and sock is None or port is not None:
            sites.append(
                TCPSite(
                    runner,
                    port=port,
                    shutdown_timeout=shutdown_timeout,
                    ssl_context=ssl_context,
                    backlog=backlog,
                    reuse_address=reuse_address,
                    reuse_port=reuse_port,
                )
            )

        if path is not None:
            if isinstance(path, (str, bytes, bytearray, memoryview)):
                sites.append(
                    UnixSite(
                        runner,
                        path,
                        shutdown_timeout=shutdown_timeout,
                        ssl_context=ssl_context,
                        backlog=backlog,
                    )
                )
            else:
                for p in path:
                    sites.append(
                        UnixSite(
                            runner,
                            p,
                            shutdown_timeout=shutdown_timeout,
                            ssl_context=ssl_context,
                            backlog=backlog,
                        )
                    )

        if sock is not None:
            if not isinstance(sock, Iterable):
                sites.append(
                    SockSite(
                        runner,
                        sock,
                        shutdown_timeout=shutdown_timeout,
                        ssl_context=ssl_context,
                        backlog=backlog,
                    )
                )
            else:
                for s in sock:
                    sites.append(
                        SockSite(
                            runner,
                            s,
                            shutdown_timeout=shutdown_timeout,
                            ssl_context=ssl_context,
                            backlog=backlog,
                        )
                    )
        for site in sites:
            await site.start()

        app.aio_pod_context = context

        if print:  # pragma: no branch
            names = sorted(str(s.name) for s in runner.sites)
            print("======== Running on {} ========\n".format(", ".join(names)))
        await stop_event.wait()
    finally:
        await runner.cleanup()
Esempio n. 16
0
 async def run(self):
     runner = AppRunner(self.app, handle_signals=True)
     await runner.setup()
     site = TCPSite(runner, port=self.port, shutdown_timeout=60.0)
     await site.start()
Esempio n. 17
0
def prepare_app(app,
                *,
                host=None,
                port=None,
                path=None,
                sock=None,
                shutdown_timeout=60.0,
                ssl_context=None,
                backlog=128,
                access_log_class=helpers.AccessLogger,
                access_log_format=helpers.AccessLogger.LOG_FORMAT,
                access_log=access_logger,
                handle_signals=True,
                reuse_address=None,
                reuse_port=None):
    """
    Slightly modified version of aiohttp.web.run_app, where the server is not
    really started, but the coroutine is returned.
    This allows to caller to run multiple apps at once.
    """
    loop = asyncio.get_event_loop()

    if asyncio.iscoroutine(app):
        app = loop.run_until_complete(app)

    runner = AppRunner(app,
                       handle_signals=handle_signals,
                       access_log_class=access_log_class,
                       access_log_format=access_log_format,
                       access_log=access_log)

    loop.run_until_complete(runner.setup())

    sites = []

    if host is not None:
        if isinstance(host, (str, bytes, bytearray, memoryview)):
            sites.append(
                TCPSite(runner,
                        host,
                        port,
                        shutdown_timeout=shutdown_timeout,
                        ssl_context=ssl_context,
                        backlog=backlog,
                        reuse_address=reuse_address,
                        reuse_port=reuse_port))
        else:
            for h in host:
                sites.append(
                    TCPSite(runner,
                            h,
                            port,
                            shutdown_timeout=shutdown_timeout,
                            ssl_context=ssl_context,
                            backlog=backlog,
                            reuse_address=reuse_address,
                            reuse_port=reuse_port))
    elif path is None and sock is None or port is not None:
        sites.append(
            TCPSite(runner,
                    port=port,
                    shutdown_timeout=shutdown_timeout,
                    ssl_context=ssl_context,
                    backlog=backlog,
                    reuse_address=reuse_address,
                    reuse_port=reuse_port))

    if path is not None:
        if isinstance(path, (str, bytes, bytearray, memoryview)):
            sites.append(
                UnixSite(runner,
                         path,
                         shutdown_timeout=shutdown_timeout,
                         ssl_context=ssl_context,
                         backlog=backlog))
        else:
            for p in path:
                sites.append(
                    UnixSite(runner,
                             p,
                             shutdown_timeout=shutdown_timeout,
                             ssl_context=ssl_context,
                             backlog=backlog))

    if sock is not None:
        if not isinstance(sock, Iterable):
            sites.append(
                SockSite(runner,
                         sock,
                         shutdown_timeout=shutdown_timeout,
                         ssl_context=ssl_context,
                         backlog=backlog))
        else:
            for s in sock:
                sites.append(
                    SockSite(runner,
                             s,
                             shutdown_timeout=shutdown_timeout,
                             ssl_context=ssl_context,
                             backlog=backlog))

    runner.prepared_sites = sites
    return runner
Esempio n. 18
0
async def run_server():  # pragma: no cover
    runner = AppRunner(await get_app())
    await runner.setup()
    site = web.TCPSite(runner, "0.0.0.0", 8080)
    await site.start()
Esempio n. 19
0
class WebIF(object):
    def __init__(self, loop, config):
        self.loop = loop
        self.config = config
        if config.has_option("global", "key"):
            self.fernet = Fernet(config.get("global", "key"))
        else:
            log(__name__).error(
                "You need to add a key to section [global], e.g.\n"
                "key: %s",
                Fernet.generate_key().decode("ascii"))
            raise NoOptionError("key", "global")

        # create data dirs, if needed
        basedir = config.get("global", "basedir")
        for section in config.sections():
            if section.startswith("dir:"):
                dir_name = os.path.join(basedir, section[4:])
                os.makedirs(dir_name, exist_ok=True)

        self.app = web.Application()
        router = self.app.router
        router.add_route('GET', "/favicon.ico", self.favicon)
        router.add_route('POST', "/login", self.login)
        router.add_route('GET', "/.well-known/acme-challenge/{token}",
                         self.serve_letsencrypt)
        router.add_route('GET',
                         "/.well-known/acme-challenge/upload/{token}/{thumb}",
                         self.upload_letsencrypt)
        router.add_route('GET', "/privacy/", self.privacy)
        router.add_route('GET', "/{token}/auth/privacy/{path:.*}",
                         self.privacy)
        router.add_route('GET', "/{token}/auth/{path:.*}", self.handle)
        router.add_route('POST', "/{token}/auth/{path:.*}", self.handle)
        router.add_route('GET', "/{path:.*}", self.handle, name='handle-get')
        router.add_route('POST', "/{path:.*}", self.handle, name='handle-post')

        self.runner = AppRunner(self.app,
                                access_log_format=config.get(
                                    "logging", "access_log_format"))
        self.loop.run_until_complete(self.runner.setup())
        self.cert_watcher = None

        loop.run_until_complete(self.init_site("http_site"))
        loop.run_until_complete(self.init_site("https_site"))

        self.letsencrypt_data = dict()

    async def init_site(self, site_name):
        if getattr(self, site_name, None):
            await getattr(self, site_name).stop()

        if site_name == "https_site":
            if not self.config.has_option("global", "https_port"):
                return
            port = self.config.getint("global", "https_port")
            ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLSv1_2)
            ssl_context.load_cert_chain(
                certfile=config.get("global", "certfile"),
                keyfile=config.get("global", "keyfile"))
            if self.cert_watcher is None:
                self.cert_watcher = loop.create_task(self.watch_cert())
        else:
            if not self.config.has_option("global", "http_port"):
                return
            port = self.config.getint("global", "http_port")
            ssl_context = None

        log(__name__).info("Starting fresh site %s on port %d", site_name,
                           port)

        site = TCPSite(self.runner,
                       config.get("global", "host"),
                       port,
                       ssl_context=ssl_context,
                       reuse_address=True)

        setattr(self, site_name, site)
        await site.start()

    def close(self):
        self.loop.run_until_complete(self.runner.cleanup())
        if self.cert_watcher:
            self.cert_watcher.cancel()

    async def watch_cert(self):
        def get_mtime(fname, old_mtime):
            try:
                return os.stat(fname).st_mtime
            except OSError:
                log(__name__).warning("Error in stat", exc_info=True)
                return old_mtime

        fname = self.config.get("global", "certfile")
        ts = get_mtime(fname, None)
        while True:
            await asyncio.sleep(5)
            new_ts = get_mtime(fname, ts)
            if new_ts != ts:
                ts = new_ts
                await self.init_site("https_site")
                log(__name__).info("Updated ssl context")

    def peername(self, request):
        peername = request.transport.get_extra_info('peername')
        return peername if peername else ("???", 0)

    def redirect(self, request, user, path, msg):
        router = request.app.router['handle-get']
        if path is None:
            path = ""
        path = (self.build_prefix(user, path) + path).lstrip("/")
        if msg is None:
            raise web.HTTPFound(router.url_for(path=path))
        else:
            raise web.HTTPFound(
                router.url_for(path=path).with_query({"msg": msg}))

    async def favicon(self, request):
        favicon = self.config.get("global", "favicon")
        return web.Response(body=base64.b64decode(favicon),
                            content_type="image/vnd.microsoft.icon")

    async def serve_letsencrypt(self, request):
        token = request.match_info["token"]
        log(__name__).info("%s - - cert verification request %s",
                           self.peername(request), token)
        if token in self.letsencrypt_data:
            return web.Response(text=self.letsencrypt_data[token], status=200)
        else:
            return web.Response(text="permission denied",
                                status=403,
                                reason="permission denied")

    async def upload_letsencrypt(self, request):
        token = request.match_info["token"]
        thumb = request.match_info["thumb"]
        log(__name__).info("%s - - cert verification file %s.%s",
                           self.peername(request), token, thumb)
        self.letsencrypt_data[token] = "%s.%s" % (token, thumb)
        return web.Response(text="ok", status=200)

    async def get_user(self, request):
        """Get username either from token, basic auth header, or form keys"""

        user, password = await self.get_user_from_request(request)

        if self.config.user_perms.get(user, dict()).get("creds",
                                                        None) == password:
            log(__name__).info("%s - - User %s active", self.peername(request),
                               user)
            cred = ("%s:%s" % (user, password)).encode("ascii")
            cred = base64.b64encode(cred).decode("ascii")
            return user, password, cred

        log(__name__).warning("%s - - Anonymous active %s:%s",
                              self.peername(request), user, password)
        return UNAUTH, None, ""

    async def get_user_from_request(self, request):
        try:
            if request.content_type == "application/x-www-form-urlencoded":
                formdata = await request.post()
                request.formdata = formdata
            else:
                request.formdata = dict()

            # form keys first
            user = request.formdata.get("user", None)
            password = request.formdata.get("password", None)
            if user is not None and password is not None:
                return user, password

            if "token" in request.match_info:
                token = request.match_info["token"].encode("ascii")
                user, path = self.fernet.decrypt(token).decode("utf-8").split(
                    ":", 1)
                if path == request.match_info["path"]:
                    return user, self.config.user_perms.get(user, dict()).get(
                        "creds", None)

            # basic-auth second
            auth = request.headers.get("authorization", "")
            if auth.lower().startswith("basic "):
                return base64.b64decode(auth[6:]).decode("utf8").split(":", 1)

        except Exception:
            log(__name__).info("Error", exc_info=True)
        return UNAUTH, None

    def build_prefix(self, user, path):
        if user == UNAUTH:
            return "/"
        else:
            token = self.fernet.encrypt(
                ("%s:%s" % (user, path)).encode("utf-8"))
            return "/%s/auth/" % token.decode("ascii")

    @redirect_on_exception
    async def login(self, request):
        user, password, cred = await self.get_user(request)
        if user == UNAUTH:
            msg = "signed out"
        else:
            msg = "%s signed in" % user
        self.redirect(request, user, "", msg)

    @redirect_on_exception
    async def handle(self, request):
        path = request.match_info.get("path", "")
        fullname, dirname, fname = self.normalize_path(path)

        action = request.query.get("action", None)
        user, password, cred = await self.get_user(request)

        if fname is None and action in [None, "list"]:
            if not (has_access(user, dirname, Access.LIST)
                    or has_access(user, dirname, Access.UPLOAD)
                    or has_access(user, dirname, Access.MKDIR)):
                raise AccessError(HTTPStatus.UNAUTHORIZED, "forbidden")

            subdirs, files = list_dir(fullname)
            filter_file_list(user, dirname, subdirs, files)
            msg = request.query.get("msg", None)
            template = jinja2.Template(TEMPLATE)
            resp = web.Response(
                status=200,
                content_type='text/html',
                text=template.render(
                    dirname=dirname,
                    parentdir=None
                    if not dirname else os.path.split(dirname)[0],
                    files=[{
                        "name": f,
                        "path": os.path.join(dirname, f)
                    } for f in files],
                    subdirs=[{
                        "name": s,
                        "path": os.path.join(dirname, s)
                    } for s in subdirs],
                    statusmessage=msg,
                    prefix="/",
                    allow_upload=has_access(user, dirname, Access.UPLOAD),
                    allow_delete=has_access(user, dirname, Access.DELETE),
                    allow_fetch=has_access(user, dirname, Access.FETCH),
                    allow_mkdir=has_access(user, dirname, Access.MKDIR),
                    debug=config.getboolean("global", "debug"),
                    user=user,
                    bp=self.build_prefix))

        elif fname is None and action == "upload":
            if has_access(user, dirname, Access.UPLOAD):
                resp = await self.upload_file(request, user, dirname)
            else:
                raise AccessError(401, "forbidden")

        elif fname and action == "delete":
            if has_access(user, dirname, Access.DELETE):
                resp = await self.delete_file(request, user, dirname, fname)
            else:
                raise AccessError(401, "forbidden")

        elif fname is None and action == "mkdir":
            if has_access(user, dirname, Access.MKDIR):
                resp = await self.make_dir(request, user, dirname)
            else:
                raise AccessError(401, "forbidden")

        elif fname and action is None:
            if has_access(user, dirname, Access.FETCH):
                log(__name__).info("%s - - Download %s by %s",
                                   self.peername(request), fullname, user)
                resp = await self.streamfile(request, fullname)
            else:
                raise AccessError(401, "forbidden")

        resp.headers["Accept-Ranges"] = "bytes"
        return resp

    async def make_dir(self, request, user, dirname):
        filename = text.get_valid_filename(request.formdata["dirname"])
        name = os.path.join(dirname, filename)
        fullname = os.path.join(config.get("global", "basedir"), name)
        os.makedirs(fullname, exist_ok=True)
        log(__name__).info("%s - - Directory %s created by %s",
                           self.peername(request), fullname, user)
        self.redirect(request, user, dirname, "Directory %s created" % name)

    async def upload_file(self, request, user, directory):
        if request.method != "POST" or request.content_type != "multipart/form-data":
            AccessError(400, "invalid")

        reader = await request.multipart()
        field = await reader.next()
        if field.name != "file":
            AccessError(400, "invalid")
        fname = text.get_valid_filename(os.path.basename(field.filename))
        name = os.path.join(directory, fname)
        fullname = os.path.join(config.get("global", "basedir"), name)
        if not fname:
            self.redirect(request, user, directory, "invalid filename")
        if os.access(fullname, os.R_OK):
            raise AccessError(403, "duplicate")
        with open(fullname, 'wb') as f:
            while True:
                chunk = await field.read_chunk()  # 8192 bytes by default.
                if not chunk:
                    break
                f.write(chunk)
        log(__name__).info("%s - - File %s uploaded by %s",
                           self.peername(request), name, user)
        size = os.lstat(fullname).st_size
        self.redirect(request, user, directory,
                      "%d bytes saved as %s" % (size, fname))

    async def delete_file(self, request, user, dirname, filename):
        name = os.path.join(dirname, filename)
        fullname = os.path.join(config.get("global", "basedir"), name)
        os.remove(fullname)
        log(__name__).info("%s - - File %s deleted by %s",
                           self.peername(request), fullname, user)
        self.redirect(request, user, dirname, "File %s deleted" % name)
        raise AccessError(401, "permission denied")

    async def streamfile(self, request, fullname):
        rng = request.http_range

        resp = web.StreamResponse(status=200 if rng.start is None else 206,
                                  reason="OK",
                                  headers={
                                      'Content-Type': get_mime_type(fullname),
                                      'Accept-Ranges': 'bytes'
                                  })

        size = length = os.path.getsize(fullname)

        try:

            with open(fullname, 'rb') as f:
                if rng.start is not None:
                    length -= rng.start
                    if rng.stop is not None:
                        length = min(rng.stop - rng.start + 1, length)
                    cr = "bytes {0}-{1}/{2}".format(rng.start,
                                                    rng.start + length - 1,
                                                    size)
                    resp.headers["Content-Range"] = cr
                    f.seek(rng.start)
                await resp.prepare(request)
                while length > 0:
                    buf = f.read(min(8192, length))
                    if not buf:
                        break
                    length -= len(buf)
                    await resp.write(buf)

                await resp.write_eof()
        except ConnectionResetError:
            # client went away
            pass

        return resp

    async def privacy(self, request):
        path = request.match_info.get("path", "")
        print("Path=", path)
        user, password, cred = await self.get_user(request)
        self.redirect(request, user, path,
                      self.config.get("global", "gdprmsg"))

    def normalize_path(self, path):
        basedir = self.config.get("global", "basedir")
        name = os.path.abspath(os.path.join(basedir, path))
        path = name[len(basedir):].strip(os.path.sep)

        if name != basedir and not name.startswith(basedir + os.path.sep):
            log(__name__).warning("Outside base: %s", name)
            raise AccessError(403, "forbidden")
        try:
            sr = os.stat(name)
        except FileNotFoundError:
            log(__name__).warning("File not found: %s", name)
            raise AccessError(404, "not found")
        except Exception:
            log(__name__).error("Stat error on %s", name)
            raise AccessError(401, "forbidden")

        if sr.st_uid != os.geteuid():
            log(__name__).warning("Wrong owner found for %s", name)

        if stat.S_ISDIR(sr.st_mode):
            return name, path, None

        if stat.S_ISREG(sr.st_mode):
            dirname, fname = os.path.split(path)
            return name, dirname, fname

        raise AccessError(401, "forbidden")
Esempio n. 20
0
def main():
    """
    entry point
    """
    parser = argparse.ArgumentParser(
        description='Distributed Cronlike Scheduler')

    parser.add_argument('-l',
                        '--log-file',
                        default=None,
                        help='path to store logfile')
    parser.add_argument('-p',
                        '--storage-path',
                        default=None,
                        help='directory where to store cache')
    parser.add_argument('-u',
                        '--udp-communication-port',
                        type=int,
                        default=12345,
                        help='communication port (default: 12345)')
    parser.add_argument('-i',
                        '--broadcast-interval',
                        type=int,
                        default=5,
                        help='interval for broadcasting data over UDP')
    parser.add_argument(
        '-c',
        '--cron',
        default=None,
        help=
        'crontab to use (default: /etc/crontab, use `memory` to not save to file'
    )
    parser.add_argument('-d',
                        '--cron-user',
                        default=None,
                        help='user for storing cron entries')
    parser.add_argument('-w',
                        '--web-port',
                        type=int,
                        default=8080,
                        help='web hosting port (default: 8080)')
    parser.add_argument(
        '-n',
        '--ntp-server',
        default='pool.ntp.org',
        help='NTP server to detect clock skew (default: pool.ntp.org)')
    parser.add_argument(
        '-s',
        '--node-staleness',
        type=int,
        default=180,
        help=
        'Time in seconds of non-communication for a node to be marked as stale (defailt: 180s)'
    )
    parser.add_argument(
        '-x',
        '--hash-key',
        default='abracadabra',
        help="String to use for verifying UDP traffic (to disable use '')")
    parser.add_argument('-v',
                        '--verbose',
                        action='store_true',
                        default=False,
                        help='verbose logging')

    args = parser.parse_args()

    if get_ntp_offset(args.ntp_server) > 60:
        exit("your clock is not in sync (check system NTP settings)")

    root_logger = logging.getLogger()
    if args.log_file:
        file_handler = logging.FileHandler(args.log_file)
        file_handler.setFormatter(logging.Formatter(log_format))
        root_logger.addHandler(file_handler)
    if args.verbose:
        root_logger.setLevel(logging.DEBUG)
    else:
        root_logger.setLevel(logging.INFO)
        logging.getLogger('aiohttp').setLevel(logging.WARNING)

    pool = ThreadPoolExecutor(4)

    storage = Storage(args.storage_path)
    if args.cron:
        if args.cron == 'memory':
            processor = Processor(args.udp_communication_port,
                                  storage,
                                  cron=CronTab(tab="""* * * * * command"""))
        elif args.cron_user:
            processor = Processor(args.udp_communication_port,
                                  storage,
                                  cron=CronTab(tabfile=args.cron,
                                               user=args.cron_user),
                                  user=args.cron_user)
        else:
            processor = Processor(args.udp_communication_port,
                                  storage,
                                  cron=CronTab(tabfile=args.cron, user='******'),
                                  user='******')
    else:
        processor = Processor(args.udp_communication_port,
                              storage,
                              user='******')

    hash_key = None
    if args.hash_key != '':
        hash_key = args.hash_key

    with StatusProtocolServer(processor, args.udp_communication_port) as loop:

        running = True

        scheduler = Scheduler(storage, args.node_staleness)

        def timed_broadcast():
            """
            periodically broadcast system status and known jobs
            """
            while running:
                broadcast(
                    args.udp_communication_port,
                    UdpSerializer.dump(Status(get_ip(), get_load()), hash_key))
                for job in storage.cluster_jobs:
                    if job.assigned_to == get_ip():
                        job.pid = check_process(job.command)
                    for packet in UdpSerializer.dump(job, hash_key):
                        client(args.udp_communication_port, packet)
                time.sleep(args.broadcast_interval)

        def timed_schedule():
            """
            periodically check if cluster needs re-balancing
            """
            while running:
                time.sleep(23)
                if not scheduler.check_cluster_state():
                    logger.info("re-balancing cluster")
                    jobs = storage.cluster_jobs.copy()
                    for packet in UdpSerializer.dump(
                            ReBalance(timestamp=datetime.now()), hash_key):
                        client(args.udp_communication_port, packet)
                    time.sleep(5)
                    for job in jobs:
                        for packet in UdpSerializer.dump(job, hash_key):
                            client(args.udp_communication_port, packet)

        async def scheduled_broadcast():
            await loop.run_in_executor(pool, timed_broadcast)

        async def scheduled_rebalance():
            await loop.run_in_executor(pool, timed_schedule)

        async def save_schedule():
            """
            auto save every 100 seconds
            """
            while running:
                await asyncio.sleep(100)
                await storage.save()

        logger.info("setting broadcast interval to {0} seconds".format(
            args.broadcast_interval))
        loop.create_task(scheduled_broadcast())
        loop.create_task(scheduled_rebalance())
        if args.storage_path:
            loop.create_task(save_schedule())

        logger.info(
            "starting web application server on http://{0}:{1}/".format(
                get_ip(), args.web_port))

        if args.cron_user:
            s = Site(scheduler,
                     storage,
                     args.udp_communication_port,
                     cron=processor.cron,
                     user=args.cron_user,
                     hash_key=hash_key)
        else:
            s = Site(scheduler,
                     storage,
                     args.udp_communication_port,
                     cron=processor.cron,
                     hash_key=hash_key)
        runner = AppRunner(s.app)
        loop.run_until_complete(runner.setup())
        site_instance = TCPSite(runner, port=args.web_port)
        loop.run_until_complete(site_instance.start())

        try:
            loop.run_forever()
        except:
            logger.info("interrupt received")

        logger.info("stopping web application")
        loop.run_until_complete(site_instance.stop())

        running = False

        if args.storage_path:
            loop.create_task(storage.save())

        logger.debug("waiting for background tasks to finish")
        pending_tasks = [
            task for task in asyncio.Task.all_tasks() if not task.done()
        ]
        loop.run_until_complete(asyncio.gather(*pending_tasks))

    logger.info("elvis has left the building")