class Proto(asyncio.Protocol): def connection_made(self, transport): self.transport = transport self.parser = HttpRequestParser(self) self.body = None def connection_lost(self, exc): self.transport = None def data_received(self, data): self.parser.feed_data(data) def on_body(self, body): self.body = body def on_message_complete(self): request = ujson.loads(self.body) response = {'result': request['a'] + request['b']} self.send(ujson.dumps(response).encode('utf8')) self.body = None def send(self, data): response = b'HTTP/1.1 200 OK\nContent-Length: ' + str( len(data)).encode() + b'\n\n' + data self.transport.write(response) if not self.parser.should_keep_alive(): #print('no keep-alive') self.transport.close() self.transport = None
class HttpProtocol(asyncio.Protocol): __slots__ = ( # event loop, connection 'loop', 'transport', 'connections', 'signal', # request params 'parser', 'request', 'url', 'headers', # request config 'request_handler', 'request_timeout', 'request_max_size', # connection management '_total_request_size', '_timeout_handler', '_last_communication_time') def __init__(self, *, loop, request_handler, error_handler, signal=Signal(), connections={}, request_timeout=60, request_max_size=None): self.loop = loop self.transport = None self.request = None self.parser = None self.url = None self.headers = None self.signal = signal self.connections = connections self.request_handler = request_handler self.error_handler = error_handler self.request_timeout = request_timeout self.request_max_size = request_max_size self._total_request_size = 0 self._timeout_handler = None self._last_request_time = None self._request_handler_task = None # -------------------------------------------- # # Connection # -------------------------------------------- # def connection_made(self, transport): self.connections[self] = True self._timeout_handler = self.loop.call_later(self.request_timeout, self.connection_timeout) self.transport = transport self._last_request_time = current_time def connection_lost(self, exc): del self.connections[self] self._timeout_handler.cancel() self.cleanup() def connection_timeout(self): # Check if time_elapsed = current_time - self._last_request_time if time_elapsed < self.request_timeout: time_left = self.request_timeout - time_elapsed self._timeout_handler = \ self.loop.call_later(time_left, self.connection_timeout) else: if self._request_handler_task: self._request_handler_task.cancel() response = self.error_handler.response( self.request, RequestTimeout('Request Timeout')) self.write_response(response) # -------------------------------------------- # # Parsing # -------------------------------------------- # def data_received(self, data): # Check for the request itself getting too large and exceeding # memory limits self._total_request_size += len(data) if self._total_request_size > self.request_max_size: return self.bail_out( "Request too large ({}), connection closed".format( self._total_request_size)) # Create parser if this is the first time we're receiving data if self.parser is None: assert self.request is None self.headers = [] self.parser = HttpRequestParser(self) # Parse request chunk or close connection try: self.parser.feed_data(data) except HttpParserError as e: self.bail_out( "Invalid request data, connection closed ({})".format(e)) def on_url(self, url): self.url = url def on_header(self, name, value): if name == b'Content-Length' and int(value) > self.request_max_size: return self.bail_out( "Request body too large ({}), connection closed".format(value)) self.headers.append((name.decode(), value.decode('utf-8'))) def on_headers_complete(self): remote_addr = self.transport.get_extra_info('peername') if remote_addr: self.headers.append(('Remote-Addr', '%s:%s' % remote_addr)) self.request = Request(url_bytes=self.url, headers=CIMultiDict(self.headers), version=self.parser.get_http_version(), method=self.parser.get_method().decode()) def on_body(self, body): if self.request.body: self.request.body += body else: self.request.body = body def on_message_complete(self): self._request_handler_task = self.loop.create_task( self.request_handler(self.request, self.write_response)) # -------------------------------------------- # # Responding # -------------------------------------------- # def write_response(self, response): try: keep_alive = self.parser.should_keep_alive() \ and not self.signal.stopped self.transport.write( response.output(self.request.version, keep_alive, self.request_timeout)) if not keep_alive: self.transport.close() else: # Record that we received data self._last_request_time = current_time self.cleanup() except Exception as e: self.bail_out( "Writing response failed, connection closed {}".format(e)) def bail_out(self, message): log.debug(message) self.transport.close() def cleanup(self): self.parser = None self.request = None self.url = None self.headers = None self._request_handler_task = None self._total_request_size = 0 def close_if_idle(self): """ Close the connection if a request is not being sent or received :return: boolean - True if closed, false if staying open """ if not self.parser: self.transport.close() return True return False
class HttpProtocol(asyncio.Protocol): __slots__ = ( # event loop, connection 'loop', 'transport', 'connections', 'signal', # request params 'parser', 'request', 'url', 'headers', # request config 'request_handler', 'request_timeout', 'request_max_size', # connection management '_total_request_size', '_timeout_handler', '_last_communication_time') def __init__(self, *, loop, request_handler, error_handler, signal=Signal(), connections=set(), request_timeout=60, request_max_size=None): self.loop = loop self.transport = None self.request = None self.parser = None self.url = None self.headers = None self.signal = signal self.connections = connections self.request_handler = request_handler self.error_handler = error_handler self.request_timeout = request_timeout self.request_max_size = request_max_size self._total_request_size = 0 self._timeout_handler = None self._last_request_time = None self._request_handler_task = None # -------------------------------------------- # # Connection # -------------------------------------------- # def connection_made(self, transport): self.connections.add(self) self._timeout_handler = self.loop.call_later( self.request_timeout, self.connection_timeout) self.transport = transport self._last_request_time = current_time def connection_lost(self, exc): self.connections.discard(self) self._timeout_handler.cancel() def connection_timeout(self): # Check if time_elapsed = current_time - self._last_request_time if time_elapsed < self.request_timeout: time_left = self.request_timeout - time_elapsed self._timeout_handler = ( self.loop.call_later(time_left, self.connection_timeout)) else: if self._request_handler_task: self._request_handler_task.cancel() exception = RequestTimeout('Request Timeout') self.write_error(exception) # -------------------------------------------- # # Parsing # -------------------------------------------- # def data_received(self, data): # Check for the request itself getting too large and exceeding # memory limits self._total_request_size += len(data) if self._total_request_size > self.request_max_size: exception = PayloadTooLarge('Payload Too Large') self.write_error(exception) # Create parser if this is the first time we're receiving data if self.parser is None: assert self.request is None self.headers = [] self.parser = HttpRequestParser(self) # Parse request chunk or close connection try: self.parser.feed_data(data) except HttpParserError: exception = InvalidUsage('Bad Request') self.write_error(exception) def on_url(self, url): self.url = url def on_header(self, name, value): if name == b'Content-Length' and int(value) > self.request_max_size: exception = PayloadTooLarge('Payload Too Large') self.write_error(exception) self.headers.append((name.decode().casefold(), value.decode())) def on_headers_complete(self): self.request = Request( url_bytes=self.url, headers=CIDict(self.headers), version=self.parser.get_http_version(), method=self.parser.get_method().decode(), transport=self.transport ) def on_body(self, body): self.request.body.append(body) def on_message_complete(self): if self.request.body: self.request.body = b''.join(self.request.body) self._request_handler_task = self.loop.create_task( self.request_handler( self.request, self.write_response, self.stream_response)) # -------------------------------------------- # # Responding # -------------------------------------------- # def write_response(self, response): """ Writes response content synchronously to the transport. """ try: keep_alive = ( self.parser.should_keep_alive() and not self.signal.stopped) self.transport.write( response.output( self.request.version, keep_alive, self.request_timeout)) except AttributeError: log.error( ('Invalid response object for url {}, ' 'Expected Type: HTTPResponse, Actual Type: {}').format( self.url, type(response))) self.write_error(ServerError('Invalid response type')) except RuntimeError: log.error( 'Connection lost before response written @ {}'.format( self.request.ip)) except Exception as e: self.bail_out( "Writing response failed, connection closed {}".format( repr(e))) finally: if not keep_alive: self.transport.close() else: self._last_request_time = current_time self.cleanup() async def stream_response(self, response): """ Streams a response to the client asynchronously. Attaches the transport to the response so the response consumer can write to the response as needed. """ try: keep_alive = ( self.parser.should_keep_alive() and not self.signal.stopped) response.transport = self.transport await response.stream( self.request.version, keep_alive, self.request_timeout) except AttributeError: log.error( ('Invalid response object for url {}, ' 'Expected Type: HTTPResponse, Actual Type: {}').format( self.url, type(response))) self.write_error(ServerError('Invalid response type')) except RuntimeError: log.error( 'Connection lost before response written @ {}'.format( self.request.ip)) except Exception as e: self.bail_out( "Writing response failed, connection closed {}".format( repr(e))) finally: if not keep_alive: self.transport.close() else: self._last_request_time = current_time self.cleanup() def write_error(self, exception): try: response = self.error_handler.response(self.request, exception) version = self.request.version if self.request else '1.1' self.transport.write(response.output(version)) except RuntimeError: log.error( 'Connection lost before error written @ {}'.format( self.request.ip if self.request else 'Unknown')) except Exception as e: self.bail_out( "Writing error failed, connection closed {}".format(repr(e)), from_error=True) finally: self.transport.close() def bail_out(self, message, from_error=False): if from_error or self.transport.is_closing(): log.error( ("Transport closed @ {} and exception " "experienced during error handling").format( self.transport.get_extra_info('peername'))) log.debug( 'Exception:\n{}'.format(traceback.format_exc())) else: exception = ServerError(message) self.write_error(exception) log.error(message) def cleanup(self): self.parser = None self.request = None self.url = None self.headers = None self._request_handler_task = None self._total_request_size = 0 def close_if_idle(self): """Close the connection if a request is not being sent or received :return: boolean - True if closed, false if staying open """ if not self.parser: self.transport.close() return True return False
class HttpProtocol(asyncio.Protocol): __slots__ = ( # event loop, connection 'loop', 'transport', 'connections', 'signal', # request params 'parser', 'request', 'url', 'headers', # request config 'request_handler', 'request_timeout', 'request_max_size', # connection management '_total_request_size', '_timeout_handler', '_last_communication_time') def __init__(self, *, loop, request_handler, error_handler, signal=Signal(), connections={}, request_timeout=60, request_max_size=None): self.loop = loop self.transport = None self.request = None # 请求 self.parser = None self.url = None self.headers = None # 请求头 self.signal = signal self.connections = connections self.request_handler = request_handler # 请求处理器 self.error_handler = error_handler # 出错处理器 self.request_timeout = request_timeout self.request_max_size = request_max_size self._total_request_size = 0 self._timeout_handler = None self._last_request_time = None self._request_handler_task = None # -------------------------------------------- # # Connection # -------------------------------------------- # def connection_made(self, transport): self.connections.add(self) self._timeout_handler = self.loop.call_later(self.request_timeout, self.connection_timeout) self.transport = transport self._last_request_time = current_time def connection_lost(self, exc): self.connections.discard(self) self._timeout_handler.cancel() self.cleanup() def connection_timeout(self): # Check if time_elapsed = current_time - self._last_request_time if time_elapsed < self.request_timeout: time_left = self.request_timeout - time_elapsed self._timeout_handler = \ self.loop.call_later(time_left, self.connection_timeout) else: if self._request_handler_task: self._request_handler_task.cancel() exception = RequestTimeout('Request Timeout') self.write_error(exception) # -------------------------------------------- # # Parsing # -------------------------------------------- # def data_received(self, data): # Check for the request itself getting too large and exceeding # memory limits self._total_request_size += len(data) if self._total_request_size > self.request_max_size: exception = PayloadTooLarge('Payload Too Large') self.write_error(exception) # Create parser if this is the first time we're receiving data if self.parser is None: assert self.request is None self.headers = [] self.parser = HttpRequestParser(self) # Parse request chunk or close connection try: self.parser.feed_data(data) except HttpParserError: exception = InvalidUsage('Bad Request') self.write_error(exception) def on_url(self, url): self.url = url # # HTTP 请求: 补全 head 信息 # - 更新 headers 字段 # def on_header(self, name, value): if name == b'Content-Length' and int(value) > self.request_max_size: exception = PayloadTooLarge('Payload Too Large') self.write_error(exception) self.headers.append((name.decode(), value.decode('utf-8'))) # # HTTP 请求: 写入 head 信息 # def on_headers_complete(self): remote_addr = self.transport.get_extra_info('peername') if remote_addr: self.headers.append(('Remote-Addr', '%s:%s' % remote_addr)) # # 构建 HTTP 请求 # self.request = Request(url_bytes=self.url, headers=CIMultiDict(self.headers), version=self.parser.get_http_version(), method=self.parser.get_method().decode()) # # HTTP 请求: 写入 body 部分 # def on_body(self, body): if self.request.body: self.request.body += body else: self.request.body = body def on_message_complete(self): # # 任务创建: # self._request_handler_task = self.loop.create_task( self.request_handler(self.request, self.write_response)) # -------------------------------------------- # # Responding # - HTTP 响应部分 # -------------------------------------------- # # # HTTP 响应: 正常响应 # - 写出 HTTP 响应 # - 长连接, 更新连接时间 # def write_response(self, response): try: keep_alive = self.parser.should_keep_alive( ) and not self.signal.stopped # # 输出 HTTP 响应 # self.transport.write( response.output( # HTTP Response, 写一个响应 self.request.version, keep_alive, self.request_timeout)) if not keep_alive: # 非长连接, 关闭 self.transport.close() else: # Record that we received data self._last_request_time = current_time self.cleanup() except Exception as e: self.bail_out( "Writing response failed, connection closed {}".format(e)) # # HTTP 响应: 出错响应 # def write_error(self, exception): try: response = self.error_handler.response(self.request, exception) # 出错响应处理 version = self.request.version if self.request else '1.1' # HTTP 协议版本 self.transport.write( response.output(version)) # HTTP Response, 写一个响应 self.transport.close() except Exception as e: self.bail_out( "Writing error failed, connection closed {}".format(e)) # # 异常记录: # def bail_out(self, message): exception = ServerError(message) self.write_error(exception) log.error(message) # # 清理: # - 将字段复位为空 # def cleanup(self): self.parser = None self.request = None self.url = None self.headers = None self._request_handler_task = None self._total_request_size = 0 def close_if_idle(self): """ Close the connection if a request is not being sent or received :return: boolean - True if closed, false if staying open """ if not self.parser: self.transport.close() return True return False
class Protocol(asyncio.Protocol): """Responsible of parsing the request and writing the response. You can subclass it to set your own `Query`, `Request` or `Response` classes. """ __slots__ = ('app', 'req', 'parser', 'resp', 'writer') Query = Query Request = Request Response = Response def __init__(self, app): self.app = app self.parser = HttpRequestParser(self) def data_received(self, data: bytes): try: self.parser.feed_data(data) except HttpParserError: # If the parsing failed before on_message_begin, we don't have a # response. self.response = Response() self.response.status = HTTPStatus.BAD_REQUEST self.response.body = b'Unparsable request' self.write() def connection_made(self, transport): self.writer = transport # All on_xxx methods are in use by httptools parser. # See https://github.com/MagicStack/httptools#apis def on_header(self, name: bytes, value: bytes): self.request.headers[name.decode()] = value.decode() def on_body(self, body: bytes): self.request.body += body def on_url(self, url: bytes): self.request.url = url parsed = parse_url(url) self.request.path = parsed.path.decode() self.request.query_string = (parsed.query or b'').decode() parsed_qs = parse_qs(self.request.query_string, keep_blank_values=True) self.request.query = self.Query(parsed_qs) def on_message_begin(self): self.request = self.Request() self.response = self.Response() def on_message_complete(self): self.request.method = self.parser.get_method().decode().upper() task = self.app.loop.create_task(self.app(self.request, self.response)) task.add_done_callback(self.write) # May or may not have "future" as arg. def write(self, *args): # Appends bytes for performances. payload = b'HTTP/1.1 %a %b\r\n' % ( self.response.status.value, self.response.status.phrase.encode()) if not isinstance(self.response.body, bytes): self.response.body = self.response.body.encode() if 'Content-Length' not in self.response.headers: length = len(self.response.body) self.response.headers['Content-Length'] = str(length) for key, value in self.response.headers.items(): payload += b'%b: %b\r\n' % (key.encode(), str(value).encode()) payload += b'\r\n%b' % self.response.body self.writer.write(payload) if not self.parser.should_keep_alive(): self.writer.close()
class Channel: __slots__ = ( 'parser', 'request', 'complete', 'headers_complete', 'socket', 'reader', ) def __init__(self, socket): self.complete = False self.headers_complete = False self.parser = HttpRequestParser(self) self.request = None self.socket = socket self.reader = self._reader() def data_received(self, data: bytes): try: self.parser.feed_data(data) except HttpParserUpgrade: self.request.upgrade = True except (HttpParserError, HttpParserInvalidMethodError) as exc: # We should log the exc. raise HTTPError( HTTPStatus.BAD_REQUEST, 'Unparsable request.') async def read(self, parse: bool=True) -> bytes: data = await self.socket.recv(1024) if data: if parse: self.data_received(data) return data async def _reader(self) -> bytes: while not self.complete: data = await self.read() if not data: break yield data async def _drainer(self) -> bytes: while True: data = await self.read(parse=False) if not data: break yield data def on_header(self, name: bytes, value: bytes): value = value.decode() if value: name = name.decode().title() if name in self.request.headers: self.request.headers[name] += ', {}'.format(value) else: self.request.headers[name] = value def on_body(self, data: bytes): self.request.body += data def on_message_begin(self): self.complete = False self.request = Request(self.socket, self.reader) def on_message_complete(self): self.complete = True def on_url(self, url: bytes): self.request.url = url parsed = parse_url(url) self.request.path = unquote(parsed.path.decode()) self.request.query_string = (parsed.query or b'').decode() def on_headers_complete(self): self.request.keep_alive = self.parser.should_keep_alive() self.request.method = self.parser.get_method().decode().upper() self.headers_complete = True async def __aiter__(self): keep_alive = True while keep_alive: data = await self.read() if data is None: break if self.headers_complete: yield self.request keep_alive = self.request.keep_alive if keep_alive: if not self.complete: await self.reader.aclose() # We drain if there's an uncomplete request. async for _ in self._drainer(): pass self.request = None self.complete = False self.headers_complete = False self.reader = self._reader()
class HttpProtocol(asyncio.Protocol): __slots__ = ( # event loop, connection 'loop', 'transport', 'connections', 'signal', # request params 'parser', 'request', 'url', 'headers', # request config 'request_handler', 'request_timeout', 'request_max_size', # connection management '_total_request_size', '_timeout_handler', '_last_communication_time') def __init__(self, *, loop, request_handler, error_handler, signal=Signal(), connections=set(), request_timeout=60, request_max_size=None): self.loop = loop self.transport = None self.request = None self.parser = None self.url = None self.headers = None self.signal = signal self.connections = connections self.request_handler = request_handler self.error_handler = error_handler self.request_timeout = request_timeout self.request_max_size = request_max_size self._total_request_size = 0 self._timeout_handler = None self._last_request_time = None self._request_handler_task = None # -------------------------------------------- # # Connection # -------------------------------------------- # def connection_made(self, transport): self.connections.add(self) self._timeout_handler = self.loop.call_later( self.request_timeout, self.connection_timeout) self.transport = transport self._last_request_time = current_time def connection_lost(self, exc): self.connections.discard(self) self._timeout_handler.cancel() def connection_timeout(self): # Check if time_elapsed = current_time - self._last_request_time if time_elapsed < self.request_timeout: time_left = self.request_timeout - time_elapsed self._timeout_handler = ( self.loop.call_later(time_left, self.connection_timeout)) else: if self._request_handler_task: self._request_handler_task.cancel() exception = RequestTimeout('Request Timeout') self.write_error(exception) # -------------------------------------------- # # Parsing # -------------------------------------------- # def data_received(self, data): # Check for the request itself getting too large and exceeding # memory limits self._total_request_size += len(data) if self._total_request_size > self.request_max_size: exception = PayloadTooLarge('Payload Too Large') self.write_error(exception) # Create parser if this is the first time we're receiving data if self.parser is None: assert self.request is None self.headers = [] self.parser = HttpRequestParser(self) # Parse request chunk or close connection try: self.parser.feed_data(data) except HttpParserError: exception = InvalidUsage('Bad Request') self.write_error(exception) def on_url(self, url): self.url = url def on_header(self, name, value): if name == b'Content-Length' and int(value) > self.request_max_size: exception = PayloadTooLarge('Payload Too Large') self.write_error(exception) self.headers.append((name.decode().casefold(), value.decode())) def on_headers_complete(self): self.request = Request( url_bytes=self.url, headers=CIDict(self.headers), version=self.parser.get_http_version(), method=self.parser.get_method().decode(), transport=self.transport ) def on_body(self, body): self.request.body.append(body) def on_message_complete(self): if self.request.body: self.request.body = b''.join(self.request.body) self._request_handler_task = self.loop.create_task( self.request_handler(self.request, self.write_response)) # -------------------------------------------- # # Responding # -------------------------------------------- # def write_response(self, response): try: keep_alive = ( self.parser.should_keep_alive() and not self.signal.stopped) self.transport.write( response.output( self.request.version, keep_alive, self.request_timeout)) except RuntimeError: log.error( 'Connection lost before response written @ {}'.format( self.request.ip)) except Exception as e: self.bail_out( "Writing response failed, connection closed {}".format(e)) finally: if not keep_alive: self.transport.close() else: # Record that we received data self._last_request_time = current_time self.cleanup() def write_error(self, exception): try: response = self.error_handler.response(self.request, exception) version = self.request.version if self.request else '1.1' self.transport.write(response.output(version)) except RuntimeError: log.error( 'Connection lost before error written @ {}'.format( self.request.ip)) except Exception as e: self.bail_out( "Writing error failed, connection closed {}".format(e), from_error=True) finally: self.transport.close() def bail_out(self, message, from_error=False): if from_error and self.transport.is_closing(): log.error( ("Transport closed @ {} and exception " "experienced during error handling").format( self.transport.get_extra_info('peername'))) log.debug( 'Exception:\n{}'.format(traceback.format_exc())) else: exception = ServerError(message) self.write_error(exception) log.error(message) def cleanup(self): self.parser = None self.request = None self.url = None self.headers = None self._request_handler_task = None self._total_request_size = 0 def close_if_idle(self): """ Close the connection if a request is not being sent or received :return: boolean - True if closed, false if staying open """ if not self.parser: self.transport.close() return True return False
class BaseServer(asyncio.Protocol): def __init__(self, loop, requesthandle, toggle, connections=set(), request_timeout=10): self.loop = loop self.requesthandle = requesthandle self.toggle = toggle self.connections = connections self.request_timeout = request_timeout self.timehandle = None self.requesthandletask = None # 默认请求参数 self.ip = None self.parse = None self.url = None self.headers = {} self.body = None self.httpversion = None self.method = None self.request = None self.contentlength = 0 self.keep_alive = False # print("enter init") ####################### # 连接建立 def connection_made(self, transport): # 将本连接加入全局的连接集合 print("connect start") self.connections.add(self) self.transport = transport self.ip = transport.get_extra_info("peername") # self.timehandle = self.loop.call_later(self.request_timeout, self.teardown) def connection_lost(self, exc): print("connect end") self.connections.remove(self) # self.timehandle.cancel() self.clean() ######################## # 解析数据 def data_received(self, data): pprint(data.decode("utf-8")) self.contentlength += len(data) if not self.parse: # HttpRequestParser响应当前对象的以下方法 # - on_url(url:byte) # - on_header(name: bytes, value: bytes) # - on_headers_complete() # - on_body(body: bytes) # - on_message_complete() # get_http_version(self) -> str # def should_keep_alive(self) -> bool: self.parse = HttpRequestParser(self) try: self.parse.feed_data(data) except HttpRequestParser as e: pass def on_url(self, url): self.url = url def on_header(self, name, value): self.headers[name] = value def on_headers_complete(self): self.method = self.parse.get_method() self.httpversion = self.parse.get_http_version() self.keep_alive = self.parse.should_keep_alive() self.request = Request(self.ip, self.url, self.headers, self.method, self.httpversion, self.request_timeout if self.keep_alive else None) def on_body(self, body): # print(body) self.body = body self.request.setbody(body) def on_message_complete(self): print("parser complete") self.requesthandletask = self.loop.create_task(self.requesthandle(self.request, self.write)) ######################### # 返回响应 def write(self, response): print("start write") self.transport.write(response.make_response()) print("end write") keep_alive = self.keep_alive and not self.toggle[0] if not keep_alive: self.transport.close() else: self.clean() ######################## # 清理或者关闭连接 def clean(self): self.requsettask = None self.parse = None self.url = None self.headers = {} self.body = None self.httpversion = None self.methods = None self.request = None self.contentlength = 0 self.keep_alive = False def teardown(self): if not self.parse: self.transport.close() return True return False
class HttpProtocol(asyncio.Protocol): # http://book.pythontips.com/en/latest/__slots__magic.html # def __init__(self, *, loop, request_handler, error_handler, signal, connections, request_timeout) -> None: def __init__(self, params: panic_datatypes.ServerParams): self.params = params self.loop = params.loop self.transport = None self.request = None self.parser = None self.url = None self.headers = None self.signal = params.signal self.connections = params.connections self.request_handler = params.request_handler self.request_timeout = params.request_timeout self._total_request_size = 0 self._timeout_handler = None self._last_request_time = None self._request_handler_task = None self._identity = uuid.uuid4() # -------------------------------------------- # # Connection # -------------------------------------------- # def connection_made(self, transport): self.connections.add(self) self._timeout_handler = self.loop.call_later(self.request_timeout, self.connection_timeout) self.transport = transport self._last_request_time = datetime.datetime.utcnow() def connection_lost(self, exc): self.connections.discard(self) self._timeout_handler.cancel() self.cleanup() def connection_timeout(self): time_elapsed = datetime.datetime.utcnow() - self._last_request_time try: if time_elapsed.seconds < self.request_timeout: time_left = self.request_timeout - time_elapsed.seconds self._timeout_handler = self.loop.call_later(time_left, self.connection_timeout) except Exception as err: print(err) import ipdb; ipdb.set_trace() pass else: if self._request_handler_task: self._request_handler_task.cancel() exception = panic_exceptions.RequestTimeout('Request Timeout') self.write_error(exception) # -------------------------------------------- # # Parsing # -------------------------------------------- # def data_received(self, data): # Check for the request itself getting too large and exceeding # memory limits # TODO: ^ self._total_request_size += len(data) # Create parser if this is the first time we're receiving data if self.parser is None: assert self.request is None self.headers = panic_datatypes.HTTPHeaders() self.parser = HttpRequestParser(self) # Parse request chunk or close connection try: self.parser.feed_data(data) except HttpParserError as err: import ipdb;ipdb.set_trace() exception = panic_exceptions.InvalidUsage('Bad Request') self.write_error(exception) def on_url(self, url): self.url = url def on_header(self, name, value): #if name == b'Content-Length' and int(value) > 1000: # exception = PayloadTooLarge('Payload Too Large') # self.write_error(exception) self.headers.append(name.decode(), value.decode('utf-8')) def on_headers_complete(self): remote_addr = self.transport.get_extra_info('peername') if remote_addr: self.headers.append(remote_addr[0], str(remote_addr[1])) self.request = panic_request.Request( url = self.url, headers = self.headers, version = self.parser.get_http_version(), method = panic_datatypes.HTTPMethod.Match(self.parser.get_method().decode()) ) def on_body(self, body): self.request.body.append(body) def on_message_complete(self): self._request_handler_task = self.loop.create_task(self.request_handler(self.request, self.write_response)) # -------------------------------------------- # # Responding # -------------------------------------------- # def write_response(self, response): if self.parser: keep_alive = self.parser.should_keep_alive() and not self.signal.stopped else: keep_alive = False try: self.transport.write(response.output(getattr(self.request, 'version', '1.1'))) except RuntimeError as err: logger.error(err) except Exception as err: import ipdb; ipdb.set_trace() pass if keep_alive: self._last_request_time = datetime.datetime.utcnow() self.cleanup() else: self.transport.close() def write_error(self, exception): try: response = self.params.error_handler(self.request, exception) version = self.request.version if self.request else '1.1' self.transport.write(response.output(float(version))) self.transport.close() except panic_exceptions.RequestTimeout: exception = panic_exceptions.ServerError('RT') exception.status = 408 response = self.params.error_handler(self.request, exception) version = self.request.version if self.request else '1.1' self.transport.write(response.output(float(version))) self.transport.close() #self.write_error(exception) except Exception as err: # logger.exception(err) import traceback traceback.print_stack() import ipdb;ipdb.set_trace() import sys; sys.exit(1) self.bail_out("Writing error failed, connection closed {}".format(e)) def bail_out(self, message): exception = ServerError(message) self.write_error(exception) logger.error(message) def cleanup(self): self.parser = None self.request = None self.url = None self.headers = None self._request_handler_task = None def close_if_idle(self): """ Close the connection if a request is not being sent or received :return: boolean - True if closed, false if staying open """ if not self.parser: self.transport.close() return True return False
class HttpProtocol(asyncio.Protocol): """ HTTP 协议 """ # 插槽 __slots__ = ( # 事件循环, 连接 'loop', 'transport', 'connections', 'signal', # 请求参数 'parser', 'request', 'url', 'headers', # 请求配置 'request_handler', 'request_timeout', 'request_max_size', # 连接管理 '_total_request_size', '_timeout_handler', '_last_communication_time') def __init__(self, *, loop, request_handler, error_handler, signal=Signal(), connections={}, request_timeout=60, request_max_size=None): self.loop = loop # 事件循环 self.transport = None self.request = None # 请求 self.parser = None self.url = None # 预留的路径 self.headers = None # 请求头 self.signal = signal # 标志是否结束 self.connections = connections # 连接集合 self.request_handler = request_handler # 请求处理器 self.error_handler = error_handler # 出错处理器 self.request_timeout = request_timeout # 请求超时时间 self.request_max_size = request_max_size # 请求最大大小 self._total_request_size = 0 self._timeout_handler = None self._last_request_time = None self._request_handler_task = None # -------------------------------------------- # # 连接部分 # -------------------------------------------- # def connection_made(self, transport): """ 创建连接 """ self.connections.add(self) self._timeout_handler = self.loop.call_later( self.request_timeout, self.connection_timeout) self.transport = transport self._last_request_time = current_time def connection_lost(self, exc): """ 丢失连接 """ self.connections.discard(self) self._timeout_handler.cancel() self.cleanup() def connection_timeout(self): """ 连接超时 """ time_elapsed = current_time - self._last_request_time # 计算与上次请求间隔 if time_elapsed < self.request_timeout: # 未超时 time_left = self.request_timeout - time_elapsed self._timeout_handler = \ self.loop.call_later(time_left, self.connection_timeout) else: # 超时 if self._request_handler_task: self._request_handler_task.cancel() exception = RequestTimeout('Request Timeout') self.write_error(exception) # -------------------------------------------- # # 解析部分 # -------------------------------------------- # def data_received(self, data): """ 接受数据 """ self._total_request_size += len(data) if self._total_request_size > self.request_max_size: # 请求数据过大 # 在`exceptions.py`中添加 PayloadTooLarge 错误 exception = PayloadTooLarge('Payload Too Large') self.write_error(exception) # 如果是第一次接受数据,创建 parser if self.parser is None: assert self.request is None self.headers = [] self.parser = HttpRequestParser(self) # 解析请求 try: self.parser.feed_data(data) except HttpParserError: exception = InvalidUsage('Bad Request') self.write_error(exception) def on_url(self, url): """ 获得 url """ self.url = url def on_header(self, name, value): """ 补全 HTTP 请求的 head 信息 """ if name == b'Content-Length' and int(value) > self.request_max_size: exception = PayloadTooLarge('Payload Too Large') self.write_error(exception) self.headers.append((name.decode(), value.decode('utf-8'))) def on_headers_complete(self): """ 写入 HTTP 请求 head 信息 """ # 远程地址 remote_addr = self.transport.get_extra_info('peername') if remote_addr: self.headers.append(('Remote-Addr', '%s:%s' % remote_addr)) # HTTP 请求 head self.request = Request( url_bytes=self.url, headers=CIMultiDict(self.headers), version=self.parser.get_http_version(), method=self.parser.get_method().decode() ) def on_body(self, body): """ 写入 HTTP 请求 body """ if self.request.body: self.request.body += body else: self.request.body = body def on_message_complete(self): """ 创建 task """ self._request_handler_task = self.loop.create_task( self.request_handler(self.request, self.write_response)) # -------------------------------------------- # # 响应部分 # -------------------------------------------- # def write_response(self, response): """ 编写 HTTP 响应 """ try: keep_alive = self.parser.should_keep_alive() \ and not self.signal.stopped # 输出响应 self.transport.write( response.output( self.request.version, keep_alive, self.request_timeout)) if not keep_alive: self.transport.close() else: # 记录接收到的数据 self._last_request_time = current_time self.cleanup() except Exception as e: self.bail_out( "Writing response failed, connection closed {}".format(e)) def write_error(self, exception): """ 编写 HTTP 错误响应 """ try: response = self.error_handler.response(self.request, exception) version = self.request.version if self.request else '1.1' self.transport.write(response.output(version)) self.transport.close() except Exception as e: self.bail_out( "Writing error failed, connection closed {}".format(e)) def bail_out(self, message): """ 记录异常辅助方法 """ exception = ServerError(message) self.write_error(exception) log.error(message) def cleanup(self): """ 清空请求字段 """ self.parser = None self.request = None self.url = None self.headers = None self._request_handler_task = None self._total_request_size = 0 def close_if_idle(self): """ 若没有发生或接受请求,则关闭连接 :return: boolean - True 为关, false 为保持开启 """ if not self.parser: self.transport.close() return True return False