예제 #1
0
 def test_100_continue(self):
     # Run through a 100-continue interaction by hand:
     # When given Expect: 100-continue, we get a 100 response after the
     # headers, and then the real response after the body.
     stream = IOStream(socket.socket())
     stream.connect(("127.0.0.1", self.get_http_port()), callback=self.stop)
     self.wait()
     stream.write(b"\r\n".join([b"POST /hello HTTP/1.1",
                                b"Content-Length: 1024",
                                b"Expect: 100-continue",
                                b"Connection: close",
                                b"\r\n"]), callback=self.stop)
     self.wait()
     stream.read_until(b"\r\n\r\n", self.stop)
     data = self.wait()
     self.assertTrue(data.startswith(b"HTTP/1.1 100 "), data)
     stream.write(b"a" * 1024)
     stream.read_until(b"\r\n", self.stop)
     first_line = self.wait()
     self.assertTrue(first_line.startswith(b"HTTP/1.1 200"), first_line)
     stream.read_until(b"\r\n\r\n", self.stop)
     header_data = self.wait()
     headers = HTTPHeaders.parse(native_str(header_data.decode('latin1')))
     stream.read_bytes(int(headers["Content-Length"]), self.stop)
     body = self.wait()
     self.assertEqual(body, b"Got 1024 bytes in POST")
     stream.close()
예제 #2
0
def handle_connection(connection, address):
    log.info('Connection received from %s' % str(address))
    stream = IOStream(connection, ioloop, max_buffer_size=1024 * 1024 * 1024)
    # Getting uuid
    try:
        stream.read_bytes(4, partial(read_uuid_size, stream))
    except StreamClosedError:
        log.warn('Closed stream for getting uuid length')
예제 #3
0
파일: streams.py 프로젝트: apriljdai/wdb
def handle_connection(connection, address):
    log.info("Connection received from %s" % str(address))
    stream = IOStream(connection, ioloop, max_buffer_size=1024 * 1024 * 1024)
    # Getting uuid
    try:
        stream.read_bytes(4, partial(read_uuid_size, stream))
    except StreamClosedError:
        log.warn("Closed stream for getting uuid length")
예제 #4
0
파일: streams.py 프로젝트: clawplach/wdb
def handle_connection(connection, address):
    log.info('Connection received from %s' % str(address))
    stream = IOStream(connection, ioloop)
    # Getting uuid
    try:
        stream.read_bytes(4, partial(read_uuid_size, stream))
    except StreamClosedError:
        log.warn('Closed stream for getting uuid length')
예제 #5
0
class Client(object):
    def __init__(self, host, port, protocol, name, session, dial=True):
        self._host = host
        self._port = port
        self._protocol = protocol
        self._name = name
        self._session = session
        self._running = False
        self._socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM, 0)
        self._connection = IOStream(self._socket)
        if dial:
            self._connection.connect((host, port), self.on_dial)

    def dial(self):
        if self._running:
            return
        self._connection.connect((self._host, self._port), self.on_dial)
        self._read_loop()

    def on_dial(self):
        """
        subclass should implement this method, example:
        def on_dial(self):
            self._connection.send(msg)
            self._read_loop()
        :return: None
        """
        self._running = True

    def read_loop(self):
        #self._connection.read_bytes(self._protocol.head_size(), callback=self._debug)
        self._connection.read_bytes(self._protocol.head_size(),
                                    callback=self._handle_head)

    def _handle_head(self, head):
        receive_bytes, body_size = self._protocol.handle_head(
            self._session, head)
        # if self._session.receive_bytes >= receive_bytes:
        #     # TODO it is an error
        #     pass
        #     #return
        self._session.receive_bytes = receive_bytes

        def handler(body):
            self._protocol.handle(self._session, body)
            self.read_loop()

        if body_size > 0:
            self._connection.read_bytes(body_size, callback=handler)

    def send_raw(self, data):
        self._connection.write_to_fd(self._protocol.encode(
            self._session, data))

    def send(self, msg):
        raw_data = self._protocol.encode(self._session, msg.encode())
        self._connection.write_to_fd(raw_data)
예제 #6
0
class Flash(object):

    def __init__(self, close_callback=None):
        self._iostream = None
        self._close_callback = close_callback

    def connect(self, host='127.0.0.1', port=9999):
        sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
        self._iostream = IOStream(sock)
        self._iostream.set_close_callback(self._on_connection_close)

        # коннектимся и начинаем слушать команды
        self._iostream.connect((host, port), self._read_head)

    def close(self):
        self._on_connection_close()

    def _on_connection_close(self):
        self._iostream.close()
        if self._close_callback:
            self._close_callback()

    def _read_head(self):
        self._iostream.read_bytes(BaseCommand.meta_size, self._on_read_head)

    def _on_read_head(self, data):
        ctype, length = struct.unpack(">BH", data)

        if length:
            self._iostream.read_bytes(length, partial(self.execute_command, ctype))
        else:
            self.execute_command(ctype)

    def execute_command(self, ctype, value=None):
        command = CommandsRegistry.get_by_type(ctype)

        if command is not None:
            command.execute(value)
        # else:
        #     print 'unknown command: type={:#x}'.format(ctype)

        self._read_head()

    @classmethod
    def start(cls, host, port):
        flash = cls(close_callback=IOLoop.instance().stop)
        flash.connect(host, port)

        signal.signal(signal.SIGINT, flash.close)

        IOLoop.instance().start()
        IOLoop.instance().close()
예제 #7
0
class UnixSocketTest(AsyncTestCase):
    """HTTPServers can listen on Unix sockets too.

    Why would you want to do this?  Nginx can proxy to backends listening
    on unix sockets, for one thing (and managing a namespace for unix
    sockets can be easier than managing a bunch of TCP port numbers).

    Unfortunately, there's no way to specify a unix socket in a url for
    an HTTP client, so we have to test this by hand.
    """
    def setUp(self):
        super(UnixSocketTest, self).setUp()
        self.tmpdir = tempfile.mkdtemp()
        self.sockfile = os.path.join(self.tmpdir, "test.sock")
        sock = netutil.bind_unix_socket(self.sockfile)
        app = Application([("/hello", HelloWorldRequestHandler)])
        self.server = HTTPServer(app)
        self.server.add_socket(sock)
        self.stream = IOStream(socket.socket(socket.AF_UNIX))
        self.stream.connect(self.sockfile, self.stop)
        self.wait()

    def tearDown(self):
        self.stream.close()
        self.io_loop.run_sync(self.server.close_all_connections)
        self.server.stop()
        shutil.rmtree(self.tmpdir)
        super(UnixSocketTest, self).tearDown()

    def test_unix_socket(self):
        self.stream.write(b"GET /hello HTTP/1.0\r\n\r\n")
        self.stream.read_until(b"\r\n", self.stop)
        response = self.wait()
        self.assertEqual(response, b"HTTP/1.1 200 OK\r\n")
        self.stream.read_until(b"\r\n\r\n", self.stop)
        headers = HTTPHeaders.parse(self.wait().decode('latin1'))
        self.stream.read_bytes(int(headers["Content-Length"]), self.stop)
        body = self.wait()
        self.assertEqual(body, b"Hello world")

    def test_unix_socket_bad_request(self):
        # Unix sockets don't have remote addresses so they just return an
        # empty string.
        with ExpectLog(gen_log, "Malformed HTTP message from"):
            self.stream.write(b"garbage\r\n\r\n")
            self.stream.read_until_close(self.stop)
            response = self.wait()
        self.assertEqual(response, b"HTTP/1.1 400 Bad Request\r\n\r\n")
예제 #8
0
class UnixSocketTest(AsyncTestCase):
    """HTTPServers can listen on Unix sockets too.

    Why would you want to do this?  Nginx can proxy to backends listening
    on unix sockets, for one thing (and managing a namespace for unix
    sockets can be easier than managing a bunch of TCP port numbers).

    Unfortunately, there's no way to specify a unix socket in a url for
    an HTTP client, so we have to test this by hand.
    """
    def setUp(self):
        super(UnixSocketTest, self).setUp()
        self.tmpdir = tempfile.mkdtemp()
        self.sockfile = os.path.join(self.tmpdir, "test.sock")
        sock = netutil.bind_unix_socket(self.sockfile)
        app = Application([("/hello", HelloWorldRequestHandler)])
        self.server = HTTPServer(app, io_loop=self.io_loop)
        self.server.add_socket(sock)
        self.stream = IOStream(socket.socket(socket.AF_UNIX),
                               io_loop=self.io_loop)
        self.stream.connect(self.sockfile, self.stop)
        self.wait()

    def tearDown(self):
        self.stream.close()
        self.server.stop()
        shutil.rmtree(self.tmpdir)
        super(UnixSocketTest, self).tearDown()

    def test_unix_socket(self):
        self.stream.write(b"GET /hello HTTP/1.0\r\n\r\n")
        self.stream.read_until(b"\r\n", self.stop)
        response = self.wait()
        self.assertEqual(response, b"HTTP/1.0 200 OK\r\n")
        self.stream.read_until(b"\r\n\r\n", self.stop)
        headers = HTTPHeaders.parse(self.wait().decode('latin1'))
        self.stream.read_bytes(int(headers["Content-Length"]), self.stop)
        body = self.wait()
        self.assertEqual(body, b"Hello world")

    def test_unix_socket_bad_request(self):
        # Unix sockets don't have remote addresses so they just return an
        # empty string.
        with ExpectLog(gen_log, "Malformed HTTP message from"):
            self.stream.write(b"garbage\r\n\r\n")
            self.stream.read_until_close(self.stop)
            response = self.wait()
        self.assertEqual(response, b"")
예제 #9
0
 def test_100_continue(self):
     # Run through a 100-continue interaction by hand:
     # When given Expect: 100-continue, we get a 100 response after the
     # headers, and then the real response after the body.
     stream = IOStream(socket.socket())
     yield stream.connect(("127.0.0.1", self.get_http_port()))
     yield stream.write(
         b"\r\n".join(
             [
                 b"POST /hello HTTP/1.1",
                 b"Content-Length: 1024",
                 b"Expect: 100-continue",
                 b"Connection: close",
                 b"\r\n",
             ]
         )
     )
     data = yield stream.read_until(b"\r\n\r\n")
     self.assertTrue(data.startswith(b"HTTP/1.1 100 "), data)
     stream.write(b"a" * 1024)
     first_line = yield stream.read_until(b"\r\n")
     self.assertTrue(first_line.startswith(b"HTTP/1.1 200"), first_line)
     header_data = yield stream.read_until(b"\r\n\r\n")
     headers = HTTPHeaders.parse(native_str(header_data.decode("latin1")))
     body = yield stream.read_bytes(int(headers["Content-Length"]))
     self.assertEqual(body, b"Got 1024 bytes in POST")
     stream.close()
예제 #10
0
    def test_message_response(self):
        # handle_stream may be a coroutine and any exception in its
        # Future will be logged.
        server = client = None
        try:
            sock, port = bind_unused_port()
            sock2, port2 = bind_unused_port()

            with NullContext():
                server = StatusServer()

                notify_server = NotifyServer()
                notify_server.add_socket(sock2)

                server.notify_server = notify_server

                server.add_socket(sock)

            client = IOStream(socket.socket())        
            yield client.connect(('localhost', port))
            yield client.write(msg1)
            results = yield client.read_bytes(4)
            assert results == b'\x11\x00\x01\x10'

        finally:
            if server is not None:
                server.stop()
            if client is not None:
                client.close()
예제 #11
0
    async def handle_stream(self, stream: IOStream, address):
        print("connect from {0:s}:{1:d}".format(address[0], address[1]))
        loop = IOLoop.current()  #type: IOLoop
        frameBuffer = b''
        Q = Queue(maxsize=10)
        while True:
            try:
                if not stream.reading():
                    dataFuture = stream.read_bytes(
                        12, partial=True)  #type:futures.Future
                frameBuffer = frameBuffer + await gen.with_timeout(
                    timedelta(seconds=12), dataFuture)
                print("CurrentBuffer:", frameBuffer)
                if len(frameBuffer) < 24:
                    continue
                loop.run_in_executor(
                    None, partial(self.wrappedDecode, frameBuffer, Q))

                status = Q.get()
                frameBuffer = b''
                if status == self.DECODE_SUC:
                    await stream.write(bytes([0x3e]))
                else:
                    await stream.write(bytes([0x6c]))

            except StreamClosedError:
                print("connection closed from {0:s}:{1:d}".format(
                    address[0], address[1]))
                break

            except gen.TimeoutError:
                frameBuffer = b''
                print("No response in 3 seconds {0:s}:{1:d}".format(
                    address[0], address[1]))
예제 #12
0
파일: network.py 프로젝트: xlybaby/VAR
 def handle_connection(self, connection, address):
     stream = IOStream(connection)
     print("start handle request...")
     #message = yield stream.read_until_close()
     message = yield stream.read_bytes(20, partial=True)
     #print ("delimiter: ", chr(self._delimiter).encode())
     #stream.read_until(chr(self._delimiter).encode(), self.on_body)
     print("message from client:", message.decode().strip())
예제 #13
0
 def test_unix_socket(self):
     sockfile = os.path.join(self.tmpdir, "test.sock")
     sock = netutil.bind_unix_socket(sockfile)
     app = Application([("/hello", HelloWorldRequestHandler)])
     server = HTTPServer(app, io_loop=self.io_loop)
     server.add_socket(sock)
     stream = IOStream(socket.socket(socket.AF_UNIX), io_loop=self.io_loop)
     stream.connect(sockfile, self.stop)
     self.wait()
     stream.write(b("GET /hello HTTP/1.0\r\n\r\n"))
     stream.read_until(b("\r\n"), self.stop)
     response = self.wait()
     self.assertEqual(response, b("HTTP/1.0 200 OK\r\n"))
     stream.read_until(b("\r\n\r\n"), self.stop)
     headers = HTTPHeaders.parse(self.wait().decode('latin1'))
     stream.read_bytes(int(headers["Content-Length"]), self.stop)
     body = self.wait()
     self.assertEqual(body, b("Hello world"))
예제 #14
0
 def test_unix_socket(self):
     sockfile = os.path.join(self.tmpdir, "test.sock")
     sock = netutil.bind_unix_socket(sockfile)
     app = Application([("/hello", HelloWorldRequestHandler)])
     server = HTTPServer(app, io_loop=self.io_loop)
     server.add_socket(sock)
     stream = IOStream(socket.socket(socket.AF_UNIX), io_loop=self.io_loop)
     stream.connect(sockfile, self.stop)
     self.wait()
     stream.write(b("GET /hello HTTP/1.0\r\n\r\n"))
     stream.read_until(b("\r\n"), self.stop)
     response = self.wait()
     self.assertEqual(response, b("HTTP/1.0 200 OK\r\n"))
     stream.read_until(b("\r\n\r\n"), self.stop)
     headers = HTTPHeaders.parse(self.wait().decode('latin1'))
     stream.read_bytes(int(headers["Content-Length"]), self.stop)
     body = self.wait()
     self.assertEqual(body, b("Hello world"))
예제 #15
0
class TestIOStream(AsyncHTTPTestCase, LogTrapTestCase):
    def get_app(self):
        return Application([('/', HelloHandler)])

    def test_read_zero_bytes(self):
        s = socket.socket(socket.AF_INET, socket.SOCK_STREAM, 0)
        s.connect(("localhost", self.get_http_port()))
        self.stream = IOStream(s, io_loop=self.io_loop)
        self.stream.write(b("GET / HTTP/1.0\r\n\r\n"))

        # normal read
        self.stream.read_bytes(9, self.stop)
        data = self.wait()
        self.assertEqual(data, b("HTTP/1.0 "))

        # zero bytes
        self.stream.read_bytes(0, self.stop)
        data = self.wait()
        self.assertEqual(data, b(""))

        # another normal read
        self.stream.read_bytes(3, self.stop)
        data = self.wait()
        self.assertEqual(data, b("200"))

    def test_connection_refused(self):
        # When a connection is refused, the connect callback should not
        # be run.  (The kqueue IOLoop used to behave differently from the
        # epoll IOLoop in this respect)
        port = get_unused_port()
        stream = IOStream(socket.socket(), self.io_loop)
        self.connect_called = False
        def connect_callback():
            self.connect_called = True
        stream.set_close_callback(self.stop)
        stream.connect(("localhost", port), connect_callback)
        self.wait()
        self.assertFalse(self.connect_called)

    def test_connection_closed(self):
        # When a server sends a response and then closes the connection,
        # the client must be allowed to read the data before the IOStream
        # closes itself.  Epoll reports closed connections with a separate
        # EPOLLRDHUP event delivered at the same time as the read event,
        # while kqueue reports them as a second read/write event with an EOF
        # flag.
        response = self.fetch("/", headers={"Connection": "close"})
        response.rethrow()

    def test_read_until_close(self):
        s = socket.socket(socket.AF_INET, socket.SOCK_STREAM, 0)
        s.connect(("localhost", self.get_http_port()))
        stream = IOStream(s, io_loop=self.io_loop)
        stream.write(b("GET / HTTP/1.0\r\n\r\n"))
        
        stream.read_until_close(self.stop)
        data = self.wait()
        self.assertTrue(data.startswith(b("HTTP/1.0 200")))
        self.assertTrue(data.endswith(b("Hello")))
예제 #16
0
class ESME(DeliverMixin, BaseESME):
    def __init__(self, **kwargs):
        BaseESME.__init__(self, **kwargs)
        self.running = False
        self.closed = False

    @coroutine
    def connect(self, host, port):
        s = socket.socket(socket.AF_INET, socket.SOCK_STREAM, 0)
        self.ioloop = IOLoop.current()
        self.stream = IOStream(s)
        yield self.stream.connect((host, port))

    def on_send(self, data):
        return self.stream.write(data)

    def on_close(self):
        self.closed = True
        self.stream.close()

    @coroutine
    def readloop(self, future):
        while not self.closed and (not future or not future.done()):
            try:
                data = yield self.stream.read_bytes(1024, partial=True)
            except StreamClosedError:  # pragma: no cover
                break
            else:
                self.feed(data)

    def wait_for(self, response):
        future = Future()
        response.callback = lambda resp: future.set_result(resp.response)
        if self.running:
            return future
        else:
            return self.run(future)

    @coroutine
    def run(self, future=None):
        self.running = True
        try:
            yield self.readloop(future)
        finally:
            self.running = False

        if future and future.done():
            raise Return(future.result())
예제 #17
0
class TestIOStream(AsyncHTTPTestCase, LogTrapTestCase):
    def get_app(self):
        return Application([('/', HelloHandler)])

    def test_read_zero_bytes(self):
        s = socket.socket(socket.AF_INET, socket.SOCK_STREAM, 0)
        s.connect(("localhost", self.get_http_port()))
        self.stream = IOStream(s, io_loop=self.io_loop)
        self.stream.write("GET / HTTP/1.0\r\n\r\n")

        # normal read
        self.stream.read_bytes(9, self.stop)
        data = self.wait()
        self.assertEqual(data, "HTTP/1.0 ")

        # zero bytes
        self.stream.read_bytes(0, self.stop)
        data = self.wait()
        self.assertEqual(data, "")

        # another normal read
        self.stream.read_bytes(3, self.stop)
        data = self.wait()
        self.assertEqual(data, "200")

    def test_connection_refused(self):
        # When a connection is refused, the connect callback should not
        # be run.  (The kqueue IOLoop used to behave differently from the
        # epoll IOLoop in this respect)
        port = get_unused_port()
        stream = IOStream(socket.socket(), self.io_loop)
        self.connect_called = False

        def connect_callback():
            self.connect_called = True

        stream.set_close_callback(self.stop)
        stream.connect(("localhost", port), connect_callback)
        self.wait()
        self.assertFalse(self.connect_called)
예제 #18
0
class TestIOStream(AsyncHTTPTestCase, LogTrapTestCase):
    def get_app(self):
        return Application([('/', HelloHandler)])

    def test_read_zero_bytes(self):
        s = socket.socket(socket.AF_INET, socket.SOCK_STREAM, 0)
        s.connect(("localhost", self.get_http_port()))
        self.stream = IOStream(s, io_loop=self.io_loop)
        self.stream.write("GET / HTTP/1.0\r\n\r\n")

        # normal read
        self.stream.read_bytes(9, self.stop)
        data = self.wait()
        self.assertEqual(data, "HTTP/1.0 ")

        # zero bytes
        self.stream.read_bytes(0, self.stop)
        data = self.wait()
        self.assertEqual(data, "")

        # another normal read
        self.stream.read_bytes(3, self.stop)
        data = self.wait()
        self.assertEqual(data, "200")

    def test_connection_refused(self):
        # When a connection is refused, the connect callback should not
        # be run.  (The kqueue IOLoop used to behave differently from the
        # epoll IOLoop in this respect)
        port = get_unused_port()
        stream = IOStream(socket.socket(), self.io_loop)
        self.connect_called = False
        def connect_callback():
            self.connect_called = True
        stream.set_close_callback(self.stop)
        stream.connect(("localhost", port), connect_callback)
        self.wait()
        self.assertFalse(self.connect_called)
예제 #19
0
class TestIOStream(AsyncHTTPTestCase, LogTrapTestCase):
    def get_app(self):
        return Application([('/', HelloHandler)])

    def test_read_zero_bytes(self):
        s = socket.socket(socket.AF_INET, socket.SOCK_STREAM, 0)
        s.connect(("localhost", self.get_http_port()))
        self.stream = IOStream(s, io_loop=self.io_loop)
        self.stream.write("GET / HTTP/1.0\r\n\r\n")

        # normal read
        self.stream.read_bytes(9, self.stop)
        data = self.wait()
        self.assertEqual(data, "HTTP/1.0 ")

        # zero bytes
        self.stream.read_bytes(0, self.stop)
        data = self.wait()
        self.assertEqual(data, "")

        # another normal read
        self.stream.read_bytes(3, self.stop)
        data = self.wait()
        self.assertEqual(data, "200")
예제 #20
0
class Connection(object):
    def __init__(self, host='localhost', port=6379, unix_socket_path=None,
                 event_handler_proxy=None, stop_after=None, io_loop=None):
        self.host = host
        self.port = port
        self.unix_socket_path = unix_socket_path
        self._event_handler = event_handler_proxy
        self.timeout = stop_after
        self._stream = None
        self._io_loop = io_loop

        self.in_progress = False
        self.read_callbacks = set()
        self.ready_callbacks = deque()
        self._lock = 0
        self.info = {'db': 0, 'pass': None}

    def __del__(self):
        self.disconnect()

    def execute_pending_command(self):
        # Continue with the pending command execution
        # if all read operations are completed.
        if not self.read_callbacks and self.ready_callbacks:
            # Pop a SINGLE callback from the queue and execute it.
            # The next one will be executed from the code
            # invoked by the callback
            callback = self.ready_callbacks.popleft()
            callback()

    def ready(self):
        return (not self.read_callbacks and
                not self.ready_callbacks)

    def wait_until_ready(self, callback=None):
        if callback:
            if not self.ready():
                callback = stack_context.wrap(callback)
                self.ready_callbacks.append(callback)
            else:
                callback()

    def connect(self):
        if not self._stream:
            try:
                if self.unix_socket_path:
                    sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
                    sock.settimeout(self.timeout)
                    sock.connect(self.unix_socket_path)
                else:
                    sock = socket.create_connection(
                        (self.host, self.port),
                        timeout=self.timeout
                    )
                    sock.setsockopt(socket.SOL_TCP, socket.TCP_NODELAY, 1)
                self._stream = IOStream(sock)
                self._stream.set_close_callback(self.on_stream_close)
                self.info['db'] = 0
                self.info['pass'] = None
            except socket.error as e:
                raise ConnectionError(str(e))
            self.fire_event('on_connect')

    def on_stream_close(self):
        if self._stream:
            self.disconnect()
            callbacks = self.read_callbacks
            self.read_callbacks = set()
            for callback in callbacks:
                callback()

    def disconnect(self):
        if self._stream:
            s = self._stream
            self._stream = None
            try:
                if s.socket:
                    s.socket.shutdown(socket.SHUT_RDWR)
                s.close()
            except:
                pass

    def fire_event(self, event):
        event_handler = self._event_handler
        if event_handler:
            try:
                getattr(event_handler, event)()
            except AttributeError:
                pass

    def write(self, data, callback=None):
        if not self._stream:
            raise ConnectionError('Tried to write to '
                                  'non-existent connection')

        if callback:
            callback = stack_context.wrap(callback)
            _callback = lambda: callback(None)
            self.read_callbacks.add(_callback)
            cb = partial(self.read_callback, _callback)
        else:
            cb = None
        try:
            self._stream.write(data, callback=cb)
        except IOError as e:
            self.disconnect()
            raise ConnectionError(e.message)

    def read(self, length, callback=None):
        try:
            if not self._stream:
                self.disconnect()
                raise ConnectionError('Tried to read from '
                                      'non-existent connection')
            callback = stack_context.wrap(callback)
            self.read_callbacks.add(callback)
            self._stream.read_bytes(length,
                                    callback=partial(self.read_callback,
                                                     callback))
        except IOError:
            self.fire_event('on_disconnect')

    def read_callback(self, callback, *args, **kwargs):
        try:
            self.read_callbacks.remove(callback)
        except KeyError:
            pass
        callback(*args, **kwargs)

    def readline(self, callback=None):
        try:
            if not self._stream:
                self.disconnect()
                raise ConnectionError('Tried to read from '
                                      'non-existent connection')
            callback = stack_context.wrap(callback)
            self.read_callbacks.add(callback)
            callback = partial(self.read_callback, callback)
            self._stream.read_until(CRLF, callback=callback)
        except IOError:
            self.fire_event('on_disconnect')

    def connected(self):
        if self._stream:
            return True
        return False
예제 #21
0
class TestIOStream(AsyncHTTPTestCase, LogTrapTestCase):
    def get_app(self):
        return Application([('/', HelloHandler)])

    def make_iostream_pair(self):
        port = get_unused_port()
        [listener] = netutil.bind_sockets(port, '127.0.0.1',
                                          family=socket.AF_INET)
        streams = [None, None]
        def accept_callback(connection, address):
            streams[0] = IOStream(connection, io_loop=self.io_loop)
            self.stop()
        def connect_callback():
            streams[1] = client_stream
            self.stop()
        netutil.add_accept_handler(listener, accept_callback,
                                   io_loop=self.io_loop)
        client_stream = IOStream(socket.socket(), io_loop=self.io_loop)
        client_stream.connect(('127.0.0.1', port),
                              callback=connect_callback)
        self.wait(condition=lambda: all(streams))
        self.io_loop.remove_handler(listener.fileno())
        listener.close()
        return streams

    def test_read_zero_bytes(self):
        s = socket.socket(socket.AF_INET, socket.SOCK_STREAM, 0)
        s.connect(("localhost", self.get_http_port()))
        self.stream = IOStream(s, io_loop=self.io_loop)
        self.stream.write(b("GET / HTTP/1.0\r\n\r\n"))

        # normal read
        self.stream.read_bytes(9, self.stop)
        data = self.wait()
        self.assertEqual(data, b("HTTP/1.0 "))

        # zero bytes
        self.stream.read_bytes(0, self.stop)
        data = self.wait()
        self.assertEqual(data, b(""))

        # another normal read
        self.stream.read_bytes(3, self.stop)
        data = self.wait()
        self.assertEqual(data, b("200"))

    def test_connection_refused(self):
        # When a connection is refused, the connect callback should not
        # be run.  (The kqueue IOLoop used to behave differently from the
        # epoll IOLoop in this respect)
        port = get_unused_port()
        stream = IOStream(socket.socket(), self.io_loop)
        self.connect_called = False
        def connect_callback():
            self.connect_called = True
        stream.set_close_callback(self.stop)
        stream.connect(("localhost", port), connect_callback)
        self.wait()
        self.assertFalse(self.connect_called)

    def test_connection_closed(self):
        # When a server sends a response and then closes the connection,
        # the client must be allowed to read the data before the IOStream
        # closes itself.  Epoll reports closed connections with a separate
        # EPOLLRDHUP event delivered at the same time as the read event,
        # while kqueue reports them as a second read/write event with an EOF
        # flag.
        response = self.fetch("/", headers={"Connection": "close"})
        response.rethrow()

    def test_read_until_close(self):
        s = socket.socket(socket.AF_INET, socket.SOCK_STREAM, 0)
        s.connect(("localhost", self.get_http_port()))
        stream = IOStream(s, io_loop=self.io_loop)
        stream.write(b("GET / HTTP/1.0\r\n\r\n"))
        
        stream.read_until_close(self.stop)
        data = self.wait()
        self.assertTrue(data.startswith(b("HTTP/1.0 200")))
        self.assertTrue(data.endswith(b("Hello")))

    def test_streaming_callback(self):
        server, client = self.make_iostream_pair()
        try:
            chunks = []
            final_called = []
            def streaming_callback(data):
                chunks.append(data)
                self.stop()
            def final_callback(data):
                assert not data
                final_called.append(True)
                self.stop()
            server.read_bytes(6, callback=final_callback,
                              streaming_callback=streaming_callback)
            client.write(b("1234"))
            self.wait(condition=lambda: chunks)
            client.write(b("5678"))
            self.wait(condition=lambda: final_called)
            self.assertEqual(chunks, [b("1234"), b("56")])

            # the rest of the last chunk is still in the buffer
            server.read_bytes(2, callback=self.stop)
            data = self.wait()
            self.assertEqual(data, b("78"))
        finally:
            server.close()
            client.close()

    def test_streaming_until_close(self):
        server, client = self.make_iostream_pair()
        try:
            chunks = []
            def callback(data):
                chunks.append(data)
                self.stop()
            client.read_until_close(callback=callback,
                                    streaming_callback=callback)
            server.write(b("1234"))
            self.wait()
            server.write(b("5678"))
            self.wait()
            server.close()
            self.wait()
            self.assertEqual(chunks, [b("1234"), b("5678"), b("")])
        finally:
            server.close()
            client.close()

    def test_delayed_close_callback(self):
        # The scenario:  Server closes the connection while there is a pending
        # read that can be served out of buffered data.  The client does not
        # run the close_callback as soon as it detects the close, but rather
        # defers it until after the buffered read has finished.
        server, client = self.make_iostream_pair()
        try:
            client.set_close_callback(self.stop)
            server.write(b("12"))
            chunks = []
            def callback1(data):
                chunks.append(data)
                client.read_bytes(1, callback2)
                server.close()
            def callback2(data):
                chunks.append(data)
            client.read_bytes(1, callback1)
            self.wait()  # stopped by close_callback
            self.assertEqual(chunks, [b("1"), b("2")])
        finally:
            server.close()
            client.close()
class _HTTPConnection(object):
    _SUPPORTED_METHODS = set(
        ["GET", "HEAD", "POST", "PUT", "DELETE", "PATCH", "OPTIONS"])

    def __init__(self, io_loop, client, request, release_callback,
                 final_callback, max_buffer_size, resolver):
        self.start_time = io_loop.time()
        self.io_loop = io_loop
        self.client = client
        self.request = request
        self.release_callback = release_callback
        self.final_callback = final_callback
        self.max_buffer_size = max_buffer_size
        self.resolver = resolver
        self.code = None
        self.headers = None
        self.chunks = None
        self._decompressor = None
        # Timeout handle returned by IOLoop.add_timeout
        self._timeout = None
        with stack_context.ExceptionStackContext(self._handle_exception):
            self.parsed = urlparse.urlsplit(_unicode(self.request.url))
            if self.parsed.scheme not in ("http", "https"):
                raise ValueError("Unsupported url scheme: %s" %
                                 self.request.url)
            # urlsplit results have hostname and port results, but they
            # didn't support ipv6 literals until python 2.7.
            netloc = self.parsed.netloc
            if "@" in netloc:
                userpass, _, netloc = netloc.rpartition("@")
            match = re.match(r'^(.+):(\d+)$', netloc)
            if match:
                host = match.group(1)
                port = int(match.group(2))
            else:
                host = netloc
                port = 443 if self.parsed.scheme == "https" else 80
            if re.match(r'^\[.*\]$', host):
                # raw ipv6 addresses in urls are enclosed in brackets
                host = host[1:-1]
            self.parsed_hostname = host  # save final host for _on_connect

            if request.allow_ipv6:
                af = socket.AF_UNSPEC
            else:
                # We only try the first IP we get from getaddrinfo,
                # so restrict to ipv4 by default.
                af = socket.AF_INET

            self.resolver.resolve(host, port, af, callback=self._on_resolve)

    def _on_resolve(self, addrinfo):
        af, sockaddr = addrinfo[0]

        if self.parsed.scheme == "https":
            ssl_options = {}
            if self.request.validate_cert:
                ssl_options["cert_reqs"] = ssl.CERT_REQUIRED
            if self.request.ca_certs is not None:
                ssl_options["ca_certs"] = self.request.ca_certs
            else:
                ssl_options["ca_certs"] = _DEFAULT_CA_CERTS
            if self.request.client_key is not None:
                ssl_options["keyfile"] = self.request.client_key
            if self.request.client_cert is not None:
                ssl_options["certfile"] = self.request.client_cert

            # SSL interoperability is tricky.  We want to disable
            # SSLv2 for security reasons; it wasn't disabled by default
            # until openssl 1.0.  The best way to do this is to use
            # the SSL_OP_NO_SSLv2, but that wasn't exposed to python
            # until 3.2.  Python 2.7 adds the ciphers argument, which
            # can also be used to disable SSLv2.  As a last resort
            # on python 2.6, we set ssl_version to SSLv3.  This is
            # more narrow than we'd like since it also breaks
            # compatibility with servers configured for TLSv1 only,
            # but nearly all servers support SSLv3:
            # http://blog.ivanristic.com/2011/09/ssl-survey-protocol-support.html
            if sys.version_info >= (2, 7):
                ssl_options["ciphers"] = "DEFAULT:!SSLv2"
            else:
                # This is really only necessary for pre-1.0 versions
                # of openssl, but python 2.6 doesn't expose version
                # information.
                ssl_options["ssl_version"] = ssl.PROTOCOL_SSLv3

            self.stream = SSLIOStream(socket.socket(af),
                                      io_loop=self.io_loop,
                                      ssl_options=ssl_options,
                                      max_buffer_size=self.max_buffer_size)
        else:
            self.stream = IOStream(socket.socket(af),
                                   io_loop=self.io_loop,
                                   max_buffer_size=self.max_buffer_size)
        timeout = min(self.request.connect_timeout,
                      self.request.request_timeout)
        if timeout:
            self._timeout = self.io_loop.add_timeout(
                self.start_time + timeout,
                stack_context.wrap(self._on_timeout))
        self.stream.set_close_callback(self._on_close)
        # ipv6 addresses are broken (in self.parsed.hostname) until
        # 2.7, here is correctly parsed value calculated in __init__
        self.stream.connect(sockaddr,
                            self._on_connect,
                            server_hostname=self.parsed_hostname)

    def _on_timeout(self):
        self._timeout = None
        if self.final_callback is not None:
            raise HTTPError(599, "Timeout")

    def _remove_timeout(self):
        if self._timeout is not None:
            self.io_loop.remove_timeout(self._timeout)
            self._timeout = None

    def _on_connect(self):
        self._remove_timeout()
        if self.request.request_timeout:
            self._timeout = self.io_loop.add_timeout(
                self.start_time + self.request.request_timeout,
                stack_context.wrap(self._on_timeout))
        if (self.request.method not in self._SUPPORTED_METHODS
                and not self.request.allow_nonstandard_methods):
            raise KeyError("unknown method %s" % self.request.method)
        for key in ('network_interface', 'proxy_host', 'proxy_port',
                    'proxy_username', 'proxy_password'):
            if getattr(self.request, key, None):
                raise NotImplementedError('%s not supported' % key)
        if "Connection" not in self.request.headers:
            self.request.headers["Connection"] = "close"
        if "Host" not in self.request.headers:
            if '@' in self.parsed.netloc:
                self.request.headers["Host"] = self.parsed.netloc.rpartition(
                    '@')[-1]
            else:
                self.request.headers["Host"] = self.parsed.netloc
        username, password = None, None
        if self.parsed.username is not None:
            username, password = self.parsed.username, self.parsed.password
        elif self.request.auth_username is not None:
            username = self.request.auth_username
            password = self.request.auth_password or ''
        if username is not None:
            auth = utf8(username) + b":" + utf8(password)
            self.request.headers["Authorization"] = (b"Basic " +
                                                     base64.b64encode(auth))
        if self.request.user_agent:
            self.request.headers["User-Agent"] = self.request.user_agent
        if not self.request.allow_nonstandard_methods:
            if self.request.method in ("POST", "PATCH", "PUT"):
                assert self.request.body is not None
            else:
                assert self.request.body is None
        if self.request.body is not None:
            self.request.headers["Content-Length"] = str(len(
                self.request.body))
        if (self.request.method == "POST"
                and "Content-Type" not in self.request.headers):
            self.request.headers[
                "Content-Type"] = "application/x-www-form-urlencoded"
        if self.request.use_gzip:
            self.request.headers["Accept-Encoding"] = "gzip"
        req_path = ((self.parsed.path or '/') +
                    (('?' + self.parsed.query) if self.parsed.query else ''))
        request_lines = [
            utf8("%s %s HTTP/1.1" % (self.request.method, req_path))
        ]
        for k, v in self.request.headers.get_all():
            line = utf8(k) + b": " + utf8(v)
            if b'\n' in line:
                raise ValueError('Newline in header: ' + repr(line))
            request_lines.append(line)
        self.stream.write(b"\r\n".join(request_lines) + b"\r\n\r\n")
        if self.request.body is not None:
            self.stream.write(self.request.body)
        self.stream.read_until_regex(b"\r?\n\r?\n", self._on_headers)

    def _release(self):
        if self.release_callback is not None:
            release_callback = self.release_callback
            self.release_callback = None
            release_callback()

    def _run_callback(self, response):
        self._release()
        if self.final_callback is not None:
            final_callback = self.final_callback
            self.final_callback = None
            self.io_loop.add_callback(final_callback, response)

    def _handle_exception(self, typ, value, tb):
        if self.final_callback:
            self._remove_timeout()
            gen_log.warning("uncaught exception", exc_info=(typ, value, tb))
            self._run_callback(
                HTTPResponse(
                    self.request,
                    599,
                    error=value,
                    request_time=self.io_loop.time() - self.start_time,
                ))

            if hasattr(self, "stream"):
                self.stream.close()
            return True
        else:
            # If our callback has already been called, we are probably
            # catching an exception that is not caused by us but rather
            # some child of our callback. Rather than drop it on the floor,
            # pass it along.
            return False

    def _on_close(self):
        if self.final_callback is not None:
            message = "Connection closed"
            if self.stream.error:
                message = str(self.stream.error)
            raise HTTPError(599, message)

    def _handle_1xx(self, code):
        self.stream.read_until_regex(b"\r?\n\r?\n", self._on_headers)

    def _on_headers(self, data):
        data = native_str(data.decode("latin1"))
        first_line, _, header_data = data.partition("\n")
        match = re.match("HTTP/1.[01] ([0-9]+) ([^\r]*)", first_line)
        assert match
        code = int(match.group(1))
        self.headers = HTTPHeaders.parse(header_data)
        if 100 <= code < 200:
            self._handle_1xx(code)
            return
        else:
            self.code = code
            self.reason = match.group(2)

        if "Content-Length" in self.headers:
            if "," in self.headers["Content-Length"]:
                # Proxies sometimes cause Content-Length headers to get
                # duplicated.  If all the values are identical then we can
                # use them but if they differ it's an error.
                pieces = re.split(r',\s*', self.headers["Content-Length"])
                if any(i != pieces[0] for i in pieces):
                    raise ValueError("Multiple unequal Content-Lengths: %r" %
                                     self.headers["Content-Length"])
                self.headers["Content-Length"] = pieces[0]
            content_length = int(self.headers["Content-Length"])
        else:
            content_length = None

        if self.request.header_callback is not None:
            # re-attach the newline we split on earlier
            self.request.header_callback(first_line + _)
            for k, v in self.headers.get_all():
                self.request.header_callback("%s: %s\r\n" % (k, v))
            self.request.header_callback('\r\n')

        if self.request.method == "HEAD" or self.code == 304:
            # HEAD requests and 304 responses never have content, even
            # though they may have content-length headers
            self._on_body(b"")
            return
        if 100 <= self.code < 200 or self.code == 204:
            # These response codes never have bodies
            # http://www.w3.org/Protocols/rfc2616/rfc2616-sec4.html#sec4.3
            if ("Transfer-Encoding" in self.headers
                    or content_length not in (None, 0)):
                raise ValueError("Response with code %d should not have body" %
                                 self.code)
            self._on_body(b"")
            return

        if (self.request.use_gzip
                and self.headers.get("Content-Encoding") == "gzip"):
            self._decompressor = GzipDecompressor()
        if self.headers.get("Transfer-Encoding") == "chunked":
            self.chunks = []
            self.stream.read_until(b"\r\n", self._on_chunk_length)
        elif content_length is not None:
            self.stream.read_bytes(content_length, self._on_body)
        else:
            self.stream.read_until_close(self._on_body)

    def _on_body(self, data):
        self._remove_timeout()
        original_request = getattr(self.request, "original_request",
                                   self.request)
        if (self.request.follow_redirects and self.request.max_redirects > 0
                and self.code in (301, 302, 303, 307)):
            assert isinstance(self.request, _RequestProxy)
            new_request = copy.copy(self.request.request)
            new_request.url = urlparse.urljoin(self.request.url,
                                               self.headers["Location"])
            new_request.max_redirects = self.request.max_redirects - 1
            del new_request.headers["Host"]
            # http://www.w3.org/Protocols/rfc2616/rfc2616-sec10.html#sec10.3.4
            # Client SHOULD make a GET request after a 303.
            # According to the spec, 302 should be followed by the same
            # method as the original request, but in practice browsers
            # treat 302 the same as 303, and many servers use 302 for
            # compatibility with pre-HTTP/1.1 user agents which don't
            # understand the 303 status.
            if self.code in (302, 303):
                new_request.method = "GET"
                new_request.body = None
                for h in [
                        "Content-Length", "Content-Type", "Content-Encoding",
                        "Transfer-Encoding"
                ]:
                    try:
                        del self.request.headers[h]
                    except KeyError:
                        pass
            new_request.original_request = original_request
            final_callback = self.final_callback
            self.final_callback = None
            self._release()
            self.client.fetch(new_request, final_callback)
            self.stream.close()
            return
        if self._decompressor:
            data = (self._decompressor.decompress(data) +
                    self._decompressor.flush())
        if self.request.streaming_callback:
            if self.chunks is None:
                # if chunks is not None, we already called streaming_callback
                # in _on_chunk_data
                self.request.streaming_callback(data)
            buffer = BytesIO()
        else:
            buffer = BytesIO(data)  # TODO: don't require one big string?
        response = HTTPResponse(original_request,
                                self.code,
                                reason=self.reason,
                                headers=self.headers,
                                request_time=self.io_loop.time() -
                                self.start_time,
                                buffer=buffer,
                                effective_url=self.request.url)
        self._run_callback(response)
        self.stream.close()

    def _on_chunk_length(self, data):
        # TODO: "chunk extensions" http://tools.ietf.org/html/rfc2616#section-3.6.1
        length = int(data.strip(), 16)
        if length == 0:
            if self._decompressor is not None:
                tail = self._decompressor.flush()
                if tail:
                    # I believe the tail will always be empty (i.e.
                    # decompress will return all it can).  The purpose
                    # of the flush call is to detect errors such
                    # as truncated input.  But in case it ever returns
                    # anything, treat it as an extra chunk
                    if self.request.streaming_callback is not None:
                        self.request.streaming_callback(tail)
                    else:
                        self.chunks.append(tail)
                # all the data has been decompressed, so we don't need to
                # decompress again in _on_body
                self._decompressor = None
            self._on_body(b''.join(self.chunks))
        else:
            self.stream.read_bytes(
                length + 2,  # chunk ends with \r\n
                self._on_chunk_data)

    def _on_chunk_data(self, data):
        assert data[-2:] == b"\r\n"
        chunk = data[:-2]
        if self._decompressor:
            chunk = self._decompressor.decompress(chunk)
        if self.request.streaming_callback is not None:
            self.request.streaming_callback(chunk)
        else:
            self.chunks.append(chunk)
        self.stream.read_until(b"\r\n", self._on_chunk_length)
예제 #23
0
class TestIOStream(AsyncHTTPTestCase, LogTrapTestCase):
    def get_app(self):
        return Application([('/', HelloHandler)])

    def make_iostream_pair(self, **kwargs):
        port = get_unused_port()
        [listener] = netutil.bind_sockets(port, '127.0.0.1',
                                          family=socket.AF_INET)
        streams = [None, None]

        def accept_callback(connection, address):
            streams[0] = IOStream(connection, io_loop=self.io_loop, **kwargs)
            self.stop()

        def connect_callback():
            streams[1] = client_stream
            self.stop()
        netutil.add_accept_handler(listener, accept_callback,
                                   io_loop=self.io_loop)
        client_stream = IOStream(socket.socket(), io_loop=self.io_loop,
                                 **kwargs)
        client_stream.connect(('127.0.0.1', port),
                              callback=connect_callback)
        self.wait(condition=lambda: all(streams))
        self.io_loop.remove_handler(listener.fileno())
        listener.close()
        return streams

    def test_read_zero_bytes(self):
        s = socket.socket(socket.AF_INET, socket.SOCK_STREAM, 0)
        s.connect(("localhost", self.get_http_port()))
        self.stream = IOStream(s, io_loop=self.io_loop)
        self.stream.write(b("GET / HTTP/1.0\r\n\r\n"))

        # normal read
        self.stream.read_bytes(9, self.stop)
        data = self.wait()
        self.assertEqual(data, b("HTTP/1.0 "))

        # zero bytes
        self.stream.read_bytes(0, self.stop)
        data = self.wait()
        self.assertEqual(data, b(""))

        # another normal read
        self.stream.read_bytes(3, self.stop)
        data = self.wait()
        self.assertEqual(data, b("200"))

        s.close()

    def test_write_zero_bytes(self):
        # Attempting to write zero bytes should run the callback without
        # going into an infinite loop.
        server, client = self.make_iostream_pair()
        server.write(b(''), callback=self.stop)
        self.wait()
        # As a side effect, the stream is now listening for connection
        # close (if it wasn't already), but is not listening for writes
        self.assertEqual(server._state, IOLoop.READ | IOLoop.ERROR)
        server.close()
        client.close()

    def test_connection_refused(self):
        # When a connection is refused, the connect callback should not
        # be run.  (The kqueue IOLoop used to behave differently from the
        # epoll IOLoop in this respect)
        port = get_unused_port()
        stream = IOStream(socket.socket(), self.io_loop)
        self.connect_called = False

        def connect_callback():
            self.connect_called = True
        stream.set_close_callback(self.stop)
        stream.connect(("localhost", port), connect_callback)
        self.wait()
        self.assertFalse(self.connect_called)
        self.assertTrue(isinstance(stream.error, socket.error), stream.error)
        if sys.platform != 'cygwin':
            # cygwin's errnos don't match those used on native windows python
            self.assertEqual(stream.error.args[0], errno.ECONNREFUSED)

    def test_gaierror(self):
        # Test that IOStream sets its exc_info on getaddrinfo error
        s = socket.socket(socket.AF_INET, socket.SOCK_STREAM, 0)
        stream = IOStream(s, io_loop=self.io_loop)
        stream.set_close_callback(self.stop)
        stream.connect(('adomainthatdoesntexist.asdf', 54321))
        self.assertTrue(isinstance(stream.error, socket.gaierror), stream.error)

    def test_connection_closed(self):
        # When a server sends a response and then closes the connection,
        # the client must be allowed to read the data before the IOStream
        # closes itself.  Epoll reports closed connections with a separate
        # EPOLLRDHUP event delivered at the same time as the read event,
        # while kqueue reports them as a second read/write event with an EOF
        # flag.
        response = self.fetch("/", headers={"Connection": "close"})
        response.rethrow()

    def test_read_until_close(self):
        s = socket.socket(socket.AF_INET, socket.SOCK_STREAM, 0)
        s.connect(("localhost", self.get_http_port()))
        stream = IOStream(s, io_loop=self.io_loop)
        stream.write(b("GET / HTTP/1.0\r\n\r\n"))

        stream.read_until_close(self.stop)
        data = self.wait()
        self.assertTrue(data.startswith(b("HTTP/1.0 200")))
        self.assertTrue(data.endswith(b("Hello")))

    def test_streaming_callback(self):
        server, client = self.make_iostream_pair()
        try:
            chunks = []
            final_called = []

            def streaming_callback(data):
                chunks.append(data)
                self.stop()

            def final_callback(data):
                assert not data
                final_called.append(True)
                self.stop()
            server.read_bytes(6, callback=final_callback,
                              streaming_callback=streaming_callback)
            client.write(b("1234"))
            self.wait(condition=lambda: chunks)
            client.write(b("5678"))
            self.wait(condition=lambda: final_called)
            self.assertEqual(chunks, [b("1234"), b("56")])

            # the rest of the last chunk is still in the buffer
            server.read_bytes(2, callback=self.stop)
            data = self.wait()
            self.assertEqual(data, b("78"))
        finally:
            server.close()
            client.close()

    def test_streaming_until_close(self):
        server, client = self.make_iostream_pair()
        try:
            chunks = []

            def callback(data):
                chunks.append(data)
                self.stop()
            client.read_until_close(callback=callback,
                                    streaming_callback=callback)
            server.write(b("1234"))
            self.wait()
            server.write(b("5678"))
            self.wait()
            server.close()
            self.wait()
            self.assertEqual(chunks, [b("1234"), b("5678"), b("")])
        finally:
            server.close()
            client.close()

    def test_delayed_close_callback(self):
        # The scenario:  Server closes the connection while there is a pending
        # read that can be served out of buffered data.  The client does not
        # run the close_callback as soon as it detects the close, but rather
        # defers it until after the buffered read has finished.
        server, client = self.make_iostream_pair()
        try:
            client.set_close_callback(self.stop)
            server.write(b("12"))
            chunks = []

            def callback1(data):
                chunks.append(data)
                client.read_bytes(1, callback2)
                server.close()

            def callback2(data):
                chunks.append(data)
            client.read_bytes(1, callback1)
            self.wait()  # stopped by close_callback
            self.assertEqual(chunks, [b("1"), b("2")])
        finally:
            server.close()
            client.close()

    def test_close_buffered_data(self):
        # Similar to the previous test, but with data stored in the OS's
        # socket buffers instead of the IOStream's read buffer.  Out-of-band
        # close notifications must be delayed until all data has been
        # drained into the IOStream buffer. (epoll used to use out-of-band
        # close events with EPOLLRDHUP, but no longer)
        #
        # This depends on the read_chunk_size being smaller than the
        # OS socket buffer, so make it small.
        server, client = self.make_iostream_pair(read_chunk_size=256)
        try:
            server.write(b("A") * 512)
            client.read_bytes(256, self.stop)
            data = self.wait()
            self.assertEqual(b("A") * 256, data)
            server.close()
            # Allow the close to propagate to the client side of the
            # connection.  Using add_callback instead of add_timeout
            # doesn't seem to work, even with multiple iterations
            self.io_loop.add_timeout(datetime.timedelta(seconds=0.01), self.stop)
            self.wait()
            client.read_bytes(256, self.stop)
            data = self.wait()
            self.assertEqual(b("A") * 256, data)
        finally:
            server.close()
            client.close()

    def test_large_read_until(self):
        # Performance test: read_until used to have a quadratic component
        # so a read_until of 4MB would take 8 seconds; now it takes 0.25
        # seconds.
        server, client = self.make_iostream_pair()
        try:
            NUM_KB = 4096
            for i in xrange(NUM_KB):
                client.write(b("A") * 1024)
            client.write(b("\r\n"))
            server.read_until(b("\r\n"), self.stop)
            data = self.wait()
            self.assertEqual(len(data), NUM_KB * 1024 + 2)
        finally:
            server.close()
            client.close()
예제 #24
0
class _HTTPConnection(object):
    _SUPPORTED_METHODS = set(["GET", "HEAD", "POST", "PUT", "DELETE", "PATCH", "OPTIONS"])

    def __init__(self, io_loop, client, request, release_callback, final_callback, max_buffer_size):
        self.start_time = io_loop.time()
        self.io_loop = io_loop
        self.client = client
        self.request = request
        self.release_callback = release_callback
        self.final_callback = final_callback
        self.max_buffer_size = max_buffer_size
        self.code = None
        self.headers = None
        self.chunks = None
        self._decompressor = None
        # Timeout handle returned by IOLoop.add_timeout
        self._timeout = None
        with stack_context.ExceptionStackContext(self._handle_exception):
            self.parsed = urlparse.urlsplit(_unicode(self.request.url))
            if self.parsed.scheme not in ("http", "https"):
                raise ValueError("Unsupported url scheme: %s" % self.request.url)
            # urlsplit results have hostname and port results, but they
            # didn't support ipv6 literals until python 2.7.
            netloc = self.parsed.netloc
            if "@" in netloc:
                userpass, _, netloc = netloc.rpartition("@")
            match = re.match(r"^(.+):(\d+)$", netloc)
            if match:
                host = match.group(1)
                port = int(match.group(2))
            else:
                host = netloc
                port = 443 if self.parsed.scheme == "https" else 80
            if re.match(r"^\[.*\]$", host):
                # raw ipv6 addresses in urls are enclosed in brackets
                host = host[1:-1]
            self.parsed_hostname = host  # save final host for _on_connect
            if self.client.hostname_mapping is not None:
                host = self.client.hostname_mapping.get(host, host)

            if request.allow_ipv6:
                af = socket.AF_UNSPEC
            else:
                # We only try the first IP we get from getaddrinfo,
                # so restrict to ipv4 by default.
                af = socket.AF_INET

            self.client.resolver.getaddrinfo(host, port, af, socket.SOCK_STREAM, 0, 0, callback=self._on_resolve)

    def _on_resolve(self, future):
        af, socktype, proto, canonname, sockaddr = future.result()[0]

        if self.parsed.scheme == "https":
            ssl_options = {}
            if self.request.validate_cert:
                ssl_options["cert_reqs"] = ssl.CERT_REQUIRED
            if self.request.ca_certs is not None:
                ssl_options["ca_certs"] = self.request.ca_certs
            else:
                ssl_options["ca_certs"] = _DEFAULT_CA_CERTS
            if self.request.client_key is not None:
                ssl_options["keyfile"] = self.request.client_key
            if self.request.client_cert is not None:
                ssl_options["certfile"] = self.request.client_cert

            # SSL interoperability is tricky.  We want to disable
            # SSLv2 for security reasons; it wasn't disabled by default
            # until openssl 1.0.  The best way to do this is to use
            # the SSL_OP_NO_SSLv2, but that wasn't exposed to python
            # until 3.2.  Python 2.7 adds the ciphers argument, which
            # can also be used to disable SSLv2.  As a last resort
            # on python 2.6, we set ssl_version to SSLv3.  This is
            # more narrow than we'd like since it also breaks
            # compatibility with servers configured for TLSv1 only,
            # but nearly all servers support SSLv3:
            # http://blog.ivanristic.com/2011/09/ssl-survey-protocol-support.html
            if sys.version_info >= (2, 7):
                ssl_options["ciphers"] = "DEFAULT:!SSLv2"
            else:
                # This is really only necessary for pre-1.0 versions
                # of openssl, but python 2.6 doesn't expose version
                # information.
                ssl_options["ssl_version"] = ssl.PROTOCOL_SSLv3

            self.stream = SSLIOStream(
                socket.socket(af, socktype, proto),
                io_loop=self.io_loop,
                ssl_options=ssl_options,
                max_buffer_size=self.max_buffer_size,
            )
        else:
            self.stream = IOStream(
                socket.socket(af, socktype, proto), io_loop=self.io_loop, max_buffer_size=self.max_buffer_size
            )
        timeout = min(self.request.connect_timeout, self.request.request_timeout)
        if timeout:
            self._timeout = self.io_loop.add_timeout(self.start_time + timeout, stack_context.wrap(self._on_timeout))
        self.stream.set_close_callback(self._on_close)
        # ipv6 addresses are broken (in self.parsed.hostname) until
        # 2.7, here is correctly parsed value calculated in __init__
        self.stream.connect(sockaddr, self._on_connect, server_hostname=self.parsed_hostname)

    def _on_timeout(self):
        self._timeout = None
        if self.final_callback is not None:
            raise HTTPError(599, "Timeout")

    def _on_connect(self):
        if self._timeout is not None:
            self.io_loop.remove_timeout(self._timeout)
            self._timeout = None
        if self.request.request_timeout:
            self._timeout = self.io_loop.add_timeout(
                self.start_time + self.request.request_timeout, stack_context.wrap(self._on_timeout)
            )
        if self.request.method not in self._SUPPORTED_METHODS and not self.request.allow_nonstandard_methods:
            raise KeyError("unknown method %s" % self.request.method)
        for key in ("network_interface", "proxy_host", "proxy_port", "proxy_username", "proxy_password"):
            if getattr(self.request, key, None):
                raise NotImplementedError("%s not supported" % key)
        if "Connection" not in self.request.headers:
            self.request.headers["Connection"] = "close"
        if "Host" not in self.request.headers:
            if "@" in self.parsed.netloc:
                self.request.headers["Host"] = self.parsed.netloc.rpartition("@")[-1]
            else:
                self.request.headers["Host"] = self.parsed.netloc
        username, password = None, None
        if self.parsed.username is not None:
            username, password = self.parsed.username, self.parsed.password
        elif self.request.auth_username is not None:
            username = self.request.auth_username
            password = self.request.auth_password or ""
        if username is not None:
            auth = utf8(username) + b":" + utf8(password)
            self.request.headers["Authorization"] = b"Basic " + base64.b64encode(auth)
        if self.request.user_agent:
            self.request.headers["User-Agent"] = self.request.user_agent
        if not self.request.allow_nonstandard_methods:
            if self.request.method in ("POST", "PATCH", "PUT"):
                assert self.request.body is not None
            else:
                assert self.request.body is None
        if self.request.body is not None:
            self.request.headers["Content-Length"] = str(len(self.request.body))
        if self.request.method == "POST" and "Content-Type" not in self.request.headers:
            self.request.headers["Content-Type"] = "application/x-www-form-urlencoded"
        if self.request.use_gzip:
            self.request.headers["Accept-Encoding"] = "gzip"
        req_path = (self.parsed.path or "/") + (("?" + self.parsed.query) if self.parsed.query else "")
        request_lines = [utf8("%s %s HTTP/1.1" % (self.request.method, req_path))]
        for k, v in self.request.headers.get_all():
            line = utf8(k) + b": " + utf8(v)
            if b"\n" in line:
                raise ValueError("Newline in header: " + repr(line))
            request_lines.append(line)
        self.stream.write(b"\r\n".join(request_lines) + b"\r\n\r\n")
        if self.request.body is not None:
            self.stream.write(self.request.body)
        self.stream.read_until_regex(b"\r?\n\r?\n", self._on_headers)

    def _release(self):
        if self.release_callback is not None:
            release_callback = self.release_callback
            self.release_callback = None
            release_callback()

    def _run_callback(self, response):
        self._release()
        if self.final_callback is not None:
            final_callback = self.final_callback
            self.final_callback = None
            self.io_loop.add_callback(final_callback, response)

    def _handle_exception(self, typ, value, tb):
        if self.final_callback:
            gen_log.warning("uncaught exception", exc_info=(typ, value, tb))
            self._run_callback(
                HTTPResponse(self.request, 599, error=value, request_time=self.io_loop.time() - self.start_time)
            )

            if hasattr(self, "stream"):
                self.stream.close()
            return True
        else:
            # If our callback has already been called, we are probably
            # catching an exception that is not caused by us but rather
            # some child of our callback. Rather than drop it on the floor,
            # pass it along.
            return False

    def _on_close(self):
        if self.final_callback is not None:
            message = "Connection closed"
            if self.stream.error:
                message = str(self.stream.error)
            raise HTTPError(599, message)

    def _on_headers(self, data):
        data = native_str(data.decode("latin1"))
        first_line, _, header_data = data.partition("\n")
        match = re.match("HTTP/1.[01] ([0-9]+) ([^\r]*)", first_line)
        assert match
        code = int(match.group(1))
        if 100 <= code < 200:
            self.stream.read_until_regex(b"\r?\n\r?\n", self._on_headers)
            return
        else:
            self.code = code
            self.reason = match.group(2)
        self.headers = HTTPHeaders.parse(header_data)

        if "Content-Length" in self.headers:
            if "," in self.headers["Content-Length"]:
                # Proxies sometimes cause Content-Length headers to get
                # duplicated.  If all the values are identical then we can
                # use them but if they differ it's an error.
                pieces = re.split(r",\s*", self.headers["Content-Length"])
                if any(i != pieces[0] for i in pieces):
                    raise ValueError("Multiple unequal Content-Lengths: %r" % self.headers["Content-Length"])
                self.headers["Content-Length"] = pieces[0]
            content_length = int(self.headers["Content-Length"])
        else:
            content_length = None

        if self.request.header_callback is not None:
            # re-attach the newline we split on earlier
            self.request.header_callback(first_line + _)
            for k, v in self.headers.get_all():
                self.request.header_callback("%s: %s\r\n" % (k, v))
            self.request.header_callback("\r\n")

        if self.request.method == "HEAD" or self.code == 304:
            # HEAD requests and 304 responses never have content, even
            # though they may have content-length headers
            self._on_body(b"")
            return
        if 100 <= self.code < 200 or self.code == 204:
            # These response codes never have bodies
            # http://www.w3.org/Protocols/rfc2616/rfc2616-sec4.html#sec4.3
            if "Transfer-Encoding" in self.headers or content_length not in (None, 0):
                raise ValueError("Response with code %d should not have body" % self.code)
            self._on_body(b"")
            return

        if self.request.use_gzip and self.headers.get("Content-Encoding") == "gzip":
            self._decompressor = GzipDecompressor()
        if self.headers.get("Transfer-Encoding") == "chunked":
            self.chunks = []
            self.stream.read_until(b"\r\n", self._on_chunk_length)
        elif content_length is not None:
            self.stream.read_bytes(content_length, self._on_body)
        else:
            self.stream.read_until_close(self._on_body)

    def _on_body(self, data):
        if self._timeout is not None:
            self.io_loop.remove_timeout(self._timeout)
            self._timeout = None
        original_request = getattr(self.request, "original_request", self.request)
        if self.request.follow_redirects and self.request.max_redirects > 0 and self.code in (301, 302, 303, 307):
            assert isinstance(self.request, _RequestProxy)
            new_request = copy.copy(self.request.request)
            new_request.url = urlparse.urljoin(self.request.url, self.headers["Location"])
            new_request.max_redirects = self.request.max_redirects - 1
            del new_request.headers["Host"]
            # http://www.w3.org/Protocols/rfc2616/rfc2616-sec10.html#sec10.3.4
            # Client SHOULD make a GET request after a 303.
            # According to the spec, 302 should be followed by the same
            # method as the original request, but in practice browsers
            # treat 302 the same as 303, and many servers use 302 for
            # compatibility with pre-HTTP/1.1 user agents which don't
            # understand the 303 status.
            if self.code in (302, 303):
                new_request.method = "GET"
                new_request.body = None
                for h in ["Content-Length", "Content-Type", "Content-Encoding", "Transfer-Encoding"]:
                    try:
                        del self.request.headers[h]
                    except KeyError:
                        pass
            new_request.original_request = original_request
            final_callback = self.final_callback
            self.final_callback = None
            self._release()
            self.client.fetch(new_request, final_callback)
            self.stream.close()
            return
        if self._decompressor:
            data = self._decompressor.decompress(data) + self._decompressor.flush()
        if self.request.streaming_callback:
            if self.chunks is None:
                # if chunks is not None, we already called streaming_callback
                # in _on_chunk_data
                self.request.streaming_callback(data)
            buffer = BytesIO()
        else:
            buffer = BytesIO(data)  # TODO: don't require one big string?
        response = HTTPResponse(
            original_request,
            self.code,
            reason=self.reason,
            headers=self.headers,
            request_time=self.io_loop.time() - self.start_time,
            buffer=buffer,
            effective_url=self.request.url,
        )
        self._run_callback(response)
        self.stream.close()

    def _on_chunk_length(self, data):
        # TODO: "chunk extensions" http://tools.ietf.org/html/rfc2616#section-3.6.1
        length = int(data.strip(), 16)
        if length == 0:
            if self._decompressor is not None:
                tail = self._decompressor.flush()
                if tail:
                    # I believe the tail will always be empty (i.e.
                    # decompress will return all it can).  The purpose
                    # of the flush call is to detect errors such
                    # as truncated input.  But in case it ever returns
                    # anything, treat it as an extra chunk
                    if self.request.streaming_callback is not None:
                        self.request.streaming_callback(tail)
                    else:
                        self.chunks.append(tail)
                # all the data has been decompressed, so we don't need to
                # decompress again in _on_body
                self._decompressor = None
            self._on_body(b"".join(self.chunks))
        else:
            self.stream.read_bytes(length + 2, self._on_chunk_data)  # chunk ends with \r\n

    def _on_chunk_data(self, data):
        assert data[-2:] == b"\r\n"
        chunk = data[:-2]
        if self._decompressor:
            chunk = self._decompressor.decompress(chunk)
        if self.request.streaming_callback is not None:
            self.request.streaming_callback(chunk)
        else:
            self.chunks.append(chunk)
        self.stream.read_until(b"\r\n", self._on_chunk_length)
예제 #25
0
class _HTTPConnection(object):
    _SUPPORTED_METHODS = set(["GET", "HEAD", "POST", "PUT", "DELETE"])

    def __init__(self, io_loop, client, request, release_callback,
                 final_callback, max_buffer_size):
        self.start_time = time.time()
        self.io_loop = io_loop
        self.client = client
        self.request = request
        self.release_callback = release_callback
        self.final_callback = final_callback
        self.code = None
        self.headers = None
        self.chunks = None
        self._decompressor = None
        # Timeout handle returned by IOLoop.add_timeout
        self._timeout = None
        with stack_context.StackContext(self.cleanup):
            parsed = urllib.parse.urlsplit(_unicode(self.request.url))
            if ssl is None and parsed.scheme == "https":
                raise ValueError("HTTPS requires either python2.6+ or "
                                 "curl_httpclient")
            if parsed.scheme not in ("http", "https"):
                raise ValueError("Unsupported url scheme: %s" %
                                 self.request.url)
            # urlsplit results have hostname and port results, but they
            # didn't support ipv6 literals until python 2.7.
            netloc = parsed.netloc
            if "@" in netloc:
                userpass, _, netloc = netloc.rpartition("@")
            match = re.match(r'^(.+):(\d+)$', netloc)
            if match:
                host = match.group(1)
                port = int(match.group(2))
            else:
                host = netloc
                port = 443 if parsed.scheme == "https" else 80
            if re.match(r'^\[.*\]$', host):
                # raw ipv6 addresses in urls are enclosed in brackets
                host = host[1:-1]
            if self.client.hostname_mapping is not None:
                host = self.client.hostname_mapping.get(host, host)

            if request.allow_ipv6:
                af = socket.AF_UNSPEC
            else:
                # We only try the first IP we get from getaddrinfo,
                # so restrict to ipv4 by default.
                af = socket.AF_INET

            addrinfo = socket.getaddrinfo(host, port, af, socket.SOCK_STREAM,
                                          0, 0)
            af, socktype, proto, canonname, sockaddr = addrinfo[0]

            if parsed.scheme == "https":
                ssl_options = {}
                if request.validate_cert:
                    ssl_options["cert_reqs"] = ssl.CERT_REQUIRED
                if request.ca_certs is not None:
                    ssl_options["ca_certs"] = request.ca_certs
                else:
                    ssl_options["ca_certs"] = _DEFAULT_CA_CERTS
                if request.client_key is not None:
                    ssl_options["keyfile"] = request.client_key
                if request.client_cert is not None:
                    ssl_options["certfile"] = request.client_cert
                self.stream = SSLIOStream(socket.socket(af, socktype, proto),
                                          io_loop=self.io_loop,
                                          ssl_options=ssl_options,
                                          max_buffer_size=max_buffer_size)
            else:
                self.stream = IOStream(socket.socket(af, socktype, proto),
                                       io_loop=self.io_loop,
                                       max_buffer_size=max_buffer_size)
            timeout = min(request.connect_timeout, request.request_timeout)
            if timeout:
                self._timeout = self.io_loop.add_timeout(
                    self.start_time + timeout,
                    self._on_timeout)
            self.stream.set_close_callback(self._on_close)
            self.stream.connect(sockaddr,
                                functools.partial(self._on_connect, parsed))

    def _on_timeout(self):
        self._timeout = None
        self._run_callback(HTTPResponse(self.request, 599,
                                        request_time=time.time() - self.start_time,
                                        error=HTTPError(599, "Timeout")))
        self.stream.close()

    def _on_connect(self, parsed):
        if self._timeout is not None:
            self.io_loop.remove_timeout(self._timeout)
            self._timeout = None
        if self.request.request_timeout:
            self._timeout = self.io_loop.add_timeout(
                self.start_time + self.request.request_timeout,
                self._on_timeout)
        if (self.request.validate_cert and
            isinstance(self.stream, SSLIOStream)):
            match_hostname(self.stream.socket.getpeercert(),
                           parsed.hostname)
        if (self.request.method not in self._SUPPORTED_METHODS and
            not self.request.allow_nonstandard_methods):
            raise KeyError("unknown method %s" % self.request.method)
        for key in ('network_interface',
                    'proxy_host', 'proxy_port',
                    'proxy_username', 'proxy_password'):
            if getattr(self.request, key, None):
                raise NotImplementedError('%s not supported' % key)
        if "Host" not in self.request.headers:
            self.request.headers["Host"] = parsed.netloc
        username, password = None, None
        if parsed.username is not None:
            username, password = parsed.username, parsed.password
        elif self.request.auth_username is not None:
            username = self.request.auth_username
            password = self.request.auth_password
        if username is not None:
            auth = utf8(username) + b(":") + utf8(password)
            self.request.headers["Authorization"] = (b("Basic ") +
                                                     base64.b64encode(auth))
        if self.request.user_agent:
            self.request.headers["User-Agent"] = self.request.user_agent
        if not self.request.allow_nonstandard_methods:
            if self.request.method in ("POST", "PUT"):
                assert self.request.body is not None
            else:
                assert self.request.body is None
        if self.request.body is not None:
            self.request.headers["Content-Length"] = str(len(
                    self.request.body))
        if (self.request.method == "POST" and
            "Content-Type" not in self.request.headers):
            self.request.headers["Content-Type"] = "application/x-www-form-urlencoded"
        if self.request.use_gzip:
            self.request.headers["Accept-Encoding"] = "gzip"
        req_path = ((parsed.path or '/') +
                (('?' + parsed.query) if parsed.query else ''))
        request_lines = [utf8("%s %s HTTP/1.1" % (self.request.method,
                                                  req_path))]
        for k, v in self.request.headers.get_all():
            line = utf8(k) + b(": ") + utf8(v)
            if b('\n') in line:
                raise ValueError('Newline in header: ' + repr(line))
            request_lines.append(line)
        self.stream.write(b("\r\n").join(request_lines) + b("\r\n\r\n"))
        if self.request.body is not None:
            self.stream.write(self.request.body)
        self.stream.read_until_regex(b("\r?\n\r?\n"), self._on_headers)

    def _release(self):
        if self.release_callback is not None:
            release_callback = self.release_callback
            self.release_callback = None
            release_callback()

    def _run_callback(self, response):
        self._release()
        if self.final_callback is not None:
            final_callback = self.final_callback
            self.final_callback = None
            final_callback(response)

    @contextlib.contextmanager
    def cleanup(self):
        try:
            yield
        except Exception as e:
            logging.warning("uncaught exception", exc_info=True)
            self._run_callback(HTTPResponse(self.request, 599, error=e, 
                                request_time=time.time() - self.start_time,
                                ))

    def _on_close(self):
        self._run_callback(HTTPResponse(
                self.request, 599,
                request_time=time.time() - self.start_time,
                error=HTTPError(599, "Connection closed")))

    def _on_headers(self, data):
        data = native_str(data.decode("latin1"))
        first_line, _, header_data = data.partition("\n")
        match = re.match("HTTP/1.[01] ([0-9]+)", first_line)
        assert match
        self.code = int(match.group(1))
        self.headers = HTTPHeaders.parse(header_data)
        if self.request.header_callback is not None:
            for k, v in self.headers.get_all():
                self.request.header_callback("%s: %s\r\n" % (k, v))
        if (self.request.use_gzip and
            self.headers.get("Content-Encoding") == "gzip"):
            # Magic parameter makes zlib module understand gzip header
            # http://stackoverflow.com/questions/1838699/how-can-i-decompress-a-gzip-stream-with-zlib
            self._decompressor = zlib.decompressobj(16+zlib.MAX_WBITS)
        if self.headers.get("Transfer-Encoding") == "chunked":
            self.chunks = []
            self.stream.read_until(b("\r\n"), self._on_chunk_length)
        elif "Content-Length" in self.headers:
            if "," in self.headers["Content-Length"]:
                # Proxies sometimes cause Content-Length headers to get
                # duplicated.  If all the values are identical then we can
                # use them but if they differ it's an error.
                pieces = re.split(r',\s*', self.headers["Content-Length"])
                if any(i != pieces[0] for i in pieces):
                    raise ValueError("Multiple unequal Content-Lengths: %r" % 
                                     self.headers["Content-Length"])
                self.headers["Content-Length"] = pieces[0]
            self.stream.read_bytes(int(self.headers["Content-Length"]),
                                   self._on_body)
        else:
            self.stream.read_until_close(self._on_body)

    def _on_body(self, data):
        if self._timeout is not None:
            self.io_loop.remove_timeout(self._timeout)
            self._timeout = None
        if self._decompressor:
            data = self._decompressor.decompress(data)
        if self.request.streaming_callback:
            if self.chunks is None:
                # if chunks is not None, we already called streaming_callback
                # in _on_chunk_data
                self.request.streaming_callback(data)
            buffer = BytesIO()
        else:
            buffer = BytesIO(data) # TODO: don't require one big string?
        original_request = getattr(self.request, "original_request",
                                   self.request)
        if (self.request.follow_redirects and
            self.request.max_redirects > 0 and
            self.code in (301, 302)):
            new_request = copy.copy(self.request)
            new_request.url = urllib.parse.urljoin(self.request.url,
                                               self.headers["Location"])
            new_request.max_redirects -= 1
            del new_request.headers["Host"]
            new_request.original_request = original_request
            final_callback = self.final_callback
            self.final_callback = None
            self._release()
            self.client.fetch(new_request, final_callback)
            self.stream.close()
            return
        response = HTTPResponse(original_request,
                                self.code, headers=self.headers,
                                request_time=time.time() - self.start_time,
                                buffer=buffer,
                                effective_url=self.request.url)
        self._run_callback(response)
        self.stream.close()

    def _on_chunk_length(self, data):
        # TODO: "chunk extensions" http://tools.ietf.org/html/rfc2616#section-3.6.1
        length = int(data.strip(), 16)
        if length == 0:
            # all the data has been decompressed, so we don't need to
            # decompress again in _on_body
            self._decompressor = None
            self._on_body(b('').join(self.chunks))
        else:
            self.stream.read_bytes(length + 2,  # chunk ends with \r\n
                              self._on_chunk_data)

    def _on_chunk_data(self, data):
        assert data[-2:] == b("\r\n")
        chunk = data[:-2]
        if self._decompressor:
            chunk = self._decompressor.decompress(chunk)
        if self.request.streaming_callback is not None:
            self.request.streaming_callback(chunk)
        else:
            self.chunks.append(chunk)
        self.stream.read_until(b("\r\n"), self._on_chunk_length)
예제 #26
0
class TwitterStream(object):
    """ Twitter stream connection """

    _instance = None

    def __init__(self):
        """ Just set up the cache list and get first set """
        # prepopulating cache
        client = AsyncHTTPClient()
        client.fetch("http://search.twitter.com/search.json?q="+
            SETTINGS["track"], self.cache_callback)

    def cache_callback(self, response):
        """ Set up last fifty messages """
        messages = json.loads(response.body)["results"][:50]
        messages.reverse()
        for message in messages:
            try:
                text = message["text"]
                name = ""
                username = message["from_user"]
                avatar = message["profile_image_url"]
                CACHE.append({
                    "type": "tweet",
                    "text": text,
                    "name": name,
                    "username": username,
                    "avatar": avatar,
                    "time": 1
                })
            except KeyError:
                print "invalid", message
                continue
        self.open_twitter_stream()


    @classmethod
    def instance(cls):
        """ Returns the singleton """
        if not cls._instance:
            cls._instance = cls()
        return cls._instance

    def open_twitter_stream(self):
        """ Creates the client and watches stream """
        sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM, 0)
        self.twitter_stream = IOStream(sock)
        self.twitter_stream.connect(("stream.twitter.com", 80))
        import base64
        base64string = base64.encodestring("%s:%s" % (SETTINGS["username"],
            SETTINGS["password"]))[:-1]
        headers = {"Authorization": "Basic %s" % base64string,
                   "Host": "stream.twitter.com"}
        request = ["GET /1/statuses/filter.json?track=%s HTTP/1.1" %
            SETTINGS["track"]]
        for key, value in headers.iteritems():
            request.append("%s: %s" % (key, value))
        request = "\r\n".join(request) + "\r\n\r\n"
        self.twitter_stream.write(request)
        self.twitter_stream.read_until("\r\n\r\n", self.on_headers)

    def on_headers(self, response):
        """ Starts monitoring for results. """
        status_line = response.splitlines()[0]
        response_code = status_line.replace("HTTP/1.1", "")
        response_code = int(response_code.split()[0].strip())
        if response_code != 200:
            raise Exception("Twitter could not connect: %s" % status_line)
        self.wait_for_message()

    def wait_for_message(self):
        """ Throw a read event on the stack. """
        self.twitter_stream.read_until("\r\n", self.on_result)

    def on_result(self, response):
        """ Gets length of next message and reads it """
        if (response.strip() == ""):
            return self.wait_for_message()
        length = int(response.strip(), 16)
        self.twitter_stream.read_bytes(length, self.parse_json)

    def parse_json(self, response):
        """ Checks JSON message """
        if not response.strip():
            # Empty line, happens sometimes for keep alive
            return self.wait_for_message()
        try:
            response = json.loads(response)
        except ValueError:
            print "Invalid response:"
            print response
            return self.wait_for_message()

        self.parse_response(response)

    def parse_response(self, response):
        """ Parse the twitter message """
        try:
            text = response["text"]
            name = response["user"]["name"]
            username = response["user"]["screen_name"]
            avatar = response["user"]["profile_image_url_https"]
        except KeyError, exc:
            print "Invalid tweet structure, missing %s" % exc
            return self.wait_for_message()

        message = {
            "type": "tweet",
            "text": text,
            "avatar": avatar,
            "name": name,
            "username": username,
            "time": int(time.time())
        }

        broadcast_message(message)
        self.wait_for_message()
예제 #27
0
파일: client.py 프로젝트: gnap/stpclient-py
class Connection(object):
    # Constants for connection state
    _CLOSED = 0x001
    _CONNECTING = 0x002
    _STREAMING = 0x004

    '''
    timeout -1: no timeout, None: per-request setting, other: overide per-request setting
    '''
    def __init__(self, io_loop, client, timeout=-1, connect_timeout=-1, max_buffer_size=104857600):
        self.io_loop = io_loop
        self.client = client
        self.timeout = timeout
        self.connect_timeout = connect_timeout
        self.start_time = time.time()
        self.stream = None
        self._timeoutevent = None
        self._callback = None
        self._request_queue = collections.deque()
        self._request = None
        self._response = STPResponse()
        self._state = Connection._CLOSED

    @property
    def closed(self):
        return self._state == Connection._CLOSED

    def close(self):
        if self.stream is not None and not self.stream.closed():
            self.stream.close()
        self.stream = None

    def _connect(self):
        self._state = Connection._CONNECTING
        af = socket.AF_INET if self.client.unix_socket is None else socket.AF_UNIX
        self.stream = IOStream(socket.socket(af, socket.SOCK_STREAM),
                                io_loop=self.io_loop,
                                max_buffer_size=self.client.max_buffer_size)
        if self.connect_timeout is not None and self.connect_timeout > 0:
            self._timeoutevent = self.io_loop.add_timeout(time.time() + self.connect_timeout,
                                                            self._on_timeout)
        self.stream.set_close_callback(self._on_close)
        addr = self.client.unix_socket if self.client.unix_socket is not None else (self.client.host, self.client.port)
        self.stream.connect(addr, self._on_connect)

    def _on_connect(self):
        if self._timeoutevent is not None:
            self.io_loop.remove_timeout(self._timeoutevent)
            self._timeoutevent = None
        self._state = Connection._STREAMING
        self._send_request()

    def _on_timeout(self):
        self._timeoutevent = None
        self._run_callback(STPResponse(request_time=time.time() - self.start_time,
                error=exceptions.STPTimeoutError('Timeout')))
        if self.stream is not None:
            self.stream.close()
        self.stream = None
        self._state = Connection._CLOSED
        self._request = None
        if len(self._request_queue) > 0:
            self._connect_and_send_request()

    def _on_close(self):
        self._run_callback(STPResponse(request_time=time.time() - self.start_time,
                error=exceptions.STPNetworkError('Connection error')))
        self._state = Connection._CLOSED
        self._request = None
        if len(self._request_queue) > 0:
            self._connect_and_send_request()

    def send_request(self, request, callback):
        self._request_queue.append((request, callback))
        self._connect_and_send_request()

    def _connect_and_send_request(self):
        if len(self._request_queue) > 0 and self._request is None:
            self._request, self._callback = self._request_queue.popleft()
            if self.stream is None or self._state == Connection._CLOSED:
                self._connect()
            elif self._state == Connection._STREAMING:
                self._send_request()

    def _send_request(self):
        def write_callback():
            '''tornado needs it'''
            pass
        timeout = self.timeout
        if self._request.request_timeout is not None:
            timeout = self._request.request_timeout
        if timeout is not None and timeout > 0:
            self._timeoutevent = self.io_loop.add_timeout(time.time() + timeout, self._on_timeout)
        self.start_time = time.time()
        self.stream.write(self._request.serialize(), write_callback)
        self._read_arg()

    def _run_callback(self, response):
        if self._callback is not None:
            callback = self._callback
            self._callback = None
            callback(response)

    def _read_arg(self):
        self.stream.read_until(b'\r\n', self._on_arglen)

    def _on_arglen(self, data):
        if data == '\r\n':
            response = self._response
            self._response = STPResponse()
            response.request_time = time.time() - self.start_time
            if self._timeoutevent is not None:
                self.io_loop.remove_timeout(self._timeoutevent)
            self._run_callback(response)
            self._request = None
            self._connect_and_send_request()
        else:
            try:
                arglen = int(data[:-2])
                self.stream.read_bytes(arglen, self._on_arg)
            except Exception as e:
                self._run_callback(STPResponse(request_time=time.time() - self.start_time,
                        error=exceptions.STPProtocolError(str(e))))

    def _on_arg(self, data):
        self._response._argv.append(data)
        self.stream.read_until(b'\r\n', self._on_strip_arg_eol)

    def _on_strip_arg_eol(self, data):
        self._read_arg()
예제 #28
0
class Connection(object):
    def __init__(self, host, port, event_handler,
                 stop_after=None, io_loop=None):
        self.host = host
        self.port = port
        self._event_handler = weakref.proxy(event_handler)
        self.timeout = stop_after
        self._stream = None
        self._io_loop = io_loop
        self.try_left = 2

        self.in_progress = False
        self.read_callbacks = []
        self.ready_callbacks = deque()

    def __del__(self):
        self.disconnect()

    def __enter__(self):
        return self

    def __exit__(self, *args, **kwargs):
        if self.ready_callbacks:
            # Pop a SINGLE callback from the queue and execute it.
            # The next one will be executed from the code
            # invoked by the callback
            callback = self.ready_callbacks.popleft()
            callback()

    def ready(self):
        return not self.read_callbacks and not self.ready_callbacks

    def wait_until_ready(self, callback=None):
        if callback:
            if not self.ready():
                self.ready_callbacks.append(callback)
            else:
                callback()
        return self

    def connect(self):
        if not self._stream:
            try:
                sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM, 0)
                sock.setsockopt(socket.SOL_TCP, socket.TCP_NODELAY, 1)
                sock.settimeout(self.timeout)
                sock.connect((self.host, self.port))
                self._stream = IOStream(sock, io_loop=self._io_loop)
                self._stream.set_close_callback(self.on_stream_close)
                self.connected()
            except socket.error as e:
                raise ConnectionError(str(e))
            self.fire_event('on_connect')

    def on_stream_close(self):
        if self._stream:
            self._stream = None
            callbacks = self.read_callbacks
            self.read_callbacks = []
            for callback in callbacks:
                callback(None)

    def disconnect(self):
        if self._stream:
            s = self._stream
            self._stream = None
            try:
                s.socket.shutdown(socket.SHUT_RDWR)
                s.close()
            except socket.error:
                pass

    def fire_event(self, event):
        if self._event_handler:
            try:
                getattr(self._event_handler, event)()
            except AttributeError:
                pass

    def write(self, data, try_left=None):
        if try_left is None:
            try_left = self.try_left
        if not self._stream:
            self.connect()
            if not self._stream:
                raise ConnectionError('Tried to write to '
                                      'non-existent connection')

        if try_left > 0:
            try:
                self._stream.write(data)
            except IOError:
                self.disconnect()
                self.write(data, try_left - 1)
        else:
            raise ConnectionError('Tried to write to non-existent connection')

    def read(self, length, callback=None):
        try:
            if not self._stream:
                self.disconnect()
                raise ConnectionError('Tried to read from '
                                      'non-existent connection')
            self.read_callbacks.append(callback)
            self._stream.read_bytes(length,
                                    callback=partial(self.read_callback,
                                                     callback))
        except IOError:
            self.fire_event('on_disconnect')

    def read_callback(self, callback, *args, **kwargs):
        self.read_callbacks.remove(callback)
        callback(*args, **kwargs)

    def readline(self, callback=None):
        try:
            if not self._stream:
                self.disconnect()
                raise ConnectionError('Tried to read from '
                                      'non-existent connection')
            self.read_callbacks.append(callback)
            self._stream.read_until('\r\n',
                                    callback=partial(self.read_callback,
                                                     callback))
        except IOError:
            self.fire_event('on_disconnect')

    def connected(self):
        if self._stream:
            return True
        return False
예제 #29
0
class KafkaTornado(BaseKafka):
    def __init__(self, *args, **kwargs):
        if 'io_loop' in kwargs:
            self._io_loop = kwargs['io_loop']
            del kwargs['io_loop']
        else:
            self._io_loop = None
        BaseKafka.__init__(self, *args, **kwargs)

        self._stream = None

    # Socket management methods

    def _connect(self):
        """ Connect to the Kafka server. """

        sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM, 0)

        try:
            sock.connect((self.host, self.port))
        except Exception:
            raise ConnectionFailure("Could not connect to kafka at {0}:{1}".format(self.host, self.port))
        else:
            self._stream = IOStream(sock, io_loop=self._io_loop)

    def _disconnect(self):
        """ Disconnect from the remote server & close the socket. """
        try:
            self._stream.close()
        except IOError:
            pass
        finally:
            self._stream = None

    def _read(self, length, callback=None):
        """ Send a read request to the remote Kafka server. """

        if callback is None:
            callback = lambda v: v

        if not self._stream:
            self._connect()

        return self._stream.read_bytes(length, callback)

    def _write(self, data, callback=None, retries=BaseKafka.MAX_RETRY):
        """ Write `data` to the remote Kafka server. """

        if callback is None:
            callback = lambda: None

        if not self._stream:
            self._connect()

        try:
            return self._stream.write(data, callback)
        except IOError:
            if retries > 0:
                self._stream = None
                retries_left = retries - 1
                socket_log.warn('Write failure, retrying ({0} retries left)'.format(retries_left))
                return self._write(data, callback, retries_left)
            else:
                raise
예제 #30
0
class Client(object):

    def __init__(self, host='localhost', port=11300,
                 connect_timeout=socket.getdefaulttimeout(), io_loop=None):
        self._connect_timeout = connect_timeout
        self.host = host
        self.port = port
        self.io_loop = io_loop or IOLoop.instance()
        self._stream = None
        self._using = 'default'  # current tube
        self._watching = set(['default'])   # set of watched tubes
        self._queue = deque()
        self._talking = False
        self._reconnect_cb = None

    def _reconnect(self):
        # wait some time before trying to re-connect
        self.io_loop.add_timeout(time.time() + RECONNECT_TIMEOUT,
                lambda: self.connect(self._reconnected))

    def _reconnected(self):
        # re-establish the used tube and tubes being watched
        watch = self._watching.difference(['default'])
        # ignore "default", if it is not in the client's watch list
        ignore = set(['default']).difference(self._watching)

        def do_next(_=None):
            try:
                if watch:
                    self.watch(watch.pop(), do_next)
                elif ignore:
                    self.ignore(ignore.pop(), do_next)
                elif self._using != 'default':
                    # change the tube used, and callback to user
                    self.use(self._using, self._reconnect_cb)
                elif self._reconnect_cb:
                    # callback to user
                    self._reconnect_cb()
            except:
                # ignored, as next re-connect will retry the operation
                pass

        do_next()

    @coroutine
    def connect(self):
        """Connect to beanstalkd server."""
        if not self.closed():
            return
        self._talking = False
        self._socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM,
                socket.IPPROTO_TCP)
        self._stream = IOStream(self._socket, io_loop=self.io_loop)
        self._stream.set_close_callback(self._reconnect)
        yield Task(self._stream.connect, (self.host, self.port))

    def set_reconnect_callback(self, callback):
        """Set callback to be called if connection has been lost and
        re-established again.

        If the connection is closed unexpectedly, the client will automatically
        attempt to re-connect with 1 second intervals. After re-connecting, the
        client will attempt to re-establish the used tube and watched tubes.
        """
        self._reconnect_cb = callback

    @coroutine
    def close(self):
        """Close connection to server."""
        key = object()
        if self._stream:
            self._stream.set_close_callback((yield Callback(key)))
        if not self.closed():
            yield Task(self._stream.write, b'quit\r\n')
            self._stream.close()
            yield Wait(key)

    def closed(self):
        """"Returns True if the connection is closed."""
        return not self._stream or self._stream.closed()

    def _interact(self, request, callback):
        # put the interaction request into the FIFO queue
        cb = stack_context.wrap(callback)
        self._queue.append((request, cb))
        self._process_queue()

    def _process_queue(self):
        if self._talking or not self._queue:
            return
        # pop a request of the queue and perform the send-receive interaction
        self._talking = True
        with stack_context.NullContext():
            req, cb = self._queue.popleft()
            command = req.cmd + b'\r\n'
            if req.body:
                command += req.body + b'\r\n'

            # write command and body to socket stream
            self._stream.write(command,
                    # when command is written: read line from socket stream
                    lambda: self._stream.read_until(b'\r\n',
                    # when a line has been read: return status and results
                    lambda data: self._recv(req, data, cb)))

    def _recv(self, req, data, cb):
        # parse the data received as server response
        spl = data.decode('utf8').split()
        status, values = spl[0], spl[1:]

        error = None
        err_args = ObjectDict(request=req, status=status, values=values)

        if req.ok and status in req.ok:
            # avoid raising a Buried exception when using the bury command
            pass
        elif status == 'BURIED':
            error = Buried(**err_args)
        elif status == 'TIMED_OUT':
            error = TimedOut(**err_args)
        elif status == 'DEADLINE_SOON':
            error = DeadlineSoon(**err_args)
        elif req.err and status in req.err:
            error = CommandFailed(**err_args)
        else:
            error = UnexpectedResponse(**err_args)

        resp = Bunch(req=req, status=status, values=values, error=error)

        if error or not req.read_body:
            # end the request and callback with results
            self._do_callback(cb, resp)
        else:
            # read the body including the terminating two bytes of crlf
            if len(values) == 2:
                job_id, size = int(values[0]), int(values[1])
                resp.job_id = int(job_id)
            else:
                size = int(values[0])
            self._stream.read_bytes(size + 2,
                    lambda data: self._recv_body(data[:-2], resp, cb))

    def _recv_body(self, data, resp, cb):
        if resp.req.parse_yaml:
            # parse the yaml encoded body
            self._parse_yaml(data, resp, cb)
        else:
            # don't parse body, it is a job!
            # end the request and callback with results
            resp.body = ObjectDict(id=resp.job_id, body=data)
            self._do_callback(cb, resp)

    def _parse_yaml(self, data, resp, cb):
        # dirty parsing of yaml data
        # (assumes that data is a yaml encoded list or dict)
        spl = data.decode('utf8').split('\n')[1:-1]
        if spl[0].startswith('- '):
            # it is a list
            resp.body = [s[2:] for s in spl]
        else:
            # it is a dict
            conv = lambda v: ((float(v) if '.' in v else int(v))
                if v.replace('.', '', 1).isdigit() else v)
            resp.body = ObjectDict((k, conv(v.strip())) for k, v in
                    (s.split(':') for s in spl))
        self._do_callback(cb, resp)

    def _do_callback(self, cb, resp):
        # end the request and process next item in the queue
        # and callback with results
        self._talking = False
        self.io_loop.add_callback(self._process_queue)

        if not cb:
            return

        # default is to callback with error state (None or exception)
        obj = None
        req = resp.req

        if resp.error:
            obj = resp.error

        elif req.read_value:
            # callback with an integer value or a string
            if resp.values[0].isdigit():
                obj = int(resp.values[0])
            else:
                obj = resp.values[0]

        elif req.read_body:
            # callback with the body (job or parsed yaml)
            obj = resp.body

        self.io_loop.add_callback(lambda: cb(obj))

    #
    #  Producer commands
    #

    @coroutine
    def put(self, body, priority=DEFAULT_PRIORITY, delay=0, ttr=120):
        """Put a job body (a byte string) into the current tube.

        The job can be delayed a number of seconds, before it is put in the
        ready queue, default is no delay.

        The job is assigned a Time To Run (ttr, in seconds), the mininum is
        1 sec., default is ttr=120 sec.

        Calls back with id when job is inserted. If an error occured,
        the callback gets a Buried or CommandFailed exception. The job is
        buried when either the body is too big, so server ran out of memory,
        or when the server is in draining mode.
        """
        cmd = 'put {} {} {} {}'.format(priority, delay, ttr,
            len(body)).encode('utf8')
        assert isinstance(body, bytes)
        request = Bunch(cmd=cmd, ok=['INSERTED'], err=['BURIED', 'JOB_TOO_BIG',
                'DRAINING'], body=body, read_value=True)
        resp = yield Task(self._interact, request)
        raise Return(resp)

    @coroutine
    def use(self, name):
        """Use the tube with given name.

        Calls back with the name of the tube now being used.
        """
        cmd = 'use {}'.format(name).encode('utf8')
        request = Bunch(cmd=cmd, ok=['USING'],
                read_value=True)
        resp = yield Task(self._interact, request)
        if not isinstance(resp, Exception):
            self._using = resp
        raise Return(resp)

    #
    #  Worker commands
    #

    @coroutine
    def reserve(self, timeout=None):
        """Reserve a job from one of the watched tubes, with optional timeout
        in seconds.

        Not specifying a timeout (timeout=None, the default) will make the
        client put the communication with beanstalkd on hold, until either a
        job is reserved, or a already reserved job is approaching it's TTR
        deadline. Commands issued while waiting for the "reserve" callback will
        be queued and sent in FIFO order, when communication is resumed.

        A timeout value of 0 will cause the server to immediately return either
        a response or TIMED_OUT. A positive value of timeout will limit the
        amount of time the client will will hold communication until a job
        becomes available.

        Calls back with a job dict (keys id and body). If the request timed out,
        the callback gets a TimedOut exception. If a reserved job has deadline
        within the next second, the callback gets a DeadlineSoon exception.
        """
        if timeout is not None:
            cmd = 'reserve-with-timeout {}'.format(timeout).encode('utf8')
        else:
            cmd = b'reserve'
        request = Bunch(cmd=cmd, ok=['RESERVED'], err=['DEADLINE_SOON',
                'TIMED_OUT'], read_body=True)
        resp = yield Task(self._interact, request)
        raise Return(resp)

    @coroutine
    def delete(self, job_id):
        """Delete job with given id.

        Calls back when job is deleted. If the job does not exist, or it is not
        neither reserved by the client, ready or buried; the callback gets a
        CommandFailed exception.
        """
        cmd = 'delete {}'.format(job_id).encode('utf8')
        request = Bunch(cmd=cmd, ok=['DELETED'], err=['NOT_FOUND'])
        resp = yield Task(self._interact, request)
        raise Return(resp)

    @coroutine
    def release(self, job_id, priority=DEFAULT_PRIORITY, delay=0):
        """Release a reserved job back into the ready queue.

        A new priority can be assigned to the job.

        It is also possible to specify a delay (in seconds) to wait before
        putting the job in the ready queue. The job will be in the "delayed"
        state during this time.

        Calls back when job is released. If the job was buried, the callback
        gets a Buried exception. If the job does not exist, or it is not
        reserved by the client, the callback gets a CommandFailed exception.
        """
        cmd = 'release {} {} {}'.format(job_id, priority, delay).encode('utf8')
        request = Bunch(cmd=cmd, ok=['RELEASED'], err=['BURIED', 'NOT_FOUND'])
        resp = yield Task(self._interact, request)
        raise Return(resp)

    @coroutine
    def bury(self, job_id, priority=DEFAULT_PRIORITY):
        """Bury job with given id.

        A new priority can be assigned to the job.

        Calls back when job is burried. If the job does not exist, or it is not
        reserved by the client, the callback gets a CommandFailed exception.
        """
        cmd = 'bury {} {}'.format(job_id, priority).encode('utf8')
        request = Bunch(cmd=cmd, ok=['BURIED'], err=['NOT_FOUND'])
        resp = yield Task(self._interact, request)
        raise Return(resp)

    @coroutine
    def touch(self, job_id):
        """Touch job with given id.

        This is for requesting more time to work on a reserved job before it
        expires.

        Calls back when job is touched. If the job does not exist, or it is not
        reserved by the client, the callback gets a CommandFailed exception.
        """
        cmd = 'touch {}'.format(job_id).encode('utf8')
        request = Bunch(cmd=cmd, ok=['TOUCHED'], err=['NOT_FOUND'])
        resp = yield Task(self._interact, request)
        raise Return(resp)

    @coroutine
    def watch(self, name):
        """Watch tube with given name.

        Calls back with number of tubes currently in the watch list.
        """
        cmd = 'watch {}'.format(name).encode('utf8')
        request = Bunch(cmd=cmd, ok=['WATCHING'], read_value=True)
        resp = yield Task(self._interact, request)
        # add to the client's watch list
        self._watching.add(name)
        raise Return(resp)

    @coroutine
    def ignore(self, name):
        """Stop watching tube with given name.

        Calls back with the number of tubes currently in the watch list. On an
        attempt to ignore the only tube in the watch list, the callback gets a
        CommandFailed exception.
        """
        cmd = 'ignore {}'.format(name).encode('utf8')
        request = Bunch(cmd=cmd, ok=['WATCHING'], err=['NOT_IGNORED'],
                read_value=True)
        resp = yield Task(self._interact, request)
        if name in self._watching:
            # remove from the client's watch list
            self._watching.remove(name)
        raise Return(resp)

    #
    #  Other commands
    #

    def _peek(self, variant, callback):
        # a shared gateway for the peek* commands
        cmd = 'peek{}'.format(variant).encode('utf8')
        request = Bunch(cmd=cmd, ok=['FOUND'], err=['NOT_FOUND'],
                read_body=True)
        self._interact(request, callback)

    @coroutine
    def peek(self, job_id):
        """Peek at job with given id.

        Calls back with a job dict (keys id and body). If no job exists with
        that id, the callback gets a CommandFailed exception.
        """
        resp = yield Task(self._peek, ' {}'.format(job_id))
        raise Return(resp)

    @coroutine
    def peek_ready(self):
        """Peek at next ready job in the current tube.

        Calls back with a job dict (keys id and body). If no ready jobs exist,
        the callback gets a CommandFailed exception.
        """
        resp = yield Task(self._peek, '-ready')
        raise Return(resp)

    @coroutine
    def peek_delayed(self):
        """Peek at next delayed job in the current tube.

        Calls back with a job dict (keys id and body). If no delayed jobs exist,
        the callback gets a CommandFailed exception.
        """
        resp = yield Task(self._peek, '-delayed')
        raise Return(resp)

    @coroutine
    def peek_buried(self):
        """Peek at next buried job in the current tube.

        Calls back with a job dict (keys id and body). If no buried jobs exist,
        the callback gets a CommandFailed exception.
        """
        resp = yield Task(self._peek, '-buried')
        raise Return(resp)

    @coroutine
    def kick(self, bound=1):
        """Kick at most `bound` jobs into the ready queue from the current tube.

        Calls back with the number of jobs actually kicked.
        """
        cmd = 'kick {}'.format(bound).encode('utf8')
        request = Bunch(cmd=cmd, ok=['KICKED'], read_value=True)
        resp = yield Task(self._interact, request)
        raise Return(resp)

    @coroutine
    def kick_job(self, job_id):
        """Kick job with given id into the ready queue.
        (Requires Beanstalkd version >= 1.8)

        Calls back when job is kicked. If no job exists with that id, or if
        job is not in a kickable state, the callback gets a CommandFailed
        exception.
        """
        cmd = 'kick-job {}'.format(job_id).encode('utf8')
        request = Bunch(cmd=cmd, ok=['KICKED'], err=['NOT_FOUND'])
        resp = yield Task(self._interact, request)
        raise Return(resp)

    @coroutine
    def stats_job(self, job_id):
        """A dict of stats about the job with given id.

        If no job exists with that id, the callback gets a CommandFailed
        exception.
        """
        cmd = 'stats-job {}'.format(job_id).encode('utf8')
        request = Bunch(cmd=cmd, ok=['OK'], err=['NOT_FOUND'], read_body=True,
                parse_yaml=True)
        resp = yield Task(self._interact, request)
        raise Return(resp)

    @coroutine
    def stats_tube(self, name):
        """A dict of stats about the tube with given name.

        If no tube exists with that name, the callback gets a CommandFailed
        exception.
        """
        cmd = 'stats-tube {}'.format(name).encode('utf8')
        request = Bunch(cmd=cmd, ok=['OK'], err=['NOT_FOUND'], read_body=True,
                parse_yaml=True)
        resp = yield Task(self._interact, request)
        raise Return(resp)

    @coroutine
    def stats(self):
        """A dict of beanstalkd statistics."""
        request = Bunch(cmd=b'stats', ok=['OK'], read_body=True,
                parse_yaml=True)
        resp = yield Task(self._interact, request)
        raise Return(resp)

    @coroutine
    def list_tubes(self):
        """List of all existing tubes."""
        request = Bunch(cmd=b'list-tubes', ok=['OK'], read_body=True,
                parse_yaml=True)
        resp = yield Task(self._interact, request)
        raise Return(resp)

    @coroutine
    def list_tube_used(self):
        """Name of the tube currently being used."""
        request = Bunch(cmd=b'list-tube-used', ok=['USING'], read_value=True)
        resp = yield Task(self._interact, request)
        raise Return(resp)

    @coroutine
    def list_tubes_watched(self):
        """List of tubes currently being watched."""
        request = Bunch(cmd=b'list-tubes-watched', ok=['OK'], read_body=True,
                parse_yaml=True)
        resp = yield Task(self._interact, request)
        raise Return(resp)

    @coroutine
    def pause_tube(self, name, delay):
        """Delay any new job being reserved from the tube for a given time.

        The delay is an integer number of seconds to wait before reserving any
        more jobs from the queue.

        Calls back when tube is paused. If tube does not exists, the callback
        will get a CommandFailed exception.
        """
        cmd = 'pause-tube {} {}'.format(name, delay).encode('utf8')
        request = Bunch(cmd=cmd, ok=['PAUSED'], err=['NOT_FOUND'])
        resp = yield Task(self._interact, request)
        raise Return(resp)
예제 #31
0
class TestIOStream(AsyncHTTPTestCase, LogTrapTestCase):
    def get_app(self):
        return Application([('/', HelloHandler)])

    def make_iostream_pair(self, **kwargs):
        port = get_unused_port()
        [listener] = netutil.bind_sockets(port,
                                          '127.0.0.1',
                                          family=socket.AF_INET)
        streams = [None, None]

        def accept_callback(connection, address):
            streams[0] = IOStream(connection, io_loop=self.io_loop, **kwargs)
            self.stop()

        def connect_callback():
            streams[1] = client_stream
            self.stop()

        netutil.add_accept_handler(listener,
                                   accept_callback,
                                   io_loop=self.io_loop)
        client_stream = IOStream(socket.socket(),
                                 io_loop=self.io_loop,
                                 **kwargs)
        client_stream.connect(('127.0.0.1', port), callback=connect_callback)
        self.wait(condition=lambda: all(streams))
        self.io_loop.remove_handler(listener.fileno())
        listener.close()
        return streams

    def test_read_zero_bytes(self):
        s = socket.socket(socket.AF_INET, socket.SOCK_STREAM, 0)
        s.connect(("localhost", self.get_http_port()))
        self.stream = IOStream(s, io_loop=self.io_loop)
        self.stream.write(b("GET / HTTP/1.0\r\n\r\n"))

        # normal read
        self.stream.read_bytes(9, self.stop)
        data = self.wait()
        self.assertEqual(data, b("HTTP/1.0 "))

        # zero bytes
        self.stream.read_bytes(0, self.stop)
        data = self.wait()
        self.assertEqual(data, b(""))

        # another normal read
        self.stream.read_bytes(3, self.stop)
        data = self.wait()
        self.assertEqual(data, b("200"))

    def test_connection_refused(self):
        # When a connection is refused, the connect callback should not
        # be run.  (The kqueue IOLoop used to behave differently from the
        # epoll IOLoop in this respect)
        port = get_unused_port()
        stream = IOStream(socket.socket(), self.io_loop)
        self.connect_called = False

        def connect_callback():
            self.connect_called = True

        stream.set_close_callback(self.stop)
        stream.connect(("localhost", port), connect_callback)
        self.wait()
        self.assertFalse(self.connect_called)

    def test_connection_closed(self):
        # When a server sends a response and then closes the connection,
        # the client must be allowed to read the data before the IOStream
        # closes itself.  Epoll reports closed connections with a separate
        # EPOLLRDHUP event delivered at the same time as the read event,
        # while kqueue reports them as a second read/write event with an EOF
        # flag.
        response = self.fetch("/", headers={"Connection": "close"})
        response.rethrow()

    def test_read_until_close(self):
        s = socket.socket(socket.AF_INET, socket.SOCK_STREAM, 0)
        s.connect(("localhost", self.get_http_port()))
        stream = IOStream(s, io_loop=self.io_loop)
        stream.write(b("GET / HTTP/1.0\r\n\r\n"))

        stream.read_until_close(self.stop)
        data = self.wait()
        self.assertTrue(data.startswith(b("HTTP/1.0 200")))
        self.assertTrue(data.endswith(b("Hello")))

    def test_streaming_callback(self):
        server, client = self.make_iostream_pair()
        try:
            chunks = []
            final_called = []

            def streaming_callback(data):
                chunks.append(data)
                self.stop()

            def final_callback(data):
                assert not data
                final_called.append(True)
                self.stop()

            server.read_bytes(6,
                              callback=final_callback,
                              streaming_callback=streaming_callback)
            client.write(b("1234"))
            self.wait(condition=lambda: chunks)
            client.write(b("5678"))
            self.wait(condition=lambda: final_called)
            self.assertEqual(chunks, [b("1234"), b("56")])

            # the rest of the last chunk is still in the buffer
            server.read_bytes(2, callback=self.stop)
            data = self.wait()
            self.assertEqual(data, b("78"))
        finally:
            server.close()
            client.close()

    def test_streaming_until_close(self):
        server, client = self.make_iostream_pair()
        try:
            chunks = []

            def callback(data):
                chunks.append(data)
                self.stop()

            client.read_until_close(callback=callback,
                                    streaming_callback=callback)
            server.write(b("1234"))
            self.wait()
            server.write(b("5678"))
            self.wait()
            server.close()
            self.wait()
            self.assertEqual(chunks, [b("1234"), b("5678"), b("")])
        finally:
            server.close()
            client.close()

    def test_delayed_close_callback(self):
        # The scenario:  Server closes the connection while there is a pending
        # read that can be served out of buffered data.  The client does not
        # run the close_callback as soon as it detects the close, but rather
        # defers it until after the buffered read has finished.
        server, client = self.make_iostream_pair()
        try:
            client.set_close_callback(self.stop)
            server.write(b("12"))
            chunks = []

            def callback1(data):
                chunks.append(data)
                client.read_bytes(1, callback2)
                server.close()

            def callback2(data):
                chunks.append(data)

            client.read_bytes(1, callback1)
            self.wait()  # stopped by close_callback
            self.assertEqual(chunks, [b("1"), b("2")])
        finally:
            server.close()
            client.close()

    def test_close_buffered_data(self):
        # Similar to the previous test, but with data stored in the OS's
        # socket buffers instead of the IOStream's read buffer.  Out-of-band
        # close notifications must be delayed until all data has been
        # drained into the IOStream buffer. (epoll used to use out-of-band
        # close events with EPOLLRDHUP, but no longer)
        #
        # This depends on the read_chunk_size being smaller than the
        # OS socket buffer, so make it small.
        server, client = self.make_iostream_pair(read_chunk_size=256)
        try:
            server.write(b("A") * 512)
            client.read_bytes(256, self.stop)
            data = self.wait()
            self.assertEqual(b("A") * 256, data)
            server.close()
            # Allow the close to propagate to the client side of the
            # connection.  Using add_callback instead of add_timeout
            # doesn't seem to work, even with multiple iterations
            self.io_loop.add_timeout(time.time() + 0.01, self.stop)
            self.wait()
            client.read_bytes(256, self.stop)
            data = self.wait()
            self.assertEqual(b("A") * 256, data)
        finally:
            server.close()
            client.close()
예제 #32
0
파일: green.py 프로젝트: alex8224/gTornado
class AsyncSocket(object):
    def __init__(self, sock):
        self._iostream = IOStream(sock)
        self._resolver = Resolver()
        self._readtimeout = 0
        self._connecttimeout = 0
   
    def set_readtimeout(self, timeout):
        self._readtimeout = timeout

    def set_connecttimeout(self, timeout):
        self._connecttimeout = timeout

    @synclize
    def connect(self, address):
        host, port = address
        timer = None
        try:
            if self._connecttimeout:
                timer = Timeout(self._connecttimeout)
                timer.start()
            resolved_addrs = yield self._resolver.resolve(host, port, family=socket.AF_INET)
            for addr in resolved_addrs:
                family, host_port = addr
                yield self._iostream.connect(host_port)
                break
        except TimeoutException:
            self.close()
            raise
        finally:
            if timer:
                timer.cancel()
    #@synclize
    def sendall(self, buff):
        self._iostream.write(buff)

    @synclize
    def read(self, nbytes, partial=False):
        timer = None
        try:
            if self._readtimeout:
                timer = Timeout(self._readtimeout)
                timer.start()
            buff = yield self._iostream.read_bytes(nbytes, partial=partial)
            raise Return(buff)
        except TimeoutException:
            self.close()
            raise
        finally:
            if timer:
                timer.cancel()

    def recv(self, nbytes):
        return self.read(nbytes, partial=True)

    @synclize
    def readline(self, max_bytes=-1):
        timer = None
        if self._readtimeout:
            timer = Timeout(self._readtimeout)
            timer.start()
        try:
            if max_bytes > 0:
                buff = yield self._iostream.read_until('\n', max_bytes=max_bytes)
            else:
                buff = yield self._iostream.read_until('\n')
            raise Return(buff)
        except TimeoutException:
            self.close()
            raise
        finally:
            if timer:
                timer.cancel()

    def close(self):
        self._iostream.close()

    def set_nodelay(self, flag):
        self._iostream.set_nodelay(flag)

    def settimeout(self, timeout):
        pass

    def shutdown(self, direction):
        if self._iostream.fileno():
            self._iostream.fileno().shutdown(direction)

    def recv_into(self, buff):
        expected_rbytes = len(buff)
        data = self.read(expected_rbytes, True)
        srcarray = bytearray(data)
        nbytes = len(srcarray)
        buff[0:nbytes] = srcarray
        return nbytes

    def makefile(self, mode, other):
        return self
예제 #33
0
class Connection(object):
    def __init__(self,
                 host,
                 port,
                 on_connect,
                 on_disconnect,
                 timeout=None,
                 io_loop=None):
        self.host = host
        self.port = port
        self.on_connect = on_connect
        self.on_disconnect = on_disconnect
        self.timeout = timeout
        self._stream = None
        self._io_loop = io_loop
        self.try_left = 2

        self.in_progress = False
        self.read_queue = []

    def connect(self):
        try:
            sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM, 0)
            sock.setsockopt(socket.SOL_TCP, socket.TCP_NODELAY, 1)
            sock.settimeout(self.timeout)
            sock.connect((self.host, self.port))
            self._stream = IOStream(sock, io_loop=self._io_loop)
            self.connected()
        except socket.error as e:
            raise ConnectionError(str(e))
        self.on_connect()

    def disconnect(self):
        if self._stream:
            try:
                self._stream.close()
            except socket.error as e:
                pass
            self._stream = None

    def write(self, data, try_left=None):
        if try_left is None:
            try_left = self.try_left
        if not self._stream:
            self.connect()
            if not self._stream:
                raise ConnectionError(
                    'Tried to write to non-existent connection')

        if try_left > 0:
            try:
                self._stream.write(data)
            except IOError:
                self.disconnect()
                self.write(data, try_left - 1)
        else:
            raise ConnectionError('Tried to write to non-existent connection')

    def read(self, length, callback):
        try:
            if not self._stream:
                self.disconnect()
                raise ConnectionError(
                    'Tried to read from non-existent connection')
            self._stream.read_bytes(length, callback)
        except IOError:
            self.on_disconnect()

    def readline(self, callback):
        try:
            if not self._stream:
                self.disconnect()
                raise ConnectionError(
                    'Tried to read from non-existent connection')
            self._stream.read_until(b'\r\n', callback)
        except Exception as e:
            self.on_disconnect()

    def try_to_perform_read(self):
        if not self.in_progress and self.read_queue:
            self.in_progress = True
            self._io_loop.add_callback(partial(self.read_queue.pop(0), None))

    @async
    def queue_wait(self, callback):
        self.read_queue.append(callback)
        self.try_to_perform_read()

    def read_done(self):
        self.in_progress = False
        self.try_to_perform_read()

    def connected(self):
        if self._stream:
            return True
        return False
예제 #34
0
파일: tcpconn.py 프로젝트: goodfeng/hello
class NTcpConnector(object):
    def __init__(self, host, port):
        self.routes = {}
        self.host = host
        self.port = port
        self._s = socket.socket(socket.AF_INET, socket.SOCK_STREAM, 0)
        self.stream = IOStream(self._s)
        self.stream.connect((self.host, self.port), self._start_recv)

    def unregister(self, client):
        self.routes = dict(
            filter(lambda x: x[1] != client, self.routes.items()))

    def __lt__(self, other):
        return id(self) < id(other)

    def sendMsg(self, client, content):
        sn = client.application.proxys.getSN()
        self.routes[sn] = client
        data = struct.pack('<i6I%dsI' % len(content), int(-1), 10020,
                           20 + len(content), sn, 0, int(time.time()), 1,
                           content.encode('utf-8'),
                           int((20 + len(content)) ^ 0xaaaaaaaa))
        self.stream.write(data)

    def is_connected(self):
        return not self.stream.closed()

    def invalidate(self):
        self.stream.close_fd()

    def _start_recv(self):
        self.stream.read_bytes(12, self._on_frame)

    def _on_frame(self, data):
        nLen = struct.unpack('<i2I', data)[2]
        self.stream.read_bytes(nLen, self._on_msg)

    def _on_msg(self, data):
        nLen = len(data)
        sn, nTag, nTime, nCmdId, dataS = struct.unpack('<4I%dsI' % (nLen - 20),
                                                       data)[0:-1]

        if sn == 0:
            self.stream.write(
                struct.pack('<i7I', int(-1), 10000, 20, 0, 0, int(time.time()),
                            0, int(20 ^ 0xaaaaaaaa)))
        elif sn > 0 and (sn in self.routes):
            fs, strField = {}, ''
            if nCmdId == 110 and nLen == 292:  #十档报价
                ds = struct.unpack(
                    '<2iIqIqIqIqIqIqIqIqIqIqIqIqIqIqIqIqIqIqIqIqIqIq', dataS)
                strField = 'nSecurityID,nTime,nPxBid1,llVolumeBid1,nPxBid2,llVolumeBid2,nPxBid3,llVolumeBid3,nPxBid4,llVolumeBid4,nPxBid5,llVolumeBid5,nPxBid6,llVolumeBid6,nPxBid7,llVolumeBid7,nPxBid8,llVolumeBid8,nPxBid9,llVolumeBid9,nPxBid10,llVolumeBid10,nWeightedAvgBidPx,llTotalBidVolume,nPxOffer1,llVolumeOffer1,nPxOffer2,llVolumeOffer2,nPxOffer3,llVolumeOffer3,nPxOffer4,llVolumeOffer4,nPxOffer5,llVolumeOffer5,nPxOffer6,llVolumeOffer6,nPxOffer7,llVolumeOffer7,nPxOffer8,llVolumeOffer8,nPxOffer9,llVolumeOffer9,nPxOffer10,llVolumeOffer10,nWeightedAvgOfferPx,llTotalOfferVolume'
            elif nCmdId == 165 and nLen == 644:  #委托明细
                ds = struct.unpack('<2iI3i150i', dataS)
                strField = 'nSecurityID,nTime,nPx,nLevel,nOrderCount,nRevealCount,nStatus1,nVolume1,nChangeVolume1,nStatus2,nVolume2,nChangeVolume2,nStatus3,nVolume3,nChangeVolume3,nStatus4,nVolume4,nChangeVolume4,nStatus5,nVolume5,nChangeVolume5,nStatus6,nVolume6,nChangeVolume6,nStatus7,nVolume7,nChangeVolume7,nStatus8,nVolume8,nChangeVolume8,nStatus9,nVolume9,nChangeVolume9,nStatus10,nVolume10,nChangeVolume10,nStatus11,nVolume11,nChangeVolume11,nStatus12,nVolume12,nChangeVolume12,nStatus13,nVolume13,nChangeVolume13,nStatus14,nVolume14,nChangeVolume14,nStatus15,nVolume15,nChangeVolume15,nStatus16,nVolume16,nChangeVolume16,nStatus17,nVolume17,nChangeVolume17,nStatus18,nVolume18,nChangeVolume18,nStatus19,nVolume19,nChangeVolume19,nStatus20,nVolume20,nChangeVolume20,nStatus21,nVolume21,nChangeVolume21,nStatus22,nVolume22,nChangeVolume22,nStatus23,nVolume23,nChangeVolume23,nStatus24,nVolume24,nChangeVolume24,nStatus25,nVolume25,nChangeVolume25,nStatus26,nVolume26,nChangeVolume26,nStatus27,nVolume27,nChangeVolume27,nStatus28,nVolume28,nChangeVolume28,nStatus29,nVolume29,nChangeVolume29,nStatus30,nVolume30,nChangeVolume30,nStatus31,nVolume31,nChangeVolume31,nStatus32,nVolume32,nChangeVolume32,nStatus33,nVolume33,nChangeVolume33,nStatus34,nVolume34,nChangeVolume34,nStatus35,nVolume35,nChangeVolume35,nStatus36,nVolume36,nChangeVolume36,nStatus37,nVolume37,nChangeVolume37,nStatus38,nVolume38,nChangeVolume38,nStatus39,nVolume39,nChangeVolume39,nStatus40,nVolume40,nChangeVolume40,nStatus41,nVolume41,nChangeVolume41,nStatus42,nVolume42,nChangeVolume42,nStatus43,nVolume43,nChangeVolume43,nStatus44,nVolume44,nChangeVolume44,nStatus45,nVolume45,nChangeVolume45,nStatus46,nVolume46,nChangeVolume46,nStatus47,nVolume47,nChangeVolume47,nStatus48,nVolume48,nChangeVolume48,nStatus49,nVolume49,nChangeVolume49,nStatus50,nVolume50,nChangeVolume50'
            else:
                pass

            if strField:
                fields = strField.split(',')
                for i in range(0, len(fields)):
                    fs[fields[i]] = ds[i]
                fs['nCmdId'] = nCmdId
                self.routes[sn].callback(fs)

        self._start_recv()
예제 #35
0
class TornadoClient(Client):
    """A non-blocking Pomelo client by tornado ioloop

    Usage :

        class ClientHandler(object) :

            def on_recv_data(self, client, proto_type, data) :
                print "recv_data..."
                return data

            def on_connected(self, client, user_data) :
                print "connect..."
                client.send_heartbeat()

            def on_disconnect(self, client) :
                print "disconnect..."

            def on_heartbeat(self, client) :
                print "heartbeat..."
                send request ...

            def on_response(self, client, route, request, response) :
                print "response..."

            def on_push(self, client, route, push_data) :
                print "notify..."

        handler = ClientHandler()
        client = TornadoClient(handler)
        client.connect(host, int(port))
        client.run()
        tornado.ioloop.IOLoop.current().start()
    """
    def __init__(self, handler):
        self.socket = socket(AF_INET, SOCK_STREAM)
        self.iostream = None
        self.protocol_package = None
        super(TornadoClient, self).__init__(handler)

    def connect(self, host, port):
        self.iostream = IOStream(self.socket)
        self.iostream.set_close_callback(self.on_close)
        self.iostream.connect((host, port), self.on_connect)

    def on_connect(self):
        self.send_sync()
        self.on_data()

    def on_close(self):
        if hasattr(self.handler, 'on_disconnect'):
            self.handler.on_disconnect(self)

    def send(self, data):
        assert not self.iostream.closed(), "iostream has closed"
        if not isinstance(data, bytes):
            data = bytes(data)
        self.iostream.write(data)

    def on_data(self):
        assert not self.iostream.closed(), "iostream has closed"
        if None is self.protocol_package or self.protocol_package.completed():
            self.iostream.read_bytes(4, self.on_head)

    def on_head(self, head):
        self.protocol_package = Protocol.unpack(head)
        self.iostream.read_bytes(self.protocol_package.length, self.on_body)

    def on_body(self, body):
        if hasattr(self.handler, 'on_recv_data'):
            body = self.handler.on_recv_data(self,
                                             self.protocol_package.proto_type,
                                             body)
        self.protocol_package.append(body)
        self.on_protocol(self.protocol_package)
        self.on_data()

    def close(self):
        if self.iostream:
            self.iostream.close()
예제 #36
0
파일: tornstalk.py 프로젝트: nod/tornstalk
class Connection(object):
    """
    Encapsulates the communication, including parsing, with the beanstalkd
    """

    def __init__(self, host, port, io_loop=None):
        self._ioloop = io_loop or IOLoop.instance()

        # setup our protocol callbacks
        # beanstalkd will reply with a superset of these replies, but these
        # are the only ones we handle today.  patches gleefully accepted.
        self._beanstalk_protocol_1x = dict(
            # generic returns
            OUT_OF_MEMORY = self.fail,
            INTERNAL_ERROR = self.fail,
            DRAINING = self.fail,
            BAD_FORMAT = self.fail,
            UNKNOWN_COMMAND = self.fail,
            # put <pri> <delay> <ttr> <bytes>
            INSERTED = self.ret_inserted,
            BURIED = self.ret_inserted,
            EXPECTED_CRLF = self.fail,
            JOB_TOO_BIG = self.fail,
            # use
            USING = None,
            # reserve
            RESERVED = self.ret_reserved,
            DEADLINE_SOON = None,
            TIMED_OUT = None,
            # delete <id>
            DELETED = None,
            NOT_FOUND = None,
            # touch <id>
            TOUCHED = None,
            # watch <tube>
            WATCHING = None,
            #ignore <tube>
            NOT_IGNORED = None,
            )

        # open a connection to the beanstalkd
        _sock = socket.socket(
                        socket.AF_INET,
                        socket.SOCK_STREAM,
                        socket.IPPROTO_TCP
                        )
        _sock.connect((host, port))
        _sock.setblocking(False)
        self.stream = IOStream(_sock, io_loop=self._ioloop)

        # i like a placeholder for this. we'll assign it later
        self.callback = None
        self.tsr = TornStalkResponse()

    def _parse_response(self, resp):
        print "parse_response"
        tokens = resp.strip().split()
        if not tokens: return
        print 'tok:', tokens[1:]
        self._beanstalk_protocol_1x.get(tokens[0])(tokens)

    def _payload_rcvd(self, payload):
        self.tsr.data = payload[:-2] # lose the \r\n
        self.callback(self.tsr) # lose the \r\n

    def _command(self, contents):
        print "sending>%s<" % contents
        self.stream.write(contents)
        self.stream.read_until('\r\n', self._parse_response)

    def cmd_put(self, body, callback, priority=10000, delay=0, ttr=1):
        """
        send the put command to the beanstalkd with a message body

        priority needs to be between 0 and 2**32. lower gets done first
        delay is number of seconds before job is available in queue
        ttr is number of seconds the job has to run by a worker

        bs: put <pri> <delay> <ttr> <bytes>
        """
        self.callback = callback
        cmd = 'put {priority} {delay} {ttr} {size}'.format(
                priority = priority,
                delay = delay,
                ttr = ttr,
                size = len(body)
                )
        payload = '{}\r\n{}\r\n'.format(cmd, body)
        self._command(payload)

    def cmd_reserve(self, callback):
        self.callback = callback
        cmd = 'reserve\r\n'
        self._command(cmd)

    def ret_inserted(self, toks):
        """ handles both INSERTED and BURIED """
        jobid = int(toks[1])
        self.callback(TornStalkResponse(data=jobid))

    def ret_reserved(self, toks):
        jobid, size = toks[1:]
        jobid = int(jobid)
        size = int(size) + 2 # len('\r\n')
        self.stream.read_bytes(size, self._payload_rcvd)

    def handle_error(self, *a):
        print "error", a
        raise TornStalkError(a)

    def ok(self, *a):
        print "ok", a
        return True

    def fail(self, toks):
        self.callback(TornStalkResponse(result=False, msg=toks[1]))
예제 #37
0
class KeepAliveTest(AsyncHTTPTestCase):
    """Tests various scenarios for HTTP 1.1 keep-alive support.

    These tests don't use AsyncHTTPClient because we want to control
    connection reuse and closing.
    """

    def get_app(self):
        class HelloHandler(RequestHandler):
            def get(self):
                self.finish("Hello world")

            def post(self):
                self.finish("Hello world")

        class LargeHandler(RequestHandler):
            def get(self):
                # 512KB should be bigger than the socket buffers so it will
                # be written out in chunks.
                self.write("".join(chr(i % 256) * 1024 for i in range(512)))

        class FinishOnCloseHandler(RequestHandler):
            def initialize(self, cleanup_event):
                self.cleanup_event = cleanup_event

            @gen.coroutine
            def get(self):
                self.flush()
                yield self.cleanup_event.wait()

            def on_connection_close(self):
                # This is not very realistic, but finishing the request
                # from the close callback has the right timing to mimic
                # some errors seen in the wild.
                self.finish("closed")

        self.cleanup_event = Event()
        return Application(
            [
                ("/", HelloHandler),
                ("/large", LargeHandler),
                (
                    "/finish_on_close",
                    FinishOnCloseHandler,
                    dict(cleanup_event=self.cleanup_event),
                ),
            ]
        )

    def setUp(self):
        super(KeepAliveTest, self).setUp()
        self.http_version = b"HTTP/1.1"

    def tearDown(self):
        # We just closed the client side of the socket; let the IOLoop run
        # once to make sure the server side got the message.
        self.io_loop.add_timeout(datetime.timedelta(seconds=0.001), self.stop)
        self.wait()

        if hasattr(self, "stream"):
            self.stream.close()
        super(KeepAliveTest, self).tearDown()

    # The next few methods are a crude manual http client
    @gen.coroutine
    def connect(self):
        self.stream = IOStream(socket.socket())
        yield self.stream.connect(("127.0.0.1", self.get_http_port()))

    @gen.coroutine
    def read_headers(self):
        first_line = yield self.stream.read_until(b"\r\n")
        self.assertTrue(first_line.startswith(b"HTTP/1.1 200"), first_line)
        header_bytes = yield self.stream.read_until(b"\r\n\r\n")
        headers = HTTPHeaders.parse(header_bytes.decode("latin1"))
        raise gen.Return(headers)

    @gen.coroutine
    def read_response(self):
        self.headers = yield self.read_headers()
        body = yield self.stream.read_bytes(int(self.headers["Content-Length"]))
        self.assertEqual(b"Hello world", body)

    def close(self):
        self.stream.close()
        del self.stream

    @gen_test
    def test_two_requests(self):
        yield self.connect()
        self.stream.write(b"GET / HTTP/1.1\r\n\r\n")
        yield self.read_response()
        self.stream.write(b"GET / HTTP/1.1\r\n\r\n")
        yield self.read_response()
        self.close()

    @gen_test
    def test_request_close(self):
        yield self.connect()
        self.stream.write(b"GET / HTTP/1.1\r\nConnection: close\r\n\r\n")
        yield self.read_response()
        data = yield self.stream.read_until_close()
        self.assertTrue(not data)
        self.assertEqual(self.headers["Connection"], "close")
        self.close()

    # keepalive is supported for http 1.0 too, but it's opt-in
    @gen_test
    def test_http10(self):
        self.http_version = b"HTTP/1.0"
        yield self.connect()
        self.stream.write(b"GET / HTTP/1.0\r\n\r\n")
        yield self.read_response()
        data = yield self.stream.read_until_close()
        self.assertTrue(not data)
        self.assertTrue("Connection" not in self.headers)
        self.close()

    @gen_test
    def test_http10_keepalive(self):
        self.http_version = b"HTTP/1.0"
        yield self.connect()
        self.stream.write(b"GET / HTTP/1.0\r\nConnection: keep-alive\r\n\r\n")
        yield self.read_response()
        self.assertEqual(self.headers["Connection"], "Keep-Alive")
        self.stream.write(b"GET / HTTP/1.0\r\nConnection: keep-alive\r\n\r\n")
        yield self.read_response()
        self.assertEqual(self.headers["Connection"], "Keep-Alive")
        self.close()

    @gen_test
    def test_http10_keepalive_extra_crlf(self):
        self.http_version = b"HTTP/1.0"
        yield self.connect()
        self.stream.write(b"GET / HTTP/1.0\r\nConnection: keep-alive\r\n\r\n\r\n")
        yield self.read_response()
        self.assertEqual(self.headers["Connection"], "Keep-Alive")
        self.stream.write(b"GET / HTTP/1.0\r\nConnection: keep-alive\r\n\r\n")
        yield self.read_response()
        self.assertEqual(self.headers["Connection"], "Keep-Alive")
        self.close()

    @gen_test
    def test_pipelined_requests(self):
        yield self.connect()
        self.stream.write(b"GET / HTTP/1.1\r\n\r\nGET / HTTP/1.1\r\n\r\n")
        yield self.read_response()
        yield self.read_response()
        self.close()

    @gen_test
    def test_pipelined_cancel(self):
        yield self.connect()
        self.stream.write(b"GET / HTTP/1.1\r\n\r\nGET / HTTP/1.1\r\n\r\n")
        # only read once
        yield self.read_response()
        self.close()

    @gen_test
    def test_cancel_during_download(self):
        yield self.connect()
        self.stream.write(b"GET /large HTTP/1.1\r\n\r\n")
        yield self.read_headers()
        yield self.stream.read_bytes(1024)
        self.close()

    @gen_test
    def test_finish_while_closed(self):
        yield self.connect()
        self.stream.write(b"GET /finish_on_close HTTP/1.1\r\n\r\n")
        yield self.read_headers()
        self.close()
        # Let the hanging coroutine clean up after itself
        self.cleanup_event.set()

    @gen_test
    def test_keepalive_chunked(self):
        self.http_version = b"HTTP/1.0"
        yield self.connect()
        self.stream.write(
            b"POST / HTTP/1.0\r\n"
            b"Connection: keep-alive\r\n"
            b"Transfer-Encoding: chunked\r\n"
            b"\r\n"
            b"0\r\n"
            b"\r\n"
        )
        yield self.read_response()
        self.assertEqual(self.headers["Connection"], "Keep-Alive")
        self.stream.write(b"GET / HTTP/1.0\r\nConnection: keep-alive\r\n\r\n")
        yield self.read_response()
        self.assertEqual(self.headers["Connection"], "Keep-Alive")
        self.close()
예제 #38
0
class TestIOStream(AsyncHTTPTestCase, LogTrapTestCase):
    def get_app(self):
        return Application([('/', HelloHandler)])

    def make_iostream_pair(self):
        port = get_unused_port()
        [listener] = netutil.bind_sockets(port, '127.0.0.1',
                                          family=socket.AF_INET)
        streams = [None, None]
        def accept_callback(connection, address):
            streams[0] = IOStream(connection, io_loop=self.io_loop)
            self.stop()
        def connect_callback():
            streams[1] = client_stream
            self.stop()
        netutil.add_accept_handler(listener, accept_callback,
                                   io_loop=self.io_loop)
        client_stream = IOStream(socket.socket(), io_loop=self.io_loop)
        client_stream.connect(('127.0.0.1', port),
                              callback=connect_callback)
        self.wait(condition=lambda: all(streams))
        return streams

    def test_read_zero_bytes(self):
        s = socket.socket(socket.AF_INET, socket.SOCK_STREAM, 0)
        s.connect(("localhost", self.get_http_port()))
        self.stream = IOStream(s, io_loop=self.io_loop)
        self.stream.write(b("GET / HTTP/1.0\r\n\r\n"))

        # normal read
        self.stream.read_bytes(9, self.stop)
        data = self.wait()
        self.assertEqual(data, b("HTTP/1.0 "))

        # zero bytes
        self.stream.read_bytes(0, self.stop)
        data = self.wait()
        self.assertEqual(data, b(""))

        # another normal read
        self.stream.read_bytes(3, self.stop)
        data = self.wait()
        self.assertEqual(data, b("200"))

    def test_connection_refused(self):
        # When a connection is refused, the connect callback should not
        # be run.  (The kqueue IOLoop used to behave differently from the
        # epoll IOLoop in this respect)
        port = get_unused_port()
        stream = IOStream(socket.socket(), self.io_loop)
        self.connect_called = False
        def connect_callback():
            self.connect_called = True
        stream.set_close_callback(self.stop)
        stream.connect(("localhost", port), connect_callback)
        self.wait()
        self.assertFalse(self.connect_called)

    def test_connection_closed(self):
        # When a server sends a response and then closes the connection,
        # the client must be allowed to read the data before the IOStream
        # closes itself.  Epoll reports closed connections with a separate
        # EPOLLRDHUP event delivered at the same time as the read event,
        # while kqueue reports them as a second read/write event with an EOF
        # flag.
        response = self.fetch("/", headers={"Connection": "close"})
        response.rethrow()

    def test_read_until_close(self):
        s = socket.socket(socket.AF_INET, socket.SOCK_STREAM, 0)
        s.connect(("localhost", self.get_http_port()))
        stream = IOStream(s, io_loop=self.io_loop)
        stream.write(b("GET / HTTP/1.0\r\n\r\n"))
        
        stream.read_until_close(self.stop)
        data = self.wait()
        self.assertTrue(data.startswith(b("HTTP/1.0 200")))
        self.assertTrue(data.endswith(b("Hello")))

    def test_streaming_callback(self):
        server, client = self.make_iostream_pair()
        try:
            chunks = []
            final_called = []
            def streaming_callback(data):
                chunks.append(data)
                self.stop()
            def final_callback(data):
                assert not data
                final_called.append(True)
                self.stop()
            server.read_bytes(6, callback=final_callback,
                              streaming_callback=streaming_callback)
            client.write(b("1234"))
            self.wait(condition=lambda: chunks)
            client.write(b("5678"))
            self.wait(condition=lambda: final_called)
            self.assertEqual(chunks, [b("1234"), b("56")])

            # the rest of the last chunk is still in the buffer
            server.read_bytes(2, callback=self.stop)
            data = self.wait()
            self.assertEqual(data, b("78"))
        finally:
            server.close()
            client.close()

    def test_streaming_until_close(self):
        server, client = self.make_iostream_pair()
        try:
            chunks = []
            def callback(data):
                chunks.append(data)
                self.stop()
            client.read_until_close(callback=callback,
                                    streaming_callback=callback)
            server.write(b("1234"))
            self.wait()
            server.write(b("5678"))
            self.wait()
            server.close()
            self.wait()
            self.assertEqual(chunks, [b("1234"), b("5678"), b("")])
        finally:
            server.close()
            client.close()
예제 #39
0
class TestIOStream(AsyncHTTPTestCase, LogTrapTestCase):
    def get_app(self):
        return Application([('/', HelloHandler)])

    def make_iostream_pair(self, **kwargs):
        port = get_unused_port()
        [listener] = netutil.bind_sockets(port,
                                          '127.0.0.1',
                                          family=socket.AF_INET)
        streams = [None, None]

        def accept_callback(connection, address):
            streams[0] = IOStream(connection, io_loop=self.io_loop, **kwargs)
            self.stop()

        def connect_callback():
            streams[1] = client_stream
            self.stop()

        netutil.add_accept_handler(listener,
                                   accept_callback,
                                   io_loop=self.io_loop)
        client_stream = IOStream(socket.socket(),
                                 io_loop=self.io_loop,
                                 **kwargs)
        client_stream.connect(('127.0.0.1', port), callback=connect_callback)
        self.wait(condition=lambda: all(streams))
        self.io_loop.remove_handler(listener.fileno())
        listener.close()
        return streams

    def test_read_zero_bytes(self):
        s = socket.socket(socket.AF_INET, socket.SOCK_STREAM, 0)
        s.connect(("localhost", self.get_http_port()))
        self.stream = IOStream(s, io_loop=self.io_loop)
        self.stream.write(b("GET / HTTP/1.0\r\n\r\n"))

        # normal read
        self.stream.read_bytes(9, self.stop)
        data = self.wait()
        self.assertEqual(data, b("HTTP/1.0 "))

        # zero bytes
        self.stream.read_bytes(0, self.stop)
        data = self.wait()
        self.assertEqual(data, b(""))

        # another normal read
        self.stream.read_bytes(3, self.stop)
        data = self.wait()
        self.assertEqual(data, b("200"))

        s.close()

    def test_write_zero_bytes(self):
        # Attempting to write zero bytes should run the callback without
        # going into an infinite loop.
        server, client = self.make_iostream_pair()
        server.write(b(''), callback=self.stop)
        self.wait()
        # As a side effect, the stream is now listening for connection
        # close (if it wasn't already), but is not listening for writes
        self.assertEqual(server._state, IOLoop.READ | IOLoop.ERROR)
        server.close()
        client.close()

    def test_connection_refused(self):
        # When a connection is refused, the connect callback should not
        # be run.  (The kqueue IOLoop used to behave differently from the
        # epoll IOLoop in this respect)
        port = get_unused_port()
        stream = IOStream(socket.socket(), self.io_loop)
        self.connect_called = False

        def connect_callback():
            self.connect_called = True

        stream.set_close_callback(self.stop)
        stream.connect(("localhost", port), connect_callback)
        self.wait()
        self.assertFalse(self.connect_called)
        self.assertTrue(isinstance(stream.error, socket.error), stream.error)
        if sys.platform != 'cygwin':
            # cygwin's errnos don't match those used on native windows python
            self.assertEqual(stream.error.args[0], errno.ECONNREFUSED)

    def test_gaierror(self):
        # Test that IOStream sets its exc_info on getaddrinfo error
        s = socket.socket(socket.AF_INET, socket.SOCK_STREAM, 0)
        stream = IOStream(s, io_loop=self.io_loop)
        stream.set_close_callback(self.stop)
        stream.connect(('adomainthatdoesntexist.asdf', 54321))
        self.assertTrue(isinstance(stream.error, socket.gaierror),
                        stream.error)

    def test_connection_closed(self):
        # When a server sends a response and then closes the connection,
        # the client must be allowed to read the data before the IOStream
        # closes itself.  Epoll reports closed connections with a separate
        # EPOLLRDHUP event delivered at the same time as the read event,
        # while kqueue reports them as a second read/write event with an EOF
        # flag.
        response = self.fetch("/", headers={"Connection": "close"})
        response.rethrow()

    def test_read_until_close(self):
        s = socket.socket(socket.AF_INET, socket.SOCK_STREAM, 0)
        s.connect(("localhost", self.get_http_port()))
        stream = IOStream(s, io_loop=self.io_loop)
        stream.write(b("GET / HTTP/1.0\r\n\r\n"))

        stream.read_until_close(self.stop)
        data = self.wait()
        self.assertTrue(data.startswith(b("HTTP/1.0 200")))
        self.assertTrue(data.endswith(b("Hello")))

    def test_streaming_callback(self):
        server, client = self.make_iostream_pair()
        try:
            chunks = []
            final_called = []

            def streaming_callback(data):
                chunks.append(data)
                self.stop()

            def final_callback(data):
                assert not data
                final_called.append(True)
                self.stop()

            server.read_bytes(6,
                              callback=final_callback,
                              streaming_callback=streaming_callback)
            client.write(b("1234"))
            self.wait(condition=lambda: chunks)
            client.write(b("5678"))
            self.wait(condition=lambda: final_called)
            self.assertEqual(chunks, [b("1234"), b("56")])

            # the rest of the last chunk is still in the buffer
            server.read_bytes(2, callback=self.stop)
            data = self.wait()
            self.assertEqual(data, b("78"))
        finally:
            server.close()
            client.close()

    def test_streaming_until_close(self):
        server, client = self.make_iostream_pair()
        try:
            chunks = []

            def callback(data):
                chunks.append(data)
                self.stop()

            client.read_until_close(callback=callback,
                                    streaming_callback=callback)
            server.write(b("1234"))
            self.wait()
            server.write(b("5678"))
            self.wait()
            server.close()
            self.wait()
            self.assertEqual(chunks, [b("1234"), b("5678"), b("")])
        finally:
            server.close()
            client.close()

    def test_delayed_close_callback(self):
        # The scenario:  Server closes the connection while there is a pending
        # read that can be served out of buffered data.  The client does not
        # run the close_callback as soon as it detects the close, but rather
        # defers it until after the buffered read has finished.
        server, client = self.make_iostream_pair()
        try:
            client.set_close_callback(self.stop)
            server.write(b("12"))
            chunks = []

            def callback1(data):
                chunks.append(data)
                client.read_bytes(1, callback2)
                server.close()

            def callback2(data):
                chunks.append(data)

            client.read_bytes(1, callback1)
            self.wait()  # stopped by close_callback
            self.assertEqual(chunks, [b("1"), b("2")])
        finally:
            server.close()
            client.close()

    def test_close_buffered_data(self):
        # Similar to the previous test, but with data stored in the OS's
        # socket buffers instead of the IOStream's read buffer.  Out-of-band
        # close notifications must be delayed until all data has been
        # drained into the IOStream buffer. (epoll used to use out-of-band
        # close events with EPOLLRDHUP, but no longer)
        #
        # This depends on the read_chunk_size being smaller than the
        # OS socket buffer, so make it small.
        server, client = self.make_iostream_pair(read_chunk_size=256)
        try:
            server.write(b("A") * 512)
            client.read_bytes(256, self.stop)
            data = self.wait()
            self.assertEqual(b("A") * 256, data)
            server.close()
            # Allow the close to propagate to the client side of the
            # connection.  Using add_callback instead of add_timeout
            # doesn't seem to work, even with multiple iterations
            self.io_loop.add_timeout(datetime.timedelta(seconds=0.01),
                                     self.stop)
            self.wait()
            client.read_bytes(256, self.stop)
            data = self.wait()
            self.assertEqual(b("A") * 256, data)
        finally:
            server.close()
            client.close()

    def test_large_read_until(self):
        # Performance test: read_until used to have a quadratic component
        # so a read_until of 4MB would take 8 seconds; now it takes 0.25
        # seconds.
        server, client = self.make_iostream_pair()
        try:
            NUM_KB = 4096
            for i in xrange(NUM_KB):
                client.write(b("A") * 1024)
            client.write(b("\r\n"))
            server.read_until(b("\r\n"), self.stop)
            data = self.wait()
            self.assertEqual(len(data), NUM_KB * 1024 + 2)
        finally:
            server.close()
            client.close()
예제 #40
0
class KeepAliveTest(AsyncHTTPTestCase):
    """Tests various scenarios for HTTP 1.1 keep-alive support.

    These tests don't use AsyncHTTPClient because we want to control
    connection reuse and closing.
    """
    def get_app(self):
        class HelloHandler(RequestHandler):
            def get(self):
                self.finish('Hello world')

        class LargeHandler(RequestHandler):
            def get(self):
                # 512KB should be bigger than the socket buffers so it will
                # be written out in chunks.
                self.write(''.join(chr(i % 256) * 1024 for i in range(512)))

        class FinishOnCloseHandler(RequestHandler):
            @asynchronous
            def get(self):
                self.flush()

            def on_connection_close(self):
                # This is not very realistic, but finishing the request
                # from the close callback has the right timing to mimic
                # some errors seen in the wild.
                self.finish('closed')

        return Application([('/', HelloHandler), ('/large', LargeHandler),
                            ('/finish_on_close', FinishOnCloseHandler)])

    def setUp(self):
        super(KeepAliveTest, self).setUp()
        self.http_version = b'HTTP/1.1'

    def tearDown(self):
        # We just closed the client side of the socket; let the IOLoop run
        # once to make sure the server side got the message.
        self.io_loop.add_timeout(datetime.timedelta(seconds=0.001), self.stop)
        self.wait()

        if hasattr(self, 'stream'):
            self.stream.close()
        super(KeepAliveTest, self).tearDown()

    # The next few methods are a crude manual http client
    def connect(self):
        self.stream = IOStream(socket.socket(), io_loop=self.io_loop)
        self.stream.connect(('localhost', self.get_http_port()), self.stop)
        self.wait()

    def read_headers(self):
        self.stream.read_until(b'\r\n', self.stop)
        first_line = self.wait()
        self.assertTrue(first_line.startswith(self.http_version + b' 200'),
                        first_line)
        self.stream.read_until(b'\r\n\r\n', self.stop)
        header_bytes = self.wait()
        headers = HTTPHeaders.parse(header_bytes.decode('latin1'))
        return headers

    def read_response(self):
        self.headers = self.read_headers()
        self.stream.read_bytes(int(self.headers['Content-Length']), self.stop)
        body = self.wait()
        self.assertEqual(b'Hello world', body)

    def close(self):
        self.stream.close()
        del self.stream

    def test_two_requests(self):
        self.connect()
        self.stream.write(b'GET / HTTP/1.1\r\n\r\n')
        self.read_response()
        self.stream.write(b'GET / HTTP/1.1\r\n\r\n')
        self.read_response()
        self.close()

    def test_request_close(self):
        self.connect()
        self.stream.write(b'GET / HTTP/1.1\r\nConnection: close\r\n\r\n')
        self.read_response()
        self.stream.read_until_close(callback=self.stop)
        data = self.wait()
        self.assertTrue(not data)
        self.close()

    # keepalive is supported for http 1.0 too, but it's opt-in
    def test_http10(self):
        self.http_version = b'HTTP/1.0'
        self.connect()
        self.stream.write(b'GET / HTTP/1.0\r\n\r\n')
        self.read_response()
        self.stream.read_until_close(callback=self.stop)
        data = self.wait()
        self.assertTrue(not data)
        self.assertTrue('Connection' not in self.headers)
        self.close()

    def test_http10_keepalive(self):
        self.http_version = b'HTTP/1.0'
        self.connect()
        self.stream.write(b'GET / HTTP/1.0\r\nConnection: keep-alive\r\n\r\n')
        self.read_response()
        self.assertEqual(self.headers['Connection'], 'Keep-Alive')
        self.stream.write(b'GET / HTTP/1.0\r\nConnection: keep-alive\r\n\r\n')
        self.read_response()
        self.assertEqual(self.headers['Connection'], 'Keep-Alive')
        self.close()

    def test_pipelined_requests(self):
        self.connect()
        self.stream.write(b'GET / HTTP/1.1\r\n\r\nGET / HTTP/1.1\r\n\r\n')
        self.read_response()
        self.read_response()
        self.close()

    def test_pipelined_cancel(self):
        self.connect()
        self.stream.write(b'GET / HTTP/1.1\r\n\r\nGET / HTTP/1.1\r\n\r\n')
        # only read once
        self.read_response()
        self.close()

    def test_cancel_during_download(self):
        self.connect()
        self.stream.write(b'GET /large HTTP/1.1\r\n\r\n')
        self.read_headers()
        self.stream.read_bytes(1024, self.stop)
        self.wait()
        self.close()

    def test_finish_while_closed(self):
        self.connect()
        self.stream.write(b'GET /finish_on_close HTTP/1.1\r\n\r\n')
        self.read_headers()
        self.close()
예제 #41
0
class Connection(Connection):

    def __init__(self, pool=None, *args, **kwargs):
        super(Connection, self).__init__(*args, **kwargs)
        self._pool = pool
        self._stream = None
        self._callbacks = []
        self._ready = False

    def __enter__(self):
        return self

    def __exit__(self, exc_type, exc_val, exc_tb):
        self._pool.release(self)

    def _add_callback(self, func):
        self._callbacks.append(func)

    def _do_callbacks(self):
        self._ready = True
        while 1:
            try:
                func = self._callbacks.pop()
                func()
            except IndexError:
                # all done
                break
            except:
                # other error
                continue

    def connect(self):
        s = socket.socket(socket.AF_INET, socket.SOCK_STREAM, 0)
        self._sock = IOStream(s)  # tornado iostream
        self._sock.connect((self._host, self._port), self._do_callbacks)

    def send(self, payload, correlation_id=-1, callback=None):
        """
        :param payload: an encoded kafka packet
        :param correlation_id: for now, just for debug logging
        :return:
        """
        if not self._ready:
            def _callback(*args, **kwargs):
                self.send(payload, correlation_id, callback)
            self._add_callback(_callback)
            return
        log.debug("About to send %d bytes to Kafka, request %d" % (len(payload), correlation_id))
        if payload:
            _bytes = struct.pack('>i%ds' % len(payload), len(payload), payload)
        else:
            _bytes = struct.pack('>i', -1)
        try:
            self._sock.write(_bytes, callback)  # simply using sendall
        except:
            self.close()
            callback(None)
            self._log_and_raise('Unable to send payload to Kafka')

    def _recv(self, size, callback):
        try:
            self._sock.read_bytes(min(size, 4096), callback)
        except:
            self.close()
            callback(None)  # if error, set None
            self._log_and_raise('Unable to receive data from Kafka')

    def recv(self, correlation_id=-1, callback=None):
        """

        :param correlation_id: for now, just for debug logging
        :return: kafka response packet
        """
        log.debug("Reading response %d from Kafka" % correlation_id)
        if not self._ready:
            def _callback():
                self.recv(correlation_id, callback)
            self._add_callback(_callback)
            return
        def get_size(resp):
            if resp == None:
                callback(None)
            size, = struct.unpack('>i', resp)
            self._recv(size, callback)
        self._recv(4, get_size)  # read the response length

    def close(self):
        self._callbacks = []
        log.debug("Closing socket connection" + self._log_tail)
        if self._sock:
            self._sock.close()
            self._sock = None
        else:
            log.debug("Socket connection not exists" + self._log_tail)

    def closed(self):
        return self._sock.closed()
예제 #42
0
class _RPCClientConnection(object):
    '''An RPC client connection.'''

    def __init__(self, close_callback):
        self._stream = None
        self._sequence = itertools.count()
        self._pending = {}  # sequence -> callback
        self._pending_read = None
        self._close_callback = stack_context.wrap(close_callback)

    @gen.engine
    def connect(self, address, nonce=protocol.NULL_NONCE, callback=None):
        if self._stream is not None:
            raise RuntimeError('Attempting to reconnect existing connection')
        sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM, 0)
        sock.setsockopt(socket.SOL_TCP, socket.TCP_NODELAY, 1)
        self._stream = IOStream(sock)
        self._stream.set_close_callback(self._handle_close)
        # If the connect fails, our close callback will be called, and
        # the Wait will never return.
        self._stream.connect((address, protocol.PORT),
                callback=(yield gen.Callback('connect')))
        self._write(nonce)
        yield gen.Wait('connect')

        rnonce = yield gen.Task(self._read, protocol.NONCE_LEN)
        # Start reply-handling coroutine
        self._reply_handler()
        if callback is not None:
            callback(rnonce)

    def close(self):
        if self._stream is not None:
            self._stream.close()
        else:
            # close() called before connect().  Synthesize the close event
            # ourselves.
            self._handle_close()

    def _handle_close(self):
        if self._pending_read is not None:
            # The pending read callback will never be called.  Call it
            # ourselves to clean up.
            self._pending_read(None)
        if self._close_callback is not None:
            cb = self._close_callback
            self._close_callback = None
            cb()

    @gen.engine
    def _read(self, count, callback=None):
        if self._pending_read is not None:
            raise RuntimeError('Double read on connection')
        self._pending_read = stack_context.wrap((yield gen.Callback('read')))
        try:
            self._stream.read_bytes(count, callback=self._pending_read)
            buf = yield gen.Wait('read')
            if buf is None:
                # _handle_close() is cleaning us up
                raise ConnectionFailure('Connection closed')
        except IOError, e:
            self.close()
            raise ConnectionFailure(str(e))
        finally:
예제 #43
0
class AsyncRedisClient(object):
    """An non-blocking Redis client.

    Example usage::

        import ioloop

        def handle_request(result):
            print 'Redis reply: %r' % result
            ioloop.IOLoop.instance().stop()

        redis_client = AsyncRedisClient(('127.0.0.1', 6379))
        redis_client.fetch(('set', 'foo', 'bar'), None)
        redis_client.fetch(('get', 'foo'), handle_request)
        ioloop.IOLoop.instance().start()

    This class implements a Redis client on top of Tornado's IOStreams.
    It does not currently implement all applicable parts of the Redis
    specification, but it does enough to work with major redis server APIs
    (mostly tested against the LIST/HASH/PUBSUB API so far).

    This class has not been tested extensively in production and
    should be considered somewhat experimental as of the release of
    tornado 1.2.  It is intended to become the default tornado
    AsyncRedisClient implementation.
    """

    def __init__(self, address, io_loop=None):
        """Creates a AsyncRedisClient.

        address is the tuple of redis server address that can be connect by
        IOStream. It can be to ('127.0.0.1', 6379).
        """
        self.address         = address
        self.io_loop         = io_loop or IOLoop.instance()
        self._callback_queue = deque()
        self._callback       = None
        self._read_buffer    = None
        self._result_queue   = deque()
        self.socket          = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
        self.stream          = IOStream(self.socket, self.io_loop)
        self.stream.connect(self.address, self._wait_result)

    def close(self):
        """Destroys this redis client, freeing any file descriptors used.
        Not needed in normal use, but may be helpful in unittests that
        create and destroy redis clients.  No other methods may be called
        on the AsyncRedisClient after close().
        """
        self.stream.close()

    def fetch(self, request, callback):
        """Executes a request, calling callback with an redis `result`.

        The request shuold be a string tuple. like ('set', 'foo', 'bar')

        If an error occurs during the fetch, a `RedisError` exception will
        throw out. You can use try...except to catch the exception (if any)
        in the callback.
        """
        self._callback_queue.append(callback)
        self.stream.write(encode(request))

    def _wait_result(self):
        """Read a completed result data from the redis server."""
        self._read_buffer = deque()
        self.stream.read_until('\r\n', self._on_read_first_line)

    def _maybe_callback(self):
        """Try call callback in _callback_queue when we read a redis result."""
        try:
            read_buffer    = self._read_buffer
            callback       = self._callback
            result_queue   = self._result_queue
            callback_queue = self._callback_queue
            if result_queue:
                result_queue.append(read_buffer)
                read_buffer = result_queue.popleft()
            if callback_queue:
                callback = self._callback = callback_queue.popleft()
            if callback:
                callback(decode(read_buffer))
        except Exception:
            logging.error('Uncaught callback exception', exc_info=True)
            self.close()
            raise
        finally:
            self._wait_result()

    def _on_read_first_line(self, data):
        self._read_buffer.append(data)
        c = data[0]
        if c in ':+-':
            self._maybe_callback()
        elif c == '$':
            if data[:3] == '$-1':
                self._maybe_callback()
            else:
                length = int(data[1:])
                self.stream.read_bytes(length+2, self._on_read_bulk_body)
        elif c == '*':
            if data[1] in '-0' :
                self._maybe_callback()
            else:
                self._multibulk_number = int(data[1:])
                self.stream.read_until('\r\n', self._on_read_multibulk_bulk_head)

    def _on_read_bulk_body(self, data):
        self._read_buffer.append(data)
        self._maybe_callback()

    def _on_read_multibulk_bulk_head(self, data):
        self._read_buffer.append(data)
        c = data[0]
        if c == '$':
            length = int(data[1:])
            self.stream.read_bytes(length+2, self._on_read_multibulk_bulk_body)
        else:
            self._maybe_callback()

    def _on_read_multibulk_bulk_body(self, data):
        self._read_buffer.append(data)
        self._multibulk_number -= 1
        if self._multibulk_number:
            self.stream.read_until('\r\n', self._on_read_multibulk_bulk_head)
        else:
            self._maybe_callback()
예제 #44
0
class TTornadoTransport(TTransportBase):
    """A non-blocking Thrift client.

  Example usage::

    import greenlet
    from tornado import ioloop
    from thrift.transport import TTransport
    from thrift.protocol import TBinaryProtocol

    from viewfinder.backend.thrift import TTornadoTransport

    transport = TTransport.TFramedTransport(TTornadoTransport('localhost', 9090))
    protocol = TBinaryProtocol.TBinaryProtocol(transport)
    client = Service.Client(protocol)
    ioloop.IOLoop.instance().start()

  Then, from within an asynchronous tornado request handler:

    class MyApp(tornado.web.RequestHandler):
      @tornado.web.asynchronous
      def post(self):
      def business_logic():
        ...any thrift calls...
        self.write(...stuff that gets returned to client...)
        self.finish() #end the asynchronous request
      gr = greenlet.greenlet(business_logic)
      gr.switch()
  """
    def __init__(self, host='localhost', port=9090):
        """Initialize a TTornadoTransport with a Tornado IOStream.

    @param host(str) The host to connect to.
    @param port(int) The (TCP) port to connect to.
    """
        self.host = host
        self.port = port
        self._stream = None
        self._io_loop = ioloop.IOLoop.current()
        self._timeout_secs = None

    def set_timeout(self, timeout_secs):
        """Sets a timeout for use with open/read/write operations."""
        self._timeout_secs = timeout_secs

    def isOpen(self):
        return self._stream is not None

    def open(self):
        """Creates a connection to host:port and spins up a tornado
    IOStream object to write requests and read responses from the
    thrift server. After making the asynchronous connect call to
    _stream, the current greenlet yields control back to the parent
    greenlet (presumably the "master" greenlet).
    """
        assert greenlet.getcurrent().parent is not None
        # TODO(spencer): allow ipv6? (af = socket.AF_UNSPEC)
        addrinfo = socket.getaddrinfo(self.host, self.port, socket.AF_INET,
                                      socket.SOCK_STREAM, 0, 0)
        af, socktype, proto, canonname, sockaddr = addrinfo[0]
        self._stream = IOStream(socket.socket(af, socktype, proto),
                                io_loop=self._io_loop)
        self._open_internal(sockaddr)

    def close(self):
        if self._stream:
            self._stream.set_close_callback(None)
            self._stream.close()
            self._stream = None

    @_wrap_transport
    def read(self, sz):
        logging.debug("reading %d bytes from %s:%d" %
                      (sz, self.host, self.port))
        cur_gr = greenlet.getcurrent()

        def _on_read(buf):
            if self._stream:
                cur_gr.switch(buf)

        self._stream.read_bytes(sz, _on_read)
        buf = cur_gr.parent.switch()
        if len(buf) == 0:
            raise TTransportException(type=TTransportException.END_OF_FILE,
                                      message='TTornadoTransport read 0 bytes')
        logging.debug("read %d bytes in %.2fms" %
                      (len(buf), (time.time() - self._start_time) * 1000))
        return buf

    @_wrap_transport
    def write(self, buf):
        logging.debug("writing %d bytes to %s:%d" %
                      (len(buf), self.host, self.port))
        cur_gr = greenlet.getcurrent()

        def _on_write():
            if self._stream:
                cur_gr.switch()

        self._stream.write(buf, _on_write)
        cur_gr.parent.switch()
        logging.debug("wrote %d bytes in %.2fms" %
                      (len(buf), (time.time() - self._start_time) * 1000))

    @_wrap_transport
    def flush(self):
        pass

    @_wrap_transport
    def _open_internal(self, sockaddr):
        logging.debug("opening connection to %s:%d" % (self.host, self.port))
        cur_gr = greenlet.getcurrent()

        def _on_connect():
            if self._stream:
                cur_gr.switch()

        self._stream.connect(sockaddr, _on_connect)
        cur_gr.parent.switch()
        logging.info("opened connection to %s:%d" % (self.host, self.port))

    def _check_stream(self):
        if not self._stream:
            raise TTransportException(type=TTransportException.NOT_OPEN,
                                      message='transport not open')

    def _set_timeout(self):
        if self._timeout_secs:
            return self._io_loop.add_timeout(
                time.time() + self._timeout_secs,
                functools.partial(self._on_timeout, gr=greenlet.getcurrent()))
        return None

    def _clear_timeout(self, timeout):
        if timeout:
            self._io_loop.remove_timeout(timeout)

    def _on_timeout(self, gr):
        gr.throw(
            TTransportException(type=TTransportException.TIMED_OUT,
                                message="connection timed out to %s:%d" %
                                (self.host, self.port)))

    def _on_close(self, gr):
        self._stream = None
        message = "connection to %s:%d closed" % (self.host, self.port)
        if gr:
            gr.throw(
                TTransportException(type=TTransportException.NOT_OPEN,
                                    message=message))
        else:
            logging.error(message)
예제 #45
0
class KeepAliveTest(AsyncHTTPTestCase):
    """Tests various scenarios for HTTP 1.1 keep-alive support.

    These tests don't use AsyncHTTPClient because we want to control
    connection reuse and closing.
    """
    def get_app(self):
        class HelloHandler(RequestHandler):
            def get(self):
                self.finish("Hello world")

            def post(self):
                self.finish("Hello world")

        class LargeHandler(RequestHandler):
            def get(self):
                # 512KB should be bigger than the socket buffers so it will
                # be written out in chunks.
                self.write("".join(chr(i % 256) * 1024 for i in range(512)))

        class FinishOnCloseHandler(RequestHandler):
            def initialize(self, cleanup_event):
                self.cleanup_event = cleanup_event

            @gen.coroutine
            def get(self):
                self.flush()
                yield self.cleanup_event.wait()

            def on_connection_close(self):
                # This is not very realistic, but finishing the request
                # from the close callback has the right timing to mimic
                # some errors seen in the wild.
                self.finish("closed")

        self.cleanup_event = Event()
        return Application([
            ("/", HelloHandler),
            ("/large", LargeHandler),
            (
                "/finish_on_close",
                FinishOnCloseHandler,
                dict(cleanup_event=self.cleanup_event),
            ),
        ])

    def setUp(self):
        super(KeepAliveTest, self).setUp()
        self.http_version = b"HTTP/1.1"

    def tearDown(self):
        # We just closed the client side of the socket; let the IOLoop run
        # once to make sure the server side got the message.
        self.io_loop.add_timeout(datetime.timedelta(seconds=0.001), self.stop)
        self.wait()

        if hasattr(self, "stream"):
            self.stream.close()
        super(KeepAliveTest, self).tearDown()

    # The next few methods are a crude manual http client
    @gen.coroutine
    def connect(self):
        self.stream = IOStream(socket.socket())
        yield self.stream.connect(("10.0.0.7", self.get_http_port()))

    @gen.coroutine
    def read_headers(self):
        first_line = yield self.stream.read_until(b"\r\n")
        self.assertTrue(first_line.startswith(b"HTTP/1.1 200"), first_line)
        header_bytes = yield self.stream.read_until(b"\r\n\r\n")
        headers = HTTPHeaders.parse(header_bytes.decode("latin1"))
        raise gen.Return(headers)

    @gen.coroutine
    def read_response(self):
        self.headers = yield self.read_headers()
        body = yield self.stream.read_bytes(int(
            self.headers["Content-Length"]))
        self.assertEqual(b"Hello world", body)

    def close(self):
        self.stream.close()
        del self.stream

    @gen_test
    def test_two_requests(self):
        yield self.connect()
        self.stream.write(b"GET / HTTP/1.1\r\n\r\n")
        yield self.read_response()
        self.stream.write(b"GET / HTTP/1.1\r\n\r\n")
        yield self.read_response()
        self.close()

    @gen_test
    def test_request_close(self):
        yield self.connect()
        self.stream.write(b"GET / HTTP/1.1\r\nConnection: close\r\n\r\n")
        yield self.read_response()
        data = yield self.stream.read_until_close()
        self.assertTrue(not data)
        self.assertEqual(self.headers["Connection"], "close")
        self.close()

    # keepalive is supported for http 1.0 too, but it's opt-in
    @gen_test
    def test_http10(self):
        self.http_version = b"HTTP/1.0"
        yield self.connect()
        self.stream.write(b"GET / HTTP/1.0\r\n\r\n")
        yield self.read_response()
        data = yield self.stream.read_until_close()
        self.assertTrue(not data)
        self.assertTrue("Connection" not in self.headers)
        self.close()

    @gen_test
    def test_http10_keepalive(self):
        self.http_version = b"HTTP/1.0"
        yield self.connect()
        self.stream.write(b"GET / HTTP/1.0\r\nConnection: keep-alive\r\n\r\n")
        yield self.read_response()
        self.assertEqual(self.headers["Connection"], "Keep-Alive")
        self.stream.write(b"GET / HTTP/1.0\r\nConnection: keep-alive\r\n\r\n")
        yield self.read_response()
        self.assertEqual(self.headers["Connection"], "Keep-Alive")
        self.close()

    @gen_test
    def test_http10_keepalive_extra_crlf(self):
        self.http_version = b"HTTP/1.0"
        yield self.connect()
        self.stream.write(
            b"GET / HTTP/1.0\r\nConnection: keep-alive\r\n\r\n\r\n")
        yield self.read_response()
        self.assertEqual(self.headers["Connection"], "Keep-Alive")
        self.stream.write(b"GET / HTTP/1.0\r\nConnection: keep-alive\r\n\r\n")
        yield self.read_response()
        self.assertEqual(self.headers["Connection"], "Keep-Alive")
        self.close()

    @gen_test
    def test_pipelined_requests(self):
        yield self.connect()
        self.stream.write(b"GET / HTTP/1.1\r\n\r\nGET / HTTP/1.1\r\n\r\n")
        yield self.read_response()
        yield self.read_response()
        self.close()

    @gen_test
    def test_pipelined_cancel(self):
        yield self.connect()
        self.stream.write(b"GET / HTTP/1.1\r\n\r\nGET / HTTP/1.1\r\n\r\n")
        # only read once
        yield self.read_response()
        self.close()

    @gen_test
    def test_cancel_during_download(self):
        yield self.connect()
        self.stream.write(b"GET /large HTTP/1.1\r\n\r\n")
        yield self.read_headers()
        yield self.stream.read_bytes(1024)
        self.close()

    @gen_test
    def test_finish_while_closed(self):
        yield self.connect()
        self.stream.write(b"GET /finish_on_close HTTP/1.1\r\n\r\n")
        yield self.read_headers()
        self.close()
        # Let the hanging coroutine clean up after itself
        self.cleanup_event.set()

    @gen_test
    def test_keepalive_chunked(self):
        self.http_version = b"HTTP/1.0"
        yield self.connect()
        self.stream.write(b"POST / HTTP/1.0\r\n"
                          b"Connection: keep-alive\r\n"
                          b"Transfer-Encoding: chunked\r\n"
                          b"\r\n"
                          b"0\r\n"
                          b"\r\n")
        yield self.read_response()
        self.assertEqual(self.headers["Connection"], "Keep-Alive")
        self.stream.write(b"GET / HTTP/1.0\r\nConnection: keep-alive\r\n\r\n")
        yield self.read_response()
        self.assertEqual(self.headers["Connection"], "Keep-Alive")
        self.close()
예제 #46
0
파일: sockets.py 프로젝트: TFenby/wdb
def handle_connection(connection, address):
    log.info('Connection received from %s' % str(address))
    stream = IOStream(connection, ioloop)
    # Getting uuid
    stream.read_bytes(4, partial(read_uuid_size, stream))
예제 #47
0
class _RPCClientConnection(object):
    '''An RPC client connection.'''
    def __init__(self, close_callback):
        self._stream = None
        self._sequence = itertools.count()
        self._pending = {}  # sequence -> callback
        self._pending_read = None
        self._close_callback = stack_context.wrap(close_callback)

    @gen.engine
    def connect(self, address, nonce=protocol.NULL_NONCE, callback=None):
        if self._stream is not None:
            raise RuntimeError('Attempting to reconnect existing connection')
        sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM, 0)
        sock.setsockopt(socket.SOL_TCP, socket.TCP_NODELAY, 1)
        self._stream = IOStream(sock)
        self._stream.set_close_callback(self._handle_close)
        # If the connect fails, our close callback will be called, and
        # the Wait will never return.
        self._stream.connect((address, protocol.PORT),
                             callback=(yield gen.Callback('connect')))
        self._write(nonce)
        yield gen.Wait('connect')

        rnonce = yield gen.Task(self._read, protocol.NONCE_LEN)
        # Start reply-handling coroutine
        self._reply_handler()
        if callback is not None:
            callback(rnonce)

    def close(self):
        if self._stream is not None:
            self._stream.close()
        else:
            # close() called before connect().  Synthesize the close event
            # ourselves.
            self._handle_close()

    def _handle_close(self):
        if self._pending_read is not None:
            # The pending read callback will never be called.  Call it
            # ourselves to clean up.
            self._pending_read(None)
        if self._close_callback is not None:
            cb = self._close_callback
            self._close_callback = None
            cb()

    @gen.engine
    def _read(self, count, callback=None):
        if self._pending_read is not None:
            raise RuntimeError('Double read on connection')
        self._pending_read = stack_context.wrap((yield gen.Callback('read')))
        try:
            self._stream.read_bytes(count, callback=self._pending_read)
            buf = yield gen.Wait('read')
            if buf is None:
                # _handle_close() is cleaning us up
                raise ConnectionFailure('Connection closed')
        except IOError, e:
            self.close()
            raise ConnectionFailure(str(e))
        finally:
예제 #48
0
class _HTTPConnection(object):
    _SUPPORTED_METHODS = set(["GET", "HEAD", "POST", "PUT", "DELETE"])

    def __init__(self, io_loop, client, request, release_callback, final_callback, max_buffer_size):
        self.start_time = time.time()
        self.io_loop = io_loop
        self.client = client
        self.request = request
        self.release_callback = release_callback
        self.final_callback = final_callback
        self.code = None
        self.headers = None
        self.chunks = None
        self._decompressor = None
        # Timeout handle returned by IOLoop.add_timeout
        self._timeout = None
        with stack_context.StackContext(self.cleanup):
            parsed = urllib.parse.urlsplit(_unicode(self.request.url))
            if ssl is None and parsed.scheme == "https":
                raise ValueError("HTTPS requires either python2.6+ or " "curl_httpclient")
            if parsed.scheme not in ("http", "https"):
                raise ValueError("Unsupported url scheme: %s" % self.request.url)
            # urlsplit results have hostname and port results, but they
            # didn't support ipv6 literals until python 2.7.
            netloc = parsed.netloc
            if "@" in netloc:
                userpass, _, netloc = netloc.rpartition("@")
            match = re.match(r"^(.+):(\d+)$", netloc)
            if match:
                host = match.group(1)
                port = int(match.group(2))
            else:
                host = netloc
                port = 443 if parsed.scheme == "https" else 80
            if re.match(r"^\[.*\]$", host):
                # raw ipv6 addresses in urls are enclosed in brackets
                host = host[1:-1]
            if self.client.hostname_mapping is not None:
                host = self.client.hostname_mapping.get(host, host)

            if request.allow_ipv6:
                af = socket.AF_UNSPEC
            else:
                # We only try the first IP we get from getaddrinfo,
                # so restrict to ipv4 by default.
                af = socket.AF_INET

            addrinfo = socket.getaddrinfo(host, port, af, socket.SOCK_STREAM, 0, 0)
            af, socktype, proto, canonname, sockaddr = addrinfo[0]

            if parsed.scheme == "https":
                ssl_options = {}
                if request.validate_cert:
                    ssl_options["cert_reqs"] = ssl.CERT_REQUIRED
                if request.ca_certs is not None:
                    ssl_options["ca_certs"] = request.ca_certs
                else:
                    ssl_options["ca_certs"] = _DEFAULT_CA_CERTS
                if request.client_key is not None:
                    ssl_options["keyfile"] = request.client_key
                if request.client_cert is not None:
                    ssl_options["certfile"] = request.client_cert

                # SSL interoperability is tricky.  We want to disable
                # SSLv2 for security reasons; it wasn't disabled by default
                # until openssl 1.0.  The best way to do this is to use
                # the SSL_OP_NO_SSLv2, but that wasn't exposed to python
                # until 3.2.  Python 2.7 adds the ciphers argument, which
                # can also be used to disable SSLv2.  As a last resort
                # on python 2.6, we set ssl_version to SSLv3.  This is
                # more narrow than we'd like since it also breaks
                # compatibility with servers configured for TLSv1 only,
                # but nearly all servers support SSLv3:
                # http://blog.ivanristic.com/2011/09/ssl-survey-protocol-support.html
                if sys.version_info >= (2, 7):
                    ssl_options["ciphers"] = "DEFAULT:!SSLv2"
                else:
                    # This is really only necessary for pre-1.0 versions
                    # of openssl, but python 2.6 doesn't expose version
                    # information.
                    ssl_options["ssl_version"] = ssl.PROTOCOL_SSLv3

                self.stream = SSLIOStream(
                    socket.socket(af, socktype, proto),
                    io_loop=self.io_loop,
                    ssl_options=ssl_options,
                    max_buffer_size=max_buffer_size,
                )
            else:
                self.stream = IOStream(
                    socket.socket(af, socktype, proto), io_loop=self.io_loop, max_buffer_size=max_buffer_size
                )
            timeout = min(request.connect_timeout, request.request_timeout)
            if timeout:
                self._timeout = self.io_loop.add_timeout(self.start_time + timeout, self._on_timeout)
            self.stream.set_close_callback(self._on_close)
            self.stream.connect(sockaddr, functools.partial(self._on_connect, parsed))

    def _on_timeout(self):
        self._timeout = None
        self._run_callback(
            HTTPResponse(self.request, 599, request_time=time.time() - self.start_time, error=HTTPError(599, "Timeout"))
        )
        self.stream.close()

    def _on_connect(self, parsed):
        if self._timeout is not None:
            self.io_loop.remove_timeout(self._timeout)
            self._timeout = None
        if self.request.request_timeout:
            self._timeout = self.io_loop.add_timeout(self.start_time + self.request.request_timeout, self._on_timeout)
        if self.request.validate_cert and isinstance(self.stream, SSLIOStream):
            match_hostname(self.stream.socket.getpeercert(), parsed.hostname)
        if self.request.method not in self._SUPPORTED_METHODS and not self.request.allow_nonstandard_methods:
            raise KeyError("unknown method %s" % self.request.method)
        for key in ("network_interface", "proxy_host", "proxy_port", "proxy_username", "proxy_password"):
            if getattr(self.request, key, None):
                raise NotImplementedError("%s not supported" % key)
        if "Host" not in self.request.headers:
            self.request.headers["Host"] = parsed.netloc
        username, password = None, None
        if parsed.username is not None:
            username, password = parsed.username, parsed.password
        elif self.request.auth_username is not None:
            username = self.request.auth_username
            password = self.request.auth_password or ""
        if username is not None:
            auth = utf8(username) + b(":") + utf8(password)
            self.request.headers["Authorization"] = b("Basic ") + base64.b64encode(auth)
        if self.request.user_agent:
            self.request.headers["User-Agent"] = self.request.user_agent
        if not self.request.allow_nonstandard_methods:
            if self.request.method in ("POST", "PUT"):
                assert self.request.body is not None
            else:
                assert self.request.body is None
        if self.request.body is not None:
            self.request.headers["Content-Length"] = str(len(self.request.body))
        if self.request.method == "POST" and "Content-Type" not in self.request.headers:
            self.request.headers["Content-Type"] = "application/x-www-form-urlencoded"
        if self.request.use_gzip:
            self.request.headers["Accept-Encoding"] = "gzip"
        req_path = (parsed.path or "/") + (("?" + parsed.query) if parsed.query else "")
        request_lines = [utf8("%s %s HTTP/1.1" % (self.request.method, req_path))]
        for k, v in self.request.headers.get_all():
            line = utf8(k) + b(": ") + utf8(v)
            if b("\n") in line:
                raise ValueError("Newline in header: " + repr(line))
            request_lines.append(line)
        self.stream.write(b("\r\n").join(request_lines) + b("\r\n\r\n"))
        if self.request.body is not None:
            self.stream.write(self.request.body)
        self.stream.read_until_regex(b("\r?\n\r?\n"), self._on_headers)

    def _release(self):
        if self.release_callback is not None:
            release_callback = self.release_callback
            self.release_callback = None
            release_callback()

    def _run_callback(self, response):
        self._release()
        if self.final_callback is not None:
            final_callback = self.final_callback
            self.final_callback = None
            final_callback(response)

    @contextlib.contextmanager
    def cleanup(self):
        try:
            yield
        except Exception as e:
            logging.warning("uncaught exception", exc_info=True)
            self._run_callback(HTTPResponse(self.request, 599, error=e, request_time=time.time() - self.start_time))

    def _on_close(self):
        self._run_callback(
            HTTPResponse(
                self.request, 599, request_time=time.time() - self.start_time, error=HTTPError(599, "Connection closed")
            )
        )

    def _on_headers(self, data):
        data = native_str(data.decode("latin1"))
        first_line, _, header_data = data.partition("\n")
        match = re.match("HTTP/1.[01] ([0-9]+)", first_line)
        assert match
        self.code = int(match.group(1))
        self.headers = HTTPHeaders.parse(header_data)

        if "Content-Length" in self.headers:
            if "," in self.headers["Content-Length"]:
                # Proxies sometimes cause Content-Length headers to get
                # duplicated.  If all the values are identical then we can
                # use them but if they differ it's an error.
                pieces = re.split(r",\s*", self.headers["Content-Length"])
                if any(i != pieces[0] for i in pieces):
                    raise ValueError("Multiple unequal Content-Lengths: %r" % self.headers["Content-Length"])
                self.headers["Content-Length"] = pieces[0]
            content_length = int(self.headers["Content-Length"])
        else:
            content_length = None

        if self.request.header_callback is not None:
            for k, v in self.headers.get_all():
                self.request.header_callback("%s: %s\r\n" % (k, v))

        if self.request.method == "HEAD":
            # HEAD requests never have content, even though they may have
            # content-length headers
            self._on_body(b(""))
            return
        if 100 <= self.code < 200 or self.code in (204, 304):
            # These response codes never have bodies
            # http://www.w3.org/Protocols/rfc2616/rfc2616-sec4.html#sec4.3
            assert "Transfer-Encoding" not in self.headers
            assert content_length in (None, 0)
            self._on_body(b(""))
            return

        if self.request.use_gzip and self.headers.get("Content-Encoding") == "gzip":
            # Magic parameter makes zlib module understand gzip header
            # http://stackoverflow.com/questions/1838699/how-can-i-decompress-a-gzip-stream-with-zlib
            self._decompressor = zlib.decompressobj(16 + zlib.MAX_WBITS)
        if self.headers.get("Transfer-Encoding") == "chunked":
            self.chunks = []
            self.stream.read_until(b("\r\n"), self._on_chunk_length)
        elif content_length is not None:
            self.stream.read_bytes(content_length, self._on_body)
        else:
            self.stream.read_until_close(self._on_body)

    def _on_body(self, data):
        if self._timeout is not None:
            self.io_loop.remove_timeout(self._timeout)
            self._timeout = None
        original_request = getattr(self.request, "original_request", self.request)
        if self.request.follow_redirects and self.request.max_redirects > 0 and self.code in (301, 302, 303, 307):
            new_request = copy.copy(self.request)
            new_request.url = urllib.parse.urljoin(self.request.url, self.headers["Location"])
            new_request.max_redirects -= 1
            del new_request.headers["Host"]
            # http://www.w3.org/Protocols/rfc2616/rfc2616-sec10.html#sec10.3.4
            # client SHOULD make a GET request
            if self.code == 303:
                new_request.method = "GET"
                new_request.body = None
                for h in ["Content-Length", "Content-Type", "Content-Encoding", "Transfer-Encoding"]:
                    try:
                        del self.request.headers[h]
                    except KeyError:
                        pass
            new_request.original_request = original_request
            final_callback = self.final_callback
            self.final_callback = None
            self._release()
            self.client.fetch(new_request, final_callback)
            self.stream.close()
            return
        if self._decompressor:
            data = self._decompressor.decompress(data)
        if self.request.streaming_callback:
            if self.chunks is None:
                # if chunks is not None, we already called streaming_callback
                # in _on_chunk_data
                self.request.streaming_callback(data)
            buffer = BytesIO()
        else:
            buffer = BytesIO(data)  # TODO: don't require one big string?
        response = HTTPResponse(
            original_request,
            self.code,
            headers=self.headers,
            request_time=time.time() - self.start_time,
            buffer=buffer,
            effective_url=self.request.url,
        )
        self._run_callback(response)
        self.stream.close()

    def _on_chunk_length(self, data):
        # TODO: "chunk extensions" http://tools.ietf.org/html/rfc2616#section-3.6.1
        length = int(data.strip(), 16)
        if length == 0:
            # all the data has been decompressed, so we don't need to
            # decompress again in _on_body
            self._decompressor = None
            self._on_body(b("").join(self.chunks))
        else:
            self.stream.read_bytes(length + 2, self._on_chunk_data)  # chunk ends with \r\n

    def _on_chunk_data(self, data):
        assert data[-2:] == b("\r\n")
        chunk = data[:-2]
        if self._decompressor:
            chunk = self._decompressor.decompress(chunk)
        if self.request.streaming_callback is not None:
            self.request.streaming_callback(chunk)
        else:
            self.chunks.append(chunk)
        self.stream.read_until(b("\r\n"), self._on_chunk_length)
예제 #49
0
파일: connection.py 프로젝트: 736907871/zyl
class Connection(object):
    def __init__(self,
                 host='localhost',
                 port=6379,
                 unix_socket_path=None,
                 event_handler_proxy=None,
                 stop_after=None,
                 io_loop=None):
        self.host = host
        self.port = port
        self.unix_socket_path = unix_socket_path
        self._event_handler = event_handler_proxy
        self.timeout = stop_after
        self._stream = None
        self._io_loop = io_loop

        self.in_progress = False
        self.read_callbacks = set()
        self.ready_callbacks = deque()
        self._lock = 0
        self.info = {'db': 0, 'pass': None}

    def __del__(self):
        self.disconnect()

    def execute_pending_command(self):
        # Continue with the pending command execution
        # if all read operations are completed.
        if not self.read_callbacks and self.ready_callbacks:
            # Pop a SINGLE callback from the queue and execute it.
            # The next one will be executed from the code
            # invoked by the callback
            callback = self.ready_callbacks.popleft()
            callback()

    def ready(self):
        return (not self.read_callbacks and not self.ready_callbacks)

    def wait_until_ready(self, callback=None):
        if callback:
            if not self.ready():
                callback = stack_context.wrap(callback)
                self.ready_callbacks.append(callback)
            else:
                callback()

    def connect(self):
        if not self._stream:
            try:
                if self.unix_socket_path:
                    sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
                    sock.settimeout(self.timeout)
                    sock.connect(self.unix_socket_path)
                else:
                    sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM, 0)
                    sock.setsockopt(socket.SOL_TCP, socket.TCP_NODELAY, 1)
                    sock.settimeout(self.timeout)
                    sock.connect((self.host, self.port))
                self._stream = IOStream(sock, io_loop=self._io_loop)
                self._stream.set_close_callback(self.on_stream_close)
                self.info['db'] = 0
                self.info['pass'] = None
            except socket.error as e:
                raise ConnectionError(str(e))
            self.fire_event('on_connect')

    def on_stream_close(self):
        if self._stream:
            self.disconnect()
            callbacks = self.read_callbacks
            self.read_callbacks = set()
            for callback in callbacks:
                callback()

    def disconnect(self):
        if self._stream:
            s = self._stream
            self._stream = None
            try:
                if s.socket:
                    s.socket.shutdown(socket.SHUT_RDWR)
                s.close()
            except:
                pass

    def fire_event(self, event):
        event_handler = self._event_handler
        if event_handler:
            try:
                getattr(event_handler, event)()
            except AttributeError:
                pass

    def write(self, data, callback=None):
        if not self._stream:
            raise ConnectionError('Tried to write to '
                                  'non-existent connection')

        if callback:
            callback = stack_context.wrap(callback)
            _callback = lambda: callback(None)
            self.read_callbacks.add(_callback)
            cb = partial(self.read_callback, _callback)
        else:
            cb = None
        try:
            if PY3:
                data = bytes(data, encoding='utf-8')
            self._stream.write(data, callback=cb)
        except IOError as e:
            self.disconnect()
            raise ConnectionError(e.message)

    def read(self, length, callback=None):
        try:
            if not self._stream:
                self.disconnect()
                raise ConnectionError('Tried to read from '
                                      'non-existent connection')
            callback = stack_context.wrap(callback)
            self.read_callbacks.add(callback)
            self._stream.read_bytes(length,
                                    callback=partial(self.read_callback,
                                                     callback))
        except IOError:
            self.fire_event('on_disconnect')

    def read_callback(self, callback, *args, **kwargs):
        try:
            self.read_callbacks.remove(callback)
        except KeyError:
            pass
        callback(*args, **kwargs)

    def readline(self, callback=None):
        try:
            if not self._stream:
                self.disconnect()
                raise ConnectionError('Tried to read from '
                                      'non-existent connection')
            callback = stack_context.wrap(callback)
            self.read_callbacks.add(callback)
            callback = partial(self.read_callback, callback)
            self._stream.read_until(CRLF, callback=callback)
        except IOError:
            self.fire_event('on_disconnect')

    def connected(self):
        if self._stream:
            return True
        return False
예제 #50
0
class AsyncRedisClient(object):
    """An non-blocking Redis client.

    Example usage::

        import ioloop

        def handle_request(result):
            print 'Redis reply: %r' % result
            ioloop.IOLoop.instance().stop()

        redis_client = AsyncRedisClient(('127.0.0.1', 6379))
        redis_client.fetch(('set', 'foo', 'bar'), None)
        redis_client.fetch(('get', 'foo'), handle_request)
        ioloop.IOLoop.instance().start()

    This class implements a Redis client on top of Tornado's IOStreams.
    It does not currently implement all applicable parts of the Redis
    specification, but it does enough to work with major redis server APIs
    (mostly tested against the LIST/HASH/PUBSUB API so far).

    This class has not been tested extensively in production and
    should be considered somewhat experimental as of the release of
    tornado 1.2.  It is intended to become the default tornado
    AsyncRedisClient implementation.
    """
    def __init__(self, address, io_loop=None, socket_timeout=10):
        """Creates a AsyncRedisClient.

        address is the tuple of redis server address that can be connect by
        IOStream. It can be to ('127.0.0.1', 6379).
        """
        self.address = address
        self.io_loop = io_loop or IOLoop.instance()
        self._callback_queue = deque()
        self._callback = None
        self._read_buffer = None
        self._result_queue = deque()
        self.socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
        self.socket.settimeout(socket_timeout)
        self.socket.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
        self.socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
        self.stream = IOStream(self.socket, self.io_loop)
        self.stream.connect(self.address, self._wait_result)

    def close(self):
        """Destroys this redis client, freeing any file descriptors used.
        Not needed in normal use, but may be helpful in unittests that
        create and destroy redis clients.  No other methods may be called
        on the AsyncRedisClient after close().
        """
        self.stream.close()

    def fetch(self, request, callback):
        """Executes a request, calling callback with an redis `result`.

        The request should be a string tuple. like ('set', 'foo', 'bar')

        If an error occurs during the fetch, a `RedisError` exception will
        throw out. You can use try...except to catch the exception (if any)
        in the callback.
        """
        self._callback_queue.append(callback)
        self.stream.write(encode(request))

    def _wait_result(self):
        """Read a completed result data from the redis server."""
        self._read_buffer = deque()
        self.stream.read_until('\r\n', self._on_read_first_line)

    def _maybe_callback(self):
        """Try call callback in _callback_queue when we read a redis result."""
        try:
            read_buffer = self._read_buffer
            callback = self._callback
            result_queue = self._result_queue
            callback_queue = self._callback_queue
            if result_queue:
                result_queue.append(read_buffer)
                read_buffer = result_queue.popleft()
            if callback_queue:
                callback = self._callback = callback_queue.popleft()
            if callback:
                callback(decode(read_buffer))
        except Exception:
            logging.error('Uncaught callback exception', exc_info=True)
            self.close()
            raise
        finally:
            self._wait_result()

    def _on_read_first_line(self, data):
        self._read_buffer.append(data)
        c = data[0]
        if c in ':+-':
            self._maybe_callback()
        elif c == '$':
            if data[:3] == '$-1':
                self._maybe_callback()
            else:
                length = int(data[1:])
                self.stream.read_bytes(length + 2, self._on_read_bulk_body)
        elif c == '*':
            if data[1] in '-0':
                self._maybe_callback()
            else:
                self._multibulk_number = int(data[1:])
                self.stream.read_until('\r\n',
                                       self._on_read_multibulk_bulk_head)

    def _on_read_bulk_body(self, data):
        self._read_buffer.append(data)
        self._maybe_callback()

    def _on_read_multibulk_bulk_head(self, data):
        self._read_buffer.append(data)
        c = data[0]
        if c == '$':
            length = int(data[1:])
            self.stream.read_bytes(length + 2,
                                   self._on_read_multibulk_bulk_body)
        else:
            self._maybe_callback()

    def _on_read_multibulk_bulk_body(self, data):
        self._read_buffer.append(data)
        self._multibulk_number -= 1
        if self._multibulk_number:
            self.stream.read_until('\r\n', self._on_read_multibulk_bulk_head)
        else:
            self._maybe_callback()
예제 #51
0
class AsyncConn(event.EventedMixin):
    """
    Low level object representing a TCP connection to nsqd.

    When a message on this connection is requeued and the requeue delay
    has not been specified, it calculates the delay automatically by an
    increasing multiple of ``requeue_delay``.

    Generates the following events that can be listened to with
    :meth:`nsq.AsyncConn.on`:

     * ``connect``
     * ``close``
     * ``error``
     * ``identify``
     * ``identify_response``
     * ``auth``
     * ``auth_response``
     * ``heartbeat``
     * ``ready``
     * ``message``
     * ``response``
     * ``backoff``
     * ``resume``

    :param host: the host to connect to

    :param port: the post to connect to

    :param timeout: the timeout for read/write operations (in seconds)

    :param heartbeat_interval: the amount of time (in seconds) to negotiate
        with the connected producers to send heartbeats (requires nsqd 0.2.19+)

    :param requeue_delay: the base multiple used when calculating requeue delay
        (multiplied by # of attempts)

    :param tls_v1: enable TLS v1 encryption (requires nsqd 0.2.22+)

    :param tls_options: dictionary of options to pass to `ssl.wrap_socket()
        <http://docs.python.org/2/library/ssl.html#ssl.wrap_socket>`_ as
        ``**kwargs``

    :param snappy: enable Snappy stream compression (requires nsqd 0.2.23+)

    :param deflate: enable deflate stream compression (requires nsqd 0.2.23+)

    :param deflate_level: configure the deflate compression level for this
        connection (requires nsqd 0.2.23+)

    :param output_buffer_size: size of the buffer (in bytes) used by nsqd
        for buffering writes to this connection

    :param output_buffer_timeout: timeout (in ms) used by nsqd before
        flushing buffered writes (set to 0 to disable).  **Warning**:
        configuring clients with an extremely low (``< 25ms``)
        ``output_buffer_timeout`` has a significant effect on ``nsqd``
        CPU usage (particularly with ``> 50`` clients connected).

    :param sample_rate: take only a sample of the messages being sent
        to the client. Not setting this or setting it to 0 will ensure
        you get all the messages destined for the client.
        Sample rate can be greater than 0 or less than 100 and the client
        will receive that percentage of the message traffic.
        (requires nsqd 0.2.25+)

    :param user_agent: a string identifying the agent for this client
        in the spirit of HTTP (default: ``<client_library_name>/<version>``)
        (requires nsqd 0.2.25+)

    :param auth_secret: a string passed when using nsq auth
        (requires nsqd 1.0+)

    :param msg_timeout: the amount of time (in seconds) that nsqd will wait
        before considering messages that have been delivered to this
        consumer timed out (requires nsqd 0.2.28+)

    :param hostname: a string identifying the host where this client runs
        (default: ``<hostname>``)
    """
    def __init__(self,
                 host,
                 port,
                 timeout=1.0,
                 heartbeat_interval=30,
                 requeue_delay=90,
                 tls_v1=False,
                 tls_options=None,
                 snappy=False,
                 deflate=False,
                 deflate_level=6,
                 user_agent=DEFAULT_USER_AGENT,
                 output_buffer_size=16 * 1024,
                 output_buffer_timeout=250,
                 sample_rate=0,
                 auth_secret=None,
                 msg_timeout=None,
                 hostname=None):
        assert isinstance(host, string_types)
        assert isinstance(port, int)
        assert isinstance(timeout, float)
        assert isinstance(tls_options, (dict, None.__class__))
        assert isinstance(deflate_level, int)
        assert isinstance(heartbeat_interval, int) and heartbeat_interval >= 1
        assert isinstance(requeue_delay, int) and requeue_delay >= 0
        assert isinstance(output_buffer_size, int) and output_buffer_size >= 0
        assert isinstance(output_buffer_timeout,
                          int) and output_buffer_timeout >= 0
        assert isinstance(sample_rate,
                          int) and sample_rate >= 0 and sample_rate < 100
        assert msg_timeout is None or (isinstance(msg_timeout, (float, int))
                                       and msg_timeout > 0)
        # auth_secret validated by to_bytes() below

        self.state = INIT
        self.host = host
        self.port = port
        self.timeout = timeout
        self.last_recv_timestamp = time.time()
        self.last_msg_timestamp = time.time()
        self.in_flight = 0
        self.rdy = 0
        self.rdy_timeout = None
        # for backwards compatibility when interacting with older nsqd
        # (pre 0.2.20), default this to their hard-coded max
        self.max_rdy_count = 2500
        self.tls_v1 = tls_v1
        self.tls_options = tls_options
        self.snappy = snappy
        self.deflate = deflate
        self.deflate_level = deflate_level
        self.hostname = hostname
        if self.hostname is None:
            self.hostname = socket.gethostname()
        self.short_hostname = self.hostname.split('.')[0]
        self.heartbeat_interval = heartbeat_interval * 1000
        self.msg_timeout = int(msg_timeout * 1000) if msg_timeout else None
        self.requeue_delay = requeue_delay

        self.output_buffer_size = output_buffer_size
        self.output_buffer_timeout = output_buffer_timeout
        self.sample_rate = sample_rate
        self.user_agent = user_agent

        self._authentication_required = False  # tracking server auth state
        self.auth_secret = to_bytes(auth_secret) if auth_secret else None

        self.socket = None
        self.stream = None
        self._features_to_enable = []

        self.last_rdy = 0
        self.rdy = 0

        self.callback_queue = []
        self.encoder = DefaultEncoder()

        super(AsyncConn, self).__init__()

    @property
    def id(self):
        return str(self)

    def __str__(self):
        return self.host + ':' + str(self.port)

    def connected(self):
        return self.state == CONNECTED

    def connecting(self):
        return self.state == CONNECTING

    def closed(self):
        return self.state in (INIT, DISCONNECTED)

    def connect(self):
        if not self.closed():
            return

        # Assume host is an ipv6 address if it has a colon.
        if ':' in self.host:
            self.socket = socket.socket(socket.AF_INET6, socket.SOCK_STREAM)
        else:
            self.socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)

        self.socket.settimeout(self.timeout)
        self.socket.setblocking(0)

        self.stream = IOStream(self.socket)
        self.stream.set_close_callback(self._socket_close)
        self.stream.set_nodelay(True)

        self.state = CONNECTING
        self.on(event.CONNECT, self._on_connect)
        self.on(event.DATA, self._on_data)

        fut = self.stream.connect((self.host, self.port))
        IOLoop.current().add_future(fut, self._connect_callback)

    def _connect_callback(self, fut):
        fut.result()
        self.state = CONNECTED
        self.stream.write(protocol.MAGIC_V2)
        self._start_read()
        self.trigger(event.CONNECT, conn=self)

    def _read_bytes(self, size, callback):
        try:
            fut = self.stream.read_bytes(size)
            IOLoop.current().add_future(fut, callback)
        except IOError:
            self.close()
            self.trigger(
                event.ERROR,
                conn=self,
                error=protocol.ConnectionClosedError('Stream is closed'),
            )

    def _start_read(self):
        if self.stream is None:
            return  # IOStream.start_tls() invalidates stream, will call again when ready
        self._read_bytes(4, self._read_size)

    def _socket_close(self):
        self.state = DISCONNECTED
        self.trigger(event.CLOSE, conn=self)

    def close(self):
        self.stream.close()

    def _read_size(self, fut):
        try:
            data = fut.result()
            size = struct_l.unpack(data)[0]
        except Exception:
            self.close()
            self.trigger(
                event.ERROR,
                conn=self,
                error=protocol.IntegrityError('failed to unpack size'),
            )
            return
        self._read_bytes(size, self._read_body)

    def _read_body(self, fut):
        try:
            data = fut.result()
            self.trigger(event.DATA, conn=self, data=data)
        except Exception:
            logger.exception('uncaught exception in data event')
        self._start_read()

    def send(self, data):
        return self.stream.write(self.encoder.encode(data))

    def upgrade_to_tls(self, options=None):
        # in order to upgrade to TLS we need to *replace* the IOStream...
        opts = {
            'cert_reqs': ssl.CERT_REQUIRED,
            'ssl_version': ssl.PROTOCOL_TLSv1_2
        }
        opts.update(options or {})

        fut = self.stream.start_tls(False,
                                    ssl_options=opts,
                                    server_hostname=self.host)
        self.stream = None

        def finish_upgrade_tls(fut):
            try:
                self.stream = fut.result()
                self.socket = self.stream.socket
                self._start_read()
            except Exception as e:
                # skip self.close() because no stream
                self.trigger(
                    event.ERROR,
                    conn=self,
                    error=protocol.SendError('failed to upgrade to TLS', e),
                )

        IOLoop.current().add_future(fut, finish_upgrade_tls)

    def upgrade_to_snappy(self):
        assert SnappySocket, 'snappy requires the python-snappy package'

        # in order to upgrade to Snappy we need to use whatever IOStream
        # is currently in place (normal or SSL)...
        #
        # first read any compressed bytes the existing IOStream might have
        # already buffered and use that to bootstrap the SnappySocket, then
        # monkey patch the existing IOStream by replacing its socket
        # with a wrapper that will automagically handle compression.
        existing_data = self.stream._consume(self.stream._read_buffer_size)
        self.socket = SnappySocket(self.socket)
        self.socket.bootstrap(existing_data)
        self.stream.socket = self.socket
        self.encoder = SnappyEncoder()

    def upgrade_to_deflate(self):
        # in order to upgrade to DEFLATE we need to use whatever IOStream
        # is currently in place (normal or SSL)...
        #
        # first read any compressed bytes the existing IOStream might have
        # already buffered and use that to bootstrap the DeflateSocket, then
        # monkey patch the existing IOStream by replacing its socket
        # with a wrapper that will automagically handle compression.
        existing_data = self.stream._consume(self.stream._read_buffer_size)
        self.socket = DeflateSocket(self.socket, self.deflate_level)
        self.socket.bootstrap(existing_data)
        self.stream.socket = self.socket
        self.encoder = DeflateEncoder(level=self.deflate_level)

    def send_rdy(self, value):
        if self.last_rdy != value:
            try:
                self.send(protocol.ready(value))
            except Exception as e:
                self.close()
                self.trigger(
                    event.ERROR,
                    conn=self,
                    error=protocol.SendError('failed to send RDY %d' % value,
                                             e),
                )
                return False
        self.last_rdy = value
        self.rdy = value
        return True

    def _on_connect(self, **kwargs):
        identify_data = {
            'short_id': self.
            short_hostname,  # TODO remove when deprecating pre 1.0 support
            'long_id':
            self.hostname,  # TODO remove when deprecating pre 1.0 support
            'client_id': self.short_hostname,
            'hostname': self.hostname,
            'heartbeat_interval': self.heartbeat_interval,
            'feature_negotiation': True,
            'tls_v1': self.tls_v1,
            'snappy': self.snappy,
            'deflate': self.deflate,
            'deflate_level': self.deflate_level,
            'output_buffer_timeout': self.output_buffer_timeout,
            'output_buffer_size': self.output_buffer_size,
            'sample_rate': self.sample_rate,
            'user_agent': self.user_agent
        }
        if self.msg_timeout:
            identify_data['msg_timeout'] = self.msg_timeout
        self.trigger(event.IDENTIFY, conn=self, data=identify_data)
        self.on(event.RESPONSE, self._on_identify_response)
        try:
            self.send(protocol.identify(identify_data))
        except Exception as e:
            self.close()
            self.trigger(
                event.ERROR,
                conn=self,
                error=protocol.SendError('failed to bootstrap connection', e),
            )

    def _on_identify_response(self, data, **kwargs):
        self.off(event.RESPONSE, self._on_identify_response)

        if data == b'OK':
            logger.warning(
                'nsqd version does not support feature netgotiation')
            return self.trigger(event.READY, conn=self)

        try:
            data = json.loads(data.decode('utf-8'))
        except ValueError:
            self.close()
            self.trigger(
                event.ERROR,
                conn=self,
                error=protocol.IntegrityError(
                    'failed to parse IDENTIFY response JSON from nsqd - %r' %
                    data),
            )
            return

        self.trigger(event.IDENTIFY_RESPONSE, conn=self, data=data)

        if self.tls_v1 and data.get('tls_v1'):
            self._features_to_enable.append('tls_v1')
        if self.snappy and data.get('snappy'):
            self._features_to_enable.append('snappy')
        if self.deflate and data.get('deflate'):
            self._features_to_enable.append('deflate')

        if data.get('auth_required'):
            self._authentication_required = True

        if data.get('max_rdy_count'):
            self.max_rdy_count = data.get('max_rdy_count')
        else:
            # for backwards compatibility when interacting with older nsqd
            # (pre 0.2.20), default this to their hard-coded max
            logger.warn('setting max_rdy_count to default value of 2500')
            self.max_rdy_count = 2500

        self.on(event.RESPONSE, self._on_response_continue)
        self._on_response_continue(conn=self, data=None)

    def _on_response_continue(self, data, **kwargs):
        if self._features_to_enable:
            feature = self._features_to_enable.pop(0)
            if feature == 'tls_v1':
                self.upgrade_to_tls(self.tls_options)
            elif feature == 'snappy':
                self.upgrade_to_snappy()
            elif feature == 'deflate':
                self.upgrade_to_deflate()
            # the server will 'OK' after these connection upgrades triggering another response
            return

        self.off(event.RESPONSE, self._on_response_continue)
        if self.auth_secret and self._authentication_required:
            self.on(event.RESPONSE, self._on_auth_response)
            self.trigger(event.AUTH, conn=self, data=self.auth_secret)
            try:
                self.send(protocol.auth(self.auth_secret))
            except Exception as e:
                self.close()
                self.trigger(
                    event.ERROR,
                    conn=self,
                    error=protocol.SendError('Error sending AUTH', e),
                )
            return
        self.trigger(event.READY, conn=self)

    def _on_auth_response(self, data, **kwargs):
        try:
            data = json.loads(data.decode('utf-8'))
        except ValueError:
            self.close()
            self.trigger(
                event.ERROR,
                conn=self,
                error=protocol.IntegrityError(
                    'failed to parse AUTH response JSON from nsqd - %r' %
                    data),
            )
            return

        self.off(event.RESPONSE, self._on_auth_response)
        self.trigger(event.AUTH_RESPONSE, conn=self, data=data)
        return self.trigger(event.READY, conn=self)

    def _on_data(self, data, **kwargs):
        self.last_recv_timestamp = time.time()
        frame, data = protocol.unpack_response(data)
        if frame == protocol.FRAME_TYPE_MESSAGE:
            self.last_msg_timestamp = time.time()
            self.in_flight += 1

            message = protocol.decode_message(data)
            message.on(event.FINISH, self._on_message_finish)
            message.on(event.REQUEUE, self._on_message_requeue)
            message.on(event.TOUCH, self._on_message_touch)

            self.trigger(event.MESSAGE, conn=self, message=message)
        elif frame == protocol.FRAME_TYPE_RESPONSE and data == b'_heartbeat_':
            self.send(protocol.nop())
            self.trigger(event.HEARTBEAT, conn=self)
        elif frame == protocol.FRAME_TYPE_RESPONSE:
            self.trigger(event.RESPONSE, conn=self, data=data)
        elif frame == protocol.FRAME_TYPE_ERROR:
            self.trigger(event.ERROR, conn=self, error=protocol.Error(data))

    def _on_message_requeue(self, message, backoff=True, time_ms=-1, **kwargs):
        if backoff:
            self.trigger(event.BACKOFF, conn=self)
        else:
            self.trigger(event.CONTINUE, conn=self)

        self.in_flight -= 1
        try:
            time_ms = self.requeue_delay * message.attempts * 1000 if time_ms < 0 else time_ms
            self.send(protocol.requeue(message.id, time_ms))
        except Exception as e:
            self.close()
            self.trigger(event.ERROR,
                         conn=self,
                         error=protocol.SendError(
                             'failed to send REQ %s @ %d' %
                             (message.id, time_ms), e))

    def _on_message_finish(self, message, **kwargs):
        self.trigger(event.RESUME, conn=self)

        self.in_flight -= 1
        try:
            self.send(protocol.finish(message.id))
        except Exception as e:
            self.close()
            self.trigger(
                event.ERROR,
                conn=self,
                error=protocol.SendError('failed to send FIN %s' % message.id,
                                         e),
            )

    def _on_message_touch(self, message, **kwargs):
        try:
            self.send(protocol.touch(message.id))
        except Exception as e:
            self.close()
            self.trigger(
                event.ERROR,
                conn=self,
                error=protocol.SendError(
                    'failed to send TOUCH %s' % message.id, e),
            )
예제 #52
0
class Connection(object):

    def __init__(self, host='localhost', port=11211, pool=None):
        self._host = host
        self._port = port
        self._pool = pool
        self._socket = None
        self._stream = None
        self._ioloop = IOLoop.instance()
        self.connect()

    def connect(self):
        try:       
            self._socket = socket(AF_INET, SOCK_STREAM, 0)
            self._socket.connect((self._host, self._port))
            self._stream = IOStream(self._socket, io_loop=self._ioloop)
            self._stream.set_close_callback(self.on_disconnect)
        except error as e:
            raise ConnectionError(e)

    def disconect(self):
        callback = self._final_callback
        self._final_callback = None
        try:
            if callback:
                callback(None)
        finally:
            self._stream._close_callback = None
            self._stream.close()

    def on_disconnect(self):
        callback = self._final_callback
        self._final_callback = None
        try:
            if callback:
                callback(None)
        finally:
            logging.debug('asyncmemcached closing connection')
            self._pool.release(self)

    def closed(self):
        return self._stream.closed()

    def send_command(self, fullcmd, expect_str, callback):
        self._final_callback = callback
        if self._stream.closed():
            self.connect()
        with stack_context.StackContext(self.cleanup):
            if fullcmd[0:3] == 'get' or \
                    fullcmd[0:4] == 'incr' or \
                    fullcmd[0:4] == 'decr':
                self._stream.write(fullcmd, self.read_value)
            else:
                self._stream.write(fullcmd,
                        functools.partial(self.read_response, expect_str))
    
    def read_response(self, expect_str):
        self._stream.read_until('\r\n', 
                        functools.partial(self._expect_callback,
                                        expect_str))
    def read_value(self):
        self._stream.read_until('\r\n', self._expect_value_header_callback)

    def _expect_value_header_callback(self, response):
        response = response[:-2]

        if response[:5] == 'VALUE':
            resp, key, flag, length = response.split()
            flag = int(flag)
            length = int(length)
            self._stream.read_bytes(length+2, self._expect_value_callback)
        elif response.isdigit():
            try:
                callback = self._final_callback
                self._final_callback = None
                if callback:
                    callback(int(response))
            finally:
                self._pool.release(self)
        else:
            try:
                callback = self._final_callback
                self._final_callback = None
                if callback:
                    callback(None)
            finally:
                self._pool.release(self)

    def _expect_value_callback(self, value):
        
        value = value[:-2]
        self._stream.read_until('\r\n',
                functools.partial(self._end_value_callback, value))

    def _end_value_callback(self, value, response):
        response = response.rstrip('\r\n')

        if response == 'END':
            try:
                callback = self._final_callback
                self._final_callback = None
                if callback:
                    callback(value)
            finally:
                self._pool.release(self)
        else:
            raise RedisError('error %s' % response)

    def _expect_callback(self, expect_str, response):
        response = response.rstrip('\r\n')
        if response == expect_str:
            try:
                callback = self._final_callback
                self._final_callback = None
                if callback:
                    callback(None)
            finally:
                self._pool.release(self)
        else:
            raise RedisError('error %s' % response)

    @contextlib.contextmanager
    def cleanup(self):
        try:
            yield
        except Exception as e:
            logging.warning("uncaught exception", exc_info=True)
            try:
                callback = self._final_callback
                self._final_callback = None
                if callback:
                    callback(None)
            finally:
                self._pool.release(self)
예제 #53
0
class KeepAliveTest(AsyncHTTPTestCase):
    """Tests various scenarios for HTTP 1.1 keep-alive support.

    These tests don't use AsyncHTTPClient because we want to control
    connection reuse and closing.
    """
    def get_app(self):
        class HelloHandler(RequestHandler):
            def get(self):
                self.finish('Hello world')

            def post(self):
                self.finish('Hello world')

        class LargeHandler(RequestHandler):
            def get(self):
                # 512KB should be bigger than the socket buffers so it will
                # be written out in chunks.
                self.write(''.join(chr(i % 256) * 1024 for i in range(512)))

        class FinishOnCloseHandler(RequestHandler):
            @asynchronous
            def get(self):
                self.flush()

            def on_connection_close(self):
                # This is not very realistic, but finishing the request
                # from the close callback has the right timing to mimic
                # some errors seen in the wild.
                self.finish('closed')

        return Application([('/', HelloHandler),
                            ('/large', LargeHandler),
                            ('/finish_on_close', FinishOnCloseHandler)])

    def setUp(self):
        super(KeepAliveTest, self).setUp()
        self.http_version = b'HTTP/1.1'

    def tearDown(self):
        # We just closed the client side of the socket; let the IOLoop run
        # once to make sure the server side got the message.
        self.io_loop.add_timeout(datetime.timedelta(seconds=0.001), self.stop)
        self.wait()

        if hasattr(self, 'stream'):
            self.stream.close()
        super(KeepAliveTest, self).tearDown()

    # The next few methods are a crude manual http client
    def connect(self):
        self.stream = IOStream(socket.socket())
        self.stream.connect(('127.0.0.1', self.get_http_port()), self.stop)
        self.wait()

    def read_headers(self):
        self.stream.read_until(b'\r\n', self.stop)
        first_line = self.wait()
        self.assertTrue(first_line.startswith(b'HTTP/1.1 200'), first_line)
        self.stream.read_until(b'\r\n\r\n', self.stop)
        header_bytes = self.wait()
        headers = HTTPHeaders.parse(header_bytes.decode('latin1'))
        return headers

    def read_response(self):
        self.headers = self.read_headers()
        self.stream.read_bytes(int(self.headers['Content-Length']), self.stop)
        body = self.wait()
        self.assertEqual(b'Hello world', body)

    def close(self):
        self.stream.close()
        del self.stream

    def test_two_requests(self):
        self.connect()
        self.stream.write(b'GET / HTTP/1.1\r\n\r\n')
        self.read_response()
        self.stream.write(b'GET / HTTP/1.1\r\n\r\n')
        self.read_response()
        self.close()

    def test_request_close(self):
        self.connect()
        self.stream.write(b'GET / HTTP/1.1\r\nConnection: close\r\n\r\n')
        self.read_response()
        self.stream.read_until_close(callback=self.stop)
        data = self.wait()
        self.assertTrue(not data)
        self.assertEqual(self.headers['Connection'], 'close')
        self.close()

    # keepalive is supported for http 1.0 too, but it's opt-in
    def test_http10(self):
        self.http_version = b'HTTP/1.0'
        self.connect()
        self.stream.write(b'GET / HTTP/1.0\r\n\r\n')
        self.read_response()
        self.stream.read_until_close(callback=self.stop)
        data = self.wait()
        self.assertTrue(not data)
        self.assertTrue('Connection' not in self.headers)
        self.close()

    def test_http10_keepalive(self):
        self.http_version = b'HTTP/1.0'
        self.connect()
        self.stream.write(b'GET / HTTP/1.0\r\nConnection: keep-alive\r\n\r\n')
        self.read_response()
        self.assertEqual(self.headers['Connection'], 'Keep-Alive')
        self.stream.write(b'GET / HTTP/1.0\r\nConnection: keep-alive\r\n\r\n')
        self.read_response()
        self.assertEqual(self.headers['Connection'], 'Keep-Alive')
        self.close()

    def test_http10_keepalive_extra_crlf(self):
        self.http_version = b'HTTP/1.0'
        self.connect()
        self.stream.write(b'GET / HTTP/1.0\r\nConnection: keep-alive\r\n\r\n\r\n')
        self.read_response()
        self.assertEqual(self.headers['Connection'], 'Keep-Alive')
        self.stream.write(b'GET / HTTP/1.0\r\nConnection: keep-alive\r\n\r\n')
        self.read_response()
        self.assertEqual(self.headers['Connection'], 'Keep-Alive')
        self.close()

    def test_pipelined_requests(self):
        self.connect()
        self.stream.write(b'GET / HTTP/1.1\r\n\r\nGET / HTTP/1.1\r\n\r\n')
        self.read_response()
        self.read_response()
        self.close()

    def test_pipelined_cancel(self):
        self.connect()
        self.stream.write(b'GET / HTTP/1.1\r\n\r\nGET / HTTP/1.1\r\n\r\n')
        # only read once
        self.read_response()
        self.close()

    def test_cancel_during_download(self):
        self.connect()
        self.stream.write(b'GET /large HTTP/1.1\r\n\r\n')
        self.read_headers()
        self.stream.read_bytes(1024, self.stop)
        self.wait()
        self.close()

    def test_finish_while_closed(self):
        self.connect()
        self.stream.write(b'GET /finish_on_close HTTP/1.1\r\n\r\n')
        self.read_headers()
        self.close()

    def test_keepalive_chunked(self):
        self.http_version = b'HTTP/1.0'
        self.connect()
        self.stream.write(b'POST / HTTP/1.0\r\n'
                          b'Connection: keep-alive\r\n'
                          b'Transfer-Encoding: chunked\r\n'
                          b'\r\n'
                          b'0\r\n'
                          b'\r\n')
        self.read_response()
        self.assertEqual(self.headers['Connection'], 'Keep-Alive')
        self.stream.write(b'GET / HTTP/1.0\r\nConnection: keep-alive\r\n\r\n')
        self.read_response()
        self.assertEqual(self.headers['Connection'], 'Keep-Alive')
        self.close()
예제 #54
0
class TTornadoTransport(TTransportBase):
  """A non-blocking Thrift client.

  Example usage::

    import greenlet
    from tornado import ioloop
    from thrift.transport import TTransport
    from thrift.protocol import TBinaryProtocol

    from viewfinder.backend.thrift import TTornadoTransport

    transport = TTransport.TFramedTransport(TTornadoTransport('localhost', 9090))
    protocol = TBinaryProtocol.TBinaryProtocol(transport)
    client = Service.Client(protocol)
    ioloop.IOLoop.instance().start()

  Then, from within an asynchronous tornado request handler:

    class MyApp(tornado.web.RequestHandler):
      @tornado.web.asynchronous
      def post(self):
      def business_logic():
        ...any thrift calls...
        self.write(...stuff that gets returned to client...)
        self.finish() #end the asynchronous request
      gr = greenlet.greenlet(business_logic)
      gr.switch()
  """

  def __init__(self, host='localhost', port=9090):
    """Initialize a TTornadoTransport with a Tornado IOStream.

    @param host(str) The host to connect to.
    @param port(int) The (TCP) port to connect to.
    """
    self.host = host
    self.port = port
    self._stream = None
    self._io_loop = ioloop.IOLoop.current()
    self._timeout_secs = None

  def set_timeout(self, timeout_secs):
    """Sets a timeout for use with open/read/write operations."""
    self._timeout_secs = timeout_secs

  def isOpen(self):
    return self._stream is not None

  def open(self):
    """Creates a connection to host:port and spins up a tornado
    IOStream object to write requests and read responses from the
    thrift server. After making the asynchronous connect call to
    _stream, the current greenlet yields control back to the parent
    greenlet (presumably the "master" greenlet).
    """
    assert greenlet.getcurrent().parent is not None
    # TODO(spencer): allow ipv6? (af = socket.AF_UNSPEC)
    addrinfo = socket.getaddrinfo(self.host, self.port, socket.AF_INET,
                                  socket.SOCK_STREAM, 0, 0)
    af, socktype, proto, canonname, sockaddr = addrinfo[0]
    self._stream = IOStream(socket.socket(af, socktype, proto),
                            io_loop=self._io_loop)
    self._open_internal(sockaddr)

  def close(self):
    if self._stream:
      self._stream.set_close_callback(None)
      self._stream.close()
      self._stream = None

  @_wrap_transport
  def read(self, sz):
    logging.debug("reading %d bytes from %s:%d" % (sz, self.host, self.port))
    cur_gr = greenlet.getcurrent()
    def _on_read(buf):
      if self._stream:
        cur_gr.switch(buf)
    self._stream.read_bytes(sz, _on_read)
    buf = cur_gr.parent.switch()
    if len(buf) == 0:
      raise TTransportException(type=TTransportException.END_OF_FILE,
                                message='TTornadoTransport read 0 bytes')
    logging.debug("read %d bytes in %.2fms" %
                  (len(buf), (time.time() - self._start_time) * 1000))
    return buf

  @_wrap_transport
  def write(self, buf):
    logging.debug("writing %d bytes to %s:%d" % (len(buf), self.host, self.port))
    cur_gr = greenlet.getcurrent()
    def _on_write():
      if self._stream:
        cur_gr.switch()
    self._stream.write(buf, _on_write)
    cur_gr.parent.switch()
    logging.debug("wrote %d bytes in %.2fms" %
                  (len(buf), (time.time() - self._start_time) * 1000))

  @_wrap_transport
  def flush(self):
    pass

  @_wrap_transport
  def _open_internal(self, sockaddr):
    logging.debug("opening connection to %s:%d" % (self.host, self.port))
    cur_gr = greenlet.getcurrent()
    def _on_connect():
      if self._stream:
        cur_gr.switch()
    self._stream.connect(sockaddr, _on_connect)
    cur_gr.parent.switch()
    logging.info("opened connection to %s:%d" % (self.host, self.port))

  def _check_stream(self):
    if not self._stream:
      raise TTransportException(
        type=TTransportException.NOT_OPEN, message='transport not open')

  def _set_timeout(self):
    if self._timeout_secs:
      return self._io_loop.add_timeout(
        time.time() + self._timeout_secs, functools.partial(
          self._on_timeout, gr=greenlet.getcurrent()))
    return None

  def _clear_timeout(self, timeout):
    if timeout:
      self._io_loop.remove_timeout(timeout)

  def _on_timeout(self, gr):
    gr.throw(TTransportException(
        type=TTransportException.TIMED_OUT,
        message="connection timed out to %s:%d" % (self.host, self.port)))

  def _on_close(self, gr):
    self._stream = None
    message = "connection to %s:%d closed" % (self.host, self.port)
    if gr:
      gr.throw(TTransportException(
          type=TTransportException.NOT_OPEN, message=message))
    else:
      logging.error(message)