예제 #1
0
def run_app(app, port, loop):
    runner = AppRunner(app, access_log=None)
    loop.run_until_complete(runner.setup())

    site = TCPSite(runner, HOST, port, shutdown_timeout=0.01)
    loop.run_until_complete(site.start())

    try:
        loop.run_forever()
    except KeyboardInterrupt:  # pragma: no branch
        pass
    finally:
        logger.info('shutting down server...')
        start = loop.time()
        with contextlib.suppress(asyncio.TimeoutError, KeyboardInterrupt):
            loop.run_until_complete(runner.cleanup())
        logger.debug('shutdown took %0.2fs', loop.time() - start)
예제 #2
0
파일: application.py 프로젝트: witlox/dcron
def main():
    """
    entry point
    """
    parser = argparse.ArgumentParser(
        description='Distributed Cronlike Scheduler')

    parser.add_argument('-l',
                        '--log-file',
                        default=None,
                        help='path to store logfile')
    parser.add_argument('-p',
                        '--storage-path',
                        default=None,
                        help='directory where to store cache')
    parser.add_argument('-u',
                        '--udp-communication-port',
                        type=int,
                        default=12345,
                        help='communication port (default: 12345)')
    parser.add_argument('-i',
                        '--broadcast-interval',
                        type=int,
                        default=5,
                        help='interval for broadcasting data over UDP')
    parser.add_argument(
        '-c',
        '--cron',
        default=None,
        help=
        'crontab to use (default: /etc/crontab, use `memory` to not save to file'
    )
    parser.add_argument('-d',
                        '--cron-user',
                        default=None,
                        help='user for storing cron entries')
    parser.add_argument('-w',
                        '--web-port',
                        type=int,
                        default=8080,
                        help='web hosting port (default: 8080)')
    parser.add_argument(
        '-n',
        '--ntp-server',
        default='pool.ntp.org',
        help='NTP server to detect clock skew (default: pool.ntp.org)')
    parser.add_argument(
        '-s',
        '--node-staleness',
        type=int,
        default=180,
        help=
        'Time in seconds of non-communication for a node to be marked as stale (defailt: 180s)'
    )
    parser.add_argument(
        '-x',
        '--hash-key',
        default='abracadabra',
        help="String to use for verifying UDP traffic (to disable use '')")
    parser.add_argument('-v',
                        '--verbose',
                        action='store_true',
                        default=False,
                        help='verbose logging')

    args = parser.parse_args()

    if get_ntp_offset(args.ntp_server) > 60:
        exit("your clock is not in sync (check system NTP settings)")

    root_logger = logging.getLogger()
    if args.log_file:
        file_handler = logging.FileHandler(args.log_file)
        file_handler.setFormatter(logging.Formatter(log_format))
        root_logger.addHandler(file_handler)
    if args.verbose:
        root_logger.setLevel(logging.DEBUG)
    else:
        root_logger.setLevel(logging.INFO)
        logging.getLogger('aiohttp').setLevel(logging.WARNING)

    pool = ThreadPoolExecutor(4)

    storage = Storage(args.storage_path)
    if args.cron:
        if args.cron == 'memory':
            processor = Processor(args.udp_communication_port,
                                  storage,
                                  cron=CronTab(tab="""* * * * * command"""))
        elif args.cron_user:
            processor = Processor(args.udp_communication_port,
                                  storage,
                                  cron=CronTab(tabfile=args.cron,
                                               user=args.cron_user),
                                  user=args.cron_user)
        else:
            processor = Processor(args.udp_communication_port,
                                  storage,
                                  cron=CronTab(tabfile=args.cron, user='******'),
                                  user='******')
    else:
        processor = Processor(args.udp_communication_port,
                              storage,
                              user='******')

    hash_key = None
    if args.hash_key != '':
        hash_key = args.hash_key

    with StatusProtocolServer(processor, args.udp_communication_port) as loop:

        running = True

        scheduler = Scheduler(storage, args.node_staleness)

        def timed_broadcast():
            """
            periodically broadcast system status and known jobs
            """
            while running:
                broadcast(
                    args.udp_communication_port,
                    UdpSerializer.dump(Status(get_ip(), get_load()), hash_key))
                for job in storage.cluster_jobs:
                    if job.assigned_to == get_ip():
                        job.pid = check_process(job.command)
                    for packet in UdpSerializer.dump(job, hash_key):
                        client(args.udp_communication_port, packet)
                time.sleep(args.broadcast_interval)

        def timed_schedule():
            """
            periodically check if cluster needs re-balancing
            """
            while running:
                time.sleep(23)
                if not scheduler.check_cluster_state():
                    logger.info("re-balancing cluster")
                    jobs = storage.cluster_jobs.copy()
                    for packet in UdpSerializer.dump(
                            ReBalance(timestamp=datetime.now()), hash_key):
                        client(args.udp_communication_port, packet)
                    time.sleep(5)
                    for job in jobs:
                        for packet in UdpSerializer.dump(job, hash_key):
                            client(args.udp_communication_port, packet)

        async def scheduled_broadcast():
            await loop.run_in_executor(pool, timed_broadcast)

        async def scheduled_rebalance():
            await loop.run_in_executor(pool, timed_schedule)

        async def save_schedule():
            """
            auto save every 100 seconds
            """
            while running:
                await asyncio.sleep(100)
                await storage.save()

        logger.info("setting broadcast interval to {0} seconds".format(
            args.broadcast_interval))
        loop.create_task(scheduled_broadcast())
        loop.create_task(scheduled_rebalance())
        if args.storage_path:
            loop.create_task(save_schedule())

        logger.info(
            "starting web application server on http://{0}:{1}/".format(
                get_ip(), args.web_port))

        if args.cron_user:
            s = Site(scheduler,
                     storage,
                     args.udp_communication_port,
                     cron=processor.cron,
                     user=args.cron_user,
                     hash_key=hash_key)
        else:
            s = Site(scheduler,
                     storage,
                     args.udp_communication_port,
                     cron=processor.cron,
                     hash_key=hash_key)
        runner = AppRunner(s.app)
        loop.run_until_complete(runner.setup())
        site_instance = TCPSite(runner, port=args.web_port)
        loop.run_until_complete(site_instance.start())

        try:
            loop.run_forever()
        except:
            logger.info("interrupt received")

        logger.info("stopping web application")
        loop.run_until_complete(site_instance.stop())

        running = False

        if args.storage_path:
            loop.create_task(storage.save())

        logger.debug("waiting for background tasks to finish")
        pending_tasks = [
            task for task in asyncio.Task.all_tasks() if not task.done()
        ]
        loop.run_until_complete(asyncio.gather(*pending_tasks))

    logger.info("elvis has left the building")
예제 #3
0
class WebIF(object):
    def __init__(self, loop, config):
        self.loop = loop
        self.config = config
        if config.has_option("global", "key"):
            self.fernet = Fernet(config.get("global", "key"))
        else:
            log(__name__).error(
                "You need to add a key to section [global], e.g.\n"
                "key: %s",
                Fernet.generate_key().decode("ascii"))
            raise NoOptionError("key", "global")

        # create data dirs, if needed
        basedir = config.get("global", "basedir")
        for section in config.sections():
            if section.startswith("dir:"):
                dir_name = os.path.join(basedir, section[4:])
                os.makedirs(dir_name, exist_ok=True)

        self.app = web.Application()
        router = self.app.router
        router.add_route('GET', "/favicon.ico", self.favicon)
        router.add_route('POST', "/login", self.login)
        router.add_route('GET', "/.well-known/acme-challenge/{token}",
                         self.serve_letsencrypt)
        router.add_route('GET',
                         "/.well-known/acme-challenge/upload/{token}/{thumb}",
                         self.upload_letsencrypt)
        router.add_route('GET', "/privacy/", self.privacy)
        router.add_route('GET', "/{token}/auth/privacy/{path:.*}",
                         self.privacy)
        router.add_route('GET', "/{token}/auth/{path:.*}", self.handle)
        router.add_route('POST', "/{token}/auth/{path:.*}", self.handle)
        router.add_route('GET', "/{path:.*}", self.handle, name='handle-get')
        router.add_route('POST', "/{path:.*}", self.handle, name='handle-post')

        self.runner = AppRunner(self.app,
                                access_log_format=config.get(
                                    "logging", "access_log_format"))
        self.loop.run_until_complete(self.runner.setup())
        self.cert_watcher = None

        loop.run_until_complete(self.init_site("http_site"))
        loop.run_until_complete(self.init_site("https_site"))

        self.letsencrypt_data = dict()

    async def init_site(self, site_name):
        if getattr(self, site_name, None):
            await getattr(self, site_name).stop()

        if site_name == "https_site":
            if not self.config.has_option("global", "https_port"):
                return
            port = self.config.getint("global", "https_port")
            ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLSv1_2)
            ssl_context.load_cert_chain(
                certfile=config.get("global", "certfile"),
                keyfile=config.get("global", "keyfile"))
            if self.cert_watcher is None:
                self.cert_watcher = loop.create_task(self.watch_cert())
        else:
            if not self.config.has_option("global", "http_port"):
                return
            port = self.config.getint("global", "http_port")
            ssl_context = None

        log(__name__).info("Starting fresh site %s on port %d", site_name,
                           port)

        site = TCPSite(self.runner,
                       config.get("global", "host"),
                       port,
                       ssl_context=ssl_context,
                       reuse_address=True)

        setattr(self, site_name, site)
        await site.start()

    def close(self):
        self.loop.run_until_complete(self.runner.cleanup())
        if self.cert_watcher:
            self.cert_watcher.cancel()

    async def watch_cert(self):
        def get_mtime(fname, old_mtime):
            try:
                return os.stat(fname).st_mtime
            except OSError:
                log(__name__).warning("Error in stat", exc_info=True)
                return old_mtime

        fname = self.config.get("global", "certfile")
        ts = get_mtime(fname, None)
        while True:
            await asyncio.sleep(5)
            new_ts = get_mtime(fname, ts)
            if new_ts != ts:
                ts = new_ts
                await self.init_site("https_site")
                log(__name__).info("Updated ssl context")

    def peername(self, request):
        peername = request.transport.get_extra_info('peername')
        return peername if peername else ("???", 0)

    def redirect(self, request, user, path, msg):
        router = request.app.router['handle-get']
        if path is None:
            path = ""
        path = (self.build_prefix(user, path) + path).lstrip("/")
        if msg is None:
            raise web.HTTPFound(router.url_for(path=path))
        else:
            raise web.HTTPFound(
                router.url_for(path=path).with_query({"msg": msg}))

    async def favicon(self, request):
        favicon = self.config.get("global", "favicon")
        return web.Response(body=base64.b64decode(favicon),
                            content_type="image/vnd.microsoft.icon")

    async def serve_letsencrypt(self, request):
        token = request.match_info["token"]
        log(__name__).info("%s - - cert verification request %s",
                           self.peername(request), token)
        if token in self.letsencrypt_data:
            return web.Response(text=self.letsencrypt_data[token], status=200)
        else:
            return web.Response(text="permission denied",
                                status=403,
                                reason="permission denied")

    async def upload_letsencrypt(self, request):
        token = request.match_info["token"]
        thumb = request.match_info["thumb"]
        log(__name__).info("%s - - cert verification file %s.%s",
                           self.peername(request), token, thumb)
        self.letsencrypt_data[token] = "%s.%s" % (token, thumb)
        return web.Response(text="ok", status=200)

    async def get_user(self, request):
        """Get username either from token, basic auth header, or form keys"""

        user, password = await self.get_user_from_request(request)

        if self.config.user_perms.get(user, dict()).get("creds",
                                                        None) == password:
            log(__name__).info("%s - - User %s active", self.peername(request),
                               user)
            cred = ("%s:%s" % (user, password)).encode("ascii")
            cred = base64.b64encode(cred).decode("ascii")
            return user, password, cred

        log(__name__).warning("%s - - Anonymous active %s:%s",
                              self.peername(request), user, password)
        return UNAUTH, None, ""

    async def get_user_from_request(self, request):
        try:
            if request.content_type == "application/x-www-form-urlencoded":
                formdata = await request.post()
                request.formdata = formdata
            else:
                request.formdata = dict()

            # form keys first
            user = request.formdata.get("user", None)
            password = request.formdata.get("password", None)
            if user is not None and password is not None:
                return user, password

            if "token" in request.match_info:
                token = request.match_info["token"].encode("ascii")
                user, path = self.fernet.decrypt(token).decode("utf-8").split(
                    ":", 1)
                if path == request.match_info["path"]:
                    return user, self.config.user_perms.get(user, dict()).get(
                        "creds", None)

            # basic-auth second
            auth = request.headers.get("authorization", "")
            if auth.lower().startswith("basic "):
                return base64.b64decode(auth[6:]).decode("utf8").split(":", 1)

        except Exception:
            log(__name__).info("Error", exc_info=True)
        return UNAUTH, None

    def build_prefix(self, user, path):
        if user == UNAUTH:
            return "/"
        else:
            token = self.fernet.encrypt(
                ("%s:%s" % (user, path)).encode("utf-8"))
            return "/%s/auth/" % token.decode("ascii")

    @redirect_on_exception
    async def login(self, request):
        user, password, cred = await self.get_user(request)
        if user == UNAUTH:
            msg = "signed out"
        else:
            msg = "%s signed in" % user
        self.redirect(request, user, "", msg)

    @redirect_on_exception
    async def handle(self, request):
        path = request.match_info.get("path", "")
        fullname, dirname, fname = self.normalize_path(path)

        action = request.query.get("action", None)
        user, password, cred = await self.get_user(request)

        if fname is None and action in [None, "list"]:
            if not (has_access(user, dirname, Access.LIST)
                    or has_access(user, dirname, Access.UPLOAD)
                    or has_access(user, dirname, Access.MKDIR)):
                raise AccessError(HTTPStatus.UNAUTHORIZED, "forbidden")

            subdirs, files = list_dir(fullname)
            filter_file_list(user, dirname, subdirs, files)
            msg = request.query.get("msg", None)
            template = jinja2.Template(TEMPLATE)
            resp = web.Response(
                status=200,
                content_type='text/html',
                text=template.render(
                    dirname=dirname,
                    parentdir=None
                    if not dirname else os.path.split(dirname)[0],
                    files=[{
                        "name": f,
                        "path": os.path.join(dirname, f)
                    } for f in files],
                    subdirs=[{
                        "name": s,
                        "path": os.path.join(dirname, s)
                    } for s in subdirs],
                    statusmessage=msg,
                    prefix="/",
                    allow_upload=has_access(user, dirname, Access.UPLOAD),
                    allow_delete=has_access(user, dirname, Access.DELETE),
                    allow_fetch=has_access(user, dirname, Access.FETCH),
                    allow_mkdir=has_access(user, dirname, Access.MKDIR),
                    debug=config.getboolean("global", "debug"),
                    user=user,
                    bp=self.build_prefix))

        elif fname is None and action == "upload":
            if has_access(user, dirname, Access.UPLOAD):
                resp = await self.upload_file(request, user, dirname)
            else:
                raise AccessError(401, "forbidden")

        elif fname and action == "delete":
            if has_access(user, dirname, Access.DELETE):
                resp = await self.delete_file(request, user, dirname, fname)
            else:
                raise AccessError(401, "forbidden")

        elif fname is None and action == "mkdir":
            if has_access(user, dirname, Access.MKDIR):
                resp = await self.make_dir(request, user, dirname)
            else:
                raise AccessError(401, "forbidden")

        elif fname and action is None:
            if has_access(user, dirname, Access.FETCH):
                log(__name__).info("%s - - Download %s by %s",
                                   self.peername(request), fullname, user)
                resp = await self.streamfile(request, fullname)
            else:
                raise AccessError(401, "forbidden")

        resp.headers["Accept-Ranges"] = "bytes"
        return resp

    async def make_dir(self, request, user, dirname):
        filename = text.get_valid_filename(request.formdata["dirname"])
        name = os.path.join(dirname, filename)
        fullname = os.path.join(config.get("global", "basedir"), name)
        os.makedirs(fullname, exist_ok=True)
        log(__name__).info("%s - - Directory %s created by %s",
                           self.peername(request), fullname, user)
        self.redirect(request, user, dirname, "Directory %s created" % name)

    async def upload_file(self, request, user, directory):
        if request.method != "POST" or request.content_type != "multipart/form-data":
            AccessError(400, "invalid")

        reader = await request.multipart()
        field = await reader.next()
        if field.name != "file":
            AccessError(400, "invalid")
        fname = text.get_valid_filename(os.path.basename(field.filename))
        name = os.path.join(directory, fname)
        fullname = os.path.join(config.get("global", "basedir"), name)
        if not fname:
            self.redirect(request, user, directory, "invalid filename")
        if os.access(fullname, os.R_OK):
            raise AccessError(403, "duplicate")
        with open(fullname, 'wb') as f:
            while True:
                chunk = await field.read_chunk()  # 8192 bytes by default.
                if not chunk:
                    break
                f.write(chunk)
        log(__name__).info("%s - - File %s uploaded by %s",
                           self.peername(request), name, user)
        size = os.lstat(fullname).st_size
        self.redirect(request, user, directory,
                      "%d bytes saved as %s" % (size, fname))

    async def delete_file(self, request, user, dirname, filename):
        name = os.path.join(dirname, filename)
        fullname = os.path.join(config.get("global", "basedir"), name)
        os.remove(fullname)
        log(__name__).info("%s - - File %s deleted by %s",
                           self.peername(request), fullname, user)
        self.redirect(request, user, dirname, "File %s deleted" % name)
        raise AccessError(401, "permission denied")

    async def streamfile(self, request, fullname):
        rng = request.http_range

        resp = web.StreamResponse(status=200 if rng.start is None else 206,
                                  reason="OK",
                                  headers={
                                      'Content-Type': get_mime_type(fullname),
                                      'Accept-Ranges': 'bytes'
                                  })

        size = length = os.path.getsize(fullname)

        try:

            with open(fullname, 'rb') as f:
                if rng.start is not None:
                    length -= rng.start
                    if rng.stop is not None:
                        length = min(rng.stop - rng.start + 1, length)
                    cr = "bytes {0}-{1}/{2}".format(rng.start,
                                                    rng.start + length - 1,
                                                    size)
                    resp.headers["Content-Range"] = cr
                    f.seek(rng.start)
                await resp.prepare(request)
                while length > 0:
                    buf = f.read(min(8192, length))
                    if not buf:
                        break
                    length -= len(buf)
                    await resp.write(buf)

                await resp.write_eof()
        except ConnectionResetError:
            # client went away
            pass

        return resp

    async def privacy(self, request):
        path = request.match_info.get("path", "")
        print("Path=", path)
        user, password, cred = await self.get_user(request)
        self.redirect(request, user, path,
                      self.config.get("global", "gdprmsg"))

    def normalize_path(self, path):
        basedir = self.config.get("global", "basedir")
        name = os.path.abspath(os.path.join(basedir, path))
        path = name[len(basedir):].strip(os.path.sep)

        if name != basedir and not name.startswith(basedir + os.path.sep):
            log(__name__).warning("Outside base: %s", name)
            raise AccessError(403, "forbidden")
        try:
            sr = os.stat(name)
        except FileNotFoundError:
            log(__name__).warning("File not found: %s", name)
            raise AccessError(404, "not found")
        except Exception:
            log(__name__).error("Stat error on %s", name)
            raise AccessError(401, "forbidden")

        if sr.st_uid != os.geteuid():
            log(__name__).warning("Wrong owner found for %s", name)

        if stat.S_ISDIR(sr.st_mode):
            return name, path, None

        if stat.S_ISREG(sr.st_mode):
            dirname, fname = os.path.split(path)
            return name, dirname, fname

        raise AccessError(401, "forbidden")
예제 #4
0
def prepare_app(app,
                *,
                host=None,
                port=None,
                path=None,
                sock=None,
                shutdown_timeout=60.0,
                ssl_context=None,
                backlog=128,
                access_log_class=helpers.AccessLogger,
                access_log_format=helpers.AccessLogger.LOG_FORMAT,
                access_log=access_logger,
                handle_signals=True,
                reuse_address=None,
                reuse_port=None):
    """
    Slightly modified version of aiohttp.web.run_app, where the server is not
    really started, but the coroutine is returned.
    This allows to caller to run multiple apps at once.
    """
    loop = asyncio.get_event_loop()

    if asyncio.iscoroutine(app):
        app = loop.run_until_complete(app)

    runner = AppRunner(app,
                       handle_signals=handle_signals,
                       access_log_class=access_log_class,
                       access_log_format=access_log_format,
                       access_log=access_log)

    loop.run_until_complete(runner.setup())

    sites = []

    if host is not None:
        if isinstance(host, (str, bytes, bytearray, memoryview)):
            sites.append(
                TCPSite(runner,
                        host,
                        port,
                        shutdown_timeout=shutdown_timeout,
                        ssl_context=ssl_context,
                        backlog=backlog,
                        reuse_address=reuse_address,
                        reuse_port=reuse_port))
        else:
            for h in host:
                sites.append(
                    TCPSite(runner,
                            h,
                            port,
                            shutdown_timeout=shutdown_timeout,
                            ssl_context=ssl_context,
                            backlog=backlog,
                            reuse_address=reuse_address,
                            reuse_port=reuse_port))
    elif path is None and sock is None or port is not None:
        sites.append(
            TCPSite(runner,
                    port=port,
                    shutdown_timeout=shutdown_timeout,
                    ssl_context=ssl_context,
                    backlog=backlog,
                    reuse_address=reuse_address,
                    reuse_port=reuse_port))

    if path is not None:
        if isinstance(path, (str, bytes, bytearray, memoryview)):
            sites.append(
                UnixSite(runner,
                         path,
                         shutdown_timeout=shutdown_timeout,
                         ssl_context=ssl_context,
                         backlog=backlog))
        else:
            for p in path:
                sites.append(
                    UnixSite(runner,
                             p,
                             shutdown_timeout=shutdown_timeout,
                             ssl_context=ssl_context,
                             backlog=backlog))

    if sock is not None:
        if not isinstance(sock, Iterable):
            sites.append(
                SockSite(runner,
                         sock,
                         shutdown_timeout=shutdown_timeout,
                         ssl_context=ssl_context,
                         backlog=backlog))
        else:
            for s in sock:
                sites.append(
                    SockSite(runner,
                             s,
                             shutdown_timeout=shutdown_timeout,
                             ssl_context=ssl_context,
                             backlog=backlog))

    runner.prepared_sites = sites
    return runner