async def __get_user_from_remote(self): """ 从远程拉取用户列表 :return: """ user_data = await self.api.fetch_user_list() User.create_or_update_user_from_data_list(user_data)
async def flush_metrics_to_remote(url): fields = [ User.user_id, User.ip_list, User.tcp_conn_num, User.upload_traffic, User.download_traffic, ] with db.atomic("EXCLUSIVE"): users = list(User.select(*fields).where(User.need_sync == True)) User.update( ip_list=set(), upload_traffic=0, download_traffic=0, need_sync=False ).where(User.need_sync == True).execute() data = [] for user in users: data.append( { "user_id": user.user_id, "ip_list": list(user.ip_list), "tcp_conn_num": user.tcp_conn_num, "upload_traffic": user.upload_traffic, "download_traffic": user.download_traffic, } ) async with httpx.AsyncClient() as client: await client.post(url, json={"data": data})
def __get_user_from_json(path): """ 从JSON配置文件中创建或更新User表 :param path: :return: """ with open(path, "r") as f: data = json.load(f) User.create_or_update_user_from_data_list(data["users"])
async def start_remote_sync_server(self, api_endpoint, sync_time): try: User.create_or_update_from_remote(api_endpoint) # TODO 用户流量记录 # UserServer.flush_metrics_to_remote(api_endpoint) for user in User.select(): await self.loop.create_task(self.init_server(user)) except Exception as e: logging.warning(f"sync user error {e}") self.loop.call_later(sync_time, self.start_remote_sync_server, api_endpoint, sync_time)
async def __report_user_stats(self): """ 上报用户数据 :return: """ users = User.select().where(User.is_deleted == False) User.update(conn_ip_set=set(), upload_traffic=0, download_traffic=0, total_traffic=0).where(User.is_deleted == False) await self.api.report_user_stats(user_data=users)
async def flush_metrics_to_remote(url): data = [{ "user_id": user.user_id, "ip_list": list(user.ip_list), "tcp_conn_num": user.tcp_conn_num, "upload_traffic": user.upload_traffic, "download_traffic": user.download_traffic, } for user in User.get_need_sync_user_metrics()] async with httpx.AsyncClient() as client: try: await client.post(url, json={"data": data}) except Exception as e: logging.warning(f"flush_metrics_to_remote error: {e}") else: User.reset_need_sync_user_traffic()
def __init__( self, user_port=None, access_user: User = None, ts_protocol=flag.TRANSPORT_TCP, peername=None, ): self.user_port = user_port self.access_user = access_user self.ts_protocol = ts_protocol self.peername = peername self.cipher = None self._buffer = bytearray() if self.access_user: self.method = access_user.method else: self.method = (User.list_by_port(self.user_port).first().method ) # NOTE 所有的user用的加密方式必须是一种 self.cipher_cls = SUPPORT_METHODS.get(self.method) if not self.cipher_cls: raise Exception(f"暂时不支持这种加密方式:{self.method}") if self.cipher_cls.AEAD_CIPHER and self.ts_protocol == flag.TRANSPORT_TCP: self._first_data_len = self.cipher_cls.tcp_first_data_len() else: self._first_data_len = 0
def get_cipher_by_port(cls, port) -> CipherMan: user_list = User.list_by_port(port) if len(user_list) == 1: access_user = user_list[0] else: access_user = None return cls(user_list, access_user=access_user)
def get_cipher_by_port(cls, port, ts_protocol, peername) -> CipherMan: user_query = User.list_by_port(port) access_user = user_query.first() if user_query.count() == 1 else None return cls(port, access_user=access_user, ts_protocol=ts_protocol, peername=peername)
def _find_access_user(self, first_data: bytes): """通过auth校验来找到正确的user""" with memoryview(first_data) as d: salt = first_data[:self.cipher_cls.SALT_SIZE] if salt in self.bf: raise RuntimeError("repeated salt founded!") else: self.bf.add(salt) t1 = time.time() cnt = 0 for user in User.list_by_port(self.user_port).iterator(): if not self.last_access_user: self.last_access_user = user try: cnt += 1 cipher = self.cipher_cls(user.password) with memoryview(first_data) as d: if self.ts_protocol == flag.TRANSPORT_TCP: cipher.decrypt(d) else: cipher.unpack(d) self.access_user = user break except ValueError as e: if e.args[0] != "MAC check failed": raise e del cipher logging.info( f"用户:{self.access_user} 一共寻找了{ cnt }个user,共花费{(time.time()-t1)*1000}ms" )
def __init__( self, user_port=None, access_user: User = None, ts_protocol=flag.TRANSPORT_TCP, ): self.user_port = user_port self.access_user = access_user self.ts_protocol = ts_protocol self.cipher = None self._buffer = bytearray() self.last_access_user = None if self.access_user: self.method = access_user.method else: self.method = (User.list_by_port(self.user_port).first().method ) # NOTE 所有的user用的加密方式必须是一种 self.cipher_cls = self.SUPPORT_METHODS.get(self.method) if self.cipher_cls.AEAD_CIPHER: if self.ts_protocol == flag.TRANSPORT_TCP: self._first_data_len = self.cipher_cls.tcp_first_data_len() else: self._first_data_len = self.cipher_cls.udp_first_data_len() else: self._first_data_len = 0
def get_cipher_by_port(cls, port, ts_protocol) -> CipherMan: user_query = User.list_by_port(port) if user_query.count() == 1: access_user = user_query.first() else: access_user = None return cls(port, access_user=access_user, ts_protocol=ts_protocol)
def decrypt(self, data: bytes): if (self.access_user is None and len(data) + len(self._buffer) < self._first_data_len): self._buffer.extend(data) return if not self.access_user: self._buffer.extend(data) if self.ts_protocol == flag.TRANSPORT_TCP: first_data = self._buffer[:self._first_data_len] else: first_data = self._buffer salt = first_data[:self.cipher_cls.SALT_SIZE] if salt in self.bf: raise RuntimeError("repeated salt founded!") else: self.bf.add(salt) access_user = User.find_access_user(self.user_port, self.method, self.ts_protocol, first_data) if not access_user or access_user.enable is False: raise RuntimeError( f"can not find enable access user: {self.user_port}-{self.ts_protocol}-{self.cipher_cls}" ) self.access_user = access_user data = bytes(self._buffer) if not self.cipher: self.cipher = self.cipher_cls(self.access_user.password) self.record_user_traffic(len(data), 0) if self.ts_protocol == flag.TRANSPORT_TCP: return self.cipher.decrypt(data) else: return self.cipher_cls(self.access_user.password).unpack(data)
async def sync_from_remote(self): try: User.flush_metrics_to_remote(self.api_endpoint) User.create_or_update_from_remote(self.api_endpoint) except Exception as e: logging.warning(f"sync user error {e}") for user in User.select().where(User.enable == True): await self.loop.create_task(self.init_server(user)) for user in User.select().where(User.enable == False): self.close_user_server(user) self.loop.call_later(self.sync_time, self.loop.create_task, self.sync_from_remote())
def start_ss_cron_job(sync_time, use_json=False): from shadowsocks.mdb.models import User, UserServer loop = asyncio.get_event_loop() try: if use_json: User.create_or_update_from_json("userconfigs.json") else: User.create_or_update_from_remote() UserServer.flush_data_to_remote() User.init_user_servers() except Exception as e: logging.warning(f"sync user error {e}") loop.call_later(sync_time, start_ss_cron_job, sync_time, use_json)
async def start_and_check_ss_server(self): """ 启动ss server并且定期检查是否要开启新的server TODO 关闭不需要的server """ if self.use_json: await self.sync_from_json_cron() else: await self.sync_from_remote_cron() for user in User.select().where(User.enable == True): try: await self.init_server(user) except Exception as e: logging.error(e) self.loop.stop() self.loop.call_later( self.sync_time, self.loop.create_task, self.start_and_check_ss_server(), )
async def start_and_check_ss_server(self): """ 启动ss server并且定期检查是否要开启新的server TODO 关闭不需要的server :return: """ if self.use_json: await self.__sync_from_json() else: await self.__sync_from_remote() for user in User.select().where(User.is_deleted == False): try: await self.__init_server(user) except Exception as e: logger.exception(e) self.loop.stop() self.loop.call_later( self.sync_time, self.loop.create_task, self.start_and_check_ss_server(), )
async def start_ss_json_server(self): User.create_or_update_from_json("userconfigs.json") for user in User.select().where(User.enable == True): await self.loop.create_task(self.init_server(user))
def get_cipher_by_port(cls, port) -> CipherMan: user_list = User.list_by_port(port) if len(user_list) != 1: raise ValueError("单个端口找到了多个用户") return cls(user_list[0])
def create_or_update_from_json(path): with open(path, "r") as f: data = json.load(f) User.create_or_update_by_user_data_list(data["users"])
async def start_ss_server(self): for user in User.select().where(User.enable == True): await self.loop.create_task(self.init_server(user)) for user in User.select().where(User.enable == False): self.close_user_server(user)
def test_find_access_user(app): users = User.select(User.port == 1025, User.method == "chacha20-ietf-poly1305") first_user = users.first() print(first_user)
async def get_user_from_remote(url): async with httpx.AsyncClient() as client: res = await client.get(url) User.create_or_update_by_user_data_list(res.json()["users"])
def shutdown(): User.shutdown_user_servers() loop.stop()
def app(): app = App() app._prepare() User.sync_from_json_cron(10) return app
async def sync_from_json_cron(self, sync_time): try: User.create_or_update_from_json("userconfigs.json") except Exception as e: logging.warning(f"sync user from json error {e}")