Ejemplo n.º 1
0
    def test_indexing_line(self):
        client = AsyncHTTPClient(io_loop=self.io_loop)
        ping = yield client.fetch("http://*****:*****@version'], 1)
        self.assertEqual(doc['message'], "My name is Yuri and I'm 6 years old.")
Ejemplo n.º 2
0
    def test_connection_refused(self):
        # When a connection is refused, the connect callback should not
        # be run.  (The kqueue IOLoop used to behave differently from the
        # epoll IOLoop in this respect)
        cleanup_func, port = refusing_port()
        self.addCleanup(cleanup_func)
        stream = IOStream(socket.socket(), self.io_loop)
        self.connect_called = False

        def connect_callback():
            self.connect_called = True
            self.stop()

        stream.set_close_callback(self.stop)
        # log messages vary by platform and ioloop implementation
        with ExpectLog(gen_log, ".*", required=False):
            stream.connect(("127.0.0.1", port), connect_callback)
            self.wait()
        self.assertFalse(self.connect_called)
        self.assertTrue(isinstance(stream.error, socket.error), stream.error)
        if sys.platform != "cygwin":
            _ERRNO_CONNREFUSED = (errno.ECONNREFUSED,)
            if hasattr(errno, "WSAECONNREFUSED"):
                _ERRNO_CONNREFUSED += (errno.WSAECONNREFUSED,)
            # cygwin's errnos don't match those used on native windows python
            self.assertTrue(stream.error.args[0] in _ERRNO_CONNREFUSED)
Ejemplo n.º 3
0
    def __init__(self, io_loop, request, callback):
        self.io_loop = io_loop
        self.request = request
        self.callback = callback
        self.code = None
        self.headers = None
        self.chunks = None
        with stack_context.StackContext(self.cleanup):
            parsed = urlparse.urlsplit(self.request.url)
            if ":" in parsed.netloc:
                host, _, port = parsed.netloc.partition(":")
                port = int(port)
            else:
                host = parsed.netloc
                port = 443 if parsed.scheme == "https" else 80

            if parsed.scheme == "https":
                # TODO: cert verification, etc
                self.stream = SSLIOStream(socket.socket(),
                                          io_loop=self.io_loop)
            else:
                self.stream = IOStream(socket.socket(),
                                       io_loop=self.io_loop)
            self.stream.connect((host, port),
                                functools.partial(self._on_connect, parsed))
Ejemplo n.º 4
0
 def _create_stream(self, max_buffer_size, af, addr):
     # Always connect in plaintext; we'll convert to ssl if necessary
     # after one connection has completed.
     stream = IOStream(socket.socket(af),
                       io_loop=self.io_loop,
                       max_buffer_size=max_buffer_size)
     return stream.connect(addr)
Ejemplo n.º 5
0
    def connect(self):
        self._loop = IOLoop.current()
        try:
            if self.unix_socket and self.host in ("localhost", "127.0.0.1"):
                sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
                self.host_info = "Localhost via UNIX socket"
                address = self.unix_socket
            else:
                sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM, 0)
                self.host_info = "socket %s:%d" % (self.host, self.port)
                address = (self.host, self.port)
            sock.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1)
            if self.no_delay:
                sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
            sock = IOStream(sock)

            child_gr = greenlet.getcurrent()
            main = child_gr.parent
            assert main is not None, "Execut must be running in child greenlet"

            if self.connect_timeout:

                def timeout():
                    if not self.socket:
                        sock.close((None, IOError("connection timeout")))

                self._loop.call_later(self.connect_timeout, timeout)

            def connected(future):
                if future._exc_info is not None:
                    child_gr.throw(future.exception())
                else:
                    self.socket = sock
                    child_gr.switch()

            future = sock.connect(address)
            self._loop.add_future(future, connected)
            main.switch()

            self._rfile = self.socket
            self._get_server_information()
            self._request_authentication()

            if self.sql_mode is not None:
                c = self.cursor()
                c.execute("SET sql_mode=%s", (self.sql_mode,))

            if self.init_command is not None:
                c = self.cursor()
                c.execute(self.init_command)
                self.commit()

            if self.autocommit_mode is not None:
                self.autocommit(self.autocommit_mode)
        except Exception as e:
            if self.socket:
                self._rfile = None
                self.socket.close()
                self.socket = None
            raise err.OperationalError(2003, "Can't connect to MySQL server on %s:%s (%s)" % (self.host, self.port, e))
Ejemplo n.º 6
0
class Connection(object):
    def __init__(self, host, port, event_handler,
                 stop_after=None, io_loop=None):
        self.host = host
        self.port = port
        self._event_handler = weakref.proxy(event_handler)
        self.timeout = stop_after
        self._stream = None
        self._io_loop = io_loop
        self.try_left = 2

        self.in_progress = False
        self.read_queue = []
        self.read_callbacks = []

    def __del__(self):
        self.disconnect()

    def connect(self):
        if not self._stream:
            try:
                sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM, 0)
                sock.setsockopt(socket.SOL_TCP, socket.TCP_NODELAY, 1)
                sock.settimeout(self.timeout)
                sock.connect((self.host, self.port))
                self._stream = IOStream(sock, io_loop=self._io_loop)
                self._stream.set_close_callback(self.on_stream_close)
                self.connected()
            except socket.error, e:
                raise ConnectionError(str(e))
            self.fire_event('on_connect')
Ejemplo n.º 7
0
class TCPClient(object):
    def __init__(self, io_loop=None):
        self.io_loop = self.io_loop = io_loop or IOLoop.current()

        #self.shutdown = False
        self.sock_fd = socket.socket(socket.AF_INET, socket.SOCK_STREAM, 0)
        self.sock_fd.settimeout(0.5)
        self.stream = IOStream(self.sock_fd)
        #self.stream.set_close_callback(self.on_close)

    def connect(self, host, port):
        #self.stream.connect((self.host, self.port), self.send_message)
        self.stream.connect((host, port))
        return self.stream

    @return_future
    def connect_server(self, host, port, callback=None):
        self.stream.connect((host, port), callback=callback)

    def on_close(self):
        if self.shutdown:
            self.io_loop.stop()

    def set_shutdown(self):
        self.shutdown = True
Ejemplo n.º 8
0
 def test_empty_request(self):
     stream = IOStream(socket.socket(), io_loop=self.io_loop)
     stream.connect(("localhost", self.get_http_port()), self.stop)
     self.wait()
     stream.close()
     self.io_loop.add_timeout(datetime.timedelta(seconds=0.001), self.stop)
     self.wait()
Ejemplo n.º 9
0
 def _create_stream(self, max_buffer_size, af, addr, source_ip=None,
                    source_port=None):
     # Always connect in plaintext; we'll convert to ssl if necessary
     # after one connection has completed.
     source_port_bind = source_port if isinstance(source_port, int) else 0
     source_ip_bind = source_ip
     if source_port_bind and not source_ip:
         # User required a specific port, but did not specify
         # a certain source IP, will bind to the default loopback.
         source_ip_bind = '::1' if af == socket.AF_INET6 else '127.0.0.1'
         # Trying to use the same address family as the requested af socket:
         # - 127.0.0.1 for IPv4
         # - ::1 for IPv6
     socket_obj = socket.socket(af)
     set_close_exec(socket_obj.fileno())
     if source_port_bind or source_ip_bind:
         # If the user requires binding also to a specific IP/port.
         try:
             socket_obj.bind((source_ip_bind, source_port_bind))
         except socket.error:
             socket_obj.close()
             # Fail loudly if unable to use the IP/port.
             raise
     try:
         stream = IOStream(socket_obj,
                           max_buffer_size=max_buffer_size)
     except socket.error as e:
         fu = Future()
         fu.set_exception(e)
         return fu
     else:
         return stream, stream.connect(addr)
Ejemplo n.º 10
0
 def test_gaierror(self):
     # Test that IOStream sets its exc_info on getaddrinfo error
     s = socket.socket(socket.AF_INET, socket.SOCK_STREAM, 0)
     stream = IOStream(s, io_loop=self.io_loop)
     stream.set_close_callback(self.stop)
     stream.connect(('adomainthatdoesntexist.asdf', 54321))
     self.assertTrue(isinstance(stream.error, socket.gaierror), stream.error)
Ejemplo n.º 11
0
class IRCStream(object):
    """
    A connection to an IRC server utilizing IOStream
    """
    def __init__(self, nick, url, io_loop=None):
        self.nick = nick
        self.url = url
        self.io_loop = io_loop or IOLoop.instance()

        parsed = urlparse.urlsplit(self.url)
        assert parsed.scheme == 'irc'
        if ':' in parsed.netloc:
            host, _, port = parsed.netloc.partition(':')
            port = int(port)
        else:
            host = parsed.netloc
            port = 6667
        self.host = host
        self.port = port

    def connect(self, callback):
        self.stream = IOStream(socket.socket(), io_loop=self.io_loop)
        self.stream.connect((self.host, self.port),
                            functools.partial(self._on_connect, callback))


    def _on_connect(self, callback):
        self.stream.write('NICK %s\r\n' % self.nick)
        callback(True)
Ejemplo n.º 12
0
    def connect(self):
        # Creates and returns a connection object for use.

        sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM, 0)
        stream = IOStream(sock)
        stream.connect((self.host, self.port))
        return self.connection(stream, sock.getsockname(), self.data, self.terminator)
Ejemplo n.º 13
0
    def __init__(self, io_loop, request, callback):
        self.start_time = time.time()
        self.io_loop = io_loop
        self.request = request
        self.callback = callback
        self.code = None
        self.headers = None
        self.chunks = None
        self._decompressor = None
        # Timeout handle returned by IOLoop.add_timeout
        self._timeout = None
        with stack_context.StackContext(self.cleanup):
            parsed = urlparse.urlsplit(self.request.url)
            if ":" in parsed.netloc:
                host, _, port = parsed.netloc.partition(":")
                port = int(port)
            else:
                host = parsed.netloc
                port = 443 if parsed.scheme == "https" else 80

            if parsed.scheme == "https":
                # TODO: cert verification, etc
                self.stream = SSLIOStream(socket.socket(),
                                          io_loop=self.io_loop)
            else:
                self.stream = IOStream(socket.socket(),
                                       io_loop=self.io_loop)
            timeout = min(request.connect_timeout, request.request_timeout)
            if timeout:
                self._connect_timeout = self.io_loop.add_timeout(
                    self.start_time + timeout,
                    self._on_timeout)
            self.stream.connect((host, port),
                                functools.partial(self._on_connect, parsed))
Ejemplo n.º 14
0
    def test_handle_stream_coroutine_logging(self):
        # handle_stream may be a coroutine and any exception in its
        # Future will be logged.
        class TestServer(TCPServer):
            @gen.coroutine
            def handle_stream(self, stream, address):
                yield gen.moment
                stream.close()
                1 / 0

        server = client = None
        try:
            sock, port = bind_unused_port()
            with NullContext():
                server = TestServer()
                server.add_socket(sock)
            client = IOStream(socket.socket())
            with ExpectLog(app_log, "Exception in callback"):
                yield client.connect(('localhost', port))
                yield client.read_until_close()
                yield gen.moment
        finally:
            if server is not None:
                server.stop()
            if client is not None:
                client.close()
Ejemplo n.º 15
0
    def __init__(self, io_loop, client, request, callback):
        self.start_time = time.time()
        self.io_loop = io_loop
        self.client = client
        self.request = request
        self.callback = callback
        self.code = None
        self.headers = None
        self.chunks = None
        self._decompressor = None
        # Timeout handle returned by IOLoop.add_timeout
        self._timeout = None
        with stack_context.StackContext(self.cleanup):
            parsed = urlparse.urlsplit(_unicode(self.request.url))
            # urlsplit results have hostname and port results, but they
            # didn't support ipv6 literals until python 2.7.
            netloc = parsed.netloc
            if "@" in netloc:
                userpass, _, netloc = netloc.rpartition("@")
            match = re.match(r"^(.+):(\d+)$", netloc)
            if match:
                host = match.group(1)
                port = int(match.group(2))
            else:
                host = netloc
                port = 443 if parsed.scheme == "https" else 80
            if re.match(r"^\[.*\]$", host):
                # raw ipv6 addresses in urls are enclosed in brackets
                host = host[1:-1]
            if self.client.hostname_mapping is not None:
                host = self.client.hostname_mapping.get(host, host)

            if request.allow_ipv6:
                af = socket.AF_UNSPEC
            else:
                # We only try the first IP we get from getaddrinfo,
                # so restrict to ipv4 by default.
                af = socket.AF_INET

            addrinfo = socket.getaddrinfo(host, port, af, socket.SOCK_STREAM, 0, 0)
            af, socktype, proto, canonname, sockaddr = addrinfo[0]

            if parsed.scheme == "https":
                ssl_options = {}
                if request.validate_cert:
                    ssl_options["cert_reqs"] = ssl.CERT_REQUIRED
                if request.ca_certs is not None:
                    ssl_options["ca_certs"] = request.ca_certs
                else:
                    ssl_options["ca_certs"] = _DEFAULT_CA_CERTS
                self.stream = SSLIOStream(
                    socket.socket(af, socktype, proto), io_loop=self.io_loop, ssl_options=ssl_options
                )
            else:
                self.stream = IOStream(socket.socket(af, socktype, proto), io_loop=self.io_loop)
            timeout = min(request.connect_timeout, request.request_timeout)
            if timeout:
                self._connect_timeout = self.io_loop.add_timeout(self.start_time + timeout, self._on_timeout)
            self.stream.set_close_callback(self._on_close)
            self.stream.connect(sockaddr, functools.partial(self._on_connect, parsed))
Ejemplo n.º 16
0
class ForwardConnection(object):

    def __init__(self, remote_address, stream, address, headers):
        self.remote_address = remote_address
        self.stream = stream
        self.address = address
        self.headers = headers
        sock = socket.socket()
        self.remote_stream = IOStream(sock)
        self.remote_stream.connect(self.remote_address, self._on_remote_connected)    
        self.remote_stream.set_close_callback(self._on_close)    

    def _on_remote_write_complete(self):
        logging.info('send request to %s', self.remote_address)
        self.remote_stream.read_until_close(self._on_remote_read_close)

    def _on_remote_connected(self):
        logging.info('forward %r to %r', self.address, self.remote_address)
        self.remote_stream.write(self.headers, self._on_remote_write_complete)

    def _on_remote_read_close(self, data):
        self.stream.write(data, self.stream.close)

    def _on_close(self):
        logging.info('remote quit %s', self.remote_address)
        self.remote_stream.close()
Ejemplo n.º 17
0
class HTTPServerRawTest(AsyncHTTPTestCase):
    def get_app(self):
        return Application([("/echo", EchoHandler)])

    def setUp(self):
        super(HTTPServerRawTest, self).setUp()
        self.stream = IOStream(socket.socket())
        self.stream.connect(("localhost", self.get_http_port()), self.stop)
        self.wait()

    def tearDown(self):
        self.stream.close()
        super(HTTPServerRawTest, self).tearDown()

    def test_empty_request(self):
        self.stream.close()
        self.io_loop.add_timeout(datetime.timedelta(seconds=0.001), self.stop)
        self.wait()

    def test_malformed_first_line(self):
        with ExpectLog(gen_log, ".*Malformed HTTP request line"):
            self.stream.write(b"asdf\r\n\r\n")
            # TODO: need an async version of ExpectLog so we don't need
            # hard-coded timeouts here.
            self.io_loop.add_timeout(datetime.timedelta(seconds=0.01), self.stop)
            self.wait()

    def test_malformed_headers(self):
        with ExpectLog(gen_log, ".*Malformed HTTP headers"):
            self.stream.write(b"GET / HTTP/1.0\r\nasdf\r\n\r\n")
            self.io_loop.add_timeout(datetime.timedelta(seconds=0.01), self.stop)
            self.wait()
Ejemplo n.º 18
0
    def execute(self, cmd):
        """ Executes `cmd` on host and returns results

        Creates socket and tries to execute command against zookeeper. Socket
        is limited by quasi-Tornado's timeout. It doesn't check validity of response.

        Note:
            Timeout should be implemented using tornado.concurrent.chain_future:
            https://github.com/tornadoweb/tornado/blob/master/tornado/concurrent.py#L316

            such a wrapper exists in Tornado 4.0+ - with_timeout
            https://github.com/tornadoweb/tornado/blob/master/tornado/gen.py#L507

        Args:
            cmd: Four-letter string containing command to execute
        Returns:
            Raw response - bytes.
        Raises:
            HostConnectionTimeout: If sum times of connection, request, respons exceeds timeout
            Socket Errors: like ECONNNECTIONREFUSED,...
        """

        ioloop = IOLoop.current()
        address_family, addr = yield self._resolve(ioloop)
        stream = IOStream(socket.socket(address_family), io_loop=ioloop)
        stream.connect(addr)
        cmd = '{}\n'.format(cmd.strip())
        yield gen.Task(stream.write, cmd.encode('utf-8'))
        data = yield gen.Task(stream.read_until_close)
        raise gen.Return(data)
Ejemplo n.º 19
0
def handle_connection(connection, address):
    log.info("Connection received from %s" % str(address))
    stream = IOStream(connection, ioloop, max_buffer_size=1024 * 1024 * 1024)
    # Getting uuid
    try:
        stream.read_bytes(4, partial(read_uuid_size, stream))
    except StreamClosedError:
        log.warn("Closed stream for getting uuid length")
Ejemplo n.º 20
0
def handle_connection(connection, address):
    log.info('Connection received from %s' % str(address))
    stream = IOStream(connection, ioloop)
    # Getting uuid
    try:
        stream.read_bytes(4, partial(read_uuid_size, stream))
    except StreamClosedError:
        log.warn('Closed stream for getting uuid length')
Ejemplo n.º 21
0
    def connect(self):
        IOStream.connect(self, ("localhost", self._port), self._on_connect)
        MjpgClient.clients[self._camera_id] = self

        logging.debug(
            "mjpg client for camera %(camera_id)s connecting on port %(port)s..."
            % {"port": self._port, "camera_id": self._camera_id}
        )
Ejemplo n.º 22
0
 def initiate(cls, host, port, infohash):
     af = socket.AF_INET
     addrinfo = socket.getaddrinfo(host, port, af, socket.SOCK_STREAM,
                                       0, 0)
     af, socktype, proto, canonname, sockaddr = addrinfo[0]
     stream = IOStream(socket.socket(af, socktype, proto),
                            io_loop=cls.io_loop)
     stream.connect(sockaddr, functools.partial(cls.initiate_connected, stream, sockaddr, infohash))
Ejemplo n.º 23
0
    def _handle_accept(self, fd, events):
        connection, address = self._socket.accept()
        stream = IOStream(connection)
        host = "%s:%d" % address #host = ":".join(str(i) for i in address)
        self._streams[host] = stream

        ccb = functools.partial(self._handle_close, host) #same as: cb =  lambda : self._handle_close(host)
        stream.set_close_callback(ccb)
        stream.read_until("\r\n", functools.partial(self._handle_read, host))
Ejemplo n.º 24
0
class SubProcessApplication(Application):
    """Run application class in subprocess."""

    def __init__(self, target, io_loop=None):
        Application.__init__(self)
        if isinstance(target, str):
            self.target = self._load_module(target)
        else:
            self.target = target
        if io_loop is None:
            self.io_loop = IOLoop().instance()
        else:
            self.io_loop = io_loop
        self._process = None
        self.socket = None
        self.runner = None

    def start(self, config):
        signal.signal(signal.SIGCHLD, self._sigchld)
        self.socket, child = socket.socketpair()
        self.runner = _Subprocess(self.target, child, config)
        self._process = multiprocessing.Process(target=self.runner.run)
        self._process.start()
        child.close()
        self.ios = IOStream(self.socket, self.io_loop)
        self.ios.read_until('\r\n', self._receiver)

    def _close(self, timeout):
        self._process.join(timeout)

    def _sigchld(self, signum, frame):
        self._close(0.5)

    def _receiver(self, data):
        """Receive data from subprocess. Forward to session."""
        msg = json.loads(binascii.a2b_base64(data))
        if self.session:
            self.session.send(msg)
        else:
            logging.error("from app: %s", str(msg))

    def stop(self):
        self._process.terminate()
        self._close(2.0)

    def send(self, data):
        """Send data to application."""
        self.ios.write(data + '\r\n')

    def received(self, data):
        """Handle data from session. Forward to subprocess for handling."""
        self.send(binascii.b2a_base64(json.dumps(data)))
        return True

    def _load_module(self, modulename):
        import importlib
        return importlib.import_module(modulename)
Ejemplo n.º 25
0
    def _on_resolve(self, addrinfo):
        af, sockaddr = addrinfo[0]

        if self.parsed.scheme == "https":
            ssl_options = {}
            if self.request.validate_cert:
                ssl_options["cert_reqs"] = ssl.CERT_REQUIRED
            if self.request.ca_certs is not None:
                ssl_options["ca_certs"] = self.request.ca_certs
            else:
                ssl_options["ca_certs"] = _DEFAULT_CA_CERTS
            if self.request.client_key is not None:
                ssl_options["keyfile"] = self.request.client_key
            if self.request.client_cert is not None:
                ssl_options["certfile"] = self.request.client_cert

            # SSL interoperability is tricky.  We want to disable
            # SSLv2 for security reasons; it wasn't disabled by default
            # until openssl 1.0.  The best way to do this is to use
            # the SSL_OP_NO_SSLv2, but that wasn't exposed to python
            # until 3.2.  Python 2.7 adds the ciphers argument, which
            # can also be used to disable SSLv2.  As a last resort
            # on python 2.6, we set ssl_version to SSLv3.  This is
            # more narrow than we'd like since it also breaks
            # compatibility with servers configured for TLSv1 only,
            # but nearly all servers support SSLv3:
            # http://blog.ivanristic.com/2011/09/ssl-survey-protocol-support.html
            if sys.version_info >= (2, 7):
                ssl_options["ciphers"] = "DEFAULT:!SSLv2"
            else:
                # This is really only necessary for pre-1.0 versions
                # of openssl, but python 2.6 doesn't expose version
                # information.
                ssl_options["ssl_version"] = ssl.PROTOCOL_SSLv3

            self.stream = SSLIOStream(socket.socket(af),
                                      io_loop=self.io_loop,
                                      ssl_options=ssl_options,
                                      max_buffer_size=self.max_buffer_size)
        else:
            self.stream = IOStream(socket.socket(af),
                                   io_loop=self.io_loop,
                                   max_buffer_size=self.max_buffer_size)
        timeout = min(self.request.connect_timeout, self.request.request_timeout)
        if timeout:
            self._timeout = self.io_loop.add_timeout(
                self.start_time + timeout,
                stack_context.wrap(self._on_timeout))
        # ivan: commenting out close_callback as we close the steam from close,
        # and we don't want get exception 599 in this case
        # self.stream.set_close_callback(self._on_close)

        # ipv6 addresses are broken (in self.parsed.hostname) until
        # 2.7, here is correctly parsed value calculated in __init__
        self.stream.connect(sockaddr, self._on_connect,
                            server_hostname=self.parsed_hostname)
Ejemplo n.º 26
0
class Connection(object):
    def __init__(self, host='localhost', port=6379, weak_event_handler=None,
                 stop_after=None, io_loop=None):
        self.host = host
        self.port = port
        if weak_event_handler:
            self._event_handler = weak_event_handler
        else:
            self._event_handler = None
        self.timeout = stop_after
        self._stream = None
        self._io_loop = io_loop

        self.in_progress = False
        self.read_callbacks = []
        self.ready_callbacks = deque()
        self._lock = 0
        self.info = {'db': 0}

    def __del__(self):
        self.disconnect()

    def execute_pending_command(self):
        # Continue with the pending command execution
        # if all read operations are completed.
        if not self.read_callbacks and self.ready_callbacks:
            # Pop a SINGLE callback from the queue and execute it.
            # The next one will be executed from the code
            # invoked by the callback
            callback = self.ready_callbacks.popleft()
            callback()

    def ready(self):
        return (not self.read_callbacks and
                not self.ready_callbacks)

    def wait_until_ready(self, callback=None):
        if callback:
            if not self.ready():
                self.ready_callbacks.append(callback)
            else:
                callback()

    def connect(self):
        if not self._stream:
            try:
                sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM, 0)
                sock.setsockopt(socket.SOL_TCP, socket.TCP_NODELAY, 1)
                sock.settimeout(self.timeout)
                sock.connect((self.host, self.port))
                self._stream = IOStream(sock, io_loop=self._io_loop)
                self._stream.set_close_callback(self.on_stream_close)
                self.info['db'] = 0
            except socket.error, e:
                raise ConnectionError(str(e))
            self.fire_event('on_connect')
Ejemplo n.º 27
0
    def test_read_until_close(self):
        s = socket.socket(socket.AF_INET, socket.SOCK_STREAM, 0)
        s.connect(("localhost", self.get_http_port()))
        stream = IOStream(s, io_loop=self.io_loop)
        stream.write(b("GET / HTTP/1.0\r\n\r\n"))

        stream.read_until_close(self.stop)
        data = self.wait()
        self.assertTrue(data.startswith(b("HTTP/1.0 200")))
        self.assertTrue(data.endswith(b("Hello")))
Ejemplo n.º 28
0
 def _create_stream(self, max_buffer_size, af, addr):
     # Always connect in plaintext; we'll convert to ssl if necessary
     # after one connection has completed.
     s = socket.socket(af)
     log.debug("connect:%s" % af)
     s.bind((config.bind_ip, 0))
     stream = IOStream(s,
                       io_loop=self.io_loop,
                       max_buffer_size=max_buffer_size)
     return stream.connect(addr)
Ejemplo n.º 29
0
def new_stream(ip, port, callback=None):
    """ Create, connect and return a stream in blocking mode.
        This is for longterm connection use, for async connection
        see `async_stream_task`
    """
    # TODO: handle exception on IOStream.connect()
    s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
    stream = IOStream(s)
    stream.connect((ip, port), callback=callback)
    return stream
Ejemplo n.º 30
0
class ForwardConnection(object):
    def __init__(self, remote_address, stream, address):
        self.remote_address = remote_address
        self.stream = stream
        self.address = address
        sock = socket.socket()
        self.remote_stream = IOStream(sock)
        self.remote_stream.connect(self.remote_address, self._on_remote_connected)

    def _on_remote_connected(self):
        logging.info("forward %r to %r", self.address, self.remote_address)
        self.remote_stream.read_until_close(self._on_remote_read_close, self.stream.write)
        self.stream.read_until_close(self._on_read_close, self.remote_stream.write)

    def _on_remote_read_close(self, data):
        if self.stream.writing():
            self.stream.write(data, self.stream.close)
        else:
            self.stream.close()

    def _on_read_close(self, data):
        if self.remote_stream.writing():
            self.remote_stream.write(data, self.remote_stream.close)
        else:
            self.remote_stream.close()
Ejemplo n.º 31
0
    def __init__(self, io_loop, client, request, callback):
        self.start_time = time.time()
        self.io_loop = io_loop
        self.client = client
        self.request = request
        self.callback = callback
        self.code = None
        self.headers = None
        self.chunks = None
        self._decompressor = None
        # Timeout handle returned by IOLoop.add_timeout
        self._timeout = None
        with stack_context.StackContext(self.cleanup):
            parsed = urlparse.urlsplit(_unicode(self.request.url))
            if ssl is None and parsed.scheme == "https":
                raise ValueError("HTTPS requires either python2.6+ or "
                                 "curl_httpclient")
            if parsed.scheme not in ("http", "https"):
                raise ValueError("Unsupported url scheme: %s" %
                                 self.request.url)
            # urlsplit results have hostname and port results, but they
            # didn't support ipv6 literals until python 2.7.
            netloc = parsed.netloc
            if "@" in netloc:
                userpass, _, netloc = netloc.rpartition("@")
            match = re.match(r'^(.+):(\d+)$', netloc)
            if match:
                host = match.group(1)
                port = int(match.group(2))
            else:
                host = netloc
                port = 443 if parsed.scheme == "https" else 80
            if re.match(r'^\[.*\]$', host):
                # raw ipv6 addresses in urls are enclosed in brackets
                host = host[1:-1]
            if self.client.hostname_mapping is not None:
                host = self.client.hostname_mapping.get(host, host)

            if request.allow_ipv6:
                af = socket.AF_UNSPEC
            else:
                # We only try the first IP we get from getaddrinfo,
                # so restrict to ipv4 by default.
                af = socket.AF_INET

            addrinfo = socket.getaddrinfo(host, port, af, socket.SOCK_STREAM,
                                          0, 0)
            af, socktype, proto, canonname, sockaddr = addrinfo[0]

            if parsed.scheme == "https":
                ssl_options = {}
                if request.validate_cert:
                    ssl_options["cert_reqs"] = ssl.CERT_REQUIRED
                if request.ca_certs is not None:
                    ssl_options["ca_certs"] = request.ca_certs
                else:
                    ssl_options["ca_certs"] = _DEFAULT_CA_CERTS
                if request.client_key is not None:
                    ssl_options["keyfile"] = request.client_key
                if request.client_cert is not None:
                    ssl_options["certfile"] = request.client_cert
                self.stream = SSLIOStream(socket.socket(af, socktype, proto),
                                          io_loop=self.io_loop,
                                          ssl_options=ssl_options)
            else:
                self.stream = IOStream(socket.socket(af, socktype, proto),
                                       io_loop=self.io_loop)
            timeout = min(request.connect_timeout, request.request_timeout)
            if timeout:
                self._connect_timeout = self.io_loop.add_timeout(
                    self.start_time + timeout, self._on_timeout)
            self.stream.set_close_callback(self._on_close)
            self.stream.connect(sockaddr,
                                functools.partial(self._on_connect, parsed))
Ejemplo n.º 32
0
 def _make_client_iostream(self, connection, **kwargs):
     return IOStream(connection, **kwargs)
Ejemplo n.º 33
0
 def makeSocket(self, timeout=1):
     """"""
     s = socket.socket(socket.AF_INET, socket.SOCK_STREAM, 0)
     self._stream = IOStream(s)
     self._stream.connect((self.host, self.port))
     return s
Ejemplo n.º 34
0
 def accept_callback(conn, address):
     stream = IOStream(conn, io_loop=self.io_loop)
     stream.read_until(b"\r\n\r\n",
                       functools.partial(write_response, stream))
Ejemplo n.º 35
0
class UnixSocketTest(AsyncTestCase):
    """HTTPServers can listen on Unix sockets too.

    Why would you want to do this?  Nginx can proxy to backends listening
    on unix sockets, for one thing (and managing a namespace for unix
    sockets can be easier than managing a bunch of TCP port numbers).

    Unfortunately, there's no way to specify a unix socket in a url for
    an HTTP client, so we have to test this by hand.
    """
    def setUp(self):
        super(UnixSocketTest, self).setUp()
        self.tmpdir = tempfile.mkdtemp()
        self.sockfile = os.path.join(self.tmpdir, "test.sock")
        sock = netutil.bind_unix_socket(self.sockfile)
        app = Application([("/hello", HelloWorldRequestHandler)])
        self.server = HTTPServer(app)
        self.server.add_socket(sock)
        self.stream = IOStream(socket.socket(socket.AF_UNIX))
        self.stream.connect(self.sockfile, self.stop)
        self.wait()

    def tearDown(self):
        self.stream.close()
        self.io_loop.run_sync(self.server.close_all_connections)
        self.server.stop()
        shutil.rmtree(self.tmpdir)
        super(UnixSocketTest, self).tearDown()

    def test_unix_socket(self):
        self.stream.write(b"GET /hello HTTP/1.0\r\n\r\n")
        self.stream.read_until(b"\r\n", self.stop)
        response = self.wait()
        self.assertEqual(response, b"HTTP/1.1 200 OK\r\n")
        self.stream.read_until(b"\r\n\r\n", self.stop)
        headers = HTTPHeaders.parse(self.wait().decode('latin1'))
        self.stream.read_bytes(int(headers["Content-Length"]), self.stop)
        body = self.wait()
        self.assertEqual(body, b"Hello world")

    def test_unix_socket_bad_request(self):
        # Unix sockets don't have remote addresses so they just return an
        # empty string.
        with ExpectLog(gen_log, "Malformed HTTP message from"):
            self.stream.write(b"garbage\r\n\r\n")
            self.stream.read_until_close(self.stop)
            response = self.wait()
        self.assertEqual(response, b"HTTP/1.1 400 Bad Request\r\n\r\n")
 def __init__(self, target_ip, target_port, message):
     s = socket.socket(socket.AF_INET, socket.SOCK_STREAM, 0)
     self.__message = message
     self.__stream = IOStream(s)
     self.__stream.connect((target_ip, target_port), self.__sending_action)
     pass
Ejemplo n.º 37
0
class Client(RedisCommandsMixin):
    """
        Redis client class
    """
    def __init__(self, io_loop=None):
        """
            Constructor

            :param io_loop:
                Optional IOLoop instance
        """
        self._io_loop = io_loop or IOLoop.instance()

        self._stream = None

        self.reader = None
        self.callbacks = deque()

        self._sub_callback = False

    def connect(self, host='localhost', port=6379, callback=None):
        """
            Connect to redis server

            :param host:
                Host to connect to
            :param port:
                Port
            :param callback:
                Optional callback to be triggered upon connection
        """
        sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM, 0)
        return self._connect(sock, (host, port), callback)

    def connect_usocket(self, usock, callback=None):
        """
            Connect to redis server with unix socket
        """
        sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM, 0)
        return self._connect(sock, usock, callback)

    def on_disconnect(self):
        """
            Override this method if you want to handle disconnections
        """
        pass

    # State
    def is_idle(self):
        """
            Check if client is not waiting for any responses
        """
        return len(self.callbacks) == 0

    def is_connected(self):
        """
            Check if client is still connected
        """
        return bool(self._stream) and not self._stream.closed()

    def send_message(self, args, callback=None):
        """
            Send command to redis

            :param args:
                Arguments to send
            :param callback:
                Callback
        """
        # Special case for pub-sub
        cmd = args[0]

        if (self._sub_callback is not None and
            cmd not in ('PSUBSCRIBE', 'SUBSCRIBE', 'PUNSUBSCRIBE', 'UNSUBSCRIBE', 'QUIT')):
            raise ValueError('Cannot run normal command over PUBSUB connection')

        # Send command
        self._stream.write(self.format_message(args))
        if callback is not None:
            callback = stack_context.wrap(callback)
        self.callbacks.append((callback, None))

    def send_messages(self, args_pipeline, callback=None):
        """
            Send command pipeline to redis

            :param args_pipeline:
                Arguments pipeline to send
            :param callback:
                Callback
        """

        if self._sub_callback is not None:
            raise ValueError('Cannot run pipeline over PUBSUB connection')

        # Send command pipeline
        messages = [self.format_message(args) for args in args_pipeline]
        self._stream.write(b"".join(messages))
        if callback is not None:
            callback = stack_context.wrap(callback)
        self.callbacks.append((callback, (len(messages), [])))

    def format_message(self, args):
        """
            Create redis message

            :param args:
                Message data
        """
        l = "*%d" % len(args)
        lines = [l.encode('utf-8')]
        for arg in args:
            if not isinstance(arg, (string_types, bytes,)):
                arg = str(arg)
            if isinstance(arg, text_type):
                arg = arg.encode('utf-8')
            l = "$%d" % len(arg)
            lines.append(l.encode('utf-8'))
            lines.append(arg)
        lines.append(b"")
        return b"\r\n".join(lines)

    def close(self):
        """
            Close redis connection
        """
        self.quit()
        self._stream.close()

    # Pub/sub commands
    def psubscribe(self, patterns, callback=None):
        """
            Customized psubscribe command - will keep one callback for all incoming messages

            :param patterns:
                string or list of strings
            :param callback:
                callback
        """
        self._set_sub_callback(callback)
        super(Client, self).psubscribe(patterns)

    def subscribe(self, channels, callback=None):
        """
            Customized subscribe command - will keep one callback for all incoming messages

            :param channels:
                string or list of strings
            :param callback:
                Callback
        """
        self._set_sub_callback(callback)
        super(Client, self).subscribe(channels)

    def _set_sub_callback(self, callback):
        if self._sub_callback is None:
            self._sub_callback = callback

        assert self._sub_callback == callback

    # Helpers
    def _connect(self, sock, addr, callback):
        self._reset()

        self._stream = IOStream(sock, io_loop=self._io_loop)
        self._stream.connect(addr, callback=callback)
        self._stream.read_until_close(self._on_close, self._on_read)

    # Event handlers
    def _on_read(self, data):
        self.reader.feed(data)

        resp = self.reader.gets()

        while resp is not False:
            if self._sub_callback:
                try:
                    self._sub_callback(resp)
                except:
                    logger.exception('SUB callback failed')
            else:
                if self.callbacks:
                    callback, callback_data = self.callbacks[0]
                    if callback_data is None:
                        callback_resp = resp
                    else:
                        # handle pipeline responses
                        num_resp, callback_resp = callback_data
                        callback_resp.append(resp)
                        while len(callback_resp) < num_resp:
                            resp = self.reader.gets()
                            if resp is False:
                                # callback_resp is yet incomplete
                                return
                            callback_resp.append(resp)
                    self.callbacks.popleft()
                    if callback is not None:
                        try:
                            callback(callback_resp)
                        except:
                            logger.exception('Callback failed')
                else:
                    logger.debug('Ignored response: %s' % repr(resp))

            resp = self.reader.gets()

    def _on_close(self, data=None):
        if data is not None:
            self._on_read(data)

        # Trigger any pending callbacks
        callbacks = self.callbacks
        self.callbacks = deque()

        if callbacks:
            for cb in callbacks:
                callback, callback_data = cb
                if callback is not None:
                    try:
                        callback(None)
                    except:
                        logger.exception('Exception in callback')

        if self._sub_callback is not None:
            try:
                self._sub_callback(None)
            except:
                logger.exception('Exception in SUB callback')
            self._sub_callback = None

        # Trigger on_disconnect
        self.on_disconnect()

    def _reset(self):
        self.reader = hiredis.Reader()
        self._sub_callback = None

    def pipeline(self):
        return Pipeline(self)
Ejemplo n.º 38
0
 def test_connection_close(self):
     s = socket.socket(socket.AF_INET, socket.SOCK_STREAM, 0)
     s.connect(("localhost", self.get_http_port()))
     self.stream = IOStream(s, io_loop=self.io_loop)
     self.stream.write(b("GET / HTTP/1.0\r\n\r\n"))
     self.wait()
Ejemplo n.º 39
0
class TestIOStreamStartTLS(AsyncTestCase):
    def setUp(self):
        try:
            super(TestIOStreamStartTLS, self).setUp()
            self.listener, self.port = bind_unused_port()
            self.server_stream = None
            self.server_accepted = Future()  # type: Future[None]
            netutil.add_accept_handler(self.listener, self.accept)
            self.client_stream = IOStream(socket.socket())
            self.io_loop.add_future(
                self.client_stream.connect(('127.0.0.1', self.port)),
                self.stop)
            self.wait()
            self.io_loop.add_future(self.server_accepted, self.stop)
            self.wait()
        except Exception as e:
            print(e)
            raise

    def tearDown(self):
        if self.server_stream is not None:
            self.server_stream.close()
        if self.client_stream is not None:
            self.client_stream.close()
        self.listener.close()
        super(TestIOStreamStartTLS, self).tearDown()

    def accept(self, connection, address):
        if self.server_stream is not None:
            self.fail("should only get one connection")
        self.server_stream = IOStream(connection)
        self.server_accepted.set_result(None)

    @gen.coroutine
    def client_send_line(self, line):
        self.client_stream.write(line)
        recv_line = yield self.server_stream.read_until(b"\r\n")
        self.assertEqual(line, recv_line)

    @gen.coroutine
    def server_send_line(self, line):
        self.server_stream.write(line)
        recv_line = yield self.client_stream.read_until(b"\r\n")
        self.assertEqual(line, recv_line)

    def client_start_tls(self, ssl_options=None, server_hostname=None):
        client_stream = self.client_stream
        self.client_stream = None
        return client_stream.start_tls(False, ssl_options, server_hostname)

    def server_start_tls(self, ssl_options=None):
        server_stream = self.server_stream
        self.server_stream = None
        return server_stream.start_tls(True, ssl_options)

    @gen_test
    def test_start_tls_smtp(self):
        # This flow is simplified from RFC 3207 section 5.
        # We don't really need all of this, but it helps to make sure
        # that after realistic back-and-forth traffic the buffers end up
        # in a sane state.
        yield self.server_send_line(b"220 mail.example.com ready\r\n")
        yield self.client_send_line(b"EHLO mail.example.com\r\n")
        yield self.server_send_line(b"250-mail.example.com welcome\r\n")
        yield self.server_send_line(b"250 STARTTLS\r\n")
        yield self.client_send_line(b"STARTTLS\r\n")
        yield self.server_send_line(b"220 Go ahead\r\n")
        client_future = self.client_start_tls(dict(cert_reqs=ssl.CERT_NONE))
        server_future = self.server_start_tls(_server_ssl_options())
        self.client_stream = yield client_future
        self.server_stream = yield server_future
        self.assertTrue(isinstance(self.client_stream, SSLIOStream))
        self.assertTrue(isinstance(self.server_stream, SSLIOStream))
        yield self.client_send_line(b"EHLO mail.example.com\r\n")
        yield self.server_send_line(b"250 mail.example.com welcome\r\n")

    @gen_test
    def test_handshake_fail(self):
        server_future = self.server_start_tls(_server_ssl_options())
        # Certificates are verified with the default configuration.
        client_future = self.client_start_tls(server_hostname="localhost")
        with ExpectLog(gen_log, "SSL Error"):
            with self.assertRaises(ssl.SSLError):
                yield client_future
        with self.assertRaises((ssl.SSLError, socket.error)):
            yield server_future

    @gen_test
    def test_check_hostname(self):
        # Test that server_hostname parameter to start_tls is being used.
        # The check_hostname functionality is only available in python 2.7 and
        # up and in python 3.4 and up.
        server_future = self.server_start_tls(_server_ssl_options())
        client_future = self.client_start_tls(ssl.create_default_context(),
                                              server_hostname='127.0.0.1')
        with ExpectLog(gen_log, "SSL Error"):
            with self.assertRaises(ssl.SSLError):
                # The client fails to connect with an SSL error.
                yield client_future
        with self.assertRaises(Exception):
            # The server fails to connect, but the exact error is unspecified.
            yield server_future
Ejemplo n.º 40
0
 def accept_callback(conn, address):
     # fake an HTTP server using chunked encoding where the final chunks
     # and connection close all happen at once
     stream = IOStream(conn, io_loop=self.io_loop)
     stream.read_until(b"\r\n\r\n",
                       functools.partial(write_response, stream))
Ejemplo n.º 41
0
 def accept(self, connection, address):
     if self.server_stream is not None:
         self.fail("should only get one connection")
     self.server_stream = IOStream(connection)
     self.server_accepted.set_result(None)
Ejemplo n.º 42
0
 def test_100_continue(self):
     # Run through a 100-continue interaction by hand:
     # When given Expect: 100-continue, we get a 100 response after the
     # headers, and then the real response after the body.
     stream = IOStream(socket.socket())
     stream.connect(("127.0.0.1", self.get_http_port()), callback=self.stop)
     self.wait()
     stream.write(b"\r\n".join([
         b"POST /hello HTTP/1.1", b"Content-Length: 1024",
         b"Expect: 100-continue", b"Connection: close", b"\r\n"
     ]),
                  callback=self.stop)
     self.wait()
     stream.read_until(b"\r\n\r\n", self.stop)
     data = self.wait()
     self.assertTrue(data.startswith(b"HTTP/1.1 100 "), data)
     stream.write(b"a" * 1024)
     stream.read_until(b"\r\n", self.stop)
     first_line = self.wait()
     self.assertTrue(first_line.startswith(b"HTTP/1.1 200"), first_line)
     stream.read_until(b"\r\n\r\n", self.stop)
     header_data = self.wait()
     headers = HTTPHeaders.parse(native_str(header_data.decode('latin1')))
     stream.read_bytes(int(headers["Content-Length"]), self.stop)
     body = self.wait()
     self.assertEqual(body, b"Got 1024 bytes in POST")
     stream.close()
Ejemplo n.º 43
0
 def connect(self):
     stream = IOStream(socket.socket())
     stream.connect(('127.0.0.1', self.get_http_port()), self.stop)
     self.wait()
     self.streams.append(stream)
     return stream
Ejemplo n.º 44
0
 def test_body_size_override_reset(self):
     # The max_body_size override is reset between requests.
     stream = IOStream(socket.socket())
     try:
         yield stream.connect(('127.0.0.1', self.get_http_port()))
         # Use a raw stream so we can make sure it's all on one connection.
         stream.write(b'PUT /streaming?expected_size=10240 HTTP/1.1\r\n'
                      b'Content-Length: 10240\r\n\r\n')
         stream.write(b'a' * 10240)
         start_line, headers, response = yield gen.Task(
             read_stream_body, stream)
         self.assertEqual(response, b'10240')
         # Without the ?expected_size parameter, we get the old default value
         stream.write(b'PUT /streaming HTTP/1.1\r\n'
                      b'Content-Length: 10240\r\n\r\n')
         with ExpectLog(gen_log, '.*Content-Length too long'):
             data = yield stream.read_until_close()
         self.assertEqual(data, b'HTTP/1.1 400 Bad Request\r\n\r\n')
     finally:
         stream.close()
Ejemplo n.º 45
0
class _HTTPConnection(object):
    _SUPPORTED_METHODS = set(["GET", "HEAD", "POST", "PUT", "DELETE"])

    def __init__(self, io_loop, client, request, callback):
        self.start_time = time.time()
        self.io_loop = io_loop
        self.client = client
        self.request = request
        self.callback = callback
        self.code = None
        self.headers = None
        self.chunks = None
        self._decompressor = None
        # Timeout handle returned by IOLoop.add_timeout
        self._timeout = None
        with stack_context.StackContext(self.cleanup):
            parsed = urlparse.urlsplit(_unicode(self.request.url))
            if ssl is None and parsed.scheme == "https":
                raise ValueError("HTTPS requires either python2.6+ or "
                                 "curl_httpclient")
            if parsed.scheme not in ("http", "https"):
                raise ValueError("Unsupported url scheme: %s" %
                                 self.request.url)
            # urlsplit results have hostname and port results, but they
            # didn't support ipv6 literals until python 2.7.
            netloc = parsed.netloc
            if "@" in netloc:
                userpass, _, netloc = netloc.rpartition("@")
            match = re.match(r'^(.+):(\d+)$', netloc)
            if match:
                host = match.group(1)
                port = int(match.group(2))
            else:
                host = netloc
                port = 443 if parsed.scheme == "https" else 80
            if re.match(r'^\[.*\]$', host):
                # raw ipv6 addresses in urls are enclosed in brackets
                host = host[1:-1]
            if self.client.hostname_mapping is not None:
                host = self.client.hostname_mapping.get(host, host)

            if request.allow_ipv6:
                af = socket.AF_UNSPEC
            else:
                # We only try the first IP we get from getaddrinfo,
                # so restrict to ipv4 by default.
                af = socket.AF_INET

            addrinfo = socket.getaddrinfo(host, port, af, socket.SOCK_STREAM,
                                          0, 0)
            af, socktype, proto, canonname, sockaddr = addrinfo[0]

            if parsed.scheme == "https":
                ssl_options = {}
                if request.validate_cert:
                    ssl_options["cert_reqs"] = ssl.CERT_REQUIRED
                if request.ca_certs is not None:
                    ssl_options["ca_certs"] = request.ca_certs
                else:
                    ssl_options["ca_certs"] = _DEFAULT_CA_CERTS
                if request.client_key is not None:
                    ssl_options["keyfile"] = request.client_key
                if request.client_cert is not None:
                    ssl_options["certfile"] = request.client_cert
                self.stream = SSLIOStream(socket.socket(af, socktype, proto),
                                          io_loop=self.io_loop,
                                          ssl_options=ssl_options)
            else:
                self.stream = IOStream(socket.socket(af, socktype, proto),
                                       io_loop=self.io_loop)
            timeout = min(request.connect_timeout, request.request_timeout)
            if timeout:
                self._connect_timeout = self.io_loop.add_timeout(
                    self.start_time + timeout, self._on_timeout)
            self.stream.set_close_callback(self._on_close)
            self.stream.connect(sockaddr,
                                functools.partial(self._on_connect, parsed))

    def _on_timeout(self):
        self._timeout = None
        if self.callback is not None:
            self.callback(
                HTTPResponse(self.request,
                             599,
                             headers=self.headers,
                             error=HTTPError(599, "Timeout")))
            self.callback = None
        self.stream.close()

    def _on_connect(self, parsed):
        if self._timeout is not None:
            self.io_loop.remove_callback(self._timeout)
            self._timeout = None
        if self.request.request_timeout:
            self._timeout = self.io_loop.add_timeout(
                self.start_time + self.request.request_timeout,
                self._on_timeout)
        if (self.request.validate_cert
                and isinstance(self.stream, SSLIOStream)):
            match_hostname(self.stream.socket.getpeercert(), parsed.hostname)
        if (self.request.method not in self._SUPPORTED_METHODS
                and not self.request.allow_nonstandard_methods):
            raise KeyError("unknown method %s" % self.request.method)
        for key in ('network_interface', 'proxy_host', 'proxy_port',
                    'proxy_username', 'proxy_password'):
            if getattr(self.request, key, None):
                raise NotImplementedError('%s not supported' % key)
        if "Host" not in self.request.headers:
            self.request.headers["Host"] = parsed.netloc
        username, password = None, None
        if parsed.username is not None:
            username, password = parsed.username, parsed.password
        elif self.request.auth_username is not None:
            username = self.request.auth_username
            password = self.request.auth_password
        if username is not None:
            auth = utf8(username) + b(":") + utf8(password)
            self.request.headers["Authorization"] = (b("Basic ") +
                                                     base64.b64encode(auth))
        if self.request.user_agent:
            self.request.headers["User-Agent"] = self.request.user_agent
        has_body = self.request.method in ("POST", "PUT")
        if has_body:
            assert self.request.body is not None
            self.request.headers["Content-Length"] = str(len(
                self.request.body))
        else:
            assert self.request.body is None
        if (self.request.method == "POST"
                and "Content-Type" not in self.request.headers):
            self.request.headers[
                "Content-Type"] = "application/x-www-form-urlencoded"
        if self.request.use_gzip:
            self.request.headers["Accept-Encoding"] = "gzip"
        req_path = ((parsed.path or '/') +
                    (('?' + parsed.query) if parsed.query else ''))
        request_lines = [
            utf8("%s %s HTTP/1.1" % (self.request.method, req_path))
        ]
        for k, v in self.request.headers.get_all():
            line = utf8(k) + b(": ") + utf8(v)
            if b('\n') in line:
                raise ValueError('Newline in header: ' + repr(line))
            request_lines.append(line)
        self.stream.write(b("\r\n").join(request_lines) + b("\r\n\r\n"))
        if has_body:
            self.stream.write(self.request.body)
        self.stream.read_until(b("\r\n\r\n"), self._on_headers)

    @contextlib.contextmanager
    def cleanup(self):
        try:
            yield
        except Exception, e:
            logging.warning("uncaught exception", exc_info=True)
            if self.callback is not None:
                callback = self.callback
                self.callback = None
                callback(
                    HTTPResponse(self.request,
                                 599,
                                 error=e,
                                 headers=self.headers))
Ejemplo n.º 46
0
 def setUp(self):
     super(HTTPServerRawTest, self).setUp()
     self.stream = IOStream(socket.socket())
     self.stream.connect(('127.0.0.1', self.get_http_port()), self.stop)
     self.wait()
Ejemplo n.º 47
0
class KeepAliveTest(AsyncHTTPTestCase):
    """Tests various scenarios for HTTP 1.1 keep-alive support.

    These tests don't use AsyncHTTPClient because we want to control
    connection reuse and closing.
    """
    def get_app(self):
        class HelloHandler(RequestHandler):
            def get(self):
                self.finish('Hello world')

            def post(self):
                self.finish('Hello world')

        class LargeHandler(RequestHandler):
            def get(self):
                # 512KB should be bigger than the socket buffers so it will
                # be written out in chunks.
                self.write(''.join(chr(i % 256) * 1024 for i in range(512)))

        class FinishOnCloseHandler(RequestHandler):
            @asynchronous
            def get(self):
                self.flush()

            def on_connection_close(self):
                # This is not very realistic, but finishing the request
                # from the close callback has the right timing to mimic
                # some errors seen in the wild.
                self.finish('closed')

        return Application([('/', HelloHandler), ('/large', LargeHandler),
                            ('/finish_on_close', FinishOnCloseHandler)])

    def setUp(self):
        super(KeepAliveTest, self).setUp()
        self.http_version = b'HTTP/1.1'

    def tearDown(self):
        # We just closed the client side of the socket; let the IOLoop run
        # once to make sure the server side got the message.
        self.io_loop.add_timeout(datetime.timedelta(seconds=0.001), self.stop)
        self.wait()

        if hasattr(self, 'stream'):
            self.stream.close()
        super(KeepAliveTest, self).tearDown()

    # The next few methods are a crude manual http client
    def connect(self):
        self.stream = IOStream(socket.socket())
        self.stream.connect(('127.0.0.1', self.get_http_port()), self.stop)
        self.wait()

    def read_headers(self):
        self.stream.read_until(b'\r\n', self.stop)
        first_line = self.wait()
        self.assertTrue(first_line.startswith(b'HTTP/1.1 200'), first_line)
        self.stream.read_until(b'\r\n\r\n', self.stop)
        header_bytes = self.wait()
        headers = HTTPHeaders.parse(header_bytes.decode('latin1'))
        return headers

    def read_response(self):
        self.headers = self.read_headers()
        self.stream.read_bytes(int(self.headers['Content-Length']), self.stop)
        body = self.wait()
        self.assertEqual(b'Hello world', body)

    def close(self):
        self.stream.close()
        del self.stream

    def test_two_requests(self):
        self.connect()
        self.stream.write(b'GET / HTTP/1.1\r\n\r\n')
        self.read_response()
        self.stream.write(b'GET / HTTP/1.1\r\n\r\n')
        self.read_response()
        self.close()

    def test_request_close(self):
        self.connect()
        self.stream.write(b'GET / HTTP/1.1\r\nConnection: close\r\n\r\n')
        self.read_response()
        self.stream.read_until_close(callback=self.stop)
        data = self.wait()
        self.assertTrue(not data)
        self.assertEqual(self.headers['Connection'], 'close')
        self.close()

    # keepalive is supported for http 1.0 too, but it's opt-in
    def test_http10(self):
        self.http_version = b'HTTP/1.0'
        self.connect()
        self.stream.write(b'GET / HTTP/1.0\r\n\r\n')
        self.read_response()
        self.stream.read_until_close(callback=self.stop)
        data = self.wait()
        self.assertTrue(not data)
        self.assertTrue('Connection' not in self.headers)
        self.close()

    def test_http10_keepalive(self):
        self.http_version = b'HTTP/1.0'
        self.connect()
        self.stream.write(b'GET / HTTP/1.0\r\nConnection: keep-alive\r\n\r\n')
        self.read_response()
        self.assertEqual(self.headers['Connection'], 'Keep-Alive')
        self.stream.write(b'GET / HTTP/1.0\r\nConnection: keep-alive\r\n\r\n')
        self.read_response()
        self.assertEqual(self.headers['Connection'], 'Keep-Alive')
        self.close()

    def test_http10_keepalive_extra_crlf(self):
        self.http_version = b'HTTP/1.0'
        self.connect()
        self.stream.write(
            b'GET / HTTP/1.0\r\nConnection: keep-alive\r\n\r\n\r\n')
        self.read_response()
        self.assertEqual(self.headers['Connection'], 'Keep-Alive')
        self.stream.write(b'GET / HTTP/1.0\r\nConnection: keep-alive\r\n\r\n')
        self.read_response()
        self.assertEqual(self.headers['Connection'], 'Keep-Alive')
        self.close()

    def test_pipelined_requests(self):
        self.connect()
        self.stream.write(b'GET / HTTP/1.1\r\n\r\nGET / HTTP/1.1\r\n\r\n')
        self.read_response()
        self.read_response()
        self.close()

    def test_pipelined_cancel(self):
        self.connect()
        self.stream.write(b'GET / HTTP/1.1\r\n\r\nGET / HTTP/1.1\r\n\r\n')
        # only read once
        self.read_response()
        self.close()

    def test_cancel_during_download(self):
        self.connect()
        self.stream.write(b'GET /large HTTP/1.1\r\n\r\n')
        self.read_headers()
        self.stream.read_bytes(1024, self.stop)
        self.wait()
        self.close()

    def test_finish_while_closed(self):
        self.connect()
        self.stream.write(b'GET /finish_on_close HTTP/1.1\r\n\r\n')
        self.read_headers()
        self.close()

    def test_keepalive_chunked(self):
        self.http_version = b'HTTP/1.0'
        self.connect()
        self.stream.write(b'POST / HTTP/1.0\r\n'
                          b'Connection: keep-alive\r\n'
                          b'Transfer-Encoding: chunked\r\n'
                          b'\r\n'
                          b'0\r\n'
                          b'\r\n')
        self.read_response()
        self.assertEqual(self.headers['Connection'], 'Keep-Alive')
        self.stream.write(b'GET / HTTP/1.0\r\nConnection: keep-alive\r\n\r\n')
        self.read_response()
        self.assertEqual(self.headers['Connection'], 'Keep-Alive')
        self.close()
Ejemplo n.º 48
0
 def accept_callback(conn, addr):
     self.server_stream = IOStream(conn)
     self.addCleanup(self.server_stream.close)
     event.set()
Ejemplo n.º 49
0
                                             **self.ssl_options)
            except ssl.SSLError, err:
                if err.args[0] == ssl.SSL_ERROR_EOF:
                    return connection.close()
                else:
                    raise
            except socket.error, err:
                if err.args[0] == errno.ECONNABORTED:
                    return connection.close()
                else:
                    raise
        try:
            if self.ssl_options is not None:
                stream = SSLIOStream(connection, io_loop=self.io_loop)
            else:
                stream = IOStream(connection, io_loop=self.io_loop)
            self.handle_stream(stream, address)
        except Exception:
            logging.error("Error in connection callback", exc_info=True)


def bind_sockets(port, address=None, family=socket.AF_UNSPEC, backlog=128):
    """Creates listening sockets bound to the given port and address.

    Returns a list of socket objects (multiple sockets are returned if
    the given address maps to multiple IP addresses, which is most common
    for mixed IPv4 and IPv6 use).

    Address may be either an IP address or hostname.  If it's a hostname,
    the server will listen on all IP addresses associated with the
    name.  Address may be an empty string or None to listen on all
Ejemplo n.º 50
0
class IPCClient(object):
    '''
    A Tornado IPC client very similar to Tornado's TCPClient class
    but using either UNIX domain sockets or TCP sockets

    This was written because Tornado does not have its own IPC
    server/client implementation.

    :param IOLoop io_loop: A Tornado ioloop to handle scheduling
    :param str/int socket_path: A path on the filesystem where a socket
                                belonging to a running IPCServer can be
                                found.
                                It may also be of type 'int', in which
                                case it is used as the port for a tcp
                                localhost connection.
    '''

    # Create singleton map between two sockets
    instance_map = weakref.WeakKeyDictionary()

    def __new__(cls, socket_path, io_loop=None):
        io_loop = io_loop or tornado.ioloop.IOLoop.current()
        if io_loop not in IPCClient.instance_map:
            IPCClient.instance_map[io_loop] = weakref.WeakValueDictionary()
        loop_instance_map = IPCClient.instance_map[io_loop]

        # FIXME
        key = str(socket_path)

        if key not in loop_instance_map:
            log.debug('Initializing new IPCClient for path: {0}'.format(key))
            new_client = object.__new__(cls)
            # FIXME
            new_client.__singleton_init__(io_loop=io_loop, socket_path=socket_path)
            loop_instance_map[key] = new_client
        else:
            log.debug('Re-using IPCClient for {0}'.format(key))
        return loop_instance_map[key]

    def __singleton_init__(self, socket_path, io_loop=None):
        '''
        Create a new IPC client

        IPC clients cannot bind to ports, but must connect to
        existing IPC servers. Clients can then send messages
        to the server.

        '''
        self.io_loop = io_loop or tornado.ioloop.IOLoop.current()
        self.socket_path = socket_path
        self._closing = False
        self.stream = None
        if six.PY2:
            encoding = None
        else:
            encoding = 'utf-8'
        self.unpacker = msgpack.Unpacker(encoding=encoding)

    def __init__(self, socket_path, io_loop=None):
        # Handled by singleton __new__
        pass

    def connected(self):
        return self.stream is not None and not self.stream.closed()

    def connect(self, callback=None, timeout=None):
        '''
        Connect to the IPC socket
        '''
        if hasattr(self, '_connecting_future') and not self._connecting_future.done():  # pylint: disable=E0203
            future = self._connecting_future  # pylint: disable=E0203
        else:
            if hasattr(self, '_connecting_future'):
                # read previous future result to prevent the "unhandled future exception" error
                self._connecting_future.exc_info()  # pylint: disable=E0203
            future = tornado.concurrent.Future()
            self._connecting_future = future
            self._connect(timeout=timeout)

        if callback is not None:
            def handle_future(future):
                response = future.result()
                self.io_loop.add_callback(callback, response)
            future.add_done_callback(handle_future)

        return future

    @tornado.gen.coroutine
    def _connect(self, timeout=None):
        '''
        Connect to a running IPCServer
        '''
        if isinstance(self.socket_path, int):
            sock_type = socket.AF_INET
            sock_addr = ('127.0.0.1', self.socket_path)
        else:
            sock_type = socket.AF_UNIX
            sock_addr = self.socket_path

        self.stream = None
        if timeout is not None:
            timeout_at = time.time() + timeout

        while True:
            if self._closing:
                break

            if self.stream is None:
                self.stream = IOStream(
                    socket.socket(sock_type, socket.SOCK_STREAM),
                    io_loop=self.io_loop,
                )

            try:
                log.trace('IPCClient: Connecting to socket: {0}'.format(self.socket_path))
                yield self.stream.connect(sock_addr)
                self._connecting_future.set_result(True)
                break
            except Exception as e:
                if self.stream.closed():
                    self.stream = None

                if timeout is None or time.time() > timeout_at:
                    if self.stream is not None:
                        self.stream.close()
                        self.stream = None
                    self._connecting_future.set_exception(e)
                    break

                yield tornado.gen.sleep(1)

    def __del__(self):
        self.close()

    def close(self):
        '''
        Routines to handle any cleanup before the instance shuts down.
        Sockets and filehandles should be closed explicitly, to prevent
        leaks.
        '''
        if self._closing:
            return
        self._closing = True
        if self.stream is not None and not self.stream.closed():
            self.stream.close()

        # Remove the entry from the instance map so
        # that a closed entry may not be reused.
        # This forces this operation even if the reference
        # count of the entry has not yet gone to zero.
        if self.io_loop in IPCClient.instance_map:
            loop_instance_map = IPCClient.instance_map[self.io_loop]
            key = str(self.socket_path)
            if key in loop_instance_map:
                del loop_instance_map[key]
Ejemplo n.º 51
0
    def _connect(self, sock, addr, callback):
        self._reset()

        self._stream = IOStream(sock, io_loop=self._io_loop)
        self._stream.connect(addr, callback=callback)
        self._stream.read_until_close(self._on_close, self._on_read)
Ejemplo n.º 52
0
class PAConnection(object):
    def __init__(self, host, port, io_loop, key):
        self.io_loop = io_loop
        self.resolver = Resolver()
        self._callbacks = {}
        self._connected = False
        self.queue = deque()
        self.key = key
        self.stream = None
        self.pepv_act_resp = None
        self.prof = {}
        with stack_context.ExceptionStackContext(self._handle_exception):
            self.resolver.resolve(host,
                                  port,
                                  socket.AF_INET,
                                  callback=self._on_resolve)

    def _handle_exception(self, typ, value, tb):
        gen_log.exception("pa connection error [%s] [%s] %s", typ, value, tb)

    def _on_resolve(self, addrinfo):
        af = addrinfo[0][0]
        self.stream = IOStream(socket.socket(af))
        self.stream.set_nodelay(True)
        self.stream.set_close_callback(self._on_close)
        sockaddr = addrinfo[0][1]
        # gen_log.info("sock addr {0}".format(sockaddr))
        self.stream.connect(sockaddr, self._on_connect)

    def _on_close(self):
        gen_log.info("pa conn close")

    def _on_connect(self):
        # gen_log.info("start conn to pa")
        self._connected = True
        self.stream.write('\xab\xcd')  # magic number of act protocol
        # gen_log.info('write data {0}'.format(repr(encode_act_key(self.key))))
        self.stream.write(encode_act_key(self.key))
        self._process_queue()
        self.stream.read_bytes(4, self._on_id)

    def _on_id(self, data):
        resp = ActResponse()
        resp.Id = bytes2int(data)
        self.prof[resp.Id].append(time.time())
        self.pepv_act_resp = resp
        self.stream.read_bytes(4, self._on_rlen)

    def _on_rlen(self, data):
        self.stream.read_bytes(bytes2int(data), self._on_res_body)

    def _on_res_body(self, data):
        resp = self.pepv_act_resp
        resp.result = data
        cb = self._callbacks[resp.Id]
        t = time.time()
        # gen_log.info(
        #     "ActID[{0}]: {1}, {2}, {3}, {4}, {5}, {6}, {7}".format(resp.Id, self.prof[resp.Id][0], self.prof[resp.Id][1], self.prof[resp.Id][2], t, self.prof[resp.Id][1]-self.prof[resp.Id][0], self.prof[resp.Id][2]-self.prof[resp.Id][0], t-self.prof[resp.Id][0]))
        del self.prof[resp.Id]
        del self._callbacks[resp.Id]
        # self.io_loop.add_callback(cb, resp)
        cb(resp)
        self.stream.read_bytes(4, self._on_id)

    def fetch(self, act_request, callback):
        if act_request.Id in self._callbacks:
            gen_log.error("act Id {0} already in cbs !!".format(
                act_request.Id))
        self._callbacks[act_request.Id] = callback
        self.prof[act_request.Id] = [
            time.time(),
        ]
        self.queue.append(act_request)
        self._process_queue()

    def _process_queue(self):
        if not self._connected:
            # gen_log.info("act connection not ready, wait an other turn")
            return
        with stack_context.NullContext():
            while self.queue:
                act_request = self.queue.popleft()
                self.prof[act_request.Id].append(time.time())
                self.stream.write(act_request.encode_body())
Ejemplo n.º 53
0
class HTTPServerRawTest(AsyncHTTPTestCase):
    def get_app(self):
        return Application([
            ('/echo', EchoHandler),
        ])

    def setUp(self):
        super(HTTPServerRawTest, self).setUp()
        self.stream = IOStream(socket.socket())
        self.stream.connect(('127.0.0.1', self.get_http_port()), self.stop)
        self.wait()

    def tearDown(self):
        self.stream.close()
        super(HTTPServerRawTest, self).tearDown()

    def test_empty_request(self):
        self.stream.close()
        self.io_loop.add_timeout(datetime.timedelta(seconds=0.001), self.stop)
        self.wait()

    def test_malformed_first_line(self):
        with ExpectLog(gen_log, '.*Malformed HTTP request line'):
            self.stream.write(b'asdf\r\n\r\n')
            # TODO: need an async version of ExpectLog so we don't need
            # hard-coded timeouts here.
            self.io_loop.add_timeout(datetime.timedelta(seconds=0.01),
                                     self.stop)
            self.wait()

    def test_malformed_headers(self):
        with ExpectLog(gen_log, '.*Malformed HTTP headers'):
            self.stream.write(b'GET / HTTP/1.0\r\nasdf\r\n\r\n')
            self.io_loop.add_timeout(datetime.timedelta(seconds=0.01),
                                     self.stop)
            self.wait()

    def test_chunked_request_body(self):
        # Chunked requests are not widely supported and we don't have a way
        # to generate them in AsyncHTTPClient, but HTTPServer will read them.
        self.stream.write(b"""\
POST /echo HTTP/1.1
Transfer-Encoding: chunked
Content-Type: application/x-www-form-urlencoded

4
foo=
3
bar
0

""".replace(b"\n", b"\r\n"))
        read_stream_body(self.stream, self.stop)
        headers, response = self.wait()
        self.assertEqual(json_decode(response), {u('foo'): [u('bar')]})
Ejemplo n.º 54
0
 def connect(self):
     self.stream = IOStream(socket.socket())
     self.stream.connect(('127.0.0.1', self.get_http_port()), self.stop)
     self.wait()
Ejemplo n.º 55
0
 def _make_client_iostream(self):
     return IOStream(socket.socket(), io_loop=self.io_loop)
Ejemplo n.º 56
0
 def connect(self):
     IOStream.connect(self, ('localhost', self._port), self._on_connect)
Ejemplo n.º 57
0
 def _make_client_iostream(self):
     return IOStream(socket.socket())
Ejemplo n.º 58
0
class HTTPServerRawTest(AsyncHTTPTestCase):
    def get_app(self):
        return Application([
            ('/echo', EchoHandler),
        ])

    def setUp(self):
        super(HTTPServerRawTest, self).setUp()
        self.stream = IOStream(socket.socket())
        self.stream.connect(('127.0.0.1', self.get_http_port()), self.stop)
        self.wait()

    def tearDown(self):
        self.stream.close()
        super(HTTPServerRawTest, self).tearDown()

    def test_empty_request(self):
        self.stream.close()
        self.io_loop.add_timeout(datetime.timedelta(seconds=0.001), self.stop)
        self.wait()

    def test_malformed_first_line_response(self):
        with ExpectLog(gen_log, '.*Malformed HTTP request line'):
            self.stream.write(b'asdf\r\n\r\n')
            read_stream_body(self.stream, self.stop)
            start_line, headers, response = self.wait()
            self.assertEqual('HTTP/1.1', start_line.version)
            self.assertEqual(400, start_line.code)
            self.assertEqual('Bad Request', start_line.reason)

    def test_malformed_first_line_log(self):
        with ExpectLog(gen_log, '.*Malformed HTTP request line'):
            self.stream.write(b'asdf\r\n\r\n')
            # TODO: need an async version of ExpectLog so we don't need
            # hard-coded timeouts here.
            self.io_loop.add_timeout(datetime.timedelta(seconds=0.05),
                                     self.stop)
            self.wait()

    def test_malformed_headers(self):
        with ExpectLog(gen_log, '.*Malformed HTTP headers'):
            self.stream.write(b'GET / HTTP/1.0\r\nasdf\r\n\r\n')
            self.io_loop.add_timeout(datetime.timedelta(seconds=0.05),
                                     self.stop)
            self.wait()

    def test_chunked_request_body(self):
        # Chunked requests are not widely supported and we don't have a way
        # to generate them in AsyncHTTPClient, but HTTPServer will read them.
        self.stream.write(b"""\
POST /echo HTTP/1.1
Transfer-Encoding: chunked
Content-Type: application/x-www-form-urlencoded

4
foo=
3
bar
0

""".replace(b"\n", b"\r\n"))
        read_stream_body(self.stream, self.stop)
        start_line, headers, response = self.wait()
        self.assertEqual(json_decode(response), {u'foo': [u'bar']})

    def test_chunked_request_uppercase(self):
        # As per RFC 2616 section 3.6, "Transfer-Encoding" header's value is
        # case-insensitive.
        self.stream.write(b"""\
POST /echo HTTP/1.1
Transfer-Encoding: Chunked
Content-Type: application/x-www-form-urlencoded

4
foo=
3
bar
0

""".replace(b"\n", b"\r\n"))
        read_stream_body(self.stream, self.stop)
        start_line, headers, response = self.wait()
        self.assertEqual(json_decode(response), {u'foo': [u'bar']})

    def test_invalid_content_length(self):
        with ExpectLog(gen_log, '.*Only integer Content-Length is allowed'):
            self.stream.write(b"""\
POST /echo HTTP/1.1
Content-Length: foo

bar

""".replace(b"\n", b"\r\n"))
            self.stream.read_until_close(self.stop)
            self.wait()
Ejemplo n.º 59
0
 def _make_server_iostream(self, connection, **kwargs):
     return IOStream(connection, **kwargs)
Ejemplo n.º 60
0
 def test_unix_socket(self):
     sockfile = os.path.join(self.tmpdir, "test.sock")
     sock = netutil.bind_unix_socket(sockfile)
     app = Application([("/hello", HelloWorldRequestHandler)])
     server = HTTPServer(app, io_loop=self.io_loop)
     server.add_socket(sock)
     stream = IOStream(socket.socket(socket.AF_UNIX), io_loop=self.io_loop)
     stream.connect(sockfile, self.stop)
     self.wait()
     stream.write(b("GET /hello HTTP/1.0\r\n\r\n"))
     stream.read_until(b("\r\n"), self.stop)
     response = self.wait()
     self.assertEqual(response, b("HTTP/1.0 200 OK\r\n"))
     stream.read_until(b("\r\n\r\n"), self.stop)
     headers = HTTPHeaders.parse(self.wait().decode('latin1'))
     stream.read_bytes(int(headers["Content-Length"]), self.stop)
     body = self.wait()
     self.assertEqual(body, b("Hello world"))
     stream.close()
     server.stop()