Beispiel #1
0
class ForwardConnection(object):

    def __init__(self, remote_address, stream, address, headers):
        self.remote_address = remote_address
        self.stream = stream
        self.address = address
        self.headers = headers
        sock = socket.socket()
        self.remote_stream = IOStream(sock)
        self.remote_stream.connect(self.remote_address, self._on_remote_connected)    
        self.remote_stream.set_close_callback(self._on_close)    

    def _on_remote_write_complete(self):
        logging.info('send request to %s', self.remote_address)
        self.remote_stream.read_until_close(self._on_remote_read_close)

    def _on_remote_connected(self):
        logging.info('forward %r to %r', self.address, self.remote_address)
        self.remote_stream.write(self.headers, self._on_remote_write_complete)

    def _on_remote_read_close(self, data):
        self.stream.write(data, self.stream.close)

    def _on_close(self):
        logging.info('remote quit %s', self.remote_address)
        self.remote_stream.close()
Beispiel #2
0
class ForwardConnection(object):
    def __init__(self, remote_address, stream, address):
        self.remote_address = remote_address
        self.stream = stream
        self.address = address
        sock = socket.socket()
        self.remote_stream = IOStream(sock)
        self.remote_stream.connect(self.remote_address, self._on_remote_connected)

    def _on_remote_connected(self):
        logging.info("forward %r to %r", self.address, self.remote_address)
        self.remote_stream.read_until_close(self._on_remote_read_close, self.stream.write)
        self.stream.read_until_close(self._on_read_close, self.remote_stream.write)

    def _on_remote_read_close(self, data):
        if self.stream.writing():
            self.stream.write(data, self.stream.close)
        else:
            self.stream.close()

    def _on_read_close(self, data):
        if self.remote_stream.writing():
            self.remote_stream.write(data, self.remote_stream.close)
        else:
            self.remote_stream.close()
Beispiel #3
0
    def test_read_until_close(self):
        s = socket.socket(socket.AF_INET, socket.SOCK_STREAM, 0)
        s.connect(("localhost", self.get_http_port()))
        stream = IOStream(s, io_loop=self.io_loop)
        stream.write(b("GET / HTTP/1.0\r\n\r\n"))

        stream.read_until_close(self.stop)
        data = self.wait()
        self.assertTrue(data.startswith(b("HTTP/1.0 200")))
        self.assertTrue(data.endswith(b("Hello")))
Beispiel #4
0
class RemoteUpstream(Upstream):

    """
    The most methods are the same in LocalUpstream, but maybe in future
    need to be diffrent.

    """

    def initialize(self):
        self.socket = socket.socket(self._address_type, socket.SOCK_STREAM)
        self.stream = IOStream(self.socket)
        self.stream.set_close_callback(self.on_close)

    def do_connect(self):
        self.stream.connect(self.dest, self.on_connect)

    @property
    def address(self):
        return self.socket.getsockname()

    @property
    def address_type(self):
        return self._address_type

    def on_connect(self):

        self.connection_callback(self)
        on_finish = functools.partial(self.on_streaming_data, finished=True)
        self.stream.read_until_close(on_finish, self.on_streaming_data)

    def on_close(self):
        if self.stream.error:
            self.error_callback(self, self.stream.error)
        else:
            self.close_callback(self)

    def on_streaming_data(self, data, finished=False):
        if len(data):
            self.streaming_callback(self, data)

    def do_write(self, data):
        try:
            self.stream.write(data)
        except IOError as e:
            self.close()

    def do_close(self):
        if self.socket:
            logger.info("close upstream: %s:%s" % self.address)
            self.stream.close()
Beispiel #5
0
class UnixSocketTest(AsyncTestCase):
    """HTTPServers can listen on Unix sockets too.

    Why would you want to do this?  Nginx can proxy to backends listening
    on unix sockets, for one thing (and managing a namespace for unix
    sockets can be easier than managing a bunch of TCP port numbers).

    Unfortunately, there's no way to specify a unix socket in a url for
    an HTTP client, so we have to test this by hand.
    """
    def setUp(self):
        super(UnixSocketTest, self).setUp()
        self.tmpdir = tempfile.mkdtemp()
        self.sockfile = os.path.join(self.tmpdir, "test.sock")
        sock = netutil.bind_unix_socket(self.sockfile)
        app = Application([("/hello", HelloWorldRequestHandler)])
        self.server = HTTPServer(app)
        self.server.add_socket(sock)
        self.stream = IOStream(socket.socket(socket.AF_UNIX))
        self.stream.connect(self.sockfile, self.stop)
        self.wait()

    def tearDown(self):
        self.stream.close()
        self.io_loop.run_sync(self.server.close_all_connections)
        self.server.stop()
        shutil.rmtree(self.tmpdir)
        super(UnixSocketTest, self).tearDown()

    def test_unix_socket(self):
        self.stream.write(b"GET /hello HTTP/1.0\r\n\r\n")
        self.stream.read_until(b"\r\n", self.stop)
        response = self.wait()
        self.assertEqual(response, b"HTTP/1.1 200 OK\r\n")
        self.stream.read_until(b"\r\n\r\n", self.stop)
        headers = HTTPHeaders.parse(self.wait().decode('latin1'))
        self.stream.read_bytes(int(headers["Content-Length"]), self.stop)
        body = self.wait()
        self.assertEqual(body, b"Hello world")

    def test_unix_socket_bad_request(self):
        # Unix sockets don't have remote addresses so they just return an
        # empty string.
        with ExpectLog(gen_log, "Malformed HTTP message from"):
            self.stream.write(b"garbage\r\n\r\n")
            self.stream.read_until_close(self.stop)
            response = self.wait()
        self.assertEqual(response, b"HTTP/1.1 400 Bad Request\r\n\r\n")
Beispiel #6
0
class ForwardConnection(object):
    def __init__(self, server, stream, address):
        self._close_callback = None
        self.server = server
        self.stream = stream
        self.reverse_address = address
        self.address = stream.socket.getsockname()
        self.remote_address = server.conf[self.address]
        sock = socket.socket()
        self.remote_stream = IOStream(sock)
        self.remote_stream.connect(self.remote_address, self._on_remote_connected)

    def close(self):
        self.remote_stream.close()

    def set_close_callback(self, callback):
        self._close_callback = callback

    def _on_remote_connected(self):
        ip_from = self.reverse_address[0]
        fwd_str = get_forwarding_str(self.address[0], self.address[1],
                                      self.remote_address[0], self.remote_address[1])
        logging.info('Connected ip: %s, forward %s', ip_from, fwd_str)
        self.remote_stream.read_until_close(self._on_remote_read_close, self.stream.write)
        self.stream.read_until_close(self._on_read_close, self.remote_stream.write)

    def _on_remote_read_close(self, data):
        if self.stream.writing():
            self.stream.write(data, self.stream.close)
        else:
            if self.stream.closed():
                self._on_closed()
            else:
                self.stream.close()

    def _on_read_close(self, data):
        if self.remote_stream.writing():
            self.remote_stream.write(data, self.remote_stream.close)
        else:
            if self.remote_stream.closed():
                self._on_closed()
            else:
                self.remote_stream.close()

    def _on_closed(self):
        logging.info('Disconnected ip: %s', self.reverse_address[0])
        if self._close_callback:
            self._close_callback(self)
Beispiel #7
0
    def test_handle_stream_coroutine_logging(self):
        # handle_stream may be a coroutine and any exception in its
        # Future will be logged.
        class TestServer(TCPServer):
            @gen.coroutine
            def handle_stream(self, stream, address):
                yield gen.moment
                stream.close()
                1 / 0

        server = client = None
        try:
            sock, port = bind_unused_port()
            with NullContext():
                server = TestServer()
                server.add_socket(sock)
            client = IOStream(socket.socket())
            with ExpectLog(app_log, "Exception in callback"):
                yield client.connect(('localhost', port))
                yield client.read_until_close()
                yield gen.moment
        finally:
            if server is not None:
                server.stop()
            if client is not None:
                client.close()
Beispiel #8
0
 def test_timeout(self):
     stream = IOStream(socket.socket())
     try:
         yield stream.connect(("127.0.0.1", self.get_http_port()))
         # Use a raw stream because AsyncHTTPClient won't let us read a
         # response without finishing a body.
         stream.write(b"PUT /streaming?body_timeout=0.1 HTTP/1.0\r\n" b"Content-Length: 42\r\n\r\n")
         with ExpectLog(gen_log, "Timeout reading body"):
             response = yield stream.read_until_close()
         self.assertEqual(response, b"")
     finally:
         stream.close()
 def test_timeout(self):
     stream = IOStream(socket.socket())
     try:
         yield stream.connect(('127.0.0.1', self.get_http_port()))
         # Use a raw stream because AsyncHTTPClient won't let us read a
         # response without finishing a body.
         stream.write(b'PUT /streaming?body_timeout=0.1 HTTP/1.0\r\n'
                      b'Content-Length: 42\r\n\r\n')
         with ExpectLog(gen_log, 'Timeout reading body'):
             response = yield stream.read_until_close()
         self.assertEqual(response, b'')
     finally:
         stream.close()
Beispiel #10
0
class Client(Session):
    def __init__(self, protocol, io_loop=None):
        Session.__init__(self, protocol, io_loop)
        self.auto_reconnect = True
        self.reconnect_time = 5
        self.connect_time = 5

    def connect(self, address):
        self.status = SessionStream.CONNECTING
        sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM, 0)
        self.ios = IOStream(sock, self.io_loop)
        self.ios.set_close_callback(self._ios_closed)
        self.address = address
        self.clear_buffer()
        self.add_timer(self._connect_timeout, self.connect_time, "connect")
        self.ios.connect(address, self._connected)

    def _connected(self):
        self.status = SessionStream.CONNECTED
        self.protocol.connected(self)
        self.remove_timer("connect")
        self.ios.read_until_close(self._disconnected, self._receiver)

    def _connect_timeout(self):
        self.remove_timer("connect")
        self.status = SessionStream.IDLE
        self.ios.close()
        if self.auto_reconnect:
            self.add_timer(self._do_reconnect, self.reconnect_time, "reconnect")

    def _ios_closed(self):
        if self.status == SessionStream.CONNECTING:
            self.remove_timer("connect")
        if self.auto_reconnect:
            self.add_timer(self._do_reconnect, self.reconnect_time, "reconnect")

    def _do_reconnect(self):
        self.remove_timer("reconnect")
        self.connect(self.address)
Beispiel #11
0
class EchoClient(object):
    """
    An asynchronous client for EchoServer
    """

    def __init__(self, address, family=socket.AF_INET, socktype=socket.SOCK_STREAM):
        self.io_stream = IOStream(socket.socket(family, socktype, 0))
        self.address = address
        self.is_closed = False

    def handle_close(self, data):
        self.is_closed = True

    def send_message(self, message, handle_response):
        def handle_connect():
            self.io_stream.read_until_close(self.handle_close, handle_response)
            self.write(message)

        self.io_stream.connect(self.address, handle_connect)

    def write(self, message):
        if not isinstance(message, bytes):
            message = message.encode("UTF-8")
        self.io_stream.write(message)
Beispiel #12
0
class UnixSocketTest(AsyncTestCase):
    """HTTPServers can listen on Unix sockets too.

    Why would you want to do this?  Nginx can proxy to backends listening
    on unix sockets, for one thing (and managing a namespace for unix
    sockets can be easier than managing a bunch of TCP port numbers).

    Unfortunately, there's no way to specify a unix socket in a url for
    an HTTP client, so we have to test this by hand.
    """

    def setUp(self):
        super().setUp()
        self.tmpdir = tempfile.mkdtemp()
        self.sockfile = os.path.join(self.tmpdir, "test.sock")
        sock = netutil.bind_unix_socket(self.sockfile)
        app = Application([("/hello", HelloWorldRequestHandler)])
        self.server = HTTPServer(app)
        self.server.add_socket(sock)
        self.stream = IOStream(socket.socket(socket.AF_UNIX))
        self.io_loop.run_sync(lambda: self.stream.connect(self.sockfile))

    def tearDown(self):
        self.stream.close()
        self.io_loop.run_sync(self.server.close_all_connections)
        self.server.stop()
        shutil.rmtree(self.tmpdir)
        super().tearDown()

    @gen_test
    def test_unix_socket(self):
        self.stream.write(b"GET /hello HTTP/1.0\r\n\r\n")
        response = yield self.stream.read_until(b"\r\n")
        self.assertEqual(response, b"HTTP/1.1 200 OK\r\n")
        header_data = yield self.stream.read_until(b"\r\n\r\n")
        headers = HTTPHeaders.parse(header_data.decode("latin1"))
        body = yield self.stream.read_bytes(int(headers["Content-Length"]))
        self.assertEqual(body, b"Hello world")

    @gen_test
    def test_unix_socket_bad_request(self):
        # Unix sockets don't have remote addresses so they just return an
        # empty string.
        with ExpectLog(gen_log, "Malformed HTTP message from"):
            self.stream.write(b"garbage\r\n\r\n")
            response = yield self.stream.read_until_close()
        self.assertEqual(response, b"HTTP/1.1 400 Bad Request\r\n\r\n")
Beispiel #13
0
    def test_handle_stream_native_coroutine(self):
        # handle_stream may be a native coroutine.

        class TestServer(TCPServer):
            async def handle_stream(self, stream, address):
                stream.write(b"data")
                stream.close()

        sock, port = bind_unused_port()
        server = TestServer()
        server.add_socket(sock)
        client = IOStream(socket.socket())
        yield client.connect(("localhost", port))
        result = yield client.read_until_close()
        self.assertEqual(result, b"data")
        server.stop()
        client.close()
Beispiel #14
0
 def test_body_size_override_reset(self):
     # The max_body_size override is reset between requests.
     stream = IOStream(socket.socket())
     try:
         yield stream.connect(("127.0.0.1", self.get_http_port()))
         # Use a raw stream so we can make sure it's all on one connection.
         stream.write(b"PUT /streaming?expected_size=10240 HTTP/1.1\r\n" b"Content-Length: 10240\r\n\r\n")
         stream.write(b"a" * 10240)
         headers, response = yield gen.Task(read_stream_body, stream)
         self.assertEqual(response, b"10240")
         # Without the ?expected_size parameter, we get the old default value
         stream.write(b"PUT /streaming HTTP/1.1\r\n" b"Content-Length: 10240\r\n\r\n")
         with ExpectLog(gen_log, ".*Content-Length too long"):
             data = yield stream.read_until_close()
         self.assertEqual(data, b"")
     finally:
         stream.close()
Beispiel #15
0
    def test_handle_stream_native_coroutine(self):
        # handle_stream may be a native coroutine.

        class TestServer(TCPServer):
            async def handle_stream(self, stream, address):
                stream.write(b"data")
                stream.close()

        sock, port = bind_unused_port()
        server = TestServer()
        server.add_socket(sock)
        client = IOStream(socket.socket())
        yield client.connect(("10.0.0.7", port))
        result = yield client.read_until_close()
        self.assertEqual(result, b"data")
        server.stop()
        client.close()
Beispiel #16
0
 def test_body_size_override_reset(self):
     # The max_body_size override is reset between requests.
     stream = IOStream(socket.socket())
     try:
         yield stream.connect(('127.0.0.1', self.get_http_port()))
         # Use a raw stream so we can make sure it's all on one connection.
         stream.write(b'PUT /streaming?expected_size=10240 HTTP/1.1\r\n'
                      b'Content-Length: 10240\r\n\r\n')
         stream.write(b'a' * 10240)
         headers, response = yield gen.Task(read_stream_body, stream)
         self.assertEqual(response, b'10240')
         # Without the ?expected_size parameter, we get the old default value
         stream.write(b'PUT /streaming HTTP/1.1\r\n'
                      b'Content-Length: 10240\r\n\r\n')
         with ExpectLog(gen_log, '.*Content-Length too long'):
             data = yield stream.read_until_close()
         self.assertEqual(data, b'')
     finally:
         stream.close()
Beispiel #17
0
 def test_body_size_override_reset(self):
     # The max_body_size override is reset between requests.
     stream = IOStream(socket.socket())
     try:
         yield stream.connect(("10.0.0.7", self.get_http_port()))
         # Use a raw stream so we can make sure it's all on one connection.
         stream.write(b"PUT /streaming?expected_size=10240 HTTP/1.1\r\n"
                      b"Content-Length: 10240\r\n\r\n")
         stream.write(b"a" * 10240)
         start_line, headers, response = yield read_stream_body(stream)
         self.assertEqual(response, b"10240")
         # Without the ?expected_size parameter, we get the old default value
         stream.write(b"PUT /streaming HTTP/1.1\r\n"
                      b"Content-Length: 10240\r\n\r\n")
         with ExpectLog(gen_log, ".*Content-Length too long"):
             data = yield stream.read_until_close()
         self.assertEqual(data, b"HTTP/1.1 400 Bad Request\r\n\r\n")
     finally:
         stream.close()
Beispiel #18
0
    def test_handle_stream_native_coroutine(self):
        # handle_stream may be a native coroutine.

        namespace = exec_test(globals(), locals(), """
        class TestServer(TCPServer):
            async def handle_stream(self, stream, address):
                stream.write(b'data')
                stream.close()
        """)

        sock, port = bind_unused_port()
        server = namespace['TestServer']()
        server.add_socket(sock)
        client = IOStream(socket.socket())
        yield client.connect(('localhost', port))
        result = yield client.read_until_close()
        self.assertEqual(result, b'data')
        server.stop()
        client.close()
Beispiel #19
0
    def test_handle_stream_native_coroutine(self):
        # handle_stream may be a native coroutine.

        namespace = exec_test(globals(), locals(), """
        class TestServer(TCPServer):
            async def handle_stream(self, stream, address):
                stream.write(b'data')
                stream.close()
        """)

        sock, port = bind_unused_port()
        server = namespace['TestServer']()
        server.add_socket(sock)
        client = IOStream(socket.socket())
        yield client.connect(('localhost', port))
        result = yield client.read_until_close()
        self.assertEqual(result, b'data')
        server.stop()
        client.close()
Beispiel #20
0
 def test_body_size_override_reset(self):
     # The max_body_size override is reset between requests.
     stream = IOStream(socket.socket())
     try:
         yield stream.connect(('127.0.0.1', self.get_http_port()))
         # Use a raw stream so we can make sure it's all on one connection.
         stream.write(b'PUT /streaming?expected_size=10240 HTTP/1.1\r\n'
                      b'Content-Length: 10240\r\n\r\n')
         stream.write(b'a' * 10240)
         fut = Future()
         read_stream_body(stream, callback=fut.set_result)
         start_line, headers, response = yield fut
         self.assertEqual(response, b'10240')
         # Without the ?expected_size parameter, we get the old default value
         stream.write(b'PUT /streaming HTTP/1.1\r\n'
                      b'Content-Length: 10240\r\n\r\n')
         with ExpectLog(gen_log, '.*Content-Length too long'):
             data = yield stream.read_until_close()
         self.assertEqual(data, b'HTTP/1.1 400 Bad Request\r\n\r\n')
     finally:
         stream.close()
class _HTTPConnection(object):
    _SUPPORTED_METHODS = set(
        ["GET", "HEAD", "POST", "PUT", "DELETE", "PATCH", "OPTIONS"])

    def __init__(self, io_loop, client, request, release_callback,
                 final_callback, max_buffer_size, resolver):
        self.start_time = io_loop.time()
        self.io_loop = io_loop
        self.client = client
        self.request = request
        self.release_callback = release_callback
        self.final_callback = final_callback
        self.max_buffer_size = max_buffer_size
        self.resolver = resolver
        self.code = None
        self.headers = None
        self.chunks = None
        self._decompressor = None
        # Timeout handle returned by IOLoop.add_timeout
        self._timeout = None
        with stack_context.ExceptionStackContext(self._handle_exception):
            self.parsed = urlparse.urlsplit(_unicode(self.request.url))
            if self.parsed.scheme not in ("http", "https"):
                raise ValueError("Unsupported url scheme: %s" %
                                 self.request.url)
            # urlsplit results have hostname and port results, but they
            # didn't support ipv6 literals until python 2.7.
            netloc = self.parsed.netloc
            if "@" in netloc:
                userpass, _, netloc = netloc.rpartition("@")
            match = re.match(r'^(.+):(\d+)$', netloc)
            if match:
                host = match.group(1)
                port = int(match.group(2))
            else:
                host = netloc
                port = 443 if self.parsed.scheme == "https" else 80
            if re.match(r'^\[.*\]$', host):
                # raw ipv6 addresses in urls are enclosed in brackets
                host = host[1:-1]
            self.parsed_hostname = host  # save final host for _on_connect

            if request.allow_ipv6:
                af = socket.AF_UNSPEC
            else:
                # We only try the first IP we get from getaddrinfo,
                # so restrict to ipv4 by default.
                af = socket.AF_INET

            self.resolver.resolve(host, port, af, callback=self._on_resolve)

    def _on_resolve(self, addrinfo):
        af, sockaddr = addrinfo[0]

        if self.parsed.scheme == "https":
            ssl_options = {}
            if self.request.validate_cert:
                ssl_options["cert_reqs"] = ssl.CERT_REQUIRED
            if self.request.ca_certs is not None:
                ssl_options["ca_certs"] = self.request.ca_certs
            else:
                ssl_options["ca_certs"] = _DEFAULT_CA_CERTS
            if self.request.client_key is not None:
                ssl_options["keyfile"] = self.request.client_key
            if self.request.client_cert is not None:
                ssl_options["certfile"] = self.request.client_cert

            # SSL interoperability is tricky.  We want to disable
            # SSLv2 for security reasons; it wasn't disabled by default
            # until openssl 1.0.  The best way to do this is to use
            # the SSL_OP_NO_SSLv2, but that wasn't exposed to python
            # until 3.2.  Python 2.7 adds the ciphers argument, which
            # can also be used to disable SSLv2.  As a last resort
            # on python 2.6, we set ssl_version to SSLv3.  This is
            # more narrow than we'd like since it also breaks
            # compatibility with servers configured for TLSv1 only,
            # but nearly all servers support SSLv3:
            # http://blog.ivanristic.com/2011/09/ssl-survey-protocol-support.html
            if sys.version_info >= (2, 7):
                ssl_options["ciphers"] = "DEFAULT:!SSLv2"
            else:
                # This is really only necessary for pre-1.0 versions
                # of openssl, but python 2.6 doesn't expose version
                # information.
                ssl_options["ssl_version"] = ssl.PROTOCOL_SSLv3

            self.stream = SSLIOStream(socket.socket(af),
                                      io_loop=self.io_loop,
                                      ssl_options=ssl_options,
                                      max_buffer_size=self.max_buffer_size)
        else:
            self.stream = IOStream(socket.socket(af),
                                   io_loop=self.io_loop,
                                   max_buffer_size=self.max_buffer_size)
        timeout = min(self.request.connect_timeout,
                      self.request.request_timeout)
        if timeout:
            self._timeout = self.io_loop.add_timeout(
                self.start_time + timeout,
                stack_context.wrap(self._on_timeout))
        self.stream.set_close_callback(self._on_close)
        # ipv6 addresses are broken (in self.parsed.hostname) until
        # 2.7, here is correctly parsed value calculated in __init__
        self.stream.connect(sockaddr,
                            self._on_connect,
                            server_hostname=self.parsed_hostname)

    def _on_timeout(self):
        self._timeout = None
        if self.final_callback is not None:
            raise HTTPError(599, "Timeout")

    def _remove_timeout(self):
        if self._timeout is not None:
            self.io_loop.remove_timeout(self._timeout)
            self._timeout = None

    def _on_connect(self):
        self._remove_timeout()
        if self.request.request_timeout:
            self._timeout = self.io_loop.add_timeout(
                self.start_time + self.request.request_timeout,
                stack_context.wrap(self._on_timeout))
        if (self.request.method not in self._SUPPORTED_METHODS
                and not self.request.allow_nonstandard_methods):
            raise KeyError("unknown method %s" % self.request.method)
        for key in ('network_interface', 'proxy_host', 'proxy_port',
                    'proxy_username', 'proxy_password'):
            if getattr(self.request, key, None):
                raise NotImplementedError('%s not supported' % key)
        if "Connection" not in self.request.headers:
            self.request.headers["Connection"] = "close"
        if "Host" not in self.request.headers:
            if '@' in self.parsed.netloc:
                self.request.headers["Host"] = self.parsed.netloc.rpartition(
                    '@')[-1]
            else:
                self.request.headers["Host"] = self.parsed.netloc
        username, password = None, None
        if self.parsed.username is not None:
            username, password = self.parsed.username, self.parsed.password
        elif self.request.auth_username is not None:
            username = self.request.auth_username
            password = self.request.auth_password or ''
        if username is not None:
            auth = utf8(username) + b":" + utf8(password)
            self.request.headers["Authorization"] = (b"Basic " +
                                                     base64.b64encode(auth))
        if self.request.user_agent:
            self.request.headers["User-Agent"] = self.request.user_agent
        if not self.request.allow_nonstandard_methods:
            if self.request.method in ("POST", "PATCH", "PUT"):
                assert self.request.body is not None
            else:
                assert self.request.body is None
        if self.request.body is not None:
            self.request.headers["Content-Length"] = str(len(
                self.request.body))
        if (self.request.method == "POST"
                and "Content-Type" not in self.request.headers):
            self.request.headers[
                "Content-Type"] = "application/x-www-form-urlencoded"
        if self.request.use_gzip:
            self.request.headers["Accept-Encoding"] = "gzip"
        req_path = ((self.parsed.path or '/') +
                    (('?' + self.parsed.query) if self.parsed.query else ''))
        request_lines = [
            utf8("%s %s HTTP/1.1" % (self.request.method, req_path))
        ]
        for k, v in self.request.headers.get_all():
            line = utf8(k) + b": " + utf8(v)
            if b'\n' in line:
                raise ValueError('Newline in header: ' + repr(line))
            request_lines.append(line)
        self.stream.write(b"\r\n".join(request_lines) + b"\r\n\r\n")
        if self.request.body is not None:
            self.stream.write(self.request.body)
        self.stream.read_until_regex(b"\r?\n\r?\n", self._on_headers)

    def _release(self):
        if self.release_callback is not None:
            release_callback = self.release_callback
            self.release_callback = None
            release_callback()

    def _run_callback(self, response):
        self._release()
        if self.final_callback is not None:
            final_callback = self.final_callback
            self.final_callback = None
            self.io_loop.add_callback(final_callback, response)

    def _handle_exception(self, typ, value, tb):
        if self.final_callback:
            self._remove_timeout()
            gen_log.warning("uncaught exception", exc_info=(typ, value, tb))
            self._run_callback(
                HTTPResponse(
                    self.request,
                    599,
                    error=value,
                    request_time=self.io_loop.time() - self.start_time,
                ))

            if hasattr(self, "stream"):
                self.stream.close()
            return True
        else:
            # If our callback has already been called, we are probably
            # catching an exception that is not caused by us but rather
            # some child of our callback. Rather than drop it on the floor,
            # pass it along.
            return False

    def _on_close(self):
        if self.final_callback is not None:
            message = "Connection closed"
            if self.stream.error:
                message = str(self.stream.error)
            raise HTTPError(599, message)

    def _handle_1xx(self, code):
        self.stream.read_until_regex(b"\r?\n\r?\n", self._on_headers)

    def _on_headers(self, data):
        data = native_str(data.decode("latin1"))
        first_line, _, header_data = data.partition("\n")
        match = re.match("HTTP/1.[01] ([0-9]+) ([^\r]*)", first_line)
        assert match
        code = int(match.group(1))
        self.headers = HTTPHeaders.parse(header_data)
        if 100 <= code < 200:
            self._handle_1xx(code)
            return
        else:
            self.code = code
            self.reason = match.group(2)

        if "Content-Length" in self.headers:
            if "," in self.headers["Content-Length"]:
                # Proxies sometimes cause Content-Length headers to get
                # duplicated.  If all the values are identical then we can
                # use them but if they differ it's an error.
                pieces = re.split(r',\s*', self.headers["Content-Length"])
                if any(i != pieces[0] for i in pieces):
                    raise ValueError("Multiple unequal Content-Lengths: %r" %
                                     self.headers["Content-Length"])
                self.headers["Content-Length"] = pieces[0]
            content_length = int(self.headers["Content-Length"])
        else:
            content_length = None

        if self.request.header_callback is not None:
            # re-attach the newline we split on earlier
            self.request.header_callback(first_line + _)
            for k, v in self.headers.get_all():
                self.request.header_callback("%s: %s\r\n" % (k, v))
            self.request.header_callback('\r\n')

        if self.request.method == "HEAD" or self.code == 304:
            # HEAD requests and 304 responses never have content, even
            # though they may have content-length headers
            self._on_body(b"")
            return
        if 100 <= self.code < 200 or self.code == 204:
            # These response codes never have bodies
            # http://www.w3.org/Protocols/rfc2616/rfc2616-sec4.html#sec4.3
            if ("Transfer-Encoding" in self.headers
                    or content_length not in (None, 0)):
                raise ValueError("Response with code %d should not have body" %
                                 self.code)
            self._on_body(b"")
            return

        if (self.request.use_gzip
                and self.headers.get("Content-Encoding") == "gzip"):
            self._decompressor = GzipDecompressor()
        if self.headers.get("Transfer-Encoding") == "chunked":
            self.chunks = []
            self.stream.read_until(b"\r\n", self._on_chunk_length)
        elif content_length is not None:
            self.stream.read_bytes(content_length, self._on_body)
        else:
            self.stream.read_until_close(self._on_body)

    def _on_body(self, data):
        self._remove_timeout()
        original_request = getattr(self.request, "original_request",
                                   self.request)
        if (self.request.follow_redirects and self.request.max_redirects > 0
                and self.code in (301, 302, 303, 307)):
            assert isinstance(self.request, _RequestProxy)
            new_request = copy.copy(self.request.request)
            new_request.url = urlparse.urljoin(self.request.url,
                                               self.headers["Location"])
            new_request.max_redirects = self.request.max_redirects - 1
            del new_request.headers["Host"]
            # http://www.w3.org/Protocols/rfc2616/rfc2616-sec10.html#sec10.3.4
            # Client SHOULD make a GET request after a 303.
            # According to the spec, 302 should be followed by the same
            # method as the original request, but in practice browsers
            # treat 302 the same as 303, and many servers use 302 for
            # compatibility with pre-HTTP/1.1 user agents which don't
            # understand the 303 status.
            if self.code in (302, 303):
                new_request.method = "GET"
                new_request.body = None
                for h in [
                        "Content-Length", "Content-Type", "Content-Encoding",
                        "Transfer-Encoding"
                ]:
                    try:
                        del self.request.headers[h]
                    except KeyError:
                        pass
            new_request.original_request = original_request
            final_callback = self.final_callback
            self.final_callback = None
            self._release()
            self.client.fetch(new_request, final_callback)
            self.stream.close()
            return
        if self._decompressor:
            data = (self._decompressor.decompress(data) +
                    self._decompressor.flush())
        if self.request.streaming_callback:
            if self.chunks is None:
                # if chunks is not None, we already called streaming_callback
                # in _on_chunk_data
                self.request.streaming_callback(data)
            buffer = BytesIO()
        else:
            buffer = BytesIO(data)  # TODO: don't require one big string?
        response = HTTPResponse(original_request,
                                self.code,
                                reason=self.reason,
                                headers=self.headers,
                                request_time=self.io_loop.time() -
                                self.start_time,
                                buffer=buffer,
                                effective_url=self.request.url)
        self._run_callback(response)
        self.stream.close()

    def _on_chunk_length(self, data):
        # TODO: "chunk extensions" http://tools.ietf.org/html/rfc2616#section-3.6.1
        length = int(data.strip(), 16)
        if length == 0:
            if self._decompressor is not None:
                tail = self._decompressor.flush()
                if tail:
                    # I believe the tail will always be empty (i.e.
                    # decompress will return all it can).  The purpose
                    # of the flush call is to detect errors such
                    # as truncated input.  But in case it ever returns
                    # anything, treat it as an extra chunk
                    if self.request.streaming_callback is not None:
                        self.request.streaming_callback(tail)
                    else:
                        self.chunks.append(tail)
                # all the data has been decompressed, so we don't need to
                # decompress again in _on_body
                self._decompressor = None
            self._on_body(b''.join(self.chunks))
        else:
            self.stream.read_bytes(
                length + 2,  # chunk ends with \r\n
                self._on_chunk_data)

    def _on_chunk_data(self, data):
        assert data[-2:] == b"\r\n"
        chunk = data[:-2]
        if self._decompressor:
            chunk = self._decompressor.decompress(chunk)
        if self.request.streaming_callback is not None:
            self.request.streaming_callback(chunk)
        else:
            self.chunks.append(chunk)
        self.stream.read_until(b"\r\n", self._on_chunk_length)
Beispiel #22
0
class _HTTPConnection(object):
    _SUPPORTED_METHODS = set(
        ["GET", "HEAD", "POST", "PUT", "DELETE", "PATCH", "OPTIONS"])

    def __init__(self, io_loop, client, request, release_callback,
                 final_callback, max_buffer_size):
        self.start_time = time.time()
        self.io_loop = io_loop
        self.client = client
        self.request = request
        self.release_callback = release_callback
        self.final_callback = final_callback
        self.code = None
        self.headers = None
        self.chunks = None
        self._decompressor = None
        # Timeout handle returned by IOLoop.add_timeout
        self._timeout = None
        with stack_context.StackContext(self.cleanup):
            parsed = urlparse.urlsplit(_unicode(self.request.url))
            if ssl is None and parsed.scheme == "https":
                raise ValueError("HTTPS requires either python2.6+ or "
                                 "curl_httpclient")
            if parsed.scheme not in ("http", "https"):
                raise ValueError("Unsupported url scheme: %s" %
                                 self.request.url)
            # urlsplit results have hostname and port results, but they
            # didn't support ipv6 literals until python 2.7.
            netloc = parsed.netloc
            if "@" in netloc:
                userpass, _, netloc = netloc.rpartition("@")
            match = re.match(r'^(.+):(\d+)$', netloc)
            if match:
                host = match.group(1)
                port = int(match.group(2))
            else:
                host = netloc
                port = 443 if parsed.scheme == "https" else 80
            if re.match(r'^\[.*\]$', host):
                # raw ipv6 addresses in urls are enclosed in brackets
                host = host[1:-1]
            parsed_hostname = host  # save final parsed host for _on_connect
            if self.client.hostname_mapping is not None:
                host = self.client.hostname_mapping.get(host, host)

            if request.allow_ipv6:
                af = socket.AF_UNSPEC
            else:
                # We only try the first IP we get from getaddrinfo,
                # so restrict to ipv4 by default.
                af = socket.AF_INET

            addrinfo = socket.getaddrinfo(host, port, af, socket.SOCK_STREAM,
                                          0, 0)
            af, socktype, proto, canonname, sockaddr = addrinfo[0]

            if parsed.scheme == "https":
                ssl_options = {}
                if request.validate_cert:
                    ssl_options["cert_reqs"] = ssl.CERT_REQUIRED
                if request.ca_certs is not None:
                    ssl_options["ca_certs"] = request.ca_certs
                else:
                    ssl_options["ca_certs"] = _DEFAULT_CA_CERTS
                if request.client_key is not None:
                    ssl_options["keyfile"] = request.client_key
                if request.client_cert is not None:
                    ssl_options["certfile"] = request.client_cert

                # SSL interoperability is tricky.  We want to disable
                # SSLv2 for security reasons; it wasn't disabled by default
                # until openssl 1.0.  The best way to do this is to use
                # the SSL_OP_NO_SSLv2, but that wasn't exposed to python
                # until 3.2.  Python 2.7 adds the ciphers argument, which
                # can also be used to disable SSLv2.  As a last resort
                # on python 2.6, we set ssl_version to SSLv3.  This is
                # more narrow than we'd like since it also breaks
                # compatibility with servers configured for TLSv1 only,
                # but nearly all servers support SSLv3:
                # http://blog.ivanristic.com/2011/09/ssl-survey-protocol-support.html
                if sys.version_info >= (2, 7):
                    ssl_options["ciphers"] = "DEFAULT:!SSLv2"
                else:
                    # This is really only necessary for pre-1.0 versions
                    # of openssl, but python 2.6 doesn't expose version
                    # information.
                    ssl_options["ssl_version"] = ssl.PROTOCOL_SSLv3

                self.stream = SSLIOStream(socket.socket(af, socktype, proto),
                                          io_loop=self.io_loop,
                                          ssl_options=ssl_options,
                                          max_buffer_size=max_buffer_size)
            else:
                self.stream = IOStream(socket.socket(af, socktype, proto),
                                       io_loop=self.io_loop,
                                       max_buffer_size=max_buffer_size)
            timeout = min(request.connect_timeout, request.request_timeout)
            if timeout:
                self._timeout = self.io_loop.add_timeout(
                    self.start_time + timeout, self._on_timeout)
            self.stream.set_close_callback(self._on_close)
            self.stream.connect(
                sockaddr,
                functools.partial(self._on_connect, parsed, parsed_hostname))

    def _on_timeout(self):
        self._timeout = None
        self._run_callback(
            HTTPResponse(self.request,
                         599,
                         request_time=time.time() - self.start_time,
                         error=HTTPError(599, "Timeout")))
        self.stream.close()

    def _on_connect(self, parsed, parsed_hostname):
        if self._timeout is not None:
            self.io_loop.remove_timeout(self._timeout)
            self._timeout = None
        if self.request.request_timeout:
            self._timeout = self.io_loop.add_timeout(
                self.start_time + self.request.request_timeout,
                self._on_timeout)
        if (self.request.validate_cert
                and isinstance(self.stream, SSLIOStream)):
            match_hostname(
                self.stream.socket.getpeercert(),
                # ipv6 addresses are broken (in
                # parsed.hostname) until 2.7, here is
                # correctly parsed value calculated in
                # __init__
                parsed_hostname)
        if (self.request.method not in self._SUPPORTED_METHODS
                and not self.request.allow_nonstandard_methods):
            raise KeyError("unknown method %s" % self.request.method)
        for key in ('network_interface', 'proxy_host', 'proxy_port',
                    'proxy_username', 'proxy_password'):
            if getattr(self.request, key, None):
                raise NotImplementedError('%s not supported' % key)
        if "Connection" not in self.request.headers:
            self.request.headers["Connection"] = "close"
        if "Host" not in self.request.headers:
            if '@' in parsed.netloc:
                self.request.headers["Host"] = parsed.netloc.rpartition(
                    '@')[-1]
            else:
                self.request.headers["Host"] = parsed.netloc
        username, password = None, None
        if parsed.username is not None:
            username, password = parsed.username, parsed.password
        elif self.request.auth_username is not None:
            username = self.request.auth_username
            password = self.request.auth_password or ''
        if username is not None:
            auth = utf8(username) + b(":") + utf8(password)
            self.request.headers["Authorization"] = (b("Basic ") +
                                                     base64.b64encode(auth))
        if self.request.user_agent:
            self.request.headers["User-Agent"] = self.request.user_agent
        if not self.request.allow_nonstandard_methods:
            if self.request.method in ("POST", "PATCH", "PUT"):
                assert self.request.body is not None
            else:
                assert self.request.body is None
        if self.request.body is not None:
            self.request.headers["Content-Length"] = str(len(
                self.request.body))
        if (self.request.method == "POST"
                and "Content-Type" not in self.request.headers):
            self.request.headers[
                "Content-Type"] = "application/x-www-form-urlencoded"
        if self.request.use_gzip:
            self.request.headers["Accept-Encoding"] = "gzip"
        req_path = ((parsed.path or '/') +
                    (('?' + parsed.query) if parsed.query else ''))
        request_lines = [
            utf8("%s %s HTTP/1.1" % (self.request.method, req_path))
        ]
        for k, v in self.request.headers.get_all():
            line = utf8(k) + b(": ") + utf8(v)
            if b('\n') in line:
                raise ValueError('Newline in header: ' + repr(line))
            request_lines.append(line)
        self.stream.write(b("\r\n").join(request_lines) + b("\r\n\r\n"))
        if self.request.body is not None:
            self.stream.write(self.request.body)
        self.stream.read_until_regex(b("\r?\n\r?\n"), self._on_headers)

    def _release(self):
        if self.release_callback is not None:
            release_callback = self.release_callback
            self.release_callback = None
            release_callback()

    def _run_callback(self, response):
        self._release()
        if self.final_callback is not None:
            final_callback = self.final_callback
            self.final_callback = None
            final_callback(response)

    @contextlib.contextmanager
    def cleanup(self):
        try:
            yield
        except Exception as e:
            logging.warning("uncaught exception", exc_info=True)
            self._run_callback(
                HTTPResponse(
                    self.request,
                    599,
                    error=e,
                    request_time=time.time() - self.start_time,
                ))
            if hasattr(self, "stream"):
                self.stream.close()

    def _on_close(self):
        self._run_callback(
            HTTPResponse(self.request,
                         599,
                         request_time=time.time() - self.start_time,
                         error=HTTPError(599, "Connection closed")))

    def _on_headers(self, data):
        data = native_str(data.decode("latin1"))
        first_line, _, header_data = data.partition("\n")
        match = re.match("HTTP/1.[01] ([0-9]+)", first_line)
        assert match
        self.code = int(match.group(1))
        self.headers = HTTPHeaders.parse(header_data)

        if "Content-Length" in self.headers:
            if "," in self.headers["Content-Length"]:
                # Proxies sometimes cause Content-Length headers to get
                # duplicated.  If all the values are identical then we can
                # use them but if they differ it's an error.
                pieces = re.split(r',\s*', self.headers["Content-Length"])
                if any(i != pieces[0] for i in pieces):
                    raise ValueError("Multiple unequal Content-Lengths: %r" %
                                     self.headers["Content-Length"])
                self.headers["Content-Length"] = pieces[0]
            content_length = int(self.headers["Content-Length"])
        else:
            content_length = None

        if self.request.header_callback is not None:
            for k, v in self.headers.get_all():
                self.request.header_callback("%s: %s\r\n" % (k, v))

        if self.request.method == "HEAD":
            # HEAD requests never have content, even though they may have
            # content-length headers
            self._on_body(b(""))
            return
        if 100 <= self.code < 200 or self.code in (204, 304):
            # These response codes never have bodies
            # http://www.w3.org/Protocols/rfc2616/rfc2616-sec4.html#sec4.3
            assert "Transfer-Encoding" not in self.headers
            assert content_length in (None, 0)
            self._on_body(b(""))
            return

        if (self.request.use_gzip
                and self.headers.get("Content-Encoding") == "gzip"):
            # Magic parameter makes zlib module understand gzip header
            # http://stackoverflow.com/questions/1838699/how-can-i-decompress-a-gzip-stream-with-zlib
            self._decompressor = zlib.decompressobj(16 + zlib.MAX_WBITS)
        if self.headers.get("Transfer-Encoding") == "chunked":
            self.chunks = []
            self.stream.read_until(b("\r\n"), self._on_chunk_length)
        elif content_length is not None:
            self.stream.read_bytes(content_length, self._on_body)
        else:
            self.stream.read_until_close(self._on_body)

    def _on_body(self, data):
        if self._timeout is not None:
            self.io_loop.remove_timeout(self._timeout)
            self._timeout = None
        original_request = getattr(self.request, "original_request",
                                   self.request)
        if (self.request.follow_redirects and self.request.max_redirects > 0
                and self.code in (301, 302, 303, 307)):
            new_request = copy.copy(self.request)
            new_request.url = urlparse.urljoin(self.request.url,
                                               self.headers["Location"])
            new_request.max_redirects -= 1
            del new_request.headers["Host"]
            # http://www.w3.org/Protocols/rfc2616/rfc2616-sec10.html#sec10.3.4
            # client SHOULD make a GET request
            if self.code == 303:
                new_request.method = "GET"
                new_request.body = None
                for h in [
                        "Content-Length", "Content-Type", "Content-Encoding",
                        "Transfer-Encoding"
                ]:
                    try:
                        del self.request.headers[h]
                    except KeyError:
                        pass
            new_request.original_request = original_request
            final_callback = self.final_callback
            self.final_callback = None
            self._release()
            self.client.fetch(new_request, final_callback)
            self.stream.close()
            return
        if self._decompressor:
            data = self._decompressor.decompress(data)
        if self.request.streaming_callback:
            if self.chunks is None:
                # if chunks is not None, we already called streaming_callback
                # in _on_chunk_data
                self.request.streaming_callback(data)
            buffer = BytesIO()
        else:
            buffer = BytesIO(data)  # TODO: don't require one big string?
        response = HTTPResponse(original_request,
                                self.code,
                                headers=self.headers,
                                request_time=time.time() - self.start_time,
                                buffer=buffer,
                                effective_url=self.request.url)
        self._run_callback(response)
        self.stream.close()

    def _on_chunk_length(self, data):
        # TODO: "chunk extensions" http://tools.ietf.org/html/rfc2616#section-3.6.1
        length = int(data.strip(), 16)
        if length == 0:
            # all the data has been decompressed, so we don't need to
            # decompress again in _on_body
            self._decompressor = None
            self._on_body(b('').join(self.chunks))
        else:
            self.stream.read_bytes(
                length + 2,  # chunk ends with \r\n
                self._on_chunk_data)

    def _on_chunk_data(self, data):
        assert data[-2:] == b("\r\n")
        chunk = data[:-2]
        if self._decompressor:
            chunk = self._decompressor.decompress(chunk)
        if self.request.streaming_callback is not None:
            self.request.streaming_callback(chunk)
        else:
            self.chunks.append(chunk)
        self.stream.read_until(b("\r\n"), self._on_chunk_length)
Beispiel #23
0
class WebSocketProxyHandler(WebSocketHandler):
    """
    Proxy a websocket connection to a service listening on a given (host, port) pair
    """

    def initialize(self, **kwargs):
        self.remote_address = kwargs.get("address")
        self.io_stream = IOStream(socket.socket(kwargs.get("family", socket.AF_INET),
                                                kwargs.get("type", socket.SOCK_STREAM),
                                                0))
        self.filters = kwargs.get("filters", [])
        self.io_stream.set_close_callback(self.on_close)

    def open(self):
        """
        Open the connection to the service when the WebSocket connection has been established
        """
        logger.info("Forwarding connection to server %s" % tuple_to_address(self.remote_address))
        self.io_stream.connect(self.remote_address, self.on_connect)

    def on_message(self, message):
        """
        On message received from WebSocket, forward data to the service
        """
        try:
            data = None if message is None else bytes(message)
            for filtr in self.filters:
                data = filtr.ws_to_socket(data=data)
            if data:
                self.io_stream.write(data)
        except Exception as e:
            logger.exception(e)
            self.close()

    def on_close(self, *args, **kwargs):
        """
        When web socket gets closed, close the connection to the service too
        """
        logger.info("Closing connection with peer at %s" % tuple_to_address(self.remote_address))
        logger.debug("Received args %s and %s", args, kwargs)
        #if not self.io_stream._closed:
        for message in args:
            self.on_peer_message(message)
        if not self.io_stream.closed():
            self.io_stream.close()
        self.close()

    def on_connect(self):
        """
        Callback invoked on connection with mapped service
        """
        logger.info("Connection established with peer at %s" % tuple_to_address(self.remote_address))
        self.io_stream.read_until_close(self.on_close, self.on_peer_message)

    def on_peer_message(self, message):
        """
        On message received from peer service, send back to client through WebSocket
        """
        try:
            data = None if message is None else bytes(message)
            for filtr in self.filters:
                data = filtr.socket_to_ws(data=data)
            if data:
                self.write_message(data, binary=True)
        except FilterException as e:
            logger.exception(e)
            self.on_close()
Beispiel #24
0
class BaseTornadoClient(AsyncModbusClientMixin):
    """
    Base Tornado client
    """
    stream = None
    io_loop = None

    def __init__(self, *args, **kwargs):
        """
        Initializes BaseTornadoClient.
        ioloop to be passed as part of kwargs ('ioloop')
        :param args:
        :param kwargs:
        """
        self.io_loop = kwargs.pop("ioloop", None)
        super(BaseTornadoClient, self).__init__(*args, **kwargs)

    @abc.abstractmethod
    def get_socket(self):
        """
        return instance of the socket to connect to
        """

    @gen.coroutine
    def connect(self):
        """
        Connect to the socket identified by host and port

        :returns: Future
        :rtype: tornado.concurrent.Future
        """
        conn = self.get_socket()
        self.stream = IOStream(conn, io_loop=self.io_loop or IOLoop.current())
        self.stream.connect((self.host, self.port))
        self.stream.read_until_close(None, streaming_callback=self.on_receive)
        self._connected = True
        LOGGER.debug("Client connected")

        raise gen.Return(self)

    def on_receive(self, *args):
        """
        On data recieve call back
        :param args: data received
        :return:
        """
        data = args[0] if len(args) > 0 else None

        if not data:
            return
        LOGGER.debug("recv: " + hexlify_packets(data))
        unit = self.framer.decode_data(data).get("unit", 0)
        self.framer.processIncomingPacket(data,
                                          self._handle_response,
                                          unit=unit)

    def execute(self, request=None):
        """
        Executes a transaction
        :param request:
        :return:
        """
        request.transaction_id = self.transaction.getNextTID()
        packet = self.framer.buildPacket(request)
        LOGGER.debug("send: " + hexlify_packets(packet))
        self.stream.write(packet)
        return self._build_response(request.transaction_id)

    def _handle_response(self, reply, **kwargs):
        """
        Handle response received
        :param reply:
        :param kwargs:
        :return:
        """
        if reply is not None:
            tid = reply.transaction_id
            future = self.transaction.getTransaction(tid)
            if future:
                future.set_result(reply)
            else:
                LOGGER.debug("Unrequested message: {}".format(reply))

    def _build_response(self, tid):
        """
        Builds a future response
        :param tid:
        :return:
        """
        f = Future()

        if not self._connected:
            f.set_exception(ConnectionException("Client is not connected"))
            return f

        self.transaction.addTransaction(f, tid)
        return f

    def close(self):
        """
        Closes the underlying IOStream
        """
        LOGGER.debug("Client disconnected")
        if self.stream:
            self.stream.close_fd()

        self.stream = None
        self._connected = False
Beispiel #25
0
class HTTPServerRawTest(AsyncHTTPTestCase):
    def get_app(self):
        return Application([("/echo", EchoHandler)])

    def setUp(self):
        super(HTTPServerRawTest, self).setUp()
        self.stream = IOStream(socket.socket())
        self.io_loop.run_sync(lambda: self.stream.connect(
            ("10.0.0.7", self.get_http_port())))

    def tearDown(self):
        self.stream.close()
        super(HTTPServerRawTest, self).tearDown()

    def test_empty_request(self):
        self.stream.close()
        self.io_loop.add_timeout(datetime.timedelta(seconds=0.001), self.stop)
        self.wait()

    def test_malformed_first_line_response(self):
        with ExpectLog(gen_log, ".*Malformed HTTP request line"):
            self.stream.write(b"asdf\r\n\r\n")
            start_line, headers, response = self.io_loop.run_sync(
                lambda: read_stream_body(self.stream))
            self.assertEqual("HTTP/1.1", start_line.version)
            self.assertEqual(400, start_line.code)
            self.assertEqual("Bad Request", start_line.reason)

    def test_malformed_first_line_log(self):
        with ExpectLog(gen_log, ".*Malformed HTTP request line"):
            self.stream.write(b"asdf\r\n\r\n")
            # TODO: need an async version of ExpectLog so we don't need
            # hard-coded timeouts here.
            self.io_loop.add_timeout(datetime.timedelta(seconds=0.05),
                                     self.stop)
            self.wait()

    def test_malformed_headers(self):
        with ExpectLog(gen_log,
                       ".*Malformed HTTP message.*no colon in header line"):
            self.stream.write(b"GET / HTTP/1.0\r\nasdf\r\n\r\n")
            self.io_loop.add_timeout(datetime.timedelta(seconds=0.05),
                                     self.stop)
            self.wait()

    def test_chunked_request_body(self):
        # Chunked requests are not widely supported and we don't have a way
        # to generate them in AsyncHTTPClient, but HTTPServer will read them.
        self.stream.write(b"""\
POST /echo HTTP/1.1
Transfer-Encoding: chunked
Content-Type: application/x-www-form-urlencoded

4
foo=
3
bar
0

""".replace(b"\n", b"\r\n"))
        start_line, headers, response = self.io_loop.run_sync(
            lambda: read_stream_body(self.stream))
        self.assertEqual(json_decode(response), {u"foo": [u"bar"]})

    def test_chunked_request_uppercase(self):
        # As per RFC 2616 section 3.6, "Transfer-Encoding" header's value is
        # case-insensitive.
        self.stream.write(b"""\
POST /echo HTTP/1.1
Transfer-Encoding: Chunked
Content-Type: application/x-www-form-urlencoded

4
foo=
3
bar
0

""".replace(b"\n", b"\r\n"))
        start_line, headers, response = self.io_loop.run_sync(
            lambda: read_stream_body(self.stream))
        self.assertEqual(json_decode(response), {u"foo": [u"bar"]})

    @gen_test
    def test_invalid_content_length(self):
        with ExpectLog(gen_log, ".*Only integer Content-Length is allowed"):
            self.stream.write(b"""\
POST /echo HTTP/1.1
Content-Length: foo

bar

""".replace(b"\n", b"\r\n"))
            yield self.stream.read_until_close()
Beispiel #26
0
class _RedisConnection(object):
    def __init__(self, io_loop, init_buf, final_callback, redis_tuple, redis_pass):
        """
        :param io_loop: 你懂的
        :param init_buf: 第一次写入
        :param final_callback: resp赋值时调用
        :param redis_tuple: (ip, port, db)
        :param redis_pass: redis密码
        """
        self.__io_loop = io_loop
        self.__final_cb = final_callback
        self.__stream = None
        #redis应答解析remain
        self.__recv_buf = ''

        init_buf = init_buf or ''
        init_buf = chain_select_cmd(redis_tuple[2], init_buf)
        if redis_pass is None:
            self.__init_buf = (init_buf,)
        else:
            assert redis_pass and isinstance(redis_pass, str)
            self.__init_buf = (redis_auth(redis_pass), init_buf)

        self.__haspass = redis_pass is not None
        self.__init_buf = ''.join(self.__init_buf)

        self.__connected = False
        #redis指令上下文, connect指令个数(AUTH, SELECT .etc),trans,cmd_count
        self.__cmd_env = deque()
        self.__written = False

    def connect(self, init_future, redis_tuple, active_trans, cmd_count):
        """
        :param init_future: 第一个future对象
        :param redis_tuple: (ip, port, db)
        :param active_trans: 事务是否激活
        :param cmd_count: 指令个数
        """
        if self.__stream is not None:
            return
        #future, connect_count, transaction, cmd_count
        self.__cmd_env.append((init_future, 1 + int(self.__haspass), active_trans, cmd_count))
        with ExceptionStackContext(self.__handle_ex):
            self.__stream = IOStream(socket.socket(socket.AF_INET, socket.SOCK_STREAM, 0),
                                     io_loop=self.__io_loop)
            self.__stream.connect(redis_tuple[:2], self.__on_connect)

    def write(self, write_buf, new_future, include_select, active_trans, cmd_count, by_connect=False):
        """
        :param new_future: 由于闭包的影响,在resp回调函数中会保存上一次的future对象,该对象必须得到更新
        :param include_select: 是否包含SELECT指令
        :param active_trans: 事务是否激活
        :param cmd_count: 指令个数
        """
        if by_connect:
            self.__stream.write(self.__init_buf)
            self.__init_buf = None
            return

        self.__cmd_env.append((new_future, int(include_select), active_trans, cmd_count))
        if not self.__connected:
            self.__init_buf = ''.join((self.__init_buf, write_buf))
            return

        if self.__init_buf:
            write_buf = ''.join((self.__init_buf, write_buf))

        self.__stream.write(write_buf)
        self.__init_buf = None

    def __on_connect(self):
        """连接,只需要发送初始cmd即可
        """
        self.__connected = True
        self.__stream.set_nodelay(True)
        self.write(None, None, None, None, None, True)
        self.__stream.read_until_close(None, self.__on_resp)

    def __on_resp(self, recv):
        """
        :param recv: 收到的buf
        """
        recv = ''.join((self.__recv_buf, recv))

        idx = 0
        for future, connect_count, trans, count in self.__cmd_env:
            ok, payload, recv = decode_resp_ondemand(recv, connect_count, trans, count)
            if not ok:
                break

            idx += 1
            self.__run_callback({_RESP_FUTURE: future, RESP_RESULT: payload})

        self.__recv_buf = recv
        for _ in xrange(idx):
            self.__cmd_env.popleft()

    def __run_callback(self, resp):
        if self.__final_cb is None:
            return

        self.__io_loop.add_callback(self.__final_cb, resp)

    def __handle_ex(self, typ, value, tb):
        """
        :param typ: 异常类型
        """
        if self.__final_cb:
            self.__run_callback({RESP_ERR: value})
            return True
        return False
Beispiel #27
0
class _RedisConnection(object):
    def __init__(self, io_loop, write_buf, final_callback, redis_tuple,
                 redis_pass):
        """
        :param io_loop: 你懂的
        :param write_buf: 第一次写入
        :param final_callback: resp赋值时调用
        :param redis_tuple: (ip, port, db)
        :param redis_pass: redis密码
        """
        self.__io_loop = io_loop
        self.__final_cb = final_callback
        self.__stream = None
        #redis应答解析remain
        self.__recv_buf = ''
        self.__write_buf = write_buf

        init_buf = ''
        init_buf = chain_select_cmd(redis_tuple[2], init_buf)
        if redis_pass is None:
            self.__init_buf = (init_buf, )
        else:
            assert redis_pass and isinstance(redis_pass, str)
            self.__init_buf = (redis_auth(redis_pass), init_buf)

        self.__haspass = redis_pass is not None
        self.__init_buf = ''.join(self.__init_buf)

        self.__connect_state = CONNECT_INIT
        #redis指令上下文, connect指令个数(AUTH, SELECT .etc),trans,cmd_count
        self.__cmd_env = deque()
        self.__written = False

    def connect(self, init_future, redis_tuple, active_trans, cmd_count):
        """
        :param init_future: 第一个future对象
        :param redis_tuple: (ip, port, db)
        :param active_trans: 事务是否激活
        :param cmd_count: 指令个数
        """
        if self.__stream is not None:
            return

        #future, connect_count, transaction, cmd_count
        self.__cmd_env.append((init_future, 1 + int(self.__haspass), False, 0))
        self.__cmd_env.append((init_future, 0, active_trans, cmd_count))

        with ExceptionStackContext(self.__handle_ex):
            self.__stream = IOStream(socket.socket(socket.AF_INET,
                                                   socket.SOCK_STREAM, 0),
                                     io_loop=self.__io_loop)
            self.__stream.set_close_callback(self.__on_close)
            self.__stream.connect(redis_tuple[:2], self.__on_connect)
            self.__connect_state = CONNECT_ING

    def connect_state(self):
        return self.__connect_state

    def write(self,
              write_buf,
              new_future,
              include_select,
              active_trans,
              cmd_count,
              by_connect=False):
        """
        :param new_future: 由于闭包的影响,在resp回调函数中会保存上一次的future对象,该对象必须得到更新
        :param include_select: 是否包含SELECT指令
        :param active_trans: 事务是否激活
        :param cmd_count: 指令个数
        """
        if by_connect:
            self.__stream.write(self.__init_buf)
            self.__init_buf = None

            if self.__write_buf:
                self.__stream.write(self.__write_buf)
                self.__write_buf = None
            return

        self.__cmd_env.append(
            (new_future, int(include_select), active_trans, cmd_count))
        if self.__connect_state == CONNECT_ING:
            self.__write_buf = ''.join((self.__write_buf, write_buf))
            return

        if self.__write_buf:
            write_buf = ''.join((self.__write_buf, write_buf))

        self.__stream.write(write_buf)
        self.__write_buf = None

    def __on_connect(self):
        """连接,只需要发送初始cmd即可
        """
        self.__connect_state = CONNECT_SUCC
        self.__stream.set_nodelay(True)
        self.write(None, None, None, None, None, True)
        self.__stream.read_until_close(None, self.__on_resp)

    def __on_resp(self, recv):
        """
        :param recv: 收到的buf
        """
        recv = ''.join((self.__recv_buf, recv))

        idx = 0
        for future, connect_count, trans, count in self.__cmd_env:
            ok, payload, recv = decode_resp_ondemand(recv, connect_count,
                                                     trans, count)
            if not ok:
                break

            idx += 1
            if count > 0:
                self.__run_callback({
                    _RESP_FUTURE: future,
                    RESP_RESULT: payload
                })

        self.__recv_buf = recv
        for _ in xrange(idx):
            self.__cmd_env.popleft()

    def __on_close(self):
        self.__connect_state = CONNECT_INIT
        if self.__final_cb:
            if self.__stream.error:
                self.__run_callback({RESP_ERR: self.__stream.error})

    def __run_callback(self, resp):
        if self.__final_cb is None:
            return

        self.__io_loop.add_callback(self.__final_cb, resp)

    def __handle_ex(self, typ, value, tb):
        """
        :param typ: 异常类型
        """
        if self.__final_cb:
            self.__run_callback({RESP_ERR: value})
            return True
        return False
Beispiel #28
0
class KeepAliveTest(AsyncHTTPTestCase):
    """Tests various scenarios for HTTP 1.1 keep-alive support.

    These tests don't use AsyncHTTPClient because we want to control
    connection reuse and closing.
    """
    def get_app(self):
        class HelloHandler(RequestHandler):
            def get(self):
                self.finish('Hello world')

        class LargeHandler(RequestHandler):
            def get(self):
                # 512KB should be bigger than the socket buffers so it will
                # be written out in chunks.
                self.write(''.join(chr(i % 256) * 1024 for i in range(512)))

        class FinishOnCloseHandler(RequestHandler):
            @asynchronous
            def get(self):
                self.flush()

            def on_connection_close(self):
                # This is not very realistic, but finishing the request
                # from the close callback has the right timing to mimic
                # some errors seen in the wild.
                self.finish('closed')

        return Application([('/', HelloHandler), ('/large', LargeHandler),
                            ('/finish_on_close', FinishOnCloseHandler)])

    def setUp(self):
        super(KeepAliveTest, self).setUp()
        self.http_version = b'HTTP/1.1'

    def tearDown(self):
        # We just closed the client side of the socket; let the IOLoop run
        # once to make sure the server side got the message.
        self.io_loop.add_timeout(datetime.timedelta(seconds=0.001), self.stop)
        self.wait()

        if hasattr(self, 'stream'):
            self.stream.close()
        super(KeepAliveTest, self).tearDown()

    # The next few methods are a crude manual http client
    def connect(self):
        self.stream = IOStream(socket.socket(), io_loop=self.io_loop)
        self.stream.connect(('localhost', self.get_http_port()), self.stop)
        self.wait()

    def read_headers(self):
        self.stream.read_until(b'\r\n', self.stop)
        first_line = self.wait()
        self.assertTrue(first_line.startswith(self.http_version + b' 200'),
                        first_line)
        self.stream.read_until(b'\r\n\r\n', self.stop)
        header_bytes = self.wait()
        headers = HTTPHeaders.parse(header_bytes.decode('latin1'))
        return headers

    def read_response(self):
        self.headers = self.read_headers()
        self.stream.read_bytes(int(self.headers['Content-Length']), self.stop)
        body = self.wait()
        self.assertEqual(b'Hello world', body)

    def close(self):
        self.stream.close()
        del self.stream

    def test_two_requests(self):
        self.connect()
        self.stream.write(b'GET / HTTP/1.1\r\n\r\n')
        self.read_response()
        self.stream.write(b'GET / HTTP/1.1\r\n\r\n')
        self.read_response()
        self.close()

    def test_request_close(self):
        self.connect()
        self.stream.write(b'GET / HTTP/1.1\r\nConnection: close\r\n\r\n')
        self.read_response()
        self.stream.read_until_close(callback=self.stop)
        data = self.wait()
        self.assertTrue(not data)
        self.close()

    # keepalive is supported for http 1.0 too, but it's opt-in
    def test_http10(self):
        self.http_version = b'HTTP/1.0'
        self.connect()
        self.stream.write(b'GET / HTTP/1.0\r\n\r\n')
        self.read_response()
        self.stream.read_until_close(callback=self.stop)
        data = self.wait()
        self.assertTrue(not data)
        self.assertTrue('Connection' not in self.headers)
        self.close()

    def test_http10_keepalive(self):
        self.http_version = b'HTTP/1.0'
        self.connect()
        self.stream.write(b'GET / HTTP/1.0\r\nConnection: keep-alive\r\n\r\n')
        self.read_response()
        self.assertEqual(self.headers['Connection'], 'Keep-Alive')
        self.stream.write(b'GET / HTTP/1.0\r\nConnection: keep-alive\r\n\r\n')
        self.read_response()
        self.assertEqual(self.headers['Connection'], 'Keep-Alive')
        self.close()

    def test_pipelined_requests(self):
        self.connect()
        self.stream.write(b'GET / HTTP/1.1\r\n\r\nGET / HTTP/1.1\r\n\r\n')
        self.read_response()
        self.read_response()
        self.close()

    def test_pipelined_cancel(self):
        self.connect()
        self.stream.write(b'GET / HTTP/1.1\r\n\r\nGET / HTTP/1.1\r\n\r\n')
        # only read once
        self.read_response()
        self.close()

    def test_cancel_during_download(self):
        self.connect()
        self.stream.write(b'GET /large HTTP/1.1\r\n\r\n')
        self.read_headers()
        self.stream.read_bytes(1024, self.stop)
        self.wait()
        self.close()

    def test_finish_while_closed(self):
        self.connect()
        self.stream.write(b'GET /finish_on_close HTTP/1.1\r\n\r\n')
        self.read_headers()
        self.close()
Beispiel #29
0
class _HTTPConnection(object):
    _SUPPORTED_METHODS = set(["GET", "HEAD", "POST", "PUT", "DELETE"])

    def __init__(self, io_loop, client, request, release_callback,
                 final_callback, max_buffer_size):
        self.start_time = time.time()
        self.io_loop = io_loop
        self.client = client
        self.request = request
        self.release_callback = release_callback
        self.final_callback = final_callback
        self.code = None
        self.headers = None
        self.chunks = None
        self._decompressor = None
        # Timeout handle returned by IOLoop.add_timeout
        self._timeout = None
        with stack_context.StackContext(self.cleanup):
            parsed = urllib.parse.urlsplit(_unicode(self.request.url))
            if ssl is None and parsed.scheme == "https":
                raise ValueError("HTTPS requires either python2.6+ or "
                                 "curl_httpclient")
            if parsed.scheme not in ("http", "https"):
                raise ValueError("Unsupported url scheme: %s" %
                                 self.request.url)
            # urlsplit results have hostname and port results, but they
            # didn't support ipv6 literals until python 2.7.
            netloc = parsed.netloc
            if "@" in netloc:
                userpass, _, netloc = netloc.rpartition("@")
            match = re.match(r'^(.+):(\d+)$', netloc)
            if match:
                host = match.group(1)
                port = int(match.group(2))
            else:
                host = netloc
                port = 443 if parsed.scheme == "https" else 80
            if re.match(r'^\[.*\]$', host):
                # raw ipv6 addresses in urls are enclosed in brackets
                host = host[1:-1]
            if self.client.hostname_mapping is not None:
                host = self.client.hostname_mapping.get(host, host)

            if request.allow_ipv6:
                af = socket.AF_UNSPEC
            else:
                # We only try the first IP we get from getaddrinfo,
                # so restrict to ipv4 by default.
                af = socket.AF_INET

            addrinfo = socket.getaddrinfo(host, port, af, socket.SOCK_STREAM,
                                          0, 0)
            af, socktype, proto, canonname, sockaddr = addrinfo[0]

            if parsed.scheme == "https":
                ssl_options = {}
                if request.validate_cert:
                    ssl_options["cert_reqs"] = ssl.CERT_REQUIRED
                if request.ca_certs is not None:
                    ssl_options["ca_certs"] = request.ca_certs
                else:
                    ssl_options["ca_certs"] = _DEFAULT_CA_CERTS
                if request.client_key is not None:
                    ssl_options["keyfile"] = request.client_key
                if request.client_cert is not None:
                    ssl_options["certfile"] = request.client_cert
                self.stream = SSLIOStream(socket.socket(af, socktype, proto),
                                          io_loop=self.io_loop,
                                          ssl_options=ssl_options,
                                          max_buffer_size=max_buffer_size)
            else:
                self.stream = IOStream(socket.socket(af, socktype, proto),
                                       io_loop=self.io_loop,
                                       max_buffer_size=max_buffer_size)
            timeout = min(request.connect_timeout, request.request_timeout)
            if timeout:
                self._timeout = self.io_loop.add_timeout(
                    self.start_time + timeout,
                    self._on_timeout)
            self.stream.set_close_callback(self._on_close)
            self.stream.connect(sockaddr,
                                functools.partial(self._on_connect, parsed))

    def _on_timeout(self):
        self._timeout = None
        self._run_callback(HTTPResponse(self.request, 599,
                                        request_time=time.time() - self.start_time,
                                        error=HTTPError(599, "Timeout")))
        self.stream.close()

    def _on_connect(self, parsed):
        if self._timeout is not None:
            self.io_loop.remove_timeout(self._timeout)
            self._timeout = None
        if self.request.request_timeout:
            self._timeout = self.io_loop.add_timeout(
                self.start_time + self.request.request_timeout,
                self._on_timeout)
        if (self.request.validate_cert and
            isinstance(self.stream, SSLIOStream)):
            match_hostname(self.stream.socket.getpeercert(),
                           parsed.hostname)
        if (self.request.method not in self._SUPPORTED_METHODS and
            not self.request.allow_nonstandard_methods):
            raise KeyError("unknown method %s" % self.request.method)
        for key in ('network_interface',
                    'proxy_host', 'proxy_port',
                    'proxy_username', 'proxy_password'):
            if getattr(self.request, key, None):
                raise NotImplementedError('%s not supported' % key)
        if "Host" not in self.request.headers:
            self.request.headers["Host"] = parsed.netloc
        username, password = None, None
        if parsed.username is not None:
            username, password = parsed.username, parsed.password
        elif self.request.auth_username is not None:
            username = self.request.auth_username
            password = self.request.auth_password
        if username is not None:
            auth = utf8(username) + b(":") + utf8(password)
            self.request.headers["Authorization"] = (b("Basic ") +
                                                     base64.b64encode(auth))
        if self.request.user_agent:
            self.request.headers["User-Agent"] = self.request.user_agent
        if not self.request.allow_nonstandard_methods:
            if self.request.method in ("POST", "PUT"):
                assert self.request.body is not None
            else:
                assert self.request.body is None
        if self.request.body is not None:
            self.request.headers["Content-Length"] = str(len(
                    self.request.body))
        if (self.request.method == "POST" and
            "Content-Type" not in self.request.headers):
            self.request.headers["Content-Type"] = "application/x-www-form-urlencoded"
        if self.request.use_gzip:
            self.request.headers["Accept-Encoding"] = "gzip"
        req_path = ((parsed.path or '/') +
                (('?' + parsed.query) if parsed.query else ''))
        request_lines = [utf8("%s %s HTTP/1.1" % (self.request.method,
                                                  req_path))]
        for k, v in self.request.headers.get_all():
            line = utf8(k) + b(": ") + utf8(v)
            if b('\n') in line:
                raise ValueError('Newline in header: ' + repr(line))
            request_lines.append(line)
        self.stream.write(b("\r\n").join(request_lines) + b("\r\n\r\n"))
        if self.request.body is not None:
            self.stream.write(self.request.body)
        self.stream.read_until_regex(b("\r?\n\r?\n"), self._on_headers)

    def _release(self):
        if self.release_callback is not None:
            release_callback = self.release_callback
            self.release_callback = None
            release_callback()

    def _run_callback(self, response):
        self._release()
        if self.final_callback is not None:
            final_callback = self.final_callback
            self.final_callback = None
            final_callback(response)

    @contextlib.contextmanager
    def cleanup(self):
        try:
            yield
        except Exception as e:
            logging.warning("uncaught exception", exc_info=True)
            self._run_callback(HTTPResponse(self.request, 599, error=e, 
                                request_time=time.time() - self.start_time,
                                ))

    def _on_close(self):
        self._run_callback(HTTPResponse(
                self.request, 599,
                request_time=time.time() - self.start_time,
                error=HTTPError(599, "Connection closed")))

    def _on_headers(self, data):
        data = native_str(data.decode("latin1"))
        first_line, _, header_data = data.partition("\n")
        match = re.match("HTTP/1.[01] ([0-9]+)", first_line)
        assert match
        self.code = int(match.group(1))
        self.headers = HTTPHeaders.parse(header_data)
        if self.request.header_callback is not None:
            for k, v in self.headers.get_all():
                self.request.header_callback("%s: %s\r\n" % (k, v))
        if (self.request.use_gzip and
            self.headers.get("Content-Encoding") == "gzip"):
            # Magic parameter makes zlib module understand gzip header
            # http://stackoverflow.com/questions/1838699/how-can-i-decompress-a-gzip-stream-with-zlib
            self._decompressor = zlib.decompressobj(16+zlib.MAX_WBITS)
        if self.headers.get("Transfer-Encoding") == "chunked":
            self.chunks = []
            self.stream.read_until(b("\r\n"), self._on_chunk_length)
        elif "Content-Length" in self.headers:
            if "," in self.headers["Content-Length"]:
                # Proxies sometimes cause Content-Length headers to get
                # duplicated.  If all the values are identical then we can
                # use them but if they differ it's an error.
                pieces = re.split(r',\s*', self.headers["Content-Length"])
                if any(i != pieces[0] for i in pieces):
                    raise ValueError("Multiple unequal Content-Lengths: %r" % 
                                     self.headers["Content-Length"])
                self.headers["Content-Length"] = pieces[0]
            self.stream.read_bytes(int(self.headers["Content-Length"]),
                                   self._on_body)
        else:
            self.stream.read_until_close(self._on_body)

    def _on_body(self, data):
        if self._timeout is not None:
            self.io_loop.remove_timeout(self._timeout)
            self._timeout = None
        if self._decompressor:
            data = self._decompressor.decompress(data)
        if self.request.streaming_callback:
            if self.chunks is None:
                # if chunks is not None, we already called streaming_callback
                # in _on_chunk_data
                self.request.streaming_callback(data)
            buffer = BytesIO()
        else:
            buffer = BytesIO(data) # TODO: don't require one big string?
        original_request = getattr(self.request, "original_request",
                                   self.request)
        if (self.request.follow_redirects and
            self.request.max_redirects > 0 and
            self.code in (301, 302)):
            new_request = copy.copy(self.request)
            new_request.url = urllib.parse.urljoin(self.request.url,
                                               self.headers["Location"])
            new_request.max_redirects -= 1
            del new_request.headers["Host"]
            new_request.original_request = original_request
            final_callback = self.final_callback
            self.final_callback = None
            self._release()
            self.client.fetch(new_request, final_callback)
            self.stream.close()
            return
        response = HTTPResponse(original_request,
                                self.code, headers=self.headers,
                                request_time=time.time() - self.start_time,
                                buffer=buffer,
                                effective_url=self.request.url)
        self._run_callback(response)
        self.stream.close()

    def _on_chunk_length(self, data):
        # TODO: "chunk extensions" http://tools.ietf.org/html/rfc2616#section-3.6.1
        length = int(data.strip(), 16)
        if length == 0:
            # all the data has been decompressed, so we don't need to
            # decompress again in _on_body
            self._decompressor = None
            self._on_body(b('').join(self.chunks))
        else:
            self.stream.read_bytes(length + 2,  # chunk ends with \r\n
                              self._on_chunk_data)

    def _on_chunk_data(self, data):
        assert data[-2:] == b("\r\n")
        chunk = data[:-2]
        if self._decompressor:
            chunk = self._decompressor.decompress(chunk)
        if self.request.streaming_callback is not None:
            self.request.streaming_callback(chunk)
        else:
            self.chunks.append(chunk)
        self.stream.read_until(b("\r\n"), self._on_chunk_length)
Beispiel #30
0
class RedisPubSub(PubSubBase):
    
    def __init__(self, host='127.0.0.1', port=6379, *args, **kwargs):
        self.host = host
        self.port = port
        super(RedisPubSub, self).__init__(*args, **kwargs)

    @staticmethod
    def get_redis():
        return redis.StrictRedis(
            host = '127.0.0.1',
            port = 6379,
            db   = 0
        )

    ##
    ## pubsub api
    ##

    def connect(self):
        self.socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
        self.stream = IOStream(self.socket)
        self.stream.connect((self.host, self.port), self.on_connect)

    def disconnect(self):
        self.unsubscribe()
        self.stream.close()

    def subscribe(self, channel_id):
        self.send('SUBSCRIBE', channel_id)

    def unsubscribe(self, channel_id=None):
        if channel_id:
            self.send('UNSUBSCRIBE', channel_id)
        else:
            self.send('UNSUBSCRIBE')

    @staticmethod
    def publish(channel_id, message):
        r = RedisPubSub.get_redis()
        r.publish(channel_id, message)

    ##
    ## socket/stream callbacks
    ##

    def on_connect(self):
        self.stream.set_close_callback(self.on_close)
        self.stream.read_until_close(self.on_data, self.on_streaming_data)
        self.reader = hiredis.Reader()
        self.connected()

    def on_data(self, *args, **kwargs):
        pass

    def on_streaming_data(self, data):
        self.reader.feed(data)
        reply = self.reader.gets()
        while reply:
            if reply[0] == 'subscribe':
                self.subscribed(reply[1])
            elif reply[0] == 'unsubscribe':
                self.unsubscribed(reply[1])
            elif reply[0] == 'message':
                self.on_message(reply[1], reply[2])
            else:
                raise Exception('Unhandled data from redis %s' % reply)
            reply = self.reader.gets()

    def on_close(self):
        self.socket = None
        self.stream = None
        self.disconnected()

    ##
    ## redis protocol parser (derived from redis-py)
    ##

    def encode(self, value):
        if isinstance(value, bytes):
            return value
        if isinstance(value, float):
            value = repr(value)
        if not isinstance(value, basestring):
            value = str(value)
        if isinstance(value, unicode):
            value = value.encode('utf-8', 'strict')
        return value

    def pack_command(self, *args):
        cmd = io.BytesIO()
        cmd.write('*')
        cmd.write(str(len(args)))
        cmd.write('\r\n')
        for arg in args:
            arg = self.encode(arg)
            cmd.write('$')
            cmd.write(str(len(arg)))
            cmd.write('\r\n')
            cmd.write(arg)
            cmd.write('\r\n')
        return cmd.getvalue()

    def send(self, *args):
        """Send redis command."""
        cmd = self.pack_command(*args)
        self.stream.write(cmd)
Beispiel #31
0
class SoxDecoder:
    """Decodes a stream of encoded data to Sox via stdin and receives the output
    on stdout.
    """
    def __init__(self, codec, out_channels, out_samplerate, out_samplesize):
        self._started = False

        # codec and WAV params
        self._codec = codec
        self._channels = out_channels
        self._sample_rate = out_samplerate
        self._sample_size = out_samplesize

        # events
        self.on_close = None
        self.on_data_ready = None
        self.on_unhandled_error = None
        self.on_wav_params_ready = None

    @property
    def codec(self):
        return self._codec

    @property
    def channels(self):
        return self._channels

    @property
    def channel_mode(self):
        return {1: "Mono", 2: "Stereo"}.get(self._channels, "Unknown")

    @property
    def sample_rate(self):
        return self._sample_rate

    @property
    def sample_size(self):
        return self._sample_size

    def start(self, socket_or_fd, read_mtu):
        """Starts the decoder. If already started, this does nothing.
        """
        if self._started:
            return

        # process
        self._process = Subprocess([
            "sox", "-t", self._codec, "-", "--bits",
            str(self._sample_size), "--channels",
            str(self._channels), "--rate",
            str(self._sample_rate), "-t", "wav", "-"
        ],
                                   stdin=Subprocess.STREAM,
                                   stdout=Subprocess.STREAM,
                                   stderr=Subprocess.STREAM)
        self._process.stdout.set_close_callback(self._on_close)
        self._process.stdout.read_until_close(
            streaming_callback=self._out_data_ready)
        self._process.stderr.read_until_close(
            streaming_callback=self._sox_error)

        # did we get socket or fd?
        sock = socket_or_fd
        if isinstance(socket_or_fd, int):
            logger.debug("SoxDecoder received fd, building socket...")
            sock = socket.socket(fileno=socket_or_fd)
        sock.setblocking(True)

        # input pump
        self._input = IOStream(socket=sock)
        self._input.read_until_close(self._in_data_ready)

        # start
        self._started = True

        # we know WAV params already
        if self.on_wav_params_ready:
            self.on_wav_params_ready()

    def stop(self):
        """Stops the decoder. If already stopped, this does nothing.
        """
        if not self._started:
            return

        self._started = False
        self._process.proc.kill()
        self._process = None

    def _on_close(self, *args):
        """Called when the Sox process exits.
        """
        if not self._started:
            return

        self.stop()

        if self.on_close:
            self.on_close()

    def _in_data_ready(self, data):
        """Writes encoded data to the Sox input stream.
        """
        if not self._started:
            raise InvalidOperationError("Not started.")

        self._process.stdin.write(data)

    def _out_data_ready(self, data):
        """Called when decoded data is ready.
        """
        if self.on_data_ready:
            self.on_data_ready(data=data)

    def _sox_error(self, data):
        """Called when Sox writes to stderr. This isn't necessarily fatal, so
        we don't close the process.
        """
        if self.on_unhandled_error:
            self.on_unhandled_error(error=data)
Beispiel #32
0
class Connection(RedisCommandsMixin):
    def __init__(self, redis, on_connect=None):
        logger.debug('Creating new Redis connection.')
        self.redis = redis
        self.reader = hiredis.Reader()
        self._watch = set()
        self._multi = False
        self.callbacks = deque()
        self._on_connect_callback = on_connect

        self.stream = IOStream(socket.socket(redis._family, socket.SOCK_STREAM,
                                             0),
                               io_loop=redis._ioloop)
        self.stream.set_close_callback(self._on_close)
        self.stream.connect(redis._addr, self._on_connect)

    def _on_connect(self):
        logger.debug('Connected!')
        self.stream.read_until_close(self._on_close, self._on_read)
        self.redis._shared.append(self)
        if self._on_connect_callback is not None:
            self._on_connect_callback(self)
            self._on_connect_callback = None

    def _on_read(self, data):
        self.reader.feed(data)
        while True:
            resp = self.reader.gets()
            if resp is False:
                break
            callback = self.callbacks.popleft()
            if callback is not None:
                self.redis._ioloop.add_callback(partial(callback, resp))

    def is_idle(self):
        return len(self.callbacks) == 0

    def is_shared(self):
        return self in self.redis._shared

    def lock(self):
        if not self.is_shared():
            raise Exception('Connection already is locked!')
        self.redis._shared.remove(self)

    def unlock(self, callback=None):
        def cb(resp):
            assert resp == 'OK'
            self.redis._shared.append(self)

        if self._multi:
            self.send_message(['DISCARD'])
        elif self._watch:
            self.send_message(['UNWATCH'])

        self.send_message(['SELECT', self.redis._database], cb)

    def send_message(self, args, callback=None):

        command = args[0]

        if 'SUBSCRIBE' in command:
            raise NotImplementedError('Not yet.')

        # Do not allow the commands, affecting the execution of other commands,
        # to be used on shared connection.
        if command in ('WATCH', 'MULTI'):
            if self.is_shared():
                raise Exception('Command %s is not allowed while connection '
                                'is shared!' % command)
            if command == 'WATCH':
                self._watch.add(args[1])
            if command == 'MULTI':
                self._multi = True

        # monitor transaction state, to unlock correctly
        if command in ('EXEC', 'DISCARD', 'UNWATCH'):
            if command in ('EXEC', 'DISCARD'):
                self._multi = False
            self._watch.clear()

        self.stream.write(self.format_message(args))

        future = Future()

        if callback is not None:
            future.add_done_callback(stack_context.wrap(callback))

        self.callbacks.append(future.set_result)

        return future

    def format_message(self, args):
        l = "*%d" % len(args)
        lines = [l.encode('utf-8')]
        for arg in args:
            if not isinstance(arg, str):
                arg = str(arg)
            arg = arg.encode('utf-8')
            l = "$%d" % len(arg)
            lines.append(l.encode('utf-8'))
            lines.append(arg)
        lines.append(b"")
        return b"\r\n".join(lines)

    def close(self):
        self.send_command(['QUIT'])
        if self.is_shared():
            self.lock()

    def _on_close(self, data=None):
        logger.debug('Redis connection was closed.')
        if data is not None:
            self._on_read(data)
        if self.is_shared():
            self.lock()
Beispiel #33
0
def handle_connection(connection, address):
    stream = IOStream(connection)
    message = yield stream.read_until_close()
    print("message from client:", message.decode().strip())
Beispiel #34
0
class Server(Session):
    def __init__(self, protocol, io_loop=None, sock=None):
        Session.__init__(self, protocol, io_loop)
        self.is_server = True
        if sock is not None:
            self._setup(sock)

    def bind(self, address):
        self.address = address
        self.sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM, 0)
        self.sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
        self.sock.setblocking(0)
        self.sock.bind(address)
        
    def start(self):
        self.sock.listen(2)
        self.io_loop.add_handler(self.sock.fileno(), self._accept, self.io_loop.READ)
        self.status = SessionStream.LISTENING
        
    def listen(self, address):
        self.bind(address)
        self.start()

    def _accept(self, fd, event):
        if fd != self.sock.fileno():
            print "panic: fd != sock.fileno() ...."
            return
        conn, address = self.sock.accept()
        if self.status == SessionStream.LISTENING:
            self._setup(conn)
        elif self.status == SessionStream.CONNECTED:
            logging.warning( "already connected ...")
            conn.close()

    def _setup(self, conn):
        self.ios = IOStream(conn, self.io_loop)
        self.status = SessionStream.CONNECTED
        self.protocol.connected(self)
        self.ios.set_close_callback(self._ios_closed)
        self.ios.read_until_close(self._disconnected, self._receiver)
        
    def close(self):
        logging.debug("Server.close() called...")
        if self.ios is not None:
            self.ios.close()
        self.io_loop.remove_handler(self.sock.fileno())
        if self.sock is not None:
            self.sock.close()
        self.status = SessionStream.IDLE

    def _disconnected(self, data):
        logging.info("client disconnected ... ")
        self.protocol.disconnected(self)
        self.clear_buffer()
        self.clear_timers()
        self.ios.close()
        self.status = SessionStream.LISTENING
        logging.debug("end of client disconnect ...")

    def _ios_closed(self):
        self.status = SessionStream.LISTENING
Beispiel #35
0
class _HTTPConnection(object):
    _SUPPORTED_METHODS = set(["GET", "HEAD", "POST", "PUT", "DELETE"])

    def __init__(self, io_loop, client, request, release_callback, final_callback, max_buffer_size):
        self.start_time = time.time()
        self.io_loop = io_loop
        self.client = client
        self.request = request
        self.release_callback = release_callback
        self.final_callback = final_callback
        self.code = None
        self.headers = None
        self.chunks = None
        self._decompressor = None
        # Timeout handle returned by IOLoop.add_timeout
        self._timeout = None
        with stack_context.StackContext(self.cleanup):
            parsed = urllib.parse.urlsplit(_unicode(self.request.url))
            if ssl is None and parsed.scheme == "https":
                raise ValueError("HTTPS requires either python2.6+ or " "curl_httpclient")
            if parsed.scheme not in ("http", "https"):
                raise ValueError("Unsupported url scheme: %s" % self.request.url)
            # urlsplit results have hostname and port results, but they
            # didn't support ipv6 literals until python 2.7.
            netloc = parsed.netloc
            if "@" in netloc:
                userpass, _, netloc = netloc.rpartition("@")
            match = re.match(r"^(.+):(\d+)$", netloc)
            if match:
                host = match.group(1)
                port = int(match.group(2))
            else:
                host = netloc
                port = 443 if parsed.scheme == "https" else 80
            if re.match(r"^\[.*\]$", host):
                # raw ipv6 addresses in urls are enclosed in brackets
                host = host[1:-1]
            if self.client.hostname_mapping is not None:
                host = self.client.hostname_mapping.get(host, host)

            if request.allow_ipv6:
                af = socket.AF_UNSPEC
            else:
                # We only try the first IP we get from getaddrinfo,
                # so restrict to ipv4 by default.
                af = socket.AF_INET

            addrinfo = socket.getaddrinfo(host, port, af, socket.SOCK_STREAM, 0, 0)
            af, socktype, proto, canonname, sockaddr = addrinfo[0]

            if parsed.scheme == "https":
                ssl_options = {}
                if request.validate_cert:
                    ssl_options["cert_reqs"] = ssl.CERT_REQUIRED
                if request.ca_certs is not None:
                    ssl_options["ca_certs"] = request.ca_certs
                else:
                    ssl_options["ca_certs"] = _DEFAULT_CA_CERTS
                if request.client_key is not None:
                    ssl_options["keyfile"] = request.client_key
                if request.client_cert is not None:
                    ssl_options["certfile"] = request.client_cert

                # SSL interoperability is tricky.  We want to disable
                # SSLv2 for security reasons; it wasn't disabled by default
                # until openssl 1.0.  The best way to do this is to use
                # the SSL_OP_NO_SSLv2, but that wasn't exposed to python
                # until 3.2.  Python 2.7 adds the ciphers argument, which
                # can also be used to disable SSLv2.  As a last resort
                # on python 2.6, we set ssl_version to SSLv3.  This is
                # more narrow than we'd like since it also breaks
                # compatibility with servers configured for TLSv1 only,
                # but nearly all servers support SSLv3:
                # http://blog.ivanristic.com/2011/09/ssl-survey-protocol-support.html
                if sys.version_info >= (2, 7):
                    ssl_options["ciphers"] = "DEFAULT:!SSLv2"
                else:
                    # This is really only necessary for pre-1.0 versions
                    # of openssl, but python 2.6 doesn't expose version
                    # information.
                    ssl_options["ssl_version"] = ssl.PROTOCOL_SSLv3

                self.stream = SSLIOStream(
                    socket.socket(af, socktype, proto),
                    io_loop=self.io_loop,
                    ssl_options=ssl_options,
                    max_buffer_size=max_buffer_size,
                )
            else:
                self.stream = IOStream(
                    socket.socket(af, socktype, proto), io_loop=self.io_loop, max_buffer_size=max_buffer_size
                )
            timeout = min(request.connect_timeout, request.request_timeout)
            if timeout:
                self._timeout = self.io_loop.add_timeout(self.start_time + timeout, self._on_timeout)
            self.stream.set_close_callback(self._on_close)
            self.stream.connect(sockaddr, functools.partial(self._on_connect, parsed))

    def _on_timeout(self):
        self._timeout = None
        self._run_callback(
            HTTPResponse(self.request, 599, request_time=time.time() - self.start_time, error=HTTPError(599, "Timeout"))
        )
        self.stream.close()

    def _on_connect(self, parsed):
        if self._timeout is not None:
            self.io_loop.remove_timeout(self._timeout)
            self._timeout = None
        if self.request.request_timeout:
            self._timeout = self.io_loop.add_timeout(self.start_time + self.request.request_timeout, self._on_timeout)
        if self.request.validate_cert and isinstance(self.stream, SSLIOStream):
            match_hostname(self.stream.socket.getpeercert(), parsed.hostname)
        if self.request.method not in self._SUPPORTED_METHODS and not self.request.allow_nonstandard_methods:
            raise KeyError("unknown method %s" % self.request.method)
        for key in ("network_interface", "proxy_host", "proxy_port", "proxy_username", "proxy_password"):
            if getattr(self.request, key, None):
                raise NotImplementedError("%s not supported" % key)
        if "Host" not in self.request.headers:
            self.request.headers["Host"] = parsed.netloc
        username, password = None, None
        if parsed.username is not None:
            username, password = parsed.username, parsed.password
        elif self.request.auth_username is not None:
            username = self.request.auth_username
            password = self.request.auth_password or ""
        if username is not None:
            auth = utf8(username) + b(":") + utf8(password)
            self.request.headers["Authorization"] = b("Basic ") + base64.b64encode(auth)
        if self.request.user_agent:
            self.request.headers["User-Agent"] = self.request.user_agent
        if not self.request.allow_nonstandard_methods:
            if self.request.method in ("POST", "PUT"):
                assert self.request.body is not None
            else:
                assert self.request.body is None
        if self.request.body is not None:
            self.request.headers["Content-Length"] = str(len(self.request.body))
        if self.request.method == "POST" and "Content-Type" not in self.request.headers:
            self.request.headers["Content-Type"] = "application/x-www-form-urlencoded"
        if self.request.use_gzip:
            self.request.headers["Accept-Encoding"] = "gzip"
        req_path = (parsed.path or "/") + (("?" + parsed.query) if parsed.query else "")
        request_lines = [utf8("%s %s HTTP/1.1" % (self.request.method, req_path))]
        for k, v in self.request.headers.get_all():
            line = utf8(k) + b(": ") + utf8(v)
            if b("\n") in line:
                raise ValueError("Newline in header: " + repr(line))
            request_lines.append(line)
        self.stream.write(b("\r\n").join(request_lines) + b("\r\n\r\n"))
        if self.request.body is not None:
            self.stream.write(self.request.body)
        self.stream.read_until_regex(b("\r?\n\r?\n"), self._on_headers)

    def _release(self):
        if self.release_callback is not None:
            release_callback = self.release_callback
            self.release_callback = None
            release_callback()

    def _run_callback(self, response):
        self._release()
        if self.final_callback is not None:
            final_callback = self.final_callback
            self.final_callback = None
            final_callback(response)

    @contextlib.contextmanager
    def cleanup(self):
        try:
            yield
        except Exception as e:
            logging.warning("uncaught exception", exc_info=True)
            self._run_callback(HTTPResponse(self.request, 599, error=e, request_time=time.time() - self.start_time))

    def _on_close(self):
        self._run_callback(
            HTTPResponse(
                self.request, 599, request_time=time.time() - self.start_time, error=HTTPError(599, "Connection closed")
            )
        )

    def _on_headers(self, data):
        data = native_str(data.decode("latin1"))
        first_line, _, header_data = data.partition("\n")
        match = re.match("HTTP/1.[01] ([0-9]+)", first_line)
        assert match
        self.code = int(match.group(1))
        self.headers = HTTPHeaders.parse(header_data)

        if "Content-Length" in self.headers:
            if "," in self.headers["Content-Length"]:
                # Proxies sometimes cause Content-Length headers to get
                # duplicated.  If all the values are identical then we can
                # use them but if they differ it's an error.
                pieces = re.split(r",\s*", self.headers["Content-Length"])
                if any(i != pieces[0] for i in pieces):
                    raise ValueError("Multiple unequal Content-Lengths: %r" % self.headers["Content-Length"])
                self.headers["Content-Length"] = pieces[0]
            content_length = int(self.headers["Content-Length"])
        else:
            content_length = None

        if self.request.header_callback is not None:
            for k, v in self.headers.get_all():
                self.request.header_callback("%s: %s\r\n" % (k, v))

        if self.request.method == "HEAD":
            # HEAD requests never have content, even though they may have
            # content-length headers
            self._on_body(b(""))
            return
        if 100 <= self.code < 200 or self.code in (204, 304):
            # These response codes never have bodies
            # http://www.w3.org/Protocols/rfc2616/rfc2616-sec4.html#sec4.3
            assert "Transfer-Encoding" not in self.headers
            assert content_length in (None, 0)
            self._on_body(b(""))
            return

        if self.request.use_gzip and self.headers.get("Content-Encoding") == "gzip":
            # Magic parameter makes zlib module understand gzip header
            # http://stackoverflow.com/questions/1838699/how-can-i-decompress-a-gzip-stream-with-zlib
            self._decompressor = zlib.decompressobj(16 + zlib.MAX_WBITS)
        if self.headers.get("Transfer-Encoding") == "chunked":
            self.chunks = []
            self.stream.read_until(b("\r\n"), self._on_chunk_length)
        elif content_length is not None:
            self.stream.read_bytes(content_length, self._on_body)
        else:
            self.stream.read_until_close(self._on_body)

    def _on_body(self, data):
        if self._timeout is not None:
            self.io_loop.remove_timeout(self._timeout)
            self._timeout = None
        original_request = getattr(self.request, "original_request", self.request)
        if self.request.follow_redirects and self.request.max_redirects > 0 and self.code in (301, 302, 303, 307):
            new_request = copy.copy(self.request)
            new_request.url = urllib.parse.urljoin(self.request.url, self.headers["Location"])
            new_request.max_redirects -= 1
            del new_request.headers["Host"]
            # http://www.w3.org/Protocols/rfc2616/rfc2616-sec10.html#sec10.3.4
            # client SHOULD make a GET request
            if self.code == 303:
                new_request.method = "GET"
                new_request.body = None
                for h in ["Content-Length", "Content-Type", "Content-Encoding", "Transfer-Encoding"]:
                    try:
                        del self.request.headers[h]
                    except KeyError:
                        pass
            new_request.original_request = original_request
            final_callback = self.final_callback
            self.final_callback = None
            self._release()
            self.client.fetch(new_request, final_callback)
            self.stream.close()
            return
        if self._decompressor:
            data = self._decompressor.decompress(data)
        if self.request.streaming_callback:
            if self.chunks is None:
                # if chunks is not None, we already called streaming_callback
                # in _on_chunk_data
                self.request.streaming_callback(data)
            buffer = BytesIO()
        else:
            buffer = BytesIO(data)  # TODO: don't require one big string?
        response = HTTPResponse(
            original_request,
            self.code,
            headers=self.headers,
            request_time=time.time() - self.start_time,
            buffer=buffer,
            effective_url=self.request.url,
        )
        self._run_callback(response)
        self.stream.close()

    def _on_chunk_length(self, data):
        # TODO: "chunk extensions" http://tools.ietf.org/html/rfc2616#section-3.6.1
        length = int(data.strip(), 16)
        if length == 0:
            # all the data has been decompressed, so we don't need to
            # decompress again in _on_body
            self._decompressor = None
            self._on_body(b("").join(self.chunks))
        else:
            self.stream.read_bytes(length + 2, self._on_chunk_data)  # chunk ends with \r\n

    def _on_chunk_data(self, data):
        assert data[-2:] == b("\r\n")
        chunk = data[:-2]
        if self._decompressor:
            chunk = self._decompressor.decompress(chunk)
        if self.request.streaming_callback is not None:
            self.request.streaming_callback(chunk)
        else:
            self.chunks.append(chunk)
        self.stream.read_until(b("\r\n"), self._on_chunk_length)
Beispiel #36
0
class IRCClient(object):

    def __init__(self, uri, ioloop):
        self.ioloop = ioloop
        self.callbacks = {
            "PING": [pong_callback],
            "NOTICE": [debug_callback],
            "ERROR": [die_callback]
        }
        self.conn = IRCConnection.from_uri(uri)
        self.socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM, 0)
        self.stream = IOStream(self.socket, io_loop=ioloop)
        self.current_chunk = ""

    def add_message_callback(self, command, func):
        self.callbacks.setdefault(command, []).append(func)

    def read_timeout(self):
        if self.message_future and not self.message_future.done():
            self.message_future.set_exception(ReadTimeout("TIMEOUT"))

    def stream_bytes(self, chunk):
        self.current_chunk += chunk
        while "\r\n" in self.current_chunk:
            original_message, self.current_chunk = self.current_chunk.split(
                "\r\n", 1)

            message = Message.from_message(original_message)

            if message.ident[1:].startswith(self.conn.username):
                logging.debug(
                    "Skipping message from self: {0}".format(message.message))
                continue

            if message.command not in self.callbacks:
                logging.info("SKIPPING - {0}".format(original_message))
                continue

            for callback in self.callbacks[message.command]:
                try:
                    callback(message, self.stream)
                except StreamClosedError:
                    logging.error("Stream was closed.")
                    raise
                except Exception as exc:
                    logging.error("Exception {0} in callback.".format(exc))

    @gen.coroutine
    def listen(self):
        logging.info(
            "Connecting to {0}:{1}".format(self.conn.host, self.conn.port))
        self.stream.connect((self.conn.host, self.conn.port))

        logging.info("Registering the client.")

        self.stream.write("PASS {0}\r\n".format(self.conn.password))
        self.stream.write("NICK {0}\r\n".format(self.conn.username))
        self.stream.write("USER {0} {1} unused :{2}\r\n".format(
            self.conn.username, socket.gethostname(), self.conn.name))
        for channel in self.conn.channels:
            logging.info("Joining channel: #{0}".format(channel))
            self.stream.write("JOIN #{0}\r\n".format(channel))

        try:
            yield self.stream.read_until_close(
                streaming_callback=self.stream_bytes)
        except StreamClosedError:
            logging.error("Stream is closed.")
            raise gen.Return(False)
        raise gen.Return(True)

    def stop(self):
        logging.info("Client no longer listening.")
        self.stream.write("QUIT\r\n")
        self.stream.close()
Beispiel #37
0
class KeepAliveTest(AsyncHTTPTestCase):
    """Tests various scenarios for HTTP 1.1 keep-alive support.

    These tests don't use AsyncHTTPClient because we want to control
    connection reuse and closing.
    """

    def get_app(self):
        class HelloHandler(RequestHandler):
            def get(self):
                self.finish("Hello world")

            def post(self):
                self.finish("Hello world")

        class LargeHandler(RequestHandler):
            def get(self):
                # 512KB should be bigger than the socket buffers so it will
                # be written out in chunks.
                self.write("".join(chr(i % 256) * 1024 for i in range(512)))

        class FinishOnCloseHandler(RequestHandler):
            def initialize(self, cleanup_event):
                self.cleanup_event = cleanup_event

            @gen.coroutine
            def get(self):
                self.flush()
                yield self.cleanup_event.wait()

            def on_connection_close(self):
                # This is not very realistic, but finishing the request
                # from the close callback has the right timing to mimic
                # some errors seen in the wild.
                self.finish("closed")

        self.cleanup_event = Event()
        return Application(
            [
                ("/", HelloHandler),
                ("/large", LargeHandler),
                (
                    "/finish_on_close",
                    FinishOnCloseHandler,
                    dict(cleanup_event=self.cleanup_event),
                ),
            ]
        )

    def setUp(self):
        super(KeepAliveTest, self).setUp()
        self.http_version = b"HTTP/1.1"

    def tearDown(self):
        # We just closed the client side of the socket; let the IOLoop run
        # once to make sure the server side got the message.
        self.io_loop.add_timeout(datetime.timedelta(seconds=0.001), self.stop)
        self.wait()

        if hasattr(self, "stream"):
            self.stream.close()
        super(KeepAliveTest, self).tearDown()

    # The next few methods are a crude manual http client
    @gen.coroutine
    def connect(self):
        self.stream = IOStream(socket.socket())
        yield self.stream.connect(("127.0.0.1", self.get_http_port()))

    @gen.coroutine
    def read_headers(self):
        first_line = yield self.stream.read_until(b"\r\n")
        self.assertTrue(first_line.startswith(b"HTTP/1.1 200"), first_line)
        header_bytes = yield self.stream.read_until(b"\r\n\r\n")
        headers = HTTPHeaders.parse(header_bytes.decode("latin1"))
        raise gen.Return(headers)

    @gen.coroutine
    def read_response(self):
        self.headers = yield self.read_headers()
        body = yield self.stream.read_bytes(int(self.headers["Content-Length"]))
        self.assertEqual(b"Hello world", body)

    def close(self):
        self.stream.close()
        del self.stream

    @gen_test
    def test_two_requests(self):
        yield self.connect()
        self.stream.write(b"GET / HTTP/1.1\r\n\r\n")
        yield self.read_response()
        self.stream.write(b"GET / HTTP/1.1\r\n\r\n")
        yield self.read_response()
        self.close()

    @gen_test
    def test_request_close(self):
        yield self.connect()
        self.stream.write(b"GET / HTTP/1.1\r\nConnection: close\r\n\r\n")
        yield self.read_response()
        data = yield self.stream.read_until_close()
        self.assertTrue(not data)
        self.assertEqual(self.headers["Connection"], "close")
        self.close()

    # keepalive is supported for http 1.0 too, but it's opt-in
    @gen_test
    def test_http10(self):
        self.http_version = b"HTTP/1.0"
        yield self.connect()
        self.stream.write(b"GET / HTTP/1.0\r\n\r\n")
        yield self.read_response()
        data = yield self.stream.read_until_close()
        self.assertTrue(not data)
        self.assertTrue("Connection" not in self.headers)
        self.close()

    @gen_test
    def test_http10_keepalive(self):
        self.http_version = b"HTTP/1.0"
        yield self.connect()
        self.stream.write(b"GET / HTTP/1.0\r\nConnection: keep-alive\r\n\r\n")
        yield self.read_response()
        self.assertEqual(self.headers["Connection"], "Keep-Alive")
        self.stream.write(b"GET / HTTP/1.0\r\nConnection: keep-alive\r\n\r\n")
        yield self.read_response()
        self.assertEqual(self.headers["Connection"], "Keep-Alive")
        self.close()

    @gen_test
    def test_http10_keepalive_extra_crlf(self):
        self.http_version = b"HTTP/1.0"
        yield self.connect()
        self.stream.write(b"GET / HTTP/1.0\r\nConnection: keep-alive\r\n\r\n\r\n")
        yield self.read_response()
        self.assertEqual(self.headers["Connection"], "Keep-Alive")
        self.stream.write(b"GET / HTTP/1.0\r\nConnection: keep-alive\r\n\r\n")
        yield self.read_response()
        self.assertEqual(self.headers["Connection"], "Keep-Alive")
        self.close()

    @gen_test
    def test_pipelined_requests(self):
        yield self.connect()
        self.stream.write(b"GET / HTTP/1.1\r\n\r\nGET / HTTP/1.1\r\n\r\n")
        yield self.read_response()
        yield self.read_response()
        self.close()

    @gen_test
    def test_pipelined_cancel(self):
        yield self.connect()
        self.stream.write(b"GET / HTTP/1.1\r\n\r\nGET / HTTP/1.1\r\n\r\n")
        # only read once
        yield self.read_response()
        self.close()

    @gen_test
    def test_cancel_during_download(self):
        yield self.connect()
        self.stream.write(b"GET /large HTTP/1.1\r\n\r\n")
        yield self.read_headers()
        yield self.stream.read_bytes(1024)
        self.close()

    @gen_test
    def test_finish_while_closed(self):
        yield self.connect()
        self.stream.write(b"GET /finish_on_close HTTP/1.1\r\n\r\n")
        yield self.read_headers()
        self.close()
        # Let the hanging coroutine clean up after itself
        self.cleanup_event.set()

    @gen_test
    def test_keepalive_chunked(self):
        self.http_version = b"HTTP/1.0"
        yield self.connect()
        self.stream.write(
            b"POST / HTTP/1.0\r\n"
            b"Connection: keep-alive\r\n"
            b"Transfer-Encoding: chunked\r\n"
            b"\r\n"
            b"0\r\n"
            b"\r\n"
        )
        yield self.read_response()
        self.assertEqual(self.headers["Connection"], "Keep-Alive")
        self.stream.write(b"GET / HTTP/1.0\r\nConnection: keep-alive\r\n\r\n")
        yield self.read_response()
        self.assertEqual(self.headers["Connection"], "Keep-Alive")
        self.close()
Beispiel #38
0
class BaseTornadoClient(AsyncModbusClientMixin):
    """
    Base Tornado client
    """
    stream = None
    io_loop = None

    def __init__(self, *args, **kwargs):
        """
        Initializes BaseTornadoClient.
        ioloop to be passed as part of kwargs ('ioloop')
        :param args:
        :param kwargs:
        """
        self.io_loop = kwargs.pop("ioloop", None)
        super(BaseTornadoClient, self).__init__(*args, **kwargs)

    @abc.abstractmethod
    def get_socket(self):
        """
        return instance of the socket to connect to
        """

    @gen.coroutine
    def connect(self):
        """
        Connect to the socket identified by host and port

        :returns: Future
        :rtype: tornado.concurrent.Future
        """
        conn = self.get_socket()
        self.stream = IOStream(conn, io_loop=self.io_loop or IOLoop.current())
        self.stream.connect((self.host, self.port))
        self.stream.read_until_close(None,
                                     streaming_callback=self.on_receive)
        self._connected = True
        LOGGER.debug("Client connected")

        raise gen.Return(self)

    def on_receive(self, *args):
        """
        On data recieve call back
        :param args: data received
        :return:
        """
        data = args[0] if len(args) > 0 else None

        if not data:
            return
        LOGGER.debug("recv: " + " ".join([hex(byte2int(x)) for x in data]))
        unit = self.framer.decode_data(data).get("uid", 0)
        self.framer.processIncomingPacket(data, self._handle_response, unit=unit)

    def execute(self, request=None):
        """
        Executes a transaction
        :param request:
        :return:
        """
        request.transaction_id = self.transaction.getNextTID()
        packet = self.framer.buildPacket(request)
        LOGGER.debug("send: " + " ".join([hex(byte2int(x)) for x in packet]))
        self.stream.write(packet)
        return self._build_response(request.transaction_id)

    def _handle_response(self, reply, **kwargs):
        """
        Handle response received
        :param reply:
        :param kwargs:
        :return:
        """
        if reply is not None:
            tid = reply.transaction_id
            future = self.transaction.getTransaction(tid)
            if future:
                future.set_result(reply)
            else:
                LOGGER.debug("Unrequested message: {}".format(reply))

    def _build_response(self, tid):
        """
        Builds a future response
        :param tid:
        :return:
        """
        f = Future()

        if not self._connected:
            f.set_exception(ConnectionException("Client is not connected"))
            return f

        self.transaction.addTransaction(f, tid)
        return f

    def close(self):
        """
        Closes the underlying IOStream
        """
        LOGGER.debug("Client disconnected")
        if self.stream:
            self.stream.close_fd()

        self.stream = None
        self._connected = False
class _HTTPConnection(object):
    _SUPPORTED_METHODS = set(["GET", "HEAD", "POST", "PUT", "DELETE", "PATCH", "OPTIONS"])

    def __init__(self, io_loop, client, request, release_callback, final_callback, max_buffer_size):
        self.start_time = io_loop.time()
        self.io_loop = io_loop
        self.client = client
        self.request = request
        self.release_callback = release_callback
        self.final_callback = final_callback
        self.max_buffer_size = max_buffer_size
        self.code = None
        self.headers = None
        self.chunks = None
        self._decompressor = None
        # Timeout handle returned by IOLoop.add_timeout
        self._timeout = None
        with stack_context.ExceptionStackContext(self._handle_exception):
            self.parsed = urlparse.urlsplit(_unicode(self.request.url))
            if self.parsed.scheme not in ("http", "https"):
                raise ValueError("Unsupported url scheme: %s" % self.request.url)
            # urlsplit results have hostname and port results, but they
            # didn't support ipv6 literals until python 2.7.
            netloc = self.parsed.netloc
            if "@" in netloc:
                userpass, _, netloc = netloc.rpartition("@")
            match = re.match(r"^(.+):(\d+)$", netloc)
            if match:
                host = match.group(1)
                port = int(match.group(2))
            else:
                host = netloc
                port = 443 if self.parsed.scheme == "https" else 80
            if re.match(r"^\[.*\]$", host):
                # raw ipv6 addresses in urls are enclosed in brackets
                host = host[1:-1]
            self.parsed_hostname = host  # save final host for _on_connect
            if self.client.hostname_mapping is not None:
                host = self.client.hostname_mapping.get(host, host)

            if request.allow_ipv6:
                af = socket.AF_UNSPEC
            else:
                # We only try the first IP we get from getaddrinfo,
                # so restrict to ipv4 by default.
                af = socket.AF_INET

            self.client.resolver.getaddrinfo(host, port, af, socket.SOCK_STREAM, 0, 0, callback=self._on_resolve)

    def _on_resolve(self, future):
        af, socktype, proto, canonname, sockaddr = future.result()[0]

        if self.parsed.scheme == "https":
            ssl_options = {}
            if self.request.validate_cert:
                ssl_options["cert_reqs"] = ssl.CERT_REQUIRED
            if self.request.ca_certs is not None:
                ssl_options["ca_certs"] = self.request.ca_certs
            else:
                ssl_options["ca_certs"] = _DEFAULT_CA_CERTS
            if self.request.client_key is not None:
                ssl_options["keyfile"] = self.request.client_key
            if self.request.client_cert is not None:
                ssl_options["certfile"] = self.request.client_cert

            # SSL interoperability is tricky.  We want to disable
            # SSLv2 for security reasons; it wasn't disabled by default
            # until openssl 1.0.  The best way to do this is to use
            # the SSL_OP_NO_SSLv2, but that wasn't exposed to python
            # until 3.2.  Python 2.7 adds the ciphers argument, which
            # can also be used to disable SSLv2.  As a last resort
            # on python 2.6, we set ssl_version to SSLv3.  This is
            # more narrow than we'd like since it also breaks
            # compatibility with servers configured for TLSv1 only,
            # but nearly all servers support SSLv3:
            # http://blog.ivanristic.com/2011/09/ssl-survey-protocol-support.html
            if sys.version_info >= (2, 7):
                ssl_options["ciphers"] = "DEFAULT:!SSLv2"
            else:
                # This is really only necessary for pre-1.0 versions
                # of openssl, but python 2.6 doesn't expose version
                # information.
                ssl_options["ssl_version"] = ssl.PROTOCOL_SSLv3

            self.stream = SSLIOStream(
                socket.socket(af, socktype, proto),
                io_loop=self.io_loop,
                ssl_options=ssl_options,
                max_buffer_size=self.max_buffer_size,
            )
        else:
            self.stream = IOStream(
                socket.socket(af, socktype, proto), io_loop=self.io_loop, max_buffer_size=self.max_buffer_size
            )
        timeout = min(self.request.connect_timeout, self.request.request_timeout)
        if timeout:
            self._timeout = self.io_loop.add_timeout(self.start_time + timeout, stack_context.wrap(self._on_timeout))
        self.stream.set_close_callback(self._on_close)
        # ipv6 addresses are broken (in self.parsed.hostname) until
        # 2.7, here is correctly parsed value calculated in __init__
        self.stream.connect(sockaddr, self._on_connect, server_hostname=self.parsed_hostname)

    def _on_timeout(self):
        self._timeout = None
        if self.final_callback is not None:
            raise HTTPError(599, "Timeout")

    def _on_connect(self):
        if self._timeout is not None:
            self.io_loop.remove_timeout(self._timeout)
            self._timeout = None
        if self.request.request_timeout:
            self._timeout = self.io_loop.add_timeout(
                self.start_time + self.request.request_timeout, stack_context.wrap(self._on_timeout)
            )
        if self.request.method not in self._SUPPORTED_METHODS and not self.request.allow_nonstandard_methods:
            raise KeyError("unknown method %s" % self.request.method)
        for key in ("network_interface", "proxy_host", "proxy_port", "proxy_username", "proxy_password"):
            if getattr(self.request, key, None):
                raise NotImplementedError("%s not supported" % key)
        if "Connection" not in self.request.headers:
            self.request.headers["Connection"] = "close"
        if "Host" not in self.request.headers:
            if "@" in self.parsed.netloc:
                self.request.headers["Host"] = self.parsed.netloc.rpartition("@")[-1]
            else:
                self.request.headers["Host"] = self.parsed.netloc
        username, password = None, None
        if self.parsed.username is not None:
            username, password = self.parsed.username, self.parsed.password
        elif self.request.auth_username is not None:
            username = self.request.auth_username
            password = self.request.auth_password or ""
        if username is not None:
            auth = utf8(username) + b":" + utf8(password)
            self.request.headers["Authorization"] = b"Basic " + base64.b64encode(auth)
        if self.request.user_agent:
            self.request.headers["User-Agent"] = self.request.user_agent
        if not self.request.allow_nonstandard_methods:
            if self.request.method in ("POST", "PATCH", "PUT"):
                assert self.request.body is not None
            else:
                assert self.request.body is None
        if self.request.body is not None:
            self.request.headers["Content-Length"] = str(len(self.request.body))
        if self.request.method == "POST" and "Content-Type" not in self.request.headers:
            self.request.headers["Content-Type"] = "application/x-www-form-urlencoded"
        if self.request.use_gzip:
            self.request.headers["Accept-Encoding"] = "gzip"
        req_path = (self.parsed.path or "/") + (("?" + self.parsed.query) if self.parsed.query else "")
        request_lines = [utf8("%s %s HTTP/1.1" % (self.request.method, req_path))]
        for k, v in self.request.headers.get_all():
            line = utf8(k) + b": " + utf8(v)
            if b"\n" in line:
                raise ValueError("Newline in header: " + repr(line))
            request_lines.append(line)
        self.stream.write(b"\r\n".join(request_lines) + b"\r\n\r\n")
        if self.request.body is not None:
            self.stream.write(self.request.body)
        self.stream.read_until_regex(b"\r?\n\r?\n", self._on_headers)

    def _release(self):
        if self.release_callback is not None:
            release_callback = self.release_callback
            self.release_callback = None
            release_callback()

    def _run_callback(self, response):
        self._release()
        if self.final_callback is not None:
            final_callback = self.final_callback
            self.final_callback = None
            self.io_loop.add_callback(final_callback, response)

    def _handle_exception(self, typ, value, tb):
        if self.final_callback:
            gen_log.warning("uncaught exception", exc_info=(typ, value, tb))
            self._run_callback(
                HTTPResponse(self.request, 599, error=value, request_time=self.io_loop.time() - self.start_time)
            )

            if hasattr(self, "stream"):
                self.stream.close()
            return True
        else:
            # If our callback has already been called, we are probably
            # catching an exception that is not caused by us but rather
            # some child of our callback. Rather than drop it on the floor,
            # pass it along.
            return False

    def _on_close(self):
        if self.final_callback is not None:
            message = "Connection closed"
            if self.stream.error:
                message = str(self.stream.error)
            raise HTTPError(599, message)

    def _on_headers(self, data):
        data = native_str(data.decode("latin1"))
        first_line, _, header_data = data.partition("\n")
        match = re.match("HTTP/1.[01] ([0-9]+) ([^\r]*)", first_line)
        assert match
        code = int(match.group(1))
        if 100 <= code < 200:
            self.stream.read_until_regex(b"\r?\n\r?\n", self._on_headers)
            return
        else:
            self.code = code
            self.reason = match.group(2)
        self.headers = HTTPHeaders.parse(header_data)

        if "Content-Length" in self.headers:
            if "," in self.headers["Content-Length"]:
                # Proxies sometimes cause Content-Length headers to get
                # duplicated.  If all the values are identical then we can
                # use them but if they differ it's an error.
                pieces = re.split(r",\s*", self.headers["Content-Length"])
                if any(i != pieces[0] for i in pieces):
                    raise ValueError("Multiple unequal Content-Lengths: %r" % self.headers["Content-Length"])
                self.headers["Content-Length"] = pieces[0]
            content_length = int(self.headers["Content-Length"])
        else:
            content_length = None

        if self.request.header_callback is not None:
            # re-attach the newline we split on earlier
            self.request.header_callback(first_line + _)
            for k, v in self.headers.get_all():
                self.request.header_callback("%s: %s\r\n" % (k, v))
            self.request.header_callback("\r\n")

        if self.request.method == "HEAD" or self.code == 304:
            # HEAD requests and 304 responses never have content, even
            # though they may have content-length headers
            self._on_body(b"")
            return
        if 100 <= self.code < 200 or self.code == 204:
            # These response codes never have bodies
            # http://www.w3.org/Protocols/rfc2616/rfc2616-sec4.html#sec4.3
            if "Transfer-Encoding" in self.headers or content_length not in (None, 0):
                raise ValueError("Response with code %d should not have body" % self.code)
            self._on_body(b"")
            return

        if self.request.use_gzip and self.headers.get("Content-Encoding") == "gzip":
            self._decompressor = GzipDecompressor()
        if self.headers.get("Transfer-Encoding") == "chunked":
            self.chunks = []
            self.stream.read_until(b"\r\n", self._on_chunk_length)
        elif content_length is not None:
            self.stream.read_bytes(content_length, self._on_body)
        else:
            self.stream.read_until_close(self._on_body)

    def _on_body(self, data):
        if self._timeout is not None:
            self.io_loop.remove_timeout(self._timeout)
            self._timeout = None
        original_request = getattr(self.request, "original_request", self.request)
        if self.request.follow_redirects and self.request.max_redirects > 0 and self.code in (301, 302, 303, 307):
            assert isinstance(self.request, _RequestProxy)
            new_request = copy.copy(self.request.request)
            new_request.url = urlparse.urljoin(self.request.url, self.headers["Location"])
            new_request.max_redirects = self.request.max_redirects - 1
            del new_request.headers["Host"]
            # http://www.w3.org/Protocols/rfc2616/rfc2616-sec10.html#sec10.3.4
            # Client SHOULD make a GET request after a 303.
            # According to the spec, 302 should be followed by the same
            # method as the original request, but in practice browsers
            # treat 302 the same as 303, and many servers use 302 for
            # compatibility with pre-HTTP/1.1 user agents which don't
            # understand the 303 status.
            if self.code in (302, 303):
                new_request.method = "GET"
                new_request.body = None
                for h in ["Content-Length", "Content-Type", "Content-Encoding", "Transfer-Encoding"]:
                    try:
                        del self.request.headers[h]
                    except KeyError:
                        pass
            new_request.original_request = original_request
            final_callback = self.final_callback
            self.final_callback = None
            self._release()
            self.client.fetch(new_request, final_callback)
            self.stream.close()
            return
        if self._decompressor:
            data = self._decompressor.decompress(data) + self._decompressor.flush()
        if self.request.streaming_callback:
            if self.chunks is None:
                # if chunks is not None, we already called streaming_callback
                # in _on_chunk_data
                self.request.streaming_callback(data)
            buffer = BytesIO()
        else:
            buffer = BytesIO(data)  # TODO: don't require one big string?
        response = HTTPResponse(
            original_request,
            self.code,
            reason=self.reason,
            headers=self.headers,
            request_time=self.io_loop.time() - self.start_time,
            buffer=buffer,
            effective_url=self.request.url,
        )
        self._run_callback(response)
        self.stream.close()

    def _on_chunk_length(self, data):
        # TODO: "chunk extensions" http://tools.ietf.org/html/rfc2616#section-3.6.1
        length = int(data.strip(), 16)
        if length == 0:
            if self._decompressor is not None:
                tail = self._decompressor.flush()
                if tail:
                    # I believe the tail will always be empty (i.e.
                    # decompress will return all it can).  The purpose
                    # of the flush call is to detect errors such
                    # as truncated input.  But in case it ever returns
                    # anything, treat it as an extra chunk
                    if self.request.streaming_callback is not None:
                        self.request.streaming_callback(tail)
                    else:
                        self.chunks.append(tail)
                # all the data has been decompressed, so we don't need to
                # decompress again in _on_body
                self._decompressor = None
            self._on_body(b"".join(self.chunks))
        else:
            self.stream.read_bytes(length + 2, self._on_chunk_data)  # chunk ends with \r\n

    def _on_chunk_data(self, data):
        assert data[-2:] == b"\r\n"
        chunk = data[:-2]
        if self._decompressor:
            chunk = self._decompressor.decompress(chunk)
        if self.request.streaming_callback is not None:
            self.request.streaming_callback(chunk)
        else:
            self.chunks.append(chunk)
        self.stream.read_until(b"\r\n", self._on_chunk_length)
Beispiel #40
0
class _RedisConnection(object):
    def __init__(self, final_callback, redis_tuple, redis_pwd):
        """
        :param final_callback: resp赋值时调用
        :param redis_tuple: (ip, port, db)
        :param redis_pwd: redis密码
        """
        self.__io_loop = IOLoop.instance()
        self.__resp_cb = final_callback
        self.__stream = None
        #redis应答解析remain
        self.__recv_buf = ''
        self.__redis_tuple = redis_tuple
        self.__redis_pwd = redis_pwd
        #redis指令上下文, connect指令个数(AUTH, SELECT .etc),trans,cmd_count
        self.__cmd_env = deque()
        self.__cache_before_connect = []
        self.__connected = False

    def con_ok(self):
        """
        连接对象是否ok
        :return:
        """
        return self.__connected

    def connect(self, init_future):
        """
        connect指令包括:AUTH, SELECT
        :param init_future: 第一个future对象
        """
        #future, connect_count, transaction, cmd_count
        self.__cmd_env.append((init_future, 1 + int(bool(self.__redis_pwd)), False, 0))
        self.__stream = IOStream(socket.socket(socket.AF_INET, socket.SOCK_STREAM, 0),
                                 io_loop=self.__io_loop)
        self.__stream.set_close_callback(self.__on_close)
        self.__stream.connect(self.__redis_tuple[:2], self.__on_connect)

    def __on_connect(self):
        """连接,只需要发送初始cmd即可
        """
        self.__connected = True
        self.__stream.set_nodelay(True)
        self.__stream.read_until_close(self.__last_closd_recv, self.__on_resp)
        self.__stream.write(chain_select_cmd(self.__redis_pwd, self.__redis_tuple[-1]))
        for x in self.__cache_before_connect:
            self.__stream.write(x)
        self.__cache_before_connect = []

    def write(self, buf, new_future, active_trans, cmd_count):
        """
        :param new_future: 由于闭包的影响,在resp回调函数中会保存上一次的future对象,该对象必须得到更新
        :param active_trans: 事务是否激活
        :param cmd_count: 指令个数
        """
        self.__cmd_env.append((new_future, 0, active_trans, cmd_count))
        if not self.__connected:
            self.__cache_before_connect.append(buf)
            return
        self.__stream.write(buf)

    def __last_closd_recv(self, data):
        """
        socket关闭时最后几个字节
        """
        if not data:
            return
        self.__on_resp(data)

    def __on_resp(self, recv):
        """
        :param recv: 收到的buf
        """
        recv = ''.join((self.__recv_buf, recv))

        idx = 0
        for future, connect, trans, cmd in self.__cmd_env:
            ok, payload, recv = decode_resp_ondemand(recv, connect, trans, cmd)
            if not ok:
                break

            idx += 1
            if not connect:
                self.__run_callback({_RESP_FUTURE: future, RESP_RESULT: payload})

        self.__recv_buf = recv
        for _ in xrange(idx):
            self.__cmd_env.popleft()

    def __run_callback(self, resp):
        if self.__resp_cb is None:
            return
        self.__io_loop.add_callback(self.__resp_cb, resp)

    def __on_close(self):
        self.__connected = False
        while len(self.__cmd_env) > 0:
            self.__run_callback({_RESP_FUTURE: self.__cmd_env.popleft(), RESP_RESULT: 0})
        self.__cmd_env.clear()
class HTTPServerRawTest(AsyncHTTPTestCase):
    def get_app(self):
        return Application([
            ('/echo', EchoHandler),
        ])

    def setUp(self):
        super(HTTPServerRawTest, self).setUp()
        self.stream = IOStream(socket.socket())
        self.stream.connect(('127.0.0.1', self.get_http_port()), self.stop)
        self.wait()

    def tearDown(self):
        self.stream.close()
        super(HTTPServerRawTest, self).tearDown()

    def test_empty_request(self):
        self.stream.close()
        self.io_loop.add_timeout(datetime.timedelta(seconds=0.001), self.stop)
        self.wait()

    def test_malformed_first_line(self):
        with ExpectLog(gen_log, '.*Malformed HTTP request line'):
            self.stream.write(b'asdf\r\n\r\n')
            # TODO: need an async version of ExpectLog so we don't need
            # hard-coded timeouts here.
            self.io_loop.add_timeout(datetime.timedelta(seconds=0.05),
                                     self.stop)
            self.wait()

    def test_malformed_headers(self):
        with ExpectLog(gen_log, '.*Malformed HTTP headers'):
            self.stream.write(b'GET / HTTP/1.0\r\nasdf\r\n\r\n')
            self.io_loop.add_timeout(datetime.timedelta(seconds=0.05),
                                     self.stop)
            self.wait()

    def test_chunked_request_body(self):
        # Chunked requests are not widely supported and we don't have a way
        # to generate them in AsyncHTTPClient, but HTTPServer will read them.
        self.stream.write(b"""\
POST /echo HTTP/1.1
Transfer-Encoding: chunked
Content-Type: application/x-www-form-urlencoded

4
foo=
3
bar
0

""".replace(b"\n", b"\r\n"))
        read_stream_body(self.stream, self.stop)
        headers, response = self.wait()
        self.assertEqual(json_decode(response), {u'foo': [u'bar']})

    def test_chunked_request_uppercase(self):
        # As per RFC 2616 section 3.6, "Transfer-Encoding" header's value is
        # case-insensitive.
        self.stream.write(b"""\
POST /echo HTTP/1.1
Transfer-Encoding: Chunked
Content-Type: application/x-www-form-urlencoded

4
foo=
3
bar
0

""".replace(b"\n", b"\r\n"))
        read_stream_body(self.stream, self.stop)
        headers, response = self.wait()
        self.assertEqual(json_decode(response), {u'foo': [u'bar']})

    def test_invalid_content_length(self):
        with ExpectLog(gen_log, '.*Only integer Content-Length is allowed'):
            self.stream.write(b"""\
POST /echo HTTP/1.1
Content-Length: foo

bar

""".replace(b"\n", b"\r\n"))
            self.stream.read_until_close(self.stop)
            self.wait()
Beispiel #42
0
class Client(RedisCommandsMixin):
    """
        Redis client class
    """
    def __init__(self, io_loop=None):
        """
            Constructor

            :param io_loop:
                Optional IOLoop instance
        """
        self._io_loop = io_loop or IOLoop.instance()

        self._stream = None

        self.reader = None
        self.callbacks = deque()

        self._sub_callback = False

    def connect(self, host='localhost', port=6379, callback=None):
        """
            Connect to redis server

            :param host:
                Host to connect to
            :param port:
                Port
            :param callback:
                Optional callback to be triggered upon connection
        """
        sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM, 0)
        return self._connect(sock, (host, port), callback)

    def connect_usocket(self, usock, callback=None):
        """
            Connect to redis server with unix socket
        """
        sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM, 0)
        return self._connect(sock, usock, callback)

    def on_disconnect(self):
        """
            Override this method if you want to handle disconnections
        """
        pass

    # State
    def is_idle(self):
        """
            Check if client is not waiting for any responses
        """
        return len(self.callbacks) == 0

    def is_connected(self):
        """
            Check if client is still connected
        """
        return bool(self._stream) and not self._stream.closed()

    def send_message(self, args, callback=None):
        """
            Send command to redis

            :param args:
                Arguments to send
            :param callback:
                Callback
        """
        # Special case for pub-sub
        cmd = args[0]

        if (self._sub_callback is not None and
            cmd not in ('PSUBSCRIBE', 'SUBSCRIBE', 'PUNSUBSCRIBE', 'UNSUBSCRIBE')):
            raise ValueError('Cannot run normal command over PUBSUB connection')

        # Send command
        self._stream.write(self.format_message(args))
        if callback is not None:
            callback = stack_context.wrap(callback)
        self.callbacks.append((callback, None))

    def send_messages(self, args_pipeline, callback=None):
        """
            Send command pipeline to redis

            :param args_pipeline:
                Arguments pipeline to send
            :param callback:
                Callback
        """

        if self._sub_callback is not None:
            raise ValueError('Cannot run pipeline over PUBSUB connection')

        # Send command pipeline
        messages = [self.format_message(args) for args in args_pipeline]
        self._stream.write(b"".join(messages))
        if callback is not None:
            callback = stack_context.wrap(callback)
        self.callbacks.append((callback, (len(messages), [])))

    def format_message(self, args):
        """
            Create redis message

            :param args:
                Message data
        """
        l = "*%d" % len(args)
        lines = [l.encode('utf-8')]
        for arg in args:
            if not isinstance(arg, string_types):
                arg = str(arg)
            if isinstance(arg, text_type):
                arg = arg.encode('utf-8')
            l = "$%d" % len(arg)
            lines.append(l.encode('utf-8'))
            lines.append(arg)
        lines.append(b"")
        return b"\r\n".join(lines)

    def close(self):
        """
            Close redis connection
        """
        self.quit()
        self._stream.close()

    # Pub/sub commands
    def psubscribe(self, patterns, callback=None):
        """
            Customized psubscribe command - will keep one callback for all incoming messages

            :param patterns:
                string or list of strings
            :param callback:
                callback
        """
        self._set_sub_callback(callback)
        super(Client, self).psubscribe(patterns)

    def subscribe(self, channels, callback=None):
        """
            Customized subscribe command - will keep one callback for all incoming messages

            :param channels:
                string or list of strings
            :param callback:
                Callback
        """
        self._set_sub_callback(callback)
        super(Client, self).subscribe(channels)

    def _set_sub_callback(self, callback):
        if self._sub_callback is None:
            self._sub_callback = callback

        assert self._sub_callback == callback

    # Helpers
    def _connect(self, sock, addr, callback):
        self._reset()

        self._stream = IOStream(sock, io_loop=self._io_loop)
        self._stream.connect(addr, callback=callback)
        self._stream.read_until_close(self._on_close, self._on_read)

    # Event handlers
    def _on_read(self, data):
        self.reader.feed(data)

        resp = self.reader.gets()

        while resp is not False:
            if self._sub_callback:
                try:
                    self._sub_callback(resp)
                except:
                    logger.exception('SUB callback failed')
            else:
                if self.callbacks:
                    callback, callback_data = self.callbacks[0]
                    if callback_data is None:
                        callback_resp = resp
                    else:
                        # handle pipeline responses
                        num_resp, callback_resp = callback_data
                        callback_resp.append(resp)
                        while len(callback_resp) < num_resp:
                            resp = self.reader.gets()
                            if resp is False:
                                # callback_resp is yet incomplete
                                return
                            callback_resp.append(resp)
                    self.callbacks.popleft()
                    if callback is not None:
                        try:
                            callback(callback_resp)
                        except:
                            logger.exception('Callback failed')
                else:
                    logger.debug('Ignored response: %s' % repr(resp))

            resp = self.reader.gets()

    def _on_close(self, data=None):
        if data is not None:
            self._on_read(data)

        # Trigger any pending callbacks
        callbacks = self.callbacks
        self.callbacks = deque()

        if callbacks:
            for cb in callbacks:
                callback, callback_data = cb
                if callback is not None:
                    try:
                        callback(None)
                    except:
                        logger.exception('Exception in callback')

        if self._sub_callback is not None:
            try:
                self._sub_callback(None)
            except:
                logger.exception('Exception in SUB callback')
            self._sub_callback = None

        # Trigger on_disconnect
        self.on_disconnect()

    def _reset(self):
        self.reader = hiredis.Reader()
        self._sub_callback = None

    def pipeline(self):
        return Pipeline(self)
Beispiel #43
0
class Connection(RedisCommandsMixin):

    def __init__(self, redis, on_connect=None):
        logger.debug('Creating new Redis connection.')
        self.redis = redis
        self.reader = hiredis.Reader()
        self._watch = set()
        self._multi = False
        self.callbacks = deque()
        self._on_connect_callback = on_connect

        self.stream = IOStream(
            socket.socket(redis._family, socket.SOCK_STREAM, 0),
            io_loop=redis._ioloop
        )
        self.stream.set_close_callback(self._on_close)
        self.stream.connect(redis._addr, self._on_connect)


    def _on_connect(self):
        logger.debug('Connected!')
        self.stream.read_until_close(self._on_close, self._on_read)
        self.redis._shared.append(self)
        if self._on_connect_callback is not None:
            self._on_connect_callback(self)
            self._on_connect_callback = None

    def _on_read(self, data):
        self.reader.feed(data)
        while True:
            resp = self.reader.gets()
            if resp is False:
                break
            callback = self.callbacks.popleft()
            if callback is not None:
                self.redis._ioloop.add_callback(partial(callback, resp))

    def is_idle(self):
        return len(self.callbacks) == 0

    def is_shared(self):
        return self in self.redis._shared

    def lock(self):
        if not self.is_shared():
            raise Exception('Connection already is locked!')
        self.redis._shared.remove(self)

    def unlock(self, callback=None):

        def cb(resp):
            assert resp == 'OK'
            self.redis._shared.append(self)

        if self._multi:
            self.send_message(['DISCARD'])
        elif self._watch:
            self.send_message(['UNWATCH'])

        self.send_message(['SELECT', self.redis._database], cb)

    def send_message(self, args, callback=None):

        command = args[0]

        if 'SUBSCRIBE' in command:
            raise NotImplementedError('Not yet.')

        # Do not allow the commands, affecting the execution of other commands,
        # to be used on shared connection.
        if command in ('WATCH', 'MULTI'):
            if self.is_shared():
                raise Exception('Command %s is not allowed while connection '
                                'is shared!' % command)
            if command == 'WATCH':
                self._watch.add(args[1])
            if command == 'MULTI':
                self._multi = True

        # monitor transaction state, to unlock correctly
        if command in ('EXEC', 'DISCARD', 'UNWATCH'):
            if command in ('EXEC', 'DISCARD'):
                self._multi = False
            self._watch.clear()

        self.stream.write(self.format_message(args))

        future = Future()

        if callback is not None:
            future.add_done_callback(stack_context.wrap(callback))

        self.callbacks.append(future.set_result)

        return future

    def format_message(self, args):
        l = "*%d" % len(args)
        lines = [l.encode('utf-8')]
        for arg in args:
            if not isinstance(arg, str):
                arg = str(arg)
            arg = arg.encode('utf-8')
            l = "$%d" % len(arg)
            lines.append(l.encode('utf-8'))
            lines.append(arg)
        lines.append(b"")
        return b"\r\n".join(lines)

    def close(self):
        self.send_command(['QUIT'])
        if self.is_shared():
            self.lock()

    def _on_close(self, data=None):
        logger.debug('Redis connection was closed.')
        if data is not None:
            self._on_read(data)
        if self.is_shared():
            self.lock()
Beispiel #44
0
class KeepAliveTest(AsyncHTTPTestCase):
    """Tests various scenarios for HTTP 1.1 keep-alive support.

    These tests don't use AsyncHTTPClient because we want to control
    connection reuse and closing.
    """
    def get_app(self):
        class HelloHandler(RequestHandler):
            def get(self):
                self.finish("Hello world")

            def post(self):
                self.finish("Hello world")

        class LargeHandler(RequestHandler):
            def get(self):
                # 512KB should be bigger than the socket buffers so it will
                # be written out in chunks.
                self.write("".join(chr(i % 256) * 1024 for i in range(512)))

        class FinishOnCloseHandler(RequestHandler):
            def initialize(self, cleanup_event):
                self.cleanup_event = cleanup_event

            @gen.coroutine
            def get(self):
                self.flush()
                yield self.cleanup_event.wait()

            def on_connection_close(self):
                # This is not very realistic, but finishing the request
                # from the close callback has the right timing to mimic
                # some errors seen in the wild.
                self.finish("closed")

        self.cleanup_event = Event()
        return Application([
            ("/", HelloHandler),
            ("/large", LargeHandler),
            (
                "/finish_on_close",
                FinishOnCloseHandler,
                dict(cleanup_event=self.cleanup_event),
            ),
        ])

    def setUp(self):
        super(KeepAliveTest, self).setUp()
        self.http_version = b"HTTP/1.1"

    def tearDown(self):
        # We just closed the client side of the socket; let the IOLoop run
        # once to make sure the server side got the message.
        self.io_loop.add_timeout(datetime.timedelta(seconds=0.001), self.stop)
        self.wait()

        if hasattr(self, "stream"):
            self.stream.close()
        super(KeepAliveTest, self).tearDown()

    # The next few methods are a crude manual http client
    @gen.coroutine
    def connect(self):
        self.stream = IOStream(socket.socket())
        yield self.stream.connect(("10.0.0.7", self.get_http_port()))

    @gen.coroutine
    def read_headers(self):
        first_line = yield self.stream.read_until(b"\r\n")
        self.assertTrue(first_line.startswith(b"HTTP/1.1 200"), first_line)
        header_bytes = yield self.stream.read_until(b"\r\n\r\n")
        headers = HTTPHeaders.parse(header_bytes.decode("latin1"))
        raise gen.Return(headers)

    @gen.coroutine
    def read_response(self):
        self.headers = yield self.read_headers()
        body = yield self.stream.read_bytes(int(
            self.headers["Content-Length"]))
        self.assertEqual(b"Hello world", body)

    def close(self):
        self.stream.close()
        del self.stream

    @gen_test
    def test_two_requests(self):
        yield self.connect()
        self.stream.write(b"GET / HTTP/1.1\r\n\r\n")
        yield self.read_response()
        self.stream.write(b"GET / HTTP/1.1\r\n\r\n")
        yield self.read_response()
        self.close()

    @gen_test
    def test_request_close(self):
        yield self.connect()
        self.stream.write(b"GET / HTTP/1.1\r\nConnection: close\r\n\r\n")
        yield self.read_response()
        data = yield self.stream.read_until_close()
        self.assertTrue(not data)
        self.assertEqual(self.headers["Connection"], "close")
        self.close()

    # keepalive is supported for http 1.0 too, but it's opt-in
    @gen_test
    def test_http10(self):
        self.http_version = b"HTTP/1.0"
        yield self.connect()
        self.stream.write(b"GET / HTTP/1.0\r\n\r\n")
        yield self.read_response()
        data = yield self.stream.read_until_close()
        self.assertTrue(not data)
        self.assertTrue("Connection" not in self.headers)
        self.close()

    @gen_test
    def test_http10_keepalive(self):
        self.http_version = b"HTTP/1.0"
        yield self.connect()
        self.stream.write(b"GET / HTTP/1.0\r\nConnection: keep-alive\r\n\r\n")
        yield self.read_response()
        self.assertEqual(self.headers["Connection"], "Keep-Alive")
        self.stream.write(b"GET / HTTP/1.0\r\nConnection: keep-alive\r\n\r\n")
        yield self.read_response()
        self.assertEqual(self.headers["Connection"], "Keep-Alive")
        self.close()

    @gen_test
    def test_http10_keepalive_extra_crlf(self):
        self.http_version = b"HTTP/1.0"
        yield self.connect()
        self.stream.write(
            b"GET / HTTP/1.0\r\nConnection: keep-alive\r\n\r\n\r\n")
        yield self.read_response()
        self.assertEqual(self.headers["Connection"], "Keep-Alive")
        self.stream.write(b"GET / HTTP/1.0\r\nConnection: keep-alive\r\n\r\n")
        yield self.read_response()
        self.assertEqual(self.headers["Connection"], "Keep-Alive")
        self.close()

    @gen_test
    def test_pipelined_requests(self):
        yield self.connect()
        self.stream.write(b"GET / HTTP/1.1\r\n\r\nGET / HTTP/1.1\r\n\r\n")
        yield self.read_response()
        yield self.read_response()
        self.close()

    @gen_test
    def test_pipelined_cancel(self):
        yield self.connect()
        self.stream.write(b"GET / HTTP/1.1\r\n\r\nGET / HTTP/1.1\r\n\r\n")
        # only read once
        yield self.read_response()
        self.close()

    @gen_test
    def test_cancel_during_download(self):
        yield self.connect()
        self.stream.write(b"GET /large HTTP/1.1\r\n\r\n")
        yield self.read_headers()
        yield self.stream.read_bytes(1024)
        self.close()

    @gen_test
    def test_finish_while_closed(self):
        yield self.connect()
        self.stream.write(b"GET /finish_on_close HTTP/1.1\r\n\r\n")
        yield self.read_headers()
        self.close()
        # Let the hanging coroutine clean up after itself
        self.cleanup_event.set()

    @gen_test
    def test_keepalive_chunked(self):
        self.http_version = b"HTTP/1.0"
        yield self.connect()
        self.stream.write(b"POST / HTTP/1.0\r\n"
                          b"Connection: keep-alive\r\n"
                          b"Transfer-Encoding: chunked\r\n"
                          b"\r\n"
                          b"0\r\n"
                          b"\r\n")
        yield self.read_response()
        self.assertEqual(self.headers["Connection"], "Keep-Alive")
        self.stream.write(b"GET / HTTP/1.0\r\nConnection: keep-alive\r\n\r\n")
        yield self.read_response()
        self.assertEqual(self.headers["Connection"], "Keep-Alive")
        self.close()
Beispiel #45
0
class Client(RedisCommandsMixin):
    """
        Redis client class
    """
    def __init__(self, io_loop=None):
        """
            Constructor

            :param io_loop:
                Optional IOLoop instance
        """
        self._io_loop = io_loop or IOLoop.instance()

        self._stream = None

        self.reader = None
        self.callbacks = deque()

        self._sub_callback = False

    def connect(self, host='localhost', port=6379, callback=None):
        """
            Connect to redis server

            :param host:
                Host to connect to
            :param port:
                Port
            :param callback:
                Optional callback to be triggered upon connection
        """
        sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM, 0)
        return self._connect(sock, (host, port), callback)

    def connect_usocket(self, usock, callback=None):
        """
            Connect to redis server with unix socket
        """
        sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM, 0)
        return self._connect(sock, usock, callback)

    def on_disconnect(self):
        """
            Override this method if you want to handle disconnections
        """
        pass

    # State
    def is_idle(self):
        """
            Check if client is not waiting for any responses
        """
        return len(self.callbacks) == 0

    def is_connected(self):
        """
            Check if client is still connected
        """
        return bool(self._stream) and not self._stream.closed()

    def send_message(self, args, callback=None):
        """
            Send command to redis

            :param args:
                Arguments to send
            :param callback:
                Callback
        """
        # Special case for pub-sub
        cmd = args[0]

        if (self._sub_callback is not None
                and cmd not in ('PSUBSCRIBE', 'SUBSCRIBE', 'PUNSUBSCRIBE',
                                'UNSUBSCRIBE')):
            raise ValueError(
                'Cannot run normal command over PUBSUB connection')

        # Send command
        self._stream.write(self.format_message(args))
        if callback is not None:
            callback = stack_context.wrap(callback)
        self.callbacks.append(callback)

    def format_message(self, args):
        """
            Create redis message

            :param args:
                Message data
        """
        l = "*%d" % len(args)
        lines = [l.encode('utf-8')]
        for arg in args:
            if not isinstance(arg, basestring):
                arg = str(arg)
            arg = arg.encode('utf-8')
            l = "$%d" % len(arg)
            lines.append(l.encode('utf-8'))
            lines.append(arg)
        lines.append(b"")
        return b"\r\n".join(lines)

    def close(self):
        """
            Close redis connection
        """
        self.quit()
        self._stream.close()

    # Pub/sub commands
    def psubscribe(self, patterns, callback=None):
        """
            Customized psubscribe command - will keep one callback for all incoming messages

            :param patterns:
                string or list of strings
            :param callback:
                callback
        """
        self._set_sub_callback(callback)
        super(Client, self).psubscribe(patterns)

    def subscribe(self, channels, callback=None):
        """
            Customized subscribe command - will keep one callback for all incoming messages

            :param channels:
                string or list of strings
            :param callback:
                Callback
        """
        self._set_sub_callback(callback)
        super(Client, self).subscribe(channels)

    def _set_sub_callback(self, callback):
        if self._sub_callback is None:
            self._sub_callback = callback

        assert self._sub_callback == callback

    # Helpers
    def _connect(self, sock, addr, callback):
        self._reset()

        self._stream = IOStream(sock, io_loop=self._io_loop)
        self._stream.read_until_close(self._on_close, self._on_read)
        self._stream.connect(addr, callback=callback)

    # Event handlers
    def _on_read(self, data):
        self.reader.feed(data)

        resp = self.reader.gets()

        while resp is not False:
            if self._sub_callback:
                try:
                    self._sub_callback(resp)
                except:
                    logger.exception('SUB callback failed')
            else:
                if self.callbacks:
                    callback = self.callbacks.popleft()
                    if callback is not None:
                        try:
                            callback(resp)
                        except:
                            logger.exception('Callback failed')
                else:
                    logger.debug('Ignored response: %s' % repr(resp))

            resp = self.reader.gets()

    def _on_close(self, data=None):
        if data is not None:
            self._on_read(data)

        # Trigger any pending callbacks
        callbacks = self.callbacks
        self.callbacks = deque()

        if callbacks:
            for cb in callbacks:
                if cb is not None:
                    try:
                        cb(None)
                    except:
                        logger.exception('Exception in callback')

        if self._sub_callback is not None:
            try:
                self._sub_callback(None)
            except:
                logger.exception('Exception in SUB callback')
            self._sub_callback = None

        # Trigger on_disconnect
        self.on_disconnect()

    def _reset(self):
        self.reader = hiredis.Reader()
        self._sub_callback = None
Beispiel #46
0
class HTTPServerRawTest(AsyncHTTPTestCase):
    def get_app(self):
        return Application([
            ('/echo', EchoHandler),
        ])

    def setUp(self):
        super(HTTPServerRawTest, self).setUp()
        self.stream = IOStream(socket.socket())
        self.stream.connect(('127.0.0.1', self.get_http_port()), self.stop)
        self.wait()

    def tearDown(self):
        self.stream.close()
        super(HTTPServerRawTest, self).tearDown()

    def test_empty_request(self):
        self.stream.close()
        self.io_loop.add_timeout(datetime.timedelta(seconds=0.001), self.stop)
        self.wait()

    def test_malformed_first_line_response(self):
        with ExpectLog(gen_log, '.*Malformed HTTP request line'):
            self.stream.write(b'asdf\r\n\r\n')
            read_stream_body(self.stream, self.stop)
            start_line, headers, response = self.wait()
            self.assertEqual('HTTP/1.1', start_line.version)
            self.assertEqual(400, start_line.code)
            self.assertEqual('Bad Request', start_line.reason)

    def test_malformed_first_line_log(self):
        with ExpectLog(gen_log, '.*Malformed HTTP request line'):
            self.stream.write(b'asdf\r\n\r\n')
            # TODO: need an async version of ExpectLog so we don't need
            # hard-coded timeouts here.
            self.io_loop.add_timeout(datetime.timedelta(seconds=0.05),
                                     self.stop)
            self.wait()

    def test_malformed_headers(self):
        with ExpectLog(gen_log, '.*Malformed HTTP headers'):
            self.stream.write(b'GET / HTTP/1.0\r\nasdf\r\n\r\n')
            self.io_loop.add_timeout(datetime.timedelta(seconds=0.05),
                                     self.stop)
            self.wait()

    def test_chunked_request_body(self):
        # Chunked requests are not widely supported and we don't have a way
        # to generate them in AsyncHTTPClient, but HTTPServer will read them.
        self.stream.write(b"""\
POST /echo HTTP/1.1
Transfer-Encoding: chunked
Content-Type: application/x-www-form-urlencoded

4
foo=
3
bar
0

""".replace(b"\n", b"\r\n"))
        read_stream_body(self.stream, self.stop)
        start_line, headers, response = self.wait()
        self.assertEqual(json_decode(response), {u'foo': [u'bar']})

    def test_chunked_request_uppercase(self):
        # As per RFC 2616 section 3.6, "Transfer-Encoding" header's value is
        # case-insensitive.
        self.stream.write(b"""\
POST /echo HTTP/1.1
Transfer-Encoding: Chunked
Content-Type: application/x-www-form-urlencoded

4
foo=
3
bar
0

""".replace(b"\n", b"\r\n"))
        read_stream_body(self.stream, self.stop)
        start_line, headers, response = self.wait()
        self.assertEqual(json_decode(response), {u'foo': [u'bar']})

    def test_invalid_content_length(self):
        with ExpectLog(gen_log, '.*Only integer Content-Length is allowed'):
            self.stream.write(b"""\
POST /echo HTTP/1.1
Content-Length: foo

bar

""".replace(b"\n", b"\r\n"))
            self.stream.read_until_close(self.stop)
            self.wait()
class _PrxConn(object):
    def __init__(self, handle_resp, svr_addr):
        assert callable(handle_resp)
        self._io_loop = IOLoop.instance()

        self.__resp_cb = handle_resp
        self.__svr_addr = svr_addr
        self._stream = None
        self._send_buf = deque()
        self._recv_buf = ''
        self.__cmd_env = deque()
        self.__con_ok = False

    def con_ok(self):
        """
        连接不可用
        """
        return self.__con_ok

    def connect(self):
        self._stream = IOStream(socket.socket(socket.AF_INET,
                                              socket.SOCK_STREAM, 0),
                                io_loop=self._io_loop)
        self._stream.set_close_callback(self._on_close)
        self._stream.connect(self.__svr_addr, self._on_connect)

    def _on_connect(self):
        self._stream.set_nodelay(True)
        while len(self._send_buf) > 0:
            self._stream.write(self._send_buf.popleft())
        self._stream.read_until_close(self._last_closd_recv, self._on_recv)
        self.__con_ok = True

    def write(self, future, encode_result):
        self.__cmd_env.append(future)
        if not self.__con_ok:
            self._send_buf.append(encode_result)
        else:
            self._stream.write(encode_result)

    def _last_closd_recv(self, data):
        """
        socket关闭时最后几个字节
        """
        self._on_recv(data)

    def _on_recv(self, buf):
        self._recv_buf += buf
        while 1:
            if not self._recv_buf:
                break
            ok, payload, self._recv_buf = decode_resp_ondemand(
                self._recv_buf, 0, False, 1)
            if not ok:
                break
            if payload and isinstance(payload,
                                      (list, tuple)) and 1 == len(payload):
                payload = payload[0]
            self.__run_callback({
                _RESP_FUTURE: self.__cmd_env.popleft(),
                RESP_RESULT: payload
            })

    def __run_callback(self, resp):
        if self.__resp_cb is None:
            return
        self._io_loop.add_callback(self.__resp_cb, resp)

    def _on_close(self):
        self.__con_ok = False
        while len(self.__cmd_env) > 0:
            self.__run_callback({
                _RESP_FUTURE: self.__cmd_env.popleft(),
                RESP_RESULT: 0
            })
        self.__cmd_env.clear()
Beispiel #48
0
class KeepAliveTest(AsyncHTTPTestCase):
    """Tests various scenarios for HTTP 1.1 keep-alive support.

    These tests don't use AsyncHTTPClient because we want to control
    connection reuse and closing.
    """
    def get_app(self):
        class HelloHandler(RequestHandler):
            def get(self):
                self.finish('Hello world')

            def post(self):
                self.finish('Hello world')

        class LargeHandler(RequestHandler):
            def get(self):
                # 512KB should be bigger than the socket buffers so it will
                # be written out in chunks.
                self.write(''.join(chr(i % 256) * 1024 for i in range(512)))

        class FinishOnCloseHandler(RequestHandler):
            @asynchronous
            def get(self):
                self.flush()

            def on_connection_close(self):
                # This is not very realistic, but finishing the request
                # from the close callback has the right timing to mimic
                # some errors seen in the wild.
                self.finish('closed')

        return Application([('/', HelloHandler),
                            ('/large', LargeHandler),
                            ('/finish_on_close', FinishOnCloseHandler)])

    def setUp(self):
        super(KeepAliveTest, self).setUp()
        self.http_version = b'HTTP/1.1'

    def tearDown(self):
        # We just closed the client side of the socket; let the IOLoop run
        # once to make sure the server side got the message.
        self.io_loop.add_timeout(datetime.timedelta(seconds=0.001), self.stop)
        self.wait()

        if hasattr(self, 'stream'):
            self.stream.close()
        super(KeepAliveTest, self).tearDown()

    # The next few methods are a crude manual http client
    def connect(self):
        self.stream = IOStream(socket.socket())
        self.stream.connect(('127.0.0.1', self.get_http_port()), self.stop)
        self.wait()

    def read_headers(self):
        self.stream.read_until(b'\r\n', self.stop)
        first_line = self.wait()
        self.assertTrue(first_line.startswith(b'HTTP/1.1 200'), first_line)
        self.stream.read_until(b'\r\n\r\n', self.stop)
        header_bytes = self.wait()
        headers = HTTPHeaders.parse(header_bytes.decode('latin1'))
        return headers

    def read_response(self):
        self.headers = self.read_headers()
        self.stream.read_bytes(int(self.headers['Content-Length']), self.stop)
        body = self.wait()
        self.assertEqual(b'Hello world', body)

    def close(self):
        self.stream.close()
        del self.stream

    def test_two_requests(self):
        self.connect()
        self.stream.write(b'GET / HTTP/1.1\r\n\r\n')
        self.read_response()
        self.stream.write(b'GET / HTTP/1.1\r\n\r\n')
        self.read_response()
        self.close()

    def test_request_close(self):
        self.connect()
        self.stream.write(b'GET / HTTP/1.1\r\nConnection: close\r\n\r\n')
        self.read_response()
        self.stream.read_until_close(callback=self.stop)
        data = self.wait()
        self.assertTrue(not data)
        self.assertEqual(self.headers['Connection'], 'close')
        self.close()

    # keepalive is supported for http 1.0 too, but it's opt-in
    def test_http10(self):
        self.http_version = b'HTTP/1.0'
        self.connect()
        self.stream.write(b'GET / HTTP/1.0\r\n\r\n')
        self.read_response()
        self.stream.read_until_close(callback=self.stop)
        data = self.wait()
        self.assertTrue(not data)
        self.assertTrue('Connection' not in self.headers)
        self.close()

    def test_http10_keepalive(self):
        self.http_version = b'HTTP/1.0'
        self.connect()
        self.stream.write(b'GET / HTTP/1.0\r\nConnection: keep-alive\r\n\r\n')
        self.read_response()
        self.assertEqual(self.headers['Connection'], 'Keep-Alive')
        self.stream.write(b'GET / HTTP/1.0\r\nConnection: keep-alive\r\n\r\n')
        self.read_response()
        self.assertEqual(self.headers['Connection'], 'Keep-Alive')
        self.close()

    def test_http10_keepalive_extra_crlf(self):
        self.http_version = b'HTTP/1.0'
        self.connect()
        self.stream.write(b'GET / HTTP/1.0\r\nConnection: keep-alive\r\n\r\n\r\n')
        self.read_response()
        self.assertEqual(self.headers['Connection'], 'Keep-Alive')
        self.stream.write(b'GET / HTTP/1.0\r\nConnection: keep-alive\r\n\r\n')
        self.read_response()
        self.assertEqual(self.headers['Connection'], 'Keep-Alive')
        self.close()

    def test_pipelined_requests(self):
        self.connect()
        self.stream.write(b'GET / HTTP/1.1\r\n\r\nGET / HTTP/1.1\r\n\r\n')
        self.read_response()
        self.read_response()
        self.close()

    def test_pipelined_cancel(self):
        self.connect()
        self.stream.write(b'GET / HTTP/1.1\r\n\r\nGET / HTTP/1.1\r\n\r\n')
        # only read once
        self.read_response()
        self.close()

    def test_cancel_during_download(self):
        self.connect()
        self.stream.write(b'GET /large HTTP/1.1\r\n\r\n')
        self.read_headers()
        self.stream.read_bytes(1024, self.stop)
        self.wait()
        self.close()

    def test_finish_while_closed(self):
        self.connect()
        self.stream.write(b'GET /finish_on_close HTTP/1.1\r\n\r\n')
        self.read_headers()
        self.close()

    def test_keepalive_chunked(self):
        self.http_version = b'HTTP/1.0'
        self.connect()
        self.stream.write(b'POST / HTTP/1.0\r\n'
                          b'Connection: keep-alive\r\n'
                          b'Transfer-Encoding: chunked\r\n'
                          b'\r\n'
                          b'0\r\n'
                          b'\r\n')
        self.read_response()
        self.assertEqual(self.headers['Connection'], 'Keep-Alive')
        self.stream.write(b'GET / HTTP/1.0\r\nConnection: keep-alive\r\n\r\n')
        self.read_response()
        self.assertEqual(self.headers['Connection'], 'Keep-Alive')
        self.close()
Beispiel #49
0
class SerialPortConnection:
    """Models a serial connection to a remote device over a Bluetooth RFCOMM
    link. Provides send and receive functionality (with proper parsing), and
    can track replies to certain requests.
    """

    CHLD_MAP = {
        0: "ReleaseAllHeldOrUDUB",
        1: "ReleaseAllActive,AcceptOther",
        2: "HoldAllActive,AcceptOther",
        3: "AddCallToConference",
        4: "JoinCalls,HangUp"
    }

    CME_ERROR_MAP = {
        0: "AG failure",
        1: "No connection to phone",
        3: "Operation not allowed",
        4: "Operation not supported",
        5: "PH-SIM PIN required",
        10: "SIM not inserted",
        11: "SIM PIN required",
        12: "SIM PUK required",
        13: "SIM failure",
        14: "SIM busy",
        16: "Incorrect password",
        17: "SIM PIN2 required",
        18: "SIM PUK2 required",
        20: "Memory full",
        21: "Invalid index",
        23: "Memory failure",
        24: "Text string too long",
        25: "Invalid text string",
        26: "Dial string too long",
        27: "Invalid dial string",
        30: "No network service",
        31: "Network timeout",
        32: "Emergency calls only"
    }

    def __init__(self, socket, device_path, async_reply_delay, io_loop):
        self._async_reply_delay = async_reply_delay
        self._io_loop = io_loop
        # socket.getpeername() returns different address
        # so use the end of the device path instead
        self._peer = device_path[-17:].replace("_", ":")
        self._remainder = b''
        # <code>: [{} ->
        #   "future": <future>
        #   "handle": <timeout handle>]
        self._reply_q = defaultdict(list)
        self._socket = socket

        self.on_close = None
        self.on_error = None
        self.on_message = None

        self._stream = IOStream(socket=self._socket)
        self._stream.set_close_callback(self._on_close)
        self._stream.read_until_close(streaming_callback=self._data_ready)

    @property
    def peer(self):
        """Returns the address of the remote device.
        """
        return self._peer

    def close(self):
        """Voluntarily closes the RFCOMM connection.
        """
        self._stream.close()

    def _async_timeout(self, code):
        """Called when an expected async reply doesn't arrive in the expected
        timeframe.
        """
        qentry = self._reply_q[code].pop()
        qentry["future"].set_exception(TimeoutError("Did not receive reply."))

    def _data_ready(self, data):
        """Parses data that has been received over the serial connection.
        """
        logger.debug("Received {} bytes from AG over SPC - {}".format(
            len(data), data))
        if len(self._remainder) > 0:
            data = self._remainder + data
            logger.debug("Appended left-over bytes - {}".format(
                self._remainder))

        while True:
            # all AG -> HF messages are <cr><lf> delimited
            try:
                msg, data = data.split(b'\x0d\x0a', 1)
            except ValueError:
                self._remainder = data
                return

            # decode to ASCII, logging but ignoring decode errors
            try:
                msg = msg.decode('ascii', errors='strict')
            except UnicodeDecodeError as e:
                logger.warning("ASCII decode error, going to ignore dodgy "
                               "characters - {}".format(e))
                msg = msg.decode('ascii', errors='ignore')

            try:
                if len(msg) > 0:
                    self._on_message(msg)
            except Exception:
                logger.exception("Message handler threw an unhandled "
                                 "exception with data \"{}\"".format(msg))

            if data == b'':
                self._remainder = b''
                return

    def _on_close(self, *args):
        """The connection was closed by either side.
        """
        self._stream = None
        self._remainder = b''
        logger.info("Serial port connection to AG was closed.")

        # error out any remaining futures
        for lst in self._reply_q.values():
            for item in lst:
                item["future"].set_exception(
                    ConnectionError("Connection was closed."))
        self._reply_q.clear()

        if self.on_close:
            self.on_close()

    def _on_message(self, msg):
        """Invoked with a parsed message that we must now process.
        """

        if msg == "ERROR":
            # cleaner to report errors separately
            if self.on_error:
                self.on_error(None)

        elif msg == "OK":
            # simple ACK
            # get a Future if async tracking
            try:
                qentry = self._reply_q["OK"].pop()
                self._io_loop.remove_timeout(qentry["handle"])
            except IndexError:
                qentry = None
            if qentry:
                qentry["future"].set_result("OK")
            else:
                if self.on_message:
                    self.on_message(code="OK", data=None)

        elif msg == "RING":
            # ringing alert
            if self.on_message:
                self.on_message(code="RING", data=None)

        else:
            # strip leading "+" and split from first ":"
            # e.g. +BRSF: ...
            code, params = msg[1:].split(":", maxsplit=1)

            # shortcut to CME error reporting handler
            if code == "CME ERROR":
                if self.on_error:
                    self.on_error(self._handle_cme_error(params))
                return

            # find a handler function
            func_name = "_handle_{}".format(code.lower())
            try:
                handler = getattr(self, func_name)
            except AttributeError:
                logger.warning(
                    "No handler for code {}, ignoring...".format(code))
                return

            # get a Future if async tracking
            try:
                qentry = self._reply_q[code].pop()
                self._io_loop.remove_timeout(qentry["handle"])
            except IndexError:
                qentry = None

            # execute handler (and deal with Future)
            try:
                ret = handler(params=params.strip())
            except Exception as e:
                logger.error(
                    "Handler threw unhandled exception - {}".format(e))
                if qentry:
                    qentry["future"].set_exception(e)
                return
            if qentry:
                qentry["future"].set_result(ret)
            #else:
            if self.on_message:
                self.on_message(code=code, data=ret)

    def _handle_brsf(self, params):
        """Supported features of the AG.
        """
        params = int(params)

        return {
            "3WAY": (params & 0x1) == 0x1,
            "ECNR": (params & 0x2) == 0x2,
            "VOICE_RECOGNITION": (params & 0x4) == 0x4,
            "INBAND_RING": (params & 0x8) == 0x8,
            "PHONE_VTAG": (params & 0x10) == 0x10,
            "CALL_REJECT": (params & 0x20) == 0x20,
            "ECALL_STAT": (params & 0x40) == 0x40,
            "ECALL_CTRL": (params & 0x80) == 0x80,
            "EXTD_ERROR": (params & 0x100) == 0x100,
            "CODEC_NEG": (params & 0x200) == 0x200,
            "HF_INDICATORS": (params & 0x400) == 0x400,
            "ESCO_S4T2": (params & 0x800) == 0x800
        }

    def _handle_chld(self, params):
        """Info about how 3way/call wait is handled.
        """
        params = ast.literal_eval(params)
        return [SerialPortConnection.CHLD_MAP.get(f, f) for f in params]

    def _handle_ciev(self, params):
        """Single indicator update.
        """
        try:
            params = params.split(",")
            return {self._indmap[int(params[0]) - 1]: params[1]}
        except IndexError:
            logger.debug("Unknown indicator, will ignore it.")

    def _handle_cind(self, params):
        """Indicators available by the AG. This class maps the indices to actual
        names to make it easier upstream.
        """
        # either initial indicator info...
        # ("call",(0,1)),("callsetup",(0-3)),("service",(0-1)),("signal",(0-5)),
        # ("roam",(0,1)),("battchg",(0-5)),("callheld",(0-2))
        if "(" in params:
            params = ast.literal_eval(params)
            self._indmap = dict([(i, name)
                                 for i, (name, _) in enumerate(params)])
            return [name for name, _ in params]

        # ...or initial indicator values
        # 0,0,1,4,0,3,0
        return dict([(self._indmap[i], val)
                     for i, val in enumerate(params.split(","))])

    def _handle_clip(self, params):
        """Contains phone number of calling party (if CLI enabled).
        """
        # "0383417060",129
        if "," in params:
            params = params[:params.index(",")]
        return params.replace("\"", "")

    def _handle_cme_error(self, params):
        """Extended error code.
        """
        return SerialPortConnection.CME_ERROR_MAP.get(int(params), params)

    def _handle_cops(self, params):
        """Network operator query response.
        """
        params = params.split(",")
        return {"mode": params[0], "name": params[2]}

    def _handle_ccwa(self, params):
        """Contains phone number of calling party in a call-waiting scenario
        (if CLI enabled).
        """
        # "0383417060",129
        if "," in params:
            params = params[:params.index(",")]
        return params.replace("\"", "")

    def send_message(self, message, async_reply_code=None):
        """Sends a message. If async is not None, this returns a Future that can
        be yielded. The Future will resolve when the supplied reply code is next
        received. The Future will error-out if no reply is received in the delay
        (seconds) given in the constructor.
        """

        try:
            logger.debug("Sending \"{}\" over SPC.".format(message))
            data = message + "\x0a"
            self._stream.write(data.encode("ascii"))
        except Exception as e:
            logger.exception("Error sending \"{}\" over SPC.".format(message))
            raise ConnectionError(
                "Error sending \"{}\" over SPC.".format(message))

        # async tracking?
        if async_reply_code:
            queue = self._reply_q[async_reply_code]
            fut = Future()
            handle = self._io_loop.call_later(delay=self._async_reply_delay,
                                              callback=self._async_timeout,
                                              code=async_reply_code)
            queue.append({"future": fut, "handle": handle})
            return fut

        return None