Exemple #1
0
class App:
    def __init__(self, debug=False):
        self.debug = debug
        if not self.debug:
            uvloop.install()

        self.loop = asyncio.get_event_loop()
        self.prepared = False
        self.proxyman = ProxyMan()

    def _init_config(self):
        self.config = {
            "GRPC_HOST": os.getenv("SS_GRPC_HOST", "127.0.0.1"),
            "GRPC_PORT": os.getenv("SS_GRPC_PORT", "5000"),
            "SENTRY_DSN": os.getenv("SS_SENTRY_DSN"),
            "API_ENDPOINT": os.getenv("SS_API_ENDPOINT"),
            "LOG_LEVEL": os.getenv("SS_LOG_LEVEL", "info"),
            "SYNC_TIME": int(os.getenv("SS_SYNC_TIME", 60)),
            "STREAM_DNS_SERVER": os.getenv("SS_STREAM_DNS_SERVER"),
            "METRICS_PORT": os.getenv("SS_METRICS_PORT"),
            "TIME_OUT_LIMIT": int(os.getenv("SS_TIME_OUT_LIMIT", 60)),
            "USER_TCP_CONN_LIMIT": int(os.getenv("SS_TCP_CONN_LIMIT", 60)),
        }

        self.grpc_host = self.config["GRPC_HOST"]
        self.grpc_port = self.config["GRPC_PORT"]
        self.log_level = self.config["LOG_LEVEL"]
        self.sync_time = self.config["SYNC_TIME"]
        self.sentry_dsn = self.config["SENTRY_DSN"]
        self.api_endpoint = self.config["API_ENDPOINT"]
        self.timeout_limit = self.config["TIME_OUT_LIMIT"]
        self.stream_dns_server = self.config["STREAM_DNS_SERVER"]
        self.user_tcp_conn_limit = self.config["USER_TCP_CONN_LIMIT"]

        self.metrics_port = self.config["METRICS_PORT"]
        self.use_sentry = True if self.sentry_dsn else False
        self.use_json = False if self.api_endpoint else True
        self.use_grpc = True if self.grpc_host and self.grpc_port else False

    def _init_logger(self):
        """
        basic log config
        """
        log_levels = {
            "CRITICAL": 50,
            "ERROR": 40,
            "WARNING": 30,
            "INFO": 20,
            "DEBUG": 10,
        }
        if self.debug:
            level = 10
        else:
            level = log_levels.get(self.log_level.upper(), 10)
        logging.basicConfig(
            format="[%(levelname)s]%(asctime)s - %(filename)s - %(funcName)s "
            "line:%(lineno)d: - %(message)s",
            level=level,
        )

    def _init_memory_db(self):
        from shadowsocks.mdb import BaseModel, models

        for _, model in inspect.getmembers(models, inspect.isclass):
            if issubclass(model, BaseModel) and model != BaseModel:
                model.create_table()
                logging.info(f"正在创建{model}内存数据库")

    # def _init_sentry(self):
    # TODO 升级到最新的sentry-sdk
    # def sentry_exception_handler(loop, context):
    #     try:
    #         raise context["exception"]
    #     except TimeoutError:
    #         logging.error(f"socket timeout msg: {context['message']}")
    #     except Exception:
    #         logging.error(f"unhandled error msg: {context['message']}")
    #         self.sentry_client.captureException(**context)

    # if not self.use_sentry:
    #     return
    # self.sentry_client = raven.Client(self.sentry_dsn, transport=AioHttpTransport)
    # self.loop.set_exception_handler(sentry_exception_handler)
    # logging.info("Init Sentry Client...")

    def _prepare(self):
        if self.prepared:
            return
        self._init_config()
        self._init_logger()
        self._init_memory_db()
        # self._init_sentry()
        self.loop.add_signal_handler(signal.SIGTERM, self.shutdown)
        self.prepared = True

    async def start_grpc_server(self):
        from shadowsocks.services import AioShadowsocksServicer

        self.grpc_server = Server([AioShadowsocksServicer()], loop=self.loop)
        await self.grpc_server.start(self.grpc_host, self.grpc_port)
        logging.info(f"Start Grpc Server on {self.grpc_host}:{self.grpc_port}")

    async def start_metrics_server(self):
        app = web.Application()
        app.router.add_get("/metrics", aio.web.server_stats)
        runner = web.AppRunner(app)
        await runner.setup()
        self.metrics_server = web.TCPSite(runner, "0.0.0.0", self.metrics_port)
        await self.metrics_server.start()
        logging.info(
            f"Start Metrics Server At: http://0.0.0.0:{self.metrics_port}/metrics"
        )

    def run(self):
        self._prepare()

        if self.use_json:
            self.loop.create_task(self.proxyman.start_ss_json_server())
        else:
            self.loop.create_task(
                self.proxyman.start_remote_sync_server(self.api_endpoint,
                                                       self.sync_time))

        if self.use_grpc:
            self.loop.create_task(self.start_grpc_server())

        if self.metrics_port:
            self.loop.create_task(self.start_metrics_server())

        try:
            self.loop.run_forever()
        except KeyboardInterrupt:
            logging.info("正在关闭所有ss server")
            self.shutdown()

    def shutdown(self):
        self.proxyman.close_server()
        if self.use_grpc:
            self.grpc_server.close()
            logging.info(f"grpc server closed!")
        if self.metrics_port:
            self.loop.create_task(self.metrics_server.stop())
            logging.info(f"metrics server closed!")
        pending = asyncio.all_tasks(self.loop)
        self.loop.run_until_complete(asyncio.gather(*pending))
        self.loop.stop()
Exemple #2
0
class App:
    def __init__(self) -> None:
        self._init_config()
        self._init_logger()
        self._prepared = False

    def _init_config(self):
        self.config = {
            "LISTEN_HOST": os.getenv("SS_LISTEN_HOST", "0.0.0.0"),
            "GRPC_HOST": os.getenv("SS_GRPC_HOST", "127.0.0.1"),
            "GRPC_PORT": os.getenv("SS_GRPC_PORT", "5000"),
            "SENTRY_DSN": os.getenv("SS_SENTRY_DSN"),
            "API_ENDPOINT": os.getenv("SS_API_ENDPOINT"),
            "LOG_LEVEL": os.getenv("SS_LOG_LEVEL", "info"),
            "SYNC_TIME": int(os.getenv("SS_SYNC_TIME", 60)),
            "STREAM_DNS_SERVER": os.getenv("SS_STREAM_DNS_SERVER"),
            "METRICS_PORT": os.getenv("SS_METRICS_PORT"),
            "TIME_OUT_LIMIT": int(os.getenv("SS_TIME_OUT_LIMIT", 60)),
            "USER_TCP_CONN_LIMIT": int(os.getenv("SS_TCP_CONN_LIMIT", 60)),
        }

        self.grpc_host = self.config["GRPC_HOST"]
        self.grpc_port = self.config["GRPC_PORT"]
        self.log_level = self.config["LOG_LEVEL"]
        self.sync_time = self.config["SYNC_TIME"]
        self.sentry_dsn = self.config["SENTRY_DSN"]
        self.listen_host = self.config["LISTEN_HOST"]
        self.api_endpoint = self.config["API_ENDPOINT"]
        self.timeout_limit = self.config["TIME_OUT_LIMIT"]
        self.stream_dns_server = self.config["STREAM_DNS_SERVER"]
        self.user_tcp_conn_limit = self.config["USER_TCP_CONN_LIMIT"]
        self.metrics_port = self.config["METRICS_PORT"]

        self.use_sentry = True if self.sentry_dsn else False
        self.use_json = False if self.api_endpoint else True
        self.metrics_server = None
        self.grpc_server = None

    def _init_logger(self):
        """
        basic log config
        """
        log_levels = {
            "CRITICAL": 50,
            "ERROR": 40,
            "WARNING": 30,
            "INFO": 20,
            "DEBUG": 10,
        }
        level = log_levels[self.log_level.upper()]
        logging.basicConfig(
            format=
            "[%(levelname)s]%(asctime)s %(funcName)s line:%(lineno)d %(message)s",
            level=level,
        )

    def _init_memory_db(self):

        for _, model in inspect.getmembers(models, inspect.isclass):
            if issubclass(model, BaseModel) and model != BaseModel:
                model.create_table()
                logging.info(f"正在创建{model}内存数据库")

    def _init_sentry(self):
        if not self.use_sentry:
            return
        sentry_sdk.init(dsn=self.sentry_dsn,
                        integrations=[AioHttpIntegration()])
        logging.info("Init Sentry Client...")

    def _prepare(self):
        if self._prepared:
            return
        self.loop = asyncio.get_event_loop()
        self._init_memory_db()
        self._init_sentry()
        self.loop.add_signal_handler(signal.SIGTERM, self._shutdown)
        self.proxyman = ProxyMan(self.use_json, self.sync_time,
                                 self.listen_host, self.api_endpoint)
        self._prepared = True

    def _shutdown(self):
        logging.info("正在关闭所有ss server")
        self.proxyman.close_server()
        if self.grpc_server:
            self.grpc_server.close()
            logging.info(f"grpc server closed!")
        if self.metrics_server:
            self.loop.create_task(self.metrics_server.stop())
            logging.info(f"metrics server closed!")
        self.loop.stop()

    def _run_loop(self):

        try:
            self.loop.run_forever()
        except KeyboardInterrupt:
            self._shutdown()

    async def _start_grpc_server(self):

        self.grpc_server = Server([AioShadowsocksServicer()], loop=self.loop)
        listen(self.grpc_server, RecvRequest, logging_grpc_request)
        await self.grpc_server.start(self.grpc_host, self.grpc_port)
        logging.info(f"Start grpc Server on {self.grpc_host}:{self.grpc_port}")

    async def _start_metrics_server(self):
        app = web.Application()
        app.router.add_get("/metrics", aio.web.server_stats)
        runner = web.AppRunner(app)
        await runner.setup()
        self.metrics_server = web.TCPSite(runner, "0.0.0.0", self.metrics_port)
        await self.metrics_server.start()
        logging.info(
            f"Start Metrics Server At: http://0.0.0.0:{self.metrics_port}/metrics"
        )

    def run_ss_server(self):
        self._prepare()
        self.loop.create_task(self.proxyman.start_and_check_ss_server())
        if self.metrics_port:
            self.loop.create_task(self._start_metrics_server())

        if self.grpc_host and self.grpc_port:
            self.loop.create_task(self._start_grpc_server())

        self._run_loop()

    def run_grpc_server(self):
        self._prepare()

        if self.grpc_host and self.grpc_port:
            self.loop.create_task(self._start_grpc_server())
        else:
            raise Exception("grpc server not config")

        self._run_loop()

    def get_user(self, user_id):
        c = SSClient(f"{self.grpc_host}:{self.grpc_port}")
        return c.get_user(user_id)
Exemple #3
0
class Server(object):
    def __init__(self) -> None:
        self.__prepared = False

    def __init_config(self) -> None:
        self.config = {
            "LISTEN_HOST": os.getenv("SS_LISTEN_HOST", "0.0.0.0"),
            "SENTRY_DSN": os.getenv("SS_SENTRY_DSN"),
            "API_ENDPOINT": os.getenv("SS_API_ENDPOINT"),
            "LOG_LEVEL": os.getenv("SS_LOG_LEVEL", "INFO"),
            "SYNC_TIME": int(os.getenv("SS_SYNC_TIME", 60)),
            "STREAM_DNS_SERVER": os.getenv("SS_STREAM_DNS_SERVER"),
            "TIME_OUT_LIMIT": int(os.getenv("SS_TIME_OUT_LIMIT", 60)),
            "USER_TCP_CONN_LIMIT": int(os.getenv("SS_TCP_CONN_LIMIT", 60)),
            "PANEL_TYPE": os.getenv("PANEL_TYPE", None),
        }
        self.log_level = self.config["LOG_LEVEL"]
        self.sync_time = self.config["SYNC_TIME"]
        self.sentry_dsn = self.config["SENTRY_DSN"]
        self.listen_host = self.config["LISTEN_HOST"]
        self.api_endpoint = self.config["API_ENDPOINT"]
        self.panel_type = self.config["PANEL_TYPE"]
        self.timeout_limit = self.config["TIME_OUT_LIMIT"]
        self.stream_dns_server = self.config["STREAM_DNS_SERVER"]
        self.user_tcp_conn_limit = self.config["USER_TCP_CONN_LIMIT"]

        self.use_sentry = True if self.sentry_dsn else False
        self.use_json = False if self.api_endpoint or self.panel_type else True

    def __prepare_logger(self) -> None:
        """
        初始化日志类
        :return:
        """
        log_levels = {
            "CRITICAL": logging.CRITICAL,
            "ERROR": logging.ERROR,
            "WARNING": logging.WARNING,
            "INFO": logging.INFO,
            "DEBUG": logging.DEBUG,
        }
        level = log_levels[self.log_level.upper()]
        logging.basicConfig(
            format=
            "%(asctime)s | %(levelname)-8s | %(name)s:%(funcName)s:%(lineno)s - %(message)s",
            level=level,
        )

    @staticmethod
    def __init_memory_db() -> None:
        """
        初始化内存数据库
        :return:
        """
        for _, model in inspect.getmembers(models, inspect.isclass):
            if issubclass(model, peewee.Model):
                model.create_table()
                logger.info(f"正在创建 {model} 内存数据库")

    def __init_loop(self) -> None:
        """
        初始化事件循环
        :return:
        """
        is_win = platform.system().lower() == "windows"
        if not is_win:
            import uvloop

            logger.info("使用 uvloop 加速")
            uvloop.install()
        else:
            logger.info("使用原生 asyncio")

        self.loop = asyncio.get_event_loop()

    def __prepare(self) -> None:
        """
        预处理
        :return:
        """
        if self.__prepared:
            return

        self.__init_config()
        self.__prepare_logger()
        self.__init_loop()
        self.__init_memory_db()

        self.proxy_man = ProxyMan(
            use_json=self.use_json,
            sync_time=self.sync_time,
            listen_host=self.listen_host,
            panel_type=self.panel_type,
            endpoint=self.api_endpoint,
        )

        self.__prepared = True

    def __shutdown(self) -> None:
        """
        停止所有服务
        :return:
        """
        logger.info("正在关闭所有 Shadowsocks 服务")
        self.proxy_man.close_server()
        self.loop.stop()

    def __run_loop(self) -> None:
        """
        启动事件循环
        :return:
        """
        try:
            self.loop.run_forever()
        except KeyboardInterrupt:
            self.__shutdown()

    def start(self) -> None:
        """
        启动服务
        :return:
        """
        self.__prepare()
        self.loop.create_task(self.proxy_man.start_and_check_ss_server())
        self.__run_loop()