def test_100_continue(self): # Run through a 100-continue interaction by hand: # When given Expect: 100-continue, we get a 100 response after the # headers, and then the real response after the body. stream = IOStream(socket.socket()) stream.connect(("127.0.0.1", self.get_http_port()), callback=self.stop) self.wait() stream.write(b"\r\n".join([b"POST /hello HTTP/1.1", b"Content-Length: 1024", b"Expect: 100-continue", b"Connection: close", b"\r\n"]), callback=self.stop) self.wait() stream.read_until(b"\r\n\r\n", self.stop) data = self.wait() self.assertTrue(data.startswith(b"HTTP/1.1 100 "), data) stream.write(b"a" * 1024) stream.read_until(b"\r\n", self.stop) first_line = self.wait() self.assertTrue(first_line.startswith(b"HTTP/1.1 200"), first_line) stream.read_until(b"\r\n\r\n", self.stop) header_data = self.wait() headers = HTTPHeaders.parse(native_str(header_data.decode('latin1'))) stream.read_bytes(int(headers["Content-Length"]), self.stop) body = self.wait() self.assertEqual(body, b"Got 1024 bytes in POST") stream.close()
def handle_connection(connection, address): log.info('Connection received from %s' % str(address)) stream = IOStream(connection, ioloop, max_buffer_size=1024 * 1024 * 1024) # Getting uuid try: stream.read_bytes(4, partial(read_uuid_size, stream)) except StreamClosedError: log.warn('Closed stream for getting uuid length')
def handle_connection(connection, address): log.info("Connection received from %s" % str(address)) stream = IOStream(connection, ioloop, max_buffer_size=1024 * 1024 * 1024) # Getting uuid try: stream.read_bytes(4, partial(read_uuid_size, stream)) except StreamClosedError: log.warn("Closed stream for getting uuid length")
def handle_connection(connection, address): log.info('Connection received from %s' % str(address)) stream = IOStream(connection, ioloop) # Getting uuid try: stream.read_bytes(4, partial(read_uuid_size, stream)) except StreamClosedError: log.warn('Closed stream for getting uuid length')
class Client(object): def __init__(self, host, port, protocol, name, session, dial=True): self._host = host self._port = port self._protocol = protocol self._name = name self._session = session self._running = False self._socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM, 0) self._connection = IOStream(self._socket) if dial: self._connection.connect((host, port), self.on_dial) def dial(self): if self._running: return self._connection.connect((self._host, self._port), self.on_dial) self._read_loop() def on_dial(self): """ subclass should implement this method, example: def on_dial(self): self._connection.send(msg) self._read_loop() :return: None """ self._running = True def read_loop(self): #self._connection.read_bytes(self._protocol.head_size(), callback=self._debug) self._connection.read_bytes(self._protocol.head_size(), callback=self._handle_head) def _handle_head(self, head): receive_bytes, body_size = self._protocol.handle_head( self._session, head) # if self._session.receive_bytes >= receive_bytes: # # TODO it is an error # pass # #return self._session.receive_bytes = receive_bytes def handler(body): self._protocol.handle(self._session, body) self.read_loop() if body_size > 0: self._connection.read_bytes(body_size, callback=handler) def send_raw(self, data): self._connection.write_to_fd(self._protocol.encode( self._session, data)) def send(self, msg): raw_data = self._protocol.encode(self._session, msg.encode()) self._connection.write_to_fd(raw_data)
class Flash(object): def __init__(self, close_callback=None): self._iostream = None self._close_callback = close_callback def connect(self, host='127.0.0.1', port=9999): sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) self._iostream = IOStream(sock) self._iostream.set_close_callback(self._on_connection_close) # коннектимся и начинаем слушать команды self._iostream.connect((host, port), self._read_head) def close(self): self._on_connection_close() def _on_connection_close(self): self._iostream.close() if self._close_callback: self._close_callback() def _read_head(self): self._iostream.read_bytes(BaseCommand.meta_size, self._on_read_head) def _on_read_head(self, data): ctype, length = struct.unpack(">BH", data) if length: self._iostream.read_bytes(length, partial(self.execute_command, ctype)) else: self.execute_command(ctype) def execute_command(self, ctype, value=None): command = CommandsRegistry.get_by_type(ctype) if command is not None: command.execute(value) # else: # print 'unknown command: type={:#x}'.format(ctype) self._read_head() @classmethod def start(cls, host, port): flash = cls(close_callback=IOLoop.instance().stop) flash.connect(host, port) signal.signal(signal.SIGINT, flash.close) IOLoop.instance().start() IOLoop.instance().close()
class UnixSocketTest(AsyncTestCase): """HTTPServers can listen on Unix sockets too. Why would you want to do this? Nginx can proxy to backends listening on unix sockets, for one thing (and managing a namespace for unix sockets can be easier than managing a bunch of TCP port numbers). Unfortunately, there's no way to specify a unix socket in a url for an HTTP client, so we have to test this by hand. """ def setUp(self): super(UnixSocketTest, self).setUp() self.tmpdir = tempfile.mkdtemp() self.sockfile = os.path.join(self.tmpdir, "test.sock") sock = netutil.bind_unix_socket(self.sockfile) app = Application([("/hello", HelloWorldRequestHandler)]) self.server = HTTPServer(app) self.server.add_socket(sock) self.stream = IOStream(socket.socket(socket.AF_UNIX)) self.stream.connect(self.sockfile, self.stop) self.wait() def tearDown(self): self.stream.close() self.io_loop.run_sync(self.server.close_all_connections) self.server.stop() shutil.rmtree(self.tmpdir) super(UnixSocketTest, self).tearDown() def test_unix_socket(self): self.stream.write(b"GET /hello HTTP/1.0\r\n\r\n") self.stream.read_until(b"\r\n", self.stop) response = self.wait() self.assertEqual(response, b"HTTP/1.1 200 OK\r\n") self.stream.read_until(b"\r\n\r\n", self.stop) headers = HTTPHeaders.parse(self.wait().decode('latin1')) self.stream.read_bytes(int(headers["Content-Length"]), self.stop) body = self.wait() self.assertEqual(body, b"Hello world") def test_unix_socket_bad_request(self): # Unix sockets don't have remote addresses so they just return an # empty string. with ExpectLog(gen_log, "Malformed HTTP message from"): self.stream.write(b"garbage\r\n\r\n") self.stream.read_until_close(self.stop) response = self.wait() self.assertEqual(response, b"HTTP/1.1 400 Bad Request\r\n\r\n")
class UnixSocketTest(AsyncTestCase): """HTTPServers can listen on Unix sockets too. Why would you want to do this? Nginx can proxy to backends listening on unix sockets, for one thing (and managing a namespace for unix sockets can be easier than managing a bunch of TCP port numbers). Unfortunately, there's no way to specify a unix socket in a url for an HTTP client, so we have to test this by hand. """ def setUp(self): super(UnixSocketTest, self).setUp() self.tmpdir = tempfile.mkdtemp() self.sockfile = os.path.join(self.tmpdir, "test.sock") sock = netutil.bind_unix_socket(self.sockfile) app = Application([("/hello", HelloWorldRequestHandler)]) self.server = HTTPServer(app, io_loop=self.io_loop) self.server.add_socket(sock) self.stream = IOStream(socket.socket(socket.AF_UNIX), io_loop=self.io_loop) self.stream.connect(self.sockfile, self.stop) self.wait() def tearDown(self): self.stream.close() self.server.stop() shutil.rmtree(self.tmpdir) super(UnixSocketTest, self).tearDown() def test_unix_socket(self): self.stream.write(b"GET /hello HTTP/1.0\r\n\r\n") self.stream.read_until(b"\r\n", self.stop) response = self.wait() self.assertEqual(response, b"HTTP/1.0 200 OK\r\n") self.stream.read_until(b"\r\n\r\n", self.stop) headers = HTTPHeaders.parse(self.wait().decode('latin1')) self.stream.read_bytes(int(headers["Content-Length"]), self.stop) body = self.wait() self.assertEqual(body, b"Hello world") def test_unix_socket_bad_request(self): # Unix sockets don't have remote addresses so they just return an # empty string. with ExpectLog(gen_log, "Malformed HTTP message from"): self.stream.write(b"garbage\r\n\r\n") self.stream.read_until_close(self.stop) response = self.wait() self.assertEqual(response, b"")
def test_100_continue(self): # Run through a 100-continue interaction by hand: # When given Expect: 100-continue, we get a 100 response after the # headers, and then the real response after the body. stream = IOStream(socket.socket()) yield stream.connect(("127.0.0.1", self.get_http_port())) yield stream.write( b"\r\n".join( [ b"POST /hello HTTP/1.1", b"Content-Length: 1024", b"Expect: 100-continue", b"Connection: close", b"\r\n", ] ) ) data = yield stream.read_until(b"\r\n\r\n") self.assertTrue(data.startswith(b"HTTP/1.1 100 "), data) stream.write(b"a" * 1024) first_line = yield stream.read_until(b"\r\n") self.assertTrue(first_line.startswith(b"HTTP/1.1 200"), first_line) header_data = yield stream.read_until(b"\r\n\r\n") headers = HTTPHeaders.parse(native_str(header_data.decode("latin1"))) body = yield stream.read_bytes(int(headers["Content-Length"])) self.assertEqual(body, b"Got 1024 bytes in POST") stream.close()
def test_message_response(self): # handle_stream may be a coroutine and any exception in its # Future will be logged. server = client = None try: sock, port = bind_unused_port() sock2, port2 = bind_unused_port() with NullContext(): server = StatusServer() notify_server = NotifyServer() notify_server.add_socket(sock2) server.notify_server = notify_server server.add_socket(sock) client = IOStream(socket.socket()) yield client.connect(('localhost', port)) yield client.write(msg1) results = yield client.read_bytes(4) assert results == b'\x11\x00\x01\x10' finally: if server is not None: server.stop() if client is not None: client.close()
async def handle_stream(self, stream: IOStream, address): print("connect from {0:s}:{1:d}".format(address[0], address[1])) loop = IOLoop.current() #type: IOLoop frameBuffer = b'' Q = Queue(maxsize=10) while True: try: if not stream.reading(): dataFuture = stream.read_bytes( 12, partial=True) #type:futures.Future frameBuffer = frameBuffer + await gen.with_timeout( timedelta(seconds=12), dataFuture) print("CurrentBuffer:", frameBuffer) if len(frameBuffer) < 24: continue loop.run_in_executor( None, partial(self.wrappedDecode, frameBuffer, Q)) status = Q.get() frameBuffer = b'' if status == self.DECODE_SUC: await stream.write(bytes([0x3e])) else: await stream.write(bytes([0x6c])) except StreamClosedError: print("connection closed from {0:s}:{1:d}".format( address[0], address[1])) break except gen.TimeoutError: frameBuffer = b'' print("No response in 3 seconds {0:s}:{1:d}".format( address[0], address[1]))
def handle_connection(self, connection, address): stream = IOStream(connection) print("start handle request...") #message = yield stream.read_until_close() message = yield stream.read_bytes(20, partial=True) #print ("delimiter: ", chr(self._delimiter).encode()) #stream.read_until(chr(self._delimiter).encode(), self.on_body) print("message from client:", message.decode().strip())
def test_unix_socket(self): sockfile = os.path.join(self.tmpdir, "test.sock") sock = netutil.bind_unix_socket(sockfile) app = Application([("/hello", HelloWorldRequestHandler)]) server = HTTPServer(app, io_loop=self.io_loop) server.add_socket(sock) stream = IOStream(socket.socket(socket.AF_UNIX), io_loop=self.io_loop) stream.connect(sockfile, self.stop) self.wait() stream.write(b("GET /hello HTTP/1.0\r\n\r\n")) stream.read_until(b("\r\n"), self.stop) response = self.wait() self.assertEqual(response, b("HTTP/1.0 200 OK\r\n")) stream.read_until(b("\r\n\r\n"), self.stop) headers = HTTPHeaders.parse(self.wait().decode('latin1')) stream.read_bytes(int(headers["Content-Length"]), self.stop) body = self.wait() self.assertEqual(body, b("Hello world"))
class TestIOStream(AsyncHTTPTestCase, LogTrapTestCase): def get_app(self): return Application([('/', HelloHandler)]) def test_read_zero_bytes(self): s = socket.socket(socket.AF_INET, socket.SOCK_STREAM, 0) s.connect(("localhost", self.get_http_port())) self.stream = IOStream(s, io_loop=self.io_loop) self.stream.write(b("GET / HTTP/1.0\r\n\r\n")) # normal read self.stream.read_bytes(9, self.stop) data = self.wait() self.assertEqual(data, b("HTTP/1.0 ")) # zero bytes self.stream.read_bytes(0, self.stop) data = self.wait() self.assertEqual(data, b("")) # another normal read self.stream.read_bytes(3, self.stop) data = self.wait() self.assertEqual(data, b("200")) def test_connection_refused(self): # When a connection is refused, the connect callback should not # be run. (The kqueue IOLoop used to behave differently from the # epoll IOLoop in this respect) port = get_unused_port() stream = IOStream(socket.socket(), self.io_loop) self.connect_called = False def connect_callback(): self.connect_called = True stream.set_close_callback(self.stop) stream.connect(("localhost", port), connect_callback) self.wait() self.assertFalse(self.connect_called) def test_connection_closed(self): # When a server sends a response and then closes the connection, # the client must be allowed to read the data before the IOStream # closes itself. Epoll reports closed connections with a separate # EPOLLRDHUP event delivered at the same time as the read event, # while kqueue reports them as a second read/write event with an EOF # flag. response = self.fetch("/", headers={"Connection": "close"}) response.rethrow() def test_read_until_close(self): s = socket.socket(socket.AF_INET, socket.SOCK_STREAM, 0) s.connect(("localhost", self.get_http_port())) stream = IOStream(s, io_loop=self.io_loop) stream.write(b("GET / HTTP/1.0\r\n\r\n")) stream.read_until_close(self.stop) data = self.wait() self.assertTrue(data.startswith(b("HTTP/1.0 200"))) self.assertTrue(data.endswith(b("Hello")))
class ESME(DeliverMixin, BaseESME): def __init__(self, **kwargs): BaseESME.__init__(self, **kwargs) self.running = False self.closed = False @coroutine def connect(self, host, port): s = socket.socket(socket.AF_INET, socket.SOCK_STREAM, 0) self.ioloop = IOLoop.current() self.stream = IOStream(s) yield self.stream.connect((host, port)) def on_send(self, data): return self.stream.write(data) def on_close(self): self.closed = True self.stream.close() @coroutine def readloop(self, future): while not self.closed and (not future or not future.done()): try: data = yield self.stream.read_bytes(1024, partial=True) except StreamClosedError: # pragma: no cover break else: self.feed(data) def wait_for(self, response): future = Future() response.callback = lambda resp: future.set_result(resp.response) if self.running: return future else: return self.run(future) @coroutine def run(self, future=None): self.running = True try: yield self.readloop(future) finally: self.running = False if future and future.done(): raise Return(future.result())
class TestIOStream(AsyncHTTPTestCase, LogTrapTestCase): def get_app(self): return Application([('/', HelloHandler)]) def test_read_zero_bytes(self): s = socket.socket(socket.AF_INET, socket.SOCK_STREAM, 0) s.connect(("localhost", self.get_http_port())) self.stream = IOStream(s, io_loop=self.io_loop) self.stream.write("GET / HTTP/1.0\r\n\r\n") # normal read self.stream.read_bytes(9, self.stop) data = self.wait() self.assertEqual(data, "HTTP/1.0 ") # zero bytes self.stream.read_bytes(0, self.stop) data = self.wait() self.assertEqual(data, "") # another normal read self.stream.read_bytes(3, self.stop) data = self.wait() self.assertEqual(data, "200") def test_connection_refused(self): # When a connection is refused, the connect callback should not # be run. (The kqueue IOLoop used to behave differently from the # epoll IOLoop in this respect) port = get_unused_port() stream = IOStream(socket.socket(), self.io_loop) self.connect_called = False def connect_callback(): self.connect_called = True stream.set_close_callback(self.stop) stream.connect(("localhost", port), connect_callback) self.wait() self.assertFalse(self.connect_called)
class TestIOStream(AsyncHTTPTestCase, LogTrapTestCase): def get_app(self): return Application([('/', HelloHandler)]) def test_read_zero_bytes(self): s = socket.socket(socket.AF_INET, socket.SOCK_STREAM, 0) s.connect(("localhost", self.get_http_port())) self.stream = IOStream(s, io_loop=self.io_loop) self.stream.write("GET / HTTP/1.0\r\n\r\n") # normal read self.stream.read_bytes(9, self.stop) data = self.wait() self.assertEqual(data, "HTTP/1.0 ") # zero bytes self.stream.read_bytes(0, self.stop) data = self.wait() self.assertEqual(data, "") # another normal read self.stream.read_bytes(3, self.stop) data = self.wait() self.assertEqual(data, "200")
class Connection(object): def __init__(self, host='localhost', port=6379, unix_socket_path=None, event_handler_proxy=None, stop_after=None, io_loop=None): self.host = host self.port = port self.unix_socket_path = unix_socket_path self._event_handler = event_handler_proxy self.timeout = stop_after self._stream = None self._io_loop = io_loop self.in_progress = False self.read_callbacks = set() self.ready_callbacks = deque() self._lock = 0 self.info = {'db': 0, 'pass': None} def __del__(self): self.disconnect() def execute_pending_command(self): # Continue with the pending command execution # if all read operations are completed. if not self.read_callbacks and self.ready_callbacks: # Pop a SINGLE callback from the queue and execute it. # The next one will be executed from the code # invoked by the callback callback = self.ready_callbacks.popleft() callback() def ready(self): return (not self.read_callbacks and not self.ready_callbacks) def wait_until_ready(self, callback=None): if callback: if not self.ready(): callback = stack_context.wrap(callback) self.ready_callbacks.append(callback) else: callback() def connect(self): if not self._stream: try: if self.unix_socket_path: sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) sock.settimeout(self.timeout) sock.connect(self.unix_socket_path) else: sock = socket.create_connection( (self.host, self.port), timeout=self.timeout ) sock.setsockopt(socket.SOL_TCP, socket.TCP_NODELAY, 1) self._stream = IOStream(sock) self._stream.set_close_callback(self.on_stream_close) self.info['db'] = 0 self.info['pass'] = None except socket.error as e: raise ConnectionError(str(e)) self.fire_event('on_connect') def on_stream_close(self): if self._stream: self.disconnect() callbacks = self.read_callbacks self.read_callbacks = set() for callback in callbacks: callback() def disconnect(self): if self._stream: s = self._stream self._stream = None try: if s.socket: s.socket.shutdown(socket.SHUT_RDWR) s.close() except: pass def fire_event(self, event): event_handler = self._event_handler if event_handler: try: getattr(event_handler, event)() except AttributeError: pass def write(self, data, callback=None): if not self._stream: raise ConnectionError('Tried to write to ' 'non-existent connection') if callback: callback = stack_context.wrap(callback) _callback = lambda: callback(None) self.read_callbacks.add(_callback) cb = partial(self.read_callback, _callback) else: cb = None try: self._stream.write(data, callback=cb) except IOError as e: self.disconnect() raise ConnectionError(e.message) def read(self, length, callback=None): try: if not self._stream: self.disconnect() raise ConnectionError('Tried to read from ' 'non-existent connection') callback = stack_context.wrap(callback) self.read_callbacks.add(callback) self._stream.read_bytes(length, callback=partial(self.read_callback, callback)) except IOError: self.fire_event('on_disconnect') def read_callback(self, callback, *args, **kwargs): try: self.read_callbacks.remove(callback) except KeyError: pass callback(*args, **kwargs) def readline(self, callback=None): try: if not self._stream: self.disconnect() raise ConnectionError('Tried to read from ' 'non-existent connection') callback = stack_context.wrap(callback) self.read_callbacks.add(callback) callback = partial(self.read_callback, callback) self._stream.read_until(CRLF, callback=callback) except IOError: self.fire_event('on_disconnect') def connected(self): if self._stream: return True return False
class TestIOStream(AsyncHTTPTestCase, LogTrapTestCase): def get_app(self): return Application([('/', HelloHandler)]) def make_iostream_pair(self): port = get_unused_port() [listener] = netutil.bind_sockets(port, '127.0.0.1', family=socket.AF_INET) streams = [None, None] def accept_callback(connection, address): streams[0] = IOStream(connection, io_loop=self.io_loop) self.stop() def connect_callback(): streams[1] = client_stream self.stop() netutil.add_accept_handler(listener, accept_callback, io_loop=self.io_loop) client_stream = IOStream(socket.socket(), io_loop=self.io_loop) client_stream.connect(('127.0.0.1', port), callback=connect_callback) self.wait(condition=lambda: all(streams)) self.io_loop.remove_handler(listener.fileno()) listener.close() return streams def test_read_zero_bytes(self): s = socket.socket(socket.AF_INET, socket.SOCK_STREAM, 0) s.connect(("localhost", self.get_http_port())) self.stream = IOStream(s, io_loop=self.io_loop) self.stream.write(b("GET / HTTP/1.0\r\n\r\n")) # normal read self.stream.read_bytes(9, self.stop) data = self.wait() self.assertEqual(data, b("HTTP/1.0 ")) # zero bytes self.stream.read_bytes(0, self.stop) data = self.wait() self.assertEqual(data, b("")) # another normal read self.stream.read_bytes(3, self.stop) data = self.wait() self.assertEqual(data, b("200")) def test_connection_refused(self): # When a connection is refused, the connect callback should not # be run. (The kqueue IOLoop used to behave differently from the # epoll IOLoop in this respect) port = get_unused_port() stream = IOStream(socket.socket(), self.io_loop) self.connect_called = False def connect_callback(): self.connect_called = True stream.set_close_callback(self.stop) stream.connect(("localhost", port), connect_callback) self.wait() self.assertFalse(self.connect_called) def test_connection_closed(self): # When a server sends a response and then closes the connection, # the client must be allowed to read the data before the IOStream # closes itself. Epoll reports closed connections with a separate # EPOLLRDHUP event delivered at the same time as the read event, # while kqueue reports them as a second read/write event with an EOF # flag. response = self.fetch("/", headers={"Connection": "close"}) response.rethrow() def test_read_until_close(self): s = socket.socket(socket.AF_INET, socket.SOCK_STREAM, 0) s.connect(("localhost", self.get_http_port())) stream = IOStream(s, io_loop=self.io_loop) stream.write(b("GET / HTTP/1.0\r\n\r\n")) stream.read_until_close(self.stop) data = self.wait() self.assertTrue(data.startswith(b("HTTP/1.0 200"))) self.assertTrue(data.endswith(b("Hello"))) def test_streaming_callback(self): server, client = self.make_iostream_pair() try: chunks = [] final_called = [] def streaming_callback(data): chunks.append(data) self.stop() def final_callback(data): assert not data final_called.append(True) self.stop() server.read_bytes(6, callback=final_callback, streaming_callback=streaming_callback) client.write(b("1234")) self.wait(condition=lambda: chunks) client.write(b("5678")) self.wait(condition=lambda: final_called) self.assertEqual(chunks, [b("1234"), b("56")]) # the rest of the last chunk is still in the buffer server.read_bytes(2, callback=self.stop) data = self.wait() self.assertEqual(data, b("78")) finally: server.close() client.close() def test_streaming_until_close(self): server, client = self.make_iostream_pair() try: chunks = [] def callback(data): chunks.append(data) self.stop() client.read_until_close(callback=callback, streaming_callback=callback) server.write(b("1234")) self.wait() server.write(b("5678")) self.wait() server.close() self.wait() self.assertEqual(chunks, [b("1234"), b("5678"), b("")]) finally: server.close() client.close() def test_delayed_close_callback(self): # The scenario: Server closes the connection while there is a pending # read that can be served out of buffered data. The client does not # run the close_callback as soon as it detects the close, but rather # defers it until after the buffered read has finished. server, client = self.make_iostream_pair() try: client.set_close_callback(self.stop) server.write(b("12")) chunks = [] def callback1(data): chunks.append(data) client.read_bytes(1, callback2) server.close() def callback2(data): chunks.append(data) client.read_bytes(1, callback1) self.wait() # stopped by close_callback self.assertEqual(chunks, [b("1"), b("2")]) finally: server.close() client.close()
class _HTTPConnection(object): _SUPPORTED_METHODS = set( ["GET", "HEAD", "POST", "PUT", "DELETE", "PATCH", "OPTIONS"]) def __init__(self, io_loop, client, request, release_callback, final_callback, max_buffer_size, resolver): self.start_time = io_loop.time() self.io_loop = io_loop self.client = client self.request = request self.release_callback = release_callback self.final_callback = final_callback self.max_buffer_size = max_buffer_size self.resolver = resolver self.code = None self.headers = None self.chunks = None self._decompressor = None # Timeout handle returned by IOLoop.add_timeout self._timeout = None with stack_context.ExceptionStackContext(self._handle_exception): self.parsed = urlparse.urlsplit(_unicode(self.request.url)) if self.parsed.scheme not in ("http", "https"): raise ValueError("Unsupported url scheme: %s" % self.request.url) # urlsplit results have hostname and port results, but they # didn't support ipv6 literals until python 2.7. netloc = self.parsed.netloc if "@" in netloc: userpass, _, netloc = netloc.rpartition("@") match = re.match(r'^(.+):(\d+)$', netloc) if match: host = match.group(1) port = int(match.group(2)) else: host = netloc port = 443 if self.parsed.scheme == "https" else 80 if re.match(r'^\[.*\]$', host): # raw ipv6 addresses in urls are enclosed in brackets host = host[1:-1] self.parsed_hostname = host # save final host for _on_connect if request.allow_ipv6: af = socket.AF_UNSPEC else: # We only try the first IP we get from getaddrinfo, # so restrict to ipv4 by default. af = socket.AF_INET self.resolver.resolve(host, port, af, callback=self._on_resolve) def _on_resolve(self, addrinfo): af, sockaddr = addrinfo[0] if self.parsed.scheme == "https": ssl_options = {} if self.request.validate_cert: ssl_options["cert_reqs"] = ssl.CERT_REQUIRED if self.request.ca_certs is not None: ssl_options["ca_certs"] = self.request.ca_certs else: ssl_options["ca_certs"] = _DEFAULT_CA_CERTS if self.request.client_key is not None: ssl_options["keyfile"] = self.request.client_key if self.request.client_cert is not None: ssl_options["certfile"] = self.request.client_cert # SSL interoperability is tricky. We want to disable # SSLv2 for security reasons; it wasn't disabled by default # until openssl 1.0. The best way to do this is to use # the SSL_OP_NO_SSLv2, but that wasn't exposed to python # until 3.2. Python 2.7 adds the ciphers argument, which # can also be used to disable SSLv2. As a last resort # on python 2.6, we set ssl_version to SSLv3. This is # more narrow than we'd like since it also breaks # compatibility with servers configured for TLSv1 only, # but nearly all servers support SSLv3: # http://blog.ivanristic.com/2011/09/ssl-survey-protocol-support.html if sys.version_info >= (2, 7): ssl_options["ciphers"] = "DEFAULT:!SSLv2" else: # This is really only necessary for pre-1.0 versions # of openssl, but python 2.6 doesn't expose version # information. ssl_options["ssl_version"] = ssl.PROTOCOL_SSLv3 self.stream = SSLIOStream(socket.socket(af), io_loop=self.io_loop, ssl_options=ssl_options, max_buffer_size=self.max_buffer_size) else: self.stream = IOStream(socket.socket(af), io_loop=self.io_loop, max_buffer_size=self.max_buffer_size) timeout = min(self.request.connect_timeout, self.request.request_timeout) if timeout: self._timeout = self.io_loop.add_timeout( self.start_time + timeout, stack_context.wrap(self._on_timeout)) self.stream.set_close_callback(self._on_close) # ipv6 addresses are broken (in self.parsed.hostname) until # 2.7, here is correctly parsed value calculated in __init__ self.stream.connect(sockaddr, self._on_connect, server_hostname=self.parsed_hostname) def _on_timeout(self): self._timeout = None if self.final_callback is not None: raise HTTPError(599, "Timeout") def _remove_timeout(self): if self._timeout is not None: self.io_loop.remove_timeout(self._timeout) self._timeout = None def _on_connect(self): self._remove_timeout() if self.request.request_timeout: self._timeout = self.io_loop.add_timeout( self.start_time + self.request.request_timeout, stack_context.wrap(self._on_timeout)) if (self.request.method not in self._SUPPORTED_METHODS and not self.request.allow_nonstandard_methods): raise KeyError("unknown method %s" % self.request.method) for key in ('network_interface', 'proxy_host', 'proxy_port', 'proxy_username', 'proxy_password'): if getattr(self.request, key, None): raise NotImplementedError('%s not supported' % key) if "Connection" not in self.request.headers: self.request.headers["Connection"] = "close" if "Host" not in self.request.headers: if '@' in self.parsed.netloc: self.request.headers["Host"] = self.parsed.netloc.rpartition( '@')[-1] else: self.request.headers["Host"] = self.parsed.netloc username, password = None, None if self.parsed.username is not None: username, password = self.parsed.username, self.parsed.password elif self.request.auth_username is not None: username = self.request.auth_username password = self.request.auth_password or '' if username is not None: auth = utf8(username) + b":" + utf8(password) self.request.headers["Authorization"] = (b"Basic " + base64.b64encode(auth)) if self.request.user_agent: self.request.headers["User-Agent"] = self.request.user_agent if not self.request.allow_nonstandard_methods: if self.request.method in ("POST", "PATCH", "PUT"): assert self.request.body is not None else: assert self.request.body is None if self.request.body is not None: self.request.headers["Content-Length"] = str(len( self.request.body)) if (self.request.method == "POST" and "Content-Type" not in self.request.headers): self.request.headers[ "Content-Type"] = "application/x-www-form-urlencoded" if self.request.use_gzip: self.request.headers["Accept-Encoding"] = "gzip" req_path = ((self.parsed.path or '/') + (('?' + self.parsed.query) if self.parsed.query else '')) request_lines = [ utf8("%s %s HTTP/1.1" % (self.request.method, req_path)) ] for k, v in self.request.headers.get_all(): line = utf8(k) + b": " + utf8(v) if b'\n' in line: raise ValueError('Newline in header: ' + repr(line)) request_lines.append(line) self.stream.write(b"\r\n".join(request_lines) + b"\r\n\r\n") if self.request.body is not None: self.stream.write(self.request.body) self.stream.read_until_regex(b"\r?\n\r?\n", self._on_headers) def _release(self): if self.release_callback is not None: release_callback = self.release_callback self.release_callback = None release_callback() def _run_callback(self, response): self._release() if self.final_callback is not None: final_callback = self.final_callback self.final_callback = None self.io_loop.add_callback(final_callback, response) def _handle_exception(self, typ, value, tb): if self.final_callback: self._remove_timeout() gen_log.warning("uncaught exception", exc_info=(typ, value, tb)) self._run_callback( HTTPResponse( self.request, 599, error=value, request_time=self.io_loop.time() - self.start_time, )) if hasattr(self, "stream"): self.stream.close() return True else: # If our callback has already been called, we are probably # catching an exception that is not caused by us but rather # some child of our callback. Rather than drop it on the floor, # pass it along. return False def _on_close(self): if self.final_callback is not None: message = "Connection closed" if self.stream.error: message = str(self.stream.error) raise HTTPError(599, message) def _handle_1xx(self, code): self.stream.read_until_regex(b"\r?\n\r?\n", self._on_headers) def _on_headers(self, data): data = native_str(data.decode("latin1")) first_line, _, header_data = data.partition("\n") match = re.match("HTTP/1.[01] ([0-9]+) ([^\r]*)", first_line) assert match code = int(match.group(1)) self.headers = HTTPHeaders.parse(header_data) if 100 <= code < 200: self._handle_1xx(code) return else: self.code = code self.reason = match.group(2) if "Content-Length" in self.headers: if "," in self.headers["Content-Length"]: # Proxies sometimes cause Content-Length headers to get # duplicated. If all the values are identical then we can # use them but if they differ it's an error. pieces = re.split(r',\s*', self.headers["Content-Length"]) if any(i != pieces[0] for i in pieces): raise ValueError("Multiple unequal Content-Lengths: %r" % self.headers["Content-Length"]) self.headers["Content-Length"] = pieces[0] content_length = int(self.headers["Content-Length"]) else: content_length = None if self.request.header_callback is not None: # re-attach the newline we split on earlier self.request.header_callback(first_line + _) for k, v in self.headers.get_all(): self.request.header_callback("%s: %s\r\n" % (k, v)) self.request.header_callback('\r\n') if self.request.method == "HEAD" or self.code == 304: # HEAD requests and 304 responses never have content, even # though they may have content-length headers self._on_body(b"") return if 100 <= self.code < 200 or self.code == 204: # These response codes never have bodies # http://www.w3.org/Protocols/rfc2616/rfc2616-sec4.html#sec4.3 if ("Transfer-Encoding" in self.headers or content_length not in (None, 0)): raise ValueError("Response with code %d should not have body" % self.code) self._on_body(b"") return if (self.request.use_gzip and self.headers.get("Content-Encoding") == "gzip"): self._decompressor = GzipDecompressor() if self.headers.get("Transfer-Encoding") == "chunked": self.chunks = [] self.stream.read_until(b"\r\n", self._on_chunk_length) elif content_length is not None: self.stream.read_bytes(content_length, self._on_body) else: self.stream.read_until_close(self._on_body) def _on_body(self, data): self._remove_timeout() original_request = getattr(self.request, "original_request", self.request) if (self.request.follow_redirects and self.request.max_redirects > 0 and self.code in (301, 302, 303, 307)): assert isinstance(self.request, _RequestProxy) new_request = copy.copy(self.request.request) new_request.url = urlparse.urljoin(self.request.url, self.headers["Location"]) new_request.max_redirects = self.request.max_redirects - 1 del new_request.headers["Host"] # http://www.w3.org/Protocols/rfc2616/rfc2616-sec10.html#sec10.3.4 # Client SHOULD make a GET request after a 303. # According to the spec, 302 should be followed by the same # method as the original request, but in practice browsers # treat 302 the same as 303, and many servers use 302 for # compatibility with pre-HTTP/1.1 user agents which don't # understand the 303 status. if self.code in (302, 303): new_request.method = "GET" new_request.body = None for h in [ "Content-Length", "Content-Type", "Content-Encoding", "Transfer-Encoding" ]: try: del self.request.headers[h] except KeyError: pass new_request.original_request = original_request final_callback = self.final_callback self.final_callback = None self._release() self.client.fetch(new_request, final_callback) self.stream.close() return if self._decompressor: data = (self._decompressor.decompress(data) + self._decompressor.flush()) if self.request.streaming_callback: if self.chunks is None: # if chunks is not None, we already called streaming_callback # in _on_chunk_data self.request.streaming_callback(data) buffer = BytesIO() else: buffer = BytesIO(data) # TODO: don't require one big string? response = HTTPResponse(original_request, self.code, reason=self.reason, headers=self.headers, request_time=self.io_loop.time() - self.start_time, buffer=buffer, effective_url=self.request.url) self._run_callback(response) self.stream.close() def _on_chunk_length(self, data): # TODO: "chunk extensions" http://tools.ietf.org/html/rfc2616#section-3.6.1 length = int(data.strip(), 16) if length == 0: if self._decompressor is not None: tail = self._decompressor.flush() if tail: # I believe the tail will always be empty (i.e. # decompress will return all it can). The purpose # of the flush call is to detect errors such # as truncated input. But in case it ever returns # anything, treat it as an extra chunk if self.request.streaming_callback is not None: self.request.streaming_callback(tail) else: self.chunks.append(tail) # all the data has been decompressed, so we don't need to # decompress again in _on_body self._decompressor = None self._on_body(b''.join(self.chunks)) else: self.stream.read_bytes( length + 2, # chunk ends with \r\n self._on_chunk_data) def _on_chunk_data(self, data): assert data[-2:] == b"\r\n" chunk = data[:-2] if self._decompressor: chunk = self._decompressor.decompress(chunk) if self.request.streaming_callback is not None: self.request.streaming_callback(chunk) else: self.chunks.append(chunk) self.stream.read_until(b"\r\n", self._on_chunk_length)
class TestIOStream(AsyncHTTPTestCase, LogTrapTestCase): def get_app(self): return Application([('/', HelloHandler)]) def make_iostream_pair(self, **kwargs): port = get_unused_port() [listener] = netutil.bind_sockets(port, '127.0.0.1', family=socket.AF_INET) streams = [None, None] def accept_callback(connection, address): streams[0] = IOStream(connection, io_loop=self.io_loop, **kwargs) self.stop() def connect_callback(): streams[1] = client_stream self.stop() netutil.add_accept_handler(listener, accept_callback, io_loop=self.io_loop) client_stream = IOStream(socket.socket(), io_loop=self.io_loop, **kwargs) client_stream.connect(('127.0.0.1', port), callback=connect_callback) self.wait(condition=lambda: all(streams)) self.io_loop.remove_handler(listener.fileno()) listener.close() return streams def test_read_zero_bytes(self): s = socket.socket(socket.AF_INET, socket.SOCK_STREAM, 0) s.connect(("localhost", self.get_http_port())) self.stream = IOStream(s, io_loop=self.io_loop) self.stream.write(b("GET / HTTP/1.0\r\n\r\n")) # normal read self.stream.read_bytes(9, self.stop) data = self.wait() self.assertEqual(data, b("HTTP/1.0 ")) # zero bytes self.stream.read_bytes(0, self.stop) data = self.wait() self.assertEqual(data, b("")) # another normal read self.stream.read_bytes(3, self.stop) data = self.wait() self.assertEqual(data, b("200")) s.close() def test_write_zero_bytes(self): # Attempting to write zero bytes should run the callback without # going into an infinite loop. server, client = self.make_iostream_pair() server.write(b(''), callback=self.stop) self.wait() # As a side effect, the stream is now listening for connection # close (if it wasn't already), but is not listening for writes self.assertEqual(server._state, IOLoop.READ | IOLoop.ERROR) server.close() client.close() def test_connection_refused(self): # When a connection is refused, the connect callback should not # be run. (The kqueue IOLoop used to behave differently from the # epoll IOLoop in this respect) port = get_unused_port() stream = IOStream(socket.socket(), self.io_loop) self.connect_called = False def connect_callback(): self.connect_called = True stream.set_close_callback(self.stop) stream.connect(("localhost", port), connect_callback) self.wait() self.assertFalse(self.connect_called) self.assertTrue(isinstance(stream.error, socket.error), stream.error) if sys.platform != 'cygwin': # cygwin's errnos don't match those used on native windows python self.assertEqual(stream.error.args[0], errno.ECONNREFUSED) def test_gaierror(self): # Test that IOStream sets its exc_info on getaddrinfo error s = socket.socket(socket.AF_INET, socket.SOCK_STREAM, 0) stream = IOStream(s, io_loop=self.io_loop) stream.set_close_callback(self.stop) stream.connect(('adomainthatdoesntexist.asdf', 54321)) self.assertTrue(isinstance(stream.error, socket.gaierror), stream.error) def test_connection_closed(self): # When a server sends a response and then closes the connection, # the client must be allowed to read the data before the IOStream # closes itself. Epoll reports closed connections with a separate # EPOLLRDHUP event delivered at the same time as the read event, # while kqueue reports them as a second read/write event with an EOF # flag. response = self.fetch("/", headers={"Connection": "close"}) response.rethrow() def test_read_until_close(self): s = socket.socket(socket.AF_INET, socket.SOCK_STREAM, 0) s.connect(("localhost", self.get_http_port())) stream = IOStream(s, io_loop=self.io_loop) stream.write(b("GET / HTTP/1.0\r\n\r\n")) stream.read_until_close(self.stop) data = self.wait() self.assertTrue(data.startswith(b("HTTP/1.0 200"))) self.assertTrue(data.endswith(b("Hello"))) def test_streaming_callback(self): server, client = self.make_iostream_pair() try: chunks = [] final_called = [] def streaming_callback(data): chunks.append(data) self.stop() def final_callback(data): assert not data final_called.append(True) self.stop() server.read_bytes(6, callback=final_callback, streaming_callback=streaming_callback) client.write(b("1234")) self.wait(condition=lambda: chunks) client.write(b("5678")) self.wait(condition=lambda: final_called) self.assertEqual(chunks, [b("1234"), b("56")]) # the rest of the last chunk is still in the buffer server.read_bytes(2, callback=self.stop) data = self.wait() self.assertEqual(data, b("78")) finally: server.close() client.close() def test_streaming_until_close(self): server, client = self.make_iostream_pair() try: chunks = [] def callback(data): chunks.append(data) self.stop() client.read_until_close(callback=callback, streaming_callback=callback) server.write(b("1234")) self.wait() server.write(b("5678")) self.wait() server.close() self.wait() self.assertEqual(chunks, [b("1234"), b("5678"), b("")]) finally: server.close() client.close() def test_delayed_close_callback(self): # The scenario: Server closes the connection while there is a pending # read that can be served out of buffered data. The client does not # run the close_callback as soon as it detects the close, but rather # defers it until after the buffered read has finished. server, client = self.make_iostream_pair() try: client.set_close_callback(self.stop) server.write(b("12")) chunks = [] def callback1(data): chunks.append(data) client.read_bytes(1, callback2) server.close() def callback2(data): chunks.append(data) client.read_bytes(1, callback1) self.wait() # stopped by close_callback self.assertEqual(chunks, [b("1"), b("2")]) finally: server.close() client.close() def test_close_buffered_data(self): # Similar to the previous test, but with data stored in the OS's # socket buffers instead of the IOStream's read buffer. Out-of-band # close notifications must be delayed until all data has been # drained into the IOStream buffer. (epoll used to use out-of-band # close events with EPOLLRDHUP, but no longer) # # This depends on the read_chunk_size being smaller than the # OS socket buffer, so make it small. server, client = self.make_iostream_pair(read_chunk_size=256) try: server.write(b("A") * 512) client.read_bytes(256, self.stop) data = self.wait() self.assertEqual(b("A") * 256, data) server.close() # Allow the close to propagate to the client side of the # connection. Using add_callback instead of add_timeout # doesn't seem to work, even with multiple iterations self.io_loop.add_timeout(datetime.timedelta(seconds=0.01), self.stop) self.wait() client.read_bytes(256, self.stop) data = self.wait() self.assertEqual(b("A") * 256, data) finally: server.close() client.close() def test_large_read_until(self): # Performance test: read_until used to have a quadratic component # so a read_until of 4MB would take 8 seconds; now it takes 0.25 # seconds. server, client = self.make_iostream_pair() try: NUM_KB = 4096 for i in xrange(NUM_KB): client.write(b("A") * 1024) client.write(b("\r\n")) server.read_until(b("\r\n"), self.stop) data = self.wait() self.assertEqual(len(data), NUM_KB * 1024 + 2) finally: server.close() client.close()
class _HTTPConnection(object): _SUPPORTED_METHODS = set(["GET", "HEAD", "POST", "PUT", "DELETE", "PATCH", "OPTIONS"]) def __init__(self, io_loop, client, request, release_callback, final_callback, max_buffer_size): self.start_time = io_loop.time() self.io_loop = io_loop self.client = client self.request = request self.release_callback = release_callback self.final_callback = final_callback self.max_buffer_size = max_buffer_size self.code = None self.headers = None self.chunks = None self._decompressor = None # Timeout handle returned by IOLoop.add_timeout self._timeout = None with stack_context.ExceptionStackContext(self._handle_exception): self.parsed = urlparse.urlsplit(_unicode(self.request.url)) if self.parsed.scheme not in ("http", "https"): raise ValueError("Unsupported url scheme: %s" % self.request.url) # urlsplit results have hostname and port results, but they # didn't support ipv6 literals until python 2.7. netloc = self.parsed.netloc if "@" in netloc: userpass, _, netloc = netloc.rpartition("@") match = re.match(r"^(.+):(\d+)$", netloc) if match: host = match.group(1) port = int(match.group(2)) else: host = netloc port = 443 if self.parsed.scheme == "https" else 80 if re.match(r"^\[.*\]$", host): # raw ipv6 addresses in urls are enclosed in brackets host = host[1:-1] self.parsed_hostname = host # save final host for _on_connect if self.client.hostname_mapping is not None: host = self.client.hostname_mapping.get(host, host) if request.allow_ipv6: af = socket.AF_UNSPEC else: # We only try the first IP we get from getaddrinfo, # so restrict to ipv4 by default. af = socket.AF_INET self.client.resolver.getaddrinfo(host, port, af, socket.SOCK_STREAM, 0, 0, callback=self._on_resolve) def _on_resolve(self, future): af, socktype, proto, canonname, sockaddr = future.result()[0] if self.parsed.scheme == "https": ssl_options = {} if self.request.validate_cert: ssl_options["cert_reqs"] = ssl.CERT_REQUIRED if self.request.ca_certs is not None: ssl_options["ca_certs"] = self.request.ca_certs else: ssl_options["ca_certs"] = _DEFAULT_CA_CERTS if self.request.client_key is not None: ssl_options["keyfile"] = self.request.client_key if self.request.client_cert is not None: ssl_options["certfile"] = self.request.client_cert # SSL interoperability is tricky. We want to disable # SSLv2 for security reasons; it wasn't disabled by default # until openssl 1.0. The best way to do this is to use # the SSL_OP_NO_SSLv2, but that wasn't exposed to python # until 3.2. Python 2.7 adds the ciphers argument, which # can also be used to disable SSLv2. As a last resort # on python 2.6, we set ssl_version to SSLv3. This is # more narrow than we'd like since it also breaks # compatibility with servers configured for TLSv1 only, # but nearly all servers support SSLv3: # http://blog.ivanristic.com/2011/09/ssl-survey-protocol-support.html if sys.version_info >= (2, 7): ssl_options["ciphers"] = "DEFAULT:!SSLv2" else: # This is really only necessary for pre-1.0 versions # of openssl, but python 2.6 doesn't expose version # information. ssl_options["ssl_version"] = ssl.PROTOCOL_SSLv3 self.stream = SSLIOStream( socket.socket(af, socktype, proto), io_loop=self.io_loop, ssl_options=ssl_options, max_buffer_size=self.max_buffer_size, ) else: self.stream = IOStream( socket.socket(af, socktype, proto), io_loop=self.io_loop, max_buffer_size=self.max_buffer_size ) timeout = min(self.request.connect_timeout, self.request.request_timeout) if timeout: self._timeout = self.io_loop.add_timeout(self.start_time + timeout, stack_context.wrap(self._on_timeout)) self.stream.set_close_callback(self._on_close) # ipv6 addresses are broken (in self.parsed.hostname) until # 2.7, here is correctly parsed value calculated in __init__ self.stream.connect(sockaddr, self._on_connect, server_hostname=self.parsed_hostname) def _on_timeout(self): self._timeout = None if self.final_callback is not None: raise HTTPError(599, "Timeout") def _on_connect(self): if self._timeout is not None: self.io_loop.remove_timeout(self._timeout) self._timeout = None if self.request.request_timeout: self._timeout = self.io_loop.add_timeout( self.start_time + self.request.request_timeout, stack_context.wrap(self._on_timeout) ) if self.request.method not in self._SUPPORTED_METHODS and not self.request.allow_nonstandard_methods: raise KeyError("unknown method %s" % self.request.method) for key in ("network_interface", "proxy_host", "proxy_port", "proxy_username", "proxy_password"): if getattr(self.request, key, None): raise NotImplementedError("%s not supported" % key) if "Connection" not in self.request.headers: self.request.headers["Connection"] = "close" if "Host" not in self.request.headers: if "@" in self.parsed.netloc: self.request.headers["Host"] = self.parsed.netloc.rpartition("@")[-1] else: self.request.headers["Host"] = self.parsed.netloc username, password = None, None if self.parsed.username is not None: username, password = self.parsed.username, self.parsed.password elif self.request.auth_username is not None: username = self.request.auth_username password = self.request.auth_password or "" if username is not None: auth = utf8(username) + b":" + utf8(password) self.request.headers["Authorization"] = b"Basic " + base64.b64encode(auth) if self.request.user_agent: self.request.headers["User-Agent"] = self.request.user_agent if not self.request.allow_nonstandard_methods: if self.request.method in ("POST", "PATCH", "PUT"): assert self.request.body is not None else: assert self.request.body is None if self.request.body is not None: self.request.headers["Content-Length"] = str(len(self.request.body)) if self.request.method == "POST" and "Content-Type" not in self.request.headers: self.request.headers["Content-Type"] = "application/x-www-form-urlencoded" if self.request.use_gzip: self.request.headers["Accept-Encoding"] = "gzip" req_path = (self.parsed.path or "/") + (("?" + self.parsed.query) if self.parsed.query else "") request_lines = [utf8("%s %s HTTP/1.1" % (self.request.method, req_path))] for k, v in self.request.headers.get_all(): line = utf8(k) + b": " + utf8(v) if b"\n" in line: raise ValueError("Newline in header: " + repr(line)) request_lines.append(line) self.stream.write(b"\r\n".join(request_lines) + b"\r\n\r\n") if self.request.body is not None: self.stream.write(self.request.body) self.stream.read_until_regex(b"\r?\n\r?\n", self._on_headers) def _release(self): if self.release_callback is not None: release_callback = self.release_callback self.release_callback = None release_callback() def _run_callback(self, response): self._release() if self.final_callback is not None: final_callback = self.final_callback self.final_callback = None self.io_loop.add_callback(final_callback, response) def _handle_exception(self, typ, value, tb): if self.final_callback: gen_log.warning("uncaught exception", exc_info=(typ, value, tb)) self._run_callback( HTTPResponse(self.request, 599, error=value, request_time=self.io_loop.time() - self.start_time) ) if hasattr(self, "stream"): self.stream.close() return True else: # If our callback has already been called, we are probably # catching an exception that is not caused by us but rather # some child of our callback. Rather than drop it on the floor, # pass it along. return False def _on_close(self): if self.final_callback is not None: message = "Connection closed" if self.stream.error: message = str(self.stream.error) raise HTTPError(599, message) def _on_headers(self, data): data = native_str(data.decode("latin1")) first_line, _, header_data = data.partition("\n") match = re.match("HTTP/1.[01] ([0-9]+) ([^\r]*)", first_line) assert match code = int(match.group(1)) if 100 <= code < 200: self.stream.read_until_regex(b"\r?\n\r?\n", self._on_headers) return else: self.code = code self.reason = match.group(2) self.headers = HTTPHeaders.parse(header_data) if "Content-Length" in self.headers: if "," in self.headers["Content-Length"]: # Proxies sometimes cause Content-Length headers to get # duplicated. If all the values are identical then we can # use them but if they differ it's an error. pieces = re.split(r",\s*", self.headers["Content-Length"]) if any(i != pieces[0] for i in pieces): raise ValueError("Multiple unequal Content-Lengths: %r" % self.headers["Content-Length"]) self.headers["Content-Length"] = pieces[0] content_length = int(self.headers["Content-Length"]) else: content_length = None if self.request.header_callback is not None: # re-attach the newline we split on earlier self.request.header_callback(first_line + _) for k, v in self.headers.get_all(): self.request.header_callback("%s: %s\r\n" % (k, v)) self.request.header_callback("\r\n") if self.request.method == "HEAD" or self.code == 304: # HEAD requests and 304 responses never have content, even # though they may have content-length headers self._on_body(b"") return if 100 <= self.code < 200 or self.code == 204: # These response codes never have bodies # http://www.w3.org/Protocols/rfc2616/rfc2616-sec4.html#sec4.3 if "Transfer-Encoding" in self.headers or content_length not in (None, 0): raise ValueError("Response with code %d should not have body" % self.code) self._on_body(b"") return if self.request.use_gzip and self.headers.get("Content-Encoding") == "gzip": self._decompressor = GzipDecompressor() if self.headers.get("Transfer-Encoding") == "chunked": self.chunks = [] self.stream.read_until(b"\r\n", self._on_chunk_length) elif content_length is not None: self.stream.read_bytes(content_length, self._on_body) else: self.stream.read_until_close(self._on_body) def _on_body(self, data): if self._timeout is not None: self.io_loop.remove_timeout(self._timeout) self._timeout = None original_request = getattr(self.request, "original_request", self.request) if self.request.follow_redirects and self.request.max_redirects > 0 and self.code in (301, 302, 303, 307): assert isinstance(self.request, _RequestProxy) new_request = copy.copy(self.request.request) new_request.url = urlparse.urljoin(self.request.url, self.headers["Location"]) new_request.max_redirects = self.request.max_redirects - 1 del new_request.headers["Host"] # http://www.w3.org/Protocols/rfc2616/rfc2616-sec10.html#sec10.3.4 # Client SHOULD make a GET request after a 303. # According to the spec, 302 should be followed by the same # method as the original request, but in practice browsers # treat 302 the same as 303, and many servers use 302 for # compatibility with pre-HTTP/1.1 user agents which don't # understand the 303 status. if self.code in (302, 303): new_request.method = "GET" new_request.body = None for h in ["Content-Length", "Content-Type", "Content-Encoding", "Transfer-Encoding"]: try: del self.request.headers[h] except KeyError: pass new_request.original_request = original_request final_callback = self.final_callback self.final_callback = None self._release() self.client.fetch(new_request, final_callback) self.stream.close() return if self._decompressor: data = self._decompressor.decompress(data) + self._decompressor.flush() if self.request.streaming_callback: if self.chunks is None: # if chunks is not None, we already called streaming_callback # in _on_chunk_data self.request.streaming_callback(data) buffer = BytesIO() else: buffer = BytesIO(data) # TODO: don't require one big string? response = HTTPResponse( original_request, self.code, reason=self.reason, headers=self.headers, request_time=self.io_loop.time() - self.start_time, buffer=buffer, effective_url=self.request.url, ) self._run_callback(response) self.stream.close() def _on_chunk_length(self, data): # TODO: "chunk extensions" http://tools.ietf.org/html/rfc2616#section-3.6.1 length = int(data.strip(), 16) if length == 0: if self._decompressor is not None: tail = self._decompressor.flush() if tail: # I believe the tail will always be empty (i.e. # decompress will return all it can). The purpose # of the flush call is to detect errors such # as truncated input. But in case it ever returns # anything, treat it as an extra chunk if self.request.streaming_callback is not None: self.request.streaming_callback(tail) else: self.chunks.append(tail) # all the data has been decompressed, so we don't need to # decompress again in _on_body self._decompressor = None self._on_body(b"".join(self.chunks)) else: self.stream.read_bytes(length + 2, self._on_chunk_data) # chunk ends with \r\n def _on_chunk_data(self, data): assert data[-2:] == b"\r\n" chunk = data[:-2] if self._decompressor: chunk = self._decompressor.decompress(chunk) if self.request.streaming_callback is not None: self.request.streaming_callback(chunk) else: self.chunks.append(chunk) self.stream.read_until(b"\r\n", self._on_chunk_length)
class _HTTPConnection(object): _SUPPORTED_METHODS = set(["GET", "HEAD", "POST", "PUT", "DELETE"]) def __init__(self, io_loop, client, request, release_callback, final_callback, max_buffer_size): self.start_time = time.time() self.io_loop = io_loop self.client = client self.request = request self.release_callback = release_callback self.final_callback = final_callback self.code = None self.headers = None self.chunks = None self._decompressor = None # Timeout handle returned by IOLoop.add_timeout self._timeout = None with stack_context.StackContext(self.cleanup): parsed = urllib.parse.urlsplit(_unicode(self.request.url)) if ssl is None and parsed.scheme == "https": raise ValueError("HTTPS requires either python2.6+ or " "curl_httpclient") if parsed.scheme not in ("http", "https"): raise ValueError("Unsupported url scheme: %s" % self.request.url) # urlsplit results have hostname and port results, but they # didn't support ipv6 literals until python 2.7. netloc = parsed.netloc if "@" in netloc: userpass, _, netloc = netloc.rpartition("@") match = re.match(r'^(.+):(\d+)$', netloc) if match: host = match.group(1) port = int(match.group(2)) else: host = netloc port = 443 if parsed.scheme == "https" else 80 if re.match(r'^\[.*\]$', host): # raw ipv6 addresses in urls are enclosed in brackets host = host[1:-1] if self.client.hostname_mapping is not None: host = self.client.hostname_mapping.get(host, host) if request.allow_ipv6: af = socket.AF_UNSPEC else: # We only try the first IP we get from getaddrinfo, # so restrict to ipv4 by default. af = socket.AF_INET addrinfo = socket.getaddrinfo(host, port, af, socket.SOCK_STREAM, 0, 0) af, socktype, proto, canonname, sockaddr = addrinfo[0] if parsed.scheme == "https": ssl_options = {} if request.validate_cert: ssl_options["cert_reqs"] = ssl.CERT_REQUIRED if request.ca_certs is not None: ssl_options["ca_certs"] = request.ca_certs else: ssl_options["ca_certs"] = _DEFAULT_CA_CERTS if request.client_key is not None: ssl_options["keyfile"] = request.client_key if request.client_cert is not None: ssl_options["certfile"] = request.client_cert self.stream = SSLIOStream(socket.socket(af, socktype, proto), io_loop=self.io_loop, ssl_options=ssl_options, max_buffer_size=max_buffer_size) else: self.stream = IOStream(socket.socket(af, socktype, proto), io_loop=self.io_loop, max_buffer_size=max_buffer_size) timeout = min(request.connect_timeout, request.request_timeout) if timeout: self._timeout = self.io_loop.add_timeout( self.start_time + timeout, self._on_timeout) self.stream.set_close_callback(self._on_close) self.stream.connect(sockaddr, functools.partial(self._on_connect, parsed)) def _on_timeout(self): self._timeout = None self._run_callback(HTTPResponse(self.request, 599, request_time=time.time() - self.start_time, error=HTTPError(599, "Timeout"))) self.stream.close() def _on_connect(self, parsed): if self._timeout is not None: self.io_loop.remove_timeout(self._timeout) self._timeout = None if self.request.request_timeout: self._timeout = self.io_loop.add_timeout( self.start_time + self.request.request_timeout, self._on_timeout) if (self.request.validate_cert and isinstance(self.stream, SSLIOStream)): match_hostname(self.stream.socket.getpeercert(), parsed.hostname) if (self.request.method not in self._SUPPORTED_METHODS and not self.request.allow_nonstandard_methods): raise KeyError("unknown method %s" % self.request.method) for key in ('network_interface', 'proxy_host', 'proxy_port', 'proxy_username', 'proxy_password'): if getattr(self.request, key, None): raise NotImplementedError('%s not supported' % key) if "Host" not in self.request.headers: self.request.headers["Host"] = parsed.netloc username, password = None, None if parsed.username is not None: username, password = parsed.username, parsed.password elif self.request.auth_username is not None: username = self.request.auth_username password = self.request.auth_password if username is not None: auth = utf8(username) + b(":") + utf8(password) self.request.headers["Authorization"] = (b("Basic ") + base64.b64encode(auth)) if self.request.user_agent: self.request.headers["User-Agent"] = self.request.user_agent if not self.request.allow_nonstandard_methods: if self.request.method in ("POST", "PUT"): assert self.request.body is not None else: assert self.request.body is None if self.request.body is not None: self.request.headers["Content-Length"] = str(len( self.request.body)) if (self.request.method == "POST" and "Content-Type" not in self.request.headers): self.request.headers["Content-Type"] = "application/x-www-form-urlencoded" if self.request.use_gzip: self.request.headers["Accept-Encoding"] = "gzip" req_path = ((parsed.path or '/') + (('?' + parsed.query) if parsed.query else '')) request_lines = [utf8("%s %s HTTP/1.1" % (self.request.method, req_path))] for k, v in self.request.headers.get_all(): line = utf8(k) + b(": ") + utf8(v) if b('\n') in line: raise ValueError('Newline in header: ' + repr(line)) request_lines.append(line) self.stream.write(b("\r\n").join(request_lines) + b("\r\n\r\n")) if self.request.body is not None: self.stream.write(self.request.body) self.stream.read_until_regex(b("\r?\n\r?\n"), self._on_headers) def _release(self): if self.release_callback is not None: release_callback = self.release_callback self.release_callback = None release_callback() def _run_callback(self, response): self._release() if self.final_callback is not None: final_callback = self.final_callback self.final_callback = None final_callback(response) @contextlib.contextmanager def cleanup(self): try: yield except Exception as e: logging.warning("uncaught exception", exc_info=True) self._run_callback(HTTPResponse(self.request, 599, error=e, request_time=time.time() - self.start_time, )) def _on_close(self): self._run_callback(HTTPResponse( self.request, 599, request_time=time.time() - self.start_time, error=HTTPError(599, "Connection closed"))) def _on_headers(self, data): data = native_str(data.decode("latin1")) first_line, _, header_data = data.partition("\n") match = re.match("HTTP/1.[01] ([0-9]+)", first_line) assert match self.code = int(match.group(1)) self.headers = HTTPHeaders.parse(header_data) if self.request.header_callback is not None: for k, v in self.headers.get_all(): self.request.header_callback("%s: %s\r\n" % (k, v)) if (self.request.use_gzip and self.headers.get("Content-Encoding") == "gzip"): # Magic parameter makes zlib module understand gzip header # http://stackoverflow.com/questions/1838699/how-can-i-decompress-a-gzip-stream-with-zlib self._decompressor = zlib.decompressobj(16+zlib.MAX_WBITS) if self.headers.get("Transfer-Encoding") == "chunked": self.chunks = [] self.stream.read_until(b("\r\n"), self._on_chunk_length) elif "Content-Length" in self.headers: if "," in self.headers["Content-Length"]: # Proxies sometimes cause Content-Length headers to get # duplicated. If all the values are identical then we can # use them but if they differ it's an error. pieces = re.split(r',\s*', self.headers["Content-Length"]) if any(i != pieces[0] for i in pieces): raise ValueError("Multiple unequal Content-Lengths: %r" % self.headers["Content-Length"]) self.headers["Content-Length"] = pieces[0] self.stream.read_bytes(int(self.headers["Content-Length"]), self._on_body) else: self.stream.read_until_close(self._on_body) def _on_body(self, data): if self._timeout is not None: self.io_loop.remove_timeout(self._timeout) self._timeout = None if self._decompressor: data = self._decompressor.decompress(data) if self.request.streaming_callback: if self.chunks is None: # if chunks is not None, we already called streaming_callback # in _on_chunk_data self.request.streaming_callback(data) buffer = BytesIO() else: buffer = BytesIO(data) # TODO: don't require one big string? original_request = getattr(self.request, "original_request", self.request) if (self.request.follow_redirects and self.request.max_redirects > 0 and self.code in (301, 302)): new_request = copy.copy(self.request) new_request.url = urllib.parse.urljoin(self.request.url, self.headers["Location"]) new_request.max_redirects -= 1 del new_request.headers["Host"] new_request.original_request = original_request final_callback = self.final_callback self.final_callback = None self._release() self.client.fetch(new_request, final_callback) self.stream.close() return response = HTTPResponse(original_request, self.code, headers=self.headers, request_time=time.time() - self.start_time, buffer=buffer, effective_url=self.request.url) self._run_callback(response) self.stream.close() def _on_chunk_length(self, data): # TODO: "chunk extensions" http://tools.ietf.org/html/rfc2616#section-3.6.1 length = int(data.strip(), 16) if length == 0: # all the data has been decompressed, so we don't need to # decompress again in _on_body self._decompressor = None self._on_body(b('').join(self.chunks)) else: self.stream.read_bytes(length + 2, # chunk ends with \r\n self._on_chunk_data) def _on_chunk_data(self, data): assert data[-2:] == b("\r\n") chunk = data[:-2] if self._decompressor: chunk = self._decompressor.decompress(chunk) if self.request.streaming_callback is not None: self.request.streaming_callback(chunk) else: self.chunks.append(chunk) self.stream.read_until(b("\r\n"), self._on_chunk_length)
class TwitterStream(object): """ Twitter stream connection """ _instance = None def __init__(self): """ Just set up the cache list and get first set """ # prepopulating cache client = AsyncHTTPClient() client.fetch("http://search.twitter.com/search.json?q="+ SETTINGS["track"], self.cache_callback) def cache_callback(self, response): """ Set up last fifty messages """ messages = json.loads(response.body)["results"][:50] messages.reverse() for message in messages: try: text = message["text"] name = "" username = message["from_user"] avatar = message["profile_image_url"] CACHE.append({ "type": "tweet", "text": text, "name": name, "username": username, "avatar": avatar, "time": 1 }) except KeyError: print "invalid", message continue self.open_twitter_stream() @classmethod def instance(cls): """ Returns the singleton """ if not cls._instance: cls._instance = cls() return cls._instance def open_twitter_stream(self): """ Creates the client and watches stream """ sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM, 0) self.twitter_stream = IOStream(sock) self.twitter_stream.connect(("stream.twitter.com", 80)) import base64 base64string = base64.encodestring("%s:%s" % (SETTINGS["username"], SETTINGS["password"]))[:-1] headers = {"Authorization": "Basic %s" % base64string, "Host": "stream.twitter.com"} request = ["GET /1/statuses/filter.json?track=%s HTTP/1.1" % SETTINGS["track"]] for key, value in headers.iteritems(): request.append("%s: %s" % (key, value)) request = "\r\n".join(request) + "\r\n\r\n" self.twitter_stream.write(request) self.twitter_stream.read_until("\r\n\r\n", self.on_headers) def on_headers(self, response): """ Starts monitoring for results. """ status_line = response.splitlines()[0] response_code = status_line.replace("HTTP/1.1", "") response_code = int(response_code.split()[0].strip()) if response_code != 200: raise Exception("Twitter could not connect: %s" % status_line) self.wait_for_message() def wait_for_message(self): """ Throw a read event on the stack. """ self.twitter_stream.read_until("\r\n", self.on_result) def on_result(self, response): """ Gets length of next message and reads it """ if (response.strip() == ""): return self.wait_for_message() length = int(response.strip(), 16) self.twitter_stream.read_bytes(length, self.parse_json) def parse_json(self, response): """ Checks JSON message """ if not response.strip(): # Empty line, happens sometimes for keep alive return self.wait_for_message() try: response = json.loads(response) except ValueError: print "Invalid response:" print response return self.wait_for_message() self.parse_response(response) def parse_response(self, response): """ Parse the twitter message """ try: text = response["text"] name = response["user"]["name"] username = response["user"]["screen_name"] avatar = response["user"]["profile_image_url_https"] except KeyError, exc: print "Invalid tweet structure, missing %s" % exc return self.wait_for_message() message = { "type": "tweet", "text": text, "avatar": avatar, "name": name, "username": username, "time": int(time.time()) } broadcast_message(message) self.wait_for_message()
class Connection(object): # Constants for connection state _CLOSED = 0x001 _CONNECTING = 0x002 _STREAMING = 0x004 ''' timeout -1: no timeout, None: per-request setting, other: overide per-request setting ''' def __init__(self, io_loop, client, timeout=-1, connect_timeout=-1, max_buffer_size=104857600): self.io_loop = io_loop self.client = client self.timeout = timeout self.connect_timeout = connect_timeout self.start_time = time.time() self.stream = None self._timeoutevent = None self._callback = None self._request_queue = collections.deque() self._request = None self._response = STPResponse() self._state = Connection._CLOSED @property def closed(self): return self._state == Connection._CLOSED def close(self): if self.stream is not None and not self.stream.closed(): self.stream.close() self.stream = None def _connect(self): self._state = Connection._CONNECTING af = socket.AF_INET if self.client.unix_socket is None else socket.AF_UNIX self.stream = IOStream(socket.socket(af, socket.SOCK_STREAM), io_loop=self.io_loop, max_buffer_size=self.client.max_buffer_size) if self.connect_timeout is not None and self.connect_timeout > 0: self._timeoutevent = self.io_loop.add_timeout(time.time() + self.connect_timeout, self._on_timeout) self.stream.set_close_callback(self._on_close) addr = self.client.unix_socket if self.client.unix_socket is not None else (self.client.host, self.client.port) self.stream.connect(addr, self._on_connect) def _on_connect(self): if self._timeoutevent is not None: self.io_loop.remove_timeout(self._timeoutevent) self._timeoutevent = None self._state = Connection._STREAMING self._send_request() def _on_timeout(self): self._timeoutevent = None self._run_callback(STPResponse(request_time=time.time() - self.start_time, error=exceptions.STPTimeoutError('Timeout'))) if self.stream is not None: self.stream.close() self.stream = None self._state = Connection._CLOSED self._request = None if len(self._request_queue) > 0: self._connect_and_send_request() def _on_close(self): self._run_callback(STPResponse(request_time=time.time() - self.start_time, error=exceptions.STPNetworkError('Connection error'))) self._state = Connection._CLOSED self._request = None if len(self._request_queue) > 0: self._connect_and_send_request() def send_request(self, request, callback): self._request_queue.append((request, callback)) self._connect_and_send_request() def _connect_and_send_request(self): if len(self._request_queue) > 0 and self._request is None: self._request, self._callback = self._request_queue.popleft() if self.stream is None or self._state == Connection._CLOSED: self._connect() elif self._state == Connection._STREAMING: self._send_request() def _send_request(self): def write_callback(): '''tornado needs it''' pass timeout = self.timeout if self._request.request_timeout is not None: timeout = self._request.request_timeout if timeout is not None and timeout > 0: self._timeoutevent = self.io_loop.add_timeout(time.time() + timeout, self._on_timeout) self.start_time = time.time() self.stream.write(self._request.serialize(), write_callback) self._read_arg() def _run_callback(self, response): if self._callback is not None: callback = self._callback self._callback = None callback(response) def _read_arg(self): self.stream.read_until(b'\r\n', self._on_arglen) def _on_arglen(self, data): if data == '\r\n': response = self._response self._response = STPResponse() response.request_time = time.time() - self.start_time if self._timeoutevent is not None: self.io_loop.remove_timeout(self._timeoutevent) self._run_callback(response) self._request = None self._connect_and_send_request() else: try: arglen = int(data[:-2]) self.stream.read_bytes(arglen, self._on_arg) except Exception as e: self._run_callback(STPResponse(request_time=time.time() - self.start_time, error=exceptions.STPProtocolError(str(e)))) def _on_arg(self, data): self._response._argv.append(data) self.stream.read_until(b'\r\n', self._on_strip_arg_eol) def _on_strip_arg_eol(self, data): self._read_arg()
class Connection(object): def __init__(self, host, port, event_handler, stop_after=None, io_loop=None): self.host = host self.port = port self._event_handler = weakref.proxy(event_handler) self.timeout = stop_after self._stream = None self._io_loop = io_loop self.try_left = 2 self.in_progress = False self.read_callbacks = [] self.ready_callbacks = deque() def __del__(self): self.disconnect() def __enter__(self): return self def __exit__(self, *args, **kwargs): if self.ready_callbacks: # Pop a SINGLE callback from the queue and execute it. # The next one will be executed from the code # invoked by the callback callback = self.ready_callbacks.popleft() callback() def ready(self): return not self.read_callbacks and not self.ready_callbacks def wait_until_ready(self, callback=None): if callback: if not self.ready(): self.ready_callbacks.append(callback) else: callback() return self def connect(self): if not self._stream: try: sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM, 0) sock.setsockopt(socket.SOL_TCP, socket.TCP_NODELAY, 1) sock.settimeout(self.timeout) sock.connect((self.host, self.port)) self._stream = IOStream(sock, io_loop=self._io_loop) self._stream.set_close_callback(self.on_stream_close) self.connected() except socket.error as e: raise ConnectionError(str(e)) self.fire_event('on_connect') def on_stream_close(self): if self._stream: self._stream = None callbacks = self.read_callbacks self.read_callbacks = [] for callback in callbacks: callback(None) def disconnect(self): if self._stream: s = self._stream self._stream = None try: s.socket.shutdown(socket.SHUT_RDWR) s.close() except socket.error: pass def fire_event(self, event): if self._event_handler: try: getattr(self._event_handler, event)() except AttributeError: pass def write(self, data, try_left=None): if try_left is None: try_left = self.try_left if not self._stream: self.connect() if not self._stream: raise ConnectionError('Tried to write to ' 'non-existent connection') if try_left > 0: try: self._stream.write(data) except IOError: self.disconnect() self.write(data, try_left - 1) else: raise ConnectionError('Tried to write to non-existent connection') def read(self, length, callback=None): try: if not self._stream: self.disconnect() raise ConnectionError('Tried to read from ' 'non-existent connection') self.read_callbacks.append(callback) self._stream.read_bytes(length, callback=partial(self.read_callback, callback)) except IOError: self.fire_event('on_disconnect') def read_callback(self, callback, *args, **kwargs): self.read_callbacks.remove(callback) callback(*args, **kwargs) def readline(self, callback=None): try: if not self._stream: self.disconnect() raise ConnectionError('Tried to read from ' 'non-existent connection') self.read_callbacks.append(callback) self._stream.read_until('\r\n', callback=partial(self.read_callback, callback)) except IOError: self.fire_event('on_disconnect') def connected(self): if self._stream: return True return False
class KafkaTornado(BaseKafka): def __init__(self, *args, **kwargs): if 'io_loop' in kwargs: self._io_loop = kwargs['io_loop'] del kwargs['io_loop'] else: self._io_loop = None BaseKafka.__init__(self, *args, **kwargs) self._stream = None # Socket management methods def _connect(self): """ Connect to the Kafka server. """ sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM, 0) try: sock.connect((self.host, self.port)) except Exception: raise ConnectionFailure("Could not connect to kafka at {0}:{1}".format(self.host, self.port)) else: self._stream = IOStream(sock, io_loop=self._io_loop) def _disconnect(self): """ Disconnect from the remote server & close the socket. """ try: self._stream.close() except IOError: pass finally: self._stream = None def _read(self, length, callback=None): """ Send a read request to the remote Kafka server. """ if callback is None: callback = lambda v: v if not self._stream: self._connect() return self._stream.read_bytes(length, callback) def _write(self, data, callback=None, retries=BaseKafka.MAX_RETRY): """ Write `data` to the remote Kafka server. """ if callback is None: callback = lambda: None if not self._stream: self._connect() try: return self._stream.write(data, callback) except IOError: if retries > 0: self._stream = None retries_left = retries - 1 socket_log.warn('Write failure, retrying ({0} retries left)'.format(retries_left)) return self._write(data, callback, retries_left) else: raise
class Client(object): def __init__(self, host='localhost', port=11300, connect_timeout=socket.getdefaulttimeout(), io_loop=None): self._connect_timeout = connect_timeout self.host = host self.port = port self.io_loop = io_loop or IOLoop.instance() self._stream = None self._using = 'default' # current tube self._watching = set(['default']) # set of watched tubes self._queue = deque() self._talking = False self._reconnect_cb = None def _reconnect(self): # wait some time before trying to re-connect self.io_loop.add_timeout(time.time() + RECONNECT_TIMEOUT, lambda: self.connect(self._reconnected)) def _reconnected(self): # re-establish the used tube and tubes being watched watch = self._watching.difference(['default']) # ignore "default", if it is not in the client's watch list ignore = set(['default']).difference(self._watching) def do_next(_=None): try: if watch: self.watch(watch.pop(), do_next) elif ignore: self.ignore(ignore.pop(), do_next) elif self._using != 'default': # change the tube used, and callback to user self.use(self._using, self._reconnect_cb) elif self._reconnect_cb: # callback to user self._reconnect_cb() except: # ignored, as next re-connect will retry the operation pass do_next() @coroutine def connect(self): """Connect to beanstalkd server.""" if not self.closed(): return self._talking = False self._socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM, socket.IPPROTO_TCP) self._stream = IOStream(self._socket, io_loop=self.io_loop) self._stream.set_close_callback(self._reconnect) yield Task(self._stream.connect, (self.host, self.port)) def set_reconnect_callback(self, callback): """Set callback to be called if connection has been lost and re-established again. If the connection is closed unexpectedly, the client will automatically attempt to re-connect with 1 second intervals. After re-connecting, the client will attempt to re-establish the used tube and watched tubes. """ self._reconnect_cb = callback @coroutine def close(self): """Close connection to server.""" key = object() if self._stream: self._stream.set_close_callback((yield Callback(key))) if not self.closed(): yield Task(self._stream.write, b'quit\r\n') self._stream.close() yield Wait(key) def closed(self): """"Returns True if the connection is closed.""" return not self._stream or self._stream.closed() def _interact(self, request, callback): # put the interaction request into the FIFO queue cb = stack_context.wrap(callback) self._queue.append((request, cb)) self._process_queue() def _process_queue(self): if self._talking or not self._queue: return # pop a request of the queue and perform the send-receive interaction self._talking = True with stack_context.NullContext(): req, cb = self._queue.popleft() command = req.cmd + b'\r\n' if req.body: command += req.body + b'\r\n' # write command and body to socket stream self._stream.write(command, # when command is written: read line from socket stream lambda: self._stream.read_until(b'\r\n', # when a line has been read: return status and results lambda data: self._recv(req, data, cb))) def _recv(self, req, data, cb): # parse the data received as server response spl = data.decode('utf8').split() status, values = spl[0], spl[1:] error = None err_args = ObjectDict(request=req, status=status, values=values) if req.ok and status in req.ok: # avoid raising a Buried exception when using the bury command pass elif status == 'BURIED': error = Buried(**err_args) elif status == 'TIMED_OUT': error = TimedOut(**err_args) elif status == 'DEADLINE_SOON': error = DeadlineSoon(**err_args) elif req.err and status in req.err: error = CommandFailed(**err_args) else: error = UnexpectedResponse(**err_args) resp = Bunch(req=req, status=status, values=values, error=error) if error or not req.read_body: # end the request and callback with results self._do_callback(cb, resp) else: # read the body including the terminating two bytes of crlf if len(values) == 2: job_id, size = int(values[0]), int(values[1]) resp.job_id = int(job_id) else: size = int(values[0]) self._stream.read_bytes(size + 2, lambda data: self._recv_body(data[:-2], resp, cb)) def _recv_body(self, data, resp, cb): if resp.req.parse_yaml: # parse the yaml encoded body self._parse_yaml(data, resp, cb) else: # don't parse body, it is a job! # end the request and callback with results resp.body = ObjectDict(id=resp.job_id, body=data) self._do_callback(cb, resp) def _parse_yaml(self, data, resp, cb): # dirty parsing of yaml data # (assumes that data is a yaml encoded list or dict) spl = data.decode('utf8').split('\n')[1:-1] if spl[0].startswith('- '): # it is a list resp.body = [s[2:] for s in spl] else: # it is a dict conv = lambda v: ((float(v) if '.' in v else int(v)) if v.replace('.', '', 1).isdigit() else v) resp.body = ObjectDict((k, conv(v.strip())) for k, v in (s.split(':') for s in spl)) self._do_callback(cb, resp) def _do_callback(self, cb, resp): # end the request and process next item in the queue # and callback with results self._talking = False self.io_loop.add_callback(self._process_queue) if not cb: return # default is to callback with error state (None or exception) obj = None req = resp.req if resp.error: obj = resp.error elif req.read_value: # callback with an integer value or a string if resp.values[0].isdigit(): obj = int(resp.values[0]) else: obj = resp.values[0] elif req.read_body: # callback with the body (job or parsed yaml) obj = resp.body self.io_loop.add_callback(lambda: cb(obj)) # # Producer commands # @coroutine def put(self, body, priority=DEFAULT_PRIORITY, delay=0, ttr=120): """Put a job body (a byte string) into the current tube. The job can be delayed a number of seconds, before it is put in the ready queue, default is no delay. The job is assigned a Time To Run (ttr, in seconds), the mininum is 1 sec., default is ttr=120 sec. Calls back with id when job is inserted. If an error occured, the callback gets a Buried or CommandFailed exception. The job is buried when either the body is too big, so server ran out of memory, or when the server is in draining mode. """ cmd = 'put {} {} {} {}'.format(priority, delay, ttr, len(body)).encode('utf8') assert isinstance(body, bytes) request = Bunch(cmd=cmd, ok=['INSERTED'], err=['BURIED', 'JOB_TOO_BIG', 'DRAINING'], body=body, read_value=True) resp = yield Task(self._interact, request) raise Return(resp) @coroutine def use(self, name): """Use the tube with given name. Calls back with the name of the tube now being used. """ cmd = 'use {}'.format(name).encode('utf8') request = Bunch(cmd=cmd, ok=['USING'], read_value=True) resp = yield Task(self._interact, request) if not isinstance(resp, Exception): self._using = resp raise Return(resp) # # Worker commands # @coroutine def reserve(self, timeout=None): """Reserve a job from one of the watched tubes, with optional timeout in seconds. Not specifying a timeout (timeout=None, the default) will make the client put the communication with beanstalkd on hold, until either a job is reserved, or a already reserved job is approaching it's TTR deadline. Commands issued while waiting for the "reserve" callback will be queued and sent in FIFO order, when communication is resumed. A timeout value of 0 will cause the server to immediately return either a response or TIMED_OUT. A positive value of timeout will limit the amount of time the client will will hold communication until a job becomes available. Calls back with a job dict (keys id and body). If the request timed out, the callback gets a TimedOut exception. If a reserved job has deadline within the next second, the callback gets a DeadlineSoon exception. """ if timeout is not None: cmd = 'reserve-with-timeout {}'.format(timeout).encode('utf8') else: cmd = b'reserve' request = Bunch(cmd=cmd, ok=['RESERVED'], err=['DEADLINE_SOON', 'TIMED_OUT'], read_body=True) resp = yield Task(self._interact, request) raise Return(resp) @coroutine def delete(self, job_id): """Delete job with given id. Calls back when job is deleted. If the job does not exist, or it is not neither reserved by the client, ready or buried; the callback gets a CommandFailed exception. """ cmd = 'delete {}'.format(job_id).encode('utf8') request = Bunch(cmd=cmd, ok=['DELETED'], err=['NOT_FOUND']) resp = yield Task(self._interact, request) raise Return(resp) @coroutine def release(self, job_id, priority=DEFAULT_PRIORITY, delay=0): """Release a reserved job back into the ready queue. A new priority can be assigned to the job. It is also possible to specify a delay (in seconds) to wait before putting the job in the ready queue. The job will be in the "delayed" state during this time. Calls back when job is released. If the job was buried, the callback gets a Buried exception. If the job does not exist, or it is not reserved by the client, the callback gets a CommandFailed exception. """ cmd = 'release {} {} {}'.format(job_id, priority, delay).encode('utf8') request = Bunch(cmd=cmd, ok=['RELEASED'], err=['BURIED', 'NOT_FOUND']) resp = yield Task(self._interact, request) raise Return(resp) @coroutine def bury(self, job_id, priority=DEFAULT_PRIORITY): """Bury job with given id. A new priority can be assigned to the job. Calls back when job is burried. If the job does not exist, or it is not reserved by the client, the callback gets a CommandFailed exception. """ cmd = 'bury {} {}'.format(job_id, priority).encode('utf8') request = Bunch(cmd=cmd, ok=['BURIED'], err=['NOT_FOUND']) resp = yield Task(self._interact, request) raise Return(resp) @coroutine def touch(self, job_id): """Touch job with given id. This is for requesting more time to work on a reserved job before it expires. Calls back when job is touched. If the job does not exist, or it is not reserved by the client, the callback gets a CommandFailed exception. """ cmd = 'touch {}'.format(job_id).encode('utf8') request = Bunch(cmd=cmd, ok=['TOUCHED'], err=['NOT_FOUND']) resp = yield Task(self._interact, request) raise Return(resp) @coroutine def watch(self, name): """Watch tube with given name. Calls back with number of tubes currently in the watch list. """ cmd = 'watch {}'.format(name).encode('utf8') request = Bunch(cmd=cmd, ok=['WATCHING'], read_value=True) resp = yield Task(self._interact, request) # add to the client's watch list self._watching.add(name) raise Return(resp) @coroutine def ignore(self, name): """Stop watching tube with given name. Calls back with the number of tubes currently in the watch list. On an attempt to ignore the only tube in the watch list, the callback gets a CommandFailed exception. """ cmd = 'ignore {}'.format(name).encode('utf8') request = Bunch(cmd=cmd, ok=['WATCHING'], err=['NOT_IGNORED'], read_value=True) resp = yield Task(self._interact, request) if name in self._watching: # remove from the client's watch list self._watching.remove(name) raise Return(resp) # # Other commands # def _peek(self, variant, callback): # a shared gateway for the peek* commands cmd = 'peek{}'.format(variant).encode('utf8') request = Bunch(cmd=cmd, ok=['FOUND'], err=['NOT_FOUND'], read_body=True) self._interact(request, callback) @coroutine def peek(self, job_id): """Peek at job with given id. Calls back with a job dict (keys id and body). If no job exists with that id, the callback gets a CommandFailed exception. """ resp = yield Task(self._peek, ' {}'.format(job_id)) raise Return(resp) @coroutine def peek_ready(self): """Peek at next ready job in the current tube. Calls back with a job dict (keys id and body). If no ready jobs exist, the callback gets a CommandFailed exception. """ resp = yield Task(self._peek, '-ready') raise Return(resp) @coroutine def peek_delayed(self): """Peek at next delayed job in the current tube. Calls back with a job dict (keys id and body). If no delayed jobs exist, the callback gets a CommandFailed exception. """ resp = yield Task(self._peek, '-delayed') raise Return(resp) @coroutine def peek_buried(self): """Peek at next buried job in the current tube. Calls back with a job dict (keys id and body). If no buried jobs exist, the callback gets a CommandFailed exception. """ resp = yield Task(self._peek, '-buried') raise Return(resp) @coroutine def kick(self, bound=1): """Kick at most `bound` jobs into the ready queue from the current tube. Calls back with the number of jobs actually kicked. """ cmd = 'kick {}'.format(bound).encode('utf8') request = Bunch(cmd=cmd, ok=['KICKED'], read_value=True) resp = yield Task(self._interact, request) raise Return(resp) @coroutine def kick_job(self, job_id): """Kick job with given id into the ready queue. (Requires Beanstalkd version >= 1.8) Calls back when job is kicked. If no job exists with that id, or if job is not in a kickable state, the callback gets a CommandFailed exception. """ cmd = 'kick-job {}'.format(job_id).encode('utf8') request = Bunch(cmd=cmd, ok=['KICKED'], err=['NOT_FOUND']) resp = yield Task(self._interact, request) raise Return(resp) @coroutine def stats_job(self, job_id): """A dict of stats about the job with given id. If no job exists with that id, the callback gets a CommandFailed exception. """ cmd = 'stats-job {}'.format(job_id).encode('utf8') request = Bunch(cmd=cmd, ok=['OK'], err=['NOT_FOUND'], read_body=True, parse_yaml=True) resp = yield Task(self._interact, request) raise Return(resp) @coroutine def stats_tube(self, name): """A dict of stats about the tube with given name. If no tube exists with that name, the callback gets a CommandFailed exception. """ cmd = 'stats-tube {}'.format(name).encode('utf8') request = Bunch(cmd=cmd, ok=['OK'], err=['NOT_FOUND'], read_body=True, parse_yaml=True) resp = yield Task(self._interact, request) raise Return(resp) @coroutine def stats(self): """A dict of beanstalkd statistics.""" request = Bunch(cmd=b'stats', ok=['OK'], read_body=True, parse_yaml=True) resp = yield Task(self._interact, request) raise Return(resp) @coroutine def list_tubes(self): """List of all existing tubes.""" request = Bunch(cmd=b'list-tubes', ok=['OK'], read_body=True, parse_yaml=True) resp = yield Task(self._interact, request) raise Return(resp) @coroutine def list_tube_used(self): """Name of the tube currently being used.""" request = Bunch(cmd=b'list-tube-used', ok=['USING'], read_value=True) resp = yield Task(self._interact, request) raise Return(resp) @coroutine def list_tubes_watched(self): """List of tubes currently being watched.""" request = Bunch(cmd=b'list-tubes-watched', ok=['OK'], read_body=True, parse_yaml=True) resp = yield Task(self._interact, request) raise Return(resp) @coroutine def pause_tube(self, name, delay): """Delay any new job being reserved from the tube for a given time. The delay is an integer number of seconds to wait before reserving any more jobs from the queue. Calls back when tube is paused. If tube does not exists, the callback will get a CommandFailed exception. """ cmd = 'pause-tube {} {}'.format(name, delay).encode('utf8') request = Bunch(cmd=cmd, ok=['PAUSED'], err=['NOT_FOUND']) resp = yield Task(self._interact, request) raise Return(resp)
class TestIOStream(AsyncHTTPTestCase, LogTrapTestCase): def get_app(self): return Application([('/', HelloHandler)]) def make_iostream_pair(self, **kwargs): port = get_unused_port() [listener] = netutil.bind_sockets(port, '127.0.0.1', family=socket.AF_INET) streams = [None, None] def accept_callback(connection, address): streams[0] = IOStream(connection, io_loop=self.io_loop, **kwargs) self.stop() def connect_callback(): streams[1] = client_stream self.stop() netutil.add_accept_handler(listener, accept_callback, io_loop=self.io_loop) client_stream = IOStream(socket.socket(), io_loop=self.io_loop, **kwargs) client_stream.connect(('127.0.0.1', port), callback=connect_callback) self.wait(condition=lambda: all(streams)) self.io_loop.remove_handler(listener.fileno()) listener.close() return streams def test_read_zero_bytes(self): s = socket.socket(socket.AF_INET, socket.SOCK_STREAM, 0) s.connect(("localhost", self.get_http_port())) self.stream = IOStream(s, io_loop=self.io_loop) self.stream.write(b("GET / HTTP/1.0\r\n\r\n")) # normal read self.stream.read_bytes(9, self.stop) data = self.wait() self.assertEqual(data, b("HTTP/1.0 ")) # zero bytes self.stream.read_bytes(0, self.stop) data = self.wait() self.assertEqual(data, b("")) # another normal read self.stream.read_bytes(3, self.stop) data = self.wait() self.assertEqual(data, b("200")) def test_connection_refused(self): # When a connection is refused, the connect callback should not # be run. (The kqueue IOLoop used to behave differently from the # epoll IOLoop in this respect) port = get_unused_port() stream = IOStream(socket.socket(), self.io_loop) self.connect_called = False def connect_callback(): self.connect_called = True stream.set_close_callback(self.stop) stream.connect(("localhost", port), connect_callback) self.wait() self.assertFalse(self.connect_called) def test_connection_closed(self): # When a server sends a response and then closes the connection, # the client must be allowed to read the data before the IOStream # closes itself. Epoll reports closed connections with a separate # EPOLLRDHUP event delivered at the same time as the read event, # while kqueue reports them as a second read/write event with an EOF # flag. response = self.fetch("/", headers={"Connection": "close"}) response.rethrow() def test_read_until_close(self): s = socket.socket(socket.AF_INET, socket.SOCK_STREAM, 0) s.connect(("localhost", self.get_http_port())) stream = IOStream(s, io_loop=self.io_loop) stream.write(b("GET / HTTP/1.0\r\n\r\n")) stream.read_until_close(self.stop) data = self.wait() self.assertTrue(data.startswith(b("HTTP/1.0 200"))) self.assertTrue(data.endswith(b("Hello"))) def test_streaming_callback(self): server, client = self.make_iostream_pair() try: chunks = [] final_called = [] def streaming_callback(data): chunks.append(data) self.stop() def final_callback(data): assert not data final_called.append(True) self.stop() server.read_bytes(6, callback=final_callback, streaming_callback=streaming_callback) client.write(b("1234")) self.wait(condition=lambda: chunks) client.write(b("5678")) self.wait(condition=lambda: final_called) self.assertEqual(chunks, [b("1234"), b("56")]) # the rest of the last chunk is still in the buffer server.read_bytes(2, callback=self.stop) data = self.wait() self.assertEqual(data, b("78")) finally: server.close() client.close() def test_streaming_until_close(self): server, client = self.make_iostream_pair() try: chunks = [] def callback(data): chunks.append(data) self.stop() client.read_until_close(callback=callback, streaming_callback=callback) server.write(b("1234")) self.wait() server.write(b("5678")) self.wait() server.close() self.wait() self.assertEqual(chunks, [b("1234"), b("5678"), b("")]) finally: server.close() client.close() def test_delayed_close_callback(self): # The scenario: Server closes the connection while there is a pending # read that can be served out of buffered data. The client does not # run the close_callback as soon as it detects the close, but rather # defers it until after the buffered read has finished. server, client = self.make_iostream_pair() try: client.set_close_callback(self.stop) server.write(b("12")) chunks = [] def callback1(data): chunks.append(data) client.read_bytes(1, callback2) server.close() def callback2(data): chunks.append(data) client.read_bytes(1, callback1) self.wait() # stopped by close_callback self.assertEqual(chunks, [b("1"), b("2")]) finally: server.close() client.close() def test_close_buffered_data(self): # Similar to the previous test, but with data stored in the OS's # socket buffers instead of the IOStream's read buffer. Out-of-band # close notifications must be delayed until all data has been # drained into the IOStream buffer. (epoll used to use out-of-band # close events with EPOLLRDHUP, but no longer) # # This depends on the read_chunk_size being smaller than the # OS socket buffer, so make it small. server, client = self.make_iostream_pair(read_chunk_size=256) try: server.write(b("A") * 512) client.read_bytes(256, self.stop) data = self.wait() self.assertEqual(b("A") * 256, data) server.close() # Allow the close to propagate to the client side of the # connection. Using add_callback instead of add_timeout # doesn't seem to work, even with multiple iterations self.io_loop.add_timeout(time.time() + 0.01, self.stop) self.wait() client.read_bytes(256, self.stop) data = self.wait() self.assertEqual(b("A") * 256, data) finally: server.close() client.close()
class AsyncSocket(object): def __init__(self, sock): self._iostream = IOStream(sock) self._resolver = Resolver() self._readtimeout = 0 self._connecttimeout = 0 def set_readtimeout(self, timeout): self._readtimeout = timeout def set_connecttimeout(self, timeout): self._connecttimeout = timeout @synclize def connect(self, address): host, port = address timer = None try: if self._connecttimeout: timer = Timeout(self._connecttimeout) timer.start() resolved_addrs = yield self._resolver.resolve(host, port, family=socket.AF_INET) for addr in resolved_addrs: family, host_port = addr yield self._iostream.connect(host_port) break except TimeoutException: self.close() raise finally: if timer: timer.cancel() #@synclize def sendall(self, buff): self._iostream.write(buff) @synclize def read(self, nbytes, partial=False): timer = None try: if self._readtimeout: timer = Timeout(self._readtimeout) timer.start() buff = yield self._iostream.read_bytes(nbytes, partial=partial) raise Return(buff) except TimeoutException: self.close() raise finally: if timer: timer.cancel() def recv(self, nbytes): return self.read(nbytes, partial=True) @synclize def readline(self, max_bytes=-1): timer = None if self._readtimeout: timer = Timeout(self._readtimeout) timer.start() try: if max_bytes > 0: buff = yield self._iostream.read_until('\n', max_bytes=max_bytes) else: buff = yield self._iostream.read_until('\n') raise Return(buff) except TimeoutException: self.close() raise finally: if timer: timer.cancel() def close(self): self._iostream.close() def set_nodelay(self, flag): self._iostream.set_nodelay(flag) def settimeout(self, timeout): pass def shutdown(self, direction): if self._iostream.fileno(): self._iostream.fileno().shutdown(direction) def recv_into(self, buff): expected_rbytes = len(buff) data = self.read(expected_rbytes, True) srcarray = bytearray(data) nbytes = len(srcarray) buff[0:nbytes] = srcarray return nbytes def makefile(self, mode, other): return self
class Connection(object): def __init__(self, host, port, on_connect, on_disconnect, timeout=None, io_loop=None): self.host = host self.port = port self.on_connect = on_connect self.on_disconnect = on_disconnect self.timeout = timeout self._stream = None self._io_loop = io_loop self.try_left = 2 self.in_progress = False self.read_queue = [] def connect(self): try: sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM, 0) sock.setsockopt(socket.SOL_TCP, socket.TCP_NODELAY, 1) sock.settimeout(self.timeout) sock.connect((self.host, self.port)) self._stream = IOStream(sock, io_loop=self._io_loop) self.connected() except socket.error as e: raise ConnectionError(str(e)) self.on_connect() def disconnect(self): if self._stream: try: self._stream.close() except socket.error as e: pass self._stream = None def write(self, data, try_left=None): if try_left is None: try_left = self.try_left if not self._stream: self.connect() if not self._stream: raise ConnectionError( 'Tried to write to non-existent connection') if try_left > 0: try: self._stream.write(data) except IOError: self.disconnect() self.write(data, try_left - 1) else: raise ConnectionError('Tried to write to non-existent connection') def read(self, length, callback): try: if not self._stream: self.disconnect() raise ConnectionError( 'Tried to read from non-existent connection') self._stream.read_bytes(length, callback) except IOError: self.on_disconnect() def readline(self, callback): try: if not self._stream: self.disconnect() raise ConnectionError( 'Tried to read from non-existent connection') self._stream.read_until(b'\r\n', callback) except Exception as e: self.on_disconnect() def try_to_perform_read(self): if not self.in_progress and self.read_queue: self.in_progress = True self._io_loop.add_callback(partial(self.read_queue.pop(0), None)) @async def queue_wait(self, callback): self.read_queue.append(callback) self.try_to_perform_read() def read_done(self): self.in_progress = False self.try_to_perform_read() def connected(self): if self._stream: return True return False
class NTcpConnector(object): def __init__(self, host, port): self.routes = {} self.host = host self.port = port self._s = socket.socket(socket.AF_INET, socket.SOCK_STREAM, 0) self.stream = IOStream(self._s) self.stream.connect((self.host, self.port), self._start_recv) def unregister(self, client): self.routes = dict( filter(lambda x: x[1] != client, self.routes.items())) def __lt__(self, other): return id(self) < id(other) def sendMsg(self, client, content): sn = client.application.proxys.getSN() self.routes[sn] = client data = struct.pack('<i6I%dsI' % len(content), int(-1), 10020, 20 + len(content), sn, 0, int(time.time()), 1, content.encode('utf-8'), int((20 + len(content)) ^ 0xaaaaaaaa)) self.stream.write(data) def is_connected(self): return not self.stream.closed() def invalidate(self): self.stream.close_fd() def _start_recv(self): self.stream.read_bytes(12, self._on_frame) def _on_frame(self, data): nLen = struct.unpack('<i2I', data)[2] self.stream.read_bytes(nLen, self._on_msg) def _on_msg(self, data): nLen = len(data) sn, nTag, nTime, nCmdId, dataS = struct.unpack('<4I%dsI' % (nLen - 20), data)[0:-1] if sn == 0: self.stream.write( struct.pack('<i7I', int(-1), 10000, 20, 0, 0, int(time.time()), 0, int(20 ^ 0xaaaaaaaa))) elif sn > 0 and (sn in self.routes): fs, strField = {}, '' if nCmdId == 110 and nLen == 292: #十档报价 ds = struct.unpack( '<2iIqIqIqIqIqIqIqIqIqIqIqIqIqIqIqIqIqIqIqIqIqIq', dataS) strField = 'nSecurityID,nTime,nPxBid1,llVolumeBid1,nPxBid2,llVolumeBid2,nPxBid3,llVolumeBid3,nPxBid4,llVolumeBid4,nPxBid5,llVolumeBid5,nPxBid6,llVolumeBid6,nPxBid7,llVolumeBid7,nPxBid8,llVolumeBid8,nPxBid9,llVolumeBid9,nPxBid10,llVolumeBid10,nWeightedAvgBidPx,llTotalBidVolume,nPxOffer1,llVolumeOffer1,nPxOffer2,llVolumeOffer2,nPxOffer3,llVolumeOffer3,nPxOffer4,llVolumeOffer4,nPxOffer5,llVolumeOffer5,nPxOffer6,llVolumeOffer6,nPxOffer7,llVolumeOffer7,nPxOffer8,llVolumeOffer8,nPxOffer9,llVolumeOffer9,nPxOffer10,llVolumeOffer10,nWeightedAvgOfferPx,llTotalOfferVolume' elif nCmdId == 165 and nLen == 644: #委托明细 ds = struct.unpack('<2iI3i150i', dataS) strField = 'nSecurityID,nTime,nPx,nLevel,nOrderCount,nRevealCount,nStatus1,nVolume1,nChangeVolume1,nStatus2,nVolume2,nChangeVolume2,nStatus3,nVolume3,nChangeVolume3,nStatus4,nVolume4,nChangeVolume4,nStatus5,nVolume5,nChangeVolume5,nStatus6,nVolume6,nChangeVolume6,nStatus7,nVolume7,nChangeVolume7,nStatus8,nVolume8,nChangeVolume8,nStatus9,nVolume9,nChangeVolume9,nStatus10,nVolume10,nChangeVolume10,nStatus11,nVolume11,nChangeVolume11,nStatus12,nVolume12,nChangeVolume12,nStatus13,nVolume13,nChangeVolume13,nStatus14,nVolume14,nChangeVolume14,nStatus15,nVolume15,nChangeVolume15,nStatus16,nVolume16,nChangeVolume16,nStatus17,nVolume17,nChangeVolume17,nStatus18,nVolume18,nChangeVolume18,nStatus19,nVolume19,nChangeVolume19,nStatus20,nVolume20,nChangeVolume20,nStatus21,nVolume21,nChangeVolume21,nStatus22,nVolume22,nChangeVolume22,nStatus23,nVolume23,nChangeVolume23,nStatus24,nVolume24,nChangeVolume24,nStatus25,nVolume25,nChangeVolume25,nStatus26,nVolume26,nChangeVolume26,nStatus27,nVolume27,nChangeVolume27,nStatus28,nVolume28,nChangeVolume28,nStatus29,nVolume29,nChangeVolume29,nStatus30,nVolume30,nChangeVolume30,nStatus31,nVolume31,nChangeVolume31,nStatus32,nVolume32,nChangeVolume32,nStatus33,nVolume33,nChangeVolume33,nStatus34,nVolume34,nChangeVolume34,nStatus35,nVolume35,nChangeVolume35,nStatus36,nVolume36,nChangeVolume36,nStatus37,nVolume37,nChangeVolume37,nStatus38,nVolume38,nChangeVolume38,nStatus39,nVolume39,nChangeVolume39,nStatus40,nVolume40,nChangeVolume40,nStatus41,nVolume41,nChangeVolume41,nStatus42,nVolume42,nChangeVolume42,nStatus43,nVolume43,nChangeVolume43,nStatus44,nVolume44,nChangeVolume44,nStatus45,nVolume45,nChangeVolume45,nStatus46,nVolume46,nChangeVolume46,nStatus47,nVolume47,nChangeVolume47,nStatus48,nVolume48,nChangeVolume48,nStatus49,nVolume49,nChangeVolume49,nStatus50,nVolume50,nChangeVolume50' else: pass if strField: fields = strField.split(',') for i in range(0, len(fields)): fs[fields[i]] = ds[i] fs['nCmdId'] = nCmdId self.routes[sn].callback(fs) self._start_recv()
class TornadoClient(Client): """A non-blocking Pomelo client by tornado ioloop Usage : class ClientHandler(object) : def on_recv_data(self, client, proto_type, data) : print "recv_data..." return data def on_connected(self, client, user_data) : print "connect..." client.send_heartbeat() def on_disconnect(self, client) : print "disconnect..." def on_heartbeat(self, client) : print "heartbeat..." send request ... def on_response(self, client, route, request, response) : print "response..." def on_push(self, client, route, push_data) : print "notify..." handler = ClientHandler() client = TornadoClient(handler) client.connect(host, int(port)) client.run() tornado.ioloop.IOLoop.current().start() """ def __init__(self, handler): self.socket = socket(AF_INET, SOCK_STREAM) self.iostream = None self.protocol_package = None super(TornadoClient, self).__init__(handler) def connect(self, host, port): self.iostream = IOStream(self.socket) self.iostream.set_close_callback(self.on_close) self.iostream.connect((host, port), self.on_connect) def on_connect(self): self.send_sync() self.on_data() def on_close(self): if hasattr(self.handler, 'on_disconnect'): self.handler.on_disconnect(self) def send(self, data): assert not self.iostream.closed(), "iostream has closed" if not isinstance(data, bytes): data = bytes(data) self.iostream.write(data) def on_data(self): assert not self.iostream.closed(), "iostream has closed" if None is self.protocol_package or self.protocol_package.completed(): self.iostream.read_bytes(4, self.on_head) def on_head(self, head): self.protocol_package = Protocol.unpack(head) self.iostream.read_bytes(self.protocol_package.length, self.on_body) def on_body(self, body): if hasattr(self.handler, 'on_recv_data'): body = self.handler.on_recv_data(self, self.protocol_package.proto_type, body) self.protocol_package.append(body) self.on_protocol(self.protocol_package) self.on_data() def close(self): if self.iostream: self.iostream.close()
class Connection(object): """ Encapsulates the communication, including parsing, with the beanstalkd """ def __init__(self, host, port, io_loop=None): self._ioloop = io_loop or IOLoop.instance() # setup our protocol callbacks # beanstalkd will reply with a superset of these replies, but these # are the only ones we handle today. patches gleefully accepted. self._beanstalk_protocol_1x = dict( # generic returns OUT_OF_MEMORY = self.fail, INTERNAL_ERROR = self.fail, DRAINING = self.fail, BAD_FORMAT = self.fail, UNKNOWN_COMMAND = self.fail, # put <pri> <delay> <ttr> <bytes> INSERTED = self.ret_inserted, BURIED = self.ret_inserted, EXPECTED_CRLF = self.fail, JOB_TOO_BIG = self.fail, # use USING = None, # reserve RESERVED = self.ret_reserved, DEADLINE_SOON = None, TIMED_OUT = None, # delete <id> DELETED = None, NOT_FOUND = None, # touch <id> TOUCHED = None, # watch <tube> WATCHING = None, #ignore <tube> NOT_IGNORED = None, ) # open a connection to the beanstalkd _sock = socket.socket( socket.AF_INET, socket.SOCK_STREAM, socket.IPPROTO_TCP ) _sock.connect((host, port)) _sock.setblocking(False) self.stream = IOStream(_sock, io_loop=self._ioloop) # i like a placeholder for this. we'll assign it later self.callback = None self.tsr = TornStalkResponse() def _parse_response(self, resp): print "parse_response" tokens = resp.strip().split() if not tokens: return print 'tok:', tokens[1:] self._beanstalk_protocol_1x.get(tokens[0])(tokens) def _payload_rcvd(self, payload): self.tsr.data = payload[:-2] # lose the \r\n self.callback(self.tsr) # lose the \r\n def _command(self, contents): print "sending>%s<" % contents self.stream.write(contents) self.stream.read_until('\r\n', self._parse_response) def cmd_put(self, body, callback, priority=10000, delay=0, ttr=1): """ send the put command to the beanstalkd with a message body priority needs to be between 0 and 2**32. lower gets done first delay is number of seconds before job is available in queue ttr is number of seconds the job has to run by a worker bs: put <pri> <delay> <ttr> <bytes> """ self.callback = callback cmd = 'put {priority} {delay} {ttr} {size}'.format( priority = priority, delay = delay, ttr = ttr, size = len(body) ) payload = '{}\r\n{}\r\n'.format(cmd, body) self._command(payload) def cmd_reserve(self, callback): self.callback = callback cmd = 'reserve\r\n' self._command(cmd) def ret_inserted(self, toks): """ handles both INSERTED and BURIED """ jobid = int(toks[1]) self.callback(TornStalkResponse(data=jobid)) def ret_reserved(self, toks): jobid, size = toks[1:] jobid = int(jobid) size = int(size) + 2 # len('\r\n') self.stream.read_bytes(size, self._payload_rcvd) def handle_error(self, *a): print "error", a raise TornStalkError(a) def ok(self, *a): print "ok", a return True def fail(self, toks): self.callback(TornStalkResponse(result=False, msg=toks[1]))
class KeepAliveTest(AsyncHTTPTestCase): """Tests various scenarios for HTTP 1.1 keep-alive support. These tests don't use AsyncHTTPClient because we want to control connection reuse and closing. """ def get_app(self): class HelloHandler(RequestHandler): def get(self): self.finish("Hello world") def post(self): self.finish("Hello world") class LargeHandler(RequestHandler): def get(self): # 512KB should be bigger than the socket buffers so it will # be written out in chunks. self.write("".join(chr(i % 256) * 1024 for i in range(512))) class FinishOnCloseHandler(RequestHandler): def initialize(self, cleanup_event): self.cleanup_event = cleanup_event @gen.coroutine def get(self): self.flush() yield self.cleanup_event.wait() def on_connection_close(self): # This is not very realistic, but finishing the request # from the close callback has the right timing to mimic # some errors seen in the wild. self.finish("closed") self.cleanup_event = Event() return Application( [ ("/", HelloHandler), ("/large", LargeHandler), ( "/finish_on_close", FinishOnCloseHandler, dict(cleanup_event=self.cleanup_event), ), ] ) def setUp(self): super(KeepAliveTest, self).setUp() self.http_version = b"HTTP/1.1" def tearDown(self): # We just closed the client side of the socket; let the IOLoop run # once to make sure the server side got the message. self.io_loop.add_timeout(datetime.timedelta(seconds=0.001), self.stop) self.wait() if hasattr(self, "stream"): self.stream.close() super(KeepAliveTest, self).tearDown() # The next few methods are a crude manual http client @gen.coroutine def connect(self): self.stream = IOStream(socket.socket()) yield self.stream.connect(("127.0.0.1", self.get_http_port())) @gen.coroutine def read_headers(self): first_line = yield self.stream.read_until(b"\r\n") self.assertTrue(first_line.startswith(b"HTTP/1.1 200"), first_line) header_bytes = yield self.stream.read_until(b"\r\n\r\n") headers = HTTPHeaders.parse(header_bytes.decode("latin1")) raise gen.Return(headers) @gen.coroutine def read_response(self): self.headers = yield self.read_headers() body = yield self.stream.read_bytes(int(self.headers["Content-Length"])) self.assertEqual(b"Hello world", body) def close(self): self.stream.close() del self.stream @gen_test def test_two_requests(self): yield self.connect() self.stream.write(b"GET / HTTP/1.1\r\n\r\n") yield self.read_response() self.stream.write(b"GET / HTTP/1.1\r\n\r\n") yield self.read_response() self.close() @gen_test def test_request_close(self): yield self.connect() self.stream.write(b"GET / HTTP/1.1\r\nConnection: close\r\n\r\n") yield self.read_response() data = yield self.stream.read_until_close() self.assertTrue(not data) self.assertEqual(self.headers["Connection"], "close") self.close() # keepalive is supported for http 1.0 too, but it's opt-in @gen_test def test_http10(self): self.http_version = b"HTTP/1.0" yield self.connect() self.stream.write(b"GET / HTTP/1.0\r\n\r\n") yield self.read_response() data = yield self.stream.read_until_close() self.assertTrue(not data) self.assertTrue("Connection" not in self.headers) self.close() @gen_test def test_http10_keepalive(self): self.http_version = b"HTTP/1.0" yield self.connect() self.stream.write(b"GET / HTTP/1.0\r\nConnection: keep-alive\r\n\r\n") yield self.read_response() self.assertEqual(self.headers["Connection"], "Keep-Alive") self.stream.write(b"GET / HTTP/1.0\r\nConnection: keep-alive\r\n\r\n") yield self.read_response() self.assertEqual(self.headers["Connection"], "Keep-Alive") self.close() @gen_test def test_http10_keepalive_extra_crlf(self): self.http_version = b"HTTP/1.0" yield self.connect() self.stream.write(b"GET / HTTP/1.0\r\nConnection: keep-alive\r\n\r\n\r\n") yield self.read_response() self.assertEqual(self.headers["Connection"], "Keep-Alive") self.stream.write(b"GET / HTTP/1.0\r\nConnection: keep-alive\r\n\r\n") yield self.read_response() self.assertEqual(self.headers["Connection"], "Keep-Alive") self.close() @gen_test def test_pipelined_requests(self): yield self.connect() self.stream.write(b"GET / HTTP/1.1\r\n\r\nGET / HTTP/1.1\r\n\r\n") yield self.read_response() yield self.read_response() self.close() @gen_test def test_pipelined_cancel(self): yield self.connect() self.stream.write(b"GET / HTTP/1.1\r\n\r\nGET / HTTP/1.1\r\n\r\n") # only read once yield self.read_response() self.close() @gen_test def test_cancel_during_download(self): yield self.connect() self.stream.write(b"GET /large HTTP/1.1\r\n\r\n") yield self.read_headers() yield self.stream.read_bytes(1024) self.close() @gen_test def test_finish_while_closed(self): yield self.connect() self.stream.write(b"GET /finish_on_close HTTP/1.1\r\n\r\n") yield self.read_headers() self.close() # Let the hanging coroutine clean up after itself self.cleanup_event.set() @gen_test def test_keepalive_chunked(self): self.http_version = b"HTTP/1.0" yield self.connect() self.stream.write( b"POST / HTTP/1.0\r\n" b"Connection: keep-alive\r\n" b"Transfer-Encoding: chunked\r\n" b"\r\n" b"0\r\n" b"\r\n" ) yield self.read_response() self.assertEqual(self.headers["Connection"], "Keep-Alive") self.stream.write(b"GET / HTTP/1.0\r\nConnection: keep-alive\r\n\r\n") yield self.read_response() self.assertEqual(self.headers["Connection"], "Keep-Alive") self.close()
class TestIOStream(AsyncHTTPTestCase, LogTrapTestCase): def get_app(self): return Application([('/', HelloHandler)]) def make_iostream_pair(self): port = get_unused_port() [listener] = netutil.bind_sockets(port, '127.0.0.1', family=socket.AF_INET) streams = [None, None] def accept_callback(connection, address): streams[0] = IOStream(connection, io_loop=self.io_loop) self.stop() def connect_callback(): streams[1] = client_stream self.stop() netutil.add_accept_handler(listener, accept_callback, io_loop=self.io_loop) client_stream = IOStream(socket.socket(), io_loop=self.io_loop) client_stream.connect(('127.0.0.1', port), callback=connect_callback) self.wait(condition=lambda: all(streams)) return streams def test_read_zero_bytes(self): s = socket.socket(socket.AF_INET, socket.SOCK_STREAM, 0) s.connect(("localhost", self.get_http_port())) self.stream = IOStream(s, io_loop=self.io_loop) self.stream.write(b("GET / HTTP/1.0\r\n\r\n")) # normal read self.stream.read_bytes(9, self.stop) data = self.wait() self.assertEqual(data, b("HTTP/1.0 ")) # zero bytes self.stream.read_bytes(0, self.stop) data = self.wait() self.assertEqual(data, b("")) # another normal read self.stream.read_bytes(3, self.stop) data = self.wait() self.assertEqual(data, b("200")) def test_connection_refused(self): # When a connection is refused, the connect callback should not # be run. (The kqueue IOLoop used to behave differently from the # epoll IOLoop in this respect) port = get_unused_port() stream = IOStream(socket.socket(), self.io_loop) self.connect_called = False def connect_callback(): self.connect_called = True stream.set_close_callback(self.stop) stream.connect(("localhost", port), connect_callback) self.wait() self.assertFalse(self.connect_called) def test_connection_closed(self): # When a server sends a response and then closes the connection, # the client must be allowed to read the data before the IOStream # closes itself. Epoll reports closed connections with a separate # EPOLLRDHUP event delivered at the same time as the read event, # while kqueue reports them as a second read/write event with an EOF # flag. response = self.fetch("/", headers={"Connection": "close"}) response.rethrow() def test_read_until_close(self): s = socket.socket(socket.AF_INET, socket.SOCK_STREAM, 0) s.connect(("localhost", self.get_http_port())) stream = IOStream(s, io_loop=self.io_loop) stream.write(b("GET / HTTP/1.0\r\n\r\n")) stream.read_until_close(self.stop) data = self.wait() self.assertTrue(data.startswith(b("HTTP/1.0 200"))) self.assertTrue(data.endswith(b("Hello"))) def test_streaming_callback(self): server, client = self.make_iostream_pair() try: chunks = [] final_called = [] def streaming_callback(data): chunks.append(data) self.stop() def final_callback(data): assert not data final_called.append(True) self.stop() server.read_bytes(6, callback=final_callback, streaming_callback=streaming_callback) client.write(b("1234")) self.wait(condition=lambda: chunks) client.write(b("5678")) self.wait(condition=lambda: final_called) self.assertEqual(chunks, [b("1234"), b("56")]) # the rest of the last chunk is still in the buffer server.read_bytes(2, callback=self.stop) data = self.wait() self.assertEqual(data, b("78")) finally: server.close() client.close() def test_streaming_until_close(self): server, client = self.make_iostream_pair() try: chunks = [] def callback(data): chunks.append(data) self.stop() client.read_until_close(callback=callback, streaming_callback=callback) server.write(b("1234")) self.wait() server.write(b("5678")) self.wait() server.close() self.wait() self.assertEqual(chunks, [b("1234"), b("5678"), b("")]) finally: server.close() client.close()
class KeepAliveTest(AsyncHTTPTestCase): """Tests various scenarios for HTTP 1.1 keep-alive support. These tests don't use AsyncHTTPClient because we want to control connection reuse and closing. """ def get_app(self): class HelloHandler(RequestHandler): def get(self): self.finish('Hello world') class LargeHandler(RequestHandler): def get(self): # 512KB should be bigger than the socket buffers so it will # be written out in chunks. self.write(''.join(chr(i % 256) * 1024 for i in range(512))) class FinishOnCloseHandler(RequestHandler): @asynchronous def get(self): self.flush() def on_connection_close(self): # This is not very realistic, but finishing the request # from the close callback has the right timing to mimic # some errors seen in the wild. self.finish('closed') return Application([('/', HelloHandler), ('/large', LargeHandler), ('/finish_on_close', FinishOnCloseHandler)]) def setUp(self): super(KeepAliveTest, self).setUp() self.http_version = b'HTTP/1.1' def tearDown(self): # We just closed the client side of the socket; let the IOLoop run # once to make sure the server side got the message. self.io_loop.add_timeout(datetime.timedelta(seconds=0.001), self.stop) self.wait() if hasattr(self, 'stream'): self.stream.close() super(KeepAliveTest, self).tearDown() # The next few methods are a crude manual http client def connect(self): self.stream = IOStream(socket.socket(), io_loop=self.io_loop) self.stream.connect(('localhost', self.get_http_port()), self.stop) self.wait() def read_headers(self): self.stream.read_until(b'\r\n', self.stop) first_line = self.wait() self.assertTrue(first_line.startswith(self.http_version + b' 200'), first_line) self.stream.read_until(b'\r\n\r\n', self.stop) header_bytes = self.wait() headers = HTTPHeaders.parse(header_bytes.decode('latin1')) return headers def read_response(self): self.headers = self.read_headers() self.stream.read_bytes(int(self.headers['Content-Length']), self.stop) body = self.wait() self.assertEqual(b'Hello world', body) def close(self): self.stream.close() del self.stream def test_two_requests(self): self.connect() self.stream.write(b'GET / HTTP/1.1\r\n\r\n') self.read_response() self.stream.write(b'GET / HTTP/1.1\r\n\r\n') self.read_response() self.close() def test_request_close(self): self.connect() self.stream.write(b'GET / HTTP/1.1\r\nConnection: close\r\n\r\n') self.read_response() self.stream.read_until_close(callback=self.stop) data = self.wait() self.assertTrue(not data) self.close() # keepalive is supported for http 1.0 too, but it's opt-in def test_http10(self): self.http_version = b'HTTP/1.0' self.connect() self.stream.write(b'GET / HTTP/1.0\r\n\r\n') self.read_response() self.stream.read_until_close(callback=self.stop) data = self.wait() self.assertTrue(not data) self.assertTrue('Connection' not in self.headers) self.close() def test_http10_keepalive(self): self.http_version = b'HTTP/1.0' self.connect() self.stream.write(b'GET / HTTP/1.0\r\nConnection: keep-alive\r\n\r\n') self.read_response() self.assertEqual(self.headers['Connection'], 'Keep-Alive') self.stream.write(b'GET / HTTP/1.0\r\nConnection: keep-alive\r\n\r\n') self.read_response() self.assertEqual(self.headers['Connection'], 'Keep-Alive') self.close() def test_pipelined_requests(self): self.connect() self.stream.write(b'GET / HTTP/1.1\r\n\r\nGET / HTTP/1.1\r\n\r\n') self.read_response() self.read_response() self.close() def test_pipelined_cancel(self): self.connect() self.stream.write(b'GET / HTTP/1.1\r\n\r\nGET / HTTP/1.1\r\n\r\n') # only read once self.read_response() self.close() def test_cancel_during_download(self): self.connect() self.stream.write(b'GET /large HTTP/1.1\r\n\r\n') self.read_headers() self.stream.read_bytes(1024, self.stop) self.wait() self.close() def test_finish_while_closed(self): self.connect() self.stream.write(b'GET /finish_on_close HTTP/1.1\r\n\r\n') self.read_headers() self.close()
class Connection(Connection): def __init__(self, pool=None, *args, **kwargs): super(Connection, self).__init__(*args, **kwargs) self._pool = pool self._stream = None self._callbacks = [] self._ready = False def __enter__(self): return self def __exit__(self, exc_type, exc_val, exc_tb): self._pool.release(self) def _add_callback(self, func): self._callbacks.append(func) def _do_callbacks(self): self._ready = True while 1: try: func = self._callbacks.pop() func() except IndexError: # all done break except: # other error continue def connect(self): s = socket.socket(socket.AF_INET, socket.SOCK_STREAM, 0) self._sock = IOStream(s) # tornado iostream self._sock.connect((self._host, self._port), self._do_callbacks) def send(self, payload, correlation_id=-1, callback=None): """ :param payload: an encoded kafka packet :param correlation_id: for now, just for debug logging :return: """ if not self._ready: def _callback(*args, **kwargs): self.send(payload, correlation_id, callback) self._add_callback(_callback) return log.debug("About to send %d bytes to Kafka, request %d" % (len(payload), correlation_id)) if payload: _bytes = struct.pack('>i%ds' % len(payload), len(payload), payload) else: _bytes = struct.pack('>i', -1) try: self._sock.write(_bytes, callback) # simply using sendall except: self.close() callback(None) self._log_and_raise('Unable to send payload to Kafka') def _recv(self, size, callback): try: self._sock.read_bytes(min(size, 4096), callback) except: self.close() callback(None) # if error, set None self._log_and_raise('Unable to receive data from Kafka') def recv(self, correlation_id=-1, callback=None): """ :param correlation_id: for now, just for debug logging :return: kafka response packet """ log.debug("Reading response %d from Kafka" % correlation_id) if not self._ready: def _callback(): self.recv(correlation_id, callback) self._add_callback(_callback) return def get_size(resp): if resp == None: callback(None) size, = struct.unpack('>i', resp) self._recv(size, callback) self._recv(4, get_size) # read the response length def close(self): self._callbacks = [] log.debug("Closing socket connection" + self._log_tail) if self._sock: self._sock.close() self._sock = None else: log.debug("Socket connection not exists" + self._log_tail) def closed(self): return self._sock.closed()
class _RPCClientConnection(object): '''An RPC client connection.''' def __init__(self, close_callback): self._stream = None self._sequence = itertools.count() self._pending = {} # sequence -> callback self._pending_read = None self._close_callback = stack_context.wrap(close_callback) @gen.engine def connect(self, address, nonce=protocol.NULL_NONCE, callback=None): if self._stream is not None: raise RuntimeError('Attempting to reconnect existing connection') sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM, 0) sock.setsockopt(socket.SOL_TCP, socket.TCP_NODELAY, 1) self._stream = IOStream(sock) self._stream.set_close_callback(self._handle_close) # If the connect fails, our close callback will be called, and # the Wait will never return. self._stream.connect((address, protocol.PORT), callback=(yield gen.Callback('connect'))) self._write(nonce) yield gen.Wait('connect') rnonce = yield gen.Task(self._read, protocol.NONCE_LEN) # Start reply-handling coroutine self._reply_handler() if callback is not None: callback(rnonce) def close(self): if self._stream is not None: self._stream.close() else: # close() called before connect(). Synthesize the close event # ourselves. self._handle_close() def _handle_close(self): if self._pending_read is not None: # The pending read callback will never be called. Call it # ourselves to clean up. self._pending_read(None) if self._close_callback is not None: cb = self._close_callback self._close_callback = None cb() @gen.engine def _read(self, count, callback=None): if self._pending_read is not None: raise RuntimeError('Double read on connection') self._pending_read = stack_context.wrap((yield gen.Callback('read'))) try: self._stream.read_bytes(count, callback=self._pending_read) buf = yield gen.Wait('read') if buf is None: # _handle_close() is cleaning us up raise ConnectionFailure('Connection closed') except IOError, e: self.close() raise ConnectionFailure(str(e)) finally:
class AsyncRedisClient(object): """An non-blocking Redis client. Example usage:: import ioloop def handle_request(result): print 'Redis reply: %r' % result ioloop.IOLoop.instance().stop() redis_client = AsyncRedisClient(('127.0.0.1', 6379)) redis_client.fetch(('set', 'foo', 'bar'), None) redis_client.fetch(('get', 'foo'), handle_request) ioloop.IOLoop.instance().start() This class implements a Redis client on top of Tornado's IOStreams. It does not currently implement all applicable parts of the Redis specification, but it does enough to work with major redis server APIs (mostly tested against the LIST/HASH/PUBSUB API so far). This class has not been tested extensively in production and should be considered somewhat experimental as of the release of tornado 1.2. It is intended to become the default tornado AsyncRedisClient implementation. """ def __init__(self, address, io_loop=None): """Creates a AsyncRedisClient. address is the tuple of redis server address that can be connect by IOStream. It can be to ('127.0.0.1', 6379). """ self.address = address self.io_loop = io_loop or IOLoop.instance() self._callback_queue = deque() self._callback = None self._read_buffer = None self._result_queue = deque() self.socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) self.stream = IOStream(self.socket, self.io_loop) self.stream.connect(self.address, self._wait_result) def close(self): """Destroys this redis client, freeing any file descriptors used. Not needed in normal use, but may be helpful in unittests that create and destroy redis clients. No other methods may be called on the AsyncRedisClient after close(). """ self.stream.close() def fetch(self, request, callback): """Executes a request, calling callback with an redis `result`. The request shuold be a string tuple. like ('set', 'foo', 'bar') If an error occurs during the fetch, a `RedisError` exception will throw out. You can use try...except to catch the exception (if any) in the callback. """ self._callback_queue.append(callback) self.stream.write(encode(request)) def _wait_result(self): """Read a completed result data from the redis server.""" self._read_buffer = deque() self.stream.read_until('\r\n', self._on_read_first_line) def _maybe_callback(self): """Try call callback in _callback_queue when we read a redis result.""" try: read_buffer = self._read_buffer callback = self._callback result_queue = self._result_queue callback_queue = self._callback_queue if result_queue: result_queue.append(read_buffer) read_buffer = result_queue.popleft() if callback_queue: callback = self._callback = callback_queue.popleft() if callback: callback(decode(read_buffer)) except Exception: logging.error('Uncaught callback exception', exc_info=True) self.close() raise finally: self._wait_result() def _on_read_first_line(self, data): self._read_buffer.append(data) c = data[0] if c in ':+-': self._maybe_callback() elif c == '$': if data[:3] == '$-1': self._maybe_callback() else: length = int(data[1:]) self.stream.read_bytes(length+2, self._on_read_bulk_body) elif c == '*': if data[1] in '-0' : self._maybe_callback() else: self._multibulk_number = int(data[1:]) self.stream.read_until('\r\n', self._on_read_multibulk_bulk_head) def _on_read_bulk_body(self, data): self._read_buffer.append(data) self._maybe_callback() def _on_read_multibulk_bulk_head(self, data): self._read_buffer.append(data) c = data[0] if c == '$': length = int(data[1:]) self.stream.read_bytes(length+2, self._on_read_multibulk_bulk_body) else: self._maybe_callback() def _on_read_multibulk_bulk_body(self, data): self._read_buffer.append(data) self._multibulk_number -= 1 if self._multibulk_number: self.stream.read_until('\r\n', self._on_read_multibulk_bulk_head) else: self._maybe_callback()
class TTornadoTransport(TTransportBase): """A non-blocking Thrift client. Example usage:: import greenlet from tornado import ioloop from thrift.transport import TTransport from thrift.protocol import TBinaryProtocol from viewfinder.backend.thrift import TTornadoTransport transport = TTransport.TFramedTransport(TTornadoTransport('localhost', 9090)) protocol = TBinaryProtocol.TBinaryProtocol(transport) client = Service.Client(protocol) ioloop.IOLoop.instance().start() Then, from within an asynchronous tornado request handler: class MyApp(tornado.web.RequestHandler): @tornado.web.asynchronous def post(self): def business_logic(): ...any thrift calls... self.write(...stuff that gets returned to client...) self.finish() #end the asynchronous request gr = greenlet.greenlet(business_logic) gr.switch() """ def __init__(self, host='localhost', port=9090): """Initialize a TTornadoTransport with a Tornado IOStream. @param host(str) The host to connect to. @param port(int) The (TCP) port to connect to. """ self.host = host self.port = port self._stream = None self._io_loop = ioloop.IOLoop.current() self._timeout_secs = None def set_timeout(self, timeout_secs): """Sets a timeout for use with open/read/write operations.""" self._timeout_secs = timeout_secs def isOpen(self): return self._stream is not None def open(self): """Creates a connection to host:port and spins up a tornado IOStream object to write requests and read responses from the thrift server. After making the asynchronous connect call to _stream, the current greenlet yields control back to the parent greenlet (presumably the "master" greenlet). """ assert greenlet.getcurrent().parent is not None # TODO(spencer): allow ipv6? (af = socket.AF_UNSPEC) addrinfo = socket.getaddrinfo(self.host, self.port, socket.AF_INET, socket.SOCK_STREAM, 0, 0) af, socktype, proto, canonname, sockaddr = addrinfo[0] self._stream = IOStream(socket.socket(af, socktype, proto), io_loop=self._io_loop) self._open_internal(sockaddr) def close(self): if self._stream: self._stream.set_close_callback(None) self._stream.close() self._stream = None @_wrap_transport def read(self, sz): logging.debug("reading %d bytes from %s:%d" % (sz, self.host, self.port)) cur_gr = greenlet.getcurrent() def _on_read(buf): if self._stream: cur_gr.switch(buf) self._stream.read_bytes(sz, _on_read) buf = cur_gr.parent.switch() if len(buf) == 0: raise TTransportException(type=TTransportException.END_OF_FILE, message='TTornadoTransport read 0 bytes') logging.debug("read %d bytes in %.2fms" % (len(buf), (time.time() - self._start_time) * 1000)) return buf @_wrap_transport def write(self, buf): logging.debug("writing %d bytes to %s:%d" % (len(buf), self.host, self.port)) cur_gr = greenlet.getcurrent() def _on_write(): if self._stream: cur_gr.switch() self._stream.write(buf, _on_write) cur_gr.parent.switch() logging.debug("wrote %d bytes in %.2fms" % (len(buf), (time.time() - self._start_time) * 1000)) @_wrap_transport def flush(self): pass @_wrap_transport def _open_internal(self, sockaddr): logging.debug("opening connection to %s:%d" % (self.host, self.port)) cur_gr = greenlet.getcurrent() def _on_connect(): if self._stream: cur_gr.switch() self._stream.connect(sockaddr, _on_connect) cur_gr.parent.switch() logging.info("opened connection to %s:%d" % (self.host, self.port)) def _check_stream(self): if not self._stream: raise TTransportException(type=TTransportException.NOT_OPEN, message='transport not open') def _set_timeout(self): if self._timeout_secs: return self._io_loop.add_timeout( time.time() + self._timeout_secs, functools.partial(self._on_timeout, gr=greenlet.getcurrent())) return None def _clear_timeout(self, timeout): if timeout: self._io_loop.remove_timeout(timeout) def _on_timeout(self, gr): gr.throw( TTransportException(type=TTransportException.TIMED_OUT, message="connection timed out to %s:%d" % (self.host, self.port))) def _on_close(self, gr): self._stream = None message = "connection to %s:%d closed" % (self.host, self.port) if gr: gr.throw( TTransportException(type=TTransportException.NOT_OPEN, message=message)) else: logging.error(message)
class KeepAliveTest(AsyncHTTPTestCase): """Tests various scenarios for HTTP 1.1 keep-alive support. These tests don't use AsyncHTTPClient because we want to control connection reuse and closing. """ def get_app(self): class HelloHandler(RequestHandler): def get(self): self.finish("Hello world") def post(self): self.finish("Hello world") class LargeHandler(RequestHandler): def get(self): # 512KB should be bigger than the socket buffers so it will # be written out in chunks. self.write("".join(chr(i % 256) * 1024 for i in range(512))) class FinishOnCloseHandler(RequestHandler): def initialize(self, cleanup_event): self.cleanup_event = cleanup_event @gen.coroutine def get(self): self.flush() yield self.cleanup_event.wait() def on_connection_close(self): # This is not very realistic, but finishing the request # from the close callback has the right timing to mimic # some errors seen in the wild. self.finish("closed") self.cleanup_event = Event() return Application([ ("/", HelloHandler), ("/large", LargeHandler), ( "/finish_on_close", FinishOnCloseHandler, dict(cleanup_event=self.cleanup_event), ), ]) def setUp(self): super(KeepAliveTest, self).setUp() self.http_version = b"HTTP/1.1" def tearDown(self): # We just closed the client side of the socket; let the IOLoop run # once to make sure the server side got the message. self.io_loop.add_timeout(datetime.timedelta(seconds=0.001), self.stop) self.wait() if hasattr(self, "stream"): self.stream.close() super(KeepAliveTest, self).tearDown() # The next few methods are a crude manual http client @gen.coroutine def connect(self): self.stream = IOStream(socket.socket()) yield self.stream.connect(("10.0.0.7", self.get_http_port())) @gen.coroutine def read_headers(self): first_line = yield self.stream.read_until(b"\r\n") self.assertTrue(first_line.startswith(b"HTTP/1.1 200"), first_line) header_bytes = yield self.stream.read_until(b"\r\n\r\n") headers = HTTPHeaders.parse(header_bytes.decode("latin1")) raise gen.Return(headers) @gen.coroutine def read_response(self): self.headers = yield self.read_headers() body = yield self.stream.read_bytes(int( self.headers["Content-Length"])) self.assertEqual(b"Hello world", body) def close(self): self.stream.close() del self.stream @gen_test def test_two_requests(self): yield self.connect() self.stream.write(b"GET / HTTP/1.1\r\n\r\n") yield self.read_response() self.stream.write(b"GET / HTTP/1.1\r\n\r\n") yield self.read_response() self.close() @gen_test def test_request_close(self): yield self.connect() self.stream.write(b"GET / HTTP/1.1\r\nConnection: close\r\n\r\n") yield self.read_response() data = yield self.stream.read_until_close() self.assertTrue(not data) self.assertEqual(self.headers["Connection"], "close") self.close() # keepalive is supported for http 1.0 too, but it's opt-in @gen_test def test_http10(self): self.http_version = b"HTTP/1.0" yield self.connect() self.stream.write(b"GET / HTTP/1.0\r\n\r\n") yield self.read_response() data = yield self.stream.read_until_close() self.assertTrue(not data) self.assertTrue("Connection" not in self.headers) self.close() @gen_test def test_http10_keepalive(self): self.http_version = b"HTTP/1.0" yield self.connect() self.stream.write(b"GET / HTTP/1.0\r\nConnection: keep-alive\r\n\r\n") yield self.read_response() self.assertEqual(self.headers["Connection"], "Keep-Alive") self.stream.write(b"GET / HTTP/1.0\r\nConnection: keep-alive\r\n\r\n") yield self.read_response() self.assertEqual(self.headers["Connection"], "Keep-Alive") self.close() @gen_test def test_http10_keepalive_extra_crlf(self): self.http_version = b"HTTP/1.0" yield self.connect() self.stream.write( b"GET / HTTP/1.0\r\nConnection: keep-alive\r\n\r\n\r\n") yield self.read_response() self.assertEqual(self.headers["Connection"], "Keep-Alive") self.stream.write(b"GET / HTTP/1.0\r\nConnection: keep-alive\r\n\r\n") yield self.read_response() self.assertEqual(self.headers["Connection"], "Keep-Alive") self.close() @gen_test def test_pipelined_requests(self): yield self.connect() self.stream.write(b"GET / HTTP/1.1\r\n\r\nGET / HTTP/1.1\r\n\r\n") yield self.read_response() yield self.read_response() self.close() @gen_test def test_pipelined_cancel(self): yield self.connect() self.stream.write(b"GET / HTTP/1.1\r\n\r\nGET / HTTP/1.1\r\n\r\n") # only read once yield self.read_response() self.close() @gen_test def test_cancel_during_download(self): yield self.connect() self.stream.write(b"GET /large HTTP/1.1\r\n\r\n") yield self.read_headers() yield self.stream.read_bytes(1024) self.close() @gen_test def test_finish_while_closed(self): yield self.connect() self.stream.write(b"GET /finish_on_close HTTP/1.1\r\n\r\n") yield self.read_headers() self.close() # Let the hanging coroutine clean up after itself self.cleanup_event.set() @gen_test def test_keepalive_chunked(self): self.http_version = b"HTTP/1.0" yield self.connect() self.stream.write(b"POST / HTTP/1.0\r\n" b"Connection: keep-alive\r\n" b"Transfer-Encoding: chunked\r\n" b"\r\n" b"0\r\n" b"\r\n") yield self.read_response() self.assertEqual(self.headers["Connection"], "Keep-Alive") self.stream.write(b"GET / HTTP/1.0\r\nConnection: keep-alive\r\n\r\n") yield self.read_response() self.assertEqual(self.headers["Connection"], "Keep-Alive") self.close()
def handle_connection(connection, address): log.info('Connection received from %s' % str(address)) stream = IOStream(connection, ioloop) # Getting uuid stream.read_bytes(4, partial(read_uuid_size, stream))
class _HTTPConnection(object): _SUPPORTED_METHODS = set(["GET", "HEAD", "POST", "PUT", "DELETE"]) def __init__(self, io_loop, client, request, release_callback, final_callback, max_buffer_size): self.start_time = time.time() self.io_loop = io_loop self.client = client self.request = request self.release_callback = release_callback self.final_callback = final_callback self.code = None self.headers = None self.chunks = None self._decompressor = None # Timeout handle returned by IOLoop.add_timeout self._timeout = None with stack_context.StackContext(self.cleanup): parsed = urllib.parse.urlsplit(_unicode(self.request.url)) if ssl is None and parsed.scheme == "https": raise ValueError("HTTPS requires either python2.6+ or " "curl_httpclient") if parsed.scheme not in ("http", "https"): raise ValueError("Unsupported url scheme: %s" % self.request.url) # urlsplit results have hostname and port results, but they # didn't support ipv6 literals until python 2.7. netloc = parsed.netloc if "@" in netloc: userpass, _, netloc = netloc.rpartition("@") match = re.match(r"^(.+):(\d+)$", netloc) if match: host = match.group(1) port = int(match.group(2)) else: host = netloc port = 443 if parsed.scheme == "https" else 80 if re.match(r"^\[.*\]$", host): # raw ipv6 addresses in urls are enclosed in brackets host = host[1:-1] if self.client.hostname_mapping is not None: host = self.client.hostname_mapping.get(host, host) if request.allow_ipv6: af = socket.AF_UNSPEC else: # We only try the first IP we get from getaddrinfo, # so restrict to ipv4 by default. af = socket.AF_INET addrinfo = socket.getaddrinfo(host, port, af, socket.SOCK_STREAM, 0, 0) af, socktype, proto, canonname, sockaddr = addrinfo[0] if parsed.scheme == "https": ssl_options = {} if request.validate_cert: ssl_options["cert_reqs"] = ssl.CERT_REQUIRED if request.ca_certs is not None: ssl_options["ca_certs"] = request.ca_certs else: ssl_options["ca_certs"] = _DEFAULT_CA_CERTS if request.client_key is not None: ssl_options["keyfile"] = request.client_key if request.client_cert is not None: ssl_options["certfile"] = request.client_cert # SSL interoperability is tricky. We want to disable # SSLv2 for security reasons; it wasn't disabled by default # until openssl 1.0. The best way to do this is to use # the SSL_OP_NO_SSLv2, but that wasn't exposed to python # until 3.2. Python 2.7 adds the ciphers argument, which # can also be used to disable SSLv2. As a last resort # on python 2.6, we set ssl_version to SSLv3. This is # more narrow than we'd like since it also breaks # compatibility with servers configured for TLSv1 only, # but nearly all servers support SSLv3: # http://blog.ivanristic.com/2011/09/ssl-survey-protocol-support.html if sys.version_info >= (2, 7): ssl_options["ciphers"] = "DEFAULT:!SSLv2" else: # This is really only necessary for pre-1.0 versions # of openssl, but python 2.6 doesn't expose version # information. ssl_options["ssl_version"] = ssl.PROTOCOL_SSLv3 self.stream = SSLIOStream( socket.socket(af, socktype, proto), io_loop=self.io_loop, ssl_options=ssl_options, max_buffer_size=max_buffer_size, ) else: self.stream = IOStream( socket.socket(af, socktype, proto), io_loop=self.io_loop, max_buffer_size=max_buffer_size ) timeout = min(request.connect_timeout, request.request_timeout) if timeout: self._timeout = self.io_loop.add_timeout(self.start_time + timeout, self._on_timeout) self.stream.set_close_callback(self._on_close) self.stream.connect(sockaddr, functools.partial(self._on_connect, parsed)) def _on_timeout(self): self._timeout = None self._run_callback( HTTPResponse(self.request, 599, request_time=time.time() - self.start_time, error=HTTPError(599, "Timeout")) ) self.stream.close() def _on_connect(self, parsed): if self._timeout is not None: self.io_loop.remove_timeout(self._timeout) self._timeout = None if self.request.request_timeout: self._timeout = self.io_loop.add_timeout(self.start_time + self.request.request_timeout, self._on_timeout) if self.request.validate_cert and isinstance(self.stream, SSLIOStream): match_hostname(self.stream.socket.getpeercert(), parsed.hostname) if self.request.method not in self._SUPPORTED_METHODS and not self.request.allow_nonstandard_methods: raise KeyError("unknown method %s" % self.request.method) for key in ("network_interface", "proxy_host", "proxy_port", "proxy_username", "proxy_password"): if getattr(self.request, key, None): raise NotImplementedError("%s not supported" % key) if "Host" not in self.request.headers: self.request.headers["Host"] = parsed.netloc username, password = None, None if parsed.username is not None: username, password = parsed.username, parsed.password elif self.request.auth_username is not None: username = self.request.auth_username password = self.request.auth_password or "" if username is not None: auth = utf8(username) + b(":") + utf8(password) self.request.headers["Authorization"] = b("Basic ") + base64.b64encode(auth) if self.request.user_agent: self.request.headers["User-Agent"] = self.request.user_agent if not self.request.allow_nonstandard_methods: if self.request.method in ("POST", "PUT"): assert self.request.body is not None else: assert self.request.body is None if self.request.body is not None: self.request.headers["Content-Length"] = str(len(self.request.body)) if self.request.method == "POST" and "Content-Type" not in self.request.headers: self.request.headers["Content-Type"] = "application/x-www-form-urlencoded" if self.request.use_gzip: self.request.headers["Accept-Encoding"] = "gzip" req_path = (parsed.path or "/") + (("?" + parsed.query) if parsed.query else "") request_lines = [utf8("%s %s HTTP/1.1" % (self.request.method, req_path))] for k, v in self.request.headers.get_all(): line = utf8(k) + b(": ") + utf8(v) if b("\n") in line: raise ValueError("Newline in header: " + repr(line)) request_lines.append(line) self.stream.write(b("\r\n").join(request_lines) + b("\r\n\r\n")) if self.request.body is not None: self.stream.write(self.request.body) self.stream.read_until_regex(b("\r?\n\r?\n"), self._on_headers) def _release(self): if self.release_callback is not None: release_callback = self.release_callback self.release_callback = None release_callback() def _run_callback(self, response): self._release() if self.final_callback is not None: final_callback = self.final_callback self.final_callback = None final_callback(response) @contextlib.contextmanager def cleanup(self): try: yield except Exception as e: logging.warning("uncaught exception", exc_info=True) self._run_callback(HTTPResponse(self.request, 599, error=e, request_time=time.time() - self.start_time)) def _on_close(self): self._run_callback( HTTPResponse( self.request, 599, request_time=time.time() - self.start_time, error=HTTPError(599, "Connection closed") ) ) def _on_headers(self, data): data = native_str(data.decode("latin1")) first_line, _, header_data = data.partition("\n") match = re.match("HTTP/1.[01] ([0-9]+)", first_line) assert match self.code = int(match.group(1)) self.headers = HTTPHeaders.parse(header_data) if "Content-Length" in self.headers: if "," in self.headers["Content-Length"]: # Proxies sometimes cause Content-Length headers to get # duplicated. If all the values are identical then we can # use them but if they differ it's an error. pieces = re.split(r",\s*", self.headers["Content-Length"]) if any(i != pieces[0] for i in pieces): raise ValueError("Multiple unequal Content-Lengths: %r" % self.headers["Content-Length"]) self.headers["Content-Length"] = pieces[0] content_length = int(self.headers["Content-Length"]) else: content_length = None if self.request.header_callback is not None: for k, v in self.headers.get_all(): self.request.header_callback("%s: %s\r\n" % (k, v)) if self.request.method == "HEAD": # HEAD requests never have content, even though they may have # content-length headers self._on_body(b("")) return if 100 <= self.code < 200 or self.code in (204, 304): # These response codes never have bodies # http://www.w3.org/Protocols/rfc2616/rfc2616-sec4.html#sec4.3 assert "Transfer-Encoding" not in self.headers assert content_length in (None, 0) self._on_body(b("")) return if self.request.use_gzip and self.headers.get("Content-Encoding") == "gzip": # Magic parameter makes zlib module understand gzip header # http://stackoverflow.com/questions/1838699/how-can-i-decompress-a-gzip-stream-with-zlib self._decompressor = zlib.decompressobj(16 + zlib.MAX_WBITS) if self.headers.get("Transfer-Encoding") == "chunked": self.chunks = [] self.stream.read_until(b("\r\n"), self._on_chunk_length) elif content_length is not None: self.stream.read_bytes(content_length, self._on_body) else: self.stream.read_until_close(self._on_body) def _on_body(self, data): if self._timeout is not None: self.io_loop.remove_timeout(self._timeout) self._timeout = None original_request = getattr(self.request, "original_request", self.request) if self.request.follow_redirects and self.request.max_redirects > 0 and self.code in (301, 302, 303, 307): new_request = copy.copy(self.request) new_request.url = urllib.parse.urljoin(self.request.url, self.headers["Location"]) new_request.max_redirects -= 1 del new_request.headers["Host"] # http://www.w3.org/Protocols/rfc2616/rfc2616-sec10.html#sec10.3.4 # client SHOULD make a GET request if self.code == 303: new_request.method = "GET" new_request.body = None for h in ["Content-Length", "Content-Type", "Content-Encoding", "Transfer-Encoding"]: try: del self.request.headers[h] except KeyError: pass new_request.original_request = original_request final_callback = self.final_callback self.final_callback = None self._release() self.client.fetch(new_request, final_callback) self.stream.close() return if self._decompressor: data = self._decompressor.decompress(data) if self.request.streaming_callback: if self.chunks is None: # if chunks is not None, we already called streaming_callback # in _on_chunk_data self.request.streaming_callback(data) buffer = BytesIO() else: buffer = BytesIO(data) # TODO: don't require one big string? response = HTTPResponse( original_request, self.code, headers=self.headers, request_time=time.time() - self.start_time, buffer=buffer, effective_url=self.request.url, ) self._run_callback(response) self.stream.close() def _on_chunk_length(self, data): # TODO: "chunk extensions" http://tools.ietf.org/html/rfc2616#section-3.6.1 length = int(data.strip(), 16) if length == 0: # all the data has been decompressed, so we don't need to # decompress again in _on_body self._decompressor = None self._on_body(b("").join(self.chunks)) else: self.stream.read_bytes(length + 2, self._on_chunk_data) # chunk ends with \r\n def _on_chunk_data(self, data): assert data[-2:] == b("\r\n") chunk = data[:-2] if self._decompressor: chunk = self._decompressor.decompress(chunk) if self.request.streaming_callback is not None: self.request.streaming_callback(chunk) else: self.chunks.append(chunk) self.stream.read_until(b("\r\n"), self._on_chunk_length)
class Connection(object): def __init__(self, host='localhost', port=6379, unix_socket_path=None, event_handler_proxy=None, stop_after=None, io_loop=None): self.host = host self.port = port self.unix_socket_path = unix_socket_path self._event_handler = event_handler_proxy self.timeout = stop_after self._stream = None self._io_loop = io_loop self.in_progress = False self.read_callbacks = set() self.ready_callbacks = deque() self._lock = 0 self.info = {'db': 0, 'pass': None} def __del__(self): self.disconnect() def execute_pending_command(self): # Continue with the pending command execution # if all read operations are completed. if not self.read_callbacks and self.ready_callbacks: # Pop a SINGLE callback from the queue and execute it. # The next one will be executed from the code # invoked by the callback callback = self.ready_callbacks.popleft() callback() def ready(self): return (not self.read_callbacks and not self.ready_callbacks) def wait_until_ready(self, callback=None): if callback: if not self.ready(): callback = stack_context.wrap(callback) self.ready_callbacks.append(callback) else: callback() def connect(self): if not self._stream: try: if self.unix_socket_path: sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) sock.settimeout(self.timeout) sock.connect(self.unix_socket_path) else: sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM, 0) sock.setsockopt(socket.SOL_TCP, socket.TCP_NODELAY, 1) sock.settimeout(self.timeout) sock.connect((self.host, self.port)) self._stream = IOStream(sock, io_loop=self._io_loop) self._stream.set_close_callback(self.on_stream_close) self.info['db'] = 0 self.info['pass'] = None except socket.error as e: raise ConnectionError(str(e)) self.fire_event('on_connect') def on_stream_close(self): if self._stream: self.disconnect() callbacks = self.read_callbacks self.read_callbacks = set() for callback in callbacks: callback() def disconnect(self): if self._stream: s = self._stream self._stream = None try: if s.socket: s.socket.shutdown(socket.SHUT_RDWR) s.close() except: pass def fire_event(self, event): event_handler = self._event_handler if event_handler: try: getattr(event_handler, event)() except AttributeError: pass def write(self, data, callback=None): if not self._stream: raise ConnectionError('Tried to write to ' 'non-existent connection') if callback: callback = stack_context.wrap(callback) _callback = lambda: callback(None) self.read_callbacks.add(_callback) cb = partial(self.read_callback, _callback) else: cb = None try: if PY3: data = bytes(data, encoding='utf-8') self._stream.write(data, callback=cb) except IOError as e: self.disconnect() raise ConnectionError(e.message) def read(self, length, callback=None): try: if not self._stream: self.disconnect() raise ConnectionError('Tried to read from ' 'non-existent connection') callback = stack_context.wrap(callback) self.read_callbacks.add(callback) self._stream.read_bytes(length, callback=partial(self.read_callback, callback)) except IOError: self.fire_event('on_disconnect') def read_callback(self, callback, *args, **kwargs): try: self.read_callbacks.remove(callback) except KeyError: pass callback(*args, **kwargs) def readline(self, callback=None): try: if not self._stream: self.disconnect() raise ConnectionError('Tried to read from ' 'non-existent connection') callback = stack_context.wrap(callback) self.read_callbacks.add(callback) callback = partial(self.read_callback, callback) self._stream.read_until(CRLF, callback=callback) except IOError: self.fire_event('on_disconnect') def connected(self): if self._stream: return True return False
class AsyncRedisClient(object): """An non-blocking Redis client. Example usage:: import ioloop def handle_request(result): print 'Redis reply: %r' % result ioloop.IOLoop.instance().stop() redis_client = AsyncRedisClient(('127.0.0.1', 6379)) redis_client.fetch(('set', 'foo', 'bar'), None) redis_client.fetch(('get', 'foo'), handle_request) ioloop.IOLoop.instance().start() This class implements a Redis client on top of Tornado's IOStreams. It does not currently implement all applicable parts of the Redis specification, but it does enough to work with major redis server APIs (mostly tested against the LIST/HASH/PUBSUB API so far). This class has not been tested extensively in production and should be considered somewhat experimental as of the release of tornado 1.2. It is intended to become the default tornado AsyncRedisClient implementation. """ def __init__(self, address, io_loop=None, socket_timeout=10): """Creates a AsyncRedisClient. address is the tuple of redis server address that can be connect by IOStream. It can be to ('127.0.0.1', 6379). """ self.address = address self.io_loop = io_loop or IOLoop.instance() self._callback_queue = deque() self._callback = None self._read_buffer = None self._result_queue = deque() self.socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) self.socket.settimeout(socket_timeout) self.socket.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) self.socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) self.stream = IOStream(self.socket, self.io_loop) self.stream.connect(self.address, self._wait_result) def close(self): """Destroys this redis client, freeing any file descriptors used. Not needed in normal use, but may be helpful in unittests that create and destroy redis clients. No other methods may be called on the AsyncRedisClient after close(). """ self.stream.close() def fetch(self, request, callback): """Executes a request, calling callback with an redis `result`. The request should be a string tuple. like ('set', 'foo', 'bar') If an error occurs during the fetch, a `RedisError` exception will throw out. You can use try...except to catch the exception (if any) in the callback. """ self._callback_queue.append(callback) self.stream.write(encode(request)) def _wait_result(self): """Read a completed result data from the redis server.""" self._read_buffer = deque() self.stream.read_until('\r\n', self._on_read_first_line) def _maybe_callback(self): """Try call callback in _callback_queue when we read a redis result.""" try: read_buffer = self._read_buffer callback = self._callback result_queue = self._result_queue callback_queue = self._callback_queue if result_queue: result_queue.append(read_buffer) read_buffer = result_queue.popleft() if callback_queue: callback = self._callback = callback_queue.popleft() if callback: callback(decode(read_buffer)) except Exception: logging.error('Uncaught callback exception', exc_info=True) self.close() raise finally: self._wait_result() def _on_read_first_line(self, data): self._read_buffer.append(data) c = data[0] if c in ':+-': self._maybe_callback() elif c == '$': if data[:3] == '$-1': self._maybe_callback() else: length = int(data[1:]) self.stream.read_bytes(length + 2, self._on_read_bulk_body) elif c == '*': if data[1] in '-0': self._maybe_callback() else: self._multibulk_number = int(data[1:]) self.stream.read_until('\r\n', self._on_read_multibulk_bulk_head) def _on_read_bulk_body(self, data): self._read_buffer.append(data) self._maybe_callback() def _on_read_multibulk_bulk_head(self, data): self._read_buffer.append(data) c = data[0] if c == '$': length = int(data[1:]) self.stream.read_bytes(length + 2, self._on_read_multibulk_bulk_body) else: self._maybe_callback() def _on_read_multibulk_bulk_body(self, data): self._read_buffer.append(data) self._multibulk_number -= 1 if self._multibulk_number: self.stream.read_until('\r\n', self._on_read_multibulk_bulk_head) else: self._maybe_callback()
class AsyncConn(event.EventedMixin): """ Low level object representing a TCP connection to nsqd. When a message on this connection is requeued and the requeue delay has not been specified, it calculates the delay automatically by an increasing multiple of ``requeue_delay``. Generates the following events that can be listened to with :meth:`nsq.AsyncConn.on`: * ``connect`` * ``close`` * ``error`` * ``identify`` * ``identify_response`` * ``auth`` * ``auth_response`` * ``heartbeat`` * ``ready`` * ``message`` * ``response`` * ``backoff`` * ``resume`` :param host: the host to connect to :param port: the post to connect to :param timeout: the timeout for read/write operations (in seconds) :param heartbeat_interval: the amount of time (in seconds) to negotiate with the connected producers to send heartbeats (requires nsqd 0.2.19+) :param requeue_delay: the base multiple used when calculating requeue delay (multiplied by # of attempts) :param tls_v1: enable TLS v1 encryption (requires nsqd 0.2.22+) :param tls_options: dictionary of options to pass to `ssl.wrap_socket() <http://docs.python.org/2/library/ssl.html#ssl.wrap_socket>`_ as ``**kwargs`` :param snappy: enable Snappy stream compression (requires nsqd 0.2.23+) :param deflate: enable deflate stream compression (requires nsqd 0.2.23+) :param deflate_level: configure the deflate compression level for this connection (requires nsqd 0.2.23+) :param output_buffer_size: size of the buffer (in bytes) used by nsqd for buffering writes to this connection :param output_buffer_timeout: timeout (in ms) used by nsqd before flushing buffered writes (set to 0 to disable). **Warning**: configuring clients with an extremely low (``< 25ms``) ``output_buffer_timeout`` has a significant effect on ``nsqd`` CPU usage (particularly with ``> 50`` clients connected). :param sample_rate: take only a sample of the messages being sent to the client. Not setting this or setting it to 0 will ensure you get all the messages destined for the client. Sample rate can be greater than 0 or less than 100 and the client will receive that percentage of the message traffic. (requires nsqd 0.2.25+) :param user_agent: a string identifying the agent for this client in the spirit of HTTP (default: ``<client_library_name>/<version>``) (requires nsqd 0.2.25+) :param auth_secret: a string passed when using nsq auth (requires nsqd 1.0+) :param msg_timeout: the amount of time (in seconds) that nsqd will wait before considering messages that have been delivered to this consumer timed out (requires nsqd 0.2.28+) :param hostname: a string identifying the host where this client runs (default: ``<hostname>``) """ def __init__(self, host, port, timeout=1.0, heartbeat_interval=30, requeue_delay=90, tls_v1=False, tls_options=None, snappy=False, deflate=False, deflate_level=6, user_agent=DEFAULT_USER_AGENT, output_buffer_size=16 * 1024, output_buffer_timeout=250, sample_rate=0, auth_secret=None, msg_timeout=None, hostname=None): assert isinstance(host, string_types) assert isinstance(port, int) assert isinstance(timeout, float) assert isinstance(tls_options, (dict, None.__class__)) assert isinstance(deflate_level, int) assert isinstance(heartbeat_interval, int) and heartbeat_interval >= 1 assert isinstance(requeue_delay, int) and requeue_delay >= 0 assert isinstance(output_buffer_size, int) and output_buffer_size >= 0 assert isinstance(output_buffer_timeout, int) and output_buffer_timeout >= 0 assert isinstance(sample_rate, int) and sample_rate >= 0 and sample_rate < 100 assert msg_timeout is None or (isinstance(msg_timeout, (float, int)) and msg_timeout > 0) # auth_secret validated by to_bytes() below self.state = INIT self.host = host self.port = port self.timeout = timeout self.last_recv_timestamp = time.time() self.last_msg_timestamp = time.time() self.in_flight = 0 self.rdy = 0 self.rdy_timeout = None # for backwards compatibility when interacting with older nsqd # (pre 0.2.20), default this to their hard-coded max self.max_rdy_count = 2500 self.tls_v1 = tls_v1 self.tls_options = tls_options self.snappy = snappy self.deflate = deflate self.deflate_level = deflate_level self.hostname = hostname if self.hostname is None: self.hostname = socket.gethostname() self.short_hostname = self.hostname.split('.')[0] self.heartbeat_interval = heartbeat_interval * 1000 self.msg_timeout = int(msg_timeout * 1000) if msg_timeout else None self.requeue_delay = requeue_delay self.output_buffer_size = output_buffer_size self.output_buffer_timeout = output_buffer_timeout self.sample_rate = sample_rate self.user_agent = user_agent self._authentication_required = False # tracking server auth state self.auth_secret = to_bytes(auth_secret) if auth_secret else None self.socket = None self.stream = None self._features_to_enable = [] self.last_rdy = 0 self.rdy = 0 self.callback_queue = [] self.encoder = DefaultEncoder() super(AsyncConn, self).__init__() @property def id(self): return str(self) def __str__(self): return self.host + ':' + str(self.port) def connected(self): return self.state == CONNECTED def connecting(self): return self.state == CONNECTING def closed(self): return self.state in (INIT, DISCONNECTED) def connect(self): if not self.closed(): return # Assume host is an ipv6 address if it has a colon. if ':' in self.host: self.socket = socket.socket(socket.AF_INET6, socket.SOCK_STREAM) else: self.socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) self.socket.settimeout(self.timeout) self.socket.setblocking(0) self.stream = IOStream(self.socket) self.stream.set_close_callback(self._socket_close) self.stream.set_nodelay(True) self.state = CONNECTING self.on(event.CONNECT, self._on_connect) self.on(event.DATA, self._on_data) fut = self.stream.connect((self.host, self.port)) IOLoop.current().add_future(fut, self._connect_callback) def _connect_callback(self, fut): fut.result() self.state = CONNECTED self.stream.write(protocol.MAGIC_V2) self._start_read() self.trigger(event.CONNECT, conn=self) def _read_bytes(self, size, callback): try: fut = self.stream.read_bytes(size) IOLoop.current().add_future(fut, callback) except IOError: self.close() self.trigger( event.ERROR, conn=self, error=protocol.ConnectionClosedError('Stream is closed'), ) def _start_read(self): if self.stream is None: return # IOStream.start_tls() invalidates stream, will call again when ready self._read_bytes(4, self._read_size) def _socket_close(self): self.state = DISCONNECTED self.trigger(event.CLOSE, conn=self) def close(self): self.stream.close() def _read_size(self, fut): try: data = fut.result() size = struct_l.unpack(data)[0] except Exception: self.close() self.trigger( event.ERROR, conn=self, error=protocol.IntegrityError('failed to unpack size'), ) return self._read_bytes(size, self._read_body) def _read_body(self, fut): try: data = fut.result() self.trigger(event.DATA, conn=self, data=data) except Exception: logger.exception('uncaught exception in data event') self._start_read() def send(self, data): return self.stream.write(self.encoder.encode(data)) def upgrade_to_tls(self, options=None): # in order to upgrade to TLS we need to *replace* the IOStream... opts = { 'cert_reqs': ssl.CERT_REQUIRED, 'ssl_version': ssl.PROTOCOL_TLSv1_2 } opts.update(options or {}) fut = self.stream.start_tls(False, ssl_options=opts, server_hostname=self.host) self.stream = None def finish_upgrade_tls(fut): try: self.stream = fut.result() self.socket = self.stream.socket self._start_read() except Exception as e: # skip self.close() because no stream self.trigger( event.ERROR, conn=self, error=protocol.SendError('failed to upgrade to TLS', e), ) IOLoop.current().add_future(fut, finish_upgrade_tls) def upgrade_to_snappy(self): assert SnappySocket, 'snappy requires the python-snappy package' # in order to upgrade to Snappy we need to use whatever IOStream # is currently in place (normal or SSL)... # # first read any compressed bytes the existing IOStream might have # already buffered and use that to bootstrap the SnappySocket, then # monkey patch the existing IOStream by replacing its socket # with a wrapper that will automagically handle compression. existing_data = self.stream._consume(self.stream._read_buffer_size) self.socket = SnappySocket(self.socket) self.socket.bootstrap(existing_data) self.stream.socket = self.socket self.encoder = SnappyEncoder() def upgrade_to_deflate(self): # in order to upgrade to DEFLATE we need to use whatever IOStream # is currently in place (normal or SSL)... # # first read any compressed bytes the existing IOStream might have # already buffered and use that to bootstrap the DeflateSocket, then # monkey patch the existing IOStream by replacing its socket # with a wrapper that will automagically handle compression. existing_data = self.stream._consume(self.stream._read_buffer_size) self.socket = DeflateSocket(self.socket, self.deflate_level) self.socket.bootstrap(existing_data) self.stream.socket = self.socket self.encoder = DeflateEncoder(level=self.deflate_level) def send_rdy(self, value): if self.last_rdy != value: try: self.send(protocol.ready(value)) except Exception as e: self.close() self.trigger( event.ERROR, conn=self, error=protocol.SendError('failed to send RDY %d' % value, e), ) return False self.last_rdy = value self.rdy = value return True def _on_connect(self, **kwargs): identify_data = { 'short_id': self. short_hostname, # TODO remove when deprecating pre 1.0 support 'long_id': self.hostname, # TODO remove when deprecating pre 1.0 support 'client_id': self.short_hostname, 'hostname': self.hostname, 'heartbeat_interval': self.heartbeat_interval, 'feature_negotiation': True, 'tls_v1': self.tls_v1, 'snappy': self.snappy, 'deflate': self.deflate, 'deflate_level': self.deflate_level, 'output_buffer_timeout': self.output_buffer_timeout, 'output_buffer_size': self.output_buffer_size, 'sample_rate': self.sample_rate, 'user_agent': self.user_agent } if self.msg_timeout: identify_data['msg_timeout'] = self.msg_timeout self.trigger(event.IDENTIFY, conn=self, data=identify_data) self.on(event.RESPONSE, self._on_identify_response) try: self.send(protocol.identify(identify_data)) except Exception as e: self.close() self.trigger( event.ERROR, conn=self, error=protocol.SendError('failed to bootstrap connection', e), ) def _on_identify_response(self, data, **kwargs): self.off(event.RESPONSE, self._on_identify_response) if data == b'OK': logger.warning( 'nsqd version does not support feature netgotiation') return self.trigger(event.READY, conn=self) try: data = json.loads(data.decode('utf-8')) except ValueError: self.close() self.trigger( event.ERROR, conn=self, error=protocol.IntegrityError( 'failed to parse IDENTIFY response JSON from nsqd - %r' % data), ) return self.trigger(event.IDENTIFY_RESPONSE, conn=self, data=data) if self.tls_v1 and data.get('tls_v1'): self._features_to_enable.append('tls_v1') if self.snappy and data.get('snappy'): self._features_to_enable.append('snappy') if self.deflate and data.get('deflate'): self._features_to_enable.append('deflate') if data.get('auth_required'): self._authentication_required = True if data.get('max_rdy_count'): self.max_rdy_count = data.get('max_rdy_count') else: # for backwards compatibility when interacting with older nsqd # (pre 0.2.20), default this to their hard-coded max logger.warn('setting max_rdy_count to default value of 2500') self.max_rdy_count = 2500 self.on(event.RESPONSE, self._on_response_continue) self._on_response_continue(conn=self, data=None) def _on_response_continue(self, data, **kwargs): if self._features_to_enable: feature = self._features_to_enable.pop(0) if feature == 'tls_v1': self.upgrade_to_tls(self.tls_options) elif feature == 'snappy': self.upgrade_to_snappy() elif feature == 'deflate': self.upgrade_to_deflate() # the server will 'OK' after these connection upgrades triggering another response return self.off(event.RESPONSE, self._on_response_continue) if self.auth_secret and self._authentication_required: self.on(event.RESPONSE, self._on_auth_response) self.trigger(event.AUTH, conn=self, data=self.auth_secret) try: self.send(protocol.auth(self.auth_secret)) except Exception as e: self.close() self.trigger( event.ERROR, conn=self, error=protocol.SendError('Error sending AUTH', e), ) return self.trigger(event.READY, conn=self) def _on_auth_response(self, data, **kwargs): try: data = json.loads(data.decode('utf-8')) except ValueError: self.close() self.trigger( event.ERROR, conn=self, error=protocol.IntegrityError( 'failed to parse AUTH response JSON from nsqd - %r' % data), ) return self.off(event.RESPONSE, self._on_auth_response) self.trigger(event.AUTH_RESPONSE, conn=self, data=data) return self.trigger(event.READY, conn=self) def _on_data(self, data, **kwargs): self.last_recv_timestamp = time.time() frame, data = protocol.unpack_response(data) if frame == protocol.FRAME_TYPE_MESSAGE: self.last_msg_timestamp = time.time() self.in_flight += 1 message = protocol.decode_message(data) message.on(event.FINISH, self._on_message_finish) message.on(event.REQUEUE, self._on_message_requeue) message.on(event.TOUCH, self._on_message_touch) self.trigger(event.MESSAGE, conn=self, message=message) elif frame == protocol.FRAME_TYPE_RESPONSE and data == b'_heartbeat_': self.send(protocol.nop()) self.trigger(event.HEARTBEAT, conn=self) elif frame == protocol.FRAME_TYPE_RESPONSE: self.trigger(event.RESPONSE, conn=self, data=data) elif frame == protocol.FRAME_TYPE_ERROR: self.trigger(event.ERROR, conn=self, error=protocol.Error(data)) def _on_message_requeue(self, message, backoff=True, time_ms=-1, **kwargs): if backoff: self.trigger(event.BACKOFF, conn=self) else: self.trigger(event.CONTINUE, conn=self) self.in_flight -= 1 try: time_ms = self.requeue_delay * message.attempts * 1000 if time_ms < 0 else time_ms self.send(protocol.requeue(message.id, time_ms)) except Exception as e: self.close() self.trigger(event.ERROR, conn=self, error=protocol.SendError( 'failed to send REQ %s @ %d' % (message.id, time_ms), e)) def _on_message_finish(self, message, **kwargs): self.trigger(event.RESUME, conn=self) self.in_flight -= 1 try: self.send(protocol.finish(message.id)) except Exception as e: self.close() self.trigger( event.ERROR, conn=self, error=protocol.SendError('failed to send FIN %s' % message.id, e), ) def _on_message_touch(self, message, **kwargs): try: self.send(protocol.touch(message.id)) except Exception as e: self.close() self.trigger( event.ERROR, conn=self, error=protocol.SendError( 'failed to send TOUCH %s' % message.id, e), )
class Connection(object): def __init__(self, host='localhost', port=11211, pool=None): self._host = host self._port = port self._pool = pool self._socket = None self._stream = None self._ioloop = IOLoop.instance() self.connect() def connect(self): try: self._socket = socket(AF_INET, SOCK_STREAM, 0) self._socket.connect((self._host, self._port)) self._stream = IOStream(self._socket, io_loop=self._ioloop) self._stream.set_close_callback(self.on_disconnect) except error as e: raise ConnectionError(e) def disconect(self): callback = self._final_callback self._final_callback = None try: if callback: callback(None) finally: self._stream._close_callback = None self._stream.close() def on_disconnect(self): callback = self._final_callback self._final_callback = None try: if callback: callback(None) finally: logging.debug('asyncmemcached closing connection') self._pool.release(self) def closed(self): return self._stream.closed() def send_command(self, fullcmd, expect_str, callback): self._final_callback = callback if self._stream.closed(): self.connect() with stack_context.StackContext(self.cleanup): if fullcmd[0:3] == 'get' or \ fullcmd[0:4] == 'incr' or \ fullcmd[0:4] == 'decr': self._stream.write(fullcmd, self.read_value) else: self._stream.write(fullcmd, functools.partial(self.read_response, expect_str)) def read_response(self, expect_str): self._stream.read_until('\r\n', functools.partial(self._expect_callback, expect_str)) def read_value(self): self._stream.read_until('\r\n', self._expect_value_header_callback) def _expect_value_header_callback(self, response): response = response[:-2] if response[:5] == 'VALUE': resp, key, flag, length = response.split() flag = int(flag) length = int(length) self._stream.read_bytes(length+2, self._expect_value_callback) elif response.isdigit(): try: callback = self._final_callback self._final_callback = None if callback: callback(int(response)) finally: self._pool.release(self) else: try: callback = self._final_callback self._final_callback = None if callback: callback(None) finally: self._pool.release(self) def _expect_value_callback(self, value): value = value[:-2] self._stream.read_until('\r\n', functools.partial(self._end_value_callback, value)) def _end_value_callback(self, value, response): response = response.rstrip('\r\n') if response == 'END': try: callback = self._final_callback self._final_callback = None if callback: callback(value) finally: self._pool.release(self) else: raise RedisError('error %s' % response) def _expect_callback(self, expect_str, response): response = response.rstrip('\r\n') if response == expect_str: try: callback = self._final_callback self._final_callback = None if callback: callback(None) finally: self._pool.release(self) else: raise RedisError('error %s' % response) @contextlib.contextmanager def cleanup(self): try: yield except Exception as e: logging.warning("uncaught exception", exc_info=True) try: callback = self._final_callback self._final_callback = None if callback: callback(None) finally: self._pool.release(self)
class KeepAliveTest(AsyncHTTPTestCase): """Tests various scenarios for HTTP 1.1 keep-alive support. These tests don't use AsyncHTTPClient because we want to control connection reuse and closing. """ def get_app(self): class HelloHandler(RequestHandler): def get(self): self.finish('Hello world') def post(self): self.finish('Hello world') class LargeHandler(RequestHandler): def get(self): # 512KB should be bigger than the socket buffers so it will # be written out in chunks. self.write(''.join(chr(i % 256) * 1024 for i in range(512))) class FinishOnCloseHandler(RequestHandler): @asynchronous def get(self): self.flush() def on_connection_close(self): # This is not very realistic, but finishing the request # from the close callback has the right timing to mimic # some errors seen in the wild. self.finish('closed') return Application([('/', HelloHandler), ('/large', LargeHandler), ('/finish_on_close', FinishOnCloseHandler)]) def setUp(self): super(KeepAliveTest, self).setUp() self.http_version = b'HTTP/1.1' def tearDown(self): # We just closed the client side of the socket; let the IOLoop run # once to make sure the server side got the message. self.io_loop.add_timeout(datetime.timedelta(seconds=0.001), self.stop) self.wait() if hasattr(self, 'stream'): self.stream.close() super(KeepAliveTest, self).tearDown() # The next few methods are a crude manual http client def connect(self): self.stream = IOStream(socket.socket()) self.stream.connect(('127.0.0.1', self.get_http_port()), self.stop) self.wait() def read_headers(self): self.stream.read_until(b'\r\n', self.stop) first_line = self.wait() self.assertTrue(first_line.startswith(b'HTTP/1.1 200'), first_line) self.stream.read_until(b'\r\n\r\n', self.stop) header_bytes = self.wait() headers = HTTPHeaders.parse(header_bytes.decode('latin1')) return headers def read_response(self): self.headers = self.read_headers() self.stream.read_bytes(int(self.headers['Content-Length']), self.stop) body = self.wait() self.assertEqual(b'Hello world', body) def close(self): self.stream.close() del self.stream def test_two_requests(self): self.connect() self.stream.write(b'GET / HTTP/1.1\r\n\r\n') self.read_response() self.stream.write(b'GET / HTTP/1.1\r\n\r\n') self.read_response() self.close() def test_request_close(self): self.connect() self.stream.write(b'GET / HTTP/1.1\r\nConnection: close\r\n\r\n') self.read_response() self.stream.read_until_close(callback=self.stop) data = self.wait() self.assertTrue(not data) self.assertEqual(self.headers['Connection'], 'close') self.close() # keepalive is supported for http 1.0 too, but it's opt-in def test_http10(self): self.http_version = b'HTTP/1.0' self.connect() self.stream.write(b'GET / HTTP/1.0\r\n\r\n') self.read_response() self.stream.read_until_close(callback=self.stop) data = self.wait() self.assertTrue(not data) self.assertTrue('Connection' not in self.headers) self.close() def test_http10_keepalive(self): self.http_version = b'HTTP/1.0' self.connect() self.stream.write(b'GET / HTTP/1.0\r\nConnection: keep-alive\r\n\r\n') self.read_response() self.assertEqual(self.headers['Connection'], 'Keep-Alive') self.stream.write(b'GET / HTTP/1.0\r\nConnection: keep-alive\r\n\r\n') self.read_response() self.assertEqual(self.headers['Connection'], 'Keep-Alive') self.close() def test_http10_keepalive_extra_crlf(self): self.http_version = b'HTTP/1.0' self.connect() self.stream.write(b'GET / HTTP/1.0\r\nConnection: keep-alive\r\n\r\n\r\n') self.read_response() self.assertEqual(self.headers['Connection'], 'Keep-Alive') self.stream.write(b'GET / HTTP/1.0\r\nConnection: keep-alive\r\n\r\n') self.read_response() self.assertEqual(self.headers['Connection'], 'Keep-Alive') self.close() def test_pipelined_requests(self): self.connect() self.stream.write(b'GET / HTTP/1.1\r\n\r\nGET / HTTP/1.1\r\n\r\n') self.read_response() self.read_response() self.close() def test_pipelined_cancel(self): self.connect() self.stream.write(b'GET / HTTP/1.1\r\n\r\nGET / HTTP/1.1\r\n\r\n') # only read once self.read_response() self.close() def test_cancel_during_download(self): self.connect() self.stream.write(b'GET /large HTTP/1.1\r\n\r\n') self.read_headers() self.stream.read_bytes(1024, self.stop) self.wait() self.close() def test_finish_while_closed(self): self.connect() self.stream.write(b'GET /finish_on_close HTTP/1.1\r\n\r\n') self.read_headers() self.close() def test_keepalive_chunked(self): self.http_version = b'HTTP/1.0' self.connect() self.stream.write(b'POST / HTTP/1.0\r\n' b'Connection: keep-alive\r\n' b'Transfer-Encoding: chunked\r\n' b'\r\n' b'0\r\n' b'\r\n') self.read_response() self.assertEqual(self.headers['Connection'], 'Keep-Alive') self.stream.write(b'GET / HTTP/1.0\r\nConnection: keep-alive\r\n\r\n') self.read_response() self.assertEqual(self.headers['Connection'], 'Keep-Alive') self.close()
class TTornadoTransport(TTransportBase): """A non-blocking Thrift client. Example usage:: import greenlet from tornado import ioloop from thrift.transport import TTransport from thrift.protocol import TBinaryProtocol from viewfinder.backend.thrift import TTornadoTransport transport = TTransport.TFramedTransport(TTornadoTransport('localhost', 9090)) protocol = TBinaryProtocol.TBinaryProtocol(transport) client = Service.Client(protocol) ioloop.IOLoop.instance().start() Then, from within an asynchronous tornado request handler: class MyApp(tornado.web.RequestHandler): @tornado.web.asynchronous def post(self): def business_logic(): ...any thrift calls... self.write(...stuff that gets returned to client...) self.finish() #end the asynchronous request gr = greenlet.greenlet(business_logic) gr.switch() """ def __init__(self, host='localhost', port=9090): """Initialize a TTornadoTransport with a Tornado IOStream. @param host(str) The host to connect to. @param port(int) The (TCP) port to connect to. """ self.host = host self.port = port self._stream = None self._io_loop = ioloop.IOLoop.current() self._timeout_secs = None def set_timeout(self, timeout_secs): """Sets a timeout for use with open/read/write operations.""" self._timeout_secs = timeout_secs def isOpen(self): return self._stream is not None def open(self): """Creates a connection to host:port and spins up a tornado IOStream object to write requests and read responses from the thrift server. After making the asynchronous connect call to _stream, the current greenlet yields control back to the parent greenlet (presumably the "master" greenlet). """ assert greenlet.getcurrent().parent is not None # TODO(spencer): allow ipv6? (af = socket.AF_UNSPEC) addrinfo = socket.getaddrinfo(self.host, self.port, socket.AF_INET, socket.SOCK_STREAM, 0, 0) af, socktype, proto, canonname, sockaddr = addrinfo[0] self._stream = IOStream(socket.socket(af, socktype, proto), io_loop=self._io_loop) self._open_internal(sockaddr) def close(self): if self._stream: self._stream.set_close_callback(None) self._stream.close() self._stream = None @_wrap_transport def read(self, sz): logging.debug("reading %d bytes from %s:%d" % (sz, self.host, self.port)) cur_gr = greenlet.getcurrent() def _on_read(buf): if self._stream: cur_gr.switch(buf) self._stream.read_bytes(sz, _on_read) buf = cur_gr.parent.switch() if len(buf) == 0: raise TTransportException(type=TTransportException.END_OF_FILE, message='TTornadoTransport read 0 bytes') logging.debug("read %d bytes in %.2fms" % (len(buf), (time.time() - self._start_time) * 1000)) return buf @_wrap_transport def write(self, buf): logging.debug("writing %d bytes to %s:%d" % (len(buf), self.host, self.port)) cur_gr = greenlet.getcurrent() def _on_write(): if self._stream: cur_gr.switch() self._stream.write(buf, _on_write) cur_gr.parent.switch() logging.debug("wrote %d bytes in %.2fms" % (len(buf), (time.time() - self._start_time) * 1000)) @_wrap_transport def flush(self): pass @_wrap_transport def _open_internal(self, sockaddr): logging.debug("opening connection to %s:%d" % (self.host, self.port)) cur_gr = greenlet.getcurrent() def _on_connect(): if self._stream: cur_gr.switch() self._stream.connect(sockaddr, _on_connect) cur_gr.parent.switch() logging.info("opened connection to %s:%d" % (self.host, self.port)) def _check_stream(self): if not self._stream: raise TTransportException( type=TTransportException.NOT_OPEN, message='transport not open') def _set_timeout(self): if self._timeout_secs: return self._io_loop.add_timeout( time.time() + self._timeout_secs, functools.partial( self._on_timeout, gr=greenlet.getcurrent())) return None def _clear_timeout(self, timeout): if timeout: self._io_loop.remove_timeout(timeout) def _on_timeout(self, gr): gr.throw(TTransportException( type=TTransportException.TIMED_OUT, message="connection timed out to %s:%d" % (self.host, self.port))) def _on_close(self, gr): self._stream = None message = "connection to %s:%d closed" % (self.host, self.port) if gr: gr.throw(TTransportException( type=TTransportException.NOT_OPEN, message=message)) else: logging.error(message)