async def forward_data_task( self, dir_str: str, reader: StreamReader, writer: StreamWriter, exit_q: asyncio.Queue, exit_status: int, ): """Forward data from one side to the other.""" try: while True: data = await reader.read(8192) if len(data) == 0: print( f"{PKG_NAME}: {self.name} {dir_str}: read EOF; shutdown with exit_status={exit_status}" ) await exit_q.put(exit_status) return if self.verbose >= 4: print( f"{PKG_NAME}: {self.name} {dir_str}: {data} ## {data.hex()}" ) writer.write(data) await writer.drain() except asyncio.CancelledError: # pylint: disable=try-except-raise raise except Exception as err: # pylint: disable=broad-except print( f"{PKG_NAME}: {self.name} {dir_str} got exception {err}; {traceback.format_exc(-1)}" ) await exit_q.put(1) return
async def send_state(writer: StreamWriter): """ This only sends the pulses and period, nothing else """ writer.write( f'{step_seconds},{",".join([str(x) for x in pulses_state])},0,0,0,0,0,0\n' .encode('utf8'))
async def handle_connection(self, reader: StreamReader, writer: StreamWriter): temp_ref = TempRef() world_packet_manager = WorldPacketManager(temp_ref=temp_ref, reader=reader, writer=writer) peername = writer.get_extra_info('peername') Logger.debug('[World Server]: Accept connection from {}'.format(peername)) Logger.info('[World Server]: trying to process auth session') auth = AuthManager(reader, writer, temp_ref=temp_ref, world_packet_manager=world_packet_manager) await auth.process(step=AuthStep.SECOND) self._register_tasks() while True: try: request = await asyncio.wait_for(reader.read(4096), timeout=1.0) if request: response = await asyncio.wait_for(world_packet_manager.process(request), timeout=1.0) if response: for packet in response: writer.write(packet) await writer.drain() except TimeoutError: continue except Exception as e: Logger.error('[World Server]: exception, {}'.format(e)) traceback.print_exc() break writer.close()
def write_response( writer: StreamWriter, status: HTTPStatus, headers: Iterable[Header], body: str, ) -> None: write_http_head(writer, status, headers) writer.write(body.encode("utf-8"))
async def callback(reader: StreamReader, writer: StreamWriter): content = b'{ fail to parse' writer.write( f'content-length: {len(content)}\r\n\r\n'.encode() + content) raw_content_length = await reader.readline() content_length = int( raw_content_length.decode().split(':')[-1].strip()) self.empty_line = await reader.readline() raw_response = await reader.readexactly(content_length) writer.close() self.response = json.loads(raw_response.decode())
async def callback(reader: StreamReader, writer: StreamWriter): content = json.dumps( {'message': 'Hello, World!'}).encode() writer.write( f'content-length: {len(content)}\r\n\r\n'.encode() + content) raw_content_length = await reader.readline() content_length = int( raw_content_length.decode().split(':')[-1].strip()) self.empty_line = await reader.readline() raw_response = await reader.readexactly(content_length) self.response = json.loads(raw_response.decode()) writer.close()
async def callback(reader: StreamReader, writer: StreamWriter): self.stub['method'] = 'this method does not exist' content = json.dumps(self.stub).encode() writer.write( f'content-length: {len(content)}\r\n\r\n'.encode() + content) raw_content_length = await reader.readline() content_length = int( raw_content_length.decode().split(':')[-1].strip()) self.empty_line = await reader.readline() raw_response = await reader.readexactly(content_length) writer.close() self.response = json.loads(raw_response.decode())
async def client_connected(reader: StreamReader, writer: StreamWriter, jobs_table: JobsTable): line = await read_long_line(reader) command = deserialize_command(line) reply = await handle_command(command, jobs_table) if reply is not None: reply_json = reply.serialize() writer.write(reply_json.encode()) writer.close() await writer.wait_closed()
async def handle_connection(self, reader: StreamReader, writer: StreamWriter): self._register_tasks() temp_ref = TempRef() world_packet_manager = WorldPacketManager(temp_ref=temp_ref, reader=reader, writer=writer) Logger.info('[World Server]: trying to process auth session') auth = AuthManager(reader, writer, temp_ref=temp_ref, world_packet_manager=world_packet_manager, session_keys=self.session_keys) is_authenticated = await auth.process(step=AuthStep.SECOND) if is_authenticated: peer_name = writer.get_extra_info('peername') Logger.success( '[World Server]: Accept connection from {}'.format(peer_name)) while True: try: request = await asyncio.wait_for(reader.read(4096), timeout=0.01) if request: response = await asyncio.wait_for( world_packet_manager.process(request), timeout=0.01) if response: for packet in response: writer.write(packet) await writer.drain() except TimeoutError: pass except BrokenPipeError: pass except Exception as e: Logger.error('[World Server]: exception, {}'.format(e)) traceback.print_exc() break finally: await asyncio.sleep(0.01) writer.close()
async def callback(reader: StreamReader, writer: StreamWriter): self.stub['method'] = 'initialize' # If initialize is called without a `rootUri` param, it will raise, causing an InternalError del self.stub['params']['rootUri'] content = json.dumps(self.stub).encode() writer.write( f'content-length: {len(content)}\r\n\r\n'.encode() + content) raw_content_length = await reader.readline() content_length = int( raw_content_length.decode().split(':')[-1].strip()) self.empty_line = await reader.readline() raw_response = await reader.readexactly(content_length) writer.close() self.response = json.loads(raw_response.decode())
async def process_request(reader: StreamReader, writer: StreamWriter, world_packet_mgr: WorldPacketManager): request = await asyncio.wait_for( reader.read(4096), timeout=Config.Realm.Settings.min_timeout) if request: opcode, data = request[:1], request[1:] if data: response = await asyncio.wait_for( world_packet_mgr.process(opcode=opcode, data=data), timeout=Config.Realm.Settings.min_timeout) if response: for packet in response: writer.write(packet) await writer.drain()
def write_chunk(writer: StreamWriter, data: bytes) -> None: writer.write(bytes(hex(len(data))[2:], "ascii")) writer.write(b"\r\n") writer.write(data) writer.write(b"\r\n") encoded = (data.decode("utf-8", errors="ignore").replace("\r", "\\r").replace( "\n", "\\n")) logging.debug(f"wrote chunk to listener: {encoded}")
async def process_request(reader: StreamReader, writer: StreamWriter, world_packet_mgr: WorldPacketManager): request: bytes = await wait_for(reader.read(4096), timeout=Config.Realm.Settings.min_timeout) if request: size, opcode, data = request[:2], request[2:6], request[6:] response: List[bytes] = await wait_for( world_packet_mgr.process( size=size, opcode=opcode, data=data ), timeout=Config.Realm.Settings.min_timeout ) if response: for packet in response: writer.write(packet) await writer.drain()
async def callback(reader: StreamReader, writer: StreamWriter) -> None: while True: data = await reader.read(100) message = data.decode().strip() if message == 'calendar': answer = datetime.now().strftime("%d.%m.%Y %H:%M") elif message.startswith(ECHO_PREFIX): answer = message[len(ECHO_PREFIX):] elif message == 'stop': break else: answer = HELP_TEXT answer += NEWLINE writer.write(answer.encode()) await writer.drain() writer.close()
def write_http_head(writer: StreamWriter, code: HTTPStatus, headers: Iterable[Header]) -> None: status_line = "HTTP/1.1 {} {}\r\n".format(code.value, code.phrase) writer.write(status_line.encode("ascii")) for h, v in headers: line = h.encode("ascii") + b": " + v.encode("ascii") + b"\r\n" writer.write(line) writer.write(b"\r\n")
async def ee4_srv(r: StreamReader, w: StreamWriter): logger.info(f"connection established") # create new context for client ctx = SubscriptionContext(r, w) try: while not r.exception(): msg_raw = await r.readline() msg = msg_raw.decode().strip() logger.info(f"< {msg}") out = cmd_handler(msg, ctx) logger.info(f"> {out}") w.write(f"{out}\r\n".encode()) await w.drain() except UnicodeDecodeError: logger.warning("invalid char encountered") except AttributeError: logger.warning("attribute error encountered") except ConnectionError: logger.warning("connection error encountered") finally: logger.info("cleaning up resources") ctx.end() logger.info(f"resources purged. connection terminated.")
async def handle_connection(self, reader: StreamReader, writer: StreamWriter): self._register_tasks() connection = Connection(reader=reader, writer=writer, session_keys=self.session_keys) world_packet_mgr = WorldPacketManager(connection=connection) # send auth challenge auth_seed: bytes = urandom(4) writer.write( world_packet_mgr.generate_packet( WorldOpCode.SMSG_AUTH_CHALLENGE, auth_seed ) ) connection.auth_seed = auth_seed while True: try: await WorldServer.process_request(reader, writer, world_packet_mgr) except TimeoutError: continue finally: await sleep(Config.Realm.Settings.min_timeout)
async def write_line_to_chat(writer: StreamWriter, message: str) -> None: """Encode message and send it to the server.""" message = sanitize_message(message).encode(encoding='utf-8') + b'\n' writer.write(message) await writer.drain() logging.debug(f'Send a message: {message}')
async def write_message(self, writer: StreamWriter, data: bytearray): send_array = self.__encode_data_length__(data) send_array.extend(data) send_array.append(0) writer.write(send_array) await writer.drain()
class NeoProtocol(StreamReaderProtocol): def __init__(self, *args, quality_check=False, **kwargs): """ Args: *args: quality_check (bool): there are times when we only establish a connection to check the quality of the node/address **kwargs: """ self._stream_reader = StreamReader() self._stream_writer = None nodemanager = kwargs.pop('nodemanager') self.client = NeoNode(self, nodemanager, quality_check) self._loop = events.get_event_loop() super().__init__(self._stream_reader) def connection_made(self, transport: asyncio.transports.BaseTransport) -> None: super().connection_made(transport) self._stream_writer = StreamWriter(transport, self, self._stream_reader, self._loop) if self.client: asyncio.create_task(self.client.connection_made(transport)) def connection_lost(self, exc: Optional[Exception] = None) -> None: if self.client: task = asyncio.create_task(self.client.connection_lost(exc)) task.add_done_callback( lambda args: super(NeoProtocol, self).connection_lost(exc)) else: super().connection_lost(exc) def eof_received(self) -> bool: self._stream_reader.feed_eof() self.connection_lost() return True # False == Do not keep connection open, this makes sure that `connection_lost` gets called. # return False async def send_message(self, message: Message) -> None: try: self._stream_writer.write(message.to_array()) await self._stream_writer.drain() except ConnectionResetError: # print("connection reset") self.connection_lost(ConnectionResetError) except ConnectionError: # print("connection error") self.connection_lost(ConnectionError) except asyncio.CancelledError: # print("task cancelled, closing connection") self.connection_lost(asyncio.CancelledError) except Exception as e: # print(f"***** woah what happened here?! {traceback.format_exc()}") self.connection_lost() async def read_message(self, timeout: int = 30) -> Message: if timeout == 0: # avoid memleak. See: https://bugs.python.org/issue37042 timeout = None async def _read(): try: message_header = await self._stream_reader.readexactly(24) magic, command, payload_length, checksum = struct.unpack( 'I 12s I I', message_header) # uint32, 12byte-string, uint32, uint32 payload_data = await self._stream_reader.readexactly( payload_length) payload, = struct.unpack('{}s'.format(payload_length), payload_data) except Exception: # ensures we break out of the main run() loop of Node, which triggers a disconnect callback to clean up self.client.disconnecting = True return None m = Message(magic, command.rstrip(b'\x00').decode('utf-8'), payload) if checksum != m.get_checksum(payload): logger.debug("Message checksum incorrect") return None else: return m try: return await asyncio.wait_for(_read(), timeout) except Exception: return None def disconnect(self) -> None: if self._stream_writer: self._stream_writer.close()
async def on_client_connected(self, reader: StreamReader, writer: StreamWriter): self.logger.debug('client connected') while True: # TODO: while active instead raw_command = await reader.readline() if raw_command == b'': # TODO: not ideal, remove pooling if possible await asyncio.sleep(0.5) continue try: command = SymlServiceCommand.parse(raw_command) logging.debug("received command %s", command) callable_command = getattr(self, f'cmd_{command.name}') cmd_arg = get_type_hints(callable_command).get('cmd') if cmd_arg: args_type = get_args(cmd_arg)[0] command.args = args_type(**command.args) # TODO: handle generators try: response: SymlServiceResponse if cmd_arg: response = await callable_command(command) else: response = await callable_command() except Exception as e: exc_type, exc_value, traceback = sys.exc_info() tb = Traceback() trace = tb.extract(exc_type, exc_value, traceback, show_locals=False) trace = [dataclasses.asdict(s) for s in trace.stacks] trace = [{ **t, 'frames': t.get('frames')[1:] } for t in trace] response = SymlServiceResponse( data=dict(), errors=[ dict(message="unhandled exception while " "processing command (${exception})", exception=e, trace=trace) ]) response.command = command writer.write(response.jsonb()) writer.write('\n'.encode()) logging.debug("sending response %s", response) await writer.drain() except Exception as e: self.logger.exception("oh no", e)
class YateAsync(yate.YateBase): MODE_STDIO = 1 MODE_TCP = 2 MODE_UNIX = 3 def __init__(self, host=None, port=None, sockpath=None): super().__init__() self.event_loop = asyncio.SelectorEventLoop() asyncio.set_event_loop(self.event_loop) self.reader = None self.writer = None self.main_task = None self._automatic_bufsize = False if host is not None: self.mode = self.MODE_TCP self.host = host self.port = port elif sockpath is not None: self.mode = self.MODE_UNIX self.sockpath = sockpath else: self.mode = self.MODE_STDIO def run(self, application_main): self.main_task = self.event_loop.create_task(self._amain(application_main)) self.event_loop.run_until_complete(self.main_task) self.event_loop.close() async def _amain(self, application_main): if self.mode == self.MODE_STDIO: await self.setup_for_stdio() elif self.mode == self.MODE_TCP: await self.setup_for_tcp(self.host, self.port) elif self.mode == self.MODE_UNIX: await self.setup_for_unix(self.sockpath) else: raise NotImplementedError("Unknown mode of operation found") # now start event processing for yate messages message_loop_task = self.event_loop.create_task(self.message_processing_loop()) # then let the main program run await self._amain_ready() try: await application_main(self) except asyncio.CancelledError as e: pass # We clean up even when the main task is cancelled self.writer.close() message_loop_task.cancel() async def _amain_ready(self): pass async def setup_for_stdio(self): self.reader = asyncio.StreamReader() reader_protocol = asyncio.StreamReaderProtocol(self.reader) await self.event_loop.connect_read_pipe(lambda: reader_protocol, sys.stdin) writer_transport, writer_protocol = await self.event_loop.connect_write_pipe(FlowControlMixin, sys.stdout) self.writer = StreamWriter(writer_transport, writer_protocol, None, self.event_loop) async def setup_for_tcp(self, host, port): self.reader, self.writer = await asyncio.open_connection(host, port, loop=self.event_loop) self.send_connect() async def setup_for_unix(self, sockpath): self.reader, self.writer = await asyncio.open_unix_connection(sockpath, loop=self.event_loop) self.send_connect() async def message_processing_loop(self): try: while True: raw_message = await self.reader.readline() if raw_message == b"": break # we only receive empty bytes if this is EOF, then terminate raw_message = raw_message.strip() self._recv_message_raw(raw_message) # once message processing ends, the whole application should terminate except asyncio.CancelledError: pass self.event_loop.stop() def _send_message_raw(self, msg): if self._automatic_bufsize: yate_buf_required = len(msg) + 2 # plus \n and \0 terminator in yate if yate_buf_required > int(self.get_local("bufsize")): def deferred_msg_write(_param, _value, _success): # defer writing the message that is too long until the bufsize was adapted self.writer.write(msg + b"\n") # round to next kb requested_bufsize = ((yate_buf_required // 1024) + 1) * 1024 self.set_local("bufsize", str(requested_bufsize), done_callback=deferred_msg_write) return self.writer.write(msg + b"\n") async def drain(self): await self.writer.drain() async def register_message_handler_async(self, message, callback, priority=100, filter_attribute=None, filter_value=None): future = self.event_loop.create_future() def _done_callback(success): future.set_result(success) self.register_message_handler(message, callback, priority, filter_attribute, filter_value, done_callback=_done_callback) await future return future.result() async def register_watch_handler_async(self, message, callback): future = self.event_loop.create_future() def _done_callback(success): future.set_result(success) self.register_watch_handler(message, callback, _done_callback) await future return future.result() async def send_message_async(self, msg: MessageRequest) -> Message: future = self.event_loop.create_future() def _done_callback(old_msg, result_msg): future.set_result(result_msg) self.send_message(msg, _done_callback) await future return future.result() async def set_local_async(self, param, value): future = asyncio.get_event_loop().create_future() def done_callback(_param, _value, success): future.set_result(success) self.set_local(param, value, done_callback=done_callback) await future return future.result() async def get_local_async(self, param): if param in self._local_params: return self._local_params[param] future = asyncio.get_event_loop().create_future() def done_callback(_param, value, _success): future.set_result(value) self.set_local(param, "", done_callback=done_callback) await future return future.result() async def activate_automatic_bufsize(self): await self.get_local_async("bufsize") self._automatic_bufsize = True
def send_initialize(writer: StreamWriter): content = json.dumps(client_initialization_message).encode() writer.write( f'content-length: {len(content)}\r\n\r\n'.encode() + content)
def send_initialized(writer: StreamWriter): initialized_params = {'jsonrpc': '2.0', 'method': 'initialized', 'params': {}} content = json.dumps(initialized_params).encode() writer.write( f'content-length: {len(content)}\r\n\r\n'.encode() + content)
class NeoProtocol(StreamReaderProtocol): def __init__(self, *args, **kwargs): """ Args: *args: **kwargs: """ sr = StreamReader() self._stream_reader_orig = sr self._stream_reader_wr = weakref.ref(sr) self._stream_writer = None self.client = node.NeoNode(self) self._loop = events.get_event_loop() super().__init__(sr) def connection_made(self, transport: asyncio.transports.BaseTransport) -> None: super().connection_made(transport) self._stream_writer = StreamWriter(transport, self, self._stream_reader_orig, self._loop) if self.client: asyncio.create_task(self.client.connection_made(transport)) def connection_lost(self, exc: Optional[Exception] = None) -> None: if self.client: task = asyncio.create_task(self.client.connection_lost(exc)) task.add_done_callback( lambda args: super(NeoProtocol, self).connection_lost(exc)) else: super().connection_lost(exc) async def send_message(self, message: Message) -> None: try: self._stream_writer.write(message.to_array()) await self._stream_writer.drain() except ConnectionResetError: print("connection reset") self.connection_lost(ConnectionResetError()) except ConnectionError: print("connection error") self.connection_lost(ConnectionError()) except asyncio.CancelledError: print("task cancelled, closing connection") # mypy can't seem to deduce that CancelledError still derives from Exception self.connection_lost(asyncio.CancelledError()) # type: ignore except Exception as e: print(f"***** woah what happened here?! {traceback.format_exc()}") self.connection_lost(Exception()) async def read_message(self, timeout: Optional[int] = 30) -> Optional[Message]: if timeout == 0: # avoid memleak. See: https://bugs.python.org/issue37042 timeout = None async def _read(): try: # readexactly can throw ConnectionResetError message_header = await self._stream_reader_orig.readexactly(3) payload_length = message_header[2] if payload_length == 0xFD: len_bytes = await self._stream_reader_orig.readexactly(2) payload_length, = struct.unpack("<H", len_bytes) elif payload_length == 0xFE: len_bytes = await self._stream_reader_orig.readexactly(4) payload_length, = struct.unpack("<I", len_bytes) elif payload_length == 0xFE: len_bytes = await self._stream_reader_orig.readexactly(8) payload_length, = struct.unpack("<Q", len_bytes) else: len_bytes = b'' if payload_length > Message.PAYLOAD_MAX_SIZE: raise ValueError("Invalid format") payload_data = await self._stream_reader_orig.readexactly( payload_length) raw = message_header + len_bytes + payload_data with serialization.BinaryReader(raw) as br: m = Message() try: m.deserialize(br) return m except Exception: logger.debug( f"Failed to deserialize message: {traceback.format_exc()}" ) return None except (ConnectionResetError, ValueError) as e: # ensures we break out of the main run() loop of Node, which triggers a disconnect callback to clean up self.client.disconnecting = True logger.debug( f"Failed to read message data for reason: {traceback.format_exc()}" ) return None except (asyncio.CancelledError, asyncio.IncompleteReadError): return None except Exception: # ensures we break out of the main run() loop of Node, which triggers a disconnect callback to clean up logger.debug(f"error read message 1 {traceback.format_exc()}") return None try: # logger.debug("trying to read message") return await asyncio.wait_for(_read(), timeout) except (asyncio.TimeoutError, asyncio.CancelledError): return None except Exception: logger.debug("error read message 2") traceback.print_exc() return None def disconnect(self) -> None: if self._stream_writer: self._stream_writer.close()
class AsyncIOOWNConnection(object): SOCKET_MODES = ('*99*0##', '*99*1##', '*99*9##') MODE_COLORS = (COLOR_LT_BLUE, COLOR_LT_GREEN, COLOR_LT_CYAN) MODE_NAMES = ('CMD', 'MON', 'SCMD') ACK = '*#*1##' def __init__(self, host, port, passwd, queue, mode, log=None, loop=None): self.host = host self.port = port self.passwd = passwd self.queue = queue self.mode = mode self.loop = loop self._auto_restart = True self._run = False if self.loop is None: self.loop = asyncio.get_event_loop() self.is_ready = asyncio.Event(loop=self.loop) if log is None: hdr = '[%s:%d %s]' % \ (self.host, self.port, self.MODE_NAMES[self.mode]) self.log = \ get_logger(header=hdr, color=self.MODE_COLORS[self.mode]) else: self.log = log self.reader = None self.writer = None self.protocol = None self.transport = None @property def auto_restart(self): return self._auto_restart @auto_restart.setter def auto_restart(self, value): self._auto_restart = value async def run(self): self._run = True started = False # start with the login procedure # TODO: this is broken, it fails to restart properly ctr = 0 while self._run: if self.protocol and not self.protocol.is_connected: # reset everything self.log('we have been disconnected - reset everything') self.reader = None self.writer = None self.protocol = None self.transport = None if not self.reader: self.log('create a StreamReader') self.reader = StreamReader(loop=self.loop) if not self.protocol: self.log('create OWNProtocol') self.protocol = OWNProtocol(self.reader, log=self.log) if not self.protocol.is_connected: if not self._auto_restart: if started: self.log("no autorestart... bailing out") break started = True self.log('attempt connection') self.is_ready.clear() self.log('resetting the msg handler') self.msg_handler = self.state_start try: self.log('create a new transport connection') self.transport, _ = await self.loop.create_connection( lambda: self.protocol, self.host, self.port) except gaierror: self.log('sleep about 5 seconds') await asyncio.sleep(5) else: self.log('create a new StreamWriter') self.writer = StreamWriter( self.transport, self.protocol, self.reader, self.loop) self.log('connection is up and running') try: pkt = await asyncio.wait_for(self.reader.readuntil(b'##'), timeout=1) except asyncio.TimeoutError: pass else: # from now on, packets are strings msg = pkt.decode('ascii') # self.log('packet (%d)=> %s' % (ctr, msg)) self.log('<= %s' % (msg)) await self.msg_handler(msg) ctr += 1 self.transport.close() self.log('AsyncIOOWNConnection.run : %s the end' % (str(self))) def stop(self): self.log('stop requested', LOG_DEBUG) self._run = False async def send_packet(self, msg): self.log('=> %s' % (msg)) if isinstance(msg, str): pkt = msg.encode('ascii') self.writer.write(pkt) await self.writer.drain() async def state_start(self, msg): if msg == self.ACK: self.msg_handler = self.state_login cmd_msg = self.SOCKET_MODES[self.mode] await self.send_packet(cmd_msg) else: self.log('we didn\'t get ACK') async def state_login(self, msg): if msg == '*98*2##': # this is a call for sha2 hmac authentication self.msg_handler = self.state_hmac_sha2 await self.send_packet(self.ACK) else: # attempt matching the old password system ops_m = re.match(r'^\*#(\d+)##$', msg) if ops_m is not None: nonce = ops_m.groups()[0] # calculate the password passwd = ownCalcPass(self.passwd, nonce) passwd_msg = ('*#%s##' % (passwd)) self.msg_handler = self.state_auth await self.send_packet(passwd_msg) else: self.log("unable to parse the openpassword nonce request") await asyncio.sleep(0) async def state_hmac_sha2(self, msg): ra_m = re.match(r'^\*#(\d{128})##$', msg) if ra_m is not None: ra = ra_m.group(1) hmac = ownCalcHmacSha2(self.passwd, ra) rb, hmac_client, self.hmac_server = hmac self.msg_handler = self.state_hmac_sha2_check_response hmac_packet = ('*#%s*%s##' % (rb, hmac_client)) await self.send_packet(hmac_packet) else: self.log("unable to parse the hmac_sha2 request") await asyncio.sleep(0) async def state_hmac_sha2_check_response(self, msg): c_resp_m = re.match(r'^\*#(\d{128})##$', msg) if c_resp_m is not None: c_resp = c_resp_m.group(1) if c_resp == self.hmac_server: self.is_ready.set() self.msg_handler = self.state_dispatch await self.send_packet(self.ACK) else: self.log("wrong response from server, expected %s", self.hmac_server) await asyncio.sleep(0) else: self.log("unable to parse the hmac_sha2 request") await asyncio.sleep(0) async def state_auth(self, msg): if msg == self.ACK: self.is_ready.set() self.msg_handler = self.state_dispatch async def state_dispatch(self, msg): await self.queue.put((msg, self.mode, ))
async def initial_connection_check(self, reader: StreamReader, writer: StreamWriter): host_port = writer.get_extra_info('peername') new_user: Optional[User] = None try: # 1. send plain message writer.write(b'hello') await writer.drain() # 2. receive other's header try: received = await asyncio.wait_for(reader.read(BUFFER_SIZE), 5.0) if len(received) == 0: raise PeerToPeerError('empty msg receive') header = json.loads(received.decode()) except asyncio.TimeoutError: raise PeerToPeerError('timeout on other\'s header receive') except json.JSONDecodeError: raise PeerToPeerError( 'json decode error on other\'s header receive') # 3. generate new user user_header = UserHeader(**header) new_user = User(user_header, self.number, reader, writer, host_port, AESCipher.create_key(), SERVER_SIDE) self.number += 1 if new_user.header.name == V.SERVER_NAME: raise ConnectionAbortedError('Same origin connection') # 4. send my public key my_sec, my_pub = generate_keypair() send = json.dumps({'public-key': my_pub}).encode() await new_user.send(send) self.traffic.put_traffic_up(send) # 5. receive public key try: receive = await new_user.recv() self.traffic.put_traffic_down(receive) if len(receive) == 0: raise ConnectionAbortedError('received msg is zero.') data = json.loads(receive.decode()) except asyncio.TimeoutError: raise PeerToPeerError('timeout on public key receive') except json.JSONDecodeError: raise PeerToPeerError( 'json decode error on public key receive') # 6. encrypt and send AES key and header send = json.dumps({ 'aes-key': new_user.aeskey, 'header': self.get_server_header(), }) key = generate_shared_key(my_sec, data['public-key']) encrypted = AESCipher.encrypt(key, send.encode()) await new_user.send(encrypted) self.traffic.put_traffic_up(encrypted) # 7. receive accept signal try: encrypted = await new_user.recv() self.traffic.put_traffic_down(encrypted) except asyncio.TimeoutError: raise PeerToPeerError('timeout on accept signal receive') receive = AESCipher.decrypt(new_user.aeskey, encrypted) if receive != b'accept': raise PeerToPeerError(f"Not accept signal! {receive}") # 8. accept connection log.info( f"established connection as server from {new_user.header.name} {new_user.get_host_port()}" ) asyncio.ensure_future(self.receive_loop(new_user)) # server port's reachable check await asyncio.sleep(1.0) await self.check_reachable(new_user) return except (ConnectionAbortedError, ConnectionResetError) as e: msg = f"disconnect error {host_port} {e}" except PeerToPeerError as e: msg = f"peer2peer error {host_port} {e}" except Exception as e: msg = "InitialConnCheck: {}".format(e) log.error(msg, exc_info=True) # EXCEPTION! if new_user: # remove user self.remove_connection(new_user, msg) else: # close socket log.debug(msg) try: writer.write(msg.encode()) await writer.drain() except Exception: pass try: writer.close() except Exception: pass
def write_to(self, writer: StreamWriter) -> None: writer.write(self.header.encode('utf-8')) writer.write(self.body)