class LocalHandler(TimeoutHandler): """ 事件循环一共处理五个状态 STAGE_INIT 初始状态 socket5握手 STAGE_CONNECT 连接建立阶段 从本地获取addr 进行dns解析 STAGE_STREAM 建立管道(pipe) 进行socket5传输 STAGE_DESTROY 结束连接状态 STAGE_ERROR 异常状态 """ STAGE_INIT = 0 STAGE_CONNECT = 1 STAGE_STREAM = 2 STAGE_DESTROY = -1 STAGE_ERROR = 255 def __init__(self, user): TimeoutHandler.__init__(self) self.user = user self._stage = self.STAGE_DESTROY self._peername = None self._remote = None self._cryptor = None self._transport = None self._transport_protocol = None def _init_transport_and_cryptor(self, transport, peername, protocol): self._stage = self.STAGE_INIT self._transport = transport self._peername = peername self._transport_protocol = protocol try: self._cryptor = Cryptor(self.user.method, self.user.password, self._transport_protocol) logging.debug("tcp connection made") except NotImplementedError: self.close() logging.warning("not support cipher") def close(self): if self._transport_protocol == flag.TRANSPORT_TCP: self._transport and self._transport.close() self._remote and self._remote.close() elif self._transport_protocol == flag.TRANSPORT_UDP: pass else: raise NotImplementedError def write(self, data): if not self._transport or self._transport.is_closing(): self._transport and self._transport.abort() return self.user.server.record_traffic(used_u=0, used_d=len(data)) if self._transport_protocol == flag.TRANSPORT_TCP: self._transport.write(data) elif self._transport_protocol == flag.TRANSPORT_UDP: # get the remote address to which the socket is connected self._transport.sendto(data, self._peername) else: raise NotImplementedError def handle_tcp_connection_made(self, transport, peername): self._init_transport_and_cryptor(transport, peername, flag.TRANSPORT_TCP) self.check_conn_timeout() self.user.server.record_ip(peername) def handle_udp_connection_made(self, transport, peername): self._init_transport_and_cryptor(transport, peername, flag.TRANSPORT_UDP) self.user.server.record_ip(peername) def handle_eof_received(self): self.close() logging.debug("eof received") def handle_connection_lost(self, exc): self.close() logging.debug(f"lost exc={exc}") def handle_data_received(self, data): try: data = self._cryptor.decrypt(data) except Exception as e: self.close() logging.warning(f"decrypt data error {e}") return self.user.server.record_traffic(used_u=len(data), used_d=0) if self._stage == self.STAGE_INIT: coro = self._handle_stage_init(data) asyncio.create_task(coro) elif self._stage == self.STAGE_CONNECT: coro = self._handle_stage_connect(data) asyncio.create_task(coro) elif self._stage == self.STAGE_STREAM: self._handle_stage_stream(data) elif self._stage == self.STAGE_ERROR: self._handle_stage_error() else: logging.warning(f"unknown stage:{self._stage}") async def _handle_stage_init(self, data): try: addr_type, dst_addr, dst_port, header_length = parse_header(data) except Exception as e: self.close() logging.warning(f"parse header error: {str(e)}") return if not dst_addr: self.close() logging.warning( "can't parse addr_type: {} user: {} CMD: {}".format( addr_type, self.user, self._transport_protocol)) return else: payload = data[header_length:] loop = asyncio.get_event_loop() if self._transport_protocol == flag.TRANSPORT_TCP: self._stage = self.STAGE_CONNECT # 尝试建立tcp连接,成功的话将会返回 (transport,protocol) tcp_coro = loop.create_connection( lambda: RemoteTCP(dst_addr, dst_port, payload, self), dst_addr, dst_port) try: remote_transport, remote_tcp = await tcp_coro except (IOError, OSError) as e: self.close() self._stage = self.STAGE_DESTROY logging.debug(f"connection failed , {type(e)} e: {e}") except Exception as e: self._stage = self.STAGE_ERROR self.close() logging.warning(f"connection failed, {type(e)} e: {e}") else: self._remote = remote_tcp self._stage = self.STAGE_STREAM logging.debug(f"connection established,remote {remote_tcp}") elif self._transport_protocol == flag.TRANSPORT_UDP: self._stage = self.STAGE_INIT udp_coro = loop.create_datagram_endpoint( lambda: RemoteUDP(dst_addr, dst_port, payload, self), remote_addr=(dst_addr, dst_port), ) try: await udp_coro except (IOError, OSError) as e: self.close() self._stage = self.STAGE_DESTROY logging.debug(f"connection failed , {type(e)} e: {e}") except Exception as e: self._stage = self.STAGE_ERROR self.close() logging.warning(f"connection failed, {type(e)} e: {e}") else: raise NotImplementedError async def _handle_stage_connect(self, data): # 在握手之后,会耗费一定时间来来和remote建立连接 # 但是ss-client并不会等这个时间 所以我们在这里手动sleep一会 sleep_time = 0.3 for i in range(10): sleep_time += 0.1 if self._stage == self.STAGE_CONNECT: await asyncio.sleep(sleep_time) elif self._stage == self.STAGE_STREAM: self._remote.write(data) return self.close() logging.warning( f"timeout to connect remote user: {self.user} peername: {self._peername}" ) def _handle_stage_stream(self, data): self.keep_alive_active() self._remote.write(data) logging.debug(f"relay data length {len(data)}") def _handle_stage_error(self): self.close()
class LocalHandler(TimeoutHandler): ''' 事件循环一共处理五个状态 STAGE_INIT 初始状态 socket5握手 STAGE_CONNECT 连接建立阶段 从本地获取addr 进行dns解析 STAGE_STREAM 建立管道(pipe) 进行socket5传输 STAGE_DESTROY 结束连接状态 STAGE_ERROR 异常状态 ''' STAGE_INIT = 0 STAGE_CONNECT = 1 STAGE_STREAM = 2 STAGE_DESTROY = -1 STAGE_ERROR = 255 def __init__(self, method, password, user): TimeoutHandler.__init__(self) self.pool = ServerPool() self.user = user self._key = password self._method = method self._remote = None self._cryptor = None self._peername = None self._transport = None self._transport_protocol = None self._stage = self.STAGE_DESTROY def close(self): ''' 针对tcp/udp分别关闭连接 ''' if self._transport_protocol == flag.TRANSPORT_TCP: if self._transport is not None: self._transport.close() if self.user and self.user.tcp_count > 0: self.user.tcp_count -= 1 elif self._transport_protocol == flag.TRANSPORT_UDP: pass else: raise NotImplementedError if self.user: self.user = None def write(self, data): ''' 针对tcp/udp分别写数据 ''' if self._transport_protocol == flag.TRANSPORT_TCP: self._transport.write(data) elif self._transport_protocol == flag.TRANSPORT_UDP: self._transport.sendto(data, self._peername) else: raise NotImplementedError if self.user: # 记录下载流量 self.user.once_used_d += len(data) def handle_tcp_connection_made(self, transport): ''' 处理tcp连接 get_extra_info asyncio Transports api doc: https://docs.python.org/3/library/asyncio-protocol.html ''' self.keep_alive_open() self._stage = self.STAGE_INIT self._transport = transport self._transport_protocol = flag.TRANSPORT_TCP # get the remote address to which the socket is connected self._peername = self._transport.get_extra_info('peername') try: self._cryptor = Cryptor(self._method, self._key) logging.debug('tcp connection made') except NotImplementedError: logging.warning('not support cipher') self.close() def handle_udp_connection_made(self, transport, peername): ''' 处理udp连接 ''' self._stage = self.STAGE_INIT self._transport = transport self._transport_protocol = flag.TRANSPORT_UDP self._peername = peername try: self._cryptor = Cryptor(self._method, self._key) logging.debug('udp connection made') except NotImplementedError: logging.warning('not support cipher') self.close() def handle_data_received(self, data): if self.user is None: self.close() else: # 累计并检查用户流量 self.user.once_used_u += len(data) data = self._cryptor.decrypt(data) if self._stage == self.STAGE_INIT: coro = self._handle_stage_init(data) asyncio.ensure_future(coro) elif self._stage == self.STAGE_CONNECT: coro = self._handle_stage_connect(data) asyncio.ensure_future(coro) elif self._stage == self.STAGE_STREAM: self._handle_stage_stream(data) elif self._stage == self.STAGE_ERROR: self._handle_stage_error() else: logging.warning('unknown stage:{}'.format(self._stage)) def handle_eof_received(self): logging.debug('eof received') self.close() def handle_connection_lost(self, exc): logging.debug('lost exc={exc}'.format(exc=exc)) if self._remote is not None: self._remote.close() async def _handle_stage_init(self, data): ''' 初始化连接状态(握手后建立链接) doc: https://docs.python.org/3/library/asyncio-eventloop.html ''' from shadowsocks.tcpreply import RemoteTCP # noqa from shadowsocks.udpreply import RemoteUDP # noqa atype = data[0] if atype == flag.ATYPE_IPV4: dst_addr = socket.inet_ntop(socket.AF_INET, data[1:5]) dst_port = struct.unpack('!H', data[5:7])[0] payload = data[7:] elif atype == flag.ATYPE_IPV6: dst_addr = socket.inet_ntop(socket.AF_INET6, data[1:17]) dst_port = struct.unpack('!H', data[17:19])[0] payload = data[19:] elif atype == flag.ATYPE_DOMAINNAME: domain_length = data[1] domain_index = 2 + domain_length dst_addr = data[2:domain_index] dst_port = struct.unpack('!H', data[domain_index:domain_index + 2])[0] payload = data[domain_index + 2:] else: logging.warning('unknown atype: {}'.format(atype)) self.close() return # 获取事件循环 loop = asyncio.get_event_loop() if self._transport_protocol == flag.TRANSPORT_TCP: self._stage = self.STAGE_CONNECT if self.user and self.user.tcp_count > MAX_TCP_CONNECT: self.close() return # 尝试建立tcp连接,成功的话将会返回 (transport,protocol) tcp_coro = loop.create_connection( lambda: RemoteTCP(dst_addr, dst_port, payload, self._method, self._key, self), dst_addr, dst_port) try: remote_transport, remote_instance = await tcp_coro # 记录用户的tcp连接数 if self.user: self.user.tcp_count += 1 except (IOError, OSError) as e: logging.debug('connection faild , {} e: {}'.format(type(e), e)) self.close() self._stage = self.STAGE_DESTROY except Exception as e: logging.warning('connection failed, {} e: {}'.format( type(e), e)) self.close() self._stage = self.STAGE_ERROR else: logging.debug( 'connection established,remote {}'.format(remote_instance)) self._remote = remote_instance self._stage = self.STAGE_STREAM elif self._transport_protocol == flag.TRANSPORT_UDP: self._stage = self.STAGE_INIT # 异步建立udp连接,并存入future对象 udp_coro = loop.create_datagram_endpoint(lambda: RemoteUDP( dst_addr, dst_port, payload, self._method, self._key, self), remote_addr=(dst_addr, dst_port)) asyncio.ensure_future(udp_coro) else: raise NotImplementedError async def _handle_stage_connect(self, data): logging.debug('wait until the connection established') # 在握手之后,会耗费一定时间来来和remote建立连接 # 但是ss-client并不会等这个时间 所以我们在这里手动sleep一会 for i in range(25): if self._stage == self.STAGE_CONNECT: await asyncio.sleep(0.2) elif self._stage == self.STAGE_STREAM: logging.debug('connection established') self._remote.write(data) return else: logging.debug('some error happed stage {}'.format(self._stage)) # 5s之后连接还没建立的话 超时处理 logging.warning('time out to connect remote stage {}'.format( self._stage)) return def _handle_stage_stream(self, data): logging.debug('realy data length {}'.format(len(data))) self.keep_alive_active() self._remote.write(data) def _handle_stage_error(self): self.close()
class LocalHandler(TimeoutHandler): ''' 事件循环一共处理五个状态 STAGE_INIT 初始状态 socket5握手 STAGE_CONNECT 连接建立阶段 从本地获取addr 进行dns解析 STAGE_STREAM 建立管道(pipe) 进行socket5传输 STAGE_DESTROY 结束连接状态 STAGE_ERROR 异常状态 ''' STAGE_INIT = 0 STAGE_CONNECT = 1 STAGE_STREAM = 2 STAGE_DESTROY = -1 STAGE_ERROR = 255 def __init__(self, method, password, user): TimeoutHandler.__init__(self) self.user = user self._key = password self._method = method self._remote = None self._cryptor = None self._peername = None self._transport = None self._transport_protocol = None self._stage = self.STAGE_DESTROY def destroy(self): '''尝试优化一些内存泄露的问题''' self._stage = self.STAGE_DESTROY self._key = None self._method = None self._cryptor = None self._peername = None def traffic_filter(self): if pool.filter_user(self.user) is False: return False elif self._transport is None: return False elif self._transport._sock is None: # cpython selector_events _SelectorTransport return False return True def close(self, clean=False): if self._transport_protocol == flag.TRANSPORT_TCP: if self._transport: self._transport.close() if self.user and self.user.tcp_count > 0: self.user.tcp_count -= 1 elif self._transport_protocol == flag.TRANSPORT_UDP: pass else: raise NotImplementedError if clean: self.destroy() def write(self, data): ''' 针对tcp/udp分别写数据 ''' # filter traffic if self.traffic_filter() is False: self.close(clean=True) return if self._transport_protocol == flag.TRANSPORT_TCP: try: self._transport.write(data) # 记录下载流量 self.user.once_used_d += len(data) except MemoryError: logging.warning('memory boom user_id: {}'.format( self.user.user_id)) pool.add_user_to_jail(self.user.user_id) self.close(clean=True) elif self._transport_protocol == flag.TRANSPORT_UDP: self._transport.sendto(data, self._peername) else: raise NotImplementedError def handle_tcp_connection_made(self, transport): ''' 处理tcp连接 get_extra_info asyncio Transports api doc: https://docs.python.org/3/library/asyncio-protocol.html ''' self._stage = self.STAGE_INIT self._transport_protocol = flag.TRANSPORT_TCP # filter tcp connction if not pool.filter_user(self.user): transport.close() self.close(clean=True) return self._transport = transport # get the remote address to which the socket is connected self._peername = self._transport.get_extra_info('peername') self.keep_alive_open() try: self._cryptor = Cryptor(self._method, self._key, self._transport_protocol) logging.debug('tcp connection made') except NotImplementedError: logging.warning('not support cipher') transport.close() self.close(clean=True) def handle_udp_connection_made(self, transport, peername): ''' 处理udp连接 ''' self._stage = self.STAGE_INIT self._transport = transport self._transport_protocol = flag.TRANSPORT_UDP self._peername = peername try: self._cryptor = Cryptor(self._method, self._key, self._transport_protocol) logging.debug('udp connection made') except NotImplementedError: logging.warning('not support cipher') transport.close() self.close(clean=True) def handle_data_received(self, data): # 累计并检查用户流量 self.user.once_used_u += len(data) try: data = self._cryptor.decrypt(data) except RuntimeError as e: logging.warning('decrypt data error {}'.format(e)) self.close(clean=True) return if self._stage == self.STAGE_INIT: coro = self._handle_stage_init(data) asyncio.ensure_future(coro) elif self._stage == self.STAGE_CONNECT: coro = self._handle_stage_connect(data) asyncio.ensure_future(coro) elif self._stage == self.STAGE_STREAM: self._handle_stage_stream(data) elif self._stage == self.STAGE_ERROR: self._handle_stage_error() else: logging.warning('unknown stage:{}'.format(self._stage)) def handle_eof_received(self): logging.debug('eof received') self.close() def handle_connection_lost(self, exc): logging.debug('lost exc={exc}'.format(exc=exc)) self.close() async def _handle_stage_init(self, data): ''' 初始化连接状态(握手后建立链接) doc: https://docs.python.org/3/library/asyncio-eventloop.html ''' from shadowsocks.tcpreply import RemoteTCP from shadowsocks.udpreply import RemoteUDP atype, dst_addr, dst_port, header_length = parse_header(data) if not dst_addr: logging.warning('not valid data atype:{} user: {}'.format( atype, self.user)) self.close(clean=True) return else: payload = data[header_length:] # 获取事件循环 loop = asyncio.get_event_loop() if self._transport_protocol == flag.TRANSPORT_TCP: self._stage = self.STAGE_CONNECT # 尝试建立tcp连接,成功的话将会返回 (transport,protocol) tcp_coro = loop.create_connection( lambda: RemoteTCP(dst_addr, dst_port, payload, self._method, self._key, self), dst_addr, dst_port) try: remote_transport, remote_instance = await tcp_coro # 记录用户的tcp连接数 self.user.tcp_count += 1 except (IOError, OSError) as e: logging.debug('connection faild , {} e: {}'.format(type(e), e)) self.close() self._stage = self.STAGE_DESTROY except Exception as e: logging.warning('connection failed, {} e: {}'.format( type(e), e)) self._stage = self.STAGE_ERROR self.close() else: logging.debug( 'connection established,remote {}'.format(remote_instance)) self._remote = remote_instance self._stage = self.STAGE_STREAM elif self._transport_protocol == flag.TRANSPORT_UDP: self._stage = self.STAGE_INIT # 异步建立udp连接,并存入future对象 udp_coro = loop.create_datagram_endpoint(lambda: RemoteUDP( dst_addr, dst_port, payload, self._method, self._key, self), remote_addr=(dst_addr, dst_port)) asyncio.ensure_future(udp_coro) else: raise NotImplementedError async def _handle_stage_connect(self, data): logging.debug('wait until the connection established') # 在握手之后,会耗费一定时间来来和remote建立连接 # 但是ss-client并不会等这个时间 所以我们在这里手动sleep一会 for i in range(25): if self._stage == self.STAGE_CONNECT: await asyncio.sleep(0.2) elif self._stage == self.STAGE_STREAM: logging.debug('connection established') self._remote.write(data) return else: logging.debug('some error happed stage {}'.format(self._stage)) # 5s之后连接还没建立的话 超时处理 logging.warning('time out to connect remote stage {}'.format( self._stage)) self.close() def _handle_stage_stream(self, data): logging.debug('realy data length {}'.format(len(data))) self.keep_alive_active() self._remote.write(data) def _handle_stage_error(self): self.close()
class LocalHandler(TimeoutMixin): """ 事件循环一共处理五个状态 STAGE_INIT 初始状态 socket5握手 STAGE_CONNECT 连接建立阶段 从本地获取addr 进行dns解析 STAGE_STREAM 建立管道(pipe) 进行socket5传输 STAGE_DESTROY 结束连接状态 STAGE_ERROR 异常状态 """ STAGE_INIT = 0 STAGE_CONNECT = 1 STAGE_STREAM = 2 STAGE_DESTROY = -1 STAGE_ERROR = 255 def __init__(self, user): super().__init__() self.user = user self.server = user.server self._stage = None self._peername = None self._remote = None self._cryptor = None self._transport = None self._transport_protocol = None self._is_closing = False self._connect_buffer = bytearray() def _init_transport(self, transport, peername, protocol): self._stage = self.STAGE_INIT self._transport = transport self._peername = peername self._transport_protocol = protocol def _init_cryptor(self): try: self._cryptor = Cryptor( self.user.method, self.user.password, self._transport_protocol ) except NotImplementedError: self.close() logging.warning("not support cipher") def close(self): if self._is_closing: return self._is_closing = True if self._transport_protocol == flag.TRANSPORT_TCP: self.server.incr_tcp_conn_num(-1) self._transport and self._transport.close() if self._remote: self._remote.close() # NOTE for circular reference self._remote = None elif self._transport_protocol == flag.TRANSPORT_UDP: pass else: raise NotImplementedError self._stage = self.STAGE_DESTROY ACTIVE_CONNECTION_COUNT.dec() def write(self, data): if not self._transport or self._transport.is_closing(): self._transport and self._transport.abort() return if self._transport_protocol == flag.TRANSPORT_TCP: self._transport.write(data) elif self._transport_protocol == flag.TRANSPORT_UDP: # get the remote address to which the socket is connected self._transport.sendto(data, self._peername) else: raise NotImplementedError def handle_connection_made(self, transport_type, transport, peername): self._init_transport(transport, peername, transport_type) if transport_type == flag.TRANSPORT_TCP and self.server.limited: self.server.log_limited_msg() self.close() self._init_cryptor() CONNECTION_MADE_COUNT.inc() ACTIVE_CONNECTION_COUNT.inc() def handle_eof_received(self): self.close() logging.debug("eof received") def handle_connection_lost(self, exc): self.close() logging.debug(f"lost exc={exc}") def handle_data_received(self, data): try: data = self._cryptor.decrypt(data) except Exception as e: self.close() logging.warning(f"decrypt data error {e}") return if self._stage == self.STAGE_INIT: coro = self._handle_stage_init(data) asyncio.create_task(coro) elif self._stage == self.STAGE_CONNECT: coro = self._handle_stage_connect(data) asyncio.create_task(coro) elif self._stage == self.STAGE_STREAM: self._handle_stage_stream(data) elif self._stage == self.STAGE_ERROR: self._handle_stage_error() elif self._stage == self.STAGE_DESTROY: self.close() else: logging.warning(f"unknown stage:{self._stage}") async def _handle_stage_init(self, data): if not data: return addr_type, dst_addr, dst_port, header_length = parse_header(data) if not all([addr_type, dst_addr, dst_port, header_length]): logging.warning(f"parse error addr_type: {addr_type} user: {self.user}") self.close() return else: payload = data[header_length:] logging.debug( f"[HEADER:] {addr_type} {dst_addr}:{dst_port} - {self._transport_protocol}" ) if self._transport_protocol == flag.TRANSPORT_TCP: self._stage = self.STAGE_CONNECT tcp_coro = self.loop.create_connection( lambda: RemoteTCP(dst_addr, dst_port, payload, self), dst_addr, dst_port ) try: _, remote_tcp = await tcp_coro except (IOError, OSError) as e: self.close() self._stage = self.STAGE_DESTROY logging.debug(f"connection failed , {type(e)} e: {e}") except Exception as e: self._stage = self.STAGE_ERROR self.close() logging.warning(f"connection failed, {type(e)} e: {e}") else: self._remote = remote_tcp self._stage = self.STAGE_STREAM self._remote.write(self._connect_buffer) logging.debug(f"connection ok buffer lens:{len(self._connect_buffer)}") elif self._transport_protocol == flag.TRANSPORT_UDP: udp_coro = self.loop.create_datagram_endpoint( lambda: RemoteUDP(dst_addr, dst_port, payload, self), remote_addr=(dst_addr, dst_port), ) try: await udp_coro except (IOError, OSError) as e: self.close() self._stage = self.STAGE_DESTROY logging.debug(f"connection failed , {type(e)} e: {e}") except Exception as e: self._stage = self.STAGE_ERROR self.close() logging.warning(f"connection failed, {type(e)} e: {e}") else: raise NotImplementedError async def _handle_stage_connect(self, data): # 在握手之后,会耗费一定时间来来和remote建立连接 # 但是ss-client并不会等这个时间 把数据线放进buffer self._connect_buffer.extend(data) def _handle_stage_stream(self, data): self.keep_alive() self._remote.write(data) logging.debug(f"relay data length {len(data)}") def _handle_stage_error(self): self.close()
class LocalHandler(TimeoutHandler): """ 事件循环一共处理五个状态 STAGE_INIT 初始状态 socket5握手 STAGE_CONNECT 连接建立阶段 从本地获取addr 进行dns解析 STAGE_STREAM 建立管道(pipe) 进行socket5传输 STAGE_DESTROY 结束连接状态 STAGE_ERROR 异常状态 """ STAGE_INIT = 0 STAGE_CONNECT = 1 STAGE_STREAM = 2 STAGE_DESTROY = -1 STAGE_ERROR = 255 def __init__(self, method, password, user): TimeoutHandler.__init__(self) self.user = user self.node_type = user.node_type self._key = password self._method = method self.obfs = None self._remote = None self._cryptor = None self._peername = None self._transport = None self._transport_protocol = None self._stage = self.STAGE_DESTROY if self.user.obfs: self.obfs = Obfs(self.user.obfs) def close(self): if self._transport_protocol == flag.TRANSPORT_TCP: if self._transport: self._transport.close() if self.user and self.user.tcp_count > 0: self.user.tcp_count -= 1 elif self._transport_protocol == flag.TRANSPORT_UDP: pass else: raise NotImplementedError @UserRateLimitDecorator(calls=150, period=1) def write(self, raw_data): """ 针对tcp/udp分别写数据 ratelimit: 150calls/1s/user """ if self.obfs: data = self.obfs.server_encode(raw_data) else: data = raw_data if self._transport_protocol == flag.TRANSPORT_TCP: self._transport.write(data) # 记录下载流量 self.user.once_used_d += len(data) elif self._transport_protocol == flag.TRANSPORT_UDP: self._transport.sendto(data, self._peername) else: raise NotImplementedError def handle_tcp_connection_made(self, transport): """ 处理tcp连接 get_extra_info asyncio Transports api doc: https://docs.python.org/3/library/asyncio-protocol.html """ self._stage = self.STAGE_INIT self._transport_protocol = flag.TRANSPORT_TCP self._transport = transport # get the remote address to which the socket is connected self._peername = self._transport.get_extra_info("peername") self.keep_alive_open() try: self._cryptor = Cryptor(self._method, self._key, self._transport_protocol) logging.debug("tcp connection made") except NotImplementedError: logging.warning("not support cipher") transport.close() self.close() def handle_udp_connection_made(self, transport, peername): """ 处理udp连接 """ self._stage = self.STAGE_INIT self._transport = transport self._transport_protocol = flag.TRANSPORT_UDP self._peername = peername try: self._cryptor = Cryptor(self._method, self._key, self._transport_protocol) logging.debug("udp connection made") except NotImplementedError: logging.warning(f"not support cipher:{self._method}") transport.close() self.close() def handle_data_received(self, raw_data): if self.obfs and self.node_type == self.user.NODE_TYPE_ONE_PORT: data, header = self.obfs.server_decode(raw_data) switch_user = pool.user_pool.get_by_token(header.token) if not switch_user: logging.warning( "header not valid, path: {} peername: {}".format( header.path, self._transport.get_extra_info("peername"))) self.close() return self.user = switch_user logging.debug(f"server:{self} switch user to {self.user}") else: data = raw_data self.user.once_used_u += len(data) try: data = self._cryptor.decrypt(data) except Exception as e: logging.warning(f"decrypt data error {e}") self.close() return if self._stage == self.STAGE_INIT: coro = self._handle_stage_init(data) asyncio.create_task(coro) elif self._stage == self.STAGE_CONNECT: coro = self._handle_stage_connect(data) asyncio.create_task(coro) elif self._stage == self.STAGE_STREAM: self._handle_stage_stream(data) elif self._stage == self.STAGE_ERROR: self._handle_stage_error() else: logging.warning(f"unknown stage:{self._stage}") def handle_eof_received(self): logging.debug("eof received") self.close() def handle_connection_lost(self, exc): logging.debug(f"lost exc={exc}") self.close() async def _handle_stage_init(self, data): """ 初始化连接状态(握手后建立链接) doc: https://docs.python.org/3/library/asyncio-eventloop.html """ from shadowsocks.tcpreply import RemoteTCP from shadowsocks.udpreply import RemoteUDP atype, dst_addr, dst_port, header_length = parse_header(data) if not dst_addr: logging.warning(f"not valid data atype:{atype} user: {self.user}") self.close() return else: payload = data[header_length:] # 获取事件循环 loop = asyncio.get_event_loop() if self._transport_protocol == flag.TRANSPORT_TCP: self._stage = self.STAGE_CONNECT # 尝试建立tcp连接,成功的话将会返回 (transport,protocol) tcp_coro = loop.create_connection( lambda: RemoteTCP(dst_addr, dst_port, payload, self._method, self._key, self), dst_addr, dst_port, ) try: remote_transport, remote_instance = await tcp_coro # 记录用户的tcp连接数 self.user.tcp_count += 1 except (IOError, OSError) as e: logging.debug(f"connection failed , {type(e)} e: {e}") self.close() self._stage = self.STAGE_DESTROY except Exception as e: logging.warning(f"connection failed, {type(e)} e: {e}") self._stage = self.STAGE_ERROR self.close() else: logging.debug( f"connection established,remote {remote_instance}") self._remote = remote_instance self._stage = self.STAGE_STREAM elif self._transport_protocol == flag.TRANSPORT_UDP: self._stage = self.STAGE_INIT # 异步建立udp连接,并存入future对象 udp_coro = loop.create_datagram_endpoint( lambda: RemoteUDP(dst_addr, dst_port, payload, self._method, self._key, self), remote_addr=(dst_addr, dst_port), ) asyncio.create_task(udp_coro) else: raise NotImplementedError async def _handle_stage_connect(self, data): logging.debug("wait until the connection established") # 在握手之后,会耗费一定时间来来和remote建立连接 # 但是ss-client并不会等这个时间 所以我们在这里手动sleep一会 for _ in range(25): if self._stage == self.STAGE_CONNECT: await asyncio.sleep(0.2) elif self._stage == self.STAGE_STREAM: logging.debug("connection established") self._remote.write(data) return else: logging.debug(f"some error happed stage {self._stage}") # 5s之后连接还没建立的话 超时处理 logging.warning(f"time out to connect remote stage {self._stage}") self.close() def _handle_stage_stream(self, data): logging.debug(f"relay data length {len(data)}") self.keep_alive_active() self._remote.write(data) def _handle_stage_error(self): self.close()