예제 #1
0
파일: tirsk.py 프로젝트: nailor/tirsk
class IRCStream(object):
    """
    A connection to an IRC server utilizing IOStream
    """
    def __init__(self, nick, url, io_loop=None):
        self.nick = nick
        self.url = url
        self.io_loop = io_loop or IOLoop.instance()

        parsed = urlparse.urlsplit(self.url)
        assert parsed.scheme == 'irc'
        if ':' in parsed.netloc:
            host, _, port = parsed.netloc.partition(':')
            port = int(port)
        else:
            host = parsed.netloc
            port = 6667
        self.host = host
        self.port = port

    def connect(self, callback):
        self.stream = IOStream(socket.socket(), io_loop=self.io_loop)
        self.stream.connect((self.host, self.port),
                            functools.partial(self._on_connect, callback))


    def _on_connect(self, callback):
        self.stream.write('NICK %s\r\n' % self.nick)
        callback(True)
예제 #2
0
파일: memnado.py 프로젝트: clofresh/memnado
class Memnado(object):
    def __init__(self, host, port):
        self.host = host
        self.port = port
        
        s = socket.socket(socket.AF_INET, socket.SOCK_STREAM, 0)
        s.connect((self.host, self.port))
        self.stream = IOStream(s)
    
    def set(self, key, value, callback, expiry=0):
        key = b64encode(key)
        value = b64encode(value)
        content_length = len(value)
        self.stream.write("set %s 1 %s %s\r\n%s\r\n" % (key, expiry, 
                        content_length, value))
        self.stream.read_until("\r\n", callback)
    
    def get(self, key, callback):
        key = b64encode(key)
        
        def process_get(stream, cb, data):
            if data[0:3] == 'END': # key is empty
                cb(None)
            else:
                status, k, flags, content_length = data.strip().split(' ')
                
                def wrapped_cb(f):
                    return lambda data: f(b64decode(data))
                
                stream.read_bytes(int(content_length), wrapped_cb(cb))
                stream.read_until("\r\nEND\r\n", lambda d: d)
        
        self.stream.write("get %s\r\n" % key)
        self.stream.read_until("\r\n", functools.partial(process_get, self.stream, callback))
예제 #3
0
class HTTPServerRawTest(AsyncHTTPTestCase):
    def get_app(self):
        return Application([("/echo", EchoHandler)])

    def setUp(self):
        super(HTTPServerRawTest, self).setUp()
        self.stream = IOStream(socket.socket())
        self.stream.connect(("localhost", self.get_http_port()), self.stop)
        self.wait()

    def tearDown(self):
        self.stream.close()
        super(HTTPServerRawTest, self).tearDown()

    def test_empty_request(self):
        self.stream.close()
        self.io_loop.add_timeout(datetime.timedelta(seconds=0.001), self.stop)
        self.wait()

    def test_malformed_first_line(self):
        with ExpectLog(gen_log, ".*Malformed HTTP request line"):
            self.stream.write(b"asdf\r\n\r\n")
            # TODO: need an async version of ExpectLog so we don't need
            # hard-coded timeouts here.
            self.io_loop.add_timeout(datetime.timedelta(seconds=0.01), self.stop)
            self.wait()

    def test_malformed_headers(self):
        with ExpectLog(gen_log, ".*Malformed HTTP headers"):
            self.stream.write(b"GET / HTTP/1.0\r\nasdf\r\n\r\n")
            self.io_loop.add_timeout(datetime.timedelta(seconds=0.01), self.stop)
            self.wait()
예제 #4
0
class _UDPConnection(object):
    def __init__(self, io_loop, client, request, release_callback, final_callback, max_buffer_size):
        self.start_time = time.time()
        self.io_loop = io_loop
        self.client = client
        self.request = request
        self.release_callback = release_callback
        self.final_callback = final_callback

        address_info = socket.getaddrinfo(request.address, request.port, socket.AF_INET, socket.SOCK_DGRAM, 0, 0)
        af, socket_type, proto, _, socket_address = address_info[0]
        self.stream = IOStream(socket.socket(af, socket_type, proto), io_loop=self.io_loop,
                               max_buffer_size=max_buffer_size)

        self.stream.connect(socket_address, self._on_connect)

    def _on_connect(self):
        self.stream.write(self.request.data)
        # self.stream.read_bytes(65536, self._on_response)
        self.stream.read_until('}}', self._on_response)
        # print("asdfsfeiwjef")

    def _on_response(self, data):
        if self.release_callback is not None:
            release_callback = self.release_callback
            self.release_callback = None
            release_callback()
        self.stream.close()
        if self.final_callback is not None:
            final_callback = self.final_callback
            self.final_callback = None
            final_callback(data)
예제 #5
0
class ForwardConnection(object):
    def __init__(self, remote_address, stream, address):
        self.remote_address = remote_address
        self.stream = stream
        self.address = address
        sock = socket.socket()
        self.remote_stream = IOStream(sock)
        self.remote_stream.connect(self.remote_address, self._on_remote_connected)

    def _on_remote_connected(self):
        logging.info("forward %r to %r", self.address, self.remote_address)
        self.remote_stream.read_until_close(self._on_remote_read_close, self.stream.write)
        self.stream.read_until_close(self._on_read_close, self.remote_stream.write)

    def _on_remote_read_close(self, data):
        if self.stream.writing():
            self.stream.write(data, self.stream.close)
        else:
            self.stream.close()

    def _on_read_close(self, data):
        if self.remote_stream.writing():
            self.remote_stream.write(data, self.remote_stream.close)
        else:
            self.remote_stream.close()
예제 #6
0
 def test_100_continue(self):
     # Run through a 100-continue interaction by hand:
     # When given Expect: 100-continue, we get a 100 response after the
     # headers, and then the real response after the body.
     stream = IOStream(socket.socket())
     stream.connect(("127.0.0.1", self.get_http_port()), callback=self.stop)
     self.wait()
     stream.write(b"\r\n".join([b"POST /hello HTTP/1.1",
                                b"Content-Length: 1024",
                                b"Expect: 100-continue",
                                b"Connection: close",
                                b"\r\n"]), callback=self.stop)
     self.wait()
     stream.read_until(b"\r\n\r\n", self.stop)
     data = self.wait()
     self.assertTrue(data.startswith(b"HTTP/1.1 100 "), data)
     stream.write(b"a" * 1024)
     stream.read_until(b"\r\n", self.stop)
     first_line = self.wait()
     self.assertTrue(first_line.startswith(b"HTTP/1.1 200"), first_line)
     stream.read_until(b"\r\n\r\n", self.stop)
     header_data = self.wait()
     headers = HTTPHeaders.parse(native_str(header_data.decode('latin1')))
     stream.read_bytes(int(headers["Content-Length"]), self.stop)
     body = self.wait()
     self.assertEqual(body, b"Got 1024 bytes in POST")
     stream.close()
예제 #7
0
파일: rproxy.py 프로젝트: dnslj/python-labs
class ForwardConnection(object):

    def __init__(self, remote_address, stream, address, headers):
        self.remote_address = remote_address
        self.stream = stream
        self.address = address
        self.headers = headers
        sock = socket.socket()
        self.remote_stream = IOStream(sock)
        self.remote_stream.connect(self.remote_address, self._on_remote_connected)    
        self.remote_stream.set_close_callback(self._on_close)    

    def _on_remote_write_complete(self):
        logging.info('send request to %s', self.remote_address)
        self.remote_stream.read_until_close(self._on_remote_read_close)

    def _on_remote_connected(self):
        logging.info('forward %r to %r', self.address, self.remote_address)
        self.remote_stream.write(self.headers, self._on_remote_write_complete)

    def _on_remote_read_close(self, data):
        self.stream.write(data, self.stream.close)

    def _on_close(self):
        logging.info('remote quit %s', self.remote_address)
        self.remote_stream.close()
예제 #8
0
class ManualCapClient(BaseCapClient):
    def capitalize(self, request_data, callback=None):
        logging.debug("capitalize")
        self.request_data = request_data
        self.stream = IOStream(socket.socket())
        self.stream.connect(('127.0.0.1', self.port),
                            callback=self.handle_connect)
        self.future = Future()
        if callback is not None:
            self.future.add_done_callback(
                stack_context.wrap(lambda future: callback(future.result())))
        return self.future

    def handle_connect(self):
        logging.debug("handle_connect")
        self.stream.write(utf8(self.request_data + "\n"))
        self.stream.read_until(b'\n', callback=self.handle_read)

    def handle_read(self, data):
        logging.debug("handle_read")
        self.stream.close()
        try:
            self.future.set_result(self.process_response(data))
        except CapError as e:
            self.future.set_exception(e)
예제 #9
0
class TestIOStream(AsyncHTTPTestCase, LogTrapTestCase):
    def get_app(self):
        return Application([('/', HelloHandler)])

    def test_read_zero_bytes(self):
        s = socket.socket(socket.AF_INET, socket.SOCK_STREAM, 0)
        s.connect(("localhost", self.get_http_port()))
        self.stream = IOStream(s, io_loop=self.io_loop)
        self.stream.write(b("GET / HTTP/1.0\r\n\r\n"))

        # normal read
        self.stream.read_bytes(9, self.stop)
        data = self.wait()
        self.assertEqual(data, b("HTTP/1.0 "))

        # zero bytes
        self.stream.read_bytes(0, self.stop)
        data = self.wait()
        self.assertEqual(data, b(""))

        # another normal read
        self.stream.read_bytes(3, self.stop)
        data = self.wait()
        self.assertEqual(data, b("200"))

    def test_connection_refused(self):
        # When a connection is refused, the connect callback should not
        # be run.  (The kqueue IOLoop used to behave differently from the
        # epoll IOLoop in this respect)
        port = get_unused_port()
        stream = IOStream(socket.socket(), self.io_loop)
        self.connect_called = False
        def connect_callback():
            self.connect_called = True
        stream.set_close_callback(self.stop)
        stream.connect(("localhost", port), connect_callback)
        self.wait()
        self.assertFalse(self.connect_called)

    def test_connection_closed(self):
        # When a server sends a response and then closes the connection,
        # the client must be allowed to read the data before the IOStream
        # closes itself.  Epoll reports closed connections with a separate
        # EPOLLRDHUP event delivered at the same time as the read event,
        # while kqueue reports them as a second read/write event with an EOF
        # flag.
        response = self.fetch("/", headers={"Connection": "close"})
        response.rethrow()

    def test_read_until_close(self):
        s = socket.socket(socket.AF_INET, socket.SOCK_STREAM, 0)
        s.connect(("localhost", self.get_http_port()))
        stream = IOStream(s, io_loop=self.io_loop)
        stream.write(b("GET / HTTP/1.0\r\n\r\n"))
        
        stream.read_until_close(self.stop)
        data = self.wait()
        self.assertTrue(data.startswith(b("HTTP/1.0 200")))
        self.assertTrue(data.endswith(b("Hello")))
예제 #10
0
class SubProcessApplication(Application):
    """Run application class in subprocess."""

    def __init__(self, target, io_loop=None):
        Application.__init__(self)
        if isinstance(target, str):
            self.target = self._load_module(target)
        else:
            self.target = target
        if io_loop is None:
            self.io_loop = IOLoop().instance()
        else:
            self.io_loop = io_loop
        self._process = None
        self.socket = None
        self.runner = None

    def start(self, config):
        signal.signal(signal.SIGCHLD, self._sigchld)
        self.socket, child = socket.socketpair()
        self.runner = _Subprocess(self.target, child, config)
        self._process = multiprocessing.Process(target=self.runner.run)
        self._process.start()
        child.close()
        self.ios = IOStream(self.socket, self.io_loop)
        self.ios.read_until('\r\n', self._receiver)

    def _close(self, timeout):
        self._process.join(timeout)

    def _sigchld(self, signum, frame):
        self._close(0.5)

    def _receiver(self, data):
        """Receive data from subprocess. Forward to session."""
        msg = json.loads(binascii.a2b_base64(data))
        if self.session:
            self.session.send(msg)
        else:
            logging.error("from app: %s", str(msg))

    def stop(self):
        self._process.terminate()
        self._close(2.0)

    def send(self, data):
        """Send data to application."""
        self.ios.write(data + '\r\n')

    def received(self, data):
        """Handle data from session. Forward to subprocess for handling."""
        self.send(binascii.b2a_base64(json.dumps(data)))
        return True

    def _load_module(self, modulename):
        import importlib
        return importlib.import_module(modulename)
예제 #11
0
    def test_read_until_close(self):
        s = socket.socket(socket.AF_INET, socket.SOCK_STREAM, 0)
        s.connect(("localhost", self.get_http_port()))
        stream = IOStream(s, io_loop=self.io_loop)
        stream.write(b("GET / HTTP/1.0\r\n\r\n"))

        stream.read_until_close(self.stop)
        data = self.wait()
        self.assertTrue(data.startswith(b("HTTP/1.0 200")))
        self.assertTrue(data.endswith(b("Hello")))
예제 #12
0
 def capitalize(self, request_data, callback):
     logging.info('capitalize')
     stream = IOStream(socket.socket(), io_loop=self.io_loop)
     logging.info('connecting')
     yield gen.Task(stream.connect, ('127.0.0.1', self.port))
     stream.write(utf8(request_data + '\n'))
     logging.info('reading')
     data = yield gen.Task(stream.read_until, b('\n'))
     logging.info('returning')
     stream.close()
     callback(self.process_response(data))
예제 #13
0
 def capitalize(self, request_data):
     logging.debug('capitalize')
     stream = IOStream(socket.socket())
     logging.debug('connecting')
     yield stream.connect(('127.0.0.1', self.port))
     stream.write(utf8(request_data + '\n'))
     logging.debug('reading')
     data = yield stream.read_until(b'\n')
     logging.debug('returning')
     stream.close()
     raise gen.Return(self.process_response(data))
예제 #14
0
 def test_timeout(self):
     stream = IOStream(socket.socket())
     try:
         yield stream.connect(("127.0.0.1", self.get_http_port()))
         # Use a raw stream because AsyncHTTPClient won't let us read a
         # response without finishing a body.
         stream.write(b"PUT /streaming?body_timeout=0.1 HTTP/1.0\r\n" b"Content-Length: 42\r\n\r\n")
         with ExpectLog(gen_log, "Timeout reading body"):
             response = yield stream.read_until_close()
         self.assertEqual(response, b"")
     finally:
         stream.close()
예제 #15
0
파일: link.py 프로젝트: 64b2b6d12b/lens
class LinkLayer(NetLayer):
    NAME="link"
    SNAPLEN=1550
    ETH_P_ALL = 3 

    ALICE = 0
    BOB = 1

    def __init__(self, alice_nic = "tapa", bob_nic = "tapb", *args, **kwargs):
        super(LinkLayer, self).__init__(*args, **kwargs)
        alice_sock = self.attach(alice_nic)
        bob_sock = self.attach(bob_nic)

        io_loop = IOLoop.instance()

        self.alice_stream = IOStream(alice_sock)
        self.bob_stream = IOStream(bob_sock)

        io_loop.add_handler(alice_sock.fileno(), self.alice_read, IOLoop.READ)
        io_loop.add_handler(bob_sock.fileno(), self.bob_read, IOLoop.READ)

    # This layer is a SOURCE
    # so it will never consume packets
    def match(self, src, header):
        return False

    @classmethod
    def attach(cls, nic):
        result = subprocess.call(["ip","link","set","up","promisc","on","dev",nic])
        if result:
            raise Exception("ip link dev {0} returned exit code {1}".format(nic,result))
        sock = socket.socket(socket.AF_PACKET,socket.SOCK_RAW,socket.htons(cls.ETH_P_ALL))
        sock.bind((nic,0))
        sock.setblocking(0)
        return sock

    def alice_read(self, fd, event):
        data = self.alice_stream.socket.recv(self.SNAPLEN)
        self.add_future(self.on_read(self.ALICE, {}, data[:-2]))

    def bob_read(self, fd, event):
        data = self.bob_stream.socket.recv(self.SNAPLEN)
        self.add_future(self.on_read(self.BOB, {}, data[:-2]))

    # coroutine
    def write(self, dst, header, data):
        if dst == self.ALICE:
            return self.alice_stream.write(data)
        elif dst == self.BOB:
            return self.bob_stream.write(data)
        else:
            raise Exception("Bad destination")
예제 #16
0
 def _connect_to_node(self, host, data=None):
     try:
         sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM, 0)
         sock.setsockopt(socket.SOL_TCP, socket.TCP_NODELAY, 1)
         address = host.split(':')
         sock.connect(tuple((address[0], int(address[1]))))
         stream = IOStream(sock, io_loop=ioloop.IOLoop.instance())
         self._streams[host] = stream
         stream.set_close_callback(functools.partial(self._handle_close, host))
         self._streams[host].read_until("\r\n", functools.partial(self._handle_read, host))
         if data:
             stream.write(data)
     except socket.error, e:
         a = 5
예제 #17
0
파일: remote.py 프로젝트: chijiao/Fukei
class RemoteUpstream(Upstream):

    """
    The most methods are the same in LocalUpstream, but maybe in future
    need to be diffrent.

    """

    def initialize(self):
        self.socket = socket.socket(self._address_type, socket.SOCK_STREAM)
        self.stream = IOStream(self.socket)
        self.stream.set_close_callback(self.on_close)

    def do_connect(self):
        self.stream.connect(self.dest, self.on_connect)

    @property
    def address(self):
        return self.socket.getsockname()

    @property
    def address_type(self):
        return self._address_type

    def on_connect(self):

        self.connection_callback(self)
        on_finish = functools.partial(self.on_streaming_data, finished=True)
        self.stream.read_until_close(on_finish, self.on_streaming_data)

    def on_close(self):
        if self.stream.error:
            self.error_callback(self, self.stream.error)
        else:
            self.close_callback(self)

    def on_streaming_data(self, data, finished=False):
        if len(data):
            self.streaming_callback(self, data)

    def do_write(self, data):
        try:
            self.stream.write(data)
        except IOError as e:
            self.close()

    def do_close(self):
        if self.socket:
            logger.info("close upstream: %s:%s" % self.address)
            self.stream.close()
예제 #18
0
class HTTP1ConnectionTest(AsyncTestCase):
    def setUp(self):
        super(HTTP1ConnectionTest, self).setUp()
        self.asyncSetUp()

    @gen_test
    def asyncSetUp(self):
        listener, port = bind_unused_port()
        event = Event()

        def accept_callback(conn, addr):
            self.server_stream = IOStream(conn)
            self.addCleanup(self.server_stream.close)
            event.set()

        add_accept_handler(listener, accept_callback)
        self.client_stream = IOStream(socket.socket())
        self.addCleanup(self.client_stream.close)
        yield [self.client_stream.connect(('127.0.0.1', port)),
               event.wait()]
        self.io_loop.remove_handler(listener)
        listener.close()

    @gen_test
    def test_http10_no_content_length(self):
        # Regression test for a bug in which can_keep_alive would crash
        # for an HTTP/1.0 (not 1.1) response with no content-length.
        conn = HTTP1Connection(self.client_stream, True)
        self.server_stream.write(b"HTTP/1.0 200 Not Modified\r\n\r\nhello")
        self.server_stream.close()

        event = Event()
        test = self
        body = []

        class Delegate(HTTPMessageDelegate):
            def headers_received(self, start_line, headers):
                test.code = start_line.code

            def data_received(self, data):
                body.append(data)

            def finish(self):
                event.set()

        yield conn.read_response(Delegate())
        yield event.wait()
        self.assertEqual(self.code, 200)
        self.assertEqual(b''.join(body), b'hello')
예제 #19
0
    def fetch(self, request, callback, **kwargs):
        if not isinstance(request, HTTPRequest):
            request = HTTPRequest(url=request, **kwargs)
        callback = stack_context.wrap(callback)

        parsed = urlparse.urlsplit(request.url)
        sock = socket.socket()
        #sock.setblocking(False) # TODO non-blocking connect
        sock.connect((parsed.netloc, 80))  # TODO: other ports, https
        stream = IOStream(sock, io_loop=self.io_loop)
        # TODO: query parameters
        logging.warning("%s %s HTTP/1.0\r\n\r\n" % (request.method, parsed.path or '/'))
        stream.write("%s %s HTTP/1.0\r\n\r\n" % (request.method, parsed.path or '/'))
        stream.read_until("\r\n\r\n", functools.partial(self._on_headers,
                                                        request, callback, stream))
예제 #20
0
        def accept_callback(conn, address):
            # fake an HTTP server using chunked encoding where the final chunks
            # and connection close all happen at once
            stream = IOStream(conn, io_loop=self.io_loop)
            stream.write(b("""\
HTTP/1.1 200 OK
Transfer-Encoding: chunked

1
1
1
2
0

""").replace(b("\n"), b("\r\n")), callback=stream.close)
예제 #21
0
class UnixSocketTest(AsyncTestCase):
    """HTTPServers can listen on Unix sockets too.

    Why would you want to do this?  Nginx can proxy to backends listening
    on unix sockets, for one thing (and managing a namespace for unix
    sockets can be easier than managing a bunch of TCP port numbers).

    Unfortunately, there's no way to specify a unix socket in a url for
    an HTTP client, so we have to test this by hand.
    """
    def setUp(self):
        super(UnixSocketTest, self).setUp()
        self.tmpdir = tempfile.mkdtemp()
        self.sockfile = os.path.join(self.tmpdir, "test.sock")
        sock = netutil.bind_unix_socket(self.sockfile)
        app = Application([("/hello", HelloWorldRequestHandler)])
        self.server = HTTPServer(app)
        self.server.add_socket(sock)
        self.stream = IOStream(socket.socket(socket.AF_UNIX))
        self.stream.connect(self.sockfile, self.stop)
        self.wait()

    def tearDown(self):
        self.stream.close()
        self.io_loop.run_sync(self.server.close_all_connections)
        self.server.stop()
        shutil.rmtree(self.tmpdir)
        super(UnixSocketTest, self).tearDown()

    def test_unix_socket(self):
        self.stream.write(b"GET /hello HTTP/1.0\r\n\r\n")
        self.stream.read_until(b"\r\n", self.stop)
        response = self.wait()
        self.assertEqual(response, b"HTTP/1.1 200 OK\r\n")
        self.stream.read_until(b"\r\n\r\n", self.stop)
        headers = HTTPHeaders.parse(self.wait().decode('latin1'))
        self.stream.read_bytes(int(headers["Content-Length"]), self.stop)
        body = self.wait()
        self.assertEqual(body, b"Hello world")

    def test_unix_socket_bad_request(self):
        # Unix sockets don't have remote addresses so they just return an
        # empty string.
        with ExpectLog(gen_log, "Malformed HTTP message from"):
            self.stream.write(b"garbage\r\n\r\n")
            self.stream.read_until_close(self.stop)
            response = self.wait()
        self.assertEqual(response, b"HTTP/1.1 400 Bad Request\r\n\r\n")
예제 #22
0
class ForwardConnection(object):
    def __init__(self, server, stream, address):
        self._close_callback = None
        self.server = server
        self.stream = stream
        self.reverse_address = address
        self.address = stream.socket.getsockname()
        self.remote_address = server.conf[self.address]
        sock = socket.socket()
        self.remote_stream = IOStream(sock)
        self.remote_stream.connect(self.remote_address, self._on_remote_connected)

    def close(self):
        self.remote_stream.close()

    def set_close_callback(self, callback):
        self._close_callback = callback

    def _on_remote_connected(self):
        ip_from = self.reverse_address[0]
        fwd_str = get_forwarding_str(self.address[0], self.address[1],
                                      self.remote_address[0], self.remote_address[1])
        logging.info('Connected ip: %s, forward %s', ip_from, fwd_str)
        self.remote_stream.read_until_close(self._on_remote_read_close, self.stream.write)
        self.stream.read_until_close(self._on_read_close, self.remote_stream.write)

    def _on_remote_read_close(self, data):
        if self.stream.writing():
            self.stream.write(data, self.stream.close)
        else:
            if self.stream.closed():
                self._on_closed()
            else:
                self.stream.close()

    def _on_read_close(self, data):
        if self.remote_stream.writing():
            self.remote_stream.write(data, self.remote_stream.close)
        else:
            if self.remote_stream.closed():
                self._on_closed()
            else:
                self.remote_stream.close()

    def _on_closed(self):
        logging.info('Disconnected ip: %s', self.reverse_address[0])
        if self._close_callback:
            self._close_callback(self)
예제 #23
0
    def test_handle_stream_coroutine_logging(self):
        # handle_stream may be a coroutine and any exception in its
        # Future will be logged.
        class TestServer(TCPServer):
            @gen.coroutine
            def handle_stream(self, stream, address):
                yield stream.read_bytes(len(b"hello"))
                stream.close()
                1 / 0

        server = client = None
        try:
            sock, port = bind_unused_port()
            server = TestServer()
            server.add_socket(sock)
            client = IOStream(socket.socket())
            with ExpectLog(app_log, "Exception in callback"):
                yield client.connect(("localhost", port))
                yield client.write(b"hello")
                yield client.read_until_close()
                yield gen.moment
        finally:
            if server is not None:
                server.stop()
            if client is not None:
                client.close()
예제 #24
0
class EchoClienAsync(object):
    def __init__(self, host = "127.0.0.1", port=12345):
        self.host = host
        self.port = port
        s = socket.socket(socket.AF_INET, socket.SOCK_STREAM, 0)
        self.stream = IOStream(s)

    def connect(self):
        self.stream.connect((self.host, self.port))
        return self.stream

    def send(self, data, callback=None):
        self.stream.write(data+"\n", callback)

    def recv(self, callback):
        self.stream.read_until("\n", lambda data: callback(data[:-1]))
예제 #25
0
    def test_indexing_line(self):
        client = AsyncHTTPClient(io_loop=self.io_loop)
        ping = yield client.fetch("http://*****:*****@version'], 1)
        self.assertEqual(doc['message'], "My name is Yuri and I'm 6 years old.")
예제 #26
0
class UnixSocketTest(AsyncTestCase):
    """HTTPServers can listen on Unix sockets too.

    Why would you want to do this?  Nginx can proxy to backends listening
    on unix sockets, for one thing (and managing a namespace for unix
    sockets can be easier than managing a bunch of TCP port numbers).

    Unfortunately, there's no way to specify a unix socket in a url for
    an HTTP client, so we have to test this by hand.
    """
    def setUp(self):
        super(UnixSocketTest, self).setUp()
        self.tmpdir = tempfile.mkdtemp()
        self.sockfile = os.path.join(self.tmpdir, "test.sock")
        sock = netutil.bind_unix_socket(self.sockfile)
        app = Application([("/hello", HelloWorldRequestHandler)])
        self.server = HTTPServer(app)
        self.server.add_socket(sock)
        self.stream = IOStream(socket.socket(socket.AF_UNIX))
        self.stream.connect(self.sockfile, self.stop)
        self.wait()

    def tearDown(self):
        self.stream.close()
        self.server.stop()
        shutil.rmtree(self.tmpdir)
        super(UnixSocketTest, self).tearDown()

    def test_unix_socket(self):
        self.stream.write(b"GET /hello HTTP/1.0\r\n\r\n")
        self.stream.read_until(b"\r\n", self.stop)
        response = self.wait()
        self.assertEqual(response, b"HTTP/1.1 200 OK\r\n")
        self.stream.read_until(b"\r\n\r\n", self.stop)
        headers = HTTPHeaders.parse(self.wait().decode('latin1'))
        self.stream.read_bytes(int(headers["Content-Length"]), self.stop)
        body = self.wait()
        self.assertEqual(body, b"Hello world")

    def test_unix_socket_bad_request(self):
        # Unix sockets don't have remote addresses so they just return an
        # empty string.
        with ExpectLog(gen_log, "Malformed HTTP message from"):
            self.stream.write(b"garbage\r\n\r\n")
            self.stream.read_until_close(self.stop)
            response = self.wait()
        self.assertEqual(response, b"HTTP/1.1 400 Bad Request\r\n\r\n")
예제 #27
0
class IRCStream(object):

    _instance = None

    @classmethod
    def instance(cls):
        """ Returns the singleton """
        if not cls._instance:
            cls._instance = cls()
        return cls._instance

    def __init__(self):
        sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM, 0)
        self.stream = IOStream(sock)
        self.host = SETTINGS["irc_host"]
        self.channel = SETTINGS["irc_channel"]
        self.stream.connect((self.host, SETTINGS["irc_port"]))
        self.nick = "PyTexasBot"
        self.ident = "pytexasbot"
        self.real_name = "PyTexas StreamBot"

        self.stream.write("NICK %s\r\n" % self.nick)
        self.stream.write("USER %s %s blah :%s\r\n" % (self.ident, self.host,
            self.real_name))
        self.stream.write("JOIN #"+self.channel+"\r\n")
        self.monitor_output()

    def monitor_output(self):
        self.stream.read_until("\r\n", self.parse_line)

    def parse_line(self, response):
        response = response.strip()
        if response.startswith("PING "):
            request = response.replace("PING ", "")
            self.stream.write("PONG %s\r\n" % request)
        splitter = "PRIVMSG #%s :" % self.channel
        if splitter in response:
            parts = response.split(splitter)
            text = parts[1]
            if not text:
                # not going to throw out empty messages
                return self.monitor_output()
            nick = parts[0][1:].split("!")[0].strip()
            message = {
                "time": int(time.time()),
                "text": xhtml_escape(text),
                "name": nick,
                "username": "******",
                "type": "tweet",
                "avatar": None
            }
            broadcast_message(message)
        if response.startswith("ERROR"):
            raise Exception(response)
        else:
            print response
        self.monitor_output()
예제 #28
0
class Connection(object):
	def __init__(self, host = "localhost", port = 6379, timeout = None, io_loop = None):
		self.host = host
		self.port = port
		self._io_loop = io_loop
		self._stream = None
		self.in_porcess = False
		self.timeout = timeout
		self._lock = 0
		self.info = {"db": 0, "pass":None}

	def __del__(self):
		self.disconnect()

	#Connect to Redis Server, use tornado.iostream.IOStream to progress write and read working
	def connect(self):
		if not self._stream:
			try:
				sock = socket.create_connection((self.host, self.port), timeout = self.timeout)
				sock.setsocketopt(socket.COL_TCP, socket.TCP_NODELAY, 1)
				self._stream = IOStream(sock, io_loop = self._io_loop)
				self._stream.set_close_callback(self.on_stream_close)
				self.info["db"] = 0
				self.info["pass"] = None
			except socket.error as e:
				raise ConnectionError(e.message)

	#the operation when stram closing
	def on_stram_close(self):
		if self._stream:
			self.disconnect()

	#close the connection
	def disconnect(self):
		if self._stream:
			s = self._stream
			self._stream = None
		try:		
			if s.socket:		
				s.socket.shutdown(socket.SHUT_RDWR)
			s.close()
		except:	
			pass

	
	#Write data		
	@gen.coroutine
	def write(self, data):
		try:	
			if nnot self._stream:
				self.disconnect()
				raise ConnectionError("Try to wrtie to non-exist Connection")
			if sys.version > "3":
				data = bytes(data, encoding = "utf-8")
			yield self._stream.write(data)
		except IOError as e:	
			raise ConnectionError(e.message)
예제 #29
0
class ConnectionCloseTest(AsyncHTTPTestCase, LogTrapTestCase):
    def get_app(self):
        return Application([('/', ConnectionCloseHandler, dict(test=self))])

    def test_connection_close(self):
        s = socket.socket(socket.AF_INET, socket.SOCK_STREAM, 0)
        s.connect(("localhost", self.get_http_port()))
        self.stream = IOStream(s, io_loop=self.io_loop)
        self.stream.write("GET / HTTP/1.0\r\n\r\n")
        self.wait()

    def on_handler_waiting(self):
        logging.info('handler waiting')
        self.stream.close()

    def on_connection_close(self):
        logging.info('connection closed')
        self.stop()
예제 #30
0
class ConnectionCloseTest(AsyncHTTPTestCase, LogTrapTestCase):
    def get_app(self):
        return Application([('/', ConnectionCloseHandler, dict(test=self))])

    def test_connection_close(self):
        s = socket.socket(socket.AF_INET, socket.SOCK_STREAM, 0)
        s.connect(("localhost", self.get_http_port()))
        self.stream = IOStream(s, io_loop=self.io_loop)
        self.stream.write(b("GET / HTTP/1.0\r\n\r\n"))
        self.wait()

    def on_handler_waiting(self):
        logging.info('handler waiting')
        self.stream.close()

    def on_connection_close(self):
        logging.info('connection closed')
        self.stop()
예제 #31
0
class ConnectionCloseTest(WebTestCase):
    def get_handlers(self):
        return [('/', ConnectionCloseHandler, dict(test=self))]

    def test_connection_close(self):
        s = socket.socket(socket.AF_INET, socket.SOCK_STREAM, 0)
        s.connect(("localhost", self.get_http_port()))
        self.stream = IOStream(s, io_loop=self.io_loop)
        self.stream.write(b("GET / HTTP/1.0\r\n\r\n"))
        self.wait()

    def on_handler_waiting(self):
        logging.debug('handler waiting')
        self.stream.close()

    def on_connection_close(self):
        logging.debug('connection closed')
        self.stop()
예제 #32
0
 def test_unix_socket(self):
     sockfile = os.path.join(self.tmpdir, "test.sock")
     sock = netutil.bind_unix_socket(sockfile)
     app = Application([("/hello", HelloWorldRequestHandler)])
     server = HTTPServer(app, io_loop=self.io_loop)
     server.add_socket(sock)
     stream = IOStream(socket.socket(socket.AF_UNIX), io_loop=self.io_loop)
     stream.connect(sockfile, self.stop)
     self.wait()
     stream.write(b("GET /hello HTTP/1.0\r\n\r\n"))
     stream.read_until(b("\r\n"), self.stop)
     response = self.wait()
     self.assertEqual(response, b("HTTP/1.0 200 OK\r\n"))
     stream.read_until(b("\r\n\r\n"), self.stop)
     headers = HTTPHeaders.parse(self.wait().decode('latin1'))
     stream.read_bytes(int(headers["Content-Length"]), self.stop)
     body = self.wait()
     self.assertEqual(body, b("Hello world"))
예제 #33
0
 def test_unix_socket(self):
     sockfile = os.path.join(self.tmpdir, "test.sock")
     sock = netutil.bind_unix_socket(sockfile)
     app = Application([("/hello", HelloWorldRequestHandler)])
     server = HTTPServer(app, io_loop=self.io_loop)
     server.add_socket(sock)
     stream = IOStream(socket.socket(socket.AF_UNIX), io_loop=self.io_loop)
     stream.connect(sockfile, self.stop)
     self.wait()
     stream.write(b("GET /hello HTTP/1.0\r\n\r\n"))
     stream.read_until(b("\r\n"), self.stop)
     response = self.wait()
     self.assertEqual(response, b("HTTP/1.0 200 OK\r\n"))
     stream.read_until(b("\r\n\r\n"), self.stop)
     headers = HTTPHeaders.parse(self.wait().decode('latin1'))
     stream.read_bytes(int(headers["Content-Length"]), self.stop)
     body = self.wait()
     self.assertEqual(body, b("Hello world"))
예제 #34
0
파일: web_test.py 프로젝트: akkakks/tornado
class ConnectionCloseTest(WebTestCase):
    def get_handlers(self):
        return [('/', ConnectionCloseHandler, dict(test=self))]

    def test_connection_close(self):
        s = socket.socket(socket.AF_INET, socket.SOCK_STREAM, 0)
        s.connect(("localhost", self.get_http_port()))
        self.stream = IOStream(s, io_loop=self.io_loop)
        self.stream.write(b"GET / HTTP/1.0\r\n\r\n")
        self.wait()

    def on_handler_waiting(self):
        logging.debug('handler waiting')
        self.stream.close()

    def on_connection_close(self):
        logging.debug('connection closed')
        self.stop()
예제 #35
0
class DecoratorCapClient(BaseCapClient):
    @future_wrap
    def capitalize(self, request_data, callback):
        logging.info("capitalize")
        self.request_data = request_data
        self.stream = IOStream(socket.socket(), io_loop=self.io_loop)
        self.stream.connect(('127.0.0.1', self.port),
                            callback=self.handle_connect)
        self.callback = callback

    def handle_connect(self):
        logging.info("handle_connect")
        self.stream.write(utf8(self.request_data + "\n"))
        self.stream.read_until(b('\n'), callback=self.handle_read)

    def handle_read(self, data):
        logging.info("handle_read")
        self.stream.close()
        self.callback(self.process_response(data))
예제 #36
0
class DecoratorCapClient(BaseCapClient):
    @return_future
    def capitalize(self, request_data, callback):
        logging.info("capitalize")
        self.request_data = request_data
        self.stream = IOStream(socket.socket())
        self.stream.connect(('127.0.0.1', self.port),
                            callback=self.handle_connect)
        self.callback = callback

    def handle_connect(self):
        logging.info("handle_connect")
        self.stream.write(utf8(self.request_data + "\n"))
        self.stream.read_until(b'\n', callback=self.handle_read)

    def handle_read(self, data):
        logging.info("handle_read")
        self.stream.close()
        self.callback(self.process_response(data))
예제 #37
0
 def test_100_continue(self):
     # Run through a 100-continue interaction by hand:
     # When given Expect: 100-continue, we get a 100 response after the
     # headers, and then the real response after the body.
     stream = IOStream(socket.socket())
     yield stream.connect(("127.0.0.1", self.get_http_port()))
     yield stream.write(b"\r\n".join([
         b"POST /hello HTTP/1.1", b"Content-Length: 1024",
         b"Expect: 100-continue", b"Connection: close", b"\r\n"
     ]))
     data = yield stream.read_until(b"\r\n\r\n")
     self.assertTrue(data.startswith(b"HTTP/1.1 100 "), data)
     stream.write(b"a" * 1024)
     first_line = yield stream.read_until(b"\r\n")
     self.assertTrue(first_line.startswith(b"HTTP/1.1 200"), first_line)
     header_data = yield stream.read_until(b"\r\n\r\n")
     headers = HTTPHeaders.parse(native_str(header_data.decode('latin1')))
     body = yield stream.read_bytes(int(headers["Content-Length"]))
     self.assertEqual(body, b"Got 1024 bytes in POST")
     stream.close()
예제 #38
0
            def accept_callback(conn, address):
                stream = IOStream(conn)
                request_data = yield stream.read_until(b"\r\n\r\n")
                if b"HTTP/1." not in request_data:
                    self.skipTest("requires HTTP/1.x")
                yield stream.write(b"""\
HTTP/1.1 200 OK
X-XSS-Protection: 1;
\tmode=block

""".replace(b"\n", b"\r\n"))
                stream.close()
예제 #39
0
class TestIOStream(AsyncHTTPTestCase, LogTrapTestCase):
    def get_app(self):
        return Application([('/', HelloHandler)])

    def test_read_zero_bytes(self):
        s = socket.socket(socket.AF_INET, socket.SOCK_STREAM, 0)
        s.connect(("localhost", self.get_http_port()))
        self.stream = IOStream(s, io_loop=self.io_loop)
        self.stream.write("GET / HTTP/1.0\r\n\r\n")

        # normal read
        self.stream.read_bytes(9, self.stop)
        data = self.wait()
        self.assertEqual(data, "HTTP/1.0 ")

        # zero bytes
        self.stream.read_bytes(0, self.stop)
        data = self.wait()
        self.assertEqual(data, "")

        # another normal read
        self.stream.read_bytes(3, self.stop)
        data = self.wait()
        self.assertEqual(data, "200")

    def test_connection_refused(self):
        # When a connection is refused, the connect callback should not
        # be run.  (The kqueue IOLoop used to behave differently from the
        # epoll IOLoop in this respect)
        port = get_unused_port()
        stream = IOStream(socket.socket(), self.io_loop)
        self.connect_called = False

        def connect_callback():
            self.connect_called = True

        stream.set_close_callback(self.stop)
        stream.connect(("localhost", port), connect_callback)
        self.wait()
        self.assertFalse(self.connect_called)
예제 #40
0
class TokenizerService(object):
    '''
    Wraps the IPC to the Java TokenizerService (which runs tokenization and named
    entity extraction through CoreNLP)
    '''

    def __init__(self):
        self._socket = IOStream(socket.socket(socket.AF_INET, socket.SOCK_STREAM))
        self._requests = dict()
        self._next_id = 0
        
    @tornado.gen.coroutine
    def run(self):
        yield self._socket.connect(('127.0.0.1', PORT))
        
        while True:
            try:
                response = yield self._socket.read_until(b'\n')
            except StreamClosedError:
                response = None
            if not response:
                return
            response = json.loads(str(response, encoding='utf-8'))
            
            id = int(response['req'])
            result = TokenizerResult(tokens=list(clean_tokens(response['tokens'])),
                                     values=response['values'],
                                     constituency_parse=response['constituencyParse'],
                                     pos_tags=response['pos'],
                                     raw_tokens=response['rawTokens'],
                                     sentiment=response['sentiment'])
            self._requests[id].set_result(result)
            del self._requests[id]
        
    def tokenize(self, language_tag, query, expect=None):
        id = self._next_id
        self._next_id += 1
        
        req = dict(req=id, utterance=query, languageTag=language_tag)
        if expect is not None:
            req['expect'] = expect
        outer = Future()
        self._requests[id] = outer
        
        def then(future):
            if future.exception():
                outer.set_exception(future.exception())
                del self._requests[id]
        
        future = self._socket.write(json.dumps(req).encode())
        future.add_done_callback(then)
        return outer
예제 #41
0
class AsyncSocketHanlder(SocketHandler):
    """
    """
    def __init__(self, host, port, ioloop=None):
        """"""
        super(AsyncSocketHanlder, self).__init__(host, port)
        self._ioloop = ioloop
        self._stream = None

    def makeSocket(self, timeout=1):
        """"""
        s = socket.socket(socket.AF_INET, socket.SOCK_STREAM, 0)
        self._stream = IOStream(s)
        self._stream.connect((self.host, self.port))
        return s

    def send(self, s):
        """"""
        print s
        if self._stream is None:
            self.createSocket()
        self._stream.write(s)
예제 #42
0
class HTTPServerRawTest(AsyncHTTPTestCase):
    def get_app(self):
        return Application([
            ('/echo', EchoHandler),
        ])

    def setUp(self):
        super(HTTPServerRawTest, self).setUp()
        self.stream = IOStream(socket.socket())
        self.stream.connect(('localhost', self.get_http_port()), self.stop)
        self.wait()

    def tearDown(self):
        self.stream.close()
        super(HTTPServerRawTest, self).tearDown()

    def test_empty_request(self):
        self.stream.close()
        self.io_loop.add_timeout(datetime.timedelta(seconds=0.001), self.stop)
        self.wait()

    def test_malformed_first_line(self):
        with ExpectLog(gen_log, '.*Malformed HTTP request line'):
            self.stream.write(b'asdf\r\n\r\n')
            # TODO: need an async version of ExpectLog so we don't need
            # hard-coded timeouts here.
            self.io_loop.add_timeout(datetime.timedelta(seconds=0.01),
                                     self.stop)
            self.wait()

    def test_malformed_headers(self):
        with ExpectLog(gen_log, '.*Malformed HTTP headers'):
            self.stream.write(b'GET / HTTP/1.0\r\nasdf\r\n\r\n')
            self.io_loop.add_timeout(datetime.timedelta(seconds=0.01),
                                     self.stop)
            self.wait()
예제 #43
0
class TestIOStream(AsyncHTTPTestCase, LogTrapTestCase):
    def get_app(self):
        return Application([('/', HelloHandler)])

    def test_read_zero_bytes(self):
        s = socket.socket(socket.AF_INET, socket.SOCK_STREAM, 0)
        s.connect(("localhost", self.get_http_port()))
        self.stream = IOStream(s, io_loop=self.io_loop)
        self.stream.write("GET / HTTP/1.0\r\n\r\n")

        # normal read
        self.stream.read_bytes(9, self.stop)
        data = self.wait()
        self.assertEqual(data, "HTTP/1.0 ")

        # zero bytes
        self.stream.read_bytes(0, self.stop)
        data = self.wait()
        self.assertEqual(data, "")

        # another normal read
        self.stream.read_bytes(3, self.stop)
        data = self.wait()
        self.assertEqual(data, "200")
예제 #44
0
class ManualCapClient(BaseCapClient):
    def capitalize(self, request_data, callback=None):
        logging.info("capitalize")
        self.request_data = request_data
        self.stream = IOStream(socket.socket(), io_loop=self.io_loop)
        self.stream.connect(('127.0.0.1', self.port),
                            callback=self.handle_connect)
        self.future = Future()
        if callback is not None:
            self.future.add_done_callback(callback)
        return self.future

    def handle_connect(self):
        logging.info("handle_connect")
        self.stream.write(utf8(self.request_data + "\n"))
        self.stream.read_until(b('\n'), callback=self.handle_read)

    def handle_read(self, data):
        logging.info("handle_read")
        self.stream.close()
        try:
            self.future.set_result(self.process_response(data))
        except CapError, e:
            self.future.set_exception(e)
예제 #45
0
class EchoClient(object):
    """
    An asynchronous client for EchoServer
    """

    def __init__(self, address, family=socket.AF_INET, socktype=socket.SOCK_STREAM):
        self.io_stream = IOStream(socket.socket(family, socktype, 0))
        self.address = address
        self.is_closed = False

    def handle_close(self, data):
        self.is_closed = True

    def send_message(self, message, handle_response):
        def handle_connect():
            self.io_stream.read_until_close(self.handle_close, handle_response)
            self.write(message)

        self.io_stream.connect(self.address, handle_connect)

    def write(self, message):
        if not isinstance(message, bytes):
            message = message.encode("UTF-8")
        self.io_stream.write(message)
예제 #46
0
class ESME(DeliverMixin, BaseESME):
    def __init__(self, **kwargs):
        BaseESME.__init__(self, **kwargs)
        self.running = False
        self.closed = False

    @coroutine
    def connect(self, host, port):
        s = socket.socket(socket.AF_INET, socket.SOCK_STREAM, 0)
        self.ioloop = IOLoop.current()
        self.stream = IOStream(s)
        yield self.stream.connect((host, port))

    def on_send(self, data):
        return self.stream.write(data)

    def on_close(self):
        self.closed = True
        self.stream.close()

    @coroutine
    def readloop(self, future):
        while not self.closed and (not future or not future.done()):
            try:
                data = yield self.stream.read_bytes(1024, partial=True)
            except StreamClosedError:  # pragma: no cover
                break
            else:
                self.feed(data)

    def wait_for(self, response):
        future = Future()
        response.callback = lambda resp: future.set_result(resp.response)
        if self.running:
            return future
        else:
            return self.run(future)

    @coroutine
    def run(self, future=None):
        self.running = True
        try:
            yield self.readloop(future)
        finally:
            self.running = False

        if future and future.done():
            raise Return(future.result())
예제 #47
0
class HTTPServerRawTest(AsyncHTTPTestCase):
    def get_app(self):
        return Application([
            ('/echo', EchoHandler),
        ])

    def setUp(self):
        super(HTTPServerRawTest, self).setUp()
        self.stream = IOStream(socket.socket())
        self.stream.connect(('localhost', self.get_http_port()), self.stop)
        self.wait()

    def tearDown(self):
        self.stream.close()
        super(HTTPServerRawTest, self).tearDown()

    def test_empty_request(self):
        self.stream.close()
        self.io_loop.add_timeout(datetime.timedelta(seconds=0.001), self.stop)
        self.wait()

    def test_malformed_first_line(self):
        with ExpectLog(gen_log, '.*Malformed HTTP request line'):
            self.stream.write(b'asdf\r\n\r\n')
            # TODO: need an async version of ExpectLog so we don't need
            # hard-coded timeouts here.
            self.io_loop.add_timeout(datetime.timedelta(seconds=0.01),
                                     self.stop)
            self.wait()

    def test_malformed_headers(self):
        with ExpectLog(gen_log, '.*Malformed HTTP headers'):
            self.stream.write(b'GET / HTTP/1.0\r\nasdf\r\n\r\n')
            self.io_loop.add_timeout(datetime.timedelta(seconds=0.01),
                                     self.stop)
            self.wait()

    def test_chunked_request_body(self):
        # Chunked requests are not widely supported and we don't have a way
        # to generate them in AsyncHTTPClient, but HTTPServer will read them.
        self.stream.write(b"""\
POST /echo HTTP/1.1
Transfer-Encoding: chunked
Content-Type: application/x-www-form-urlencoded

4
foo=
3
bar
0

""".replace(b"\n", b"\r\n"))
        read_stream_body(self.stream, self.stop)
        headers, response = self.wait()
        self.assertEqual(json_decode(response), {u('foo'): [u('bar')]})
예제 #48
0
            def accept_callback(conn, address):
                # fake an HTTP server using chunked encoding where the final chunks
                # and connection close all happen at once
                stream = IOStream(conn)
                request_data = yield stream.read_until(b"\r\n\r\n")
                if b"HTTP/1." not in request_data:
                    self.skipTest("requires HTTP/1.x")
                yield stream.write(b"""\
HTTP/1.1 200 OK
Transfer-Encoding: chunked

1
1
1
2
0

""".replace(b"\n", b"\r\n"))
                stream.close()
예제 #49
0
class DBusConnection:
    def __init__(self, bus_addr):
        self.auth_parser = SASLParser()
        self.parser = Parser()
        self.router = Router(Future)
        self.authentication = Future()
        self.unique_name = None

        self._sock = socket.socket(family=socket.AF_UNIX)
        self.stream = IOStream(self._sock, read_chunk_size=4096)

        def connected():
            self.stream.write(b'\0' + make_auth_external())

        self.stream.connect(bus_addr, connected)
        self.stream.read_until_close(streaming_callback=self.data_received)

    def _authenticated(self):
        self.stream.write(BEGIN)
        self.authentication.set_result(True)
        self.data_received_post_auth(self.auth_parser.buffer)

    def data_received(self, data):
        if self.authentication.done():
            return self.data_received_post_auth(data)

        self.auth_parser.feed(data)
        if self.auth_parser.authenticated:
            self._authenticated()
        elif self.auth_parser.error:
            self.authentication.set_exception(AuthenticationError(self.auth_parser.error))

    def data_received_post_auth(self, data):
        for msg in self.parser.feed(data):
            self.router.incoming(msg)

    def send_message(self, message):
        if not self.authentication.done():
            raise RuntimeError("Wait for authentication before sending messages")

        future = self.router.outgoing(message)
        data = message.serialise()
        self.stream.write(data)
        return future
예제 #50
0
    def test_indexing_line(self):
        client = AsyncHTTPClient(io_loop=self.io_loop)
        ping = yield client.fetch("http://*****:*****@version'], 1)
        self.assertEqual(doc['message'],
                         "My name is Yuri and I'm 6 years old.")
예제 #51
0
 def test_body_size_override_reset(self):
     # The max_body_size override is reset between requests.
     stream = IOStream(socket.socket())
     try:
         yield stream.connect(("10.0.0.7", self.get_http_port()))
         # Use a raw stream so we can make sure it's all on one connection.
         stream.write(b"PUT /streaming?expected_size=10240 HTTP/1.1\r\n"
                      b"Content-Length: 10240\r\n\r\n")
         stream.write(b"a" * 10240)
         start_line, headers, response = yield read_stream_body(stream)
         self.assertEqual(response, b"10240")
         # Without the ?expected_size parameter, we get the old default value
         stream.write(b"PUT /streaming HTTP/1.1\r\n"
                      b"Content-Length: 10240\r\n\r\n")
         with ExpectLog(gen_log, ".*Content-Length too long"):
             data = yield stream.read_until_close()
         self.assertEqual(data, b"HTTP/1.1 400 Bad Request\r\n\r\n")
     finally:
         stream.close()
예제 #52
0
 def test_body_size_override_reset(self):
     # The max_body_size override is reset between requests.
     stream = IOStream(socket.socket())
     try:
         yield stream.connect(('127.0.0.1', self.get_http_port()))
         # Use a raw stream so we can make sure it's all on one connection.
         stream.write(b'PUT /streaming?expected_size=10240 HTTP/1.1\r\n'
                      b'Content-Length: 10240\r\n\r\n')
         stream.write(b'a' * 10240)
         headers, response = yield gen.Task(read_stream_body, stream)
         self.assertEqual(response, b'10240')
         # Without the ?expected_size parameter, we get the old default value
         stream.write(b'PUT /streaming HTTP/1.1\r\n'
                      b'Content-Length: 10240\r\n\r\n')
         with ExpectLog(gen_log, '.*Content-Length too long'):
             data = yield stream.read_until_close()
         self.assertEqual(data, b'')
     finally:
         stream.close()
예제 #53
0
class _HTTPConnection(object):
    _SUPPORTED_METHODS = set(["GET", "HEAD", "POST", "PUT", "DELETE"])

    def __init__(self, io_loop, client, request, release_callback,
                 final_callback, max_buffer_size):
        self.start_time = time.time()
        self.io_loop = io_loop
        self.client = client
        self.request = request
        self.release_callback = release_callback
        self.final_callback = final_callback
        self.code = None
        self.headers = None
        self.chunks = None
        self._decompressor = None
        # Timeout handle returned by IOLoop.add_timeout
        self._timeout = None
        with stack_context.StackContext(self.cleanup):
            parsed = urlparse.urlsplit(_unicode(self.request.url))
            if ssl is None and parsed.scheme == "https":
                raise ValueError("HTTPS requires either python2.6+ or "
                                 "curl_httpclient")
            if parsed.scheme not in ("http", "https"):
                raise ValueError("Unsupported url scheme: %s" %
                                 self.request.url)
            # urlsplit results have hostname and port results, but they
            # didn't support ipv6 literals until python 2.7.
            netloc = parsed.netloc
            if "@" in netloc:
                userpass, _, netloc = netloc.rpartition("@")
            match = re.match(r'^(.+):(\d+)$', netloc)
            if match:
                host = match.group(1)
                port = int(match.group(2))
            else:
                host = netloc
                port = 443 if parsed.scheme == "https" else 80
            if re.match(r'^\[.*\]$', host):
                # raw ipv6 addresses in urls are enclosed in brackets
                host = host[1:-1]
            if self.client.hostname_mapping is not None:
                host = self.client.hostname_mapping.get(host, host)

            if request.allow_ipv6:
                af = socket.AF_UNSPEC
            else:
                # We only try the first IP we get from getaddrinfo,
                # so restrict to ipv4 by default.
                af = socket.AF_INET

            addrinfo = socket.getaddrinfo(host, port, af, socket.SOCK_STREAM,
                                          0, 0)
            af, socktype, proto, canonname, sockaddr = addrinfo[0]

            if parsed.scheme == "https":
                ssl_options = {}
                if request.validate_cert:
                    ssl_options["cert_reqs"] = ssl.CERT_REQUIRED
                if request.ca_certs is not None:
                    ssl_options["ca_certs"] = request.ca_certs
                else:
                    ssl_options["ca_certs"] = _DEFAULT_CA_CERTS
                if request.client_key is not None:
                    ssl_options["keyfile"] = request.client_key
                if request.client_cert is not None:
                    ssl_options["certfile"] = request.client_cert
                self.stream = SSLIOStream(socket.socket(af, socktype, proto),
                                          io_loop=self.io_loop,
                                          ssl_options=ssl_options,
                                          max_buffer_size=max_buffer_size)
            else:
                self.stream = IOStream(socket.socket(af, socktype, proto),
                                       io_loop=self.io_loop,
                                       max_buffer_size=max_buffer_size)
            timeout = min(request.connect_timeout, request.request_timeout)
            if timeout:
                self._timeout = self.io_loop.add_timeout(
                    self.start_time + timeout,
                    self._on_timeout)
            self.stream.set_close_callback(self._on_close)
            self.stream.connect(sockaddr,
                                functools.partial(self._on_connect, parsed))

    def _on_timeout(self):
        self._timeout = None
        self._run_callback(HTTPResponse(self.request, 599,
                                        request_time=time.time() - self.start_time,
                                        error=HTTPError(599, "Timeout")))
        self.stream.close()

    def _on_connect(self, parsed):
        if self._timeout is not None:
            self.io_loop.remove_timeout(self._timeout)
            self._timeout = None
        if self.request.request_timeout:
            self._timeout = self.io_loop.add_timeout(
                self.start_time + self.request.request_timeout,
                self._on_timeout)
        if (self.request.validate_cert and
            isinstance(self.stream, SSLIOStream)):
            match_hostname(self.stream.socket.getpeercert(),
                           parsed.hostname)
        if (self.request.method not in self._SUPPORTED_METHODS and
            not self.request.allow_nonstandard_methods):
            raise KeyError("unknown method %s" % self.request.method)
        for key in ('network_interface',
                    'proxy_host', 'proxy_port',
                    'proxy_username', 'proxy_password'):
            if getattr(self.request, key, None):
                raise NotImplementedError('%s not supported' % key)
        if "Host" not in self.request.headers:
            self.request.headers["Host"] = parsed.netloc
        username, password = None, None
        if parsed.username is not None:
            username, password = parsed.username, parsed.password
        elif self.request.auth_username is not None:
            username = self.request.auth_username
            password = self.request.auth_password
        if username is not None:
            auth = utf8(username) + b(":") + utf8(password)
            self.request.headers["Authorization"] = (b("Basic ") +
                                                     base64.b64encode(auth))
        if self.request.user_agent:
            self.request.headers["User-Agent"] = self.request.user_agent
        if not self.request.allow_nonstandard_methods:
            if self.request.method in ("POST", "PUT"):
                assert self.request.body is not None
            else:
                assert self.request.body is None
        if self.request.body is not None:
            self.request.headers["Content-Length"] = str(len(
                    self.request.body))
        if (self.request.method == "POST" and
            "Content-Type" not in self.request.headers):
            self.request.headers["Content-Type"] = "application/x-www-form-urlencoded"
        if self.request.use_gzip:
            self.request.headers["Accept-Encoding"] = "gzip"
        req_path = ((parsed.path or '/') +
                (('?' + parsed.query) if parsed.query else ''))
        request_lines = [utf8("%s %s HTTP/1.1" % (self.request.method,
                                                  req_path))]
        for k, v in self.request.headers.get_all():
            line = utf8(k) + b(": ") + utf8(v)
            if b('\n') in line:
                raise ValueError('Newline in header: ' + repr(line))
            request_lines.append(line)
        self.stream.write(b("\r\n").join(request_lines) + b("\r\n\r\n"))
        if self.request.body is not None:
            self.stream.write(self.request.body)
        self.stream.read_until_regex(b("\r?\n\r?\n"), self._on_headers)

    def _release(self):
        if self.release_callback is not None:
            release_callback = self.release_callback
            self.release_callback = None
            release_callback()

    def _run_callback(self, response):
        self._release()
        if self.final_callback is not None:
            final_callback = self.final_callback
            self.final_callback = None
            final_callback(response)

    @contextlib.contextmanager
    def cleanup(self):
        try:
            yield
        except Exception, e:
            logging.warning("uncaught exception", exc_info=True)
            self._run_callback(HTTPResponse(self.request, 599, error=e, 
                                request_time=time.time() - self.start_time,
                                ))
예제 #54
0
파일: connection.py 프로젝트: 736907871/zyl
class Connection(object):
    def __init__(self,
                 host='localhost',
                 port=6379,
                 unix_socket_path=None,
                 event_handler_proxy=None,
                 stop_after=None,
                 io_loop=None):
        self.host = host
        self.port = port
        self.unix_socket_path = unix_socket_path
        self._event_handler = event_handler_proxy
        self.timeout = stop_after
        self._stream = None
        self._io_loop = io_loop

        self.in_progress = False
        self.read_callbacks = set()
        self.ready_callbacks = deque()
        self._lock = 0
        self.info = {'db': 0, 'pass': None}

    def __del__(self):
        self.disconnect()

    def execute_pending_command(self):
        # Continue with the pending command execution
        # if all read operations are completed.
        if not self.read_callbacks and self.ready_callbacks:
            # Pop a SINGLE callback from the queue and execute it.
            # The next one will be executed from the code
            # invoked by the callback
            callback = self.ready_callbacks.popleft()
            callback()

    def ready(self):
        return (not self.read_callbacks and not self.ready_callbacks)

    def wait_until_ready(self, callback=None):
        if callback:
            if not self.ready():
                callback = stack_context.wrap(callback)
                self.ready_callbacks.append(callback)
            else:
                callback()

    def connect(self):
        if not self._stream:
            try:
                if self.unix_socket_path:
                    sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
                    sock.settimeout(self.timeout)
                    sock.connect(self.unix_socket_path)
                else:
                    sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM, 0)
                    sock.setsockopt(socket.SOL_TCP, socket.TCP_NODELAY, 1)
                    sock.settimeout(self.timeout)
                    sock.connect((self.host, self.port))
                self._stream = IOStream(sock, io_loop=self._io_loop)
                self._stream.set_close_callback(self.on_stream_close)
                self.info['db'] = 0
                self.info['pass'] = None
            except socket.error as e:
                raise ConnectionError(str(e))
            self.fire_event('on_connect')

    def on_stream_close(self):
        if self._stream:
            self.disconnect()
            callbacks = self.read_callbacks
            self.read_callbacks = set()
            for callback in callbacks:
                callback()

    def disconnect(self):
        if self._stream:
            s = self._stream
            self._stream = None
            try:
                if s.socket:
                    s.socket.shutdown(socket.SHUT_RDWR)
                s.close()
            except:
                pass

    def fire_event(self, event):
        event_handler = self._event_handler
        if event_handler:
            try:
                getattr(event_handler, event)()
            except AttributeError:
                pass

    def write(self, data, callback=None):
        if not self._stream:
            raise ConnectionError('Tried to write to '
                                  'non-existent connection')

        if callback:
            callback = stack_context.wrap(callback)
            _callback = lambda: callback(None)
            self.read_callbacks.add(_callback)
            cb = partial(self.read_callback, _callback)
        else:
            cb = None
        try:
            if PY3:
                data = bytes(data, encoding='utf-8')
            self._stream.write(data, callback=cb)
        except IOError as e:
            self.disconnect()
            raise ConnectionError(e.message)

    def read(self, length, callback=None):
        try:
            if not self._stream:
                self.disconnect()
                raise ConnectionError('Tried to read from '
                                      'non-existent connection')
            callback = stack_context.wrap(callback)
            self.read_callbacks.add(callback)
            self._stream.read_bytes(length,
                                    callback=partial(self.read_callback,
                                                     callback))
        except IOError:
            self.fire_event('on_disconnect')

    def read_callback(self, callback, *args, **kwargs):
        try:
            self.read_callbacks.remove(callback)
        except KeyError:
            pass
        callback(*args, **kwargs)

    def readline(self, callback=None):
        try:
            if not self._stream:
                self.disconnect()
                raise ConnectionError('Tried to read from '
                                      'non-existent connection')
            callback = stack_context.wrap(callback)
            self.read_callbacks.add(callback)
            callback = partial(self.read_callback, callback)
            self._stream.read_until(CRLF, callback=callback)
        except IOError:
            self.fire_event('on_disconnect')

    def connected(self):
        if self._stream:
            return True
        return False
예제 #55
0
class BaseTornadoClient(AsyncModbusClientMixin):
    """
    Base Tornado client
    """
    stream = None
    io_loop = None

    def __init__(self, *args, **kwargs):
        """
        Initializes BaseTornadoClient.
        ioloop to be passed as part of kwargs ('ioloop')
        :param args:
        :param kwargs:
        """
        self.io_loop = kwargs.pop("ioloop", None)
        super(BaseTornadoClient, self).__init__(*args, **kwargs)

    @abc.abstractmethod
    def get_socket(self):
        """
        return instance of the socket to connect to
        """

    @gen.coroutine
    def connect(self):
        """
        Connect to the socket identified by host and port

        :returns: Future
        :rtype: tornado.concurrent.Future
        """
        conn = self.get_socket()
        self.stream = IOStream(conn, io_loop=self.io_loop or IOLoop.current())
        self.stream.connect((self.host, self.port))
        self.stream.read_until_close(None, streaming_callback=self.on_receive)
        self._connected = True
        LOGGER.debug("Client connected")

        raise gen.Return(self)

    def on_receive(self, *args):
        """
        On data recieve call back
        :param args: data received
        :return:
        """
        data = args[0] if len(args) > 0 else None

        if not data:
            return
        LOGGER.debug("recv: " + hexlify_packets(data))
        unit = self.framer.decode_data(data).get("unit", 0)
        self.framer.processIncomingPacket(data,
                                          self._handle_response,
                                          unit=unit)

    def execute(self, request=None):
        """
        Executes a transaction
        :param request:
        :return:
        """
        request.transaction_id = self.transaction.getNextTID()
        packet = self.framer.buildPacket(request)
        LOGGER.debug("send: " + hexlify_packets(packet))
        self.stream.write(packet)
        return self._build_response(request.transaction_id)

    def _handle_response(self, reply, **kwargs):
        """
        Handle response received
        :param reply:
        :param kwargs:
        :return:
        """
        if reply is not None:
            tid = reply.transaction_id
            future = self.transaction.getTransaction(tid)
            if future:
                future.set_result(reply)
            else:
                LOGGER.debug("Unrequested message: {}".format(reply))

    def _build_response(self, tid):
        """
        Builds a future response
        :param tid:
        :return:
        """
        f = Future()

        if not self._connected:
            f.set_exception(ConnectionException("Client is not connected"))
            return f

        self.transaction.addTransaction(f, tid)
        return f

    def close(self):
        """
        Closes the underlying IOStream
        """
        LOGGER.debug("Client disconnected")
        if self.stream:
            self.stream.close_fd()

        self.stream = None
        self._connected = False
예제 #56
0
class KeepAliveTest(AsyncHTTPTestCase):
    """Tests various scenarios for HTTP 1.1 keep-alive support.

    These tests don't use AsyncHTTPClient because we want to control
    connection reuse and closing.
    """
    def get_app(self):
        class HelloHandler(RequestHandler):
            def get(self):
                self.finish("Hello world")

            def post(self):
                self.finish("Hello world")

        class LargeHandler(RequestHandler):
            def get(self):
                # 512KB should be bigger than the socket buffers so it will
                # be written out in chunks.
                self.write("".join(chr(i % 256) * 1024 for i in range(512)))

        class FinishOnCloseHandler(RequestHandler):
            def initialize(self, cleanup_event):
                self.cleanup_event = cleanup_event

            @gen.coroutine
            def get(self):
                self.flush()
                yield self.cleanup_event.wait()

            def on_connection_close(self):
                # This is not very realistic, but finishing the request
                # from the close callback has the right timing to mimic
                # some errors seen in the wild.
                self.finish("closed")

        self.cleanup_event = Event()
        return Application([
            ("/", HelloHandler),
            ("/large", LargeHandler),
            (
                "/finish_on_close",
                FinishOnCloseHandler,
                dict(cleanup_event=self.cleanup_event),
            ),
        ])

    def setUp(self):
        super(KeepAliveTest, self).setUp()
        self.http_version = b"HTTP/1.1"

    def tearDown(self):
        # We just closed the client side of the socket; let the IOLoop run
        # once to make sure the server side got the message.
        self.io_loop.add_timeout(datetime.timedelta(seconds=0.001), self.stop)
        self.wait()

        if hasattr(self, "stream"):
            self.stream.close()
        super(KeepAliveTest, self).tearDown()

    # The next few methods are a crude manual http client
    @gen.coroutine
    def connect(self):
        self.stream = IOStream(socket.socket())
        yield self.stream.connect(("10.0.0.7", self.get_http_port()))

    @gen.coroutine
    def read_headers(self):
        first_line = yield self.stream.read_until(b"\r\n")
        self.assertTrue(first_line.startswith(b"HTTP/1.1 200"), first_line)
        header_bytes = yield self.stream.read_until(b"\r\n\r\n")
        headers = HTTPHeaders.parse(header_bytes.decode("latin1"))
        raise gen.Return(headers)

    @gen.coroutine
    def read_response(self):
        self.headers = yield self.read_headers()
        body = yield self.stream.read_bytes(int(
            self.headers["Content-Length"]))
        self.assertEqual(b"Hello world", body)

    def close(self):
        self.stream.close()
        del self.stream

    @gen_test
    def test_two_requests(self):
        yield self.connect()
        self.stream.write(b"GET / HTTP/1.1\r\n\r\n")
        yield self.read_response()
        self.stream.write(b"GET / HTTP/1.1\r\n\r\n")
        yield self.read_response()
        self.close()

    @gen_test
    def test_request_close(self):
        yield self.connect()
        self.stream.write(b"GET / HTTP/1.1\r\nConnection: close\r\n\r\n")
        yield self.read_response()
        data = yield self.stream.read_until_close()
        self.assertTrue(not data)
        self.assertEqual(self.headers["Connection"], "close")
        self.close()

    # keepalive is supported for http 1.0 too, but it's opt-in
    @gen_test
    def test_http10(self):
        self.http_version = b"HTTP/1.0"
        yield self.connect()
        self.stream.write(b"GET / HTTP/1.0\r\n\r\n")
        yield self.read_response()
        data = yield self.stream.read_until_close()
        self.assertTrue(not data)
        self.assertTrue("Connection" not in self.headers)
        self.close()

    @gen_test
    def test_http10_keepalive(self):
        self.http_version = b"HTTP/1.0"
        yield self.connect()
        self.stream.write(b"GET / HTTP/1.0\r\nConnection: keep-alive\r\n\r\n")
        yield self.read_response()
        self.assertEqual(self.headers["Connection"], "Keep-Alive")
        self.stream.write(b"GET / HTTP/1.0\r\nConnection: keep-alive\r\n\r\n")
        yield self.read_response()
        self.assertEqual(self.headers["Connection"], "Keep-Alive")
        self.close()

    @gen_test
    def test_http10_keepalive_extra_crlf(self):
        self.http_version = b"HTTP/1.0"
        yield self.connect()
        self.stream.write(
            b"GET / HTTP/1.0\r\nConnection: keep-alive\r\n\r\n\r\n")
        yield self.read_response()
        self.assertEqual(self.headers["Connection"], "Keep-Alive")
        self.stream.write(b"GET / HTTP/1.0\r\nConnection: keep-alive\r\n\r\n")
        yield self.read_response()
        self.assertEqual(self.headers["Connection"], "Keep-Alive")
        self.close()

    @gen_test
    def test_pipelined_requests(self):
        yield self.connect()
        self.stream.write(b"GET / HTTP/1.1\r\n\r\nGET / HTTP/1.1\r\n\r\n")
        yield self.read_response()
        yield self.read_response()
        self.close()

    @gen_test
    def test_pipelined_cancel(self):
        yield self.connect()
        self.stream.write(b"GET / HTTP/1.1\r\n\r\nGET / HTTP/1.1\r\n\r\n")
        # only read once
        yield self.read_response()
        self.close()

    @gen_test
    def test_cancel_during_download(self):
        yield self.connect()
        self.stream.write(b"GET /large HTTP/1.1\r\n\r\n")
        yield self.read_headers()
        yield self.stream.read_bytes(1024)
        self.close()

    @gen_test
    def test_finish_while_closed(self):
        yield self.connect()
        self.stream.write(b"GET /finish_on_close HTTP/1.1\r\n\r\n")
        yield self.read_headers()
        self.close()
        # Let the hanging coroutine clean up after itself
        self.cleanup_event.set()

    @gen_test
    def test_keepalive_chunked(self):
        self.http_version = b"HTTP/1.0"
        yield self.connect()
        self.stream.write(b"POST / HTTP/1.0\r\n"
                          b"Connection: keep-alive\r\n"
                          b"Transfer-Encoding: chunked\r\n"
                          b"\r\n"
                          b"0\r\n"
                          b"\r\n")
        yield self.read_response()
        self.assertEqual(self.headers["Connection"], "Keep-Alive")
        self.stream.write(b"GET / HTTP/1.0\r\nConnection: keep-alive\r\n\r\n")
        yield self.read_response()
        self.assertEqual(self.headers["Connection"], "Keep-Alive")
        self.close()
예제 #57
0
class HTTPServerRawTest(AsyncHTTPTestCase):
    def get_app(self):
        return Application([("/echo", EchoHandler)])

    def setUp(self):
        super(HTTPServerRawTest, self).setUp()
        self.stream = IOStream(socket.socket())
        self.io_loop.run_sync(lambda: self.stream.connect(
            ("10.0.0.7", self.get_http_port())))

    def tearDown(self):
        self.stream.close()
        super(HTTPServerRawTest, self).tearDown()

    def test_empty_request(self):
        self.stream.close()
        self.io_loop.add_timeout(datetime.timedelta(seconds=0.001), self.stop)
        self.wait()

    def test_malformed_first_line_response(self):
        with ExpectLog(gen_log, ".*Malformed HTTP request line"):
            self.stream.write(b"asdf\r\n\r\n")
            start_line, headers, response = self.io_loop.run_sync(
                lambda: read_stream_body(self.stream))
            self.assertEqual("HTTP/1.1", start_line.version)
            self.assertEqual(400, start_line.code)
            self.assertEqual("Bad Request", start_line.reason)

    def test_malformed_first_line_log(self):
        with ExpectLog(gen_log, ".*Malformed HTTP request line"):
            self.stream.write(b"asdf\r\n\r\n")
            # TODO: need an async version of ExpectLog so we don't need
            # hard-coded timeouts here.
            self.io_loop.add_timeout(datetime.timedelta(seconds=0.05),
                                     self.stop)
            self.wait()

    def test_malformed_headers(self):
        with ExpectLog(gen_log,
                       ".*Malformed HTTP message.*no colon in header line"):
            self.stream.write(b"GET / HTTP/1.0\r\nasdf\r\n\r\n")
            self.io_loop.add_timeout(datetime.timedelta(seconds=0.05),
                                     self.stop)
            self.wait()

    def test_chunked_request_body(self):
        # Chunked requests are not widely supported and we don't have a way
        # to generate them in AsyncHTTPClient, but HTTPServer will read them.
        self.stream.write(b"""\
POST /echo HTTP/1.1
Transfer-Encoding: chunked
Content-Type: application/x-www-form-urlencoded

4
foo=
3
bar
0

""".replace(b"\n", b"\r\n"))
        start_line, headers, response = self.io_loop.run_sync(
            lambda: read_stream_body(self.stream))
        self.assertEqual(json_decode(response), {u"foo": [u"bar"]})

    def test_chunked_request_uppercase(self):
        # As per RFC 2616 section 3.6, "Transfer-Encoding" header's value is
        # case-insensitive.
        self.stream.write(b"""\
POST /echo HTTP/1.1
Transfer-Encoding: Chunked
Content-Type: application/x-www-form-urlencoded

4
foo=
3
bar
0

""".replace(b"\n", b"\r\n"))
        start_line, headers, response = self.io_loop.run_sync(
            lambda: read_stream_body(self.stream))
        self.assertEqual(json_decode(response), {u"foo": [u"bar"]})

    @gen_test
    def test_invalid_content_length(self):
        with ExpectLog(gen_log, ".*Only integer Content-Length is allowed"):
            self.stream.write(b"""\
POST /echo HTTP/1.1
Content-Length: foo

bar

""".replace(b"\n", b"\r\n"))
            yield self.stream.read_until_close()
예제 #58
0
class KeepAliveTest(AsyncHTTPTestCase):
    """Tests various scenarios for HTTP 1.1 keep-alive support.

    These tests don't use AsyncHTTPClient because we want to control
    connection reuse and closing.
    """
    def get_app(self):
        class HelloHandler(RequestHandler):
            def get(self):
                self.finish('Hello world')

        class LargeHandler(RequestHandler):
            def get(self):
                # 512KB should be bigger than the socket buffers so it will
                # be written out in chunks.
                self.write(''.join(chr(i % 256) * 1024 for i in range(512)))

        class FinishOnCloseHandler(RequestHandler):
            @asynchronous
            def get(self):
                self.flush()

            def on_connection_close(self):
                # This is not very realistic, but finishing the request
                # from the close callback has the right timing to mimic
                # some errors seen in the wild.
                self.finish('closed')

        return Application([('/', HelloHandler), ('/large', LargeHandler),
                            ('/finish_on_close', FinishOnCloseHandler)])

    def setUp(self):
        super(KeepAliveTest, self).setUp()
        self.http_version = b'HTTP/1.1'

    def tearDown(self):
        # We just closed the client side of the socket; let the IOLoop run
        # once to make sure the server side got the message.
        self.io_loop.add_timeout(datetime.timedelta(seconds=0.001), self.stop)
        self.wait()

        if hasattr(self, 'stream'):
            self.stream.close()
        super(KeepAliveTest, self).tearDown()

    # The next few methods are a crude manual http client
    def connect(self):
        self.stream = IOStream(socket.socket(), io_loop=self.io_loop)
        self.stream.connect(('localhost', self.get_http_port()), self.stop)
        self.wait()

    def read_headers(self):
        self.stream.read_until(b'\r\n', self.stop)
        first_line = self.wait()
        self.assertTrue(first_line.startswith(self.http_version + b' 200'),
                        first_line)
        self.stream.read_until(b'\r\n\r\n', self.stop)
        header_bytes = self.wait()
        headers = HTTPHeaders.parse(header_bytes.decode('latin1'))
        return headers

    def read_response(self):
        self.headers = self.read_headers()
        self.stream.read_bytes(int(self.headers['Content-Length']), self.stop)
        body = self.wait()
        self.assertEqual(b'Hello world', body)

    def close(self):
        self.stream.close()
        del self.stream

    def test_two_requests(self):
        self.connect()
        self.stream.write(b'GET / HTTP/1.1\r\n\r\n')
        self.read_response()
        self.stream.write(b'GET / HTTP/1.1\r\n\r\n')
        self.read_response()
        self.close()

    def test_request_close(self):
        self.connect()
        self.stream.write(b'GET / HTTP/1.1\r\nConnection: close\r\n\r\n')
        self.read_response()
        self.stream.read_until_close(callback=self.stop)
        data = self.wait()
        self.assertTrue(not data)
        self.close()

    # keepalive is supported for http 1.0 too, but it's opt-in
    def test_http10(self):
        self.http_version = b'HTTP/1.0'
        self.connect()
        self.stream.write(b'GET / HTTP/1.0\r\n\r\n')
        self.read_response()
        self.stream.read_until_close(callback=self.stop)
        data = self.wait()
        self.assertTrue(not data)
        self.assertTrue('Connection' not in self.headers)
        self.close()

    def test_http10_keepalive(self):
        self.http_version = b'HTTP/1.0'
        self.connect()
        self.stream.write(b'GET / HTTP/1.0\r\nConnection: keep-alive\r\n\r\n')
        self.read_response()
        self.assertEqual(self.headers['Connection'], 'Keep-Alive')
        self.stream.write(b'GET / HTTP/1.0\r\nConnection: keep-alive\r\n\r\n')
        self.read_response()
        self.assertEqual(self.headers['Connection'], 'Keep-Alive')
        self.close()

    def test_pipelined_requests(self):
        self.connect()
        self.stream.write(b'GET / HTTP/1.1\r\n\r\nGET / HTTP/1.1\r\n\r\n')
        self.read_response()
        self.read_response()
        self.close()

    def test_pipelined_cancel(self):
        self.connect()
        self.stream.write(b'GET / HTTP/1.1\r\n\r\nGET / HTTP/1.1\r\n\r\n')
        # only read once
        self.read_response()
        self.close()

    def test_cancel_during_download(self):
        self.connect()
        self.stream.write(b'GET /large HTTP/1.1\r\n\r\n')
        self.read_headers()
        self.stream.read_bytes(1024, self.stop)
        self.wait()
        self.close()

    def test_finish_while_closed(self):
        self.connect()
        self.stream.write(b'GET /finish_on_close HTTP/1.1\r\n\r\n')
        self.read_headers()
        self.close()
예제 #59
0
파일: irc.py 프로젝트: kaleozhou/labots
class IRC(object):
    # Private
    _stream = None
    _charset = None
    _ioloop = None
    _timer = None
    _last_pong = None
    _is_reconnect = 0
    _buffers = []
    _send_timer = None

    host = None
    port = None
    nick = None
    chans = []
    chans_ref = {}
    names = {}
    relaybots = []
    delims = [
        ('<', '> '),
        ('[', '] '),
        ('(', ') '),
        ('{', '} '),
    ]

    # External callbacks
    # Called when you are logined
    login_callback = None
    # Called when you received specified IRC message
    # for usage of event_callback, see `botbox.dispatch`
    event_callback = None

    def __init__(self,
                 host,
                 port,
                 nick,
                 relaybots=[],
                 charset='utf-8',
                 ioloop=False):
        logger.info('Connecting to %s:%s', host, port)

        self.host = host
        self.port = port
        self.nick = nick
        self.relaybots = relaybots

        self._charset = charset
        sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
        self._ioloop = ioloop or IOLoop.instance()
        self._stream = IOStream(sock, io_loop=self._ioloop)
        self._stream.connect((host, port), self._login)

        self._last_pong = time.time()
        self._timer = PeriodicCallback(self._keep_alive,
                                       60 * 1000,
                                       io_loop=self._ioloop)
        self._timer.start()

        self._send_timer = PeriodicCallback(self._sock_send,
                                            600,
                                            io_loop=self._ioloop)
        self._send_timer.start()

    def _sock_send(self):
        if (self._buffers[0:]):
            data = self._buffers.pop(0)
            return self._stream.write(data)

    def _period_send(self, data):
        # Data will be sent in `self._sock_send()`
        self._buffers.append(bytes(data, self._charset))

    def _sock_recv(self):
        def _recv(data):
            msg = data.decode(self._charset, 'ignore')
            msg = msg[:-2]  # strip '\r\n'
            self._recv(msg)

        try:
            self._stream.read_until(b'\r\n', _recv)
        except Exception as err:
            logger.error('Read error: %s', err)
            self._reconnect()

    def _reconnect(self):
        logger.info('Reconnecting...')

        self._is_reconnect = 1

        self._stream.close()
        sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
        self._stream = IOStream(sock, io_loop=self._ioloop)
        self._stream.connect((self.host, self.port), self._login)

    # IRC message parser, return tuple (IRCMsgType, IRCMsg)
    def _parse(self, msg):
        if msg.startswith('PING :'):
            logger.debug('PING')
            return (IRCMsgType.PING, None)
        elif msg.startswith('NOTICE AUTH :'):
            logger.debug('NOTIC AUTH: "%s"')
            return (IRCMsgType.NOTICE, None)
        elif msg.startswith('ERROR :'):
            logger.debug('ERROR: "%s"', msg)
            return (IRCMsgType.ERROR, None)

        try:
            # <message> ::= [':' <prefix> <SPACE> ] <command> <params> <crlf>
            tmp = msg.split(' ', maxsplit=2)

            if len(tmp) != 3:
                raise Exception(
                    'Failed when parsing <prefix> <command> <params>')

            prefix, command, params = tmp
            logger.debug('prefix: "%s", command: "%s", params: "%s"', prefix,
                         command, params)

            # <params> ::= <SPACE> [ ':' <trailing> | <middle> <params> ]
            middle, _, trailing = params.partition(' :')
            if middle.startswith(':'):
                trailing = middle[1:]
                middle = ''
            logger.debug('middle: "%s", trailing: "%s"', middle, trailing)

            if not middle and not trailing:
                middle = trailing = ''
                # raise Exception('No <middle> and <trailing>')

            # <middle> ::= <Any *non-empty* sequence of octets not including SPACE
            #              or NUL or CR or LF, the first of which may not be ':'>
            args = middle.split(' ')
            logger.debug('args: "%s"', args)

            # <prefix> ::= <servername> | <nick> [ '!' <user> ] [ '@' <host> ]
            tmp = prefix
            nick, _, tmp = tmp.partition('!')
            user, _, host = tmp.partition('@')
            logger.debug('nick: "%s", user: "******", host: "%s"', nick, user,
                         host)

        except Exception as err:
            logger.error('Parsing error: %s', err)
            logger.error('    Message: %s', repr(msg))
            return (IRCMsgType.UNKNOW, None)
        else:
            ircmsg = IRCMsg()
            ircmsg.nick = nick[1:]  # strip ':'
            ircmsg.user = user
            ircmsg.host = host
            ircmsg.cmd = command
            ircmsg.args = args
            ircmsg.msg = trailing
            return (IRCMsgType.MSG, ircmsg)

    # Response server message
    def _resp(self, type_, ircmsg):
        if type_ == IRCMsgType.PING:
            self._pong()
        elif type_ == IRCMsgType.ERROR:
            pass
        elif type_ == IRCMsgType.MSG:
            if ircmsg.cmd == RPL_WELCOME:
                self._on_login(ircmsg.args[0])
            elif ircmsg.cmd == ERR_NICKNAMEINUSE:
                new_nick = ircmsg.args[1] + '_'
                logger.info('Nick already in use, use "%s"', new_nick)
                self._chnick(new_nick)
            elif ircmsg.cmd == 'JOIN':
                chan = ircmsg.args[0] or ircmsg.msg
                if ircmsg.nick == self.nick:
                    self.chans.append(chan)
                    self.names[chan] = set()
                    logger.info('%s has joined %s', self.nick, chan)
                self.names[chan].add(ircmsg.nick)
            elif ircmsg.cmd == 'PART':
                chan = ircmsg.args[0]
                try:
                    self.names[chan].remove(ircmsg.nick)
                except KeyError as err:
                    logger.error('KeyError: %s', err)
                    logger.error('%s %s %s %s %s %s', ircmsg.nick, ircmsg.user,
                                 ircmsg.host, ircmsg.cmd, ircmsg.args,
                                 ircmsg.msg)
                if ircmsg.nick == self.nick:
                    self.chans.remove(chan)
                    self.names[chan].clear()
                    logger.info('%s has left %s', self.nick, ircmsg.args[0])
            elif ircmsg.cmd == 'NICK':
                new_nick, old_nick = ircmsg.msg, ircmsg.nick
                for chan in self.chans:
                    if old_nick in self.names[chan]:
                        self.names[chan].remove(old_nick)
                        self.names[chan].add(new_nick)
                if old_nick == self.nick:
                    self.nick = old_nick
                    logger.info('%s is now known as %s', old_nick, new_nick)
            elif ircmsg.cmd == 'QUIT':
                nick = ircmsg.nick
                for chan in self.chans:
                    if nick in self.names[chan]:
                        self.names[chan].remove(nick)
            elif ircmsg.cmd == RPL_NAMREPLY:
                chan = ircmsg.args[2]
                names_list = [
                    x[1:] if x[0] in ['@', '+'] else x
                    for x in ircmsg.msg.split(' ')
                ]
                self.names[chan].update(names_list)
                logger.debug('NAMES: %s' % names_list)

    def _dispatch(self, type_, ircmsg):
        if type_ != IRCMsgType.MSG:
            return

        # Error message
        if ircmsg.cmd[0] in ['4', '5']:
            logger.warn('Error message: %s', ircmsg.msg)
        elif ircmsg.cmd in ['JOIN', 'PART']:
            nick, chan = ircmsg.nick, ircmsg.args[0] or ircmsg.msg
            self.event_callback(ircmsg.cmd, chan, nick)
        elif ircmsg.cmd == 'QUIT':
            nick, reason = ircmsg.nick, ircmsg.msg
            for chan in self.chans:
                if nick in self.names[chan]:
                    self.event_callback(ircmsg.cmd, chan, nick, reason)
        elif ircmsg.cmd == 'NICK':
            new_nick, old_nick = ircmsg.msg, ircmsg.nick
            for chan in self.chans:
                if old_nick in self.names[chan]:
                    self.event_callback(ircmsg.cmd, chan, old_nick, new_nick)
        elif ircmsg.cmd in ['PRIVMSG', 'NOTICE']:
            if ircmsg.msg.startswith('\x01ACTION '):
                msg = ircmsg.msg[len('\x01ACTION '):-1]
                cmd = 'ACTION'
            else:
                msg, cmd = ircmsg.msg, ircmsg.cmd
            nick, target = ircmsg.nick, ircmsg.args[0]
            self.event_callback(cmd, target, nick, msg)

            # LABOTS_MSG = ACTION or PRIVMSG or NOTICE
            # And it will:
            # - Strip IRC color codes
            # - Replace relaybot's nick with human's nick
            bot = ''
            msg = strip(msg)
            if nick in self.relaybots:
                for d in self.delims:
                    if msg.startswith(d[0]) and msg.find(d[1]) != -1:
                        bot = nick
                        nick = msg[len(d[0]):msg.find(d[1])]
                        msg = msg[msg.find(d[1]) + len(d[1]):]
                        break
            self.event_callback('LABOTS_MSG', target, bot, nick, msg)

            # LABOTS_MENTION_MSG = LABOTS_MSG + labots's nick is mentioned at
            # the head of message
            words = msg.split(' ', maxsplit=1)
            if words[0] in [self.nick + x for x in ['', ':', ',']]:
                if words[1:]:
                    msg = words[1]
                    self.event_callback('LABOTS_MENTION_MSG', target, bot,
                                        nick, msg)

    def _keep_alive(self):
        # Ping time out
        if time.time() - self._last_pong > 360:
            logger.error('Ping time out')

            self._reconnect()
            self._last_pong = time.time()

    def _recv(self, msg):
        if msg:
            type_, ircmsg = self._parse(msg)
            self._dispatch(type_, ircmsg)
            self._resp(type_, ircmsg)

        self._sock_recv()

    def _chnick(self, nick):
        self._period_send('NICK %s\r\n' % nick)

    def _on_login(self, nick):
        logger.info('You are logined as %s', nick)

        self.nick = nick
        chans = self.chans

        if not self._is_reconnect:
            self.login_callback()

        self.chans = []
        [self.join(chan, force=True) for chan in chans]

    def _login(self):
        logger.info('Try to login as "%s"', self.nick)

        self._chnick(self.nick)
        self._period_send('USER %s %s %s %s\r\n' %
                          (self.nick, 'labots', 'localhost',
                           'https://github.com/SilverRainZ/labots'))

        self._sock_recv()

    def _pong(self):
        logger.debug('Pong!')

        self._last_pong = time.time()
        self._period_send('PONG :labots!\n')

    def set_callback(self,
                     login_callback=empty_callback,
                     event_callback=empty_callback):
        self.login_callback = login_callback
        self.event_callback = event_callback

    def join(self, chan, force=False):
        if chan[0] not in ['#', '&']:
            return

        if not force:
            if chan in self.chans_ref:
                self.chans_ref[chan] += 1
                return
            self.chans_ref[chan] = 1

        logger.debug('Try to join %s', chan)
        self._period_send('JOIN %s\r\n' % chan)

    def part(self, chan):
        if chan[0] not in ['#', '&']:
            return
        if chan not in self.chans_ref:
            return

        if self.chans_ref[chan] != 1:
            self.chans_ref[chan] -= 1
            return

        self.chans_ref.pop(chan, None)

        logger.debug('Try to part %s', chan)
        self._period_send('PART %s\r\n' % chan)

    # recv_msg: Whether receive the message you sent
    def send(self, target, msg, recv_msg=True):
        lines = msg.split('\n')
        for line in lines:
            self._period_send('PRIVMSG %s :%s\r\n' % (target, line))
            # You will recv the message you sent
            if recv_msg:
                self.event_callback('PRIVMSG', target, self.nick, line)

    def action(self, target, msg):
        self._period_send('PRIVMSG %s :\1ACTION %s\1\r\n')
        # You will recv the message you sent
        self.event_callback('ACTION', target, self.nick, msg)

    def topic(self, chan, topic):
        self._period_send('TOPIC %s :%s\r\n' % (chan, topic))

    def kick(self, chan, nick, reason):
        self._period_send('KICK %s %s :%s\r\n' % (chan, nick, topic))

    def quit(self, reason='食饭'):
        self._period_send('QUIT :%s\r\n' % reason)

    def stop(self):
        logger.info('Stop')
        self.quit()
        self._stream.close()
예제 #60
0
class _RedisConnection(object):
    def __init__(self, io_loop, write_buf, final_callback, redis_tuple,
                 redis_pass):
        """
        :param io_loop: 你懂的
        :param write_buf: 第一次写入
        :param final_callback: resp赋值时调用
        :param redis_tuple: (ip, port, db)
        :param redis_pass: redis密码
        """
        self.__io_loop = io_loop
        self.__final_cb = final_callback
        self.__stream = None
        #redis应答解析remain
        self.__recv_buf = ''
        self.__write_buf = write_buf

        init_buf = ''
        init_buf = chain_select_cmd(redis_tuple[2], init_buf)
        if redis_pass is None:
            self.__init_buf = (init_buf, )
        else:
            assert redis_pass and isinstance(redis_pass, str)
            self.__init_buf = (redis_auth(redis_pass), init_buf)

        self.__haspass = redis_pass is not None
        self.__init_buf = ''.join(self.__init_buf)

        self.__connect_state = CONNECT_INIT
        #redis指令上下文, connect指令个数(AUTH, SELECT .etc),trans,cmd_count
        self.__cmd_env = deque()
        self.__written = False

    def connect(self, init_future, redis_tuple, active_trans, cmd_count):
        """
        :param init_future: 第一个future对象
        :param redis_tuple: (ip, port, db)
        :param active_trans: 事务是否激活
        :param cmd_count: 指令个数
        """
        if self.__stream is not None:
            return

        #future, connect_count, transaction, cmd_count
        self.__cmd_env.append((init_future, 1 + int(self.__haspass), False, 0))
        self.__cmd_env.append((init_future, 0, active_trans, cmd_count))

        with ExceptionStackContext(self.__handle_ex):
            self.__stream = IOStream(socket.socket(socket.AF_INET,
                                                   socket.SOCK_STREAM, 0),
                                     io_loop=self.__io_loop)
            self.__stream.set_close_callback(self.__on_close)
            self.__stream.connect(redis_tuple[:2], self.__on_connect)
            self.__connect_state = CONNECT_ING

    def connect_state(self):
        return self.__connect_state

    def write(self,
              write_buf,
              new_future,
              include_select,
              active_trans,
              cmd_count,
              by_connect=False):
        """
        :param new_future: 由于闭包的影响,在resp回调函数中会保存上一次的future对象,该对象必须得到更新
        :param include_select: 是否包含SELECT指令
        :param active_trans: 事务是否激活
        :param cmd_count: 指令个数
        """
        if by_connect:
            self.__stream.write(self.__init_buf)
            self.__init_buf = None

            if self.__write_buf:
                self.__stream.write(self.__write_buf)
                self.__write_buf = None
            return

        self.__cmd_env.append(
            (new_future, int(include_select), active_trans, cmd_count))
        if self.__connect_state == CONNECT_ING:
            self.__write_buf = ''.join((self.__write_buf, write_buf))
            return

        if self.__write_buf:
            write_buf = ''.join((self.__write_buf, write_buf))

        self.__stream.write(write_buf)
        self.__write_buf = None

    def __on_connect(self):
        """连接,只需要发送初始cmd即可
        """
        self.__connect_state = CONNECT_SUCC
        self.__stream.set_nodelay(True)
        self.write(None, None, None, None, None, True)
        self.__stream.read_until_close(None, self.__on_resp)

    def __on_resp(self, recv):
        """
        :param recv: 收到的buf
        """
        recv = ''.join((self.__recv_buf, recv))

        idx = 0
        for future, connect_count, trans, count in self.__cmd_env:
            ok, payload, recv = decode_resp_ondemand(recv, connect_count,
                                                     trans, count)
            if not ok:
                break

            idx += 1
            if count > 0:
                self.__run_callback({
                    _RESP_FUTURE: future,
                    RESP_RESULT: payload
                })

        self.__recv_buf = recv
        for _ in xrange(idx):
            self.__cmd_env.popleft()

    def __on_close(self):
        self.__connect_state = CONNECT_INIT
        if self.__final_cb:
            if self.__stream.error:
                self.__run_callback({RESP_ERR: self.__stream.error})

    def __run_callback(self, resp):
        if self.__final_cb is None:
            return

        self.__io_loop.add_callback(self.__final_cb, resp)

    def __handle_ex(self, typ, value, tb):
        """
        :param typ: 异常类型
        """
        if self.__final_cb:
            self.__run_callback({RESP_ERR: value})
            return True
        return False