class ForwardConnection(object): def __init__(self, remote_address, stream, address, headers): self.remote_address = remote_address self.stream = stream self.address = address self.headers = headers sock = socket.socket() self.remote_stream = IOStream(sock) self.remote_stream.connect(self.remote_address, self._on_remote_connected) self.remote_stream.set_close_callback(self._on_close) def _on_remote_write_complete(self): logging.info('send request to %s', self.remote_address) self.remote_stream.read_until_close(self._on_remote_read_close) def _on_remote_connected(self): logging.info('forward %r to %r', self.address, self.remote_address) self.remote_stream.write(self.headers, self._on_remote_write_complete) def _on_remote_read_close(self, data): self.stream.write(data, self.stream.close) def _on_close(self): logging.info('remote quit %s', self.remote_address) self.remote_stream.close()
class ForwardConnection(object): def __init__(self, remote_address, stream, address): self.remote_address = remote_address self.stream = stream self.address = address sock = socket.socket() self.remote_stream = IOStream(sock) self.remote_stream.connect(self.remote_address, self._on_remote_connected) def _on_remote_connected(self): logging.info("forward %r to %r", self.address, self.remote_address) self.remote_stream.read_until_close(self._on_remote_read_close, self.stream.write) self.stream.read_until_close(self._on_read_close, self.remote_stream.write) def _on_remote_read_close(self, data): if self.stream.writing(): self.stream.write(data, self.stream.close) else: self.stream.close() def _on_read_close(self, data): if self.remote_stream.writing(): self.remote_stream.write(data, self.remote_stream.close) else: self.remote_stream.close()
def test_read_until_close(self): s = socket.socket(socket.AF_INET, socket.SOCK_STREAM, 0) s.connect(("localhost", self.get_http_port())) stream = IOStream(s, io_loop=self.io_loop) stream.write(b("GET / HTTP/1.0\r\n\r\n")) stream.read_until_close(self.stop) data = self.wait() self.assertTrue(data.startswith(b("HTTP/1.0 200"))) self.assertTrue(data.endswith(b("Hello")))
class RemoteUpstream(Upstream): """ The most methods are the same in LocalUpstream, but maybe in future need to be diffrent. """ def initialize(self): self.socket = socket.socket(self._address_type, socket.SOCK_STREAM) self.stream = IOStream(self.socket) self.stream.set_close_callback(self.on_close) def do_connect(self): self.stream.connect(self.dest, self.on_connect) @property def address(self): return self.socket.getsockname() @property def address_type(self): return self._address_type def on_connect(self): self.connection_callback(self) on_finish = functools.partial(self.on_streaming_data, finished=True) self.stream.read_until_close(on_finish, self.on_streaming_data) def on_close(self): if self.stream.error: self.error_callback(self, self.stream.error) else: self.close_callback(self) def on_streaming_data(self, data, finished=False): if len(data): self.streaming_callback(self, data) def do_write(self, data): try: self.stream.write(data) except IOError as e: self.close() def do_close(self): if self.socket: logger.info("close upstream: %s:%s" % self.address) self.stream.close()
class UnixSocketTest(AsyncTestCase): """HTTPServers can listen on Unix sockets too. Why would you want to do this? Nginx can proxy to backends listening on unix sockets, for one thing (and managing a namespace for unix sockets can be easier than managing a bunch of TCP port numbers). Unfortunately, there's no way to specify a unix socket in a url for an HTTP client, so we have to test this by hand. """ def setUp(self): super(UnixSocketTest, self).setUp() self.tmpdir = tempfile.mkdtemp() self.sockfile = os.path.join(self.tmpdir, "test.sock") sock = netutil.bind_unix_socket(self.sockfile) app = Application([("/hello", HelloWorldRequestHandler)]) self.server = HTTPServer(app) self.server.add_socket(sock) self.stream = IOStream(socket.socket(socket.AF_UNIX)) self.stream.connect(self.sockfile, self.stop) self.wait() def tearDown(self): self.stream.close() self.io_loop.run_sync(self.server.close_all_connections) self.server.stop() shutil.rmtree(self.tmpdir) super(UnixSocketTest, self).tearDown() def test_unix_socket(self): self.stream.write(b"GET /hello HTTP/1.0\r\n\r\n") self.stream.read_until(b"\r\n", self.stop) response = self.wait() self.assertEqual(response, b"HTTP/1.1 200 OK\r\n") self.stream.read_until(b"\r\n\r\n", self.stop) headers = HTTPHeaders.parse(self.wait().decode('latin1')) self.stream.read_bytes(int(headers["Content-Length"]), self.stop) body = self.wait() self.assertEqual(body, b"Hello world") def test_unix_socket_bad_request(self): # Unix sockets don't have remote addresses so they just return an # empty string. with ExpectLog(gen_log, "Malformed HTTP message from"): self.stream.write(b"garbage\r\n\r\n") self.stream.read_until_close(self.stop) response = self.wait() self.assertEqual(response, b"HTTP/1.1 400 Bad Request\r\n\r\n")
class ForwardConnection(object): def __init__(self, server, stream, address): self._close_callback = None self.server = server self.stream = stream self.reverse_address = address self.address = stream.socket.getsockname() self.remote_address = server.conf[self.address] sock = socket.socket() self.remote_stream = IOStream(sock) self.remote_stream.connect(self.remote_address, self._on_remote_connected) def close(self): self.remote_stream.close() def set_close_callback(self, callback): self._close_callback = callback def _on_remote_connected(self): ip_from = self.reverse_address[0] fwd_str = get_forwarding_str(self.address[0], self.address[1], self.remote_address[0], self.remote_address[1]) logging.info('Connected ip: %s, forward %s', ip_from, fwd_str) self.remote_stream.read_until_close(self._on_remote_read_close, self.stream.write) self.stream.read_until_close(self._on_read_close, self.remote_stream.write) def _on_remote_read_close(self, data): if self.stream.writing(): self.stream.write(data, self.stream.close) else: if self.stream.closed(): self._on_closed() else: self.stream.close() def _on_read_close(self, data): if self.remote_stream.writing(): self.remote_stream.write(data, self.remote_stream.close) else: if self.remote_stream.closed(): self._on_closed() else: self.remote_stream.close() def _on_closed(self): logging.info('Disconnected ip: %s', self.reverse_address[0]) if self._close_callback: self._close_callback(self)
def test_handle_stream_coroutine_logging(self): # handle_stream may be a coroutine and any exception in its # Future will be logged. class TestServer(TCPServer): @gen.coroutine def handle_stream(self, stream, address): yield gen.moment stream.close() 1 / 0 server = client = None try: sock, port = bind_unused_port() with NullContext(): server = TestServer() server.add_socket(sock) client = IOStream(socket.socket()) with ExpectLog(app_log, "Exception in callback"): yield client.connect(('localhost', port)) yield client.read_until_close() yield gen.moment finally: if server is not None: server.stop() if client is not None: client.close()
def test_timeout(self): stream = IOStream(socket.socket()) try: yield stream.connect(("127.0.0.1", self.get_http_port())) # Use a raw stream because AsyncHTTPClient won't let us read a # response without finishing a body. stream.write(b"PUT /streaming?body_timeout=0.1 HTTP/1.0\r\n" b"Content-Length: 42\r\n\r\n") with ExpectLog(gen_log, "Timeout reading body"): response = yield stream.read_until_close() self.assertEqual(response, b"") finally: stream.close()
def test_timeout(self): stream = IOStream(socket.socket()) try: yield stream.connect(('127.0.0.1', self.get_http_port())) # Use a raw stream because AsyncHTTPClient won't let us read a # response without finishing a body. stream.write(b'PUT /streaming?body_timeout=0.1 HTTP/1.0\r\n' b'Content-Length: 42\r\n\r\n') with ExpectLog(gen_log, 'Timeout reading body'): response = yield stream.read_until_close() self.assertEqual(response, b'') finally: stream.close()
class Client(Session): def __init__(self, protocol, io_loop=None): Session.__init__(self, protocol, io_loop) self.auto_reconnect = True self.reconnect_time = 5 self.connect_time = 5 def connect(self, address): self.status = SessionStream.CONNECTING sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM, 0) self.ios = IOStream(sock, self.io_loop) self.ios.set_close_callback(self._ios_closed) self.address = address self.clear_buffer() self.add_timer(self._connect_timeout, self.connect_time, "connect") self.ios.connect(address, self._connected) def _connected(self): self.status = SessionStream.CONNECTED self.protocol.connected(self) self.remove_timer("connect") self.ios.read_until_close(self._disconnected, self._receiver) def _connect_timeout(self): self.remove_timer("connect") self.status = SessionStream.IDLE self.ios.close() if self.auto_reconnect: self.add_timer(self._do_reconnect, self.reconnect_time, "reconnect") def _ios_closed(self): if self.status == SessionStream.CONNECTING: self.remove_timer("connect") if self.auto_reconnect: self.add_timer(self._do_reconnect, self.reconnect_time, "reconnect") def _do_reconnect(self): self.remove_timer("reconnect") self.connect(self.address)
class EchoClient(object): """ An asynchronous client for EchoServer """ def __init__(self, address, family=socket.AF_INET, socktype=socket.SOCK_STREAM): self.io_stream = IOStream(socket.socket(family, socktype, 0)) self.address = address self.is_closed = False def handle_close(self, data): self.is_closed = True def send_message(self, message, handle_response): def handle_connect(): self.io_stream.read_until_close(self.handle_close, handle_response) self.write(message) self.io_stream.connect(self.address, handle_connect) def write(self, message): if not isinstance(message, bytes): message = message.encode("UTF-8") self.io_stream.write(message)
class UnixSocketTest(AsyncTestCase): """HTTPServers can listen on Unix sockets too. Why would you want to do this? Nginx can proxy to backends listening on unix sockets, for one thing (and managing a namespace for unix sockets can be easier than managing a bunch of TCP port numbers). Unfortunately, there's no way to specify a unix socket in a url for an HTTP client, so we have to test this by hand. """ def setUp(self): super().setUp() self.tmpdir = tempfile.mkdtemp() self.sockfile = os.path.join(self.tmpdir, "test.sock") sock = netutil.bind_unix_socket(self.sockfile) app = Application([("/hello", HelloWorldRequestHandler)]) self.server = HTTPServer(app) self.server.add_socket(sock) self.stream = IOStream(socket.socket(socket.AF_UNIX)) self.io_loop.run_sync(lambda: self.stream.connect(self.sockfile)) def tearDown(self): self.stream.close() self.io_loop.run_sync(self.server.close_all_connections) self.server.stop() shutil.rmtree(self.tmpdir) super().tearDown() @gen_test def test_unix_socket(self): self.stream.write(b"GET /hello HTTP/1.0\r\n\r\n") response = yield self.stream.read_until(b"\r\n") self.assertEqual(response, b"HTTP/1.1 200 OK\r\n") header_data = yield self.stream.read_until(b"\r\n\r\n") headers = HTTPHeaders.parse(header_data.decode("latin1")) body = yield self.stream.read_bytes(int(headers["Content-Length"])) self.assertEqual(body, b"Hello world") @gen_test def test_unix_socket_bad_request(self): # Unix sockets don't have remote addresses so they just return an # empty string. with ExpectLog(gen_log, "Malformed HTTP message from"): self.stream.write(b"garbage\r\n\r\n") response = yield self.stream.read_until_close() self.assertEqual(response, b"HTTP/1.1 400 Bad Request\r\n\r\n")
def test_handle_stream_native_coroutine(self): # handle_stream may be a native coroutine. class TestServer(TCPServer): async def handle_stream(self, stream, address): stream.write(b"data") stream.close() sock, port = bind_unused_port() server = TestServer() server.add_socket(sock) client = IOStream(socket.socket()) yield client.connect(("localhost", port)) result = yield client.read_until_close() self.assertEqual(result, b"data") server.stop() client.close()
def test_body_size_override_reset(self): # The max_body_size override is reset between requests. stream = IOStream(socket.socket()) try: yield stream.connect(("127.0.0.1", self.get_http_port())) # Use a raw stream so we can make sure it's all on one connection. stream.write(b"PUT /streaming?expected_size=10240 HTTP/1.1\r\n" b"Content-Length: 10240\r\n\r\n") stream.write(b"a" * 10240) headers, response = yield gen.Task(read_stream_body, stream) self.assertEqual(response, b"10240") # Without the ?expected_size parameter, we get the old default value stream.write(b"PUT /streaming HTTP/1.1\r\n" b"Content-Length: 10240\r\n\r\n") with ExpectLog(gen_log, ".*Content-Length too long"): data = yield stream.read_until_close() self.assertEqual(data, b"") finally: stream.close()
def test_handle_stream_native_coroutine(self): # handle_stream may be a native coroutine. class TestServer(TCPServer): async def handle_stream(self, stream, address): stream.write(b"data") stream.close() sock, port = bind_unused_port() server = TestServer() server.add_socket(sock) client = IOStream(socket.socket()) yield client.connect(("10.0.0.7", port)) result = yield client.read_until_close() self.assertEqual(result, b"data") server.stop() client.close()
def test_body_size_override_reset(self): # The max_body_size override is reset between requests. stream = IOStream(socket.socket()) try: yield stream.connect(('127.0.0.1', self.get_http_port())) # Use a raw stream so we can make sure it's all on one connection. stream.write(b'PUT /streaming?expected_size=10240 HTTP/1.1\r\n' b'Content-Length: 10240\r\n\r\n') stream.write(b'a' * 10240) headers, response = yield gen.Task(read_stream_body, stream) self.assertEqual(response, b'10240') # Without the ?expected_size parameter, we get the old default value stream.write(b'PUT /streaming HTTP/1.1\r\n' b'Content-Length: 10240\r\n\r\n') with ExpectLog(gen_log, '.*Content-Length too long'): data = yield stream.read_until_close() self.assertEqual(data, b'') finally: stream.close()
def test_body_size_override_reset(self): # The max_body_size override is reset between requests. stream = IOStream(socket.socket()) try: yield stream.connect(("10.0.0.7", self.get_http_port())) # Use a raw stream so we can make sure it's all on one connection. stream.write(b"PUT /streaming?expected_size=10240 HTTP/1.1\r\n" b"Content-Length: 10240\r\n\r\n") stream.write(b"a" * 10240) start_line, headers, response = yield read_stream_body(stream) self.assertEqual(response, b"10240") # Without the ?expected_size parameter, we get the old default value stream.write(b"PUT /streaming HTTP/1.1\r\n" b"Content-Length: 10240\r\n\r\n") with ExpectLog(gen_log, ".*Content-Length too long"): data = yield stream.read_until_close() self.assertEqual(data, b"HTTP/1.1 400 Bad Request\r\n\r\n") finally: stream.close()
def test_handle_stream_native_coroutine(self): # handle_stream may be a native coroutine. namespace = exec_test(globals(), locals(), """ class TestServer(TCPServer): async def handle_stream(self, stream, address): stream.write(b'data') stream.close() """) sock, port = bind_unused_port() server = namespace['TestServer']() server.add_socket(sock) client = IOStream(socket.socket()) yield client.connect(('localhost', port)) result = yield client.read_until_close() self.assertEqual(result, b'data') server.stop() client.close()
def test_body_size_override_reset(self): # The max_body_size override is reset between requests. stream = IOStream(socket.socket()) try: yield stream.connect(('127.0.0.1', self.get_http_port())) # Use a raw stream so we can make sure it's all on one connection. stream.write(b'PUT /streaming?expected_size=10240 HTTP/1.1\r\n' b'Content-Length: 10240\r\n\r\n') stream.write(b'a' * 10240) fut = Future() read_stream_body(stream, callback=fut.set_result) start_line, headers, response = yield fut self.assertEqual(response, b'10240') # Without the ?expected_size parameter, we get the old default value stream.write(b'PUT /streaming HTTP/1.1\r\n' b'Content-Length: 10240\r\n\r\n') with ExpectLog(gen_log, '.*Content-Length too long'): data = yield stream.read_until_close() self.assertEqual(data, b'HTTP/1.1 400 Bad Request\r\n\r\n') finally: stream.close()
class _HTTPConnection(object): _SUPPORTED_METHODS = set( ["GET", "HEAD", "POST", "PUT", "DELETE", "PATCH", "OPTIONS"]) def __init__(self, io_loop, client, request, release_callback, final_callback, max_buffer_size, resolver): self.start_time = io_loop.time() self.io_loop = io_loop self.client = client self.request = request self.release_callback = release_callback self.final_callback = final_callback self.max_buffer_size = max_buffer_size self.resolver = resolver self.code = None self.headers = None self.chunks = None self._decompressor = None # Timeout handle returned by IOLoop.add_timeout self._timeout = None with stack_context.ExceptionStackContext(self._handle_exception): self.parsed = urlparse.urlsplit(_unicode(self.request.url)) if self.parsed.scheme not in ("http", "https"): raise ValueError("Unsupported url scheme: %s" % self.request.url) # urlsplit results have hostname and port results, but they # didn't support ipv6 literals until python 2.7. netloc = self.parsed.netloc if "@" in netloc: userpass, _, netloc = netloc.rpartition("@") match = re.match(r'^(.+):(\d+)$', netloc) if match: host = match.group(1) port = int(match.group(2)) else: host = netloc port = 443 if self.parsed.scheme == "https" else 80 if re.match(r'^\[.*\]$', host): # raw ipv6 addresses in urls are enclosed in brackets host = host[1:-1] self.parsed_hostname = host # save final host for _on_connect if request.allow_ipv6: af = socket.AF_UNSPEC else: # We only try the first IP we get from getaddrinfo, # so restrict to ipv4 by default. af = socket.AF_INET self.resolver.resolve(host, port, af, callback=self._on_resolve) def _on_resolve(self, addrinfo): af, sockaddr = addrinfo[0] if self.parsed.scheme == "https": ssl_options = {} if self.request.validate_cert: ssl_options["cert_reqs"] = ssl.CERT_REQUIRED if self.request.ca_certs is not None: ssl_options["ca_certs"] = self.request.ca_certs else: ssl_options["ca_certs"] = _DEFAULT_CA_CERTS if self.request.client_key is not None: ssl_options["keyfile"] = self.request.client_key if self.request.client_cert is not None: ssl_options["certfile"] = self.request.client_cert # SSL interoperability is tricky. We want to disable # SSLv2 for security reasons; it wasn't disabled by default # until openssl 1.0. The best way to do this is to use # the SSL_OP_NO_SSLv2, but that wasn't exposed to python # until 3.2. Python 2.7 adds the ciphers argument, which # can also be used to disable SSLv2. As a last resort # on python 2.6, we set ssl_version to SSLv3. This is # more narrow than we'd like since it also breaks # compatibility with servers configured for TLSv1 only, # but nearly all servers support SSLv3: # http://blog.ivanristic.com/2011/09/ssl-survey-protocol-support.html if sys.version_info >= (2, 7): ssl_options["ciphers"] = "DEFAULT:!SSLv2" else: # This is really only necessary for pre-1.0 versions # of openssl, but python 2.6 doesn't expose version # information. ssl_options["ssl_version"] = ssl.PROTOCOL_SSLv3 self.stream = SSLIOStream(socket.socket(af), io_loop=self.io_loop, ssl_options=ssl_options, max_buffer_size=self.max_buffer_size) else: self.stream = IOStream(socket.socket(af), io_loop=self.io_loop, max_buffer_size=self.max_buffer_size) timeout = min(self.request.connect_timeout, self.request.request_timeout) if timeout: self._timeout = self.io_loop.add_timeout( self.start_time + timeout, stack_context.wrap(self._on_timeout)) self.stream.set_close_callback(self._on_close) # ipv6 addresses are broken (in self.parsed.hostname) until # 2.7, here is correctly parsed value calculated in __init__ self.stream.connect(sockaddr, self._on_connect, server_hostname=self.parsed_hostname) def _on_timeout(self): self._timeout = None if self.final_callback is not None: raise HTTPError(599, "Timeout") def _remove_timeout(self): if self._timeout is not None: self.io_loop.remove_timeout(self._timeout) self._timeout = None def _on_connect(self): self._remove_timeout() if self.request.request_timeout: self._timeout = self.io_loop.add_timeout( self.start_time + self.request.request_timeout, stack_context.wrap(self._on_timeout)) if (self.request.method not in self._SUPPORTED_METHODS and not self.request.allow_nonstandard_methods): raise KeyError("unknown method %s" % self.request.method) for key in ('network_interface', 'proxy_host', 'proxy_port', 'proxy_username', 'proxy_password'): if getattr(self.request, key, None): raise NotImplementedError('%s not supported' % key) if "Connection" not in self.request.headers: self.request.headers["Connection"] = "close" if "Host" not in self.request.headers: if '@' in self.parsed.netloc: self.request.headers["Host"] = self.parsed.netloc.rpartition( '@')[-1] else: self.request.headers["Host"] = self.parsed.netloc username, password = None, None if self.parsed.username is not None: username, password = self.parsed.username, self.parsed.password elif self.request.auth_username is not None: username = self.request.auth_username password = self.request.auth_password or '' if username is not None: auth = utf8(username) + b":" + utf8(password) self.request.headers["Authorization"] = (b"Basic " + base64.b64encode(auth)) if self.request.user_agent: self.request.headers["User-Agent"] = self.request.user_agent if not self.request.allow_nonstandard_methods: if self.request.method in ("POST", "PATCH", "PUT"): assert self.request.body is not None else: assert self.request.body is None if self.request.body is not None: self.request.headers["Content-Length"] = str(len( self.request.body)) if (self.request.method == "POST" and "Content-Type" not in self.request.headers): self.request.headers[ "Content-Type"] = "application/x-www-form-urlencoded" if self.request.use_gzip: self.request.headers["Accept-Encoding"] = "gzip" req_path = ((self.parsed.path or '/') + (('?' + self.parsed.query) if self.parsed.query else '')) request_lines = [ utf8("%s %s HTTP/1.1" % (self.request.method, req_path)) ] for k, v in self.request.headers.get_all(): line = utf8(k) + b": " + utf8(v) if b'\n' in line: raise ValueError('Newline in header: ' + repr(line)) request_lines.append(line) self.stream.write(b"\r\n".join(request_lines) + b"\r\n\r\n") if self.request.body is not None: self.stream.write(self.request.body) self.stream.read_until_regex(b"\r?\n\r?\n", self._on_headers) def _release(self): if self.release_callback is not None: release_callback = self.release_callback self.release_callback = None release_callback() def _run_callback(self, response): self._release() if self.final_callback is not None: final_callback = self.final_callback self.final_callback = None self.io_loop.add_callback(final_callback, response) def _handle_exception(self, typ, value, tb): if self.final_callback: self._remove_timeout() gen_log.warning("uncaught exception", exc_info=(typ, value, tb)) self._run_callback( HTTPResponse( self.request, 599, error=value, request_time=self.io_loop.time() - self.start_time, )) if hasattr(self, "stream"): self.stream.close() return True else: # If our callback has already been called, we are probably # catching an exception that is not caused by us but rather # some child of our callback. Rather than drop it on the floor, # pass it along. return False def _on_close(self): if self.final_callback is not None: message = "Connection closed" if self.stream.error: message = str(self.stream.error) raise HTTPError(599, message) def _handle_1xx(self, code): self.stream.read_until_regex(b"\r?\n\r?\n", self._on_headers) def _on_headers(self, data): data = native_str(data.decode("latin1")) first_line, _, header_data = data.partition("\n") match = re.match("HTTP/1.[01] ([0-9]+) ([^\r]*)", first_line) assert match code = int(match.group(1)) self.headers = HTTPHeaders.parse(header_data) if 100 <= code < 200: self._handle_1xx(code) return else: self.code = code self.reason = match.group(2) if "Content-Length" in self.headers: if "," in self.headers["Content-Length"]: # Proxies sometimes cause Content-Length headers to get # duplicated. If all the values are identical then we can # use them but if they differ it's an error. pieces = re.split(r',\s*', self.headers["Content-Length"]) if any(i != pieces[0] for i in pieces): raise ValueError("Multiple unequal Content-Lengths: %r" % self.headers["Content-Length"]) self.headers["Content-Length"] = pieces[0] content_length = int(self.headers["Content-Length"]) else: content_length = None if self.request.header_callback is not None: # re-attach the newline we split on earlier self.request.header_callback(first_line + _) for k, v in self.headers.get_all(): self.request.header_callback("%s: %s\r\n" % (k, v)) self.request.header_callback('\r\n') if self.request.method == "HEAD" or self.code == 304: # HEAD requests and 304 responses never have content, even # though they may have content-length headers self._on_body(b"") return if 100 <= self.code < 200 or self.code == 204: # These response codes never have bodies # http://www.w3.org/Protocols/rfc2616/rfc2616-sec4.html#sec4.3 if ("Transfer-Encoding" in self.headers or content_length not in (None, 0)): raise ValueError("Response with code %d should not have body" % self.code) self._on_body(b"") return if (self.request.use_gzip and self.headers.get("Content-Encoding") == "gzip"): self._decompressor = GzipDecompressor() if self.headers.get("Transfer-Encoding") == "chunked": self.chunks = [] self.stream.read_until(b"\r\n", self._on_chunk_length) elif content_length is not None: self.stream.read_bytes(content_length, self._on_body) else: self.stream.read_until_close(self._on_body) def _on_body(self, data): self._remove_timeout() original_request = getattr(self.request, "original_request", self.request) if (self.request.follow_redirects and self.request.max_redirects > 0 and self.code in (301, 302, 303, 307)): assert isinstance(self.request, _RequestProxy) new_request = copy.copy(self.request.request) new_request.url = urlparse.urljoin(self.request.url, self.headers["Location"]) new_request.max_redirects = self.request.max_redirects - 1 del new_request.headers["Host"] # http://www.w3.org/Protocols/rfc2616/rfc2616-sec10.html#sec10.3.4 # Client SHOULD make a GET request after a 303. # According to the spec, 302 should be followed by the same # method as the original request, but in practice browsers # treat 302 the same as 303, and many servers use 302 for # compatibility with pre-HTTP/1.1 user agents which don't # understand the 303 status. if self.code in (302, 303): new_request.method = "GET" new_request.body = None for h in [ "Content-Length", "Content-Type", "Content-Encoding", "Transfer-Encoding" ]: try: del self.request.headers[h] except KeyError: pass new_request.original_request = original_request final_callback = self.final_callback self.final_callback = None self._release() self.client.fetch(new_request, final_callback) self.stream.close() return if self._decompressor: data = (self._decompressor.decompress(data) + self._decompressor.flush()) if self.request.streaming_callback: if self.chunks is None: # if chunks is not None, we already called streaming_callback # in _on_chunk_data self.request.streaming_callback(data) buffer = BytesIO() else: buffer = BytesIO(data) # TODO: don't require one big string? response = HTTPResponse(original_request, self.code, reason=self.reason, headers=self.headers, request_time=self.io_loop.time() - self.start_time, buffer=buffer, effective_url=self.request.url) self._run_callback(response) self.stream.close() def _on_chunk_length(self, data): # TODO: "chunk extensions" http://tools.ietf.org/html/rfc2616#section-3.6.1 length = int(data.strip(), 16) if length == 0: if self._decompressor is not None: tail = self._decompressor.flush() if tail: # I believe the tail will always be empty (i.e. # decompress will return all it can). The purpose # of the flush call is to detect errors such # as truncated input. But in case it ever returns # anything, treat it as an extra chunk if self.request.streaming_callback is not None: self.request.streaming_callback(tail) else: self.chunks.append(tail) # all the data has been decompressed, so we don't need to # decompress again in _on_body self._decompressor = None self._on_body(b''.join(self.chunks)) else: self.stream.read_bytes( length + 2, # chunk ends with \r\n self._on_chunk_data) def _on_chunk_data(self, data): assert data[-2:] == b"\r\n" chunk = data[:-2] if self._decompressor: chunk = self._decompressor.decompress(chunk) if self.request.streaming_callback is not None: self.request.streaming_callback(chunk) else: self.chunks.append(chunk) self.stream.read_until(b"\r\n", self._on_chunk_length)
class _HTTPConnection(object): _SUPPORTED_METHODS = set( ["GET", "HEAD", "POST", "PUT", "DELETE", "PATCH", "OPTIONS"]) def __init__(self, io_loop, client, request, release_callback, final_callback, max_buffer_size): self.start_time = time.time() self.io_loop = io_loop self.client = client self.request = request self.release_callback = release_callback self.final_callback = final_callback self.code = None self.headers = None self.chunks = None self._decompressor = None # Timeout handle returned by IOLoop.add_timeout self._timeout = None with stack_context.StackContext(self.cleanup): parsed = urlparse.urlsplit(_unicode(self.request.url)) if ssl is None and parsed.scheme == "https": raise ValueError("HTTPS requires either python2.6+ or " "curl_httpclient") if parsed.scheme not in ("http", "https"): raise ValueError("Unsupported url scheme: %s" % self.request.url) # urlsplit results have hostname and port results, but they # didn't support ipv6 literals until python 2.7. netloc = parsed.netloc if "@" in netloc: userpass, _, netloc = netloc.rpartition("@") match = re.match(r'^(.+):(\d+)$', netloc) if match: host = match.group(1) port = int(match.group(2)) else: host = netloc port = 443 if parsed.scheme == "https" else 80 if re.match(r'^\[.*\]$', host): # raw ipv6 addresses in urls are enclosed in brackets host = host[1:-1] parsed_hostname = host # save final parsed host for _on_connect if self.client.hostname_mapping is not None: host = self.client.hostname_mapping.get(host, host) if request.allow_ipv6: af = socket.AF_UNSPEC else: # We only try the first IP we get from getaddrinfo, # so restrict to ipv4 by default. af = socket.AF_INET addrinfo = socket.getaddrinfo(host, port, af, socket.SOCK_STREAM, 0, 0) af, socktype, proto, canonname, sockaddr = addrinfo[0] if parsed.scheme == "https": ssl_options = {} if request.validate_cert: ssl_options["cert_reqs"] = ssl.CERT_REQUIRED if request.ca_certs is not None: ssl_options["ca_certs"] = request.ca_certs else: ssl_options["ca_certs"] = _DEFAULT_CA_CERTS if request.client_key is not None: ssl_options["keyfile"] = request.client_key if request.client_cert is not None: ssl_options["certfile"] = request.client_cert # SSL interoperability is tricky. We want to disable # SSLv2 for security reasons; it wasn't disabled by default # until openssl 1.0. The best way to do this is to use # the SSL_OP_NO_SSLv2, but that wasn't exposed to python # until 3.2. Python 2.7 adds the ciphers argument, which # can also be used to disable SSLv2. As a last resort # on python 2.6, we set ssl_version to SSLv3. This is # more narrow than we'd like since it also breaks # compatibility with servers configured for TLSv1 only, # but nearly all servers support SSLv3: # http://blog.ivanristic.com/2011/09/ssl-survey-protocol-support.html if sys.version_info >= (2, 7): ssl_options["ciphers"] = "DEFAULT:!SSLv2" else: # This is really only necessary for pre-1.0 versions # of openssl, but python 2.6 doesn't expose version # information. ssl_options["ssl_version"] = ssl.PROTOCOL_SSLv3 self.stream = SSLIOStream(socket.socket(af, socktype, proto), io_loop=self.io_loop, ssl_options=ssl_options, max_buffer_size=max_buffer_size) else: self.stream = IOStream(socket.socket(af, socktype, proto), io_loop=self.io_loop, max_buffer_size=max_buffer_size) timeout = min(request.connect_timeout, request.request_timeout) if timeout: self._timeout = self.io_loop.add_timeout( self.start_time + timeout, self._on_timeout) self.stream.set_close_callback(self._on_close) self.stream.connect( sockaddr, functools.partial(self._on_connect, parsed, parsed_hostname)) def _on_timeout(self): self._timeout = None self._run_callback( HTTPResponse(self.request, 599, request_time=time.time() - self.start_time, error=HTTPError(599, "Timeout"))) self.stream.close() def _on_connect(self, parsed, parsed_hostname): if self._timeout is not None: self.io_loop.remove_timeout(self._timeout) self._timeout = None if self.request.request_timeout: self._timeout = self.io_loop.add_timeout( self.start_time + self.request.request_timeout, self._on_timeout) if (self.request.validate_cert and isinstance(self.stream, SSLIOStream)): match_hostname( self.stream.socket.getpeercert(), # ipv6 addresses are broken (in # parsed.hostname) until 2.7, here is # correctly parsed value calculated in # __init__ parsed_hostname) if (self.request.method not in self._SUPPORTED_METHODS and not self.request.allow_nonstandard_methods): raise KeyError("unknown method %s" % self.request.method) for key in ('network_interface', 'proxy_host', 'proxy_port', 'proxy_username', 'proxy_password'): if getattr(self.request, key, None): raise NotImplementedError('%s not supported' % key) if "Connection" not in self.request.headers: self.request.headers["Connection"] = "close" if "Host" not in self.request.headers: if '@' in parsed.netloc: self.request.headers["Host"] = parsed.netloc.rpartition( '@')[-1] else: self.request.headers["Host"] = parsed.netloc username, password = None, None if parsed.username is not None: username, password = parsed.username, parsed.password elif self.request.auth_username is not None: username = self.request.auth_username password = self.request.auth_password or '' if username is not None: auth = utf8(username) + b(":") + utf8(password) self.request.headers["Authorization"] = (b("Basic ") + base64.b64encode(auth)) if self.request.user_agent: self.request.headers["User-Agent"] = self.request.user_agent if not self.request.allow_nonstandard_methods: if self.request.method in ("POST", "PATCH", "PUT"): assert self.request.body is not None else: assert self.request.body is None if self.request.body is not None: self.request.headers["Content-Length"] = str(len( self.request.body)) if (self.request.method == "POST" and "Content-Type" not in self.request.headers): self.request.headers[ "Content-Type"] = "application/x-www-form-urlencoded" if self.request.use_gzip: self.request.headers["Accept-Encoding"] = "gzip" req_path = ((parsed.path or '/') + (('?' + parsed.query) if parsed.query else '')) request_lines = [ utf8("%s %s HTTP/1.1" % (self.request.method, req_path)) ] for k, v in self.request.headers.get_all(): line = utf8(k) + b(": ") + utf8(v) if b('\n') in line: raise ValueError('Newline in header: ' + repr(line)) request_lines.append(line) self.stream.write(b("\r\n").join(request_lines) + b("\r\n\r\n")) if self.request.body is not None: self.stream.write(self.request.body) self.stream.read_until_regex(b("\r?\n\r?\n"), self._on_headers) def _release(self): if self.release_callback is not None: release_callback = self.release_callback self.release_callback = None release_callback() def _run_callback(self, response): self._release() if self.final_callback is not None: final_callback = self.final_callback self.final_callback = None final_callback(response) @contextlib.contextmanager def cleanup(self): try: yield except Exception as e: logging.warning("uncaught exception", exc_info=True) self._run_callback( HTTPResponse( self.request, 599, error=e, request_time=time.time() - self.start_time, )) if hasattr(self, "stream"): self.stream.close() def _on_close(self): self._run_callback( HTTPResponse(self.request, 599, request_time=time.time() - self.start_time, error=HTTPError(599, "Connection closed"))) def _on_headers(self, data): data = native_str(data.decode("latin1")) first_line, _, header_data = data.partition("\n") match = re.match("HTTP/1.[01] ([0-9]+)", first_line) assert match self.code = int(match.group(1)) self.headers = HTTPHeaders.parse(header_data) if "Content-Length" in self.headers: if "," in self.headers["Content-Length"]: # Proxies sometimes cause Content-Length headers to get # duplicated. If all the values are identical then we can # use them but if they differ it's an error. pieces = re.split(r',\s*', self.headers["Content-Length"]) if any(i != pieces[0] for i in pieces): raise ValueError("Multiple unequal Content-Lengths: %r" % self.headers["Content-Length"]) self.headers["Content-Length"] = pieces[0] content_length = int(self.headers["Content-Length"]) else: content_length = None if self.request.header_callback is not None: for k, v in self.headers.get_all(): self.request.header_callback("%s: %s\r\n" % (k, v)) if self.request.method == "HEAD": # HEAD requests never have content, even though they may have # content-length headers self._on_body(b("")) return if 100 <= self.code < 200 or self.code in (204, 304): # These response codes never have bodies # http://www.w3.org/Protocols/rfc2616/rfc2616-sec4.html#sec4.3 assert "Transfer-Encoding" not in self.headers assert content_length in (None, 0) self._on_body(b("")) return if (self.request.use_gzip and self.headers.get("Content-Encoding") == "gzip"): # Magic parameter makes zlib module understand gzip header # http://stackoverflow.com/questions/1838699/how-can-i-decompress-a-gzip-stream-with-zlib self._decompressor = zlib.decompressobj(16 + zlib.MAX_WBITS) if self.headers.get("Transfer-Encoding") == "chunked": self.chunks = [] self.stream.read_until(b("\r\n"), self._on_chunk_length) elif content_length is not None: self.stream.read_bytes(content_length, self._on_body) else: self.stream.read_until_close(self._on_body) def _on_body(self, data): if self._timeout is not None: self.io_loop.remove_timeout(self._timeout) self._timeout = None original_request = getattr(self.request, "original_request", self.request) if (self.request.follow_redirects and self.request.max_redirects > 0 and self.code in (301, 302, 303, 307)): new_request = copy.copy(self.request) new_request.url = urlparse.urljoin(self.request.url, self.headers["Location"]) new_request.max_redirects -= 1 del new_request.headers["Host"] # http://www.w3.org/Protocols/rfc2616/rfc2616-sec10.html#sec10.3.4 # client SHOULD make a GET request if self.code == 303: new_request.method = "GET" new_request.body = None for h in [ "Content-Length", "Content-Type", "Content-Encoding", "Transfer-Encoding" ]: try: del self.request.headers[h] except KeyError: pass new_request.original_request = original_request final_callback = self.final_callback self.final_callback = None self._release() self.client.fetch(new_request, final_callback) self.stream.close() return if self._decompressor: data = self._decompressor.decompress(data) if self.request.streaming_callback: if self.chunks is None: # if chunks is not None, we already called streaming_callback # in _on_chunk_data self.request.streaming_callback(data) buffer = BytesIO() else: buffer = BytesIO(data) # TODO: don't require one big string? response = HTTPResponse(original_request, self.code, headers=self.headers, request_time=time.time() - self.start_time, buffer=buffer, effective_url=self.request.url) self._run_callback(response) self.stream.close() def _on_chunk_length(self, data): # TODO: "chunk extensions" http://tools.ietf.org/html/rfc2616#section-3.6.1 length = int(data.strip(), 16) if length == 0: # all the data has been decompressed, so we don't need to # decompress again in _on_body self._decompressor = None self._on_body(b('').join(self.chunks)) else: self.stream.read_bytes( length + 2, # chunk ends with \r\n self._on_chunk_data) def _on_chunk_data(self, data): assert data[-2:] == b("\r\n") chunk = data[:-2] if self._decompressor: chunk = self._decompressor.decompress(chunk) if self.request.streaming_callback is not None: self.request.streaming_callback(chunk) else: self.chunks.append(chunk) self.stream.read_until(b("\r\n"), self._on_chunk_length)
class WebSocketProxyHandler(WebSocketHandler): """ Proxy a websocket connection to a service listening on a given (host, port) pair """ def initialize(self, **kwargs): self.remote_address = kwargs.get("address") self.io_stream = IOStream(socket.socket(kwargs.get("family", socket.AF_INET), kwargs.get("type", socket.SOCK_STREAM), 0)) self.filters = kwargs.get("filters", []) self.io_stream.set_close_callback(self.on_close) def open(self): """ Open the connection to the service when the WebSocket connection has been established """ logger.info("Forwarding connection to server %s" % tuple_to_address(self.remote_address)) self.io_stream.connect(self.remote_address, self.on_connect) def on_message(self, message): """ On message received from WebSocket, forward data to the service """ try: data = None if message is None else bytes(message) for filtr in self.filters: data = filtr.ws_to_socket(data=data) if data: self.io_stream.write(data) except Exception as e: logger.exception(e) self.close() def on_close(self, *args, **kwargs): """ When web socket gets closed, close the connection to the service too """ logger.info("Closing connection with peer at %s" % tuple_to_address(self.remote_address)) logger.debug("Received args %s and %s", args, kwargs) #if not self.io_stream._closed: for message in args: self.on_peer_message(message) if not self.io_stream.closed(): self.io_stream.close() self.close() def on_connect(self): """ Callback invoked on connection with mapped service """ logger.info("Connection established with peer at %s" % tuple_to_address(self.remote_address)) self.io_stream.read_until_close(self.on_close, self.on_peer_message) def on_peer_message(self, message): """ On message received from peer service, send back to client through WebSocket """ try: data = None if message is None else bytes(message) for filtr in self.filters: data = filtr.socket_to_ws(data=data) if data: self.write_message(data, binary=True) except FilterException as e: logger.exception(e) self.on_close()
class BaseTornadoClient(AsyncModbusClientMixin): """ Base Tornado client """ stream = None io_loop = None def __init__(self, *args, **kwargs): """ Initializes BaseTornadoClient. ioloop to be passed as part of kwargs ('ioloop') :param args: :param kwargs: """ self.io_loop = kwargs.pop("ioloop", None) super(BaseTornadoClient, self).__init__(*args, **kwargs) @abc.abstractmethod def get_socket(self): """ return instance of the socket to connect to """ @gen.coroutine def connect(self): """ Connect to the socket identified by host and port :returns: Future :rtype: tornado.concurrent.Future """ conn = self.get_socket() self.stream = IOStream(conn, io_loop=self.io_loop or IOLoop.current()) self.stream.connect((self.host, self.port)) self.stream.read_until_close(None, streaming_callback=self.on_receive) self._connected = True LOGGER.debug("Client connected") raise gen.Return(self) def on_receive(self, *args): """ On data recieve call back :param args: data received :return: """ data = args[0] if len(args) > 0 else None if not data: return LOGGER.debug("recv: " + hexlify_packets(data)) unit = self.framer.decode_data(data).get("unit", 0) self.framer.processIncomingPacket(data, self._handle_response, unit=unit) def execute(self, request=None): """ Executes a transaction :param request: :return: """ request.transaction_id = self.transaction.getNextTID() packet = self.framer.buildPacket(request) LOGGER.debug("send: " + hexlify_packets(packet)) self.stream.write(packet) return self._build_response(request.transaction_id) def _handle_response(self, reply, **kwargs): """ Handle response received :param reply: :param kwargs: :return: """ if reply is not None: tid = reply.transaction_id future = self.transaction.getTransaction(tid) if future: future.set_result(reply) else: LOGGER.debug("Unrequested message: {}".format(reply)) def _build_response(self, tid): """ Builds a future response :param tid: :return: """ f = Future() if not self._connected: f.set_exception(ConnectionException("Client is not connected")) return f self.transaction.addTransaction(f, tid) return f def close(self): """ Closes the underlying IOStream """ LOGGER.debug("Client disconnected") if self.stream: self.stream.close_fd() self.stream = None self._connected = False
class HTTPServerRawTest(AsyncHTTPTestCase): def get_app(self): return Application([("/echo", EchoHandler)]) def setUp(self): super(HTTPServerRawTest, self).setUp() self.stream = IOStream(socket.socket()) self.io_loop.run_sync(lambda: self.stream.connect( ("10.0.0.7", self.get_http_port()))) def tearDown(self): self.stream.close() super(HTTPServerRawTest, self).tearDown() def test_empty_request(self): self.stream.close() self.io_loop.add_timeout(datetime.timedelta(seconds=0.001), self.stop) self.wait() def test_malformed_first_line_response(self): with ExpectLog(gen_log, ".*Malformed HTTP request line"): self.stream.write(b"asdf\r\n\r\n") start_line, headers, response = self.io_loop.run_sync( lambda: read_stream_body(self.stream)) self.assertEqual("HTTP/1.1", start_line.version) self.assertEqual(400, start_line.code) self.assertEqual("Bad Request", start_line.reason) def test_malformed_first_line_log(self): with ExpectLog(gen_log, ".*Malformed HTTP request line"): self.stream.write(b"asdf\r\n\r\n") # TODO: need an async version of ExpectLog so we don't need # hard-coded timeouts here. self.io_loop.add_timeout(datetime.timedelta(seconds=0.05), self.stop) self.wait() def test_malformed_headers(self): with ExpectLog(gen_log, ".*Malformed HTTP message.*no colon in header line"): self.stream.write(b"GET / HTTP/1.0\r\nasdf\r\n\r\n") self.io_loop.add_timeout(datetime.timedelta(seconds=0.05), self.stop) self.wait() def test_chunked_request_body(self): # Chunked requests are not widely supported and we don't have a way # to generate them in AsyncHTTPClient, but HTTPServer will read them. self.stream.write(b"""\ POST /echo HTTP/1.1 Transfer-Encoding: chunked Content-Type: application/x-www-form-urlencoded 4 foo= 3 bar 0 """.replace(b"\n", b"\r\n")) start_line, headers, response = self.io_loop.run_sync( lambda: read_stream_body(self.stream)) self.assertEqual(json_decode(response), {u"foo": [u"bar"]}) def test_chunked_request_uppercase(self): # As per RFC 2616 section 3.6, "Transfer-Encoding" header's value is # case-insensitive. self.stream.write(b"""\ POST /echo HTTP/1.1 Transfer-Encoding: Chunked Content-Type: application/x-www-form-urlencoded 4 foo= 3 bar 0 """.replace(b"\n", b"\r\n")) start_line, headers, response = self.io_loop.run_sync( lambda: read_stream_body(self.stream)) self.assertEqual(json_decode(response), {u"foo": [u"bar"]}) @gen_test def test_invalid_content_length(self): with ExpectLog(gen_log, ".*Only integer Content-Length is allowed"): self.stream.write(b"""\ POST /echo HTTP/1.1 Content-Length: foo bar """.replace(b"\n", b"\r\n")) yield self.stream.read_until_close()
class _RedisConnection(object): def __init__(self, io_loop, init_buf, final_callback, redis_tuple, redis_pass): """ :param io_loop: 你懂的 :param init_buf: 第一次写入 :param final_callback: resp赋值时调用 :param redis_tuple: (ip, port, db) :param redis_pass: redis密码 """ self.__io_loop = io_loop self.__final_cb = final_callback self.__stream = None #redis应答解析remain self.__recv_buf = '' init_buf = init_buf or '' init_buf = chain_select_cmd(redis_tuple[2], init_buf) if redis_pass is None: self.__init_buf = (init_buf,) else: assert redis_pass and isinstance(redis_pass, str) self.__init_buf = (redis_auth(redis_pass), init_buf) self.__haspass = redis_pass is not None self.__init_buf = ''.join(self.__init_buf) self.__connected = False #redis指令上下文, connect指令个数(AUTH, SELECT .etc),trans,cmd_count self.__cmd_env = deque() self.__written = False def connect(self, init_future, redis_tuple, active_trans, cmd_count): """ :param init_future: 第一个future对象 :param redis_tuple: (ip, port, db) :param active_trans: 事务是否激活 :param cmd_count: 指令个数 """ if self.__stream is not None: return #future, connect_count, transaction, cmd_count self.__cmd_env.append((init_future, 1 + int(self.__haspass), active_trans, cmd_count)) with ExceptionStackContext(self.__handle_ex): self.__stream = IOStream(socket.socket(socket.AF_INET, socket.SOCK_STREAM, 0), io_loop=self.__io_loop) self.__stream.connect(redis_tuple[:2], self.__on_connect) def write(self, write_buf, new_future, include_select, active_trans, cmd_count, by_connect=False): """ :param new_future: 由于闭包的影响,在resp回调函数中会保存上一次的future对象,该对象必须得到更新 :param include_select: 是否包含SELECT指令 :param active_trans: 事务是否激活 :param cmd_count: 指令个数 """ if by_connect: self.__stream.write(self.__init_buf) self.__init_buf = None return self.__cmd_env.append((new_future, int(include_select), active_trans, cmd_count)) if not self.__connected: self.__init_buf = ''.join((self.__init_buf, write_buf)) return if self.__init_buf: write_buf = ''.join((self.__init_buf, write_buf)) self.__stream.write(write_buf) self.__init_buf = None def __on_connect(self): """连接,只需要发送初始cmd即可 """ self.__connected = True self.__stream.set_nodelay(True) self.write(None, None, None, None, None, True) self.__stream.read_until_close(None, self.__on_resp) def __on_resp(self, recv): """ :param recv: 收到的buf """ recv = ''.join((self.__recv_buf, recv)) idx = 0 for future, connect_count, trans, count in self.__cmd_env: ok, payload, recv = decode_resp_ondemand(recv, connect_count, trans, count) if not ok: break idx += 1 self.__run_callback({_RESP_FUTURE: future, RESP_RESULT: payload}) self.__recv_buf = recv for _ in xrange(idx): self.__cmd_env.popleft() def __run_callback(self, resp): if self.__final_cb is None: return self.__io_loop.add_callback(self.__final_cb, resp) def __handle_ex(self, typ, value, tb): """ :param typ: 异常类型 """ if self.__final_cb: self.__run_callback({RESP_ERR: value}) return True return False
class _RedisConnection(object): def __init__(self, io_loop, write_buf, final_callback, redis_tuple, redis_pass): """ :param io_loop: 你懂的 :param write_buf: 第一次写入 :param final_callback: resp赋值时调用 :param redis_tuple: (ip, port, db) :param redis_pass: redis密码 """ self.__io_loop = io_loop self.__final_cb = final_callback self.__stream = None #redis应答解析remain self.__recv_buf = '' self.__write_buf = write_buf init_buf = '' init_buf = chain_select_cmd(redis_tuple[2], init_buf) if redis_pass is None: self.__init_buf = (init_buf, ) else: assert redis_pass and isinstance(redis_pass, str) self.__init_buf = (redis_auth(redis_pass), init_buf) self.__haspass = redis_pass is not None self.__init_buf = ''.join(self.__init_buf) self.__connect_state = CONNECT_INIT #redis指令上下文, connect指令个数(AUTH, SELECT .etc),trans,cmd_count self.__cmd_env = deque() self.__written = False def connect(self, init_future, redis_tuple, active_trans, cmd_count): """ :param init_future: 第一个future对象 :param redis_tuple: (ip, port, db) :param active_trans: 事务是否激活 :param cmd_count: 指令个数 """ if self.__stream is not None: return #future, connect_count, transaction, cmd_count self.__cmd_env.append((init_future, 1 + int(self.__haspass), False, 0)) self.__cmd_env.append((init_future, 0, active_trans, cmd_count)) with ExceptionStackContext(self.__handle_ex): self.__stream = IOStream(socket.socket(socket.AF_INET, socket.SOCK_STREAM, 0), io_loop=self.__io_loop) self.__stream.set_close_callback(self.__on_close) self.__stream.connect(redis_tuple[:2], self.__on_connect) self.__connect_state = CONNECT_ING def connect_state(self): return self.__connect_state def write(self, write_buf, new_future, include_select, active_trans, cmd_count, by_connect=False): """ :param new_future: 由于闭包的影响,在resp回调函数中会保存上一次的future对象,该对象必须得到更新 :param include_select: 是否包含SELECT指令 :param active_trans: 事务是否激活 :param cmd_count: 指令个数 """ if by_connect: self.__stream.write(self.__init_buf) self.__init_buf = None if self.__write_buf: self.__stream.write(self.__write_buf) self.__write_buf = None return self.__cmd_env.append( (new_future, int(include_select), active_trans, cmd_count)) if self.__connect_state == CONNECT_ING: self.__write_buf = ''.join((self.__write_buf, write_buf)) return if self.__write_buf: write_buf = ''.join((self.__write_buf, write_buf)) self.__stream.write(write_buf) self.__write_buf = None def __on_connect(self): """连接,只需要发送初始cmd即可 """ self.__connect_state = CONNECT_SUCC self.__stream.set_nodelay(True) self.write(None, None, None, None, None, True) self.__stream.read_until_close(None, self.__on_resp) def __on_resp(self, recv): """ :param recv: 收到的buf """ recv = ''.join((self.__recv_buf, recv)) idx = 0 for future, connect_count, trans, count in self.__cmd_env: ok, payload, recv = decode_resp_ondemand(recv, connect_count, trans, count) if not ok: break idx += 1 if count > 0: self.__run_callback({ _RESP_FUTURE: future, RESP_RESULT: payload }) self.__recv_buf = recv for _ in xrange(idx): self.__cmd_env.popleft() def __on_close(self): self.__connect_state = CONNECT_INIT if self.__final_cb: if self.__stream.error: self.__run_callback({RESP_ERR: self.__stream.error}) def __run_callback(self, resp): if self.__final_cb is None: return self.__io_loop.add_callback(self.__final_cb, resp) def __handle_ex(self, typ, value, tb): """ :param typ: 异常类型 """ if self.__final_cb: self.__run_callback({RESP_ERR: value}) return True return False
class KeepAliveTest(AsyncHTTPTestCase): """Tests various scenarios for HTTP 1.1 keep-alive support. These tests don't use AsyncHTTPClient because we want to control connection reuse and closing. """ def get_app(self): class HelloHandler(RequestHandler): def get(self): self.finish('Hello world') class LargeHandler(RequestHandler): def get(self): # 512KB should be bigger than the socket buffers so it will # be written out in chunks. self.write(''.join(chr(i % 256) * 1024 for i in range(512))) class FinishOnCloseHandler(RequestHandler): @asynchronous def get(self): self.flush() def on_connection_close(self): # This is not very realistic, but finishing the request # from the close callback has the right timing to mimic # some errors seen in the wild. self.finish('closed') return Application([('/', HelloHandler), ('/large', LargeHandler), ('/finish_on_close', FinishOnCloseHandler)]) def setUp(self): super(KeepAliveTest, self).setUp() self.http_version = b'HTTP/1.1' def tearDown(self): # We just closed the client side of the socket; let the IOLoop run # once to make sure the server side got the message. self.io_loop.add_timeout(datetime.timedelta(seconds=0.001), self.stop) self.wait() if hasattr(self, 'stream'): self.stream.close() super(KeepAliveTest, self).tearDown() # The next few methods are a crude manual http client def connect(self): self.stream = IOStream(socket.socket(), io_loop=self.io_loop) self.stream.connect(('localhost', self.get_http_port()), self.stop) self.wait() def read_headers(self): self.stream.read_until(b'\r\n', self.stop) first_line = self.wait() self.assertTrue(first_line.startswith(self.http_version + b' 200'), first_line) self.stream.read_until(b'\r\n\r\n', self.stop) header_bytes = self.wait() headers = HTTPHeaders.parse(header_bytes.decode('latin1')) return headers def read_response(self): self.headers = self.read_headers() self.stream.read_bytes(int(self.headers['Content-Length']), self.stop) body = self.wait() self.assertEqual(b'Hello world', body) def close(self): self.stream.close() del self.stream def test_two_requests(self): self.connect() self.stream.write(b'GET / HTTP/1.1\r\n\r\n') self.read_response() self.stream.write(b'GET / HTTP/1.1\r\n\r\n') self.read_response() self.close() def test_request_close(self): self.connect() self.stream.write(b'GET / HTTP/1.1\r\nConnection: close\r\n\r\n') self.read_response() self.stream.read_until_close(callback=self.stop) data = self.wait() self.assertTrue(not data) self.close() # keepalive is supported for http 1.0 too, but it's opt-in def test_http10(self): self.http_version = b'HTTP/1.0' self.connect() self.stream.write(b'GET / HTTP/1.0\r\n\r\n') self.read_response() self.stream.read_until_close(callback=self.stop) data = self.wait() self.assertTrue(not data) self.assertTrue('Connection' not in self.headers) self.close() def test_http10_keepalive(self): self.http_version = b'HTTP/1.0' self.connect() self.stream.write(b'GET / HTTP/1.0\r\nConnection: keep-alive\r\n\r\n') self.read_response() self.assertEqual(self.headers['Connection'], 'Keep-Alive') self.stream.write(b'GET / HTTP/1.0\r\nConnection: keep-alive\r\n\r\n') self.read_response() self.assertEqual(self.headers['Connection'], 'Keep-Alive') self.close() def test_pipelined_requests(self): self.connect() self.stream.write(b'GET / HTTP/1.1\r\n\r\nGET / HTTP/1.1\r\n\r\n') self.read_response() self.read_response() self.close() def test_pipelined_cancel(self): self.connect() self.stream.write(b'GET / HTTP/1.1\r\n\r\nGET / HTTP/1.1\r\n\r\n') # only read once self.read_response() self.close() def test_cancel_during_download(self): self.connect() self.stream.write(b'GET /large HTTP/1.1\r\n\r\n') self.read_headers() self.stream.read_bytes(1024, self.stop) self.wait() self.close() def test_finish_while_closed(self): self.connect() self.stream.write(b'GET /finish_on_close HTTP/1.1\r\n\r\n') self.read_headers() self.close()
class _HTTPConnection(object): _SUPPORTED_METHODS = set(["GET", "HEAD", "POST", "PUT", "DELETE"]) def __init__(self, io_loop, client, request, release_callback, final_callback, max_buffer_size): self.start_time = time.time() self.io_loop = io_loop self.client = client self.request = request self.release_callback = release_callback self.final_callback = final_callback self.code = None self.headers = None self.chunks = None self._decompressor = None # Timeout handle returned by IOLoop.add_timeout self._timeout = None with stack_context.StackContext(self.cleanup): parsed = urllib.parse.urlsplit(_unicode(self.request.url)) if ssl is None and parsed.scheme == "https": raise ValueError("HTTPS requires either python2.6+ or " "curl_httpclient") if parsed.scheme not in ("http", "https"): raise ValueError("Unsupported url scheme: %s" % self.request.url) # urlsplit results have hostname and port results, but they # didn't support ipv6 literals until python 2.7. netloc = parsed.netloc if "@" in netloc: userpass, _, netloc = netloc.rpartition("@") match = re.match(r'^(.+):(\d+)$', netloc) if match: host = match.group(1) port = int(match.group(2)) else: host = netloc port = 443 if parsed.scheme == "https" else 80 if re.match(r'^\[.*\]$', host): # raw ipv6 addresses in urls are enclosed in brackets host = host[1:-1] if self.client.hostname_mapping is not None: host = self.client.hostname_mapping.get(host, host) if request.allow_ipv6: af = socket.AF_UNSPEC else: # We only try the first IP we get from getaddrinfo, # so restrict to ipv4 by default. af = socket.AF_INET addrinfo = socket.getaddrinfo(host, port, af, socket.SOCK_STREAM, 0, 0) af, socktype, proto, canonname, sockaddr = addrinfo[0] if parsed.scheme == "https": ssl_options = {} if request.validate_cert: ssl_options["cert_reqs"] = ssl.CERT_REQUIRED if request.ca_certs is not None: ssl_options["ca_certs"] = request.ca_certs else: ssl_options["ca_certs"] = _DEFAULT_CA_CERTS if request.client_key is not None: ssl_options["keyfile"] = request.client_key if request.client_cert is not None: ssl_options["certfile"] = request.client_cert self.stream = SSLIOStream(socket.socket(af, socktype, proto), io_loop=self.io_loop, ssl_options=ssl_options, max_buffer_size=max_buffer_size) else: self.stream = IOStream(socket.socket(af, socktype, proto), io_loop=self.io_loop, max_buffer_size=max_buffer_size) timeout = min(request.connect_timeout, request.request_timeout) if timeout: self._timeout = self.io_loop.add_timeout( self.start_time + timeout, self._on_timeout) self.stream.set_close_callback(self._on_close) self.stream.connect(sockaddr, functools.partial(self._on_connect, parsed)) def _on_timeout(self): self._timeout = None self._run_callback(HTTPResponse(self.request, 599, request_time=time.time() - self.start_time, error=HTTPError(599, "Timeout"))) self.stream.close() def _on_connect(self, parsed): if self._timeout is not None: self.io_loop.remove_timeout(self._timeout) self._timeout = None if self.request.request_timeout: self._timeout = self.io_loop.add_timeout( self.start_time + self.request.request_timeout, self._on_timeout) if (self.request.validate_cert and isinstance(self.stream, SSLIOStream)): match_hostname(self.stream.socket.getpeercert(), parsed.hostname) if (self.request.method not in self._SUPPORTED_METHODS and not self.request.allow_nonstandard_methods): raise KeyError("unknown method %s" % self.request.method) for key in ('network_interface', 'proxy_host', 'proxy_port', 'proxy_username', 'proxy_password'): if getattr(self.request, key, None): raise NotImplementedError('%s not supported' % key) if "Host" not in self.request.headers: self.request.headers["Host"] = parsed.netloc username, password = None, None if parsed.username is not None: username, password = parsed.username, parsed.password elif self.request.auth_username is not None: username = self.request.auth_username password = self.request.auth_password if username is not None: auth = utf8(username) + b(":") + utf8(password) self.request.headers["Authorization"] = (b("Basic ") + base64.b64encode(auth)) if self.request.user_agent: self.request.headers["User-Agent"] = self.request.user_agent if not self.request.allow_nonstandard_methods: if self.request.method in ("POST", "PUT"): assert self.request.body is not None else: assert self.request.body is None if self.request.body is not None: self.request.headers["Content-Length"] = str(len( self.request.body)) if (self.request.method == "POST" and "Content-Type" not in self.request.headers): self.request.headers["Content-Type"] = "application/x-www-form-urlencoded" if self.request.use_gzip: self.request.headers["Accept-Encoding"] = "gzip" req_path = ((parsed.path or '/') + (('?' + parsed.query) if parsed.query else '')) request_lines = [utf8("%s %s HTTP/1.1" % (self.request.method, req_path))] for k, v in self.request.headers.get_all(): line = utf8(k) + b(": ") + utf8(v) if b('\n') in line: raise ValueError('Newline in header: ' + repr(line)) request_lines.append(line) self.stream.write(b("\r\n").join(request_lines) + b("\r\n\r\n")) if self.request.body is not None: self.stream.write(self.request.body) self.stream.read_until_regex(b("\r?\n\r?\n"), self._on_headers) def _release(self): if self.release_callback is not None: release_callback = self.release_callback self.release_callback = None release_callback() def _run_callback(self, response): self._release() if self.final_callback is not None: final_callback = self.final_callback self.final_callback = None final_callback(response) @contextlib.contextmanager def cleanup(self): try: yield except Exception as e: logging.warning("uncaught exception", exc_info=True) self._run_callback(HTTPResponse(self.request, 599, error=e, request_time=time.time() - self.start_time, )) def _on_close(self): self._run_callback(HTTPResponse( self.request, 599, request_time=time.time() - self.start_time, error=HTTPError(599, "Connection closed"))) def _on_headers(self, data): data = native_str(data.decode("latin1")) first_line, _, header_data = data.partition("\n") match = re.match("HTTP/1.[01] ([0-9]+)", first_line) assert match self.code = int(match.group(1)) self.headers = HTTPHeaders.parse(header_data) if self.request.header_callback is not None: for k, v in self.headers.get_all(): self.request.header_callback("%s: %s\r\n" % (k, v)) if (self.request.use_gzip and self.headers.get("Content-Encoding") == "gzip"): # Magic parameter makes zlib module understand gzip header # http://stackoverflow.com/questions/1838699/how-can-i-decompress-a-gzip-stream-with-zlib self._decompressor = zlib.decompressobj(16+zlib.MAX_WBITS) if self.headers.get("Transfer-Encoding") == "chunked": self.chunks = [] self.stream.read_until(b("\r\n"), self._on_chunk_length) elif "Content-Length" in self.headers: if "," in self.headers["Content-Length"]: # Proxies sometimes cause Content-Length headers to get # duplicated. If all the values are identical then we can # use them but if they differ it's an error. pieces = re.split(r',\s*', self.headers["Content-Length"]) if any(i != pieces[0] for i in pieces): raise ValueError("Multiple unequal Content-Lengths: %r" % self.headers["Content-Length"]) self.headers["Content-Length"] = pieces[0] self.stream.read_bytes(int(self.headers["Content-Length"]), self._on_body) else: self.stream.read_until_close(self._on_body) def _on_body(self, data): if self._timeout is not None: self.io_loop.remove_timeout(self._timeout) self._timeout = None if self._decompressor: data = self._decompressor.decompress(data) if self.request.streaming_callback: if self.chunks is None: # if chunks is not None, we already called streaming_callback # in _on_chunk_data self.request.streaming_callback(data) buffer = BytesIO() else: buffer = BytesIO(data) # TODO: don't require one big string? original_request = getattr(self.request, "original_request", self.request) if (self.request.follow_redirects and self.request.max_redirects > 0 and self.code in (301, 302)): new_request = copy.copy(self.request) new_request.url = urllib.parse.urljoin(self.request.url, self.headers["Location"]) new_request.max_redirects -= 1 del new_request.headers["Host"] new_request.original_request = original_request final_callback = self.final_callback self.final_callback = None self._release() self.client.fetch(new_request, final_callback) self.stream.close() return response = HTTPResponse(original_request, self.code, headers=self.headers, request_time=time.time() - self.start_time, buffer=buffer, effective_url=self.request.url) self._run_callback(response) self.stream.close() def _on_chunk_length(self, data): # TODO: "chunk extensions" http://tools.ietf.org/html/rfc2616#section-3.6.1 length = int(data.strip(), 16) if length == 0: # all the data has been decompressed, so we don't need to # decompress again in _on_body self._decompressor = None self._on_body(b('').join(self.chunks)) else: self.stream.read_bytes(length + 2, # chunk ends with \r\n self._on_chunk_data) def _on_chunk_data(self, data): assert data[-2:] == b("\r\n") chunk = data[:-2] if self._decompressor: chunk = self._decompressor.decompress(chunk) if self.request.streaming_callback is not None: self.request.streaming_callback(chunk) else: self.chunks.append(chunk) self.stream.read_until(b("\r\n"), self._on_chunk_length)
class RedisPubSub(PubSubBase): def __init__(self, host='127.0.0.1', port=6379, *args, **kwargs): self.host = host self.port = port super(RedisPubSub, self).__init__(*args, **kwargs) @staticmethod def get_redis(): return redis.StrictRedis( host = '127.0.0.1', port = 6379, db = 0 ) ## ## pubsub api ## def connect(self): self.socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) self.stream = IOStream(self.socket) self.stream.connect((self.host, self.port), self.on_connect) def disconnect(self): self.unsubscribe() self.stream.close() def subscribe(self, channel_id): self.send('SUBSCRIBE', channel_id) def unsubscribe(self, channel_id=None): if channel_id: self.send('UNSUBSCRIBE', channel_id) else: self.send('UNSUBSCRIBE') @staticmethod def publish(channel_id, message): r = RedisPubSub.get_redis() r.publish(channel_id, message) ## ## socket/stream callbacks ## def on_connect(self): self.stream.set_close_callback(self.on_close) self.stream.read_until_close(self.on_data, self.on_streaming_data) self.reader = hiredis.Reader() self.connected() def on_data(self, *args, **kwargs): pass def on_streaming_data(self, data): self.reader.feed(data) reply = self.reader.gets() while reply: if reply[0] == 'subscribe': self.subscribed(reply[1]) elif reply[0] == 'unsubscribe': self.unsubscribed(reply[1]) elif reply[0] == 'message': self.on_message(reply[1], reply[2]) else: raise Exception('Unhandled data from redis %s' % reply) reply = self.reader.gets() def on_close(self): self.socket = None self.stream = None self.disconnected() ## ## redis protocol parser (derived from redis-py) ## def encode(self, value): if isinstance(value, bytes): return value if isinstance(value, float): value = repr(value) if not isinstance(value, basestring): value = str(value) if isinstance(value, unicode): value = value.encode('utf-8', 'strict') return value def pack_command(self, *args): cmd = io.BytesIO() cmd.write('*') cmd.write(str(len(args))) cmd.write('\r\n') for arg in args: arg = self.encode(arg) cmd.write('$') cmd.write(str(len(arg))) cmd.write('\r\n') cmd.write(arg) cmd.write('\r\n') return cmd.getvalue() def send(self, *args): """Send redis command.""" cmd = self.pack_command(*args) self.stream.write(cmd)
class SoxDecoder: """Decodes a stream of encoded data to Sox via stdin and receives the output on stdout. """ def __init__(self, codec, out_channels, out_samplerate, out_samplesize): self._started = False # codec and WAV params self._codec = codec self._channels = out_channels self._sample_rate = out_samplerate self._sample_size = out_samplesize # events self.on_close = None self.on_data_ready = None self.on_unhandled_error = None self.on_wav_params_ready = None @property def codec(self): return self._codec @property def channels(self): return self._channels @property def channel_mode(self): return {1: "Mono", 2: "Stereo"}.get(self._channels, "Unknown") @property def sample_rate(self): return self._sample_rate @property def sample_size(self): return self._sample_size def start(self, socket_or_fd, read_mtu): """Starts the decoder. If already started, this does nothing. """ if self._started: return # process self._process = Subprocess([ "sox", "-t", self._codec, "-", "--bits", str(self._sample_size), "--channels", str(self._channels), "--rate", str(self._sample_rate), "-t", "wav", "-" ], stdin=Subprocess.STREAM, stdout=Subprocess.STREAM, stderr=Subprocess.STREAM) self._process.stdout.set_close_callback(self._on_close) self._process.stdout.read_until_close( streaming_callback=self._out_data_ready) self._process.stderr.read_until_close( streaming_callback=self._sox_error) # did we get socket or fd? sock = socket_or_fd if isinstance(socket_or_fd, int): logger.debug("SoxDecoder received fd, building socket...") sock = socket.socket(fileno=socket_or_fd) sock.setblocking(True) # input pump self._input = IOStream(socket=sock) self._input.read_until_close(self._in_data_ready) # start self._started = True # we know WAV params already if self.on_wav_params_ready: self.on_wav_params_ready() def stop(self): """Stops the decoder. If already stopped, this does nothing. """ if not self._started: return self._started = False self._process.proc.kill() self._process = None def _on_close(self, *args): """Called when the Sox process exits. """ if not self._started: return self.stop() if self.on_close: self.on_close() def _in_data_ready(self, data): """Writes encoded data to the Sox input stream. """ if not self._started: raise InvalidOperationError("Not started.") self._process.stdin.write(data) def _out_data_ready(self, data): """Called when decoded data is ready. """ if self.on_data_ready: self.on_data_ready(data=data) def _sox_error(self, data): """Called when Sox writes to stderr. This isn't necessarily fatal, so we don't close the process. """ if self.on_unhandled_error: self.on_unhandled_error(error=data)
class Connection(RedisCommandsMixin): def __init__(self, redis, on_connect=None): logger.debug('Creating new Redis connection.') self.redis = redis self.reader = hiredis.Reader() self._watch = set() self._multi = False self.callbacks = deque() self._on_connect_callback = on_connect self.stream = IOStream(socket.socket(redis._family, socket.SOCK_STREAM, 0), io_loop=redis._ioloop) self.stream.set_close_callback(self._on_close) self.stream.connect(redis._addr, self._on_connect) def _on_connect(self): logger.debug('Connected!') self.stream.read_until_close(self._on_close, self._on_read) self.redis._shared.append(self) if self._on_connect_callback is not None: self._on_connect_callback(self) self._on_connect_callback = None def _on_read(self, data): self.reader.feed(data) while True: resp = self.reader.gets() if resp is False: break callback = self.callbacks.popleft() if callback is not None: self.redis._ioloop.add_callback(partial(callback, resp)) def is_idle(self): return len(self.callbacks) == 0 def is_shared(self): return self in self.redis._shared def lock(self): if not self.is_shared(): raise Exception('Connection already is locked!') self.redis._shared.remove(self) def unlock(self, callback=None): def cb(resp): assert resp == 'OK' self.redis._shared.append(self) if self._multi: self.send_message(['DISCARD']) elif self._watch: self.send_message(['UNWATCH']) self.send_message(['SELECT', self.redis._database], cb) def send_message(self, args, callback=None): command = args[0] if 'SUBSCRIBE' in command: raise NotImplementedError('Not yet.') # Do not allow the commands, affecting the execution of other commands, # to be used on shared connection. if command in ('WATCH', 'MULTI'): if self.is_shared(): raise Exception('Command %s is not allowed while connection ' 'is shared!' % command) if command == 'WATCH': self._watch.add(args[1]) if command == 'MULTI': self._multi = True # monitor transaction state, to unlock correctly if command in ('EXEC', 'DISCARD', 'UNWATCH'): if command in ('EXEC', 'DISCARD'): self._multi = False self._watch.clear() self.stream.write(self.format_message(args)) future = Future() if callback is not None: future.add_done_callback(stack_context.wrap(callback)) self.callbacks.append(future.set_result) return future def format_message(self, args): l = "*%d" % len(args) lines = [l.encode('utf-8')] for arg in args: if not isinstance(arg, str): arg = str(arg) arg = arg.encode('utf-8') l = "$%d" % len(arg) lines.append(l.encode('utf-8')) lines.append(arg) lines.append(b"") return b"\r\n".join(lines) def close(self): self.send_command(['QUIT']) if self.is_shared(): self.lock() def _on_close(self, data=None): logger.debug('Redis connection was closed.') if data is not None: self._on_read(data) if self.is_shared(): self.lock()
def handle_connection(connection, address): stream = IOStream(connection) message = yield stream.read_until_close() print("message from client:", message.decode().strip())
class Server(Session): def __init__(self, protocol, io_loop=None, sock=None): Session.__init__(self, protocol, io_loop) self.is_server = True if sock is not None: self._setup(sock) def bind(self, address): self.address = address self.sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM, 0) self.sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) self.sock.setblocking(0) self.sock.bind(address) def start(self): self.sock.listen(2) self.io_loop.add_handler(self.sock.fileno(), self._accept, self.io_loop.READ) self.status = SessionStream.LISTENING def listen(self, address): self.bind(address) self.start() def _accept(self, fd, event): if fd != self.sock.fileno(): print "panic: fd != sock.fileno() ...." return conn, address = self.sock.accept() if self.status == SessionStream.LISTENING: self._setup(conn) elif self.status == SessionStream.CONNECTED: logging.warning( "already connected ...") conn.close() def _setup(self, conn): self.ios = IOStream(conn, self.io_loop) self.status = SessionStream.CONNECTED self.protocol.connected(self) self.ios.set_close_callback(self._ios_closed) self.ios.read_until_close(self._disconnected, self._receiver) def close(self): logging.debug("Server.close() called...") if self.ios is not None: self.ios.close() self.io_loop.remove_handler(self.sock.fileno()) if self.sock is not None: self.sock.close() self.status = SessionStream.IDLE def _disconnected(self, data): logging.info("client disconnected ... ") self.protocol.disconnected(self) self.clear_buffer() self.clear_timers() self.ios.close() self.status = SessionStream.LISTENING logging.debug("end of client disconnect ...") def _ios_closed(self): self.status = SessionStream.LISTENING
class _HTTPConnection(object): _SUPPORTED_METHODS = set(["GET", "HEAD", "POST", "PUT", "DELETE"]) def __init__(self, io_loop, client, request, release_callback, final_callback, max_buffer_size): self.start_time = time.time() self.io_loop = io_loop self.client = client self.request = request self.release_callback = release_callback self.final_callback = final_callback self.code = None self.headers = None self.chunks = None self._decompressor = None # Timeout handle returned by IOLoop.add_timeout self._timeout = None with stack_context.StackContext(self.cleanup): parsed = urllib.parse.urlsplit(_unicode(self.request.url)) if ssl is None and parsed.scheme == "https": raise ValueError("HTTPS requires either python2.6+ or " "curl_httpclient") if parsed.scheme not in ("http", "https"): raise ValueError("Unsupported url scheme: %s" % self.request.url) # urlsplit results have hostname and port results, but they # didn't support ipv6 literals until python 2.7. netloc = parsed.netloc if "@" in netloc: userpass, _, netloc = netloc.rpartition("@") match = re.match(r"^(.+):(\d+)$", netloc) if match: host = match.group(1) port = int(match.group(2)) else: host = netloc port = 443 if parsed.scheme == "https" else 80 if re.match(r"^\[.*\]$", host): # raw ipv6 addresses in urls are enclosed in brackets host = host[1:-1] if self.client.hostname_mapping is not None: host = self.client.hostname_mapping.get(host, host) if request.allow_ipv6: af = socket.AF_UNSPEC else: # We only try the first IP we get from getaddrinfo, # so restrict to ipv4 by default. af = socket.AF_INET addrinfo = socket.getaddrinfo(host, port, af, socket.SOCK_STREAM, 0, 0) af, socktype, proto, canonname, sockaddr = addrinfo[0] if parsed.scheme == "https": ssl_options = {} if request.validate_cert: ssl_options["cert_reqs"] = ssl.CERT_REQUIRED if request.ca_certs is not None: ssl_options["ca_certs"] = request.ca_certs else: ssl_options["ca_certs"] = _DEFAULT_CA_CERTS if request.client_key is not None: ssl_options["keyfile"] = request.client_key if request.client_cert is not None: ssl_options["certfile"] = request.client_cert # SSL interoperability is tricky. We want to disable # SSLv2 for security reasons; it wasn't disabled by default # until openssl 1.0. The best way to do this is to use # the SSL_OP_NO_SSLv2, but that wasn't exposed to python # until 3.2. Python 2.7 adds the ciphers argument, which # can also be used to disable SSLv2. As a last resort # on python 2.6, we set ssl_version to SSLv3. This is # more narrow than we'd like since it also breaks # compatibility with servers configured for TLSv1 only, # but nearly all servers support SSLv3: # http://blog.ivanristic.com/2011/09/ssl-survey-protocol-support.html if sys.version_info >= (2, 7): ssl_options["ciphers"] = "DEFAULT:!SSLv2" else: # This is really only necessary for pre-1.0 versions # of openssl, but python 2.6 doesn't expose version # information. ssl_options["ssl_version"] = ssl.PROTOCOL_SSLv3 self.stream = SSLIOStream( socket.socket(af, socktype, proto), io_loop=self.io_loop, ssl_options=ssl_options, max_buffer_size=max_buffer_size, ) else: self.stream = IOStream( socket.socket(af, socktype, proto), io_loop=self.io_loop, max_buffer_size=max_buffer_size ) timeout = min(request.connect_timeout, request.request_timeout) if timeout: self._timeout = self.io_loop.add_timeout(self.start_time + timeout, self._on_timeout) self.stream.set_close_callback(self._on_close) self.stream.connect(sockaddr, functools.partial(self._on_connect, parsed)) def _on_timeout(self): self._timeout = None self._run_callback( HTTPResponse(self.request, 599, request_time=time.time() - self.start_time, error=HTTPError(599, "Timeout")) ) self.stream.close() def _on_connect(self, parsed): if self._timeout is not None: self.io_loop.remove_timeout(self._timeout) self._timeout = None if self.request.request_timeout: self._timeout = self.io_loop.add_timeout(self.start_time + self.request.request_timeout, self._on_timeout) if self.request.validate_cert and isinstance(self.stream, SSLIOStream): match_hostname(self.stream.socket.getpeercert(), parsed.hostname) if self.request.method not in self._SUPPORTED_METHODS and not self.request.allow_nonstandard_methods: raise KeyError("unknown method %s" % self.request.method) for key in ("network_interface", "proxy_host", "proxy_port", "proxy_username", "proxy_password"): if getattr(self.request, key, None): raise NotImplementedError("%s not supported" % key) if "Host" not in self.request.headers: self.request.headers["Host"] = parsed.netloc username, password = None, None if parsed.username is not None: username, password = parsed.username, parsed.password elif self.request.auth_username is not None: username = self.request.auth_username password = self.request.auth_password or "" if username is not None: auth = utf8(username) + b(":") + utf8(password) self.request.headers["Authorization"] = b("Basic ") + base64.b64encode(auth) if self.request.user_agent: self.request.headers["User-Agent"] = self.request.user_agent if not self.request.allow_nonstandard_methods: if self.request.method in ("POST", "PUT"): assert self.request.body is not None else: assert self.request.body is None if self.request.body is not None: self.request.headers["Content-Length"] = str(len(self.request.body)) if self.request.method == "POST" and "Content-Type" not in self.request.headers: self.request.headers["Content-Type"] = "application/x-www-form-urlencoded" if self.request.use_gzip: self.request.headers["Accept-Encoding"] = "gzip" req_path = (parsed.path or "/") + (("?" + parsed.query) if parsed.query else "") request_lines = [utf8("%s %s HTTP/1.1" % (self.request.method, req_path))] for k, v in self.request.headers.get_all(): line = utf8(k) + b(": ") + utf8(v) if b("\n") in line: raise ValueError("Newline in header: " + repr(line)) request_lines.append(line) self.stream.write(b("\r\n").join(request_lines) + b("\r\n\r\n")) if self.request.body is not None: self.stream.write(self.request.body) self.stream.read_until_regex(b("\r?\n\r?\n"), self._on_headers) def _release(self): if self.release_callback is not None: release_callback = self.release_callback self.release_callback = None release_callback() def _run_callback(self, response): self._release() if self.final_callback is not None: final_callback = self.final_callback self.final_callback = None final_callback(response) @contextlib.contextmanager def cleanup(self): try: yield except Exception as e: logging.warning("uncaught exception", exc_info=True) self._run_callback(HTTPResponse(self.request, 599, error=e, request_time=time.time() - self.start_time)) def _on_close(self): self._run_callback( HTTPResponse( self.request, 599, request_time=time.time() - self.start_time, error=HTTPError(599, "Connection closed") ) ) def _on_headers(self, data): data = native_str(data.decode("latin1")) first_line, _, header_data = data.partition("\n") match = re.match("HTTP/1.[01] ([0-9]+)", first_line) assert match self.code = int(match.group(1)) self.headers = HTTPHeaders.parse(header_data) if "Content-Length" in self.headers: if "," in self.headers["Content-Length"]: # Proxies sometimes cause Content-Length headers to get # duplicated. If all the values are identical then we can # use them but if they differ it's an error. pieces = re.split(r",\s*", self.headers["Content-Length"]) if any(i != pieces[0] for i in pieces): raise ValueError("Multiple unequal Content-Lengths: %r" % self.headers["Content-Length"]) self.headers["Content-Length"] = pieces[0] content_length = int(self.headers["Content-Length"]) else: content_length = None if self.request.header_callback is not None: for k, v in self.headers.get_all(): self.request.header_callback("%s: %s\r\n" % (k, v)) if self.request.method == "HEAD": # HEAD requests never have content, even though they may have # content-length headers self._on_body(b("")) return if 100 <= self.code < 200 or self.code in (204, 304): # These response codes never have bodies # http://www.w3.org/Protocols/rfc2616/rfc2616-sec4.html#sec4.3 assert "Transfer-Encoding" not in self.headers assert content_length in (None, 0) self._on_body(b("")) return if self.request.use_gzip and self.headers.get("Content-Encoding") == "gzip": # Magic parameter makes zlib module understand gzip header # http://stackoverflow.com/questions/1838699/how-can-i-decompress-a-gzip-stream-with-zlib self._decompressor = zlib.decompressobj(16 + zlib.MAX_WBITS) if self.headers.get("Transfer-Encoding") == "chunked": self.chunks = [] self.stream.read_until(b("\r\n"), self._on_chunk_length) elif content_length is not None: self.stream.read_bytes(content_length, self._on_body) else: self.stream.read_until_close(self._on_body) def _on_body(self, data): if self._timeout is not None: self.io_loop.remove_timeout(self._timeout) self._timeout = None original_request = getattr(self.request, "original_request", self.request) if self.request.follow_redirects and self.request.max_redirects > 0 and self.code in (301, 302, 303, 307): new_request = copy.copy(self.request) new_request.url = urllib.parse.urljoin(self.request.url, self.headers["Location"]) new_request.max_redirects -= 1 del new_request.headers["Host"] # http://www.w3.org/Protocols/rfc2616/rfc2616-sec10.html#sec10.3.4 # client SHOULD make a GET request if self.code == 303: new_request.method = "GET" new_request.body = None for h in ["Content-Length", "Content-Type", "Content-Encoding", "Transfer-Encoding"]: try: del self.request.headers[h] except KeyError: pass new_request.original_request = original_request final_callback = self.final_callback self.final_callback = None self._release() self.client.fetch(new_request, final_callback) self.stream.close() return if self._decompressor: data = self._decompressor.decompress(data) if self.request.streaming_callback: if self.chunks is None: # if chunks is not None, we already called streaming_callback # in _on_chunk_data self.request.streaming_callback(data) buffer = BytesIO() else: buffer = BytesIO(data) # TODO: don't require one big string? response = HTTPResponse( original_request, self.code, headers=self.headers, request_time=time.time() - self.start_time, buffer=buffer, effective_url=self.request.url, ) self._run_callback(response) self.stream.close() def _on_chunk_length(self, data): # TODO: "chunk extensions" http://tools.ietf.org/html/rfc2616#section-3.6.1 length = int(data.strip(), 16) if length == 0: # all the data has been decompressed, so we don't need to # decompress again in _on_body self._decompressor = None self._on_body(b("").join(self.chunks)) else: self.stream.read_bytes(length + 2, self._on_chunk_data) # chunk ends with \r\n def _on_chunk_data(self, data): assert data[-2:] == b("\r\n") chunk = data[:-2] if self._decompressor: chunk = self._decompressor.decompress(chunk) if self.request.streaming_callback is not None: self.request.streaming_callback(chunk) else: self.chunks.append(chunk) self.stream.read_until(b("\r\n"), self._on_chunk_length)
class IRCClient(object): def __init__(self, uri, ioloop): self.ioloop = ioloop self.callbacks = { "PING": [pong_callback], "NOTICE": [debug_callback], "ERROR": [die_callback] } self.conn = IRCConnection.from_uri(uri) self.socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM, 0) self.stream = IOStream(self.socket, io_loop=ioloop) self.current_chunk = "" def add_message_callback(self, command, func): self.callbacks.setdefault(command, []).append(func) def read_timeout(self): if self.message_future and not self.message_future.done(): self.message_future.set_exception(ReadTimeout("TIMEOUT")) def stream_bytes(self, chunk): self.current_chunk += chunk while "\r\n" in self.current_chunk: original_message, self.current_chunk = self.current_chunk.split( "\r\n", 1) message = Message.from_message(original_message) if message.ident[1:].startswith(self.conn.username): logging.debug( "Skipping message from self: {0}".format(message.message)) continue if message.command not in self.callbacks: logging.info("SKIPPING - {0}".format(original_message)) continue for callback in self.callbacks[message.command]: try: callback(message, self.stream) except StreamClosedError: logging.error("Stream was closed.") raise except Exception as exc: logging.error("Exception {0} in callback.".format(exc)) @gen.coroutine def listen(self): logging.info( "Connecting to {0}:{1}".format(self.conn.host, self.conn.port)) self.stream.connect((self.conn.host, self.conn.port)) logging.info("Registering the client.") self.stream.write("PASS {0}\r\n".format(self.conn.password)) self.stream.write("NICK {0}\r\n".format(self.conn.username)) self.stream.write("USER {0} {1} unused :{2}\r\n".format( self.conn.username, socket.gethostname(), self.conn.name)) for channel in self.conn.channels: logging.info("Joining channel: #{0}".format(channel)) self.stream.write("JOIN #{0}\r\n".format(channel)) try: yield self.stream.read_until_close( streaming_callback=self.stream_bytes) except StreamClosedError: logging.error("Stream is closed.") raise gen.Return(False) raise gen.Return(True) def stop(self): logging.info("Client no longer listening.") self.stream.write("QUIT\r\n") self.stream.close()
class KeepAliveTest(AsyncHTTPTestCase): """Tests various scenarios for HTTP 1.1 keep-alive support. These tests don't use AsyncHTTPClient because we want to control connection reuse and closing. """ def get_app(self): class HelloHandler(RequestHandler): def get(self): self.finish("Hello world") def post(self): self.finish("Hello world") class LargeHandler(RequestHandler): def get(self): # 512KB should be bigger than the socket buffers so it will # be written out in chunks. self.write("".join(chr(i % 256) * 1024 for i in range(512))) class FinishOnCloseHandler(RequestHandler): def initialize(self, cleanup_event): self.cleanup_event = cleanup_event @gen.coroutine def get(self): self.flush() yield self.cleanup_event.wait() def on_connection_close(self): # This is not very realistic, but finishing the request # from the close callback has the right timing to mimic # some errors seen in the wild. self.finish("closed") self.cleanup_event = Event() return Application( [ ("/", HelloHandler), ("/large", LargeHandler), ( "/finish_on_close", FinishOnCloseHandler, dict(cleanup_event=self.cleanup_event), ), ] ) def setUp(self): super(KeepAliveTest, self).setUp() self.http_version = b"HTTP/1.1" def tearDown(self): # We just closed the client side of the socket; let the IOLoop run # once to make sure the server side got the message. self.io_loop.add_timeout(datetime.timedelta(seconds=0.001), self.stop) self.wait() if hasattr(self, "stream"): self.stream.close() super(KeepAliveTest, self).tearDown() # The next few methods are a crude manual http client @gen.coroutine def connect(self): self.stream = IOStream(socket.socket()) yield self.stream.connect(("127.0.0.1", self.get_http_port())) @gen.coroutine def read_headers(self): first_line = yield self.stream.read_until(b"\r\n") self.assertTrue(first_line.startswith(b"HTTP/1.1 200"), first_line) header_bytes = yield self.stream.read_until(b"\r\n\r\n") headers = HTTPHeaders.parse(header_bytes.decode("latin1")) raise gen.Return(headers) @gen.coroutine def read_response(self): self.headers = yield self.read_headers() body = yield self.stream.read_bytes(int(self.headers["Content-Length"])) self.assertEqual(b"Hello world", body) def close(self): self.stream.close() del self.stream @gen_test def test_two_requests(self): yield self.connect() self.stream.write(b"GET / HTTP/1.1\r\n\r\n") yield self.read_response() self.stream.write(b"GET / HTTP/1.1\r\n\r\n") yield self.read_response() self.close() @gen_test def test_request_close(self): yield self.connect() self.stream.write(b"GET / HTTP/1.1\r\nConnection: close\r\n\r\n") yield self.read_response() data = yield self.stream.read_until_close() self.assertTrue(not data) self.assertEqual(self.headers["Connection"], "close") self.close() # keepalive is supported for http 1.0 too, but it's opt-in @gen_test def test_http10(self): self.http_version = b"HTTP/1.0" yield self.connect() self.stream.write(b"GET / HTTP/1.0\r\n\r\n") yield self.read_response() data = yield self.stream.read_until_close() self.assertTrue(not data) self.assertTrue("Connection" not in self.headers) self.close() @gen_test def test_http10_keepalive(self): self.http_version = b"HTTP/1.0" yield self.connect() self.stream.write(b"GET / HTTP/1.0\r\nConnection: keep-alive\r\n\r\n") yield self.read_response() self.assertEqual(self.headers["Connection"], "Keep-Alive") self.stream.write(b"GET / HTTP/1.0\r\nConnection: keep-alive\r\n\r\n") yield self.read_response() self.assertEqual(self.headers["Connection"], "Keep-Alive") self.close() @gen_test def test_http10_keepalive_extra_crlf(self): self.http_version = b"HTTP/1.0" yield self.connect() self.stream.write(b"GET / HTTP/1.0\r\nConnection: keep-alive\r\n\r\n\r\n") yield self.read_response() self.assertEqual(self.headers["Connection"], "Keep-Alive") self.stream.write(b"GET / HTTP/1.0\r\nConnection: keep-alive\r\n\r\n") yield self.read_response() self.assertEqual(self.headers["Connection"], "Keep-Alive") self.close() @gen_test def test_pipelined_requests(self): yield self.connect() self.stream.write(b"GET / HTTP/1.1\r\n\r\nGET / HTTP/1.1\r\n\r\n") yield self.read_response() yield self.read_response() self.close() @gen_test def test_pipelined_cancel(self): yield self.connect() self.stream.write(b"GET / HTTP/1.1\r\n\r\nGET / HTTP/1.1\r\n\r\n") # only read once yield self.read_response() self.close() @gen_test def test_cancel_during_download(self): yield self.connect() self.stream.write(b"GET /large HTTP/1.1\r\n\r\n") yield self.read_headers() yield self.stream.read_bytes(1024) self.close() @gen_test def test_finish_while_closed(self): yield self.connect() self.stream.write(b"GET /finish_on_close HTTP/1.1\r\n\r\n") yield self.read_headers() self.close() # Let the hanging coroutine clean up after itself self.cleanup_event.set() @gen_test def test_keepalive_chunked(self): self.http_version = b"HTTP/1.0" yield self.connect() self.stream.write( b"POST / HTTP/1.0\r\n" b"Connection: keep-alive\r\n" b"Transfer-Encoding: chunked\r\n" b"\r\n" b"0\r\n" b"\r\n" ) yield self.read_response() self.assertEqual(self.headers["Connection"], "Keep-Alive") self.stream.write(b"GET / HTTP/1.0\r\nConnection: keep-alive\r\n\r\n") yield self.read_response() self.assertEqual(self.headers["Connection"], "Keep-Alive") self.close()
class BaseTornadoClient(AsyncModbusClientMixin): """ Base Tornado client """ stream = None io_loop = None def __init__(self, *args, **kwargs): """ Initializes BaseTornadoClient. ioloop to be passed as part of kwargs ('ioloop') :param args: :param kwargs: """ self.io_loop = kwargs.pop("ioloop", None) super(BaseTornadoClient, self).__init__(*args, **kwargs) @abc.abstractmethod def get_socket(self): """ return instance of the socket to connect to """ @gen.coroutine def connect(self): """ Connect to the socket identified by host and port :returns: Future :rtype: tornado.concurrent.Future """ conn = self.get_socket() self.stream = IOStream(conn, io_loop=self.io_loop or IOLoop.current()) self.stream.connect((self.host, self.port)) self.stream.read_until_close(None, streaming_callback=self.on_receive) self._connected = True LOGGER.debug("Client connected") raise gen.Return(self) def on_receive(self, *args): """ On data recieve call back :param args: data received :return: """ data = args[0] if len(args) > 0 else None if not data: return LOGGER.debug("recv: " + " ".join([hex(byte2int(x)) for x in data])) unit = self.framer.decode_data(data).get("uid", 0) self.framer.processIncomingPacket(data, self._handle_response, unit=unit) def execute(self, request=None): """ Executes a transaction :param request: :return: """ request.transaction_id = self.transaction.getNextTID() packet = self.framer.buildPacket(request) LOGGER.debug("send: " + " ".join([hex(byte2int(x)) for x in packet])) self.stream.write(packet) return self._build_response(request.transaction_id) def _handle_response(self, reply, **kwargs): """ Handle response received :param reply: :param kwargs: :return: """ if reply is not None: tid = reply.transaction_id future = self.transaction.getTransaction(tid) if future: future.set_result(reply) else: LOGGER.debug("Unrequested message: {}".format(reply)) def _build_response(self, tid): """ Builds a future response :param tid: :return: """ f = Future() if not self._connected: f.set_exception(ConnectionException("Client is not connected")) return f self.transaction.addTransaction(f, tid) return f def close(self): """ Closes the underlying IOStream """ LOGGER.debug("Client disconnected") if self.stream: self.stream.close_fd() self.stream = None self._connected = False
class _HTTPConnection(object): _SUPPORTED_METHODS = set(["GET", "HEAD", "POST", "PUT", "DELETE", "PATCH", "OPTIONS"]) def __init__(self, io_loop, client, request, release_callback, final_callback, max_buffer_size): self.start_time = io_loop.time() self.io_loop = io_loop self.client = client self.request = request self.release_callback = release_callback self.final_callback = final_callback self.max_buffer_size = max_buffer_size self.code = None self.headers = None self.chunks = None self._decompressor = None # Timeout handle returned by IOLoop.add_timeout self._timeout = None with stack_context.ExceptionStackContext(self._handle_exception): self.parsed = urlparse.urlsplit(_unicode(self.request.url)) if self.parsed.scheme not in ("http", "https"): raise ValueError("Unsupported url scheme: %s" % self.request.url) # urlsplit results have hostname and port results, but they # didn't support ipv6 literals until python 2.7. netloc = self.parsed.netloc if "@" in netloc: userpass, _, netloc = netloc.rpartition("@") match = re.match(r"^(.+):(\d+)$", netloc) if match: host = match.group(1) port = int(match.group(2)) else: host = netloc port = 443 if self.parsed.scheme == "https" else 80 if re.match(r"^\[.*\]$", host): # raw ipv6 addresses in urls are enclosed in brackets host = host[1:-1] self.parsed_hostname = host # save final host for _on_connect if self.client.hostname_mapping is not None: host = self.client.hostname_mapping.get(host, host) if request.allow_ipv6: af = socket.AF_UNSPEC else: # We only try the first IP we get from getaddrinfo, # so restrict to ipv4 by default. af = socket.AF_INET self.client.resolver.getaddrinfo(host, port, af, socket.SOCK_STREAM, 0, 0, callback=self._on_resolve) def _on_resolve(self, future): af, socktype, proto, canonname, sockaddr = future.result()[0] if self.parsed.scheme == "https": ssl_options = {} if self.request.validate_cert: ssl_options["cert_reqs"] = ssl.CERT_REQUIRED if self.request.ca_certs is not None: ssl_options["ca_certs"] = self.request.ca_certs else: ssl_options["ca_certs"] = _DEFAULT_CA_CERTS if self.request.client_key is not None: ssl_options["keyfile"] = self.request.client_key if self.request.client_cert is not None: ssl_options["certfile"] = self.request.client_cert # SSL interoperability is tricky. We want to disable # SSLv2 for security reasons; it wasn't disabled by default # until openssl 1.0. The best way to do this is to use # the SSL_OP_NO_SSLv2, but that wasn't exposed to python # until 3.2. Python 2.7 adds the ciphers argument, which # can also be used to disable SSLv2. As a last resort # on python 2.6, we set ssl_version to SSLv3. This is # more narrow than we'd like since it also breaks # compatibility with servers configured for TLSv1 only, # but nearly all servers support SSLv3: # http://blog.ivanristic.com/2011/09/ssl-survey-protocol-support.html if sys.version_info >= (2, 7): ssl_options["ciphers"] = "DEFAULT:!SSLv2" else: # This is really only necessary for pre-1.0 versions # of openssl, but python 2.6 doesn't expose version # information. ssl_options["ssl_version"] = ssl.PROTOCOL_SSLv3 self.stream = SSLIOStream( socket.socket(af, socktype, proto), io_loop=self.io_loop, ssl_options=ssl_options, max_buffer_size=self.max_buffer_size, ) else: self.stream = IOStream( socket.socket(af, socktype, proto), io_loop=self.io_loop, max_buffer_size=self.max_buffer_size ) timeout = min(self.request.connect_timeout, self.request.request_timeout) if timeout: self._timeout = self.io_loop.add_timeout(self.start_time + timeout, stack_context.wrap(self._on_timeout)) self.stream.set_close_callback(self._on_close) # ipv6 addresses are broken (in self.parsed.hostname) until # 2.7, here is correctly parsed value calculated in __init__ self.stream.connect(sockaddr, self._on_connect, server_hostname=self.parsed_hostname) def _on_timeout(self): self._timeout = None if self.final_callback is not None: raise HTTPError(599, "Timeout") def _on_connect(self): if self._timeout is not None: self.io_loop.remove_timeout(self._timeout) self._timeout = None if self.request.request_timeout: self._timeout = self.io_loop.add_timeout( self.start_time + self.request.request_timeout, stack_context.wrap(self._on_timeout) ) if self.request.method not in self._SUPPORTED_METHODS and not self.request.allow_nonstandard_methods: raise KeyError("unknown method %s" % self.request.method) for key in ("network_interface", "proxy_host", "proxy_port", "proxy_username", "proxy_password"): if getattr(self.request, key, None): raise NotImplementedError("%s not supported" % key) if "Connection" not in self.request.headers: self.request.headers["Connection"] = "close" if "Host" not in self.request.headers: if "@" in self.parsed.netloc: self.request.headers["Host"] = self.parsed.netloc.rpartition("@")[-1] else: self.request.headers["Host"] = self.parsed.netloc username, password = None, None if self.parsed.username is not None: username, password = self.parsed.username, self.parsed.password elif self.request.auth_username is not None: username = self.request.auth_username password = self.request.auth_password or "" if username is not None: auth = utf8(username) + b":" + utf8(password) self.request.headers["Authorization"] = b"Basic " + base64.b64encode(auth) if self.request.user_agent: self.request.headers["User-Agent"] = self.request.user_agent if not self.request.allow_nonstandard_methods: if self.request.method in ("POST", "PATCH", "PUT"): assert self.request.body is not None else: assert self.request.body is None if self.request.body is not None: self.request.headers["Content-Length"] = str(len(self.request.body)) if self.request.method == "POST" and "Content-Type" not in self.request.headers: self.request.headers["Content-Type"] = "application/x-www-form-urlencoded" if self.request.use_gzip: self.request.headers["Accept-Encoding"] = "gzip" req_path = (self.parsed.path or "/") + (("?" + self.parsed.query) if self.parsed.query else "") request_lines = [utf8("%s %s HTTP/1.1" % (self.request.method, req_path))] for k, v in self.request.headers.get_all(): line = utf8(k) + b": " + utf8(v) if b"\n" in line: raise ValueError("Newline in header: " + repr(line)) request_lines.append(line) self.stream.write(b"\r\n".join(request_lines) + b"\r\n\r\n") if self.request.body is not None: self.stream.write(self.request.body) self.stream.read_until_regex(b"\r?\n\r?\n", self._on_headers) def _release(self): if self.release_callback is not None: release_callback = self.release_callback self.release_callback = None release_callback() def _run_callback(self, response): self._release() if self.final_callback is not None: final_callback = self.final_callback self.final_callback = None self.io_loop.add_callback(final_callback, response) def _handle_exception(self, typ, value, tb): if self.final_callback: gen_log.warning("uncaught exception", exc_info=(typ, value, tb)) self._run_callback( HTTPResponse(self.request, 599, error=value, request_time=self.io_loop.time() - self.start_time) ) if hasattr(self, "stream"): self.stream.close() return True else: # If our callback has already been called, we are probably # catching an exception that is not caused by us but rather # some child of our callback. Rather than drop it on the floor, # pass it along. return False def _on_close(self): if self.final_callback is not None: message = "Connection closed" if self.stream.error: message = str(self.stream.error) raise HTTPError(599, message) def _on_headers(self, data): data = native_str(data.decode("latin1")) first_line, _, header_data = data.partition("\n") match = re.match("HTTP/1.[01] ([0-9]+) ([^\r]*)", first_line) assert match code = int(match.group(1)) if 100 <= code < 200: self.stream.read_until_regex(b"\r?\n\r?\n", self._on_headers) return else: self.code = code self.reason = match.group(2) self.headers = HTTPHeaders.parse(header_data) if "Content-Length" in self.headers: if "," in self.headers["Content-Length"]: # Proxies sometimes cause Content-Length headers to get # duplicated. If all the values are identical then we can # use them but if they differ it's an error. pieces = re.split(r",\s*", self.headers["Content-Length"]) if any(i != pieces[0] for i in pieces): raise ValueError("Multiple unequal Content-Lengths: %r" % self.headers["Content-Length"]) self.headers["Content-Length"] = pieces[0] content_length = int(self.headers["Content-Length"]) else: content_length = None if self.request.header_callback is not None: # re-attach the newline we split on earlier self.request.header_callback(first_line + _) for k, v in self.headers.get_all(): self.request.header_callback("%s: %s\r\n" % (k, v)) self.request.header_callback("\r\n") if self.request.method == "HEAD" or self.code == 304: # HEAD requests and 304 responses never have content, even # though they may have content-length headers self._on_body(b"") return if 100 <= self.code < 200 or self.code == 204: # These response codes never have bodies # http://www.w3.org/Protocols/rfc2616/rfc2616-sec4.html#sec4.3 if "Transfer-Encoding" in self.headers or content_length not in (None, 0): raise ValueError("Response with code %d should not have body" % self.code) self._on_body(b"") return if self.request.use_gzip and self.headers.get("Content-Encoding") == "gzip": self._decompressor = GzipDecompressor() if self.headers.get("Transfer-Encoding") == "chunked": self.chunks = [] self.stream.read_until(b"\r\n", self._on_chunk_length) elif content_length is not None: self.stream.read_bytes(content_length, self._on_body) else: self.stream.read_until_close(self._on_body) def _on_body(self, data): if self._timeout is not None: self.io_loop.remove_timeout(self._timeout) self._timeout = None original_request = getattr(self.request, "original_request", self.request) if self.request.follow_redirects and self.request.max_redirects > 0 and self.code in (301, 302, 303, 307): assert isinstance(self.request, _RequestProxy) new_request = copy.copy(self.request.request) new_request.url = urlparse.urljoin(self.request.url, self.headers["Location"]) new_request.max_redirects = self.request.max_redirects - 1 del new_request.headers["Host"] # http://www.w3.org/Protocols/rfc2616/rfc2616-sec10.html#sec10.3.4 # Client SHOULD make a GET request after a 303. # According to the spec, 302 should be followed by the same # method as the original request, but in practice browsers # treat 302 the same as 303, and many servers use 302 for # compatibility with pre-HTTP/1.1 user agents which don't # understand the 303 status. if self.code in (302, 303): new_request.method = "GET" new_request.body = None for h in ["Content-Length", "Content-Type", "Content-Encoding", "Transfer-Encoding"]: try: del self.request.headers[h] except KeyError: pass new_request.original_request = original_request final_callback = self.final_callback self.final_callback = None self._release() self.client.fetch(new_request, final_callback) self.stream.close() return if self._decompressor: data = self._decompressor.decompress(data) + self._decompressor.flush() if self.request.streaming_callback: if self.chunks is None: # if chunks is not None, we already called streaming_callback # in _on_chunk_data self.request.streaming_callback(data) buffer = BytesIO() else: buffer = BytesIO(data) # TODO: don't require one big string? response = HTTPResponse( original_request, self.code, reason=self.reason, headers=self.headers, request_time=self.io_loop.time() - self.start_time, buffer=buffer, effective_url=self.request.url, ) self._run_callback(response) self.stream.close() def _on_chunk_length(self, data): # TODO: "chunk extensions" http://tools.ietf.org/html/rfc2616#section-3.6.1 length = int(data.strip(), 16) if length == 0: if self._decompressor is not None: tail = self._decompressor.flush() if tail: # I believe the tail will always be empty (i.e. # decompress will return all it can). The purpose # of the flush call is to detect errors such # as truncated input. But in case it ever returns # anything, treat it as an extra chunk if self.request.streaming_callback is not None: self.request.streaming_callback(tail) else: self.chunks.append(tail) # all the data has been decompressed, so we don't need to # decompress again in _on_body self._decompressor = None self._on_body(b"".join(self.chunks)) else: self.stream.read_bytes(length + 2, self._on_chunk_data) # chunk ends with \r\n def _on_chunk_data(self, data): assert data[-2:] == b"\r\n" chunk = data[:-2] if self._decompressor: chunk = self._decompressor.decompress(chunk) if self.request.streaming_callback is not None: self.request.streaming_callback(chunk) else: self.chunks.append(chunk) self.stream.read_until(b"\r\n", self._on_chunk_length)
class _RedisConnection(object): def __init__(self, final_callback, redis_tuple, redis_pwd): """ :param final_callback: resp赋值时调用 :param redis_tuple: (ip, port, db) :param redis_pwd: redis密码 """ self.__io_loop = IOLoop.instance() self.__resp_cb = final_callback self.__stream = None #redis应答解析remain self.__recv_buf = '' self.__redis_tuple = redis_tuple self.__redis_pwd = redis_pwd #redis指令上下文, connect指令个数(AUTH, SELECT .etc),trans,cmd_count self.__cmd_env = deque() self.__cache_before_connect = [] self.__connected = False def con_ok(self): """ 连接对象是否ok :return: """ return self.__connected def connect(self, init_future): """ connect指令包括:AUTH, SELECT :param init_future: 第一个future对象 """ #future, connect_count, transaction, cmd_count self.__cmd_env.append((init_future, 1 + int(bool(self.__redis_pwd)), False, 0)) self.__stream = IOStream(socket.socket(socket.AF_INET, socket.SOCK_STREAM, 0), io_loop=self.__io_loop) self.__stream.set_close_callback(self.__on_close) self.__stream.connect(self.__redis_tuple[:2], self.__on_connect) def __on_connect(self): """连接,只需要发送初始cmd即可 """ self.__connected = True self.__stream.set_nodelay(True) self.__stream.read_until_close(self.__last_closd_recv, self.__on_resp) self.__stream.write(chain_select_cmd(self.__redis_pwd, self.__redis_tuple[-1])) for x in self.__cache_before_connect: self.__stream.write(x) self.__cache_before_connect = [] def write(self, buf, new_future, active_trans, cmd_count): """ :param new_future: 由于闭包的影响,在resp回调函数中会保存上一次的future对象,该对象必须得到更新 :param active_trans: 事务是否激活 :param cmd_count: 指令个数 """ self.__cmd_env.append((new_future, 0, active_trans, cmd_count)) if not self.__connected: self.__cache_before_connect.append(buf) return self.__stream.write(buf) def __last_closd_recv(self, data): """ socket关闭时最后几个字节 """ if not data: return self.__on_resp(data) def __on_resp(self, recv): """ :param recv: 收到的buf """ recv = ''.join((self.__recv_buf, recv)) idx = 0 for future, connect, trans, cmd in self.__cmd_env: ok, payload, recv = decode_resp_ondemand(recv, connect, trans, cmd) if not ok: break idx += 1 if not connect: self.__run_callback({_RESP_FUTURE: future, RESP_RESULT: payload}) self.__recv_buf = recv for _ in xrange(idx): self.__cmd_env.popleft() def __run_callback(self, resp): if self.__resp_cb is None: return self.__io_loop.add_callback(self.__resp_cb, resp) def __on_close(self): self.__connected = False while len(self.__cmd_env) > 0: self.__run_callback({_RESP_FUTURE: self.__cmd_env.popleft(), RESP_RESULT: 0}) self.__cmd_env.clear()
class HTTPServerRawTest(AsyncHTTPTestCase): def get_app(self): return Application([ ('/echo', EchoHandler), ]) def setUp(self): super(HTTPServerRawTest, self).setUp() self.stream = IOStream(socket.socket()) self.stream.connect(('127.0.0.1', self.get_http_port()), self.stop) self.wait() def tearDown(self): self.stream.close() super(HTTPServerRawTest, self).tearDown() def test_empty_request(self): self.stream.close() self.io_loop.add_timeout(datetime.timedelta(seconds=0.001), self.stop) self.wait() def test_malformed_first_line(self): with ExpectLog(gen_log, '.*Malformed HTTP request line'): self.stream.write(b'asdf\r\n\r\n') # TODO: need an async version of ExpectLog so we don't need # hard-coded timeouts here. self.io_loop.add_timeout(datetime.timedelta(seconds=0.05), self.stop) self.wait() def test_malformed_headers(self): with ExpectLog(gen_log, '.*Malformed HTTP headers'): self.stream.write(b'GET / HTTP/1.0\r\nasdf\r\n\r\n') self.io_loop.add_timeout(datetime.timedelta(seconds=0.05), self.stop) self.wait() def test_chunked_request_body(self): # Chunked requests are not widely supported and we don't have a way # to generate them in AsyncHTTPClient, but HTTPServer will read them. self.stream.write(b"""\ POST /echo HTTP/1.1 Transfer-Encoding: chunked Content-Type: application/x-www-form-urlencoded 4 foo= 3 bar 0 """.replace(b"\n", b"\r\n")) read_stream_body(self.stream, self.stop) headers, response = self.wait() self.assertEqual(json_decode(response), {u'foo': [u'bar']}) def test_chunked_request_uppercase(self): # As per RFC 2616 section 3.6, "Transfer-Encoding" header's value is # case-insensitive. self.stream.write(b"""\ POST /echo HTTP/1.1 Transfer-Encoding: Chunked Content-Type: application/x-www-form-urlencoded 4 foo= 3 bar 0 """.replace(b"\n", b"\r\n")) read_stream_body(self.stream, self.stop) headers, response = self.wait() self.assertEqual(json_decode(response), {u'foo': [u'bar']}) def test_invalid_content_length(self): with ExpectLog(gen_log, '.*Only integer Content-Length is allowed'): self.stream.write(b"""\ POST /echo HTTP/1.1 Content-Length: foo bar """.replace(b"\n", b"\r\n")) self.stream.read_until_close(self.stop) self.wait()
class Client(RedisCommandsMixin): """ Redis client class """ def __init__(self, io_loop=None): """ Constructor :param io_loop: Optional IOLoop instance """ self._io_loop = io_loop or IOLoop.instance() self._stream = None self.reader = None self.callbacks = deque() self._sub_callback = False def connect(self, host='localhost', port=6379, callback=None): """ Connect to redis server :param host: Host to connect to :param port: Port :param callback: Optional callback to be triggered upon connection """ sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM, 0) return self._connect(sock, (host, port), callback) def connect_usocket(self, usock, callback=None): """ Connect to redis server with unix socket """ sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM, 0) return self._connect(sock, usock, callback) def on_disconnect(self): """ Override this method if you want to handle disconnections """ pass # State def is_idle(self): """ Check if client is not waiting for any responses """ return len(self.callbacks) == 0 def is_connected(self): """ Check if client is still connected """ return bool(self._stream) and not self._stream.closed() def send_message(self, args, callback=None): """ Send command to redis :param args: Arguments to send :param callback: Callback """ # Special case for pub-sub cmd = args[0] if (self._sub_callback is not None and cmd not in ('PSUBSCRIBE', 'SUBSCRIBE', 'PUNSUBSCRIBE', 'UNSUBSCRIBE')): raise ValueError('Cannot run normal command over PUBSUB connection') # Send command self._stream.write(self.format_message(args)) if callback is not None: callback = stack_context.wrap(callback) self.callbacks.append((callback, None)) def send_messages(self, args_pipeline, callback=None): """ Send command pipeline to redis :param args_pipeline: Arguments pipeline to send :param callback: Callback """ if self._sub_callback is not None: raise ValueError('Cannot run pipeline over PUBSUB connection') # Send command pipeline messages = [self.format_message(args) for args in args_pipeline] self._stream.write(b"".join(messages)) if callback is not None: callback = stack_context.wrap(callback) self.callbacks.append((callback, (len(messages), []))) def format_message(self, args): """ Create redis message :param args: Message data """ l = "*%d" % len(args) lines = [l.encode('utf-8')] for arg in args: if not isinstance(arg, string_types): arg = str(arg) if isinstance(arg, text_type): arg = arg.encode('utf-8') l = "$%d" % len(arg) lines.append(l.encode('utf-8')) lines.append(arg) lines.append(b"") return b"\r\n".join(lines) def close(self): """ Close redis connection """ self.quit() self._stream.close() # Pub/sub commands def psubscribe(self, patterns, callback=None): """ Customized psubscribe command - will keep one callback for all incoming messages :param patterns: string or list of strings :param callback: callback """ self._set_sub_callback(callback) super(Client, self).psubscribe(patterns) def subscribe(self, channels, callback=None): """ Customized subscribe command - will keep one callback for all incoming messages :param channels: string or list of strings :param callback: Callback """ self._set_sub_callback(callback) super(Client, self).subscribe(channels) def _set_sub_callback(self, callback): if self._sub_callback is None: self._sub_callback = callback assert self._sub_callback == callback # Helpers def _connect(self, sock, addr, callback): self._reset() self._stream = IOStream(sock, io_loop=self._io_loop) self._stream.connect(addr, callback=callback) self._stream.read_until_close(self._on_close, self._on_read) # Event handlers def _on_read(self, data): self.reader.feed(data) resp = self.reader.gets() while resp is not False: if self._sub_callback: try: self._sub_callback(resp) except: logger.exception('SUB callback failed') else: if self.callbacks: callback, callback_data = self.callbacks[0] if callback_data is None: callback_resp = resp else: # handle pipeline responses num_resp, callback_resp = callback_data callback_resp.append(resp) while len(callback_resp) < num_resp: resp = self.reader.gets() if resp is False: # callback_resp is yet incomplete return callback_resp.append(resp) self.callbacks.popleft() if callback is not None: try: callback(callback_resp) except: logger.exception('Callback failed') else: logger.debug('Ignored response: %s' % repr(resp)) resp = self.reader.gets() def _on_close(self, data=None): if data is not None: self._on_read(data) # Trigger any pending callbacks callbacks = self.callbacks self.callbacks = deque() if callbacks: for cb in callbacks: callback, callback_data = cb if callback is not None: try: callback(None) except: logger.exception('Exception in callback') if self._sub_callback is not None: try: self._sub_callback(None) except: logger.exception('Exception in SUB callback') self._sub_callback = None # Trigger on_disconnect self.on_disconnect() def _reset(self): self.reader = hiredis.Reader() self._sub_callback = None def pipeline(self): return Pipeline(self)
class Connection(RedisCommandsMixin): def __init__(self, redis, on_connect=None): logger.debug('Creating new Redis connection.') self.redis = redis self.reader = hiredis.Reader() self._watch = set() self._multi = False self.callbacks = deque() self._on_connect_callback = on_connect self.stream = IOStream( socket.socket(redis._family, socket.SOCK_STREAM, 0), io_loop=redis._ioloop ) self.stream.set_close_callback(self._on_close) self.stream.connect(redis._addr, self._on_connect) def _on_connect(self): logger.debug('Connected!') self.stream.read_until_close(self._on_close, self._on_read) self.redis._shared.append(self) if self._on_connect_callback is not None: self._on_connect_callback(self) self._on_connect_callback = None def _on_read(self, data): self.reader.feed(data) while True: resp = self.reader.gets() if resp is False: break callback = self.callbacks.popleft() if callback is not None: self.redis._ioloop.add_callback(partial(callback, resp)) def is_idle(self): return len(self.callbacks) == 0 def is_shared(self): return self in self.redis._shared def lock(self): if not self.is_shared(): raise Exception('Connection already is locked!') self.redis._shared.remove(self) def unlock(self, callback=None): def cb(resp): assert resp == 'OK' self.redis._shared.append(self) if self._multi: self.send_message(['DISCARD']) elif self._watch: self.send_message(['UNWATCH']) self.send_message(['SELECT', self.redis._database], cb) def send_message(self, args, callback=None): command = args[0] if 'SUBSCRIBE' in command: raise NotImplementedError('Not yet.') # Do not allow the commands, affecting the execution of other commands, # to be used on shared connection. if command in ('WATCH', 'MULTI'): if self.is_shared(): raise Exception('Command %s is not allowed while connection ' 'is shared!' % command) if command == 'WATCH': self._watch.add(args[1]) if command == 'MULTI': self._multi = True # monitor transaction state, to unlock correctly if command in ('EXEC', 'DISCARD', 'UNWATCH'): if command in ('EXEC', 'DISCARD'): self._multi = False self._watch.clear() self.stream.write(self.format_message(args)) future = Future() if callback is not None: future.add_done_callback(stack_context.wrap(callback)) self.callbacks.append(future.set_result) return future def format_message(self, args): l = "*%d" % len(args) lines = [l.encode('utf-8')] for arg in args: if not isinstance(arg, str): arg = str(arg) arg = arg.encode('utf-8') l = "$%d" % len(arg) lines.append(l.encode('utf-8')) lines.append(arg) lines.append(b"") return b"\r\n".join(lines) def close(self): self.send_command(['QUIT']) if self.is_shared(): self.lock() def _on_close(self, data=None): logger.debug('Redis connection was closed.') if data is not None: self._on_read(data) if self.is_shared(): self.lock()
class KeepAliveTest(AsyncHTTPTestCase): """Tests various scenarios for HTTP 1.1 keep-alive support. These tests don't use AsyncHTTPClient because we want to control connection reuse and closing. """ def get_app(self): class HelloHandler(RequestHandler): def get(self): self.finish("Hello world") def post(self): self.finish("Hello world") class LargeHandler(RequestHandler): def get(self): # 512KB should be bigger than the socket buffers so it will # be written out in chunks. self.write("".join(chr(i % 256) * 1024 for i in range(512))) class FinishOnCloseHandler(RequestHandler): def initialize(self, cleanup_event): self.cleanup_event = cleanup_event @gen.coroutine def get(self): self.flush() yield self.cleanup_event.wait() def on_connection_close(self): # This is not very realistic, but finishing the request # from the close callback has the right timing to mimic # some errors seen in the wild. self.finish("closed") self.cleanup_event = Event() return Application([ ("/", HelloHandler), ("/large", LargeHandler), ( "/finish_on_close", FinishOnCloseHandler, dict(cleanup_event=self.cleanup_event), ), ]) def setUp(self): super(KeepAliveTest, self).setUp() self.http_version = b"HTTP/1.1" def tearDown(self): # We just closed the client side of the socket; let the IOLoop run # once to make sure the server side got the message. self.io_loop.add_timeout(datetime.timedelta(seconds=0.001), self.stop) self.wait() if hasattr(self, "stream"): self.stream.close() super(KeepAliveTest, self).tearDown() # The next few methods are a crude manual http client @gen.coroutine def connect(self): self.stream = IOStream(socket.socket()) yield self.stream.connect(("10.0.0.7", self.get_http_port())) @gen.coroutine def read_headers(self): first_line = yield self.stream.read_until(b"\r\n") self.assertTrue(first_line.startswith(b"HTTP/1.1 200"), first_line) header_bytes = yield self.stream.read_until(b"\r\n\r\n") headers = HTTPHeaders.parse(header_bytes.decode("latin1")) raise gen.Return(headers) @gen.coroutine def read_response(self): self.headers = yield self.read_headers() body = yield self.stream.read_bytes(int( self.headers["Content-Length"])) self.assertEqual(b"Hello world", body) def close(self): self.stream.close() del self.stream @gen_test def test_two_requests(self): yield self.connect() self.stream.write(b"GET / HTTP/1.1\r\n\r\n") yield self.read_response() self.stream.write(b"GET / HTTP/1.1\r\n\r\n") yield self.read_response() self.close() @gen_test def test_request_close(self): yield self.connect() self.stream.write(b"GET / HTTP/1.1\r\nConnection: close\r\n\r\n") yield self.read_response() data = yield self.stream.read_until_close() self.assertTrue(not data) self.assertEqual(self.headers["Connection"], "close") self.close() # keepalive is supported for http 1.0 too, but it's opt-in @gen_test def test_http10(self): self.http_version = b"HTTP/1.0" yield self.connect() self.stream.write(b"GET / HTTP/1.0\r\n\r\n") yield self.read_response() data = yield self.stream.read_until_close() self.assertTrue(not data) self.assertTrue("Connection" not in self.headers) self.close() @gen_test def test_http10_keepalive(self): self.http_version = b"HTTP/1.0" yield self.connect() self.stream.write(b"GET / HTTP/1.0\r\nConnection: keep-alive\r\n\r\n") yield self.read_response() self.assertEqual(self.headers["Connection"], "Keep-Alive") self.stream.write(b"GET / HTTP/1.0\r\nConnection: keep-alive\r\n\r\n") yield self.read_response() self.assertEqual(self.headers["Connection"], "Keep-Alive") self.close() @gen_test def test_http10_keepalive_extra_crlf(self): self.http_version = b"HTTP/1.0" yield self.connect() self.stream.write( b"GET / HTTP/1.0\r\nConnection: keep-alive\r\n\r\n\r\n") yield self.read_response() self.assertEqual(self.headers["Connection"], "Keep-Alive") self.stream.write(b"GET / HTTP/1.0\r\nConnection: keep-alive\r\n\r\n") yield self.read_response() self.assertEqual(self.headers["Connection"], "Keep-Alive") self.close() @gen_test def test_pipelined_requests(self): yield self.connect() self.stream.write(b"GET / HTTP/1.1\r\n\r\nGET / HTTP/1.1\r\n\r\n") yield self.read_response() yield self.read_response() self.close() @gen_test def test_pipelined_cancel(self): yield self.connect() self.stream.write(b"GET / HTTP/1.1\r\n\r\nGET / HTTP/1.1\r\n\r\n") # only read once yield self.read_response() self.close() @gen_test def test_cancel_during_download(self): yield self.connect() self.stream.write(b"GET /large HTTP/1.1\r\n\r\n") yield self.read_headers() yield self.stream.read_bytes(1024) self.close() @gen_test def test_finish_while_closed(self): yield self.connect() self.stream.write(b"GET /finish_on_close HTTP/1.1\r\n\r\n") yield self.read_headers() self.close() # Let the hanging coroutine clean up after itself self.cleanup_event.set() @gen_test def test_keepalive_chunked(self): self.http_version = b"HTTP/1.0" yield self.connect() self.stream.write(b"POST / HTTP/1.0\r\n" b"Connection: keep-alive\r\n" b"Transfer-Encoding: chunked\r\n" b"\r\n" b"0\r\n" b"\r\n") yield self.read_response() self.assertEqual(self.headers["Connection"], "Keep-Alive") self.stream.write(b"GET / HTTP/1.0\r\nConnection: keep-alive\r\n\r\n") yield self.read_response() self.assertEqual(self.headers["Connection"], "Keep-Alive") self.close()
class Client(RedisCommandsMixin): """ Redis client class """ def __init__(self, io_loop=None): """ Constructor :param io_loop: Optional IOLoop instance """ self._io_loop = io_loop or IOLoop.instance() self._stream = None self.reader = None self.callbacks = deque() self._sub_callback = False def connect(self, host='localhost', port=6379, callback=None): """ Connect to redis server :param host: Host to connect to :param port: Port :param callback: Optional callback to be triggered upon connection """ sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM, 0) return self._connect(sock, (host, port), callback) def connect_usocket(self, usock, callback=None): """ Connect to redis server with unix socket """ sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM, 0) return self._connect(sock, usock, callback) def on_disconnect(self): """ Override this method if you want to handle disconnections """ pass # State def is_idle(self): """ Check if client is not waiting for any responses """ return len(self.callbacks) == 0 def is_connected(self): """ Check if client is still connected """ return bool(self._stream) and not self._stream.closed() def send_message(self, args, callback=None): """ Send command to redis :param args: Arguments to send :param callback: Callback """ # Special case for pub-sub cmd = args[0] if (self._sub_callback is not None and cmd not in ('PSUBSCRIBE', 'SUBSCRIBE', 'PUNSUBSCRIBE', 'UNSUBSCRIBE')): raise ValueError( 'Cannot run normal command over PUBSUB connection') # Send command self._stream.write(self.format_message(args)) if callback is not None: callback = stack_context.wrap(callback) self.callbacks.append(callback) def format_message(self, args): """ Create redis message :param args: Message data """ l = "*%d" % len(args) lines = [l.encode('utf-8')] for arg in args: if not isinstance(arg, basestring): arg = str(arg) arg = arg.encode('utf-8') l = "$%d" % len(arg) lines.append(l.encode('utf-8')) lines.append(arg) lines.append(b"") return b"\r\n".join(lines) def close(self): """ Close redis connection """ self.quit() self._stream.close() # Pub/sub commands def psubscribe(self, patterns, callback=None): """ Customized psubscribe command - will keep one callback for all incoming messages :param patterns: string or list of strings :param callback: callback """ self._set_sub_callback(callback) super(Client, self).psubscribe(patterns) def subscribe(self, channels, callback=None): """ Customized subscribe command - will keep one callback for all incoming messages :param channels: string or list of strings :param callback: Callback """ self._set_sub_callback(callback) super(Client, self).subscribe(channels) def _set_sub_callback(self, callback): if self._sub_callback is None: self._sub_callback = callback assert self._sub_callback == callback # Helpers def _connect(self, sock, addr, callback): self._reset() self._stream = IOStream(sock, io_loop=self._io_loop) self._stream.read_until_close(self._on_close, self._on_read) self._stream.connect(addr, callback=callback) # Event handlers def _on_read(self, data): self.reader.feed(data) resp = self.reader.gets() while resp is not False: if self._sub_callback: try: self._sub_callback(resp) except: logger.exception('SUB callback failed') else: if self.callbacks: callback = self.callbacks.popleft() if callback is not None: try: callback(resp) except: logger.exception('Callback failed') else: logger.debug('Ignored response: %s' % repr(resp)) resp = self.reader.gets() def _on_close(self, data=None): if data is not None: self._on_read(data) # Trigger any pending callbacks callbacks = self.callbacks self.callbacks = deque() if callbacks: for cb in callbacks: if cb is not None: try: cb(None) except: logger.exception('Exception in callback') if self._sub_callback is not None: try: self._sub_callback(None) except: logger.exception('Exception in SUB callback') self._sub_callback = None # Trigger on_disconnect self.on_disconnect() def _reset(self): self.reader = hiredis.Reader() self._sub_callback = None
class HTTPServerRawTest(AsyncHTTPTestCase): def get_app(self): return Application([ ('/echo', EchoHandler), ]) def setUp(self): super(HTTPServerRawTest, self).setUp() self.stream = IOStream(socket.socket()) self.stream.connect(('127.0.0.1', self.get_http_port()), self.stop) self.wait() def tearDown(self): self.stream.close() super(HTTPServerRawTest, self).tearDown() def test_empty_request(self): self.stream.close() self.io_loop.add_timeout(datetime.timedelta(seconds=0.001), self.stop) self.wait() def test_malformed_first_line_response(self): with ExpectLog(gen_log, '.*Malformed HTTP request line'): self.stream.write(b'asdf\r\n\r\n') read_stream_body(self.stream, self.stop) start_line, headers, response = self.wait() self.assertEqual('HTTP/1.1', start_line.version) self.assertEqual(400, start_line.code) self.assertEqual('Bad Request', start_line.reason) def test_malformed_first_line_log(self): with ExpectLog(gen_log, '.*Malformed HTTP request line'): self.stream.write(b'asdf\r\n\r\n') # TODO: need an async version of ExpectLog so we don't need # hard-coded timeouts here. self.io_loop.add_timeout(datetime.timedelta(seconds=0.05), self.stop) self.wait() def test_malformed_headers(self): with ExpectLog(gen_log, '.*Malformed HTTP headers'): self.stream.write(b'GET / HTTP/1.0\r\nasdf\r\n\r\n') self.io_loop.add_timeout(datetime.timedelta(seconds=0.05), self.stop) self.wait() def test_chunked_request_body(self): # Chunked requests are not widely supported and we don't have a way # to generate them in AsyncHTTPClient, but HTTPServer will read them. self.stream.write(b"""\ POST /echo HTTP/1.1 Transfer-Encoding: chunked Content-Type: application/x-www-form-urlencoded 4 foo= 3 bar 0 """.replace(b"\n", b"\r\n")) read_stream_body(self.stream, self.stop) start_line, headers, response = self.wait() self.assertEqual(json_decode(response), {u'foo': [u'bar']}) def test_chunked_request_uppercase(self): # As per RFC 2616 section 3.6, "Transfer-Encoding" header's value is # case-insensitive. self.stream.write(b"""\ POST /echo HTTP/1.1 Transfer-Encoding: Chunked Content-Type: application/x-www-form-urlencoded 4 foo= 3 bar 0 """.replace(b"\n", b"\r\n")) read_stream_body(self.stream, self.stop) start_line, headers, response = self.wait() self.assertEqual(json_decode(response), {u'foo': [u'bar']}) def test_invalid_content_length(self): with ExpectLog(gen_log, '.*Only integer Content-Length is allowed'): self.stream.write(b"""\ POST /echo HTTP/1.1 Content-Length: foo bar """.replace(b"\n", b"\r\n")) self.stream.read_until_close(self.stop) self.wait()
class _PrxConn(object): def __init__(self, handle_resp, svr_addr): assert callable(handle_resp) self._io_loop = IOLoop.instance() self.__resp_cb = handle_resp self.__svr_addr = svr_addr self._stream = None self._send_buf = deque() self._recv_buf = '' self.__cmd_env = deque() self.__con_ok = False def con_ok(self): """ 连接不可用 """ return self.__con_ok def connect(self): self._stream = IOStream(socket.socket(socket.AF_INET, socket.SOCK_STREAM, 0), io_loop=self._io_loop) self._stream.set_close_callback(self._on_close) self._stream.connect(self.__svr_addr, self._on_connect) def _on_connect(self): self._stream.set_nodelay(True) while len(self._send_buf) > 0: self._stream.write(self._send_buf.popleft()) self._stream.read_until_close(self._last_closd_recv, self._on_recv) self.__con_ok = True def write(self, future, encode_result): self.__cmd_env.append(future) if not self.__con_ok: self._send_buf.append(encode_result) else: self._stream.write(encode_result) def _last_closd_recv(self, data): """ socket关闭时最后几个字节 """ self._on_recv(data) def _on_recv(self, buf): self._recv_buf += buf while 1: if not self._recv_buf: break ok, payload, self._recv_buf = decode_resp_ondemand( self._recv_buf, 0, False, 1) if not ok: break if payload and isinstance(payload, (list, tuple)) and 1 == len(payload): payload = payload[0] self.__run_callback({ _RESP_FUTURE: self.__cmd_env.popleft(), RESP_RESULT: payload }) def __run_callback(self, resp): if self.__resp_cb is None: return self._io_loop.add_callback(self.__resp_cb, resp) def _on_close(self): self.__con_ok = False while len(self.__cmd_env) > 0: self.__run_callback({ _RESP_FUTURE: self.__cmd_env.popleft(), RESP_RESULT: 0 }) self.__cmd_env.clear()
class KeepAliveTest(AsyncHTTPTestCase): """Tests various scenarios for HTTP 1.1 keep-alive support. These tests don't use AsyncHTTPClient because we want to control connection reuse and closing. """ def get_app(self): class HelloHandler(RequestHandler): def get(self): self.finish('Hello world') def post(self): self.finish('Hello world') class LargeHandler(RequestHandler): def get(self): # 512KB should be bigger than the socket buffers so it will # be written out in chunks. self.write(''.join(chr(i % 256) * 1024 for i in range(512))) class FinishOnCloseHandler(RequestHandler): @asynchronous def get(self): self.flush() def on_connection_close(self): # This is not very realistic, but finishing the request # from the close callback has the right timing to mimic # some errors seen in the wild. self.finish('closed') return Application([('/', HelloHandler), ('/large', LargeHandler), ('/finish_on_close', FinishOnCloseHandler)]) def setUp(self): super(KeepAliveTest, self).setUp() self.http_version = b'HTTP/1.1' def tearDown(self): # We just closed the client side of the socket; let the IOLoop run # once to make sure the server side got the message. self.io_loop.add_timeout(datetime.timedelta(seconds=0.001), self.stop) self.wait() if hasattr(self, 'stream'): self.stream.close() super(KeepAliveTest, self).tearDown() # The next few methods are a crude manual http client def connect(self): self.stream = IOStream(socket.socket()) self.stream.connect(('127.0.0.1', self.get_http_port()), self.stop) self.wait() def read_headers(self): self.stream.read_until(b'\r\n', self.stop) first_line = self.wait() self.assertTrue(first_line.startswith(b'HTTP/1.1 200'), first_line) self.stream.read_until(b'\r\n\r\n', self.stop) header_bytes = self.wait() headers = HTTPHeaders.parse(header_bytes.decode('latin1')) return headers def read_response(self): self.headers = self.read_headers() self.stream.read_bytes(int(self.headers['Content-Length']), self.stop) body = self.wait() self.assertEqual(b'Hello world', body) def close(self): self.stream.close() del self.stream def test_two_requests(self): self.connect() self.stream.write(b'GET / HTTP/1.1\r\n\r\n') self.read_response() self.stream.write(b'GET / HTTP/1.1\r\n\r\n') self.read_response() self.close() def test_request_close(self): self.connect() self.stream.write(b'GET / HTTP/1.1\r\nConnection: close\r\n\r\n') self.read_response() self.stream.read_until_close(callback=self.stop) data = self.wait() self.assertTrue(not data) self.assertEqual(self.headers['Connection'], 'close') self.close() # keepalive is supported for http 1.0 too, but it's opt-in def test_http10(self): self.http_version = b'HTTP/1.0' self.connect() self.stream.write(b'GET / HTTP/1.0\r\n\r\n') self.read_response() self.stream.read_until_close(callback=self.stop) data = self.wait() self.assertTrue(not data) self.assertTrue('Connection' not in self.headers) self.close() def test_http10_keepalive(self): self.http_version = b'HTTP/1.0' self.connect() self.stream.write(b'GET / HTTP/1.0\r\nConnection: keep-alive\r\n\r\n') self.read_response() self.assertEqual(self.headers['Connection'], 'Keep-Alive') self.stream.write(b'GET / HTTP/1.0\r\nConnection: keep-alive\r\n\r\n') self.read_response() self.assertEqual(self.headers['Connection'], 'Keep-Alive') self.close() def test_http10_keepalive_extra_crlf(self): self.http_version = b'HTTP/1.0' self.connect() self.stream.write(b'GET / HTTP/1.0\r\nConnection: keep-alive\r\n\r\n\r\n') self.read_response() self.assertEqual(self.headers['Connection'], 'Keep-Alive') self.stream.write(b'GET / HTTP/1.0\r\nConnection: keep-alive\r\n\r\n') self.read_response() self.assertEqual(self.headers['Connection'], 'Keep-Alive') self.close() def test_pipelined_requests(self): self.connect() self.stream.write(b'GET / HTTP/1.1\r\n\r\nGET / HTTP/1.1\r\n\r\n') self.read_response() self.read_response() self.close() def test_pipelined_cancel(self): self.connect() self.stream.write(b'GET / HTTP/1.1\r\n\r\nGET / HTTP/1.1\r\n\r\n') # only read once self.read_response() self.close() def test_cancel_during_download(self): self.connect() self.stream.write(b'GET /large HTTP/1.1\r\n\r\n') self.read_headers() self.stream.read_bytes(1024, self.stop) self.wait() self.close() def test_finish_while_closed(self): self.connect() self.stream.write(b'GET /finish_on_close HTTP/1.1\r\n\r\n') self.read_headers() self.close() def test_keepalive_chunked(self): self.http_version = b'HTTP/1.0' self.connect() self.stream.write(b'POST / HTTP/1.0\r\n' b'Connection: keep-alive\r\n' b'Transfer-Encoding: chunked\r\n' b'\r\n' b'0\r\n' b'\r\n') self.read_response() self.assertEqual(self.headers['Connection'], 'Keep-Alive') self.stream.write(b'GET / HTTP/1.0\r\nConnection: keep-alive\r\n\r\n') self.read_response() self.assertEqual(self.headers['Connection'], 'Keep-Alive') self.close()
class SerialPortConnection: """Models a serial connection to a remote device over a Bluetooth RFCOMM link. Provides send and receive functionality (with proper parsing), and can track replies to certain requests. """ CHLD_MAP = { 0: "ReleaseAllHeldOrUDUB", 1: "ReleaseAllActive,AcceptOther", 2: "HoldAllActive,AcceptOther", 3: "AddCallToConference", 4: "JoinCalls,HangUp" } CME_ERROR_MAP = { 0: "AG failure", 1: "No connection to phone", 3: "Operation not allowed", 4: "Operation not supported", 5: "PH-SIM PIN required", 10: "SIM not inserted", 11: "SIM PIN required", 12: "SIM PUK required", 13: "SIM failure", 14: "SIM busy", 16: "Incorrect password", 17: "SIM PIN2 required", 18: "SIM PUK2 required", 20: "Memory full", 21: "Invalid index", 23: "Memory failure", 24: "Text string too long", 25: "Invalid text string", 26: "Dial string too long", 27: "Invalid dial string", 30: "No network service", 31: "Network timeout", 32: "Emergency calls only" } def __init__(self, socket, device_path, async_reply_delay, io_loop): self._async_reply_delay = async_reply_delay self._io_loop = io_loop # socket.getpeername() returns different address # so use the end of the device path instead self._peer = device_path[-17:].replace("_", ":") self._remainder = b'' # <code>: [{} -> # "future": <future> # "handle": <timeout handle>] self._reply_q = defaultdict(list) self._socket = socket self.on_close = None self.on_error = None self.on_message = None self._stream = IOStream(socket=self._socket) self._stream.set_close_callback(self._on_close) self._stream.read_until_close(streaming_callback=self._data_ready) @property def peer(self): """Returns the address of the remote device. """ return self._peer def close(self): """Voluntarily closes the RFCOMM connection. """ self._stream.close() def _async_timeout(self, code): """Called when an expected async reply doesn't arrive in the expected timeframe. """ qentry = self._reply_q[code].pop() qentry["future"].set_exception(TimeoutError("Did not receive reply.")) def _data_ready(self, data): """Parses data that has been received over the serial connection. """ logger.debug("Received {} bytes from AG over SPC - {}".format( len(data), data)) if len(self._remainder) > 0: data = self._remainder + data logger.debug("Appended left-over bytes - {}".format( self._remainder)) while True: # all AG -> HF messages are <cr><lf> delimited try: msg, data = data.split(b'\x0d\x0a', 1) except ValueError: self._remainder = data return # decode to ASCII, logging but ignoring decode errors try: msg = msg.decode('ascii', errors='strict') except UnicodeDecodeError as e: logger.warning("ASCII decode error, going to ignore dodgy " "characters - {}".format(e)) msg = msg.decode('ascii', errors='ignore') try: if len(msg) > 0: self._on_message(msg) except Exception: logger.exception("Message handler threw an unhandled " "exception with data \"{}\"".format(msg)) if data == b'': self._remainder = b'' return def _on_close(self, *args): """The connection was closed by either side. """ self._stream = None self._remainder = b'' logger.info("Serial port connection to AG was closed.") # error out any remaining futures for lst in self._reply_q.values(): for item in lst: item["future"].set_exception( ConnectionError("Connection was closed.")) self._reply_q.clear() if self.on_close: self.on_close() def _on_message(self, msg): """Invoked with a parsed message that we must now process. """ if msg == "ERROR": # cleaner to report errors separately if self.on_error: self.on_error(None) elif msg == "OK": # simple ACK # get a Future if async tracking try: qentry = self._reply_q["OK"].pop() self._io_loop.remove_timeout(qentry["handle"]) except IndexError: qentry = None if qentry: qentry["future"].set_result("OK") else: if self.on_message: self.on_message(code="OK", data=None) elif msg == "RING": # ringing alert if self.on_message: self.on_message(code="RING", data=None) else: # strip leading "+" and split from first ":" # e.g. +BRSF: ... code, params = msg[1:].split(":", maxsplit=1) # shortcut to CME error reporting handler if code == "CME ERROR": if self.on_error: self.on_error(self._handle_cme_error(params)) return # find a handler function func_name = "_handle_{}".format(code.lower()) try: handler = getattr(self, func_name) except AttributeError: logger.warning( "No handler for code {}, ignoring...".format(code)) return # get a Future if async tracking try: qentry = self._reply_q[code].pop() self._io_loop.remove_timeout(qentry["handle"]) except IndexError: qentry = None # execute handler (and deal with Future) try: ret = handler(params=params.strip()) except Exception as e: logger.error( "Handler threw unhandled exception - {}".format(e)) if qentry: qentry["future"].set_exception(e) return if qentry: qentry["future"].set_result(ret) #else: if self.on_message: self.on_message(code=code, data=ret) def _handle_brsf(self, params): """Supported features of the AG. """ params = int(params) return { "3WAY": (params & 0x1) == 0x1, "ECNR": (params & 0x2) == 0x2, "VOICE_RECOGNITION": (params & 0x4) == 0x4, "INBAND_RING": (params & 0x8) == 0x8, "PHONE_VTAG": (params & 0x10) == 0x10, "CALL_REJECT": (params & 0x20) == 0x20, "ECALL_STAT": (params & 0x40) == 0x40, "ECALL_CTRL": (params & 0x80) == 0x80, "EXTD_ERROR": (params & 0x100) == 0x100, "CODEC_NEG": (params & 0x200) == 0x200, "HF_INDICATORS": (params & 0x400) == 0x400, "ESCO_S4T2": (params & 0x800) == 0x800 } def _handle_chld(self, params): """Info about how 3way/call wait is handled. """ params = ast.literal_eval(params) return [SerialPortConnection.CHLD_MAP.get(f, f) for f in params] def _handle_ciev(self, params): """Single indicator update. """ try: params = params.split(",") return {self._indmap[int(params[0]) - 1]: params[1]} except IndexError: logger.debug("Unknown indicator, will ignore it.") def _handle_cind(self, params): """Indicators available by the AG. This class maps the indices to actual names to make it easier upstream. """ # either initial indicator info... # ("call",(0,1)),("callsetup",(0-3)),("service",(0-1)),("signal",(0-5)), # ("roam",(0,1)),("battchg",(0-5)),("callheld",(0-2)) if "(" in params: params = ast.literal_eval(params) self._indmap = dict([(i, name) for i, (name, _) in enumerate(params)]) return [name for name, _ in params] # ...or initial indicator values # 0,0,1,4,0,3,0 return dict([(self._indmap[i], val) for i, val in enumerate(params.split(","))]) def _handle_clip(self, params): """Contains phone number of calling party (if CLI enabled). """ # "0383417060",129 if "," in params: params = params[:params.index(",")] return params.replace("\"", "") def _handle_cme_error(self, params): """Extended error code. """ return SerialPortConnection.CME_ERROR_MAP.get(int(params), params) def _handle_cops(self, params): """Network operator query response. """ params = params.split(",") return {"mode": params[0], "name": params[2]} def _handle_ccwa(self, params): """Contains phone number of calling party in a call-waiting scenario (if CLI enabled). """ # "0383417060",129 if "," in params: params = params[:params.index(",")] return params.replace("\"", "") def send_message(self, message, async_reply_code=None): """Sends a message. If async is not None, this returns a Future that can be yielded. The Future will resolve when the supplied reply code is next received. The Future will error-out if no reply is received in the delay (seconds) given in the constructor. """ try: logger.debug("Sending \"{}\" over SPC.".format(message)) data = message + "\x0a" self._stream.write(data.encode("ascii")) except Exception as e: logger.exception("Error sending \"{}\" over SPC.".format(message)) raise ConnectionError( "Error sending \"{}\" over SPC.".format(message)) # async tracking? if async_reply_code: queue = self._reply_q[async_reply_code] fut = Future() handle = self._io_loop.call_later(delay=self._async_reply_delay, callback=self._async_timeout, code=async_reply_code) queue.append({"future": fut, "handle": handle}) return fut return None