Ejemplo n.º 1
0
class ManualCapClient(BaseCapClient):
    def capitalize(self, request_data, callback=None):
        logging.info("capitalize")
        self.request_data = request_data
        self.stream = IOStream(socket.socket())
        self.stream.connect(('127.0.0.1', self.port),
                            callback=self.handle_connect)
        self.future = Future()
        if callback is not None:
            self.future.add_done_callback(
                stack_context.wrap(lambda future: callback(future.result())))
        return self.future

    def handle_connect(self):
        logging.info("handle_connect")
        self.stream.write(utf8(self.request_data + "\n"))
        self.stream.read_until(b'\n', callback=self.handle_read)

    def handle_read(self, data):
        logging.info("handle_read")
        self.stream.close()
        try:
            self.future.set_result(self.process_response(data))
        except CapError as e:
            self.future.set_exception(e)
Ejemplo n.º 2
0
class _UDPConnection(object):
    def __init__(self, io_loop, client, request, release_callback, final_callback, max_buffer_size):
        self.start_time = time.time()
        self.io_loop = io_loop
        self.client = client
        self.request = request
        self.release_callback = release_callback
        self.final_callback = final_callback

        address_info = socket.getaddrinfo(request.address, request.port, socket.AF_INET, socket.SOCK_DGRAM, 0, 0)
        af, socket_type, proto, _, socket_address = address_info[0]
        self.stream = IOStream(socket.socket(af, socket_type, proto), io_loop=self.io_loop,
                               max_buffer_size=max_buffer_size)

        self.stream.connect(socket_address, self._on_connect)

    def _on_connect(self):
        self.stream.write(self.request.data)
        # self.stream.read_bytes(65536, self._on_response)
        self.stream.read_until('}}', self._on_response)
        # print("asdfsfeiwjef")

    def _on_response(self, data):
        if self.release_callback is not None:
            release_callback = self.release_callback
            self.release_callback = None
            release_callback()
        self.stream.close()
        if self.final_callback is not None:
            final_callback = self.final_callback
            self.final_callback = None
            final_callback(data)
Ejemplo n.º 3
0
 def test_100_continue(self):
     # Run through a 100-continue interaction by hand:
     # When given Expect: 100-continue, we get a 100 response after the
     # headers, and then the real response after the body.
     stream = IOStream(socket.socket())
     yield stream.connect(("127.0.0.1", self.get_http_port()))
     yield stream.write(
         b"\r\n".join(
             [
                 b"POST /hello HTTP/1.1",
                 b"Content-Length: 1024",
                 b"Expect: 100-continue",
                 b"Connection: close",
                 b"\r\n",
             ]
         )
     )
     data = yield stream.read_until(b"\r\n\r\n")
     self.assertTrue(data.startswith(b"HTTP/1.1 100 "), data)
     stream.write(b"a" * 1024)
     first_line = yield stream.read_until(b"\r\n")
     self.assertTrue(first_line.startswith(b"HTTP/1.1 200"), first_line)
     header_data = yield stream.read_until(b"\r\n\r\n")
     headers = HTTPHeaders.parse(native_str(header_data.decode("latin1")))
     body = yield stream.read_bytes(int(headers["Content-Length"]))
     self.assertEqual(body, b"Got 1024 bytes in POST")
     stream.close()
Ejemplo n.º 4
0
 def test_100_continue(self):
     # Run through a 100-continue interaction by hand:
     # When given Expect: 100-continue, we get a 100 response after the
     # headers, and then the real response after the body.
     stream = IOStream(socket.socket())
     yield stream.connect(("127.0.0.1", self.get_http_port()))
     yield stream.write(
         b"\r\n".join(
             [
                 b"POST /hello HTTP/1.1",
                 b"Content-Length: 1024",
                 b"Expect: 100-continue",
                 b"Connection: close",
                 b"\r\n",
             ]
         )
     )
     data = yield stream.read_until(b"\r\n\r\n")
     self.assertTrue(data.startswith(b"HTTP/1.1 100 "), data)
     stream.write(b"a" * 1024)
     first_line = yield stream.read_until(b"\r\n")
     self.assertTrue(first_line.startswith(b"HTTP/1.1 200"), first_line)
     header_data = yield stream.read_until(b"\r\n\r\n")
     headers = HTTPHeaders.parse(native_str(header_data.decode("latin1")))
     body = yield stream.read_bytes(int(headers["Content-Length"]))
     self.assertEqual(body, b"Got 1024 bytes in POST")
     stream.close()
Ejemplo n.º 5
0
class Memnado(object):
    def __init__(self, host, port):
        self.host = host
        self.port = port
        
        s = socket.socket(socket.AF_INET, socket.SOCK_STREAM, 0)
        s.connect((self.host, self.port))
        self.stream = IOStream(s)
    
    def set(self, key, value, callback, expiry=0):
        key = b64encode(key)
        value = b64encode(value)
        content_length = len(value)
        self.stream.write("set %s 1 %s %s\r\n%s\r\n" % (key, expiry, 
                        content_length, value))
        self.stream.read_until("\r\n", callback)
    
    def get(self, key, callback):
        key = b64encode(key)
        
        def process_get(stream, cb, data):
            if data[0:3] == 'END': # key is empty
                cb(None)
            else:
                status, k, flags, content_length = data.strip().split(' ')
                
                def wrapped_cb(f):
                    return lambda data: f(b64decode(data))
                
                stream.read_bytes(int(content_length), wrapped_cb(cb))
                stream.read_until("\r\nEND\r\n", lambda d: d)
        
        self.stream.write("get %s\r\n" % key)
        self.stream.read_until("\r\n", functools.partial(process_get, self.stream, callback))
Ejemplo n.º 6
0
class ManualCapClient(BaseCapClient):
    def capitalize(self, request_data, callback=None):
        logging.debug("capitalize")
        self.request_data = request_data
        self.stream = IOStream(socket.socket())
        self.stream.connect(('127.0.0.1', self.port),
                            callback=self.handle_connect)
        self.future = Future()
        if callback is not None:
            self.future.add_done_callback(
                stack_context.wrap(lambda future: callback(future.result())))
        return self.future

    def handle_connect(self):
        logging.debug("handle_connect")
        self.stream.write(utf8(self.request_data + "\n"))
        self.stream.read_until(b'\n', callback=self.handle_read)

    def handle_read(self, data):
        logging.debug("handle_read")
        self.stream.close()
        try:
            self.future.set_result(self.process_response(data))
        except CapError as e:
            self.future.set_exception(e)
Ejemplo n.º 7
0
class SubProcessApplication(Application):
    """Run application class in subprocess."""

    def __init__(self, target, io_loop=None):
        Application.__init__(self)
        if isinstance(target, str):
            self.target = self._load_module(target)
        else:
            self.target = target
        if io_loop is None:
            self.io_loop = IOLoop().instance()
        else:
            self.io_loop = io_loop
        self._process = None
        self.socket = None
        self.runner = None

    def start(self, config):
        signal.signal(signal.SIGCHLD, self._sigchld)
        self.socket, child = socket.socketpair()
        self.runner = _Subprocess(self.target, child, config)
        self._process = multiprocessing.Process(target=self.runner.run)
        self._process.start()
        child.close()
        self.ios = IOStream(self.socket, self.io_loop)
        self.ios.read_until('\r\n', self._receiver)

    def _close(self, timeout):
        self._process.join(timeout)

    def _sigchld(self, signum, frame):
        self._close(0.5)

    def _receiver(self, data):
        """Receive data from subprocess. Forward to session."""
        msg = json.loads(binascii.a2b_base64(data))
        if self.session:
            self.session.send(msg)
        else:
            logging.error("from app: %s", str(msg))

    def stop(self):
        self._process.terminate()
        self._close(2.0)

    def send(self, data):
        """Send data to application."""
        self.ios.write(data + '\r\n')

    def received(self, data):
        """Handle data from session. Forward to subprocess for handling."""
        self.send(binascii.b2a_base64(json.dumps(data)))
        return True

    def _load_module(self, modulename):
        import importlib
        return importlib.import_module(modulename)
Ejemplo n.º 8
0
    def _handle_accept(self, fd, events):
        connection, address = self._socket.accept()
        stream = IOStream(connection)
        host = "%s:%d" % address #host = ":".join(str(i) for i in address)
        self._streams[host] = stream

        ccb = functools.partial(self._handle_close, host) #same as: cb =  lambda : self._handle_close(host)
        stream.set_close_callback(ccb)
        stream.read_until("\r\n", functools.partial(self._handle_read, host))
Ejemplo n.º 9
0
class IRCStream(object):

    _instance = None

    @classmethod
    def instance(cls):
        """ Returns the singleton """
        if not cls._instance:
            cls._instance = cls()
        return cls._instance

    def __init__(self):
        sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM, 0)
        self.stream = IOStream(sock)
        self.host = SETTINGS["irc_host"]
        self.channel = SETTINGS["irc_channel"]
        self.stream.connect((self.host, SETTINGS["irc_port"]))
        self.nick = "PyTexasBot"
        self.ident = "pytexasbot"
        self.real_name = "PyTexas StreamBot"

        self.stream.write("NICK %s\r\n" % self.nick)
        self.stream.write("USER %s %s blah :%s\r\n" % (self.ident, self.host,
            self.real_name))
        self.stream.write("JOIN #"+self.channel+"\r\n")
        self.monitor_output()

    def monitor_output(self):
        self.stream.read_until("\r\n", self.parse_line)

    def parse_line(self, response):
        response = response.strip()
        if response.startswith("PING "):
            request = response.replace("PING ", "")
            self.stream.write("PONG %s\r\n" % request)
        splitter = "PRIVMSG #%s :" % self.channel
        if splitter in response:
            parts = response.split(splitter)
            text = parts[1]
            if not text:
                # not going to throw out empty messages
                return self.monitor_output()
            nick = parts[0][1:].split("!")[0].strip()
            message = {
                "time": int(time.time()),
                "text": xhtml_escape(text),
                "name": nick,
                "username": "******",
                "type": "tweet",
                "avatar": None
            }
            broadcast_message(message)
        if response.startswith("ERROR"):
            raise Exception(response)
        else:
            print response
        self.monitor_output()
Ejemplo n.º 10
0
    def fetch(self, request, callback, **kwargs):
        if not isinstance(request, HTTPRequest):
            request = HTTPRequest(url=request, **kwargs)
        callback = stack_context.wrap(callback)

        parsed = urlparse.urlsplit(request.url)
        sock = socket.socket()
        #sock.setblocking(False) # TODO non-blocking connect
        sock.connect((parsed.netloc, 80))  # TODO: other ports, https
        stream = IOStream(sock, io_loop=self.io_loop)
        # TODO: query parameters
        logging.warning("%s %s HTTP/1.0\r\n\r\n" % (request.method, parsed.path or '/'))
        stream.write("%s %s HTTP/1.0\r\n\r\n" % (request.method, parsed.path or '/'))
        stream.read_until("\r\n\r\n", functools.partial(self._on_headers,
                                                        request, callback, stream))
Ejemplo n.º 11
0
class UnixSocketTest(AsyncTestCase):
    """HTTPServers can listen on Unix sockets too.

    Why would you want to do this?  Nginx can proxy to backends listening
    on unix sockets, for one thing (and managing a namespace for unix
    sockets can be easier than managing a bunch of TCP port numbers).

    Unfortunately, there's no way to specify a unix socket in a url for
    an HTTP client, so we have to test this by hand.
    """
    def setUp(self):
        super(UnixSocketTest, self).setUp()
        self.tmpdir = tempfile.mkdtemp()
        self.sockfile = os.path.join(self.tmpdir, "test.sock")
        sock = netutil.bind_unix_socket(self.sockfile)
        app = Application([("/hello", HelloWorldRequestHandler)])
        self.server = HTTPServer(app, io_loop=self.io_loop)
        self.server.add_socket(sock)
        self.stream = IOStream(socket.socket(socket.AF_UNIX),
                               io_loop=self.io_loop)
        self.stream.connect(self.sockfile, self.stop)
        self.wait()

    def tearDown(self):
        self.stream.close()
        self.server.stop()
        shutil.rmtree(self.tmpdir)
        super(UnixSocketTest, self).tearDown()

    def test_unix_socket(self):
        self.stream.write(b"GET /hello HTTP/1.0\r\n\r\n")
        self.stream.read_until(b"\r\n", self.stop)
        response = self.wait()
        self.assertEqual(response, b"HTTP/1.0 200 OK\r\n")
        self.stream.read_until(b"\r\n\r\n", self.stop)
        headers = HTTPHeaders.parse(self.wait().decode('latin1'))
        self.stream.read_bytes(int(headers["Content-Length"]), self.stop)
        body = self.wait()
        self.assertEqual(body, b"Hello world")

    def test_unix_socket_bad_request(self):
        # Unix sockets don't have remote addresses so they just return an
        # empty string.
        with ExpectLog(gen_log, "Malformed HTTP message from"):
            self.stream.write(b"garbage\r\n\r\n")
            self.stream.read_until_close(self.stop)
            response = self.wait()
        self.assertEqual(response, b"")
Ejemplo n.º 12
0
class UnixSocketTest(AsyncTestCase):
    """HTTPServers can listen on Unix sockets too.

    Why would you want to do this?  Nginx can proxy to backends listening
    on unix sockets, for one thing (and managing a namespace for unix
    sockets can be easier than managing a bunch of TCP port numbers).

    Unfortunately, there's no way to specify a unix socket in a url for
    an HTTP client, so we have to test this by hand.
    """
    def setUp(self):
        super(UnixSocketTest, self).setUp()
        self.tmpdir = tempfile.mkdtemp()
        self.sockfile = os.path.join(self.tmpdir, "test.sock")
        sock = netutil.bind_unix_socket(self.sockfile)
        app = Application([("/hello", HelloWorldRequestHandler)])
        self.server = HTTPServer(app)
        self.server.add_socket(sock)
        self.stream = IOStream(socket.socket(socket.AF_UNIX))
        self.stream.connect(self.sockfile, self.stop)
        self.wait()

    def tearDown(self):
        self.stream.close()
        self.io_loop.run_sync(self.server.close_all_connections)
        self.server.stop()
        shutil.rmtree(self.tmpdir)
        super(UnixSocketTest, self).tearDown()

    def test_unix_socket(self):
        self.stream.write(b"GET /hello HTTP/1.0\r\n\r\n")
        self.stream.read_until(b"\r\n", self.stop)
        response = self.wait()
        self.assertEqual(response, b"HTTP/1.1 200 OK\r\n")
        self.stream.read_until(b"\r\n\r\n", self.stop)
        headers = HTTPHeaders.parse(self.wait().decode('latin1'))
        self.stream.read_bytes(int(headers["Content-Length"]), self.stop)
        body = self.wait()
        self.assertEqual(body, b"Hello world")

    def test_unix_socket_bad_request(self):
        # Unix sockets don't have remote addresses so they just return an
        # empty string.
        with ExpectLog(gen_log, "Malformed HTTP message from"):
            self.stream.write(b"garbage\r\n\r\n")
            self.stream.read_until_close(self.stop)
            response = self.wait()
        self.assertEqual(response, b"HTTP/1.1 400 Bad Request\r\n\r\n")
Ejemplo n.º 13
0
class EchoClienAsync(object):
    def __init__(self, host = "127.0.0.1", port=12345):
        self.host = host
        self.port = port
        s = socket.socket(socket.AF_INET, socket.SOCK_STREAM, 0)
        self.stream = IOStream(s)

    def connect(self):
        self.stream.connect((self.host, self.port))
        return self.stream

    def send(self, data, callback=None):
        self.stream.write(data+"\n", callback)

    def recv(self, callback):
        self.stream.read_until("\n", lambda data: callback(data[:-1]))
Ejemplo n.º 14
0
def talk(socket, io_loop):
    """A client connection handler that says hello to the echo server,
    waits for a response, then disconnects."""

    stream = IOStream(socket, io_loop=io_loop)
    messages = [0]

    def write(data, *args):
        print 'C: %r' % data
        stream.write(data, *args)

    def handle(data):
        write('goodbye\n', stream.close)

    stream.read_until("\n", handle)
    write('hello!\n')
Ejemplo n.º 15
0
def talk(socket, io_loop):
    """A client connection handler that says hello to the echo server,
    waits for a response, then disconnects."""

    stream = IOStream(socket, io_loop=io_loop)
    messages = [0]

    def write(data, *args):
        print 'C: %r' % data
        stream.write(data, *args)

    def handle(data):
        write('goodbye\n', stream.close)

    stream.read_until("\n", handle)
    write('hello!\n')
Ejemplo n.º 16
0
 def test_unix_socket(self):
     sockfile = os.path.join(self.tmpdir, "test.sock")
     sock = netutil.bind_unix_socket(sockfile)
     app = Application([("/hello", HelloWorldRequestHandler)])
     server = HTTPServer(app, io_loop=self.io_loop)
     server.add_socket(sock)
     stream = IOStream(socket.socket(socket.AF_UNIX), io_loop=self.io_loop)
     stream.connect(sockfile, self.stop)
     self.wait()
     stream.write(b("GET /hello HTTP/1.0\r\n\r\n"))
     stream.read_until(b("\r\n"), self.stop)
     response = self.wait()
     self.assertEqual(response, b("HTTP/1.0 200 OK\r\n"))
     stream.read_until(b("\r\n\r\n"), self.stop)
     headers = HTTPHeaders.parse(self.wait().decode('latin1'))
     stream.read_bytes(int(headers["Content-Length"]), self.stop)
     body = self.wait()
     self.assertEqual(body, b("Hello world"))
Ejemplo n.º 17
0
 def test_unix_socket(self):
     sockfile = os.path.join(self.tmpdir, "test.sock")
     sock = netutil.bind_unix_socket(sockfile)
     app = Application([("/hello", HelloWorldRequestHandler)])
     server = HTTPServer(app, io_loop=self.io_loop)
     server.add_socket(sock)
     stream = IOStream(socket.socket(socket.AF_UNIX), io_loop=self.io_loop)
     stream.connect(sockfile, self.stop)
     self.wait()
     stream.write(b("GET /hello HTTP/1.0\r\n\r\n"))
     stream.read_until(b("\r\n"), self.stop)
     response = self.wait()
     self.assertEqual(response, b("HTTP/1.0 200 OK\r\n"))
     stream.read_until(b("\r\n\r\n"), self.stop)
     headers = HTTPHeaders.parse(self.wait().decode('latin1'))
     stream.read_bytes(int(headers["Content-Length"]), self.stop)
     body = self.wait()
     self.assertEqual(body, b("Hello world"))
Ejemplo n.º 18
0
def connection_ready(sock, core, fd, event):
    """
    handler of socket connection
    """
    #print 'in'
    while True:
        try:
            connection, address = sock.accept()
        except socket.error as e:
            if e.args[0] not in (errno.EWOULDBLOCK, errno.EAGAIN):
                raise
            return

        # 保证能接收到
        # 从localhost接收数据会在事件到来之前产生
        #connection.settimeout(2)
        stream = IOStream(connection)
        stream.read_until(SOCKET_EOF, partial(handle, stream))
        return
Ejemplo n.º 19
0
class DecoratorCapClient(BaseCapClient):
    @return_future
    def capitalize(self, request_data, callback):
        logging.info("capitalize")
        self.request_data = request_data
        self.stream = IOStream(socket.socket())
        self.stream.connect(('127.0.0.1', self.port),
                            callback=self.handle_connect)
        self.callback = callback

    def handle_connect(self):
        logging.info("handle_connect")
        self.stream.write(utf8(self.request_data + "\n"))
        self.stream.read_until(b'\n', callback=self.handle_read)

    def handle_read(self, data):
        logging.info("handle_read")
        self.stream.close()
        self.callback(self.process_response(data))
Ejemplo n.º 20
0
class DecoratorCapClient(BaseCapClient):
    @future_wrap
    def capitalize(self, request_data, callback):
        logging.info("capitalize")
        self.request_data = request_data
        self.stream = IOStream(socket.socket(), io_loop=self.io_loop)
        self.stream.connect(('127.0.0.1', self.port),
                            callback=self.handle_connect)
        self.callback = callback

    def handle_connect(self):
        logging.info("handle_connect")
        self.stream.write(utf8(self.request_data + "\n"))
        self.stream.read_until(b('\n'), callback=self.handle_read)

    def handle_read(self, data):
        logging.info("handle_read")
        self.stream.close()
        self.callback(self.process_response(data))
Ejemplo n.º 21
0
 def capitalize(self, request_data):
     logging.debug('capitalize')
     stream = IOStream(socket.socket())
     logging.debug('connecting')
     yield stream.connect(('127.0.0.1', self.port))
     stream.write(utf8(request_data + '\n'))
     logging.debug('reading')
     data = yield stream.read_until(b'\n')
     logging.debug('returning')
     stream.close()
     raise gen.Return(self.process_response(data))
Ejemplo n.º 22
0
 def capitalize(self, request_data):
     logging.debug('capitalize')
     stream = IOStream(socket.socket())
     logging.debug('connecting')
     yield stream.connect(('127.0.0.1', self.port))
     stream.write(utf8(request_data + '\n'))
     logging.debug('reading')
     data = yield stream.read_until(b'\n')
     logging.debug('returning')
     stream.close()
     raise gen.Return(self.process_response(data))
 def capitalize(self, request_data):
     logging.debug("capitalize")
     stream = IOStream(socket.socket())
     logging.debug("connecting")
     yield stream.connect(("10.0.0.7", self.port))
     stream.write(utf8(request_data + "\n"))
     logging.debug("reading")
     data = yield stream.read_until(b"\n")
     logging.debug("returning")
     stream.close()
     raise gen.Return(self.process_response(data))
Ejemplo n.º 24
0
class DecoratorCapClient(BaseCapClient):
    with ignore_deprecation():
        @return_future
        def capitalize(self, request_data, callback):
            logging.debug("capitalize")
            self.request_data = request_data
            self.stream = IOStream(socket.socket())
            self.stream.connect(('127.0.0.1', self.port),
                                callback=self.handle_connect)
            self.callback = callback

    def handle_connect(self):
        logging.debug("handle_connect")
        self.stream.write(utf8(self.request_data + "\n"))
        self.stream.read_until(b'\n', callback=self.handle_read)

    def handle_read(self, data):
        logging.debug("handle_read")
        self.stream.close()
        self.callback(self.process_response(data))
Ejemplo n.º 25
0
class _Subprocess(object):
    
    def __init__(self, target, socket, config):
        self.target = target
        self.socket = socket
        self.running = False
        self.config = config
        self.init = getattr(self.target, 'setup', None)
        self.finish = getattr(self.target, 'finish', None)
        if hasattr(self.target, 'handle'):
            self.handler = getattr(self.target, 'handle')
        else:
            self.handler = target
        if not callable(self.handler):
            raise NotImplementedError('handle')
        self.io_loop = None

    def run(self):
        """Run the application data receiving loop.

        This is executed in the subprocess context.
        """
        signal.signal(signal.SIGTERM, self.sigterm)
        logging.debug("starting subprocess in pid: %d [fd %d]",
                      os.getpid(), self.socket.fileno())

        # close extra files descriptor from master, [3 -> fileno(self.socket)-1]
        os.closerange(3, self.socket.fileno())

        self.io_loop = IOLoop()
        self.ios = IOStream(self.socket, self.io_loop)
        if callable(self.init):
            self.init(self._writer, self.config)
        self.running = True
        self.ios.read_until('\r\n', self._receiver)
        try:
            self.io_loop.start()
        except (IOError, KeyboardInterrupt), ex:
            pass
        except Exception, e:
            logging.error("Subprocess recv failed: %s", str(e))
Ejemplo n.º 26
0
            def accept_callback(conn, address):
                stream = IOStream(conn)
                request_data = yield stream.read_until(b"\r\n\r\n")
                if b"HTTP/1." not in request_data:
                    self.skipTest("requires HTTP/1.x")
                yield stream.write(b"""\
HTTP/1.1 200 OK
X-XSS-Protection: 1;
\tmode=block

""".replace(b"\n", b"\r\n"))
                stream.close()
Ejemplo n.º 27
0
            def accept_callback(conn, address):
                stream = IOStream(conn)
                request_data = yield stream.read_until(b"\r\n\r\n")
                if b"HTTP/1." not in request_data:
                    self.skipTest("requires HTTP/1.x")
                yield stream.write(b"""\
HTTP/1.1 200 OK
X-XSS-Protection: 1;
\tmode=block

""".replace(b"\n", b"\r\n"))
                stream.close()
Ejemplo n.º 28
0
 def test_100_continue(self):
     # Run through a 100-continue interaction by hand:
     # When given Expect: 100-continue, we get a 100 response after the
     # headers, and then the real response after the body.
     stream = IOStream(socket.socket(), io_loop=self.io_loop)
     stream.connect(("localhost", self.get_http_port()), callback=self.stop)
     self.wait()
     stream.write(b("\r\n").join([b("POST /hello HTTP/1.1"),
                                  b("Content-Length: 1024"),
                                  b("Expect: 100-continue"),
                                  b("Connection: close"),
                                  b("\r\n")]), callback=self.stop)
     self.wait()
     stream.read_until(b("\r\n\r\n"), self.stop)
     data = self.wait()
     self.assertTrue(data.startswith(b("HTTP/1.1 100 ")), data)
     stream.write(b("a") * 1024)
     stream.read_until(b("\r\n"), self.stop)
     first_line = self.wait()
     self.assertTrue(first_line.startswith(b("HTTP/1.1 200")), first_line)
     stream.read_until(b("\r\n\r\n"), self.stop)
     header_data = self.wait()
     headers = HTTPHeaders.parse(native_str(header_data.decode('latin1')))
     stream.read_bytes(int(headers["Content-Length"]), self.stop)
     body = self.wait()
     self.assertEqual(body, b("Got 1024 bytes in POST"))
     stream.close()
Ejemplo n.º 29
0
 def test_100_continue(self):
     # Run through a 100-continue interaction by hand:
     # When given Expect: 100-continue, we get a 100 response after the
     # headers, and then the real response after the body.
     stream = IOStream(socket.socket(), io_loop=self.io_loop)
     stream.connect(("localhost", self.get_http_port()), callback=self.stop)
     self.wait()
     stream.write(b("\r\n").join([
         b("POST /hello HTTP/1.1"),
         b("Content-Length: 1024"),
         b("Expect: 100-continue"),
         b("Connection: close"),
         b("\r\n")
     ]),
                  callback=self.stop)
     self.wait()
     stream.read_until(b("\r\n\r\n"), self.stop)
     data = self.wait()
     self.assertTrue(data.startswith(b("HTTP/1.1 100 ")), data)
     stream.write(b("a") * 1024)
     stream.read_until(b("\r\n"), self.stop)
     first_line = self.wait()
     self.assertTrue(first_line.startswith(b("HTTP/1.1 200")), first_line)
     stream.read_until(b("\r\n\r\n"), self.stop)
     header_data = self.wait()
     headers = HTTPHeaders.parse(native_str(header_data.decode('latin1')))
     stream.read_bytes(int(headers["Content-Length"]), self.stop)
     body = self.wait()
     self.assertEqual(body, b("Got 1024 bytes in POST"))
     stream.close()
Ejemplo n.º 30
0
class TokenizerService(object):
    '''
    Wraps the IPC to the Java TokenizerService (which runs tokenization and named
    entity extraction through CoreNLP)
    '''

    def __init__(self):
        self._socket = IOStream(socket.socket(socket.AF_INET, socket.SOCK_STREAM))
        self._requests = dict()
        self._next_id = 0
        
    @tornado.gen.coroutine
    def run(self):
        yield self._socket.connect(('127.0.0.1', PORT))
        
        while True:
            try:
                response = yield self._socket.read_until(b'\n')
            except StreamClosedError:
                response = None
            if not response:
                return
            response = json.loads(str(response, encoding='utf-8'))
            
            id = int(response['req'])
            result = TokenizerResult(tokens=list(clean_tokens(response['tokens'])),
                                     values=response['values'],
                                     constituency_parse=response['constituencyParse'],
                                     pos_tags=response['pos'],
                                     raw_tokens=response['rawTokens'],
                                     sentiment=response['sentiment'])
            self._requests[id].set_result(result)
            del self._requests[id]
        
    def tokenize(self, language_tag, query, expect=None):
        id = self._next_id
        self._next_id += 1
        
        req = dict(req=id, utterance=query, languageTag=language_tag)
        if expect is not None:
            req['expect'] = expect
        outer = Future()
        self._requests[id] = outer
        
        def then(future):
            if future.exception():
                outer.set_exception(future.exception())
                del self._requests[id]
        
        future = self._socket.write(json.dumps(req).encode())
        future.add_done_callback(then)
        return outer
Ejemplo n.º 31
0
Archivo: server.py Proyecto: BYK/irc2ws
class WS2IRCBridge(WebSocketHandler):
    def open(self, host='irc.freenode.net', port=None):
        port = int(port or 6667)
        self.sock = IOStream(socket.socket(socket.AF_INET,
                                           socket.SOCK_STREAM, 0))
        self.sock.connect((host, port), self.sock_loop)
        logging.debug("Request received for %s:%d", host, port)

    def sock_loop(self, data=None):
        if data:
            self.write_message(data)

        if self.sock.closed():
            self.close()
            logging.debug("IRC socket closed. Closing active WebSocket too.")
        else:
            self.sock.read_until("\r\n", self.sock_loop)

    def on_message(self, message):
        self.sock.write(message.encode('utf-8') + "\r\n")

    def on_close(self):
        self.sock.close()
        logging.debug("Client closed the WebSocket.")
Ejemplo n.º 32
0
            def accept_callback(conn, address):
                # fake an HTTP server using chunked encoding where the final chunks
                # and connection close all happen at once
                stream = IOStream(conn)
                request_data = yield stream.read_until(b"\r\n\r\n")
                if b"HTTP/1." not in request_data:
                    self.skipTest("requires HTTP/1.x")
                yield stream.write(b"""\
HTTP/1.1 200 OK
Transfer-Encoding: chunked

1
1
1
2
0

""".replace(b"\n", b"\r\n"))
                stream.close()
Ejemplo n.º 33
0
            def accept_callback(conn, address):
                # fake an HTTP server using chunked encoding where the final chunks
                # and connection close all happen at once
                stream = IOStream(conn)
                request_data = yield stream.read_until(b"\r\n\r\n")
                if b"HTTP/1." not in request_data:
                    self.skipTest("requires HTTP/1.x")
                yield stream.write(b"""\
HTTP/1.1 200 OK
Transfer-Encoding: chunked

1
1
1
2
0

""".replace(b"\n", b"\r\n"))
                stream.close()
Ejemplo n.º 34
0
class TestIOStreamStartTLS(AsyncTestCase):
    def setUp(self):
        try:
            super(TestIOStreamStartTLS, self).setUp()
            self.listener, self.port = bind_unused_port()
            self.server_stream = None
            self.server_accepted = Future()
            netutil.add_accept_handler(self.listener, self.accept)
            self.client_stream = IOStream(socket.socket())
            self.io_loop.add_future(self.client_stream.connect(("127.0.0.1", self.port)), self.stop)
            self.wait()
            self.io_loop.add_future(self.server_accepted, self.stop)
            self.wait()
        except Exception as e:
            print(e)
            raise

    def tearDown(self):
        if self.server_stream is not None:
            self.server_stream.close()
        if self.client_stream is not None:
            self.client_stream.close()
        self.listener.close()
        super(TestIOStreamStartTLS, self).tearDown()

    def accept(self, connection, address):
        if self.server_stream is not None:
            self.fail("should only get one connection")
        self.server_stream = IOStream(connection)
        self.server_accepted.set_result(None)

    @gen.coroutine
    def client_send_line(self, line):
        self.client_stream.write(line)
        recv_line = yield self.server_stream.read_until(b"\r\n")
        self.assertEqual(line, recv_line)

    @gen.coroutine
    def server_send_line(self, line):
        self.server_stream.write(line)
        recv_line = yield self.client_stream.read_until(b"\r\n")
        self.assertEqual(line, recv_line)

    def client_start_tls(self, ssl_options=None, server_hostname=None):
        client_stream = self.client_stream
        self.client_stream = None
        return client_stream.start_tls(False, ssl_options, server_hostname)

    def server_start_tls(self, ssl_options=None):
        server_stream = self.server_stream
        self.server_stream = None
        return server_stream.start_tls(True, ssl_options)

    @gen_test
    def test_start_tls_smtp(self):
        # This flow is simplified from RFC 3207 section 5.
        # We don't really need all of this, but it helps to make sure
        # that after realistic back-and-forth traffic the buffers end up
        # in a sane state.
        yield self.server_send_line(b"220 mail.example.com ready\r\n")
        yield self.client_send_line(b"EHLO mail.example.com\r\n")
        yield self.server_send_line(b"250-mail.example.com welcome\r\n")
        yield self.server_send_line(b"250 STARTTLS\r\n")
        yield self.client_send_line(b"STARTTLS\r\n")
        yield self.server_send_line(b"220 Go ahead\r\n")
        client_future = self.client_start_tls()
        server_future = self.server_start_tls(_server_ssl_options())
        self.client_stream = yield client_future
        self.server_stream = yield server_future
        self.assertTrue(isinstance(self.client_stream, SSLIOStream))
        self.assertTrue(isinstance(self.server_stream, SSLIOStream))
        yield self.client_send_line(b"EHLO mail.example.com\r\n")
        yield self.server_send_line(b"250 mail.example.com welcome\r\n")

    @gen_test
    def test_handshake_fail(self):
        server_future = self.server_start_tls(_server_ssl_options())
        client_future = self.client_start_tls(dict(cert_reqs=ssl.CERT_REQUIRED, ca_certs=certifi.where()))
        with ExpectLog(gen_log, "SSL Error"):
            with self.assertRaises(ssl.SSLError):
                yield client_future
        with self.assertRaises((ssl.SSLError, socket.error)):
            yield server_future

    @unittest.skipIf(not hasattr(ssl, "create_default_context"), "ssl.create_default_context not present")
    @gen_test
    def test_check_hostname(self):
        # Test that server_hostname parameter to start_tls is being used.
        # The check_hostname functionality is only available in python 2.7 and
        # up and in python 3.4 and up.
        server_future = self.server_start_tls(_server_ssl_options())
        client_future = self.client_start_tls(ssl.create_default_context(), server_hostname=b"127.0.0.1")
        with ExpectLog(gen_log, "SSL Error"):
            with self.assertRaises(ssl.SSLError):
                yield client_future
        with self.assertRaises((ssl.SSLError, socket.error)):
            yield server_future
Ejemplo n.º 35
0
class _HTTPConnection(object):
    _SUPPORTED_METHODS = set(["GET", "HEAD", "POST", "PUT", "DELETE"])

    def __init__(self, io_loop, client, request, callback):
        self.start_time = time.time()
        self.io_loop = io_loop
        self.client = client
        self.request = request
        self.callback = callback
        self.code = None
        self.headers = None
        self.chunks = None
        self._decompressor = None
        # Timeout handle returned by IOLoop.add_timeout
        self._timeout = None
        with stack_context.StackContext(self.cleanup):
            parsed = urlparse.urlsplit(_unicode(self.request.url))
            if ssl is None and parsed.scheme == "https":
                raise ValueError("HTTPS requires either python2.6+ or "
                                 "curl_httpclient")
            if parsed.scheme not in ("http", "https"):
                raise ValueError("Unsupported url scheme: %s" %
                                 self.request.url)
            # urlsplit results have hostname and port results, but they
            # didn't support ipv6 literals until python 2.7.
            netloc = parsed.netloc
            if "@" in netloc:
                userpass, _, netloc = netloc.rpartition("@")
            match = re.match(r'^(.+):(\d+)$', netloc)
            if match:
                host = match.group(1)
                port = int(match.group(2))
            else:
                host = netloc
                port = 443 if parsed.scheme == "https" else 80
            if re.match(r'^\[.*\]$', host):
                # raw ipv6 addresses in urls are enclosed in brackets
                host = host[1:-1]
            if self.client.hostname_mapping is not None:
                host = self.client.hostname_mapping.get(host, host)

            if request.allow_ipv6:
                af = socket.AF_UNSPEC
            else:
                # We only try the first IP we get from getaddrinfo,
                # so restrict to ipv4 by default.
                af = socket.AF_INET

            addrinfo = socket.getaddrinfo(host, port, af, socket.SOCK_STREAM,
                                          0, 0)
            af, socktype, proto, canonname, sockaddr = addrinfo[0]

            if parsed.scheme == "https":
                ssl_options = {}
                if request.validate_cert:
                    ssl_options["cert_reqs"] = ssl.CERT_REQUIRED
                if request.ca_certs is not None:
                    ssl_options["ca_certs"] = request.ca_certs
                else:
                    ssl_options["ca_certs"] = _DEFAULT_CA_CERTS
                if request.client_key is not None:
                    ssl_options["keyfile"] = request.client_key
                if request.client_cert is not None:
                    ssl_options["certfile"] = request.client_cert
                self.stream = SSLIOStream(socket.socket(af, socktype, proto),
                                          io_loop=self.io_loop,
                                          ssl_options=ssl_options)
            else:
                self.stream = IOStream(socket.socket(af, socktype, proto),
                                       io_loop=self.io_loop)
            timeout = min(request.connect_timeout, request.request_timeout)
            if timeout:
                self._connect_timeout = self.io_loop.add_timeout(
                    self.start_time + timeout,
                    self._on_timeout)
            self.stream.set_close_callback(self._on_close)
            self.stream.connect(sockaddr,
                                functools.partial(self._on_connect, parsed))

    def _on_timeout(self):
        self._timeout = None
        if self.callback is not None:
            self.callback(HTTPResponse(self.request, 599,
                                       headers=self.headers,
                                       error=HTTPError(599, "Timeout")))
            self.callback = None
        self.stream.close()

    def _on_connect(self, parsed):
        if self._timeout is not None:
            self.io_loop.remove_callback(self._timeout)
            self._timeout = None
        if self.request.request_timeout:
            self._timeout = self.io_loop.add_timeout(
                self.start_time + self.request.request_timeout,
                self._on_timeout)
        if (self.request.validate_cert and
            isinstance(self.stream, SSLIOStream)):
            match_hostname(self.stream.socket.getpeercert(),
                           parsed.hostname)
        if (self.request.method not in self._SUPPORTED_METHODS and
            not self.request.allow_nonstandard_methods):
            raise KeyError("unknown method %s" % self.request.method)
        for key in ('network_interface',
                    'proxy_host', 'proxy_port',
                    'proxy_username', 'proxy_password'):
            if getattr(self.request, key, None):
                raise NotImplementedError('%s not supported' % key)
        if "Host" not in self.request.headers:
            self.request.headers["Host"] = parsed.netloc
        username, password = None, None
        if parsed.username is not None:
            username, password = parsed.username, parsed.password
        elif self.request.auth_username is not None:
            username = self.request.auth_username
            password = self.request.auth_password
        if username is not None:
            auth = utf8(username) + b(":") + utf8(password)
            self.request.headers["Authorization"] = (b("Basic ") +
                                                     base64.b64encode(auth))
        if self.request.user_agent:
            self.request.headers["User-Agent"] = self.request.user_agent
        has_body = self.request.method in ("POST", "PUT")
        if has_body:
            assert self.request.body is not None
            self.request.headers["Content-Length"] = str(len(
                self.request.body))
        else:
            assert self.request.body is None
        if (self.request.method == "POST" and
            "Content-Type" not in self.request.headers):
            self.request.headers["Content-Type"] = "application/x-www-form-urlencoded"
        if self.request.use_gzip:
            self.request.headers["Accept-Encoding"] = "gzip"
        req_path = ((parsed.path or '/') +
                (('?' + parsed.query) if parsed.query else ''))
        request_lines = [utf8("%s %s HTTP/1.1" % (self.request.method,
                                                  req_path))]
        for k, v in self.request.headers.get_all():
            line = utf8(k) + b(": ") + utf8(v)
            if b('\n') in line:
                raise ValueError('Newline in header: ' + repr(line))
            request_lines.append(line)
        self.stream.write(b("\r\n").join(request_lines) + b("\r\n\r\n"))
        if has_body:
            self.stream.write(self.request.body)
        self.stream.read_until(b("\r\n\r\n"), self._on_headers)

    @contextlib.contextmanager
    def cleanup(self):
        try:
            yield
        except Exception, e:
            logging.warning("uncaught exception", exc_info=True)
            if self.callback is not None:
                callback = self.callback
                self.callback = None
                callback(HTTPResponse(self.request, 599, error=e, headers=self.headers))
Ejemplo n.º 36
0
class SMTPClient(object):
    def __init__(self, host, port, hostname=None,
                 timeout=socket._GLOBAL_DEFAULT_TIMEOUT):
        self.host = host
        self.port = port
        self.hostname = hostname
        self.timeout = timeout
        self.connect(self.host, self.port)
        self.connected = False


    def _get_socket(self, host, port, timeout):
        return socket.create_connection((host, port), timeout)


    def _get_hostname(self):
        fqdn = socket.getfqdn()
        if "." in fqdn: return fqdn

        return socket.gethostbyname(socket.gethostname())


    def _on_connect(self, data):
        code, msg = SMTPClient.fmt_reply(data)
        print "on_connect: ", code, msg

        if code != 220: raise SMTPClientConnectError(code, msg)
        if not self.hostname: self.hostname = self._get_hostname()
        # self.cmd_helo()


    def _on_cmd_helo(self, data):
        code, msg = SMTPClient.fmt_reply(data)
        print "helo resp:", code, msg
        self.close()


    def _on_cmd_ehlo(self, data):
        # TODO implement.
        code, msg = SMTPClient.fmt_reply(data)
        print "helo resp:", code, msg
        self.close()


    def _on_close(self, data):
        print "close:", data
        self.stream.close()
        self.sock.close()


    def auth(self, user, passwd):
        pass


    def connect(self, host=None, port=None):
        host = host or "localhost"
        port = port or SMTP_PORT
        self.sock = self._get_socket(host, port, self.timeout)
        self.stream = IOStream(self.sock)
        # TODO multiline
        self.stream.read_until("\n", self._on_connect)


    def cmd_helo(self, hostname=""):
        self.execute("helo", self._on_cmd_helo, hostname or self.hostname)


    def cmd_ehlo(self, hostname=""):
        self.execute("ehlo", self._on_cmd_ehlo, hostname or self.hostname)


    def execute(self, cmd, callback, params="", end=CRLF):
        if not self.connected: raise SMTPClientException("Unconnected.")
        cmdline = "%s %s" % (cmd, params)
        cmdline = "%s%s" % (cmdline.strip(), end)
        print "cmdline:", repr(cmdline)
        self.stream.write(cmdline)
        self.stream.read_until(end, callback)


    def send_mail(self, frm, to_addrs, msg,
                  options=None, host=None, port=None):
        options = options or []
        yield self.cmd_helo(self.hostname)
        # FIXME.


    @classmethod
    def fmt_reply(cls, line):
        code, msg = line[:3], line[4:].strip()
        try:
            code = int(code)
        except:
            code = -1
        return code, msg


    def close(self):
        self.execute("quit", self._on_close)
Ejemplo n.º 37
0
class Script(object):
    def __init__(self, nick=None, user=None, logdir=None):
        self.config = config
        self.nickname = nick or config.BOT_NICKNAME
        self.username = user or config.BOT_USERNAME
        self.logdir = logdir or config.LOG_DIRECTORY
        self.stream = None
        self.ready = not config.WAIT_FOR_PING
        self.onready = []
        self.channels = []
        self.lookup_callbacks = {}
        self.host = '%s!%s@localhost' % (self.nickname, self.username)
        self.db = sqlite3.connect(config.DATABASE_FILE)
        self.db.row_factory = sqlite3.Row

    def start(self, host=None, port=None):
        host = host or config.SERVER_HOST
        port = port or config.SERVER_PORT
        s = socket.socket(socket.AF_INET, socket.SOCK_STREAM, 0)
        s.connect((host, port))
        self.stream = IOStream(s)
        self.write('NICK', self.nickname)
        self.write('USER', self.username, config.BOT_MODES, self.username,
                   config.BOT_REALNAME)
        self.read_line(self.parse_line)

    def read_line(self, callback):
        self.stream.read_until(b'\r\n', callback)

    def parse_line(self, data):
        # IRC doesn't really define an encoding. Try UTF-8 first, fall back to latin-1.
        try:
            data = data.decode('utf-8')
        except:
            data = data.decode('latin-1')
        parts = line_regex.match(data.strip()).groupdict()
        cmd = parts['command'].lower()
        args = []
        if parts['argument']:
            if ':' in parts['argument']:
                arg, text = parts['argument'].split(':', 1)
                args.extend(arg.strip().split())
                args.append(text.strip())
            else:
                args.extend(parts['argument'].split())
        func = getattr(self, 'handle_%s' % cmd, None)
        if func and callable(func):
            try:
                func(parts['prefix'], *args)
            except Exception as ex:
                print('Error handling "%s": %s' % (cmd, str(ex)))
        self.read_line(self.parse_line)

    def write(self, cmd, *args):
        if not isinstance(cmd, bytes):
            cmd = cmd.encode('ascii')
        line = cmd.upper()
        if args:
            byte_args = []
            for a in args:
                if isinstance(a, str):
                    byte_args.append(a.encode('utf-8'))
                elif isinstance(a, bytes):
                    byte_args.append(a)
                else:
                    raise Exception('Arguments to write must be str or bytes.')
            if b' ' in byte_args[-1]:
                byte_args[-1] = b':' + byte_args[-1]
            line += b' ' + b' '.join(byte_args)
        self.stream.write(line + b'\r\n')

    def chat(self, line, destination=None):
        destination = destination or self.channels
        if not destination:
            return
        if not isinstance(destination, (list, tuple)):
            destination = (destination, )
        # Since we don't receive our own chat, log it manually.
        for dest in destination:
            self.log_chat(self.user, dest, line)
            self.write('PRIVMSG', dest, line)

    def join(self, channel):
        if not channel.startswith(config.CHANNEL_PREFIX):
            return
        if self.ready:
            self.write('JOIN', channel)
        else:
            self.onready.append(lambda bot: bot.write('JOIN', channel))

    def lookup(self, nickname, callback):
        self.lookup_callbacks[nickname] = callback
        self.write('USERHOST', nickname)

    def get_user(self, host):
        host = host.lower().strip()
        cur = self.db.cursor()
        try:
            cur.execute(
                """
				select nickname, password, name, email, sms_email, can_op, auto_op
				from users u inner join hosts h on h.user_id = u.id where h.host = ?
			""", (host, ))
            return cur.fetchone()
        finally:
            cur.close()

    def log_chat(self, user, destination, line):
        if not destination.startswith(
                config.CHANNEL_PREFIX
        ) or not self.logdir or not config.ENABLE_LOGGING:
            return
        if not os.path.exists(self.logdir):
            os.mkdir(self.logdir)
        chandir = os.path.join(self.logdir, destination[1:])
        if not os.path.exists(chandir):
            os.mkdir(chandir)
        now = datetime.datetime.now()
        filename = now.strftime('%Y-%m-%d') + '.txt'
        time = now.strftime('%H:%M:%S')
        logfile = os.path.join(chandir, filename)
        with open(logfile, 'ab') as f:
            log_line = '%s\t%s\t%s\n' % (time, user, line)
            f.write(log_line.encode('utf-8'))

    def handle_ping(self, prefix, *args):
        self.write('PONG', *args)
        # I don't know if this is universally accepted, but our server doesn't
        # let you do anything until you respond to the initial PING.
        if not self.ready:
            self.ready = True
            for func in self.onready:
                func(self)

    def handle_join(self, user, channel):
        nick = user.split('!')[0]
        if nick == self.nickname:
            self.user = user
            if channel.startswith(
                    config.CHANNEL_PREFIX) and channel not in self.channels:
                self.channels.append(channel)
        else:
            u = self.get_user(user)
            if u and u['auto_op']:
                self.write('MODE', channel, '+o', nick)

    def handle_part(self, user, channel, msg=None):
        pass

    def handle_302(self, server, mynick, userhost):
        nick, host = userhost.split('=', 1)
        is_op = False
        if nick.endswith('*'):
            is_op = True
            nick = nick[:-1]
        is_away = host.startswith('-')
        host = host[1:]
        if nick in self.lookup_callbacks:
            self.lookup_callbacks[nick](self, nick, host, is_op, is_away)
            del self.lookup_callbacks[nick]

    def handle_privmsg(self, user, destination, line):
        if destination.startswith(config.CHANNEL_PREFIX):
            self.log_chat(user, destination, line)
        if line.startswith('!'):
            parts = line.split(None, 1)
            cmd = parts[0].strip()[1:]
            arg = None if len(parts) < 2 else parts[1].strip()
            self.run_command(user, destination, cmd, arg)

    def run_command(self, user, destination, cmd, arg):
        try:
            mod_name = 'commands.%s' % cmd
            if mod_name in sys.modules:
                del sys.modules[mod_name]
            mod = __import__(mod_name, [], [], 'commands')
            mod.handle(self, user, destination, arg)
        except Exception as ex:
            print('Error running command "%s": %s' % (cmd, str(ex)))
Ejemplo n.º 38
0
class KeepAliveTest(AsyncHTTPTestCase):
    """Tests various scenarios for HTTP 1.1 keep-alive support.

    These tests don't use AsyncHTTPClient because we want to control
    connection reuse and closing.
    """
    def get_app(self):
        class HelloHandler(RequestHandler):
            def get(self):
                self.finish("Hello world")

            def post(self):
                self.finish("Hello world")

        class LargeHandler(RequestHandler):
            def get(self):
                # 512KB should be bigger than the socket buffers so it will
                # be written out in chunks.
                self.write("".join(chr(i % 256) * 1024 for i in range(512)))

        class FinishOnCloseHandler(RequestHandler):
            def initialize(self, cleanup_event):
                self.cleanup_event = cleanup_event

            @gen.coroutine
            def get(self):
                self.flush()
                yield self.cleanup_event.wait()

            def on_connection_close(self):
                # This is not very realistic, but finishing the request
                # from the close callback has the right timing to mimic
                # some errors seen in the wild.
                self.finish("closed")

        self.cleanup_event = Event()
        return Application([
            ("/", HelloHandler),
            ("/large", LargeHandler),
            (
                "/finish_on_close",
                FinishOnCloseHandler,
                dict(cleanup_event=self.cleanup_event),
            ),
        ])

    def setUp(self):
        super(KeepAliveTest, self).setUp()
        self.http_version = b"HTTP/1.1"

    def tearDown(self):
        # We just closed the client side of the socket; let the IOLoop run
        # once to make sure the server side got the message.
        self.io_loop.add_timeout(datetime.timedelta(seconds=0.001), self.stop)
        self.wait()

        if hasattr(self, "stream"):
            self.stream.close()
        super(KeepAliveTest, self).tearDown()

    # The next few methods are a crude manual http client
    @gen.coroutine
    def connect(self):
        self.stream = IOStream(socket.socket())
        yield self.stream.connect(("10.0.0.7", self.get_http_port()))

    @gen.coroutine
    def read_headers(self):
        first_line = yield self.stream.read_until(b"\r\n")
        self.assertTrue(first_line.startswith(b"HTTP/1.1 200"), first_line)
        header_bytes = yield self.stream.read_until(b"\r\n\r\n")
        headers = HTTPHeaders.parse(header_bytes.decode("latin1"))
        raise gen.Return(headers)

    @gen.coroutine
    def read_response(self):
        self.headers = yield self.read_headers()
        body = yield self.stream.read_bytes(int(
            self.headers["Content-Length"]))
        self.assertEqual(b"Hello world", body)

    def close(self):
        self.stream.close()
        del self.stream

    @gen_test
    def test_two_requests(self):
        yield self.connect()
        self.stream.write(b"GET / HTTP/1.1\r\n\r\n")
        yield self.read_response()
        self.stream.write(b"GET / HTTP/1.1\r\n\r\n")
        yield self.read_response()
        self.close()

    @gen_test
    def test_request_close(self):
        yield self.connect()
        self.stream.write(b"GET / HTTP/1.1\r\nConnection: close\r\n\r\n")
        yield self.read_response()
        data = yield self.stream.read_until_close()
        self.assertTrue(not data)
        self.assertEqual(self.headers["Connection"], "close")
        self.close()

    # keepalive is supported for http 1.0 too, but it's opt-in
    @gen_test
    def test_http10(self):
        self.http_version = b"HTTP/1.0"
        yield self.connect()
        self.stream.write(b"GET / HTTP/1.0\r\n\r\n")
        yield self.read_response()
        data = yield self.stream.read_until_close()
        self.assertTrue(not data)
        self.assertTrue("Connection" not in self.headers)
        self.close()

    @gen_test
    def test_http10_keepalive(self):
        self.http_version = b"HTTP/1.0"
        yield self.connect()
        self.stream.write(b"GET / HTTP/1.0\r\nConnection: keep-alive\r\n\r\n")
        yield self.read_response()
        self.assertEqual(self.headers["Connection"], "Keep-Alive")
        self.stream.write(b"GET / HTTP/1.0\r\nConnection: keep-alive\r\n\r\n")
        yield self.read_response()
        self.assertEqual(self.headers["Connection"], "Keep-Alive")
        self.close()

    @gen_test
    def test_http10_keepalive_extra_crlf(self):
        self.http_version = b"HTTP/1.0"
        yield self.connect()
        self.stream.write(
            b"GET / HTTP/1.0\r\nConnection: keep-alive\r\n\r\n\r\n")
        yield self.read_response()
        self.assertEqual(self.headers["Connection"], "Keep-Alive")
        self.stream.write(b"GET / HTTP/1.0\r\nConnection: keep-alive\r\n\r\n")
        yield self.read_response()
        self.assertEqual(self.headers["Connection"], "Keep-Alive")
        self.close()

    @gen_test
    def test_pipelined_requests(self):
        yield self.connect()
        self.stream.write(b"GET / HTTP/1.1\r\n\r\nGET / HTTP/1.1\r\n\r\n")
        yield self.read_response()
        yield self.read_response()
        self.close()

    @gen_test
    def test_pipelined_cancel(self):
        yield self.connect()
        self.stream.write(b"GET / HTTP/1.1\r\n\r\nGET / HTTP/1.1\r\n\r\n")
        # only read once
        yield self.read_response()
        self.close()

    @gen_test
    def test_cancel_during_download(self):
        yield self.connect()
        self.stream.write(b"GET /large HTTP/1.1\r\n\r\n")
        yield self.read_headers()
        yield self.stream.read_bytes(1024)
        self.close()

    @gen_test
    def test_finish_while_closed(self):
        yield self.connect()
        self.stream.write(b"GET /finish_on_close HTTP/1.1\r\n\r\n")
        yield self.read_headers()
        self.close()
        # Let the hanging coroutine clean up after itself
        self.cleanup_event.set()

    @gen_test
    def test_keepalive_chunked(self):
        self.http_version = b"HTTP/1.0"
        yield self.connect()
        self.stream.write(b"POST / HTTP/1.0\r\n"
                          b"Connection: keep-alive\r\n"
                          b"Transfer-Encoding: chunked\r\n"
                          b"\r\n"
                          b"0\r\n"
                          b"\r\n")
        yield self.read_response()
        self.assertEqual(self.headers["Connection"], "Keep-Alive")
        self.stream.write(b"GET / HTTP/1.0\r\nConnection: keep-alive\r\n\r\n")
        yield self.read_response()
        self.assertEqual(self.headers["Connection"], "Keep-Alive")
        self.close()
Ejemplo n.º 39
0
 def accept_callback(conn, address):
     stream = IOStream(conn)
     stream.read_until(b"\r\n\r\n",
                       functools.partial(write_response, stream))
Ejemplo n.º 40
0
class _HTTPConnection(object):
    _SUPPORTED_METHODS = set(["GET", "HEAD", "POST", "PUT", "DELETE", "PATCH", "OPTIONS"])

    def __init__(self, io_loop, client, request, release_callback, final_callback, max_buffer_size):
        self.start_time = io_loop.time()
        self.io_loop = io_loop
        self.client = client
        self.request = request
        self.release_callback = release_callback
        self.final_callback = final_callback
        self.max_buffer_size = max_buffer_size
        self.code = None
        self.headers = None
        self.chunks = None
        self._decompressor = None
        # Timeout handle returned by IOLoop.add_timeout
        self._timeout = None
        with stack_context.ExceptionStackContext(self._handle_exception):
            self.parsed = urlparse.urlsplit(_unicode(self.request.url))
            if self.parsed.scheme not in ("http", "https"):
                raise ValueError("Unsupported url scheme: %s" % self.request.url)
            # urlsplit results have hostname and port results, but they
            # didn't support ipv6 literals until python 2.7.
            netloc = self.parsed.netloc
            if "@" in netloc:
                userpass, _, netloc = netloc.rpartition("@")
            match = re.match(r"^(.+):(\d+)$", netloc)
            if match:
                host = match.group(1)
                port = int(match.group(2))
            else:
                host = netloc
                port = 443 if self.parsed.scheme == "https" else 80
            if re.match(r"^\[.*\]$", host):
                # raw ipv6 addresses in urls are enclosed in brackets
                host = host[1:-1]
            self.parsed_hostname = host  # save final host for _on_connect
            if self.client.hostname_mapping is not None:
                host = self.client.hostname_mapping.get(host, host)

            if request.allow_ipv6:
                af = socket.AF_UNSPEC
            else:
                # We only try the first IP we get from getaddrinfo,
                # so restrict to ipv4 by default.
                af = socket.AF_INET

            self.client.resolver.getaddrinfo(host, port, af, socket.SOCK_STREAM, 0, 0, callback=self._on_resolve)

    def _on_resolve(self, future):
        af, socktype, proto, canonname, sockaddr = future.result()[0]

        if self.parsed.scheme == "https":
            ssl_options = {}
            if self.request.validate_cert:
                ssl_options["cert_reqs"] = ssl.CERT_REQUIRED
            if self.request.ca_certs is not None:
                ssl_options["ca_certs"] = self.request.ca_certs
            else:
                ssl_options["ca_certs"] = _DEFAULT_CA_CERTS
            if self.request.client_key is not None:
                ssl_options["keyfile"] = self.request.client_key
            if self.request.client_cert is not None:
                ssl_options["certfile"] = self.request.client_cert

            # SSL interoperability is tricky.  We want to disable
            # SSLv2 for security reasons; it wasn't disabled by default
            # until openssl 1.0.  The best way to do this is to use
            # the SSL_OP_NO_SSLv2, but that wasn't exposed to python
            # until 3.2.  Python 2.7 adds the ciphers argument, which
            # can also be used to disable SSLv2.  As a last resort
            # on python 2.6, we set ssl_version to SSLv3.  This is
            # more narrow than we'd like since it also breaks
            # compatibility with servers configured for TLSv1 only,
            # but nearly all servers support SSLv3:
            # http://blog.ivanristic.com/2011/09/ssl-survey-protocol-support.html
            if sys.version_info >= (2, 7):
                ssl_options["ciphers"] = "DEFAULT:!SSLv2"
            else:
                # This is really only necessary for pre-1.0 versions
                # of openssl, but python 2.6 doesn't expose version
                # information.
                ssl_options["ssl_version"] = ssl.PROTOCOL_SSLv3

            self.stream = SSLIOStream(
                socket.socket(af, socktype, proto),
                io_loop=self.io_loop,
                ssl_options=ssl_options,
                max_buffer_size=self.max_buffer_size,
            )
        else:
            self.stream = IOStream(
                socket.socket(af, socktype, proto), io_loop=self.io_loop, max_buffer_size=self.max_buffer_size
            )
        timeout = min(self.request.connect_timeout, self.request.request_timeout)
        if timeout:
            self._timeout = self.io_loop.add_timeout(self.start_time + timeout, stack_context.wrap(self._on_timeout))
        self.stream.set_close_callback(self._on_close)
        # ipv6 addresses are broken (in self.parsed.hostname) until
        # 2.7, here is correctly parsed value calculated in __init__
        self.stream.connect(sockaddr, self._on_connect, server_hostname=self.parsed_hostname)

    def _on_timeout(self):
        self._timeout = None
        if self.final_callback is not None:
            raise HTTPError(599, "Timeout")

    def _on_connect(self):
        if self._timeout is not None:
            self.io_loop.remove_timeout(self._timeout)
            self._timeout = None
        if self.request.request_timeout:
            self._timeout = self.io_loop.add_timeout(
                self.start_time + self.request.request_timeout, stack_context.wrap(self._on_timeout)
            )
        if self.request.method not in self._SUPPORTED_METHODS and not self.request.allow_nonstandard_methods:
            raise KeyError("unknown method %s" % self.request.method)
        for key in ("network_interface", "proxy_host", "proxy_port", "proxy_username", "proxy_password"):
            if getattr(self.request, key, None):
                raise NotImplementedError("%s not supported" % key)
        if "Connection" not in self.request.headers:
            self.request.headers["Connection"] = "close"
        if "Host" not in self.request.headers:
            if "@" in self.parsed.netloc:
                self.request.headers["Host"] = self.parsed.netloc.rpartition("@")[-1]
            else:
                self.request.headers["Host"] = self.parsed.netloc
        username, password = None, None
        if self.parsed.username is not None:
            username, password = self.parsed.username, self.parsed.password
        elif self.request.auth_username is not None:
            username = self.request.auth_username
            password = self.request.auth_password or ""
        if username is not None:
            auth = utf8(username) + b":" + utf8(password)
            self.request.headers["Authorization"] = b"Basic " + base64.b64encode(auth)
        if self.request.user_agent:
            self.request.headers["User-Agent"] = self.request.user_agent
        if not self.request.allow_nonstandard_methods:
            if self.request.method in ("POST", "PATCH", "PUT"):
                assert self.request.body is not None
            else:
                assert self.request.body is None
        if self.request.body is not None:
            self.request.headers["Content-Length"] = str(len(self.request.body))
        if self.request.method == "POST" and "Content-Type" not in self.request.headers:
            self.request.headers["Content-Type"] = "application/x-www-form-urlencoded"
        if self.request.use_gzip:
            self.request.headers["Accept-Encoding"] = "gzip"
        req_path = (self.parsed.path or "/") + (("?" + self.parsed.query) if self.parsed.query else "")
        request_lines = [utf8("%s %s HTTP/1.1" % (self.request.method, req_path))]
        for k, v in self.request.headers.get_all():
            line = utf8(k) + b": " + utf8(v)
            if b"\n" in line:
                raise ValueError("Newline in header: " + repr(line))
            request_lines.append(line)
        self.stream.write(b"\r\n".join(request_lines) + b"\r\n\r\n")
        if self.request.body is not None:
            self.stream.write(self.request.body)
        self.stream.read_until_regex(b"\r?\n\r?\n", self._on_headers)

    def _release(self):
        if self.release_callback is not None:
            release_callback = self.release_callback
            self.release_callback = None
            release_callback()

    def _run_callback(self, response):
        self._release()
        if self.final_callback is not None:
            final_callback = self.final_callback
            self.final_callback = None
            self.io_loop.add_callback(final_callback, response)

    def _handle_exception(self, typ, value, tb):
        if self.final_callback:
            gen_log.warning("uncaught exception", exc_info=(typ, value, tb))
            self._run_callback(
                HTTPResponse(self.request, 599, error=value, request_time=self.io_loop.time() - self.start_time)
            )

            if hasattr(self, "stream"):
                self.stream.close()
            return True
        else:
            # If our callback has already been called, we are probably
            # catching an exception that is not caused by us but rather
            # some child of our callback. Rather than drop it on the floor,
            # pass it along.
            return False

    def _on_close(self):
        if self.final_callback is not None:
            message = "Connection closed"
            if self.stream.error:
                message = str(self.stream.error)
            raise HTTPError(599, message)

    def _on_headers(self, data):
        data = native_str(data.decode("latin1"))
        first_line, _, header_data = data.partition("\n")
        match = re.match("HTTP/1.[01] ([0-9]+) ([^\r]*)", first_line)
        assert match
        code = int(match.group(1))
        if 100 <= code < 200:
            self.stream.read_until_regex(b"\r?\n\r?\n", self._on_headers)
            return
        else:
            self.code = code
            self.reason = match.group(2)
        self.headers = HTTPHeaders.parse(header_data)

        if "Content-Length" in self.headers:
            if "," in self.headers["Content-Length"]:
                # Proxies sometimes cause Content-Length headers to get
                # duplicated.  If all the values are identical then we can
                # use them but if they differ it's an error.
                pieces = re.split(r",\s*", self.headers["Content-Length"])
                if any(i != pieces[0] for i in pieces):
                    raise ValueError("Multiple unequal Content-Lengths: %r" % self.headers["Content-Length"])
                self.headers["Content-Length"] = pieces[0]
            content_length = int(self.headers["Content-Length"])
        else:
            content_length = None

        if self.request.header_callback is not None:
            # re-attach the newline we split on earlier
            self.request.header_callback(first_line + _)
            for k, v in self.headers.get_all():
                self.request.header_callback("%s: %s\r\n" % (k, v))
            self.request.header_callback("\r\n")

        if self.request.method == "HEAD" or self.code == 304:
            # HEAD requests and 304 responses never have content, even
            # though they may have content-length headers
            self._on_body(b"")
            return
        if 100 <= self.code < 200 or self.code == 204:
            # These response codes never have bodies
            # http://www.w3.org/Protocols/rfc2616/rfc2616-sec4.html#sec4.3
            if "Transfer-Encoding" in self.headers or content_length not in (None, 0):
                raise ValueError("Response with code %d should not have body" % self.code)
            self._on_body(b"")
            return

        if self.request.use_gzip and self.headers.get("Content-Encoding") == "gzip":
            self._decompressor = GzipDecompressor()
        if self.headers.get("Transfer-Encoding") == "chunked":
            self.chunks = []
            self.stream.read_until(b"\r\n", self._on_chunk_length)
        elif content_length is not None:
            self.stream.read_bytes(content_length, self._on_body)
        else:
            self.stream.read_until_close(self._on_body)

    def _on_body(self, data):
        if self._timeout is not None:
            self.io_loop.remove_timeout(self._timeout)
            self._timeout = None
        original_request = getattr(self.request, "original_request", self.request)
        if self.request.follow_redirects and self.request.max_redirects > 0 and self.code in (301, 302, 303, 307):
            assert isinstance(self.request, _RequestProxy)
            new_request = copy.copy(self.request.request)
            new_request.url = urlparse.urljoin(self.request.url, self.headers["Location"])
            new_request.max_redirects = self.request.max_redirects - 1
            del new_request.headers["Host"]
            # http://www.w3.org/Protocols/rfc2616/rfc2616-sec10.html#sec10.3.4
            # Client SHOULD make a GET request after a 303.
            # According to the spec, 302 should be followed by the same
            # method as the original request, but in practice browsers
            # treat 302 the same as 303, and many servers use 302 for
            # compatibility with pre-HTTP/1.1 user agents which don't
            # understand the 303 status.
            if self.code in (302, 303):
                new_request.method = "GET"
                new_request.body = None
                for h in ["Content-Length", "Content-Type", "Content-Encoding", "Transfer-Encoding"]:
                    try:
                        del self.request.headers[h]
                    except KeyError:
                        pass
            new_request.original_request = original_request
            final_callback = self.final_callback
            self.final_callback = None
            self._release()
            self.client.fetch(new_request, final_callback)
            self.stream.close()
            return
        if self._decompressor:
            data = self._decompressor.decompress(data) + self._decompressor.flush()
        if self.request.streaming_callback:
            if self.chunks is None:
                # if chunks is not None, we already called streaming_callback
                # in _on_chunk_data
                self.request.streaming_callback(data)
            buffer = BytesIO()
        else:
            buffer = BytesIO(data)  # TODO: don't require one big string?
        response = HTTPResponse(
            original_request,
            self.code,
            reason=self.reason,
            headers=self.headers,
            request_time=self.io_loop.time() - self.start_time,
            buffer=buffer,
            effective_url=self.request.url,
        )
        self._run_callback(response)
        self.stream.close()

    def _on_chunk_length(self, data):
        # TODO: "chunk extensions" http://tools.ietf.org/html/rfc2616#section-3.6.1
        length = int(data.strip(), 16)
        if length == 0:
            if self._decompressor is not None:
                tail = self._decompressor.flush()
                if tail:
                    # I believe the tail will always be empty (i.e.
                    # decompress will return all it can).  The purpose
                    # of the flush call is to detect errors such
                    # as truncated input.  But in case it ever returns
                    # anything, treat it as an extra chunk
                    if self.request.streaming_callback is not None:
                        self.request.streaming_callback(tail)
                    else:
                        self.chunks.append(tail)
                # all the data has been decompressed, so we don't need to
                # decompress again in _on_body
                self._decompressor = None
            self._on_body(b"".join(self.chunks))
        else:
            self.stream.read_bytes(length + 2, self._on_chunk_data)  # chunk ends with \r\n

    def _on_chunk_data(self, data):
        assert data[-2:] == b"\r\n"
        chunk = data[:-2]
        if self._decompressor:
            chunk = self._decompressor.decompress(chunk)
        if self.request.streaming_callback is not None:
            self.request.streaming_callback(chunk)
        else:
            self.chunks.append(chunk)
        self.stream.read_until(b"\r\n", self._on_chunk_length)
Ejemplo n.º 41
0
class AsyncRedisClient(object):
    """An non-blocking Redis client.

    Example usage::

        import ioloop

        def handle_request(result):
            print 'Redis reply: %r' % result
            ioloop.IOLoop.instance().stop()

        redis_client = AsyncRedisClient(('127.0.0.1', 6379))
        redis_client.fetch(('set', 'foo', 'bar'), None)
        redis_client.fetch(('get', 'foo'), handle_request)
        ioloop.IOLoop.instance().start()

    This class implements a Redis client on top of Tornado's IOStreams.
    It does not currently implement all applicable parts of the Redis
    specification, but it does enough to work with major redis server APIs
    (mostly tested against the LIST/HASH/PUBSUB API so far).

    This class has not been tested extensively in production and
    should be considered somewhat experimental as of the release of
    tornado 1.2.  It is intended to become the default tornado
    AsyncRedisClient implementation.
    """

    def __init__(self, address, io_loop=None):
        """Creates a AsyncRedisClient.

        address is the tuple of redis server address that can be connect by
        IOStream. It can be to ('127.0.0.1', 6379).
        """
        self.address         = address
        self.io_loop         = io_loop or IOLoop.instance()
        self._callback_queue = deque()
        self._callback       = None
        self._read_buffer    = None
        self._result_queue   = deque()
        self.socket          = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
        self.stream          = IOStream(self.socket, self.io_loop)
        self.stream.connect(self.address, self._wait_result)

    def close(self):
        """Destroys this redis client, freeing any file descriptors used.
        Not needed in normal use, but may be helpful in unittests that
        create and destroy redis clients.  No other methods may be called
        on the AsyncRedisClient after close().
        """
        self.stream.close()

    def fetch(self, request, callback):
        """Executes a request, calling callback with an redis `result`.

        The request shuold be a string tuple. like ('set', 'foo', 'bar')

        If an error occurs during the fetch, a `RedisError` exception will
        throw out. You can use try...except to catch the exception (if any)
        in the callback.
        """
        self._callback_queue.append(callback)
        self.stream.write(encode(request))

    def _wait_result(self):
        """Read a completed result data from the redis server."""
        self._read_buffer = deque()
        self.stream.read_until('\r\n', self._on_read_first_line)

    def _maybe_callback(self):
        """Try call callback in _callback_queue when we read a redis result."""
        try:
            read_buffer    = self._read_buffer
            callback       = self._callback
            result_queue   = self._result_queue
            callback_queue = self._callback_queue
            if result_queue:
                result_queue.append(read_buffer)
                read_buffer = result_queue.popleft()
            if callback_queue:
                callback = self._callback = callback_queue.popleft()
            if callback:
                callback(decode(read_buffer))
        except Exception:
            logging.error('Uncaught callback exception', exc_info=True)
            self.close()
            raise
        finally:
            self._wait_result()

    def _on_read_first_line(self, data):
        self._read_buffer.append(data)
        c = data[0]
        if c in ':+-':
            self._maybe_callback()
        elif c == '$':
            if data[:3] == '$-1':
                self._maybe_callback()
            else:
                length = int(data[1:])
                self.stream.read_bytes(length+2, self._on_read_bulk_body)
        elif c == '*':
            if data[1] in '-0' :
                self._maybe_callback()
            else:
                self._multibulk_number = int(data[1:])
                self.stream.read_until('\r\n', self._on_read_multibulk_bulk_head)

    def _on_read_bulk_body(self, data):
        self._read_buffer.append(data)
        self._maybe_callback()

    def _on_read_multibulk_bulk_head(self, data):
        self._read_buffer.append(data)
        c = data[0]
        if c == '$':
            length = int(data[1:])
            self.stream.read_bytes(length+2, self._on_read_multibulk_bulk_body)
        else:
            self._maybe_callback()

    def _on_read_multibulk_bulk_body(self, data):
        self._read_buffer.append(data)
        self._multibulk_number -= 1
        if self._multibulk_number:
            self.stream.read_until('\r\n', self._on_read_multibulk_bulk_head)
        else:
            self._maybe_callback()
class _HTTPConnection(object):
    _SUPPORTED_METHODS = set(
        ["GET", "HEAD", "POST", "PUT", "DELETE", "PATCH", "OPTIONS"])

    def __init__(self, io_loop, client, request, release_callback,
                 final_callback, max_buffer_size, resolver):
        self.start_time = io_loop.time()
        self.io_loop = io_loop
        self.client = client
        self.request = request
        self.release_callback = release_callback
        self.final_callback = final_callback
        self.max_buffer_size = max_buffer_size
        self.resolver = resolver
        self.code = None
        self.headers = None
        self.chunks = None
        self._decompressor = None
        # Timeout handle returned by IOLoop.add_timeout
        self._timeout = None
        with stack_context.ExceptionStackContext(self._handle_exception):
            self.parsed = urlparse.urlsplit(_unicode(self.request.url))
            if self.parsed.scheme not in ("http", "https"):
                raise ValueError("Unsupported url scheme: %s" %
                                 self.request.url)
            # urlsplit results have hostname and port results, but they
            # didn't support ipv6 literals until python 2.7.
            netloc = self.parsed.netloc
            if "@" in netloc:
                userpass, _, netloc = netloc.rpartition("@")
            match = re.match(r'^(.+):(\d+)$', netloc)
            if match:
                host = match.group(1)
                port = int(match.group(2))
            else:
                host = netloc
                port = 443 if self.parsed.scheme == "https" else 80
            if re.match(r'^\[.*\]$', host):
                # raw ipv6 addresses in urls are enclosed in brackets
                host = host[1:-1]
            self.parsed_hostname = host  # save final host for _on_connect

            if request.allow_ipv6:
                af = socket.AF_UNSPEC
            else:
                # We only try the first IP we get from getaddrinfo,
                # so restrict to ipv4 by default.
                af = socket.AF_INET

            self.resolver.resolve(host, port, af, callback=self._on_resolve)

    def _on_resolve(self, addrinfo):
        af, sockaddr = addrinfo[0]

        if self.parsed.scheme == "https":
            ssl_options = {}
            if self.request.validate_cert:
                ssl_options["cert_reqs"] = ssl.CERT_REQUIRED
            if self.request.ca_certs is not None:
                ssl_options["ca_certs"] = self.request.ca_certs
            else:
                ssl_options["ca_certs"] = _DEFAULT_CA_CERTS
            if self.request.client_key is not None:
                ssl_options["keyfile"] = self.request.client_key
            if self.request.client_cert is not None:
                ssl_options["certfile"] = self.request.client_cert

            # SSL interoperability is tricky.  We want to disable
            # SSLv2 for security reasons; it wasn't disabled by default
            # until openssl 1.0.  The best way to do this is to use
            # the SSL_OP_NO_SSLv2, but that wasn't exposed to python
            # until 3.2.  Python 2.7 adds the ciphers argument, which
            # can also be used to disable SSLv2.  As a last resort
            # on python 2.6, we set ssl_version to SSLv3.  This is
            # more narrow than we'd like since it also breaks
            # compatibility with servers configured for TLSv1 only,
            # but nearly all servers support SSLv3:
            # http://blog.ivanristic.com/2011/09/ssl-survey-protocol-support.html
            if sys.version_info >= (2, 7):
                ssl_options["ciphers"] = "DEFAULT:!SSLv2"
            else:
                # This is really only necessary for pre-1.0 versions
                # of openssl, but python 2.6 doesn't expose version
                # information.
                ssl_options["ssl_version"] = ssl.PROTOCOL_SSLv3

            self.stream = SSLIOStream(socket.socket(af),
                                      io_loop=self.io_loop,
                                      ssl_options=ssl_options,
                                      max_buffer_size=self.max_buffer_size)
        else:
            self.stream = IOStream(socket.socket(af),
                                   io_loop=self.io_loop,
                                   max_buffer_size=self.max_buffer_size)
        timeout = min(self.request.connect_timeout,
                      self.request.request_timeout)
        if timeout:
            self._timeout = self.io_loop.add_timeout(
                self.start_time + timeout,
                stack_context.wrap(self._on_timeout))
        self.stream.set_close_callback(self._on_close)
        # ipv6 addresses are broken (in self.parsed.hostname) until
        # 2.7, here is correctly parsed value calculated in __init__
        self.stream.connect(sockaddr,
                            self._on_connect,
                            server_hostname=self.parsed_hostname)

    def _on_timeout(self):
        self._timeout = None
        if self.final_callback is not None:
            raise HTTPError(599, "Timeout")

    def _remove_timeout(self):
        if self._timeout is not None:
            self.io_loop.remove_timeout(self._timeout)
            self._timeout = None

    def _on_connect(self):
        self._remove_timeout()
        if self.request.request_timeout:
            self._timeout = self.io_loop.add_timeout(
                self.start_time + self.request.request_timeout,
                stack_context.wrap(self._on_timeout))
        if (self.request.method not in self._SUPPORTED_METHODS
                and not self.request.allow_nonstandard_methods):
            raise KeyError("unknown method %s" % self.request.method)
        for key in ('network_interface', 'proxy_host', 'proxy_port',
                    'proxy_username', 'proxy_password'):
            if getattr(self.request, key, None):
                raise NotImplementedError('%s not supported' % key)
        if "Connection" not in self.request.headers:
            self.request.headers["Connection"] = "close"
        if "Host" not in self.request.headers:
            if '@' in self.parsed.netloc:
                self.request.headers["Host"] = self.parsed.netloc.rpartition(
                    '@')[-1]
            else:
                self.request.headers["Host"] = self.parsed.netloc
        username, password = None, None
        if self.parsed.username is not None:
            username, password = self.parsed.username, self.parsed.password
        elif self.request.auth_username is not None:
            username = self.request.auth_username
            password = self.request.auth_password or ''
        if username is not None:
            auth = utf8(username) + b":" + utf8(password)
            self.request.headers["Authorization"] = (b"Basic " +
                                                     base64.b64encode(auth))
        if self.request.user_agent:
            self.request.headers["User-Agent"] = self.request.user_agent
        if not self.request.allow_nonstandard_methods:
            if self.request.method in ("POST", "PATCH", "PUT"):
                assert self.request.body is not None
            else:
                assert self.request.body is None
        if self.request.body is not None:
            self.request.headers["Content-Length"] = str(len(
                self.request.body))
        if (self.request.method == "POST"
                and "Content-Type" not in self.request.headers):
            self.request.headers[
                "Content-Type"] = "application/x-www-form-urlencoded"
        if self.request.use_gzip:
            self.request.headers["Accept-Encoding"] = "gzip"
        req_path = ((self.parsed.path or '/') +
                    (('?' + self.parsed.query) if self.parsed.query else ''))
        request_lines = [
            utf8("%s %s HTTP/1.1" % (self.request.method, req_path))
        ]
        for k, v in self.request.headers.get_all():
            line = utf8(k) + b": " + utf8(v)
            if b'\n' in line:
                raise ValueError('Newline in header: ' + repr(line))
            request_lines.append(line)
        self.stream.write(b"\r\n".join(request_lines) + b"\r\n\r\n")
        if self.request.body is not None:
            self.stream.write(self.request.body)
        self.stream.read_until_regex(b"\r?\n\r?\n", self._on_headers)

    def _release(self):
        if self.release_callback is not None:
            release_callback = self.release_callback
            self.release_callback = None
            release_callback()

    def _run_callback(self, response):
        self._release()
        if self.final_callback is not None:
            final_callback = self.final_callback
            self.final_callback = None
            self.io_loop.add_callback(final_callback, response)

    def _handle_exception(self, typ, value, tb):
        if self.final_callback:
            self._remove_timeout()
            gen_log.warning("uncaught exception", exc_info=(typ, value, tb))
            self._run_callback(
                HTTPResponse(
                    self.request,
                    599,
                    error=value,
                    request_time=self.io_loop.time() - self.start_time,
                ))

            if hasattr(self, "stream"):
                self.stream.close()
            return True
        else:
            # If our callback has already been called, we are probably
            # catching an exception that is not caused by us but rather
            # some child of our callback. Rather than drop it on the floor,
            # pass it along.
            return False

    def _on_close(self):
        if self.final_callback is not None:
            message = "Connection closed"
            if self.stream.error:
                message = str(self.stream.error)
            raise HTTPError(599, message)

    def _handle_1xx(self, code):
        self.stream.read_until_regex(b"\r?\n\r?\n", self._on_headers)

    def _on_headers(self, data):
        data = native_str(data.decode("latin1"))
        first_line, _, header_data = data.partition("\n")
        match = re.match("HTTP/1.[01] ([0-9]+) ([^\r]*)", first_line)
        assert match
        code = int(match.group(1))
        self.headers = HTTPHeaders.parse(header_data)
        if 100 <= code < 200:
            self._handle_1xx(code)
            return
        else:
            self.code = code
            self.reason = match.group(2)

        if "Content-Length" in self.headers:
            if "," in self.headers["Content-Length"]:
                # Proxies sometimes cause Content-Length headers to get
                # duplicated.  If all the values are identical then we can
                # use them but if they differ it's an error.
                pieces = re.split(r',\s*', self.headers["Content-Length"])
                if any(i != pieces[0] for i in pieces):
                    raise ValueError("Multiple unequal Content-Lengths: %r" %
                                     self.headers["Content-Length"])
                self.headers["Content-Length"] = pieces[0]
            content_length = int(self.headers["Content-Length"])
        else:
            content_length = None

        if self.request.header_callback is not None:
            # re-attach the newline we split on earlier
            self.request.header_callback(first_line + _)
            for k, v in self.headers.get_all():
                self.request.header_callback("%s: %s\r\n" % (k, v))
            self.request.header_callback('\r\n')

        if self.request.method == "HEAD" or self.code == 304:
            # HEAD requests and 304 responses never have content, even
            # though they may have content-length headers
            self._on_body(b"")
            return
        if 100 <= self.code < 200 or self.code == 204:
            # These response codes never have bodies
            # http://www.w3.org/Protocols/rfc2616/rfc2616-sec4.html#sec4.3
            if ("Transfer-Encoding" in self.headers
                    or content_length not in (None, 0)):
                raise ValueError("Response with code %d should not have body" %
                                 self.code)
            self._on_body(b"")
            return

        if (self.request.use_gzip
                and self.headers.get("Content-Encoding") == "gzip"):
            self._decompressor = GzipDecompressor()
        if self.headers.get("Transfer-Encoding") == "chunked":
            self.chunks = []
            self.stream.read_until(b"\r\n", self._on_chunk_length)
        elif content_length is not None:
            self.stream.read_bytes(content_length, self._on_body)
        else:
            self.stream.read_until_close(self._on_body)

    def _on_body(self, data):
        self._remove_timeout()
        original_request = getattr(self.request, "original_request",
                                   self.request)
        if (self.request.follow_redirects and self.request.max_redirects > 0
                and self.code in (301, 302, 303, 307)):
            assert isinstance(self.request, _RequestProxy)
            new_request = copy.copy(self.request.request)
            new_request.url = urlparse.urljoin(self.request.url,
                                               self.headers["Location"])
            new_request.max_redirects = self.request.max_redirects - 1
            del new_request.headers["Host"]
            # http://www.w3.org/Protocols/rfc2616/rfc2616-sec10.html#sec10.3.4
            # Client SHOULD make a GET request after a 303.
            # According to the spec, 302 should be followed by the same
            # method as the original request, but in practice browsers
            # treat 302 the same as 303, and many servers use 302 for
            # compatibility with pre-HTTP/1.1 user agents which don't
            # understand the 303 status.
            if self.code in (302, 303):
                new_request.method = "GET"
                new_request.body = None
                for h in [
                        "Content-Length", "Content-Type", "Content-Encoding",
                        "Transfer-Encoding"
                ]:
                    try:
                        del self.request.headers[h]
                    except KeyError:
                        pass
            new_request.original_request = original_request
            final_callback = self.final_callback
            self.final_callback = None
            self._release()
            self.client.fetch(new_request, final_callback)
            self.stream.close()
            return
        if self._decompressor:
            data = (self._decompressor.decompress(data) +
                    self._decompressor.flush())
        if self.request.streaming_callback:
            if self.chunks is None:
                # if chunks is not None, we already called streaming_callback
                # in _on_chunk_data
                self.request.streaming_callback(data)
            buffer = BytesIO()
        else:
            buffer = BytesIO(data)  # TODO: don't require one big string?
        response = HTTPResponse(original_request,
                                self.code,
                                reason=self.reason,
                                headers=self.headers,
                                request_time=self.io_loop.time() -
                                self.start_time,
                                buffer=buffer,
                                effective_url=self.request.url)
        self._run_callback(response)
        self.stream.close()

    def _on_chunk_length(self, data):
        # TODO: "chunk extensions" http://tools.ietf.org/html/rfc2616#section-3.6.1
        length = int(data.strip(), 16)
        if length == 0:
            if self._decompressor is not None:
                tail = self._decompressor.flush()
                if tail:
                    # I believe the tail will always be empty (i.e.
                    # decompress will return all it can).  The purpose
                    # of the flush call is to detect errors such
                    # as truncated input.  But in case it ever returns
                    # anything, treat it as an extra chunk
                    if self.request.streaming_callback is not None:
                        self.request.streaming_callback(tail)
                    else:
                        self.chunks.append(tail)
                # all the data has been decompressed, so we don't need to
                # decompress again in _on_body
                self._decompressor = None
            self._on_body(b''.join(self.chunks))
        else:
            self.stream.read_bytes(
                length + 2,  # chunk ends with \r\n
                self._on_chunk_data)

    def _on_chunk_data(self, data):
        assert data[-2:] == b"\r\n"
        chunk = data[:-2]
        if self._decompressor:
            chunk = self._decompressor.decompress(chunk)
        if self.request.streaming_callback is not None:
            self.request.streaming_callback(chunk)
        else:
            self.chunks.append(chunk)
        self.stream.read_until(b"\r\n", self._on_chunk_length)
Ejemplo n.º 43
0
class PushMessageReceiver(object):
    def __init__(self):
        self.connected = False
        self._stream = None
        self.address = None
        self.family = None
        self.delayed_call = None
        self._registered_handlers = {}

    def register_message_handler(self, message_type, handler):
        self._registered_handlers[message_type] = handler

    def unregister_message_handler(self, message_type):
        try:
            del self._registered_handlers[message_type]
        except KeyError:
            pass

    def unregister_all(self):
        self._registered_handlers = {}

    def connect(self, address, family=socket.AF_INET, retry_interval=1):
        self.address = address
        self.family = family
        self.retry_interval = retry_interval
        self._connect()

    def _connect(self):
        if self.connected:
            if self.delayed_call:
                IOLoop.current().remove_timeout(self.delayed_call)
                self.delayed_call = None
            else:
                raise RuntimeError("Already connected")
        else:
            self.delayed_call = IOLoop.current().call_later(1, self._connect)
            stream_socket = socket.socket(family=self.family)
            self._stream = IOStream(stream_socket)
            self._stream.connect(self.address, self._on_connect)

    def _on_connect(self):
        self.connected = True
        self._stream.set_close_callback(self._on_closed)
        self._stream.read_until('\n', self._handle_greeting)

    def _on_closed(self):
        if self.connected:
            # The other side closed us, try to reconnect
            self.connected = False
            self.connect(self.address, self.family)

    def _handle_greeting(self, greetings):
        self._stream.read_until('\n', self._handle_message)

    def _handle_message(self, raw_message):
        message = json.loads(raw_message.strip())
        try:
            handler = self._registered_handlers[message['type']]
            handler(message['content'])
        except KeyError:
            pass
        finally:
            # continue getting messages as long as we are connected
            if self.connected:
                self._stream.read_until('\n', self._handle_message)

    def close(self):
        if self._stream:
            self.connected = False
            self._stream.close()
            self._stream = None
Ejemplo n.º 44
0
class _HTTPConnection(object):
    _SUPPORTED_METHODS = set(["GET", "HEAD", "POST", "PUT", "DELETE"])

    def __init__(self, io_loop, client, request, callback):
        self.start_time = time.time()
        self.io_loop = io_loop
        self.client = client
        self.request = request
        self.callback = callback
        self.code = None
        self.headers = None
        self.chunks = None
        self._decompressor = None
        # Timeout handle returned by IOLoop.add_timeout
        self._timeout = None
        with stack_context.StackContext(self.cleanup):
            parsed = urlparse.urlsplit(self.request.url)
            host = parsed.hostname
            if parsed.port is None:
                port = 443 if parsed.scheme == "https" else 80
            else:
                port = parsed.port
            if self.client.hostname_mapping is not None:
                host = self.client.hostname_mapping.get(host, host)

            if parsed.scheme == "https":
                ssl_options = {}
                if request.validate_cert:
                    ssl_options["cert_reqs"] = ssl.CERT_REQUIRED
                if request.ca_certs is not None:
                    ssl_options["ca_certs"] = request.ca_certs
                else:
                    ssl_options["ca_certs"] = _DEFAULT_CA_CERTS
                self.stream = SSLIOStream(socket.socket(),
                                          io_loop=self.io_loop,
                                          ssl_options=ssl_options)
            else:
                self.stream = IOStream(socket.socket(), io_loop=self.io_loop)
            timeout = min(request.connect_timeout, request.request_timeout)
            if timeout:
                self._connect_timeout = self.io_loop.add_timeout(
                    self.start_time + timeout, self._on_timeout)
            self.stream.set_close_callback(self._on_close)
            self.stream.connect((host, port),
                                functools.partial(self._on_connect, parsed))

    def _on_timeout(self):
        self._timeout = None
        if self.callback is not None:
            self.callback(
                HTTPResponse(self.request,
                             599,
                             error=HTTPError(599, "Timeout")))
            self.callback = None
        self.stream.close()

    def _on_connect(self, parsed):
        if self._timeout is not None:
            self.io_loop.remove_callback(self._timeout)
            self._timeout = None
        if self.request.request_timeout:
            self._timeout = self.io_loop.add_timeout(
                self.start_time + self.request.request_timeout,
                self._on_timeout)
        if (self.request.validate_cert
                and isinstance(self.stream, SSLIOStream)):
            match_hostname(self.stream.socket.getpeercert(), parsed.hostname)
        if (self.request.method not in self._SUPPORTED_METHODS
                and not self.request.allow_nonstandard_methods):
            raise KeyError("unknown method %s" % self.request.method)
        for key in ('network_interface', 'proxy_host', 'proxy_port',
                    'proxy_username', 'proxy_password'):
            if getattr(self.request, key, None):
                raise NotImplementedError('%s not supported' % key)
        if "Host" not in self.request.headers:
            self.request.headers["Host"] = parsed.netloc
        username, password = None, None
        if parsed.username is not None:
            username, password = parsed.username, parsed.password
        elif self.request.auth_username is not None:
            username = self.request.auth_username
            password = self.request.auth_password
        if username is not None:
            auth = "%s:%s" % (username, password)
            self.request.headers["Authorization"] = ("Basic %s" %
                                                     base64.b64encode(auth))
        if self.request.user_agent:
            self.request.headers["User-Agent"] = self.request.user_agent
        has_body = self.request.method in ("POST", "PUT")
        if has_body:
            assert self.request.body is not None
            self.request.headers["Content-Length"] = len(self.request.body)
        else:
            assert self.request.body is None
        if (self.request.method == "POST"
                and "Content-Type" not in self.request.headers):
            self.request.headers[
                "Content-Type"] = "application/x-www-form-urlencoded"
        if self.request.use_gzip:
            self.request.headers["Accept-Encoding"] = "gzip"
        req_path = ((parsed.path or '/') +
                    (('?' + parsed.query) if parsed.query else ''))
        request_lines = ["%s %s HTTP/1.1" % (self.request.method, req_path)]
        for k, v in self.request.headers.get_all():
            line = "%s: %s" % (k, v)
            if '\n' in line:
                raise ValueError('Newline in header: ' + repr(line))
            request_lines.append(line)
        self.stream.write("\r\n".join(request_lines) + "\r\n\r\n")
        if has_body:
            self.stream.write(self.request.body)
        self.stream.read_until("\r\n\r\n", self._on_headers)

    @contextlib.contextmanager
    def cleanup(self):
        try:
            yield
        except Exception, e:
            logging.warning("uncaught exception", exc_info=True)
            if self.callback is not None:
                callback = self.callback
                self.callback = None
                callback(HTTPResponse(self.request, 599, error=e))
Ejemplo n.º 45
0
class Connection(object):
    """
    Encapsulates the communication, including parsing, with the beanstalkd
    """

    def __init__(self, host, port, io_loop=None):
        self._ioloop = io_loop or IOLoop.instance()

        # setup our protocol callbacks
        # beanstalkd will reply with a superset of these replies, but these
        # are the only ones we handle today.  patches gleefully accepted.
        self._beanstalk_protocol_1x = dict(
            # generic returns
            OUT_OF_MEMORY = self.fail,
            INTERNAL_ERROR = self.fail,
            DRAINING = self.fail,
            BAD_FORMAT = self.fail,
            UNKNOWN_COMMAND = self.fail,
            # put <pri> <delay> <ttr> <bytes>
            INSERTED = self.ret_inserted,
            BURIED = self.ret_inserted,
            EXPECTED_CRLF = self.fail,
            JOB_TOO_BIG = self.fail,
            # use
            USING = None,
            # reserve
            RESERVED = self.ret_reserved,
            DEADLINE_SOON = None,
            TIMED_OUT = None,
            # delete <id>
            DELETED = None,
            NOT_FOUND = None,
            # touch <id>
            TOUCHED = None,
            # watch <tube>
            WATCHING = None,
            #ignore <tube>
            NOT_IGNORED = None,
            )

        # open a connection to the beanstalkd
        _sock = socket.socket(
                        socket.AF_INET,
                        socket.SOCK_STREAM,
                        socket.IPPROTO_TCP
                        )
        _sock.connect((host, port))
        _sock.setblocking(False)
        self.stream = IOStream(_sock, io_loop=self._ioloop)

        # i like a placeholder for this. we'll assign it later
        self.callback = None
        self.tsr = TornStalkResponse()

    def _parse_response(self, resp):
        print "parse_response"
        tokens = resp.strip().split()
        if not tokens: return
        print 'tok:', tokens[1:]
        self._beanstalk_protocol_1x.get(tokens[0])(tokens)

    def _payload_rcvd(self, payload):
        self.tsr.data = payload[:-2] # lose the \r\n
        self.callback(self.tsr) # lose the \r\n

    def _command(self, contents):
        print "sending>%s<" % contents
        self.stream.write(contents)
        self.stream.read_until('\r\n', self._parse_response)

    def cmd_put(self, body, callback, priority=10000, delay=0, ttr=1):
        """
        send the put command to the beanstalkd with a message body

        priority needs to be between 0 and 2**32. lower gets done first
        delay is number of seconds before job is available in queue
        ttr is number of seconds the job has to run by a worker

        bs: put <pri> <delay> <ttr> <bytes>
        """
        self.callback = callback
        cmd = 'put {priority} {delay} {ttr} {size}'.format(
                priority = priority,
                delay = delay,
                ttr = ttr,
                size = len(body)
                )
        payload = '{}\r\n{}\r\n'.format(cmd, body)
        self._command(payload)

    def cmd_reserve(self, callback):
        self.callback = callback
        cmd = 'reserve\r\n'
        self._command(cmd)

    def ret_inserted(self, toks):
        """ handles both INSERTED and BURIED """
        jobid = int(toks[1])
        self.callback(TornStalkResponse(data=jobid))

    def ret_reserved(self, toks):
        jobid, size = toks[1:]
        jobid = int(jobid)
        size = int(size) + 2 # len('\r\n')
        self.stream.read_bytes(size, self._payload_rcvd)

    def handle_error(self, *a):
        print "error", a
        raise TornStalkError(a)

    def ok(self, *a):
        print "ok", a
        return True

    def fail(self, toks):
        self.callback(TornStalkResponse(result=False, msg=toks[1]))
Ejemplo n.º 46
0
class TestIOStreamStartTLS(AsyncTestCase):
    def setUp(self):
        try:
            super(TestIOStreamStartTLS, self).setUp()
            self.listener, self.port = bind_unused_port()
            self.server_stream = None
            self.server_accepted = Future()  # type: Future[None]
            netutil.add_accept_handler(self.listener, self.accept)
            self.client_stream = IOStream(socket.socket())
            self.io_loop.add_future(
                self.client_stream.connect(("10.0.0.7", self.port)), self.stop)
            self.wait()
            self.io_loop.add_future(self.server_accepted, self.stop)
            self.wait()
        except Exception as e:
            print(e)
            raise

    def tearDown(self):
        if self.server_stream is not None:
            self.server_stream.close()
        if self.client_stream is not None:
            self.client_stream.close()
        self.listener.close()
        super(TestIOStreamStartTLS, self).tearDown()

    def accept(self, connection, address):
        if self.server_stream is not None:
            self.fail("should only get one connection")
        self.server_stream = IOStream(connection)
        self.server_accepted.set_result(None)

    @gen.coroutine
    def client_send_line(self, line):
        self.client_stream.write(line)
        recv_line = yield self.server_stream.read_until(b"\r\n")
        self.assertEqual(line, recv_line)

    @gen.coroutine
    def server_send_line(self, line):
        self.server_stream.write(line)
        recv_line = yield self.client_stream.read_until(b"\r\n")
        self.assertEqual(line, recv_line)

    def client_start_tls(self, ssl_options=None, server_hostname=None):
        client_stream = self.client_stream
        self.client_stream = None
        return client_stream.start_tls(False, ssl_options, server_hostname)

    def server_start_tls(self, ssl_options=None):
        server_stream = self.server_stream
        self.server_stream = None
        return server_stream.start_tls(True, ssl_options)

    @gen_test
    def test_start_tls_smtp(self):
        # This flow is simplified from RFC 3207 section 5.
        # We don't really need all of this, but it helps to make sure
        # that after realistic back-and-forth traffic the buffers end up
        # in a sane state.
        yield self.server_send_line(b"220 mail.example.com ready\r\n")
        yield self.client_send_line(b"EHLO mail.example.com\r\n")
        yield self.server_send_line(b"250-mail.example.com welcome\r\n")
        yield self.server_send_line(b"250 STARTTLS\r\n")
        yield self.client_send_line(b"STARTTLS\r\n")
        yield self.server_send_line(b"220 Go ahead\r\n")
        client_future = self.client_start_tls(dict(cert_reqs=ssl.CERT_NONE))
        server_future = self.server_start_tls(_server_ssl_options())
        self.client_stream = yield client_future
        self.server_stream = yield server_future
        self.assertTrue(isinstance(self.client_stream, SSLIOStream))
        self.assertTrue(isinstance(self.server_stream, SSLIOStream))
        yield self.client_send_line(b"EHLO mail.example.com\r\n")
        yield self.server_send_line(b"250 mail.example.com welcome\r\n")

    @gen_test
    def test_handshake_fail(self):
        server_future = self.server_start_tls(_server_ssl_options())
        # Certificates are verified with the default configuration.
        with ExpectLog(gen_log, "SSL Error"):
            client_future = self.client_start_tls(server_hostname="10.0.0.7")
            with self.assertRaises(ssl.SSLError):
                yield client_future
            with self.assertRaises((ssl.SSLError, socket.error)):
                yield server_future

    @gen_test
    def test_check_hostname(self):
        # Test that server_hostname parameter to start_tls is being used.
        # The check_hostname functionality is only available in python 2.7 and
        # up and in python 3.4 and up.
        server_future = self.server_start_tls(_server_ssl_options())
        with ExpectLog(gen_log, "SSL Error"):
            client_future = self.client_start_tls(ssl.create_default_context(),
                                                  server_hostname="10.0.0.7")
            with self.assertRaises(ssl.SSLError):
                # The client fails to connect with an SSL error.
                yield client_future
            with self.assertRaises(Exception):
                # The server fails to connect, but the exact error is unspecified.
                yield server_future
Ejemplo n.º 47
0
class IRC(object):
    # Private
    _stream = None
    _charset = None
    _ioloop = None
    _timer = None
    _last_pong = None
    _is_reconnect = 0
    _buffers = []
    _send_timer = None

    host = None
    port = None
    nick = None
    chans = []
    chans_ref = {}
    names = {}
    relaybots = []
    delims = [
        ('<', '> '),
        ('[', '] '),
        ('(', ') '),
        ('{', '} '),
    ]

    # External callbacks
    # Called when you are logined
    login_callback = None
    # Called when you received specified IRC message
    # for usage of event_callback, see `botbox.dispatch`
    event_callback = None

    def __init__(self,
                 host,
                 port,
                 nick,
                 relaybots=[],
                 charset='utf-8',
                 ioloop=False):
        logger.info('Connecting to %s:%s', host, port)

        self.host = host
        self.port = port
        self.nick = nick
        self.relaybots = relaybots

        self._charset = charset
        sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
        self._ioloop = ioloop or IOLoop.instance()
        self._stream = IOStream(sock, io_loop=self._ioloop)
        self._stream.connect((host, port), self._login)

        self._last_pong = time.time()
        self._timer = PeriodicCallback(self._keep_alive,
                                       60 * 1000,
                                       io_loop=self._ioloop)
        self._timer.start()

        self._send_timer = PeriodicCallback(self._sock_send,
                                            600,
                                            io_loop=self._ioloop)
        self._send_timer.start()

    def _sock_send(self):
        if (self._buffers[0:]):
            data = self._buffers.pop(0)
            return self._stream.write(data)

    def _period_send(self, data):
        # Data will be sent in `self._sock_send()`
        self._buffers.append(bytes(data, self._charset))

    def _sock_recv(self):
        def _recv(data):
            msg = data.decode(self._charset, 'ignore')
            msg = msg[:-2]  # strip '\r\n'
            self._recv(msg)

        try:
            self._stream.read_until(b'\r\n', _recv)
        except Exception as err:
            logger.error('Read error: %s', err)
            self._reconnect()

    def _reconnect(self):
        logger.info('Reconnecting...')

        self._is_reconnect = 1

        self._stream.close()
        sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
        self._stream = IOStream(sock, io_loop=self._ioloop)
        self._stream.connect((self.host, self.port), self._login)

    # IRC message parser, return tuple (IRCMsgType, IRCMsg)
    def _parse(self, msg):
        if msg.startswith('PING :'):
            logger.debug('PING')
            return (IRCMsgType.PING, None)
        elif msg.startswith('NOTICE AUTH :'):
            logger.debug('NOTIC AUTH: "%s"')
            return (IRCMsgType.NOTICE, None)
        elif msg.startswith('ERROR :'):
            logger.debug('ERROR: "%s"', msg)
            return (IRCMsgType.ERROR, None)

        try:
            # <message> ::= [':' <prefix> <SPACE> ] <command> <params> <crlf>
            tmp = msg.split(' ', maxsplit=2)

            if len(tmp) != 3:
                raise Exception(
                    'Failed when parsing <prefix> <command> <params>')

            prefix, command, params = tmp
            logger.debug('prefix: "%s", command: "%s", params: "%s"', prefix,
                         command, params)

            # <params> ::= <SPACE> [ ':' <trailing> | <middle> <params> ]
            middle, _, trailing = params.partition(' :')
            if middle.startswith(':'):
                trailing = middle[1:]
                middle = ''
            logger.debug('middle: "%s", trailing: "%s"', middle, trailing)

            if not middle and not trailing:
                middle = trailing = ''
                # raise Exception('No <middle> and <trailing>')

            # <middle> ::= <Any *non-empty* sequence of octets not including SPACE
            #              or NUL or CR or LF, the first of which may not be ':'>
            args = middle.split(' ')
            logger.debug('args: "%s"', args)

            # <prefix> ::= <servername> | <nick> [ '!' <user> ] [ '@' <host> ]
            tmp = prefix
            nick, _, tmp = tmp.partition('!')
            user, _, host = tmp.partition('@')
            logger.debug('nick: "%s", user: "******", host: "%s"', nick, user,
                         host)

        except Exception as err:
            logger.error('Parsing error: %s', err)
            logger.error('    Message: %s', repr(msg))
            return (IRCMsgType.UNKNOW, None)
        else:
            ircmsg = IRCMsg()
            ircmsg.nick = nick[1:]  # strip ':'
            ircmsg.user = user
            ircmsg.host = host
            ircmsg.cmd = command
            ircmsg.args = args
            ircmsg.msg = trailing
            return (IRCMsgType.MSG, ircmsg)

    # Response server message
    def _resp(self, type_, ircmsg):
        if type_ == IRCMsgType.PING:
            self._pong()
        elif type_ == IRCMsgType.ERROR:
            pass
        elif type_ == IRCMsgType.MSG:
            if ircmsg.cmd == RPL_WELCOME:
                self._on_login(ircmsg.args[0])
            elif ircmsg.cmd == ERR_NICKNAMEINUSE:
                new_nick = ircmsg.args[1] + '_'
                logger.info('Nick already in use, use "%s"', new_nick)
                self._chnick(new_nick)
            elif ircmsg.cmd == 'JOIN':
                chan = ircmsg.args[0] or ircmsg.msg
                if ircmsg.nick == self.nick:
                    self.chans.append(chan)
                    self.names[chan] = set()
                    logger.info('%s has joined %s', self.nick, chan)
                self.names[chan].add(ircmsg.nick)
            elif ircmsg.cmd == 'PART':
                chan = ircmsg.args[0]
                try:
                    self.names[chan].remove(ircmsg.nick)
                except KeyError as err:
                    logger.error('KeyError: %s', err)
                    logger.error('%s %s %s %s %s %s', ircmsg.nick, ircmsg.user,
                                 ircmsg.host, ircmsg.cmd, ircmsg.args,
                                 ircmsg.msg)
                if ircmsg.nick == self.nick:
                    self.chans.remove(chan)
                    self.names[chan].clear()
                    logger.info('%s has left %s', self.nick, ircmsg.args[0])
            elif ircmsg.cmd == 'NICK':
                new_nick, old_nick = ircmsg.msg, ircmsg.nick
                for chan in self.chans:
                    if old_nick in self.names[chan]:
                        self.names[chan].remove(old_nick)
                        self.names[chan].add(new_nick)
                if old_nick == self.nick:
                    self.nick = old_nick
                    logger.info('%s is now known as %s', old_nick, new_nick)
            elif ircmsg.cmd == 'QUIT':
                nick = ircmsg.nick
                for chan in self.chans:
                    if nick in self.names[chan]:
                        self.names[chan].remove(nick)
            elif ircmsg.cmd == RPL_NAMREPLY:
                chan = ircmsg.args[2]
                names_list = [
                    x[1:] if x[0] in ['@', '+'] else x
                    for x in ircmsg.msg.split(' ')
                ]
                self.names[chan].update(names_list)
                logger.debug('NAMES: %s' % names_list)

    def _dispatch(self, type_, ircmsg):
        if type_ != IRCMsgType.MSG:
            return

        # Error message
        if ircmsg.cmd[0] in ['4', '5']:
            logger.warn('Error message: %s', ircmsg.msg)
        elif ircmsg.cmd in ['JOIN', 'PART']:
            nick, chan = ircmsg.nick, ircmsg.args[0] or ircmsg.msg
            self.event_callback(ircmsg.cmd, chan, nick)
        elif ircmsg.cmd == 'QUIT':
            nick, reason = ircmsg.nick, ircmsg.msg
            for chan in self.chans:
                if nick in self.names[chan]:
                    self.event_callback(ircmsg.cmd, chan, nick, reason)
        elif ircmsg.cmd == 'NICK':
            new_nick, old_nick = ircmsg.msg, ircmsg.nick
            for chan in self.chans:
                if old_nick in self.names[chan]:
                    self.event_callback(ircmsg.cmd, chan, old_nick, new_nick)
        elif ircmsg.cmd in ['PRIVMSG', 'NOTICE']:
            if ircmsg.msg.startswith('\x01ACTION '):
                msg = ircmsg.msg[len('\x01ACTION '):-1]
                cmd = 'ACTION'
            else:
                msg, cmd = ircmsg.msg, ircmsg.cmd
            nick, target = ircmsg.nick, ircmsg.args[0]
            self.event_callback(cmd, target, nick, msg)

            # LABOTS_MSG = ACTION or PRIVMSG or NOTICE
            # And it will:
            # - Strip IRC color codes
            # - Replace relaybot's nick with human's nick
            bot = ''
            msg = strip(msg)
            if nick in self.relaybots:
                for d in self.delims:
                    if msg.startswith(d[0]) and msg.find(d[1]) != -1:
                        bot = nick
                        nick = msg[len(d[0]):msg.find(d[1])]
                        msg = msg[msg.find(d[1]) + len(d[1]):]
                        break
            self.event_callback('LABOTS_MSG', target, bot, nick, msg)

            # LABOTS_MENTION_MSG = LABOTS_MSG + labots's nick is mentioned at
            # the head of message
            words = msg.split(' ', maxsplit=1)
            if words[0] in [self.nick + x for x in ['', ':', ',']]:
                if words[1:]:
                    msg = words[1]
                    self.event_callback('LABOTS_MENTION_MSG', target, bot,
                                        nick, msg)

    def _keep_alive(self):
        # Ping time out
        if time.time() - self._last_pong > 360:
            logger.error('Ping time out')

            self._reconnect()
            self._last_pong = time.time()

    def _recv(self, msg):
        if msg:
            type_, ircmsg = self._parse(msg)
            self._dispatch(type_, ircmsg)
            self._resp(type_, ircmsg)

        self._sock_recv()

    def _chnick(self, nick):
        self._period_send('NICK %s\r\n' % nick)

    def _on_login(self, nick):
        logger.info('You are logined as %s', nick)

        self.nick = nick
        chans = self.chans

        if not self._is_reconnect:
            self.login_callback()

        self.chans = []
        [self.join(chan, force=True) for chan in chans]

    def _login(self):
        logger.info('Try to login as "%s"', self.nick)

        self._chnick(self.nick)
        self._period_send('USER %s %s %s %s\r\n' %
                          (self.nick, 'labots', 'localhost',
                           'https://github.com/SilverRainZ/labots'))

        self._sock_recv()

    def _pong(self):
        logger.debug('Pong!')

        self._last_pong = time.time()
        self._period_send('PONG :labots!\n')

    def set_callback(self,
                     login_callback=empty_callback,
                     event_callback=empty_callback):
        self.login_callback = login_callback
        self.event_callback = event_callback

    def join(self, chan, force=False):
        if chan[0] not in ['#', '&']:
            return

        if not force:
            if chan in self.chans_ref:
                self.chans_ref[chan] += 1
                return
            self.chans_ref[chan] = 1

        logger.debug('Try to join %s', chan)
        self._period_send('JOIN %s\r\n' % chan)

    def part(self, chan):
        if chan[0] not in ['#', '&']:
            return
        if chan not in self.chans_ref:
            return

        if self.chans_ref[chan] != 1:
            self.chans_ref[chan] -= 1
            return

        self.chans_ref.pop(chan, None)

        logger.debug('Try to part %s', chan)
        self._period_send('PART %s\r\n' % chan)

    # recv_msg: Whether receive the message you sent
    def send(self, target, msg, recv_msg=True):
        lines = msg.split('\n')
        for line in lines:
            self._period_send('PRIVMSG %s :%s\r\n' % (target, line))
            # You will recv the message you sent
            if recv_msg:
                self.event_callback('PRIVMSG', target, self.nick, line)

    def action(self, target, msg):
        self._period_send('PRIVMSG %s :\1ACTION %s\1\r\n')
        # You will recv the message you sent
        self.event_callback('ACTION', target, self.nick, msg)

    def topic(self, chan, topic):
        self._period_send('TOPIC %s :%s\r\n' % (chan, topic))

    def kick(self, chan, nick, reason):
        self._period_send('KICK %s %s :%s\r\n' % (chan, nick, topic))

    def quit(self, reason='食饭'):
        self._period_send('QUIT :%s\r\n' % reason)

    def stop(self):
        logger.info('Stop')
        self.quit()
        self._stream.close()
Ejemplo n.º 48
0
class Connection(object):
    def __init__(self,
                 host,
                 port,
                 on_connect,
                 on_disconnect,
                 timeout=None,
                 io_loop=None):
        self.host = host
        self.port = port
        self.on_connect = on_connect
        self.on_disconnect = on_disconnect
        self.timeout = timeout
        self._stream = None
        self._io_loop = io_loop
        self.try_left = 2

        self.in_progress = False
        self.read_queue = []

    def connect(self):
        try:
            sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM, 0)
            sock.setsockopt(socket.SOL_TCP, socket.TCP_NODELAY, 1)
            sock.settimeout(self.timeout)
            sock.connect((self.host, self.port))
            self._stream = IOStream(sock, io_loop=self._io_loop)
            self.connected()
        except socket.error as e:
            raise ConnectionError(str(e))
        self.on_connect()

    def disconnect(self):
        if self._stream:
            try:
                self._stream.close()
            except socket.error as e:
                pass
            self._stream = None

    def write(self, data, try_left=None):
        if try_left is None:
            try_left = self.try_left
        if not self._stream:
            self.connect()
            if not self._stream:
                raise ConnectionError(
                    'Tried to write to non-existent connection')

        if try_left > 0:
            try:
                self._stream.write(data)
            except IOError:
                self.disconnect()
                self.write(data, try_left - 1)
        else:
            raise ConnectionError('Tried to write to non-existent connection')

    def read(self, length, callback):
        try:
            if not self._stream:
                self.disconnect()
                raise ConnectionError(
                    'Tried to read from non-existent connection')
            self._stream.read_bytes(length, callback)
        except IOError:
            self.on_disconnect()

    def readline(self, callback):
        try:
            if not self._stream:
                self.disconnect()
                raise ConnectionError(
                    'Tried to read from non-existent connection')
            self._stream.read_until(b'\r\n', callback)
        except Exception as e:
            self.on_disconnect()

    def try_to_perform_read(self):
        if not self.in_progress and self.read_queue:
            self.in_progress = True
            self._io_loop.add_callback(partial(self.read_queue.pop(0), None))

    @async
    def queue_wait(self, callback):
        self.read_queue.append(callback)
        self.try_to_perform_read()

    def read_done(self):
        self.in_progress = False
        self.try_to_perform_read()

    def connected(self):
        if self._stream:
            return True
        return False
Ejemplo n.º 49
0
class Connection(object):
    def __init__(self,
                 host='localhost',
                 port=6379,
                 unix_socket_path=None,
                 event_handler_proxy=None,
                 stop_after=None,
                 io_loop=None):
        self.host = host
        self.port = port
        self.unix_socket_path = unix_socket_path
        self._event_handler = event_handler_proxy
        self.timeout = stop_after
        self._stream = None
        self._io_loop = io_loop

        self.in_progress = False
        self.read_callbacks = set()
        self.ready_callbacks = deque()
        self._lock = 0
        self.info = {'db': 0, 'pass': None}

    def __del__(self):
        self.disconnect()

    def execute_pending_command(self):
        # Continue with the pending command execution
        # if all read operations are completed.
        if not self.read_callbacks and self.ready_callbacks:
            # Pop a SINGLE callback from the queue and execute it.
            # The next one will be executed from the code
            # invoked by the callback
            callback = self.ready_callbacks.popleft()
            callback()

    def ready(self):
        return (not self.read_callbacks and not self.ready_callbacks)

    def wait_until_ready(self, callback=None):
        if callback:
            if not self.ready():
                callback = stack_context.wrap(callback)
                self.ready_callbacks.append(callback)
            else:
                callback()

    def connect(self):
        if not self._stream:
            try:
                if self.unix_socket_path:
                    sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
                    sock.settimeout(self.timeout)
                    sock.connect(self.unix_socket_path)
                else:
                    sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM, 0)
                    sock.setsockopt(socket.SOL_TCP, socket.TCP_NODELAY, 1)
                    sock.settimeout(self.timeout)
                    sock.connect((self.host, self.port))
                self._stream = IOStream(sock, io_loop=self._io_loop)
                self._stream.set_close_callback(self.on_stream_close)
                self.info['db'] = 0
                self.info['pass'] = None
            except socket.error as e:
                raise ConnectionError(str(e))
            self.fire_event('on_connect')

    def on_stream_close(self):
        if self._stream:
            self.disconnect()
            callbacks = self.read_callbacks
            self.read_callbacks = set()
            for callback in callbacks:
                callback()

    def disconnect(self):
        if self._stream:
            s = self._stream
            self._stream = None
            try:
                if s.socket:
                    s.socket.shutdown(socket.SHUT_RDWR)
                s.close()
            except:
                pass

    def fire_event(self, event):
        event_handler = self._event_handler
        if event_handler:
            try:
                getattr(event_handler, event)()
            except AttributeError:
                pass

    def write(self, data, callback=None):
        if not self._stream:
            raise ConnectionError('Tried to write to '
                                  'non-existent connection')

        if callback:
            callback = stack_context.wrap(callback)
            _callback = lambda: callback(None)
            self.read_callbacks.add(_callback)
            cb = partial(self.read_callback, _callback)
        else:
            cb = None
        try:
            if PY3:
                data = bytes(data, encoding='utf-8')
            self._stream.write(data, callback=cb)
        except IOError as e:
            self.disconnect()
            raise ConnectionError(e.message)

    def read(self, length, callback=None):
        try:
            if not self._stream:
                self.disconnect()
                raise ConnectionError('Tried to read from '
                                      'non-existent connection')
            callback = stack_context.wrap(callback)
            self.read_callbacks.add(callback)
            self._stream.read_bytes(length,
                                    callback=partial(self.read_callback,
                                                     callback))
        except IOError:
            self.fire_event('on_disconnect')

    def read_callback(self, callback, *args, **kwargs):
        try:
            self.read_callbacks.remove(callback)
        except KeyError:
            pass
        callback(*args, **kwargs)

    def readline(self, callback=None):
        try:
            if not self._stream:
                self.disconnect()
                raise ConnectionError('Tried to read from '
                                      'non-existent connection')
            callback = stack_context.wrap(callback)
            self.read_callbacks.add(callback)
            callback = partial(self.read_callback, callback)
            self._stream.read_until(CRLF, callback=callback)
        except IOError:
            self.fire_event('on_disconnect')

    def connected(self):
        if self._stream:
            return True
        return False
Ejemplo n.º 50
0
class AsyncRedisClient(object):
    """An non-blocking Redis client.

    Example usage::

        import ioloop

        def handle_request(result):
            print 'Redis reply: %r' % result
            ioloop.IOLoop.instance().stop()

        redis_client = AsyncRedisClient(('127.0.0.1', 6379))
        redis_client.fetch(('set', 'foo', 'bar'), None)
        redis_client.fetch(('get', 'foo'), handle_request)
        ioloop.IOLoop.instance().start()

    This class implements a Redis client on top of Tornado's IOStreams.
    It does not currently implement all applicable parts of the Redis
    specification, but it does enough to work with major redis server APIs
    (mostly tested against the LIST/HASH/PUBSUB API so far).

    This class has not been tested extensively in production and
    should be considered somewhat experimental as of the release of
    tornado 1.2.  It is intended to become the default tornado
    AsyncRedisClient implementation.
    """
    def __init__(self, address, io_loop=None, socket_timeout=10):
        """Creates a AsyncRedisClient.

        address is the tuple of redis server address that can be connect by
        IOStream. It can be to ('127.0.0.1', 6379).
        """
        self.address = address
        self.io_loop = io_loop or IOLoop.instance()
        self._callback_queue = deque()
        self._callback = None
        self._read_buffer = None
        self._result_queue = deque()
        self.socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
        self.socket.settimeout(socket_timeout)
        self.socket.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
        self.socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
        self.stream = IOStream(self.socket, self.io_loop)
        self.stream.connect(self.address, self._wait_result)

    def close(self):
        """Destroys this redis client, freeing any file descriptors used.
        Not needed in normal use, but may be helpful in unittests that
        create and destroy redis clients.  No other methods may be called
        on the AsyncRedisClient after close().
        """
        self.stream.close()

    def fetch(self, request, callback):
        """Executes a request, calling callback with an redis `result`.

        The request should be a string tuple. like ('set', 'foo', 'bar')

        If an error occurs during the fetch, a `RedisError` exception will
        throw out. You can use try...except to catch the exception (if any)
        in the callback.
        """
        self._callback_queue.append(callback)
        self.stream.write(encode(request))

    def _wait_result(self):
        """Read a completed result data from the redis server."""
        self._read_buffer = deque()
        self.stream.read_until('\r\n', self._on_read_first_line)

    def _maybe_callback(self):
        """Try call callback in _callback_queue when we read a redis result."""
        try:
            read_buffer = self._read_buffer
            callback = self._callback
            result_queue = self._result_queue
            callback_queue = self._callback_queue
            if result_queue:
                result_queue.append(read_buffer)
                read_buffer = result_queue.popleft()
            if callback_queue:
                callback = self._callback = callback_queue.popleft()
            if callback:
                callback(decode(read_buffer))
        except Exception:
            logging.error('Uncaught callback exception', exc_info=True)
            self.close()
            raise
        finally:
            self._wait_result()

    def _on_read_first_line(self, data):
        self._read_buffer.append(data)
        c = data[0]
        if c in ':+-':
            self._maybe_callback()
        elif c == '$':
            if data[:3] == '$-1':
                self._maybe_callback()
            else:
                length = int(data[1:])
                self.stream.read_bytes(length + 2, self._on_read_bulk_body)
        elif c == '*':
            if data[1] in '-0':
                self._maybe_callback()
            else:
                self._multibulk_number = int(data[1:])
                self.stream.read_until('\r\n',
                                       self._on_read_multibulk_bulk_head)

    def _on_read_bulk_body(self, data):
        self._read_buffer.append(data)
        self._maybe_callback()

    def _on_read_multibulk_bulk_head(self, data):
        self._read_buffer.append(data)
        c = data[0]
        if c == '$':
            length = int(data[1:])
            self.stream.read_bytes(length + 2,
                                   self._on_read_multibulk_bulk_body)
        else:
            self._maybe_callback()

    def _on_read_multibulk_bulk_body(self, data):
        self._read_buffer.append(data)
        self._multibulk_number -= 1
        if self._multibulk_number:
            self.stream.read_until('\r\n', self._on_read_multibulk_bulk_head)
        else:
            self._maybe_callback()
Ejemplo n.º 51
0
 def accept_callback(conn, address):
     stream = IOStream(conn, io_loop=self.io_loop)
     stream.read_until(b"\r\n\r\n",
                       functools.partial(write_response, stream))
Ejemplo n.º 52
0
class KeepAliveTest(AsyncHTTPTestCase):
    """Tests various scenarios for HTTP 1.1 keep-alive support.

    These tests don't use AsyncHTTPClient because we want to control
    connection reuse and closing.
    """
    def get_app(self):
        class HelloHandler(RequestHandler):
            def get(self):
                self.finish('Hello world')

        class LargeHandler(RequestHandler):
            def get(self):
                # 512KB should be bigger than the socket buffers so it will
                # be written out in chunks.
                self.write(''.join(chr(i % 256) * 1024 for i in range(512)))

        class FinishOnCloseHandler(RequestHandler):
            @asynchronous
            def get(self):
                self.flush()

            def on_connection_close(self):
                # This is not very realistic, but finishing the request
                # from the close callback has the right timing to mimic
                # some errors seen in the wild.
                self.finish('closed')

        return Application([('/', HelloHandler), ('/large', LargeHandler),
                            ('/finish_on_close', FinishOnCloseHandler)])

    def setUp(self):
        super(KeepAliveTest, self).setUp()
        self.http_version = b'HTTP/1.1'

    def tearDown(self):
        # We just closed the client side of the socket; let the IOLoop run
        # once to make sure the server side got the message.
        self.io_loop.add_timeout(datetime.timedelta(seconds=0.001), self.stop)
        self.wait()

        if hasattr(self, 'stream'):
            self.stream.close()
        super(KeepAliveTest, self).tearDown()

    # The next few methods are a crude manual http client
    def connect(self):
        self.stream = IOStream(socket.socket(), io_loop=self.io_loop)
        self.stream.connect(('localhost', self.get_http_port()), self.stop)
        self.wait()

    def read_headers(self):
        self.stream.read_until(b'\r\n', self.stop)
        first_line = self.wait()
        self.assertTrue(first_line.startswith(self.http_version + b' 200'),
                        first_line)
        self.stream.read_until(b'\r\n\r\n', self.stop)
        header_bytes = self.wait()
        headers = HTTPHeaders.parse(header_bytes.decode('latin1'))
        return headers

    def read_response(self):
        self.headers = self.read_headers()
        self.stream.read_bytes(int(self.headers['Content-Length']), self.stop)
        body = self.wait()
        self.assertEqual(b'Hello world', body)

    def close(self):
        self.stream.close()
        del self.stream

    def test_two_requests(self):
        self.connect()
        self.stream.write(b'GET / HTTP/1.1\r\n\r\n')
        self.read_response()
        self.stream.write(b'GET / HTTP/1.1\r\n\r\n')
        self.read_response()
        self.close()

    def test_request_close(self):
        self.connect()
        self.stream.write(b'GET / HTTP/1.1\r\nConnection: close\r\n\r\n')
        self.read_response()
        self.stream.read_until_close(callback=self.stop)
        data = self.wait()
        self.assertTrue(not data)
        self.close()

    # keepalive is supported for http 1.0 too, but it's opt-in
    def test_http10(self):
        self.http_version = b'HTTP/1.0'
        self.connect()
        self.stream.write(b'GET / HTTP/1.0\r\n\r\n')
        self.read_response()
        self.stream.read_until_close(callback=self.stop)
        data = self.wait()
        self.assertTrue(not data)
        self.assertTrue('Connection' not in self.headers)
        self.close()

    def test_http10_keepalive(self):
        self.http_version = b'HTTP/1.0'
        self.connect()
        self.stream.write(b'GET / HTTP/1.0\r\nConnection: keep-alive\r\n\r\n')
        self.read_response()
        self.assertEqual(self.headers['Connection'], 'Keep-Alive')
        self.stream.write(b'GET / HTTP/1.0\r\nConnection: keep-alive\r\n\r\n')
        self.read_response()
        self.assertEqual(self.headers['Connection'], 'Keep-Alive')
        self.close()

    def test_pipelined_requests(self):
        self.connect()
        self.stream.write(b'GET / HTTP/1.1\r\n\r\nGET / HTTP/1.1\r\n\r\n')
        self.read_response()
        self.read_response()
        self.close()

    def test_pipelined_cancel(self):
        self.connect()
        self.stream.write(b'GET / HTTP/1.1\r\n\r\nGET / HTTP/1.1\r\n\r\n')
        # only read once
        self.read_response()
        self.close()

    def test_cancel_during_download(self):
        self.connect()
        self.stream.write(b'GET /large HTTP/1.1\r\n\r\n')
        self.read_headers()
        self.stream.read_bytes(1024, self.stop)
        self.wait()
        self.close()

    def test_finish_while_closed(self):
        self.connect()
        self.stream.write(b'GET /finish_on_close HTTP/1.1\r\n\r\n')
        self.read_headers()
        self.close()
Ejemplo n.º 53
0
class Connection(object):
    def __init__(self, host, port, event_handler,
                 stop_after=None, io_loop=None):
        self.host = host
        self.port = port
        self._event_handler = weakref.proxy(event_handler)
        self.timeout = stop_after
        self._stream = None
        self._io_loop = io_loop
        self.try_left = 2

        self.in_progress = False
        self.read_callbacks = []
        self.ready_callbacks = deque()

    def __del__(self):
        self.disconnect()

    def __enter__(self):
        return self

    def __exit__(self, *args, **kwargs):
        if self.ready_callbacks:
            # Pop a SINGLE callback from the queue and execute it.
            # The next one will be executed from the code
            # invoked by the callback
            callback = self.ready_callbacks.popleft()
            callback()

    def ready(self):
        return not self.read_callbacks and not self.ready_callbacks

    def wait_until_ready(self, callback=None):
        if callback:
            if not self.ready():
                self.ready_callbacks.append(callback)
            else:
                callback()
        return self

    def connect(self):
        if not self._stream:
            try:
                sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM, 0)
                sock.setsockopt(socket.SOL_TCP, socket.TCP_NODELAY, 1)
                sock.settimeout(self.timeout)
                sock.connect((self.host, self.port))
                self._stream = IOStream(sock, io_loop=self._io_loop)
                self._stream.set_close_callback(self.on_stream_close)
                self.connected()
            except socket.error as e:
                raise ConnectionError(str(e))
            self.fire_event('on_connect')

    def on_stream_close(self):
        if self._stream:
            self._stream = None
            callbacks = self.read_callbacks
            self.read_callbacks = []
            for callback in callbacks:
                callback(None)

    def disconnect(self):
        if self._stream:
            s = self._stream
            self._stream = None
            try:
                s.socket.shutdown(socket.SHUT_RDWR)
                s.close()
            except socket.error:
                pass

    def fire_event(self, event):
        if self._event_handler:
            try:
                getattr(self._event_handler, event)()
            except AttributeError:
                pass

    def write(self, data, try_left=None):
        if try_left is None:
            try_left = self.try_left
        if not self._stream:
            self.connect()
            if not self._stream:
                raise ConnectionError('Tried to write to '
                                      'non-existent connection')

        if try_left > 0:
            try:
                self._stream.write(data)
            except IOError:
                self.disconnect()
                self.write(data, try_left - 1)
        else:
            raise ConnectionError('Tried to write to non-existent connection')

    def read(self, length, callback=None):
        try:
            if not self._stream:
                self.disconnect()
                raise ConnectionError('Tried to read from '
                                      'non-existent connection')
            self.read_callbacks.append(callback)
            self._stream.read_bytes(length,
                                    callback=partial(self.read_callback,
                                                     callback))
        except IOError:
            self.fire_event('on_disconnect')

    def read_callback(self, callback, *args, **kwargs):
        self.read_callbacks.remove(callback)
        callback(*args, **kwargs)

    def readline(self, callback=None):
        try:
            if not self._stream:
                self.disconnect()
                raise ConnectionError('Tried to read from '
                                      'non-existent connection')
            self.read_callbacks.append(callback)
            self._stream.read_until('\r\n',
                                    callback=partial(self.read_callback,
                                                     callback))
        except IOError:
            self.fire_event('on_disconnect')

    def connected(self):
        if self._stream:
            return True
        return False
Ejemplo n.º 54
0
class IOBot(object):
    def __init__(
        self,
        host,
        nick='hircules',
        port=6667,
        char='@',
        owner='owner',
        initial_chans=None,
        on_ready=None,
    ):
        """
        create an irc bot instance.
        @params
        initial_chans: None or list of strings representing channels to join
        """
        # Move state variables into whatever data store used
        # and access from that.
        self.nick = nick
        self.chans = set()  # chans we're a member of
        self.owner = owner
        self.host = host
        self.port = port
        self.char = char
        self._plugins = dict()
        self._core_plugins = dict()
        self._connected = False
        self._initial_chans = initial_chans
        self._on_ready = on_ready
        self._registered = []
        self._initialized = False
        # used for parsing out nicks later, just wanted to compile it once
        # server protocol gorp
        self._irc_proto = {
            'PRIVMSG': IrcProtoCmd(self._p_privmsg),
            'PING': IrcProtoCmd(self._p_ping),
            'JOIN': IrcProtoCmd(self._p_afterjoin),
            '401': IrcProtoCmd(self._p_nochan),
        }
        # build our user command list
        self.cmds = dict()

        # initialize core plugins
        self._initialize_core_plugins()

        # initialize the Store
        self.store = Store()

        # initialize API server
        self._api = APIServer(self.store)

        # finally, connect.
        self._connect()

        #
        self._initialized = True

    def hook(self, cmd, hook_f):
        """
        allows easy hooking of any raw irc protocol statement.  These will be
        executed after the initial protocol parsing occurs.  Plugins can use this
        to extend their reach lower into the protocol.
        """
        assert (cmd in self._irc_proto)
        self._irc_proto[cmd].hooks.add(hook_f)

    def joinchan(self, chan):
        self._stream.write("JOIN :%s\r\n" % chan)

    def say(self, chan, msg):
        """
        sends a message to a chan or user
        """
        self._stream.write("PRIVMSG {} :{}\r\n".format(chan, msg))

    def register(self, plugins):
        """
        accepts an instance of Plugin to add to the callback chain
        """

        for p in plugins:
            # update to support custom paths?
            p_module = __import__('iobot.plugins.%s.plugin' % p,
                                  fromlist=['Plugin'])
            p_obj = p_module.Plugin()

            cmds = self._get_commands_from_plugin(p_obj)
            self._add_plugin_commands(cmds, p_obj)

            # append the module as "registered"
            if p not in self._registered: self._registered.append(p)

    def _initialize_core_plugins(self):
        for _cp in core_plugins:
            cp_obj = _cp()
            cmds = self._get_commands_from_plugin(cp_obj)
            self._add_plugin_commands(cmds, cp_obj)
            self._core_plugins[cp_obj.NAME] = cp_obj

    def _add_plugin_commands(self, cmds, obj):
        # don't allow other people to stomp on existing plugins ??
        for cmd in cmds:
            if cmd in self._plugins and not self._initialized:
                raise ValueError('command %s already exists' % cmd)
            self._plugins[cmd] = obj

    def _get_commands_from_plugin(self, obj):
        cmds = []
        for method in dir(obj):
            if callable(getattr(obj, method)) \
               and hasattr(getattr(obj, method), 'cmd'):
                cmds.append(method)
        return cmds

    def _connect(self):
        _sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM, 0)
        self._stream = IOStream(_sock)
        self._stream.connect((self.host, self.port), self._after_connect)

    def _after_connect(self):
        self._stream.write("NICK %s\r\n" % self.nick)
        self._stream.write("USER %s 0 * :%s\r\n" % ("iobot", "iobot"))

        if self._initial_chans:
            for c in self._initial_chans:
                self.joinchan(c)
            del self._initial_chans
        if self._on_ready:
            self._on_ready()
        self._next()

    def _parse_line(self, line):
        irc = IrcObj(line, self)
        if irc.server_cmd in self._irc_proto:
            self._irc_proto[irc.server_cmd](irc, line)
        return irc

    def _p_ping(self, irc, line):
        self._stream.write("PONG %s\r\n" % line[1])

    def _p_privmsg(self, irc, line):
        # :[email protected] PRIVMSG #xx :hi
        toks = line[1:].split(':')[0].split()
        irc.chan = toks[-1]  # should be last token after last :
        irc.text = line[line.find(':', 1) + 1:].strip()
        if irc.text and irc.text.startswith(self.char):
            text_split = irc.text.split()
            irc.command = text_split[0][1:]
            irc.command_args = ' '.join(text_split[1:])

    def _p_afterjoin(self, irc, line):
        toks = line.strip().split(':')
        if irc.nick != self.nick:
            return  # we don't care right now if others join
        irc.chan = toks[-1]  # should be last token after last :
        self.chans.add(irc.chan)

    def _p_nochan(self, irc, line):
        # :senor.crunchybueno.com 401 nodnc  #xx :No such nick/channel
        toks = line.strip().split(':')
        irc.chan = toks[1].strip().split()[-1]
        if irc.chan in self.chans: self.chans.remove(irc.chan)

    def _process_plugins(self, irc):
        """ parses a completed ircObj for module hooks """
        try:
            plugin = self._plugins.get(irc.command) if irc.command else None
        except KeyError:
            # plugin does not exist
            pass

        try:
            if plugin:
                plugin_method = getattr(plugin, irc.command)
                plugin_method(irc)
        except:
            doc = "usage: %s %s" % (irc.command, plugin_method.__doc__)
            irc.say(doc)

    def _next(self):
        # go back on the loop looking for the next line of input
        self._stream.read_until('\r\n', self._incoming)

    def _incoming(self, line):
        self._process_plugins(self._parse_line(line))
        self._next()

    def reload_plugin(self, plugin):
        # reinitialize the plugins
        core = False
        for _cp in core_plugins:
            if _cp.NAME == plugin:
                reload(_cp)
                core = True
                break
        if not core:
            self.register([plugin])
            p_module = __import__('iobot.plugins.%s.plugin' % plugin,
                                  fromlist=['Plugin'])
            reload(p_module)
            self.register([plugin])
Ejemplo n.º 55
0
class TestIOStreamStartTLS(AsyncTestCase):
    def setUp(self):
        try:
            super(TestIOStreamStartTLS, self).setUp()
            self.listener, self.port = bind_unused_port()
            self.server_stream = None
            self.server_accepted = Future()
            netutil.add_accept_handler(self.listener, self.accept)
            self.client_stream = IOStream(socket.socket())
            self.io_loop.add_future(
                self.client_stream.connect(('127.0.0.1', self.port)),
                self.stop)
            self.wait()
            self.io_loop.add_future(self.server_accepted, self.stop)
            self.wait()
        except Exception as e:
            print(e)
            raise

    def tearDown(self):
        if self.server_stream is not None:
            self.server_stream.close()
        if self.client_stream is not None:
            self.client_stream.close()
        self.listener.close()
        super(TestIOStreamStartTLS, self).tearDown()

    def accept(self, connection, address):
        if self.server_stream is not None:
            self.fail("should only get one connection")
        self.server_stream = IOStream(connection)
        self.server_accepted.set_result(None)

    @gen.coroutine
    def client_send_line(self, line):
        self.client_stream.write(line)
        recv_line = yield self.server_stream.read_until(b"\r\n")
        self.assertEqual(line, recv_line)

    @gen.coroutine
    def server_send_line(self, line):
        self.server_stream.write(line)
        recv_line = yield self.client_stream.read_until(b"\r\n")
        self.assertEqual(line, recv_line)

    def client_start_tls(self, ssl_options=None):
        client_stream = self.client_stream
        self.client_stream = None
        return client_stream.start_tls(False, ssl_options)

    def server_start_tls(self, ssl_options=None):
        server_stream = self.server_stream
        self.server_stream = None
        return server_stream.start_tls(True, ssl_options)

    @gen_test
    def test_start_tls_smtp(self):
        # This flow is simplified from RFC 3207 section 5.
        # We don't really need all of this, but it helps to make sure
        # that after realistic back-and-forth traffic the buffers end up
        # in a sane state.
        yield self.server_send_line(b"220 mail.example.com ready\r\n")
        yield self.client_send_line(b"EHLO mail.example.com\r\n")
        yield self.server_send_line(b"250-mail.example.com welcome\r\n")
        yield self.server_send_line(b"250 STARTTLS\r\n")
        yield self.client_send_line(b"STARTTLS\r\n")
        yield self.server_send_line(b"220 Go ahead\r\n")
        client_future = self.client_start_tls()
        server_future = self.server_start_tls(_server_ssl_options())
        self.client_stream = yield client_future
        self.server_stream = yield server_future
        self.assertTrue(isinstance(self.client_stream, SSLIOStream))
        self.assertTrue(isinstance(self.server_stream, SSLIOStream))
        yield self.client_send_line(b"EHLO mail.example.com\r\n")
        yield self.server_send_line(b"250 mail.example.com welcome\r\n")

    @gen_test
    def test_handshake_fail(self):
        self.server_start_tls(_server_ssl_options())
        client_future = self.client_start_tls(
            dict(cert_reqs=ssl.CERT_REQUIRED, ca_certs=certifi.where()))
        with ExpectLog(gen_log, "SSL Error"):
            with self.assertRaises(ssl.SSLError):
                yield client_future
Ejemplo n.º 56
0
class KeepAliveTest(AsyncHTTPTestCase):
    """Tests various scenarios for HTTP 1.1 keep-alive support.

    These tests don't use AsyncHTTPClient because we want to control
    connection reuse and closing.
    """
    def get_app(self):
        class HelloHandler(RequestHandler):
            def get(self):
                self.finish('Hello world')

            def post(self):
                self.finish('Hello world')

        class LargeHandler(RequestHandler):
            def get(self):
                # 512KB should be bigger than the socket buffers so it will
                # be written out in chunks.
                self.write(''.join(chr(i % 256) * 1024 for i in range(512)))

        class FinishOnCloseHandler(RequestHandler):
            @asynchronous
            def get(self):
                self.flush()

            def on_connection_close(self):
                # This is not very realistic, but finishing the request
                # from the close callback has the right timing to mimic
                # some errors seen in the wild.
                self.finish('closed')

        return Application([('/', HelloHandler),
                            ('/large', LargeHandler),
                            ('/finish_on_close', FinishOnCloseHandler)])

    def setUp(self):
        super(KeepAliveTest, self).setUp()
        self.http_version = b'HTTP/1.1'

    def tearDown(self):
        # We just closed the client side of the socket; let the IOLoop run
        # once to make sure the server side got the message.
        self.io_loop.add_timeout(datetime.timedelta(seconds=0.001), self.stop)
        self.wait()

        if hasattr(self, 'stream'):
            self.stream.close()
        super(KeepAliveTest, self).tearDown()

    # The next few methods are a crude manual http client
    def connect(self):
        self.stream = IOStream(socket.socket())
        self.stream.connect(('127.0.0.1', self.get_http_port()), self.stop)
        self.wait()

    def read_headers(self):
        self.stream.read_until(b'\r\n', self.stop)
        first_line = self.wait()
        self.assertTrue(first_line.startswith(b'HTTP/1.1 200'), first_line)
        self.stream.read_until(b'\r\n\r\n', self.stop)
        header_bytes = self.wait()
        headers = HTTPHeaders.parse(header_bytes.decode('latin1'))
        return headers

    def read_response(self):
        self.headers = self.read_headers()
        self.stream.read_bytes(int(self.headers['Content-Length']), self.stop)
        body = self.wait()
        self.assertEqual(b'Hello world', body)

    def close(self):
        self.stream.close()
        del self.stream

    def test_two_requests(self):
        self.connect()
        self.stream.write(b'GET / HTTP/1.1\r\n\r\n')
        self.read_response()
        self.stream.write(b'GET / HTTP/1.1\r\n\r\n')
        self.read_response()
        self.close()

    def test_request_close(self):
        self.connect()
        self.stream.write(b'GET / HTTP/1.1\r\nConnection: close\r\n\r\n')
        self.read_response()
        self.stream.read_until_close(callback=self.stop)
        data = self.wait()
        self.assertTrue(not data)
        self.assertEqual(self.headers['Connection'], 'close')
        self.close()

    # keepalive is supported for http 1.0 too, but it's opt-in
    def test_http10(self):
        self.http_version = b'HTTP/1.0'
        self.connect()
        self.stream.write(b'GET / HTTP/1.0\r\n\r\n')
        self.read_response()
        self.stream.read_until_close(callback=self.stop)
        data = self.wait()
        self.assertTrue(not data)
        self.assertTrue('Connection' not in self.headers)
        self.close()

    def test_http10_keepalive(self):
        self.http_version = b'HTTP/1.0'
        self.connect()
        self.stream.write(b'GET / HTTP/1.0\r\nConnection: keep-alive\r\n\r\n')
        self.read_response()
        self.assertEqual(self.headers['Connection'], 'Keep-Alive')
        self.stream.write(b'GET / HTTP/1.0\r\nConnection: keep-alive\r\n\r\n')
        self.read_response()
        self.assertEqual(self.headers['Connection'], 'Keep-Alive')
        self.close()

    def test_http10_keepalive_extra_crlf(self):
        self.http_version = b'HTTP/1.0'
        self.connect()
        self.stream.write(b'GET / HTTP/1.0\r\nConnection: keep-alive\r\n\r\n\r\n')
        self.read_response()
        self.assertEqual(self.headers['Connection'], 'Keep-Alive')
        self.stream.write(b'GET / HTTP/1.0\r\nConnection: keep-alive\r\n\r\n')
        self.read_response()
        self.assertEqual(self.headers['Connection'], 'Keep-Alive')
        self.close()

    def test_pipelined_requests(self):
        self.connect()
        self.stream.write(b'GET / HTTP/1.1\r\n\r\nGET / HTTP/1.1\r\n\r\n')
        self.read_response()
        self.read_response()
        self.close()

    def test_pipelined_cancel(self):
        self.connect()
        self.stream.write(b'GET / HTTP/1.1\r\n\r\nGET / HTTP/1.1\r\n\r\n')
        # only read once
        self.read_response()
        self.close()

    def test_cancel_during_download(self):
        self.connect()
        self.stream.write(b'GET /large HTTP/1.1\r\n\r\n')
        self.read_headers()
        self.stream.read_bytes(1024, self.stop)
        self.wait()
        self.close()

    def test_finish_while_closed(self):
        self.connect()
        self.stream.write(b'GET /finish_on_close HTTP/1.1\r\n\r\n')
        self.read_headers()
        self.close()

    def test_keepalive_chunked(self):
        self.http_version = b'HTTP/1.0'
        self.connect()
        self.stream.write(b'POST / HTTP/1.0\r\n'
                          b'Connection: keep-alive\r\n'
                          b'Transfer-Encoding: chunked\r\n'
                          b'\r\n'
                          b'0\r\n'
                          b'\r\n')
        self.read_response()
        self.assertEqual(self.headers['Connection'], 'Keep-Alive')
        self.stream.write(b'GET / HTTP/1.0\r\nConnection: keep-alive\r\n\r\n')
        self.read_response()
        self.assertEqual(self.headers['Connection'], 'Keep-Alive')
        self.close()
Ejemplo n.º 57
0
 def accept_callback(conn, address):
     # fake an HTTP server using chunked encoding where the final chunks
     # and connection close all happen at once
     stream = IOStream(conn, io_loop=self.io_loop)
     stream.read_until(b("\r\n\r\n"),
                       functools.partial(write_response, stream))
Ejemplo n.º 58
0
class _HTTPConnection(object):
    _SUPPORTED_METHODS = set(["GET", "HEAD", "POST", "PUT", "DELETE"])

    def __init__(self, io_loop, client, request, callback):
        self.start_time = time.time()
        self.io_loop = io_loop
        self.client = client
        self.request = request
        self.callback = callback
        self.code = None
        self.headers = None
        self.chunks = None
        self._decompressor = None
        # Timeout handle returned by IOLoop.add_timeout
        self._timeout = None
        with stack_context.StackContext(self.cleanup):
            parsed = urlparse.urlsplit(self.request.url)
            host = parsed.hostname
            if parsed.port is None:
                port = 443 if parsed.scheme == "https" else 80
            else:
                port = parsed.port
            if self.client.hostname_mapping is not None:
                host = self.client.hostname_mapping.get(host, host)

            if parsed.scheme == "https":
                ssl_options = {}
                if request.validate_cert:
                    ssl_options["cert_reqs"] = ssl.CERT_REQUIRED
                if request.ca_certs is not None:
                    ssl_options["ca_certs"] = request.ca_certs
                else:
                    ssl_options["ca_certs"] = _DEFAULT_CA_CERTS
                self.stream = SSLIOStream(socket.socket(),
                                          io_loop=self.io_loop,
                                          ssl_options=ssl_options)
            else:
                self.stream = IOStream(socket.socket(),
                                       io_loop=self.io_loop)
            timeout = min(request.connect_timeout, request.request_timeout)
            if timeout:
                self._connect_timeout = self.io_loop.add_timeout(
                    self.start_time + timeout,
                    self._on_timeout)
            self.stream.set_close_callback(self._on_close)
            self.stream.connect((host, port),
                                functools.partial(self._on_connect, parsed))

    def _on_timeout(self):
        self._timeout = None
        if self.callback is not None:
            self.callback(HTTPResponse(self.request, 599,
                                       error=HTTPError(599, "Timeout")))
            self.callback = None
        self.stream.close()

    def _on_connect(self, parsed):
        if self._timeout is not None:
            self.io_loop.remove_callback(self._timeout)
            self._timeout = None
        if self.request.request_timeout:
            self._timeout = self.io_loop.add_timeout(
                self.start_time + self.request.request_timeout,
                self._on_timeout)
        if (self.request.validate_cert and
            isinstance(self.stream, SSLIOStream)):
            match_hostname(self.stream.socket.getpeercert(),
                           parsed.hostname)
        if (self.request.method not in self._SUPPORTED_METHODS and
            not self.request.allow_nonstandard_methods):
            raise KeyError("unknown method %s" % self.request.method)
        for key in ('network_interface',
                    'proxy_host', 'proxy_port',
                    'proxy_username', 'proxy_password'):
            if getattr(self.request, key, None):
                raise NotImplementedError('%s not supported' % key)
        if "Host" not in self.request.headers:
            self.request.headers["Host"] = parsed.netloc
        username, password = None, None
        if parsed.username is not None:
            username, password = parsed.username, parsed.password
        elif self.request.auth_username is not None:
            username = self.request.auth_username
            password = self.request.auth_password
        if username is not None:
            auth = "%s:%s" % (username, password)
            self.request.headers["Authorization"] = ("Basic %s" %
                                                     auth.encode("base64"))
        if self.request.user_agent:
            self.request.headers["User-Agent"] = self.request.user_agent
        has_body = self.request.method in ("POST", "PUT")
        if has_body:
            assert self.request.body is not None
            self.request.headers["Content-Length"] = len(
                self.request.body)
        else:
            assert self.request.body is None
        if (self.request.method == "POST" and
            "Content-Type" not in self.request.headers):
            self.request.headers["Content-Type"] = "application/x-www-form-urlencoded"
        if self.request.use_gzip:
            self.request.headers["Accept-Encoding"] = "gzip"
        req_path = ((parsed.path or '/') +
                (('?' + parsed.query) if parsed.query else ''))
        request_lines = ["%s %s HTTP/1.1" % (self.request.method,
                                             req_path)]
        for k, v in self.request.headers.get_all():
            request_lines.append("%s: %s" % (k, v))
        self.stream.write("\r\n".join(request_lines) + "\r\n\r\n")
        if has_body:
            self.stream.write(self.request.body)
        self.stream.read_until("\r\n\r\n", self._on_headers)

    @contextlib.contextmanager
    def cleanup(self):
        try:
            yield
        except Exception, e:
            logging.warning("uncaught exception", exc_info=True)
            if self.callback is not None:
                callback = self.callback
                self.callback = None
                callback(HTTPResponse(self.request, 599, error=e))