def _get_conn(self): # -> Future[connection] now = self.io_loop.time() # Try to reuse in free pool while self._free_conn: conn = self._free_conn.popleft() if now - conn.connected_time > self.max_recycle_sec: self._close_async(conn) continue log.debug("Reusing connection from pool: %s", self.stat()) fut = Future() fut.set_result(conn) return fut # Open new connection if self.max_open == 0 or self._opened_conns < self.max_open: self._opened_conns += 1 log.debug("Creating new connection: %s", self.stat()) fut = connect(**self.connect_kwargs) fut.add_done_callback(self._on_connect) # self._opened_conns -=1 on exception return fut # Wait to other connection is released. fut = Future() self._waitings.append(fut) return fut
def put_task(self, inputs, callback=None): """ return a Future of output.""" f = Future() if callback is not None: f.add_done_callback(callback) self.input_queue.put((inputs, f)) return f
def send_message(self, args, callback=None): command = args[0] if 'SUBSCRIBE' in command: raise NotImplementedError('Not yet.') # Do not allow the commands, affecting the execution of other commands, # to be used on shared connection. if command in ('WATCH', 'MULTI'): if self.is_shared(): raise Exception('Command %s is not allowed while connection ' 'is shared!' % command) if command == 'WATCH': self._watch.add(args[1]) if command == 'MULTI': self._multi = True # monitor transaction state, to unlock correctly if command in ('EXEC', 'DISCARD', 'UNWATCH'): if command in ('EXEC', 'DISCARD'): self._multi = False self._watch.clear() self.stream.write(self.format_message(args)) future = Future() if callback is not None: future.add_done_callback(stack_context.wrap(callback)) self.callbacks.append(future.set_result) return future
class ManualCapClient(BaseCapClient): def capitalize(self, request_data, callback=None): logging.debug("capitalize") self.request_data = request_data self.stream = IOStream(socket.socket()) self.stream.connect(('127.0.0.1', self.port), callback=self.handle_connect) self.future = Future() if callback is not None: self.future.add_done_callback( stack_context.wrap(lambda future: callback(future.result()))) return self.future def handle_connect(self): logging.debug("handle_connect") self.stream.write(utf8(self.request_data + "\n")) self.stream.read_until(b'\n', callback=self.handle_read) def handle_read(self, data): logging.debug("handle_read") self.stream.close() try: self.future.set_result(self.process_response(data)) except CapError as e: self.future.set_exception(e)
def wait(self, timeout: Union[float, datetime.timedelta] = None) -> "Future[None]": """Block until the internal flag is true. Returns a Future, which raises `tornado.util.TimeoutError` after a timeout. """ fut = Future() # type: Future[None] if self._value: fut.set_result(None) return fut self._waiters.add(fut) fut.add_done_callback(lambda fut: self._waiters.remove(fut)) if timeout is None: return fut else: timeout_fut = gen.with_timeout( timeout, fut, quiet_exceptions=(CancelledError,) ) # This is a slightly clumsy workaround for the fact that # gen.with_timeout doesn't cancel its futures. Cancelling # fut will remove it from the waiters list. timeout_fut.add_done_callback( lambda tf: fut.cancel() if not fut.done() else None ) return timeout_fut
def acquire( self, timeout: Union[float, datetime.timedelta] = None ) -> "Future[_ReleasingContextManager]": """Decrement the counter. Returns a Future. Block if the counter is zero and wait for a `.release`. The Future raises `.TimeoutError` after the deadline. """ waiter = Future() # type: Future[_ReleasingContextManager] if self._value > 0: self._value -= 1 waiter.set_result(_ReleasingContextManager(self)) else: self._waiters.append(waiter) if timeout: def on_timeout() -> None: if not waiter.done(): waiter.set_exception(gen.TimeoutError()) self._garbage_collect() io_loop = ioloop.IOLoop.current() timeout_handle = io_loop.add_timeout(timeout, on_timeout) waiter.add_done_callback( lambda _: io_loop.remove_timeout(timeout_handle) ) return waiter
def fetch(self, request, callback=None, raise_error=True, **kwargs): if not isinstance(request, HTTPRequest): request = HTTPRequest(url=request, **kwargs) key = self.cache.create_key(request) # Check and return future if there is a pending request pending = self.pending_requests.get(key) if pending: return pending response = self.cache.get_response_and_time(key) if response: response.cached = True if callback: self.io_loop.add_callback(callback, response) future = Future() future.set_result(response) return future future = orig_fetch(self, request, callback, raise_error, **kwargs) self.pending_requests[key] = future def cache_response(future): exc = future.exception() if exc is None: self.cache.save_response(key, future.result()) future.add_done_callback(cache_response) return future
def _start_processing_requests(self): while True: data = yield gen.Task(self._stream.read_until, '\r\n') log.debug('New request: %r', data) try: msg = json.loads(data) key = msg['key'] method = msg['method'] args = msg['args'] kwargs = msg['kwargs'] except (KeyError, ValueError): log.error('Malformed request data: %s', data) continue try: res = self._handler(method, *args, **kwargs) if isinstance(res, Future): future = res else: future = Future() future.set_result(res) except Exception as e: log.exception('Failed to handle request: %s', key) future = concurrent.TracebackFuture() future.set_exception(e) future.add_done_callback(partial(self._on_future_finished, key))
def wrapper(*args, **kwargs): future = Future() callback, args, kwargs = replacer.replace(future, args, kwargs) if callback is not None: future.add_done_callback( functools.partial(_auth_future_to_callback, callback)) f(*args, **kwargs) return future
def put_task(self, dp, callback=None): """ Same as in :meth:`AsyncPredictorBase.put_task`. """ f = Future() if callback is not None: f.add_done_callback(callback) self.input_queue.put((dp, f)) return f
def put_task(self, dp, callback=None): """ dp must be non-batched, i.e. single instance """ f = Future() if callback is not None: f.add_done_callback(callback) self.input_queue.put((dp, f)) return f
def register(self, name, callback=None): msg_id = next(self._generator) buf = self._pack_request(msg_id, RPC_REGISTER, 'register', (name)) future = Future() if callback: future.add_done_callback(callback) self.add_request_table(msg_id, future) self.write(buf) return future
def fetch(self, request, callback=None, raise_error=True, **kwargs): """Executes a request, asynchronously returning an `HTTPResponse`. The request may be either a string URL or an `HTTPRequest` object. If it is a string, we construct an `HTTPRequest` using any additional kwargs: ``HTTPRequest(request, **kwargs)`` This method returns a `.Future` whose result is an `HTTPResponse`. By default, the ``Future`` will raise an `HTTPError` if the request returned a non-200 response code (other errors may also be raised if the server could not be contacted). Instead, if ``raise_error`` is set to False, the response will always be returned regardless of the response code. If a ``callback`` is given, it will be invoked with the `HTTPResponse`. In the callback interface, `HTTPError` is not automatically raised. Instead, you must check the response's ``error`` attribute or call its `~HTTPResponse.rethrow` method. """ if self._closed: raise RuntimeError("fetch() called on closed AsyncHTTPClient") if not isinstance(request, HTTPRequest): request = HTTPRequest(url=request, **kwargs) else: if kwargs: raise ValueError("kwargs can't be used if request is an HTTPRequest object") # We may modify this (to add Host, Accept-Encoding, etc), # so make sure we don't modify the caller's object. This is also # where normal dicts get converted to HTTPHeaders objects. request.headers = httputil.HTTPHeaders(request.headers) request = _RequestProxy(request, self.defaults) future = Future() if callback is not None: callback = stack_context.wrap(callback) def handle_future(future): exc = future.exception() if isinstance(exc, HTTPError) and exc.response is not None: response = exc.response elif exc is not None: response = HTTPResponse( request, 599, error=exc, request_time=time.time() - request.start_time) else: response = future.result() self.io_loop.add_callback(callback, response) future.add_done_callback(handle_future) def handle_response(response): if raise_error and response.error: future.set_exception(response.error) else: future_set_result_unless_cancelled(future, response) self.fetch_impl(request, handle_response) return future
def close(self): if self._cursor is None: self._cursor.close() future = Future() future.set_result(None) else: future = async_call_method(self._cursor.close) self._cursor = None future.add_done_callback(self._release_lock) return future
def put_task(self, inputs, callback=None): """ :params inputs: a data point (list of component) matching input_names (not batched) :params callback: a callback to get called with the list of outputs :returns: a Future of output.""" f = Future() if callback is not None: f.add_done_callback(callback) self.input_queue.put((inputs, f)) return f
def call(self, method_name, *arg, **kwargs): _callback = kwargs.get('callback') msg_id = next(self._generator) buff = self._pack_request(msg_id, RPC_REQUEST, method_name, arg) future = Future() self.add_request_table(msg_id, future) if _callback: future.add_done_callback(_callback) self.write(buff) return future
def feed_puppy(callback=None): def do_thing(future): time.sleep(0.2) future.set_result(True) future = Future() t = threading.Thread(target=do_thing, args=(future,)) t.start() if callback: future.add_done_callback(lambda f: tornado.ioloop.IOLoop.current().add_callback(callback, f)) return future
def _set_timeout( future: Future, timeout: Union[None, float, datetime.timedelta] ) -> None: if timeout: def on_timeout() -> None: if not future.done(): future.set_exception(gen.TimeoutError()) io_loop = ioloop.IOLoop.current() timeout_handle = io_loop.add_timeout(timeout, on_timeout) future.add_done_callback(lambda _: io_loop.remove_timeout(timeout_handle))
def generate_challenge(self): """Returns authentication challenge message, with OAuth URL.""" nonce = uuid.uuid4().hex future = Future() future.add_done_callback(lambda future: self.session.transport.write_json( self.authenticate(future.result()))) self.session.properties.futures[nonce] = future return { "url": "http://robust.brendan.so/auth/twitter?robust_token=%s" % nonce }
def queue_request(queue, queue_name, **kwargs): def queue_listener(queue_name, body): f.set_result(loads(body.decode('utf-8'))) if 'timeout' not in kwargs: kwargs['timeout'] = 600 kwargs['reply_to'] = '%s-reply-%s' % (queue_name, id(kwargs)) yield queue.send(queue_name, dumps(kwargs)) f = Future() queue.listen([kwargs['reply_to']], queue_listener, workers_count=1) f.add_done_callback(lambda f: queue.stop([kwargs['reply_to']])) return (yield with_timeout(timedelta(seconds=kwargs['timeout']), f))
def wrapper(*args, **kwargs): future = Future() callback, args, kwargs = replacer.replace(future, args, kwargs) if callback is not None: future.add_done_callback( functools.partial(_auth_future_to_callback, callback)) def handle_exception(typ, value, tb): if future.done(): return False else: future_set_exc_info(future, (typ, value, tb)) return True with ExceptionStackContext(handle_exception): f(*args, **kwargs) return future
def fetch(self, request, callback=None, **kwargs): """Executes a request, asynchronously returning an `HTTPResponse`. The request may be either a string URL or an `HTTPRequest` object. If it is a string, we construct an `HTTPRequest` using any additional kwargs: ``HTTPRequest(request, **kwargs)`` This method returns a `~concurrent.futures.Future` whose result is an `HTTPResponse`. The ``Future`` wil raise an `HTTPError` if the request returned a non-200 response code. If a ``callback`` is given, it will be invoked with the `HTTPResponse`. In the callback interface, `HTTPError` is not automatically raised. Instead, you must check the response's ``error`` attribute or call its `~HTTPResponse.rethrow` method. """ if not isinstance(request, HTTPRequest): request = HTTPRequest(url=request, **kwargs) # We may modify this (to add Host, Accept-Encoding, etc), # so make sure we don't modify the caller's object. This is also # where normal dicts get converted to HTTPHeaders objects. request.headers = httputil.HTTPHeaders(request.headers) request = _RequestProxy(request, self.defaults) future = Future() if callback is not None: callback = stack_context.wrap(callback) def handle_future(future): exc = future.exception() if isinstance(exc, HTTPError) and exc.response is not None: response = exc.response elif exc is not None: response = HTTPResponse( request, 599, error=exc, request_time=time.time() - request.start_time) else: response = future.result() self.io_loop.add_callback(callback, response) future.add_done_callback(handle_future) def handle_response(response): if response.error: future.set_exception(response.error) else: future.set_result(response) self.fetch_impl(request, handle_response) return future
def wait(self, timeout=None): """等待 `.notify`. 返回一个 `.Future` 对象, 如果条件被通知则为 ``True`` , 或者在超时之后为 ``False`` . """ waiter = Future() self._waiters.append(waiter) if timeout: def on_timeout(): waiter.set_result(False) self._garbage_collect() io_loop = ioloop.IOLoop.current() timeout_handle = io_loop.add_timeout(timeout, on_timeout) waiter.add_done_callback( lambda _: io_loop.remove_timeout(timeout_handle)) return waiter
def wait(self, timeout=None): """Wait for `.notify`. Returns a `.Future` that resolves ``True`` if the condition is notified, or ``False`` after a timeout. """ waiter = Future() self._waiters.append(waiter) if timeout: def on_timeout(): waiter.set_result(False) self._garbage_collect() io_loop = ioloop.IOLoop.current() timeout_handle = io_loop.add_timeout(timeout, on_timeout) waiter.add_done_callback( lambda _: io_loop.remove_timeout(timeout_handle)) return waiter
def wrapper(*args, **kwargs): future = Future() callback, args, kwargs = replacer.replace(future, args, kwargs) if callback is not None: warnings.warn("callback arguments are deprecated, use the returned Future instead", DeprecationWarning) future.add_done_callback( wrap(functools.partial(_auth_future_to_callback, callback))) def handle_exception(typ, value, tb): if future.done(): return False else: future_set_exc_info(future, (typ, value, tb)) return True with ExceptionStackContext(handle_exception, delay_warning=True): f(*args, **kwargs) return future
def _get_frame(self, timeout=None): future = Future() if self._frame_queue: future.set_result(self._frame_queue.popleft()) else: if timeout is not None: def on_timeout(): future.set_exception(_TimeoutException()) handle = self._ioloop.add_timeout( self._ioloop.time() + timeout, on_timeout ) future.add_done_callback(lambda _: self._ioloop.remove_timeout(handle)) self._frame_future = future return future
def get_event(self, request, tag='', callback=None, ): ''' Get an event (async of course) return a future that will get it later ''' future = Future() if callback is not None: def handle_future(future): response = future.result() self.io_loop.add_callback(callback, response) future.add_done_callback(handle_future) # add this tag and future to the callbacks self.tag_map[tag].append(future) self.request_map[request].append((tag, future)) return future
def fetch(self, request, callback=None, **kwargs): """Executes a request, calling callback with an `HTTPResponse`. The request may be either a string URL or an `HTTPRequest` object. If it is a string, we construct an `HTTPRequest` using any additional kwargs: ``HTTPRequest(request, **kwargs)`` If an error occurs during the fetch, the HTTPResponse given to the callback has a non-None error attribute that contains the exception encountered during the request. You can call response.rethrow() to throw the exception (if any) in the callback. """ if not isinstance(request, HTTPRequest): request = HTTPRequest(url=request, **kwargs) # We may modify this (to add Host, Accept-Encoding, etc), # so make sure we don't modify the caller's object. This is also # where normal dicts get converted to HTTPHeaders objects. request.headers = httputil.HTTPHeaders(request.headers) request = _RequestProxy(request, self.defaults) future = Future() if callback is not None: callback = stack_context.wrap(callback) def handle_future(future): exc = future.exception() if isinstance(exc, HTTPError) and exc.response is not None: response = exc.response elif exc is not None: response = HTTPResponse( request, 599, error=exc, request_time=time.time() - request.start_time) else: response = future.result() self.io_loop.add_callback(callback, response) future.add_done_callback(handle_future) def handle_response(response): if response.error: future.set_exception(response.error) else: future.set_result(response) self.fetch_impl(request, handle_response) return future
def wait(self, timeout: Union[float, datetime.timedelta] = None) -> "Future[bool]": """Wait for `.notify`. Returns a `.Future` that resolves ``True`` if the condition is notified, or ``False`` after a timeout. """ waiter = Future() # type: Future[bool] self._waiters.append(waiter) if timeout: def on_timeout() -> None: if not waiter.done(): future_set_result_unless_cancelled(waiter, False) self._garbage_collect() io_loop = ioloop.IOLoop.current() timeout_handle = io_loop.add_timeout(timeout, on_timeout) waiter.add_done_callback(lambda _: io_loop.remove_timeout(timeout_handle)) return waiter
def acquire(self, timeout=None): """递减计数器. 返回一个 Future 对象. 如果计数器(counter)为0将会阻塞, 等待 `.release`. 在超时之后 Future 对象将会抛出 `.TimeoutError` . """ waiter = Future() if self._value > 0: self._value -= 1 waiter.set_result(_ReleasingContextManager(self)) else: self._waiters.append(waiter) if timeout: def on_timeout(): waiter.set_exception(gen.TimeoutError()) self._garbage_collect() io_loop = ioloop.IOLoop.current() timeout_handle = io_loop.add_timeout(timeout, on_timeout) waiter.add_done_callback( lambda _: io_loop.remove_timeout(timeout_handle)) return waiter
def execute_code(self, code, result_extractor=None, done_callback=None, timeout=None): """ Asynchronously execute the given code using the underlying managed kernel client Note: this method is not synchronized, it is the responsibility of the caller to synchronize using the lock member variable e.g. with (yield managed_client.lock.acquire()): yield managed_client.execute_code( code ) Parameters ---------- code : String Python code to be executed result_extractor : function [Optional] Called when the code has finished executing to extract the results into the returned Future Returns ------- Future """ if result_extractor is None: result_extractor = self._result_extractor code = PixieGatewayApp.instance().prepend_execute_code + "\n" + code app_log.debug("Executing Code: %s", code) future = Future() parent_header = self.kernel_manager.execute(self.kernel_handle, code) result_accumulator = [] def on_reply(msg): if 'msg_id' in msg['parent_header'] and msg['parent_header'][ 'msg_id'] == parent_header: if not future.done(): if "channel" not in msg: msg["channel"] = "iopub" result_accumulator.append(msg) # Complete the future on idle status if msg['header']['msg_type'] == 'status' and msg[ 'content']['execution_state'] == 'idle': future.set_result(result_extractor(result_accumulator)) elif msg['header']['msg_type'] == 'error': error_name = msg['content']['ename'] error_value = msg['content']['evalue'] trace = sanitize_traceback(msg['content']['traceback']) future.set_exception( CodeExecutionError(error_name, error_value, trace, code)) else: app_log.warning("Got an orphan message %s", msg['parent_header']) self.current_iopub_handler = on_reply if done_callback is not None: future.add_done_callback(done_callback) #attach the future to the kernel to be notified if it dies self.kernel_manager.register_execute_future(self.kernel_handle, future) if timeout is None: return future else: return gen.with_timeout(timedelta(seconds=timeout), future)
def get(self): global future future = Future() future.add_done_callback(self.done) yield future # 阻塞 挂起当前请求 线程处理其他请求
def execute_code(self, code, result_extractor=None, done_callback=None): """ Asynchronously execute the given code using the underlying managed kernel client Note: this method is not synchronized, it is the responsibility of the caller to synchronize using the lock member variable e.g. with (yield managed_client.lock.acquire()): yield managed_client.execute_code( code ) Parameters ---------- code : String Python code to be executed result_extractor : function [Optional] Called when the code has finished executing to extract the results into the returned Future Returns ------- Future """ if result_extractor is None: result_extractor = self._result_extractor code = PixieGatewayApp.instance().prepend_execute_code + "\n" + code app_log.debug("Executing Code: %s", code) future = Future() parent_header = self.kernel_client.execute(code) result_accumulator = [] def on_reply(msgList): session = type(self.kernel_client.session)( config=self.kernel_client.session.config, key=self.kernel_client.session.key, ) _, msgList = session.feed_identities(msgList) msg = session.deserialize(msgList) if 'msg_id' in msg['parent_header'] and msg['parent_header'][ 'msg_id'] == parent_header: if not future.done(): if "channel" not in msg: msg["channel"] = "iopub" result_accumulator.append(msg) # Complete the future on idle status if msg['header']['msg_type'] == 'status' and msg[ 'content']['execution_state'] == 'idle': future.set_result(result_extractor(result_accumulator)) elif msg['header']['msg_type'] == 'error': error_name = msg['content']['ename'] error_value = msg['content']['evalue'] traceback = "\n".join(msg['content']['traceback']) future.set_exception( Exception( 'Code execution Error {}: {} \nTraceback: {}\nRunning code: {}' .format(error_name, error_value, traceback, code))) else: app_log.warning("Got an orphan message %s", msg) self.iopub.on_recv(on_reply) if done_callback is not None: future.add_done_callback(done_callback) return future
def put_task(self, datapoint, callback): f = Future() f.add_done_callback(callback) self.input_queue.put((datapoint, f)) return f
def get(self, *args, **kwargs): global future future = Future() future.add_done_callback(self.done) yield future
class KernelGatewayWSClient(LoggingConfigurable): """Proxy web socket connection to a kernel/enterprise gateway.""" def __init__(self, **kwargs): super(KernelGatewayWSClient, self).__init__(**kwargs) self.kernel_id = None self.ws = None self.ws_future = Future() self.ws_future_cancelled = False @gen.coroutine def _connect(self, kernel_id): self.kernel_id = kernel_id ws_url = url_path_join( os.getenv('KG_WS_URL', KG_URL.replace('http', 'ws')), '/api/kernels', url_escape(kernel_id), 'channels') self.log.info('Connecting to {}'.format(ws_url)) parameters = { "headers": KG_HEADERS, "validate_cert": VALIDATE_KG_CERT, "connect_timeout": KG_CONNECT_TIMEOUT, "request_timeout": KG_REQUEST_TIMEOUT } if KG_HTTP_USER: parameters["auth_username"] = KG_HTTP_USER if KG_HTTP_PASS: parameters["auth_password"] = KG_HTTP_PASS if KG_CLIENT_KEY: parameters["client_key"] = KG_CLIENT_KEY parameters["client_cert"] = KG_CLIENT_CERT if KG_CLIENT_CA: parameters["ca_certs"] = KG_CLIENT_CA request = HTTPRequest(ws_url, **parameters) self.ws_future = websocket_connect(request) self.ws_future.add_done_callback(self._connection_done) def _connection_done(self, fut): if not self.ws_future_cancelled: # prevent concurrent.futures._base.CancelledError self.ws = fut.result() self.log.debug("Connection is ready: ws: {}".format(self.ws)) else: self.log.warning( "Websocket connection has been cancelled via client disconnect before its establishment. " "Kernel with ID '{}' may not be terminated on Gateway: {}". format(self.kernel_id, KG_URL)) def _disconnect(self): if self.ws is not None: # Close connection self.ws.close() elif not self.ws_future.done(): # Cancel pending connection. Since future.cancel() is a noop on tornado, we'll track cancellation locally self.ws_future.cancel() self.ws_future_cancelled = True self.log.debug("_disconnect: ws_future_cancelled: {}".format( self.ws_future_cancelled)) @gen.coroutine def _read_messages(self, callback): """Read messages from gateway server.""" while True: message = None if not self.ws_future_cancelled: try: message = yield self.ws.read_message() except Exception as e: self.log.error( "Exception reading message from websocket: {}".format( e)) # , exc_info=True) if message is None: break callback( message ) # pass back to notebook client (see self.on_open and WebSocketChannelsHandler.open) else: # ws cancelled - stop reading break def on_open(self, kernel_id, message_callback, **kwargs): """Web socket connection open against gateway server.""" self._connect(kernel_id) loop = IOLoop.current() loop.add_future(self.ws_future, lambda future: self._read_messages(message_callback)) def on_message(self, message): """Send message to gateway server.""" if self.ws is None: loop = IOLoop.current() loop.add_future(self.ws_future, lambda future: self._write_message(message)) else: self._write_message(message) def _write_message(self, message): """Send message to gateway server.""" try: if not self.ws_future_cancelled: self.ws.write_message(message) except Exception as e: self.log.error("Exception writing message to websocket: {}".format( e)) # , exc_info=True) def on_close(self): """Web socket closed event.""" self._disconnect()
class WebSocketTest(AsyncHTTPTestCase): def get_app(self): self.close_future = Future() return Application([ ('/echo', EchoHandler, dict(close_future=self.close_future)), ('/non_ws', NonWebSocketHandler), ('/header', HeaderHandler, dict(close_future=self.close_future)), ('/close_reason', CloseReasonHandler, dict(close_future=self.close_future)), ]) def test_http_request(self): # WS server, HTTP client. response = self.fetch('/echo') self.assertEqual(response.code, 400) @gen_test def test_websocket_gen(self): ws = yield websocket_connect('ws://localhost:%d/echo' % self.get_http_port(), io_loop=self.io_loop) ws.write_message('hello') response = yield ws.read_message() self.assertEqual(response, 'hello') ws.close() yield self.close_future def test_websocket_callbacks(self): websocket_connect('ws://localhost:%d/echo' % self.get_http_port(), io_loop=self.io_loop, callback=self.stop) ws = self.wait().result() ws.write_message('hello') ws.read_message(self.stop) response = self.wait().result() self.assertEqual(response, 'hello') self.close_future.add_done_callback(lambda f: self.stop()) ws.close() self.wait() @gen_test def test_websocket_http_fail(self): with self.assertRaises(HTTPError) as cm: yield websocket_connect('ws://localhost:%d/notfound' % self.get_http_port(), io_loop=self.io_loop) self.assertEqual(cm.exception.code, 404) @gen_test def test_websocket_http_success(self): with self.assertRaises(WebSocketError): yield websocket_connect('ws://localhost:%d/non_ws' % self.get_http_port(), io_loop=self.io_loop) @gen_test def test_websocket_network_fail(self): sock, port = bind_unused_port() sock.close() with self.assertRaises(IOError): with ExpectLog(gen_log, ".*"): yield websocket_connect('ws://localhost:%d/' % port, io_loop=self.io_loop, connect_timeout=3600) @gen_test def test_websocket_close_buffered_data(self): ws = yield websocket_connect('ws://localhost:%d/echo' % self.get_http_port()) ws.write_message('hello') ws.write_message('world') ws.stream.close() yield self.close_future @gen_test def test_websocket_headers(self): # Ensure that arbitrary headers can be passed through websocket_connect. ws = yield websocket_connect( HTTPRequest('ws://localhost:%d/header' % self.get_http_port(), headers={'X-Test': 'hello'})) response = yield ws.read_message() self.assertEqual(response, 'hello') ws.close() yield self.close_future @gen_test def test_server_close_reason(self): ws = yield websocket_connect('ws://localhost:%d/close_reason' % self.get_http_port()) msg = yield ws.read_message() # A message of None means the other side closed the connection. self.assertIs(msg, None) self.assertEqual(ws.close_code, 1001) self.assertEqual(ws.close_reason, "goodbye") @gen_test def test_client_close_reason(self): ws = yield websocket_connect('ws://localhost:%d/echo' % self.get_http_port()) ws.close(1001, 'goodbye') code, reason = yield self.close_future self.assertEqual(code, 1001) self.assertEqual(reason, 'goodbye') @gen_test def test_check_origin_valid_no_path(self): port = self.get_http_port() url = 'ws://localhost:%d/echo' % port headers = {'Origin': 'http://localhost:%d' % port} ws = yield websocket_connect(HTTPRequest(url, headers=headers), io_loop=self.io_loop) ws.write_message('hello') response = yield ws.read_message() self.assertEqual(response, 'hello') ws.close() yield self.close_future @gen_test def test_check_origin_valid_with_path(self): port = self.get_http_port() url = 'ws://localhost:%d/echo' % port headers = {'Origin': 'http://localhost:%d/something' % port} ws = yield websocket_connect(HTTPRequest(url, headers=headers), io_loop=self.io_loop) ws.write_message('hello') response = yield ws.read_message() self.assertEqual(response, 'hello') ws.close() yield self.close_future @gen_test def test_check_origin_invalid_partial_url(self): port = self.get_http_port() url = 'ws://localhost:%d/echo' % port headers = {'Origin': 'localhost:%d' % port} with self.assertRaises(HTTPError) as cm: yield websocket_connect(HTTPRequest(url, headers=headers), io_loop=self.io_loop) self.assertEqual(cm.exception.code, 403) @gen_test def test_check_origin_invalid(self): port = self.get_http_port() url = 'ws://localhost:%d/echo' % port # Host is localhost, which should not be accessible from some other # domain headers = {'Origin': 'http://somewhereelse.com'} with self.assertRaises(HTTPError) as cm: yield websocket_connect(HTTPRequest(url, headers=headers), io_loop=self.io_loop) self.assertEqual(cm.exception.code, 403) @gen_test def test_check_origin_invalid_subdomains(self): port = self.get_http_port() url = 'ws://localhost:%d/echo' % port # Subdomains should be disallowed by default. If we could pass a # resolver to websocket_connect we could test sibling domains as well. headers = {'Origin': 'http://subtenant.localhost'} with self.assertRaises(HTTPError) as cm: yield websocket_connect(HTTPRequest(url, headers=headers), io_loop=self.io_loop) self.assertEqual(cm.exception.code, 403)
def get(self): global future future = Future() future.add_done_callback(self.done) yield future
class WebSocketTest(WebSocketBaseTestCase): def get_app(self): self.close_future = Future() return Application([ ('/echo', EchoHandler, dict(close_future=self.close_future)), ('/non_ws', NonWebSocketHandler), ('/header', HeaderHandler, dict(close_future=self.close_future)), ('/header_echo', HeaderEchoHandler, dict(close_future=self.close_future)), ('/close_reason', CloseReasonHandler, dict(close_future=self.close_future)), ('/error_in_on_message', ErrorInOnMessageHandler, dict(close_future=self.close_future)), ('/async_prepare', AsyncPrepareHandler, dict(close_future=self.close_future)), ('/path_args/(.*)', PathArgsHandler, dict(close_future=self.close_future)), ('/coroutine', CoroutineOnMessageHandler, dict(close_future=self.close_future)), ('/render', RenderMessageHandler, dict(close_future=self.close_future)), ], template_loader=DictLoader({ 'message.html': '<b>{{ message }}</b>', })) def tearDown(self): super(WebSocketTest, self).tearDown() RequestHandler._template_loaders.clear() def test_http_request(self): # WS server, HTTP client. response = self.fetch('/echo') self.assertEqual(response.code, 400) def test_bad_websocket_version(self): response = self.fetch('/echo', headers={ 'Connection': 'Upgrade', 'Upgrade': 'WebSocket', 'Sec-WebSocket-Version': '12' }) self.assertEqual(response.code, 426) @gen_test def test_websocket_gen(self): ws = yield self.ws_connect('/echo') yield ws.write_message('hello') response = yield ws.read_message() self.assertEqual(response, 'hello') yield self.close(ws) def test_websocket_callbacks(self): websocket_connect('ws://127.0.0.1:%d/echo' % self.get_http_port(), callback=self.stop) ws = self.wait().result() ws.write_message('hello') ws.read_message(self.stop) response = self.wait().result() self.assertEqual(response, 'hello') self.close_future.add_done_callback(lambda f: self.stop()) ws.close() self.wait() @gen_test def test_binary_message(self): ws = yield self.ws_connect('/echo') ws.write_message(b'hello \xe9', binary=True) response = yield ws.read_message() self.assertEqual(response, b'hello \xe9') yield self.close(ws) @gen_test def test_unicode_message(self): ws = yield self.ws_connect('/echo') ws.write_message(u'hello \u00e9') response = yield ws.read_message() self.assertEqual(response, u'hello \u00e9') yield self.close(ws) @gen_test def test_render_message(self): ws = yield self.ws_connect('/render') ws.write_message('hello') response = yield ws.read_message() self.assertEqual(response, '<b>hello</b>') yield self.close(ws) @gen_test def test_error_in_on_message(self): ws = yield self.ws_connect('/error_in_on_message') ws.write_message('hello') with ExpectLog(app_log, "Uncaught exception"): response = yield ws.read_message() self.assertIs(response, None) yield self.close(ws) @gen_test def test_websocket_http_fail(self): with self.assertRaises(HTTPError) as cm: yield self.ws_connect('/notfound') self.assertEqual(cm.exception.code, 404) @gen_test def test_websocket_http_success(self): with self.assertRaises(WebSocketError): yield self.ws_connect('/non_ws') @gen_test def test_websocket_network_fail(self): sock, port = bind_unused_port() sock.close() with self.assertRaises(IOError): with ExpectLog(gen_log, ".*"): yield websocket_connect('ws://127.0.0.1:%d/' % port, connect_timeout=3600) @gen_test def test_websocket_close_buffered_data(self): ws = yield websocket_connect('ws://127.0.0.1:%d/echo' % self.get_http_port()) ws.write_message('hello') ws.write_message('world') # Close the underlying stream. ws.stream.close() yield self.close_future @gen_test def test_websocket_headers(self): # Ensure that arbitrary headers can be passed through websocket_connect. ws = yield websocket_connect( HTTPRequest('ws://127.0.0.1:%d/header' % self.get_http_port(), headers={'X-Test': 'hello'})) response = yield ws.read_message() self.assertEqual(response, 'hello') yield self.close(ws) @gen_test def test_websocket_header_echo(self): # Ensure that headers can be returned in the response. # Specifically, that arbitrary headers passed through websocket_connect # can be returned. ws = yield websocket_connect( HTTPRequest('ws://127.0.0.1:%d/header_echo' % self.get_http_port(), headers={'X-Test-Hello': 'hello'})) self.assertEqual(ws.headers.get('X-Test-Hello'), 'hello') self.assertEqual(ws.headers.get('X-Extra-Response-Header'), 'Extra-Response-Value') yield self.close(ws) @gen_test def test_server_close_reason(self): ws = yield self.ws_connect('/close_reason') msg = yield ws.read_message() # A message of None means the other side closed the connection. self.assertIs(msg, None) self.assertEqual(ws.close_code, 1001) self.assertEqual(ws.close_reason, "goodbye") # The on_close callback is called no matter which side closed. code, reason = yield self.close_future # The client echoed the close code it received to the server, # so the server's close code (returned via close_future) is # the same. self.assertEqual(code, 1001) @gen_test def test_client_close_reason(self): ws = yield self.ws_connect('/echo') ws.close(1001, 'goodbye') code, reason = yield self.close_future self.assertEqual(code, 1001) self.assertEqual(reason, 'goodbye') @gen_test def test_write_after_close(self): ws = yield self.ws_connect('/close_reason') msg = yield ws.read_message() self.assertIs(msg, None) with self.assertRaises(StreamClosedError): ws.write_message('hello') @gen_test def test_async_prepare(self): # Previously, an async prepare method triggered a bug that would # result in a timeout on test shutdown (and a memory leak). ws = yield self.ws_connect('/async_prepare') ws.write_message('hello') res = yield ws.read_message() self.assertEqual(res, 'hello') @gen_test def test_path_args(self): ws = yield self.ws_connect('/path_args/hello') res = yield ws.read_message() self.assertEqual(res, 'hello') @gen_test def test_coroutine(self): ws = yield self.ws_connect('/coroutine') # Send both messages immediately, coroutine must process one at a time. yield ws.write_message('hello1') yield ws.write_message('hello2') res = yield ws.read_message() self.assertEqual(res, 'hello1') res = yield ws.read_message() self.assertEqual(res, 'hello2') @gen_test def test_check_origin_valid_no_path(self): port = self.get_http_port() url = 'ws://127.0.0.1:%d/echo' % port headers = {'Origin': 'http://127.0.0.1:%d' % port} ws = yield websocket_connect(HTTPRequest(url, headers=headers)) ws.write_message('hello') response = yield ws.read_message() self.assertEqual(response, 'hello') yield self.close(ws) @gen_test def test_check_origin_valid_with_path(self): port = self.get_http_port() url = 'ws://127.0.0.1:%d/echo' % port headers = {'Origin': 'http://127.0.0.1:%d/something' % port} ws = yield websocket_connect(HTTPRequest(url, headers=headers)) ws.write_message('hello') response = yield ws.read_message() self.assertEqual(response, 'hello') yield self.close(ws) @gen_test def test_check_origin_invalid_partial_url(self): port = self.get_http_port() url = 'ws://127.0.0.1:%d/echo' % port headers = {'Origin': '127.0.0.1:%d' % port} with self.assertRaises(HTTPError) as cm: yield websocket_connect(HTTPRequest(url, headers=headers)) self.assertEqual(cm.exception.code, 403) @gen_test def test_check_origin_invalid(self): port = self.get_http_port() url = 'ws://127.0.0.1:%d/echo' % port # Host is 127.0.0.1, which should not be accessible from some other # domain headers = {'Origin': 'http://somewhereelse.com'} with self.assertRaises(HTTPError) as cm: yield websocket_connect(HTTPRequest(url, headers=headers)) self.assertEqual(cm.exception.code, 403) @gen_test def test_check_origin_invalid_subdomains(self): port = self.get_http_port() url = 'ws://localhost:%d/echo' % port # Subdomains should be disallowed by default. If we could pass a # resolver to websocket_connect we could test sibling domains as well. headers = {'Origin': 'http://subtenant.localhost'} with self.assertRaises(HTTPError) as cm: yield websocket_connect(HTTPRequest(url, headers=headers)) self.assertEqual(cm.exception.code, 403)
class GatewayWebSocketClient(LoggingConfigurable): """Proxy web socket connection to a kernel/enterprise gateway.""" def __init__(self, **kwargs): super(GatewayWebSocketClient, self).__init__(**kwargs) self.kernel_id = None self.ws = None self.ws_future = Future() self.disconnected = False async def _connect(self, kernel_id): # websocket is initialized before connection self.ws = None self.kernel_id = kernel_id ws_url = url_path_join( GatewayClient.instance().ws_url, GatewayClient.instance().kernels_endpoint, url_escape(kernel_id), 'channels' ) self.log.info('Connecting to {}'.format(ws_url)) kwargs = {} kwargs = GatewayClient.instance().load_connection_args(**kwargs) request = HTTPRequest(ws_url, **kwargs) self.ws_future = websocket_connect(request) self.ws_future.add_done_callback(self._connection_done) def _connection_done(self, fut): if not self.disconnected and fut.exception() is None: # prevent concurrent.futures._base.CancelledError self.ws = fut.result() self.log.debug("Connection is ready: ws: {}".format(self.ws)) else: self.log.warning("Websocket connection has been closed via client disconnect or due to error. " "Kernel with ID '{}' may not be terminated on GatewayClient: {}". format(self.kernel_id, GatewayClient.instance().url)) def _disconnect(self): self.disconnected = True if self.ws is not None: # Close connection self.ws.close() elif not self.ws_future.done(): # Cancel pending connection. Since future.cancel() is a noop on tornado, we'll track cancellation locally self.ws_future.cancel() self.log.debug("_disconnect: future cancelled, disconnected: {}".format(self.disconnected)) async def _read_messages(self, callback): """Read messages from gateway server.""" while self.ws is not None: message = None if not self.disconnected: try: message = await self.ws.read_message() except Exception as e: self.log.error("Exception reading message from websocket: {}".format(e)) # , exc_info=True) if message is None: if not self.disconnected: self.log.warning("Lost connection to Gateway: {}".format(self.kernel_id)) break callback(message) # pass back to notebook client (see self.on_open and WebSocketChannelsHandler.open) else: # ws cancelled - stop reading break if not self.disconnected: # if websocket is not disconnected by client, attept to reconnect to Gateway self.log.info("Attempting to re-establish the connection to Gateway: {}".format(self.kernel_id)) self._connect(self.kernel_id) loop = IOLoop.current() loop.add_future(self.ws_future, lambda future: self._read_messages(callback)) def on_open(self, kernel_id, message_callback, **kwargs): """Web socket connection open against gateway server.""" self._connect(kernel_id) loop = IOLoop.current() loop.add_future( self.ws_future, lambda future: self._read_messages(message_callback) ) def on_message(self, message): """Send message to gateway server.""" if self.ws is None: loop = IOLoop.current() loop.add_future( self.ws_future, lambda future: self._write_message(message) ) else: self._write_message(message) def _write_message(self, message): """Send message to gateway server.""" try: if not self.disconnected and self.ws is not None: self.ws.write_message(message) except Exception as e: self.log.error("Exception writing message to websocket: {}".format(e)) # , exc_info=True) def on_close(self): """Web socket closed event.""" self._disconnect()
class WebSocketTest(WebSocketBaseTestCase): def get_app(self): self.close_future = Future() # type: Future[None] return Application( [ ("/echo", EchoHandler, dict(close_future=self.close_future)), ("/redirect", RedirectHandler, dict(url="/echo")), ("/double_redirect", RedirectHandler, dict(url="/redirect")), ("/non_ws", NonWebSocketHandler), ("/header", HeaderHandler, dict(close_future=self.close_future)), ( "/header_echo", HeaderEchoHandler, dict(close_future=self.close_future), ), ( "/close_reason", CloseReasonHandler, dict(close_future=self.close_future), ), ( "/error_in_on_message", ErrorInOnMessageHandler, dict(close_future=self.close_future), ), ( "/async_prepare", AsyncPrepareHandler, dict(close_future=self.close_future), ), ( "/path_args/(.*)", PathArgsHandler, dict(close_future=self.close_future), ), ( "/coroutine", CoroutineOnMessageHandler, dict(close_future=self.close_future), ), ("/render", RenderMessageHandler, dict(close_future=self.close_future)), ( "/subprotocol", SubprotocolHandler, dict(close_future=self.close_future), ), ( "/open_coroutine", OpenCoroutineHandler, dict(close_future=self.close_future, test=self), ), ], template_loader=DictLoader( {"message.html": "<b>{{ message }}</b>"}), ) def get_http_client(self): # These tests require HTTP/1; force the use of SimpleAsyncHTTPClient. return SimpleAsyncHTTPClient() def tearDown(self): super(WebSocketTest, self).tearDown() RequestHandler._template_loaders.clear() def test_http_request(self): # WS server, HTTP client. response = self.fetch("/echo") self.assertEqual(response.code, 400) def test_missing_websocket_key(self): response = self.fetch( "/echo", headers={ "Connection": "Upgrade", "Upgrade": "WebSocket", "Sec-WebSocket-Version": "13", }, ) self.assertEqual(response.code, 400) def test_bad_websocket_version(self): response = self.fetch( "/echo", headers={ "Connection": "Upgrade", "Upgrade": "WebSocket", "Sec-WebSocket-Version": "12", }, ) self.assertEqual(response.code, 426) @gen_test def test_websocket_gen(self): ws = yield self.ws_connect("/echo") yield ws.write_message("hello") response = yield ws.read_message() self.assertEqual(response, "hello") def test_websocket_callbacks(self): websocket_connect("ws://127.0.0.1:%d/echo" % self.get_http_port(), callback=self.stop) ws = self.wait().result() ws.write_message("hello") ws.read_message(self.stop) response = self.wait().result() self.assertEqual(response, "hello") self.close_future.add_done_callback(lambda f: self.stop()) ws.close() self.wait() @gen_test def test_binary_message(self): ws = yield self.ws_connect("/echo") ws.write_message(b"hello \xe9", binary=True) response = yield ws.read_message() self.assertEqual(response, b"hello \xe9") @gen_test def test_unicode_message(self): ws = yield self.ws_connect("/echo") ws.write_message(u"hello \u00e9") response = yield ws.read_message() self.assertEqual(response, u"hello \u00e9") @gen_test def test_render_message(self): ws = yield self.ws_connect("/render") ws.write_message("hello") response = yield ws.read_message() self.assertEqual(response, "<b>hello</b>") @gen_test def test_error_in_on_message(self): ws = yield self.ws_connect("/error_in_on_message") ws.write_message("hello") with ExpectLog(app_log, "Uncaught exception"): response = yield ws.read_message() self.assertIs(response, None) @gen_test def test_websocket_http_fail(self): with self.assertRaises(HTTPError) as cm: yield self.ws_connect("/notfound") self.assertEqual(cm.exception.code, 404) @gen_test def test_websocket_http_success(self): with self.assertRaises(WebSocketError): yield self.ws_connect("/non_ws") @gen_test def test_websocket_redirect(self): conn = yield self.ws_connect("/redirect") yield conn.write_message("hello redirect") msg = yield conn.read_message() self.assertEqual("hello redirect", msg) @gen_test def test_websocket_double_redirect(self): conn = yield self.ws_connect("/double_redirect") yield conn.write_message("hello redirect") msg = yield conn.read_message() self.assertEqual("hello redirect", msg) @gen_test def test_websocket_network_fail(self): sock, port = bind_unused_port() sock.close() with self.assertRaises(IOError): with ExpectLog(gen_log, ".*"): yield websocket_connect("ws://127.0.0.1:%d/" % port, connect_timeout=3600) @gen_test def test_websocket_close_buffered_data(self): ws = yield websocket_connect("ws://127.0.0.1:%d/echo" % self.get_http_port()) ws.write_message("hello") ws.write_message("world") # Close the underlying stream. ws.stream.close() @gen_test def test_websocket_headers(self): # Ensure that arbitrary headers can be passed through websocket_connect. ws = yield websocket_connect( HTTPRequest( "ws://127.0.0.1:%d/header" % self.get_http_port(), headers={"X-Test": "hello"}, )) response = yield ws.read_message() self.assertEqual(response, "hello") @gen_test def test_websocket_header_echo(self): # Ensure that headers can be returned in the response. # Specifically, that arbitrary headers passed through websocket_connect # can be returned. ws = yield websocket_connect( HTTPRequest( "ws://127.0.0.1:%d/header_echo" % self.get_http_port(), headers={"X-Test-Hello": "hello"}, )) self.assertEqual(ws.headers.get("X-Test-Hello"), "hello") self.assertEqual(ws.headers.get("X-Extra-Response-Header"), "Extra-Response-Value") @gen_test def test_server_close_reason(self): ws = yield self.ws_connect("/close_reason") msg = yield ws.read_message() # A message of None means the other side closed the connection. self.assertIs(msg, None) self.assertEqual(ws.close_code, 1001) self.assertEqual(ws.close_reason, "goodbye") # The on_close callback is called no matter which side closed. code, reason = yield self.close_future # The client echoed the close code it received to the server, # so the server's close code (returned via close_future) is # the same. self.assertEqual(code, 1001) @gen_test def test_client_close_reason(self): ws = yield self.ws_connect("/echo") ws.close(1001, "goodbye") code, reason = yield self.close_future self.assertEqual(code, 1001) self.assertEqual(reason, "goodbye") @gen_test def test_write_after_close(self): ws = yield self.ws_connect("/close_reason") msg = yield ws.read_message() self.assertIs(msg, None) with self.assertRaises(WebSocketClosedError): ws.write_message("hello") @gen_test def test_async_prepare(self): # Previously, an async prepare method triggered a bug that would # result in a timeout on test shutdown (and a memory leak). ws = yield self.ws_connect("/async_prepare") ws.write_message("hello") res = yield ws.read_message() self.assertEqual(res, "hello") @gen_test def test_path_args(self): ws = yield self.ws_connect("/path_args/hello") res = yield ws.read_message() self.assertEqual(res, "hello") @gen_test def test_coroutine(self): ws = yield self.ws_connect("/coroutine") # Send both messages immediately, coroutine must process one at a time. yield ws.write_message("hello1") yield ws.write_message("hello2") res = yield ws.read_message() self.assertEqual(res, "hello1") res = yield ws.read_message() self.assertEqual(res, "hello2") @gen_test def test_check_origin_valid_no_path(self): port = self.get_http_port() url = "ws://127.0.0.1:%d/echo" % port headers = {"Origin": "http://127.0.0.1:%d" % port} ws = yield websocket_connect(HTTPRequest(url, headers=headers)) ws.write_message("hello") response = yield ws.read_message() self.assertEqual(response, "hello") @gen_test def test_check_origin_valid_with_path(self): port = self.get_http_port() url = "ws://127.0.0.1:%d/echo" % port headers = {"Origin": "http://127.0.0.1:%d/something" % port} ws = yield websocket_connect(HTTPRequest(url, headers=headers)) ws.write_message("hello") response = yield ws.read_message() self.assertEqual(response, "hello") @gen_test def test_check_origin_invalid_partial_url(self): port = self.get_http_port() url = "ws://127.0.0.1:%d/echo" % port headers = {"Origin": "127.0.0.1:%d" % port} with self.assertRaises(HTTPError) as cm: yield websocket_connect(HTTPRequest(url, headers=headers)) self.assertEqual(cm.exception.code, 403) @gen_test def test_check_origin_invalid(self): port = self.get_http_port() url = "ws://127.0.0.1:%d/echo" % port # Host is 127.0.0.1, which should not be accessible from some other # domain headers = {"Origin": "http://somewhereelse.com"} with self.assertRaises(HTTPError) as cm: yield websocket_connect(HTTPRequest(url, headers=headers)) self.assertEqual(cm.exception.code, 403) @gen_test def test_check_origin_invalid_subdomains(self): port = self.get_http_port() url = "ws://localhost:%d/echo" % port # Subdomains should be disallowed by default. If we could pass a # resolver to websocket_connect we could test sibling domains as well. headers = {"Origin": "http://subtenant.localhost"} with self.assertRaises(HTTPError) as cm: yield websocket_connect(HTTPRequest(url, headers=headers)) self.assertEqual(cm.exception.code, 403) @gen_test def test_subprotocols(self): ws = yield self.ws_connect("/subprotocol", subprotocols=["badproto", "goodproto"]) self.assertEqual(ws.selected_subprotocol, "goodproto") res = yield ws.read_message() self.assertEqual(res, "subprotocol=goodproto") @gen_test def test_subprotocols_not_offered(self): ws = yield self.ws_connect("/subprotocol") self.assertIs(ws.selected_subprotocol, None) res = yield ws.read_message() self.assertEqual(res, "subprotocol=None") @gen_test def test_open_coroutine(self): self.message_sent = Event() ws = yield self.ws_connect("/open_coroutine") yield ws.write_message("hello") self.message_sent.set() res = yield ws.read_message() self.assertEqual(res, "ok")
def fetch(self, request, callback=None, raise_error=True, **kwargs): """Executes a request, asynchronously returning an `HTTPResponse`. The request may be either a string URL or an `HTTPRequest` object. If it is a string, we construct an `HTTPRequest` using any additional kwargs: ``HTTPRequest(request, **kwargs)`` This method returns a `.Future` whose result is an `HTTPResponse`. By default, the ``Future`` will raise an `HTTPError` if the request returned a non-200 response code (other errors may also be raised if the server could not be contacted). Instead, if ``raise_error`` is set to False, the response will always be returned regardless of the response code. If a ``callback`` is given, it will be invoked with the `HTTPResponse`. In the callback interface, `HTTPError` is not automatically raised. Instead, you must check the response's ``error`` attribute or call its `~HTTPResponse.rethrow` method. """ if self._closed: raise RuntimeError("fetch() called on closed AsyncHTTPClient") if not isinstance(request, HTTPRequest): request = HTTPRequest(url=request, **kwargs) else: if kwargs: raise ValueError( "kwargs can't be used if request is an HTTPRequest object") # We may modify this (to add Host, Accept-Encoding, etc), # so make sure we don't modify the caller's object. This is also # where normal dicts get converted to HTTPHeaders objects. request.headers = httputil.HTTPHeaders(request.headers) request = _RequestProxy(request, self.defaults) future = Future() if callback is not None: callback = stack_context.wrap(callback) def handle_future(future): exc = future.exception() if isinstance(exc, HTTPError) and exc.response is not None: response = exc.response elif exc is not None: response = HTTPResponse(request, 599, error=exc, request_time=time.time() - request.start_time) else: response = future.result() self.io_loop.add_callback(callback, response) future.add_done_callback(handle_future) def handle_response(response): if raise_error and response.error: future.set_exception(response.error) else: future_set_result_unless_cancelled(future, response) self.fetch_impl(request, handle_response) return future