示例#1
0
    def _get_conn(self):  # -> Future[connection]
        now = self.io_loop.time()

        # Try to reuse in free pool
        while self._free_conn:
            conn = self._free_conn.popleft()
            if now - conn.connected_time > self.max_recycle_sec:
                self._close_async(conn)
                continue
            log.debug("Reusing connection from pool: %s", self.stat())
            fut = Future()
            fut.set_result(conn)
            return fut

        # Open new connection
        if self.max_open == 0 or self._opened_conns < self.max_open:
            self._opened_conns += 1
            log.debug("Creating new connection: %s", self.stat())
            fut = connect(**self.connect_kwargs)
            fut.add_done_callback(self._on_connect)  # self._opened_conns -=1 on exception
            return fut

        # Wait to other connection is released.
        fut = Future()
        self._waitings.append(fut)
        return fut
示例#2
0
 def put_task(self, inputs, callback=None):
     """ return a Future of output."""
     f = Future()
     if callback is not None:
         f.add_done_callback(callback)
     self.input_queue.put((inputs, f))
     return f
示例#3
0
文件: client.py 项目: ei-grad/toredis
    def send_message(self, args, callback=None):

        command = args[0]

        if 'SUBSCRIBE' in command:
            raise NotImplementedError('Not yet.')

        # Do not allow the commands, affecting the execution of other commands,
        # to be used on shared connection.
        if command in ('WATCH', 'MULTI'):
            if self.is_shared():
                raise Exception('Command %s is not allowed while connection '
                                'is shared!' % command)
            if command == 'WATCH':
                self._watch.add(args[1])
            if command == 'MULTI':
                self._multi = True

        # monitor transaction state, to unlock correctly
        if command in ('EXEC', 'DISCARD', 'UNWATCH'):
            if command in ('EXEC', 'DISCARD'):
                self._multi = False
            self._watch.clear()

        self.stream.write(self.format_message(args))

        future = Future()

        if callback is not None:
            future.add_done_callback(stack_context.wrap(callback))

        self.callbacks.append(future.set_result)

        return future
示例#4
0
class ManualCapClient(BaseCapClient):
    def capitalize(self, request_data, callback=None):
        logging.debug("capitalize")
        self.request_data = request_data
        self.stream = IOStream(socket.socket())
        self.stream.connect(('127.0.0.1', self.port),
                            callback=self.handle_connect)
        self.future = Future()
        if callback is not None:
            self.future.add_done_callback(
                stack_context.wrap(lambda future: callback(future.result())))
        return self.future

    def handle_connect(self):
        logging.debug("handle_connect")
        self.stream.write(utf8(self.request_data + "\n"))
        self.stream.read_until(b'\n', callback=self.handle_read)

    def handle_read(self, data):
        logging.debug("handle_read")
        self.stream.close()
        try:
            self.future.set_result(self.process_response(data))
        except CapError as e:
            self.future.set_exception(e)
示例#5
0
文件: locks.py 项目: rgbkrk/tornado
    def wait(self, timeout: Union[float, datetime.timedelta] = None) -> "Future[None]":
        """Block until the internal flag is true.

        Returns a Future, which raises `tornado.util.TimeoutError` after a
        timeout.
        """
        fut = Future()  # type: Future[None]
        if self._value:
            fut.set_result(None)
            return fut
        self._waiters.add(fut)
        fut.add_done_callback(lambda fut: self._waiters.remove(fut))
        if timeout is None:
            return fut
        else:
            timeout_fut = gen.with_timeout(
                timeout, fut, quiet_exceptions=(CancelledError,)
            )
            # This is a slightly clumsy workaround for the fact that
            # gen.with_timeout doesn't cancel its futures. Cancelling
            # fut will remove it from the waiters list.
            timeout_fut.add_done_callback(
                lambda tf: fut.cancel() if not fut.done() else None
            )
            return timeout_fut
示例#6
0
文件: locks.py 项目: rgbkrk/tornado
    def acquire(
        self, timeout: Union[float, datetime.timedelta] = None
    ) -> "Future[_ReleasingContextManager]":
        """Decrement the counter. Returns a Future.

        Block if the counter is zero and wait for a `.release`. The Future
        raises `.TimeoutError` after the deadline.
        """
        waiter = Future()  # type: Future[_ReleasingContextManager]
        if self._value > 0:
            self._value -= 1
            waiter.set_result(_ReleasingContextManager(self))
        else:
            self._waiters.append(waiter)
            if timeout:

                def on_timeout() -> None:
                    if not waiter.done():
                        waiter.set_exception(gen.TimeoutError())
                    self._garbage_collect()

                io_loop = ioloop.IOLoop.current()
                timeout_handle = io_loop.add_timeout(timeout, on_timeout)
                waiter.add_done_callback(
                    lambda _: io_loop.remove_timeout(timeout_handle)
                )
        return waiter
    def fetch(self, request, callback=None, raise_error=True, **kwargs):
        if not isinstance(request, HTTPRequest):
            request = HTTPRequest(url=request, **kwargs)

        key = self.cache.create_key(request)

        # Check and return future if there is a pending request
        pending = self.pending_requests.get(key)
        if pending:
            return pending

        response = self.cache.get_response_and_time(key)
        if response:
            response.cached = True
            if callback:
                self.io_loop.add_callback(callback, response)
            future = Future()
            future.set_result(response)
            return future

        future = orig_fetch(self, request, callback, raise_error, **kwargs)

        self.pending_requests[key] = future

        def cache_response(future):
            exc = future.exception()
            if exc is None:
                self.cache.save_response(key, future.result())

        future.add_done_callback(cache_response)
        return future
示例#8
0
    def _start_processing_requests(self):
        while True:
            data = yield gen.Task(self._stream.read_until, '\r\n')
            log.debug('New request: %r', data)
            try:
                msg = json.loads(data)
                key = msg['key']
                method = msg['method']
                args = msg['args']
                kwargs = msg['kwargs']
            except (KeyError, ValueError):
                log.error('Malformed request data: %s', data)
                continue
            try:
                res = self._handler(method, *args, **kwargs)
                if isinstance(res, Future):
                    future = res
                else:
                    future = Future()
                    future.set_result(res)
            except Exception as e:
                log.exception('Failed to handle request: %s', key)
                future = concurrent.TracebackFuture()
                future.set_exception(e)

            future.add_done_callback(partial(self._on_future_finished, key))
示例#9
0
	def wrapper(*args, **kwargs):
		future = Future()
		callback, args, kwargs = replacer.replace(future, args, kwargs)
		if callback is not None:
			future.add_done_callback(
				functools.partial(_auth_future_to_callback, callback))
		f(*args, **kwargs)
		return future
示例#10
0
 def put_task(self, dp, callback=None):
     """
     Same as in :meth:`AsyncPredictorBase.put_task`.
     """
     f = Future()
     if callback is not None:
         f.add_done_callback(callback)
     self.input_queue.put((dp, f))
     return f
示例#11
0
 def put_task(self, dp, callback=None):
     """
     dp must be non-batched, i.e. single instance
     """
     f = Future()
     if callback is not None:
         f.add_done_callback(callback)
     self.input_queue.put((dp, f))
     return f
示例#12
0
文件: rpc.py 项目: yoki123/torpc
 def register(self, name, callback=None):
     msg_id = next(self._generator)
     buf = self._pack_request(msg_id, RPC_REGISTER, 'register', (name))
     future = Future()
     if callback:
         future.add_done_callback(callback)
     self.add_request_table(msg_id, future)
     self.write(buf)
     return future
示例#13
0
    def fetch(self, request, callback=None, raise_error=True, **kwargs):
        """Executes a request, asynchronously returning an `HTTPResponse`.

        The request may be either a string URL or an `HTTPRequest` object.
        If it is a string, we construct an `HTTPRequest` using any additional
        kwargs: ``HTTPRequest(request, **kwargs)``

        This method returns a `.Future` whose result is an
        `HTTPResponse`. By default, the ``Future`` will raise an
        `HTTPError` if the request returned a non-200 response code
        (other errors may also be raised if the server could not be
        contacted). Instead, if ``raise_error`` is set to False, the
        response will always be returned regardless of the response
        code.

        If a ``callback`` is given, it will be invoked with the `HTTPResponse`.
        In the callback interface, `HTTPError` is not automatically raised.
        Instead, you must check the response's ``error`` attribute or
        call its `~HTTPResponse.rethrow` method.
        """
        if self._closed:
            raise RuntimeError("fetch() called on closed AsyncHTTPClient")
        if not isinstance(request, HTTPRequest):
            request = HTTPRequest(url=request, **kwargs)
        else:
            if kwargs:
                raise ValueError("kwargs can't be used if request is an HTTPRequest object")
        # We may modify this (to add Host, Accept-Encoding, etc),
        # so make sure we don't modify the caller's object.  This is also
        # where normal dicts get converted to HTTPHeaders objects.
        request.headers = httputil.HTTPHeaders(request.headers)
        request = _RequestProxy(request, self.defaults)
        future = Future()
        if callback is not None:
            callback = stack_context.wrap(callback)

            def handle_future(future):
                exc = future.exception()
                if isinstance(exc, HTTPError) and exc.response is not None:
                    response = exc.response
                elif exc is not None:
                    response = HTTPResponse(
                        request, 599, error=exc,
                        request_time=time.time() - request.start_time)
                else:
                    response = future.result()
                self.io_loop.add_callback(callback, response)
            future.add_done_callback(handle_future)

        def handle_response(response):
            if raise_error and response.error:
                future.set_exception(response.error)
            else:
                future_set_result_unless_cancelled(future, response)
        self.fetch_impl(request, handle_response)
        return future
示例#14
0
文件: cursor.py 项目: mosquito/mytor
 def close(self):
     if self._cursor is None:
         self._cursor.close()
         future = Future()
         future.set_result(None)
     else:
         future = async_call_method(self._cursor.close)
     self._cursor = None
     future.add_done_callback(self._release_lock)
     return future
示例#15
0
 def put_task(self, inputs, callback=None):
     """
     :params inputs: a data point (list of component) matching input_names (not batched)
     :params callback: a callback to get called with the list of outputs
     :returns: a Future of output."""
     f = Future()
     if callback is not None:
         f.add_done_callback(callback)
     self.input_queue.put((inputs, f))
     return f
示例#16
0
文件: rpc.py 项目: yoki123/torpc
    def call(self, method_name, *arg, **kwargs):
        _callback = kwargs.get('callback')
        msg_id = next(self._generator)
        buff = self._pack_request(msg_id, RPC_REQUEST, method_name, arg)
        future = Future()
        self.add_request_table(msg_id, future)

        if _callback:
            future.add_done_callback(_callback)
        self.write(buff)
        return future
def feed_puppy(callback=None):
    def do_thing(future):
        time.sleep(0.2)
        future.set_result(True)
    future = Future()
    t = threading.Thread(target=do_thing, args=(future,))
    t.start()

    if callback:
        future.add_done_callback(lambda f: tornado.ioloop.IOLoop.current().add_callback(callback, f))

    return future
示例#18
0
def _set_timeout(
    future: Future, timeout: Union[None, float, datetime.timedelta]
) -> None:
    if timeout:

        def on_timeout() -> None:
            if not future.done():
                future.set_exception(gen.TimeoutError())

        io_loop = ioloop.IOLoop.current()
        timeout_handle = io_loop.add_timeout(timeout, on_timeout)
        future.add_done_callback(lambda _: io_loop.remove_timeout(timeout_handle))
示例#19
0
文件: messages.py 项目: bbqsrc/robust
    def generate_challenge(self):
        """Returns authentication challenge message, with OAuth URL."""
        nonce = uuid.uuid4().hex

        future = Future()
        future.add_done_callback(lambda future: self.session.transport.write_json(
            self.authenticate(future.result())))
        self.session.properties.futures[nonce] = future

        return {
            "url": "http://robust.brendan.so/auth/twitter?robust_token=%s" % nonce
        }
示例#20
0
def queue_request(queue, queue_name, **kwargs):
    def queue_listener(queue_name, body):
        f.set_result(loads(body.decode('utf-8')))

    if 'timeout' not in kwargs:
        kwargs['timeout'] = 600

    kwargs['reply_to'] = '%s-reply-%s' % (queue_name, id(kwargs))
    yield queue.send(queue_name, dumps(kwargs))
    f = Future()
    queue.listen([kwargs['reply_to']], queue_listener, workers_count=1)

    f.add_done_callback(lambda f: queue.stop([kwargs['reply_to']]))

    return (yield with_timeout(timedelta(seconds=kwargs['timeout']), f))
示例#21
0
文件: auth.py 项目: alexdxy/tornado
    def wrapper(*args, **kwargs):
        future = Future()
        callback, args, kwargs = replacer.replace(future, args, kwargs)
        if callback is not None:
            future.add_done_callback(
                functools.partial(_auth_future_to_callback, callback))

        def handle_exception(typ, value, tb):
            if future.done():
                return False
            else:
                future_set_exc_info(future, (typ, value, tb))
                return True
        with ExceptionStackContext(handle_exception):
            f(*args, **kwargs)
        return future
示例#22
0
    def fetch(self, request, callback=None, **kwargs):
        """Executes a request, asynchronously returning an `HTTPResponse`.

        The request may be either a string URL or an `HTTPRequest` object.
        If it is a string, we construct an `HTTPRequest` using any additional
        kwargs: ``HTTPRequest(request, **kwargs)``

        This method returns a `~concurrent.futures.Future` whose
        result is an `HTTPResponse`.  The ``Future`` wil raise an
        `HTTPError` if the request returned a non-200 response code.

        If a ``callback`` is given, it will be invoked with the `HTTPResponse`.
        In the callback interface, `HTTPError` is not automatically raised.
        Instead, you must check the response's ``error`` attribute or
        call its `~HTTPResponse.rethrow` method.
        """
        if not isinstance(request, HTTPRequest):
            request = HTTPRequest(url=request, **kwargs)
        # We may modify this (to add Host, Accept-Encoding, etc),
        # so make sure we don't modify the caller's object.  This is also
        # where normal dicts get converted to HTTPHeaders objects.
        request.headers = httputil.HTTPHeaders(request.headers)
        request = _RequestProxy(request, self.defaults)
        future = Future()
        if callback is not None:
            callback = stack_context.wrap(callback)

            def handle_future(future):
                exc = future.exception()
                if isinstance(exc, HTTPError) and exc.response is not None:
                    response = exc.response
                elif exc is not None:
                    response = HTTPResponse(
                        request, 599, error=exc,
                        request_time=time.time() - request.start_time)
                else:
                    response = future.result()
                self.io_loop.add_callback(callback, response)
            future.add_done_callback(handle_future)

        def handle_response(response):
            if response.error:
                future.set_exception(response.error)
            else:
                future.set_result(response)
        self.fetch_impl(request, handle_response)
        return future
示例#23
0
    def wait(self, timeout=None):
        """等待 `.notify`.

        返回一个 `.Future` 对象, 如果条件被通知则为 ``True`` ,
        或者在超时之后为 ``False`` .
        """
        waiter = Future()
        self._waiters.append(waiter)
        if timeout:
            def on_timeout():
                waiter.set_result(False)
                self._garbage_collect()
            io_loop = ioloop.IOLoop.current()
            timeout_handle = io_loop.add_timeout(timeout, on_timeout)
            waiter.add_done_callback(
                lambda _: io_loop.remove_timeout(timeout_handle))
        return waiter
示例#24
0
    def wait(self, timeout=None):
        """Wait for `.notify`.

        Returns a `.Future` that resolves ``True`` if the condition is notified,
        or ``False`` after a timeout.
        """
        waiter = Future()
        self._waiters.append(waiter)
        if timeout:
            def on_timeout():
                waiter.set_result(False)
                self._garbage_collect()
            io_loop = ioloop.IOLoop.current()
            timeout_handle = io_loop.add_timeout(timeout, on_timeout)
            waiter.add_done_callback(
                lambda _: io_loop.remove_timeout(timeout_handle))
        return waiter
示例#25
0
    def wrapper(*args, **kwargs):
        future = Future()
        callback, args, kwargs = replacer.replace(future, args, kwargs)
        if callback is not None:
            warnings.warn("callback arguments are deprecated, use the returned Future instead",
                          DeprecationWarning)
            future.add_done_callback(
                wrap(functools.partial(_auth_future_to_callback, callback)))

        def handle_exception(typ, value, tb):
            if future.done():
                return False
            else:
                future_set_exc_info(future, (typ, value, tb))
                return True
        with ExceptionStackContext(handle_exception, delay_warning=True):
            f(*args, **kwargs)
        return future
示例#26
0
    def _get_frame(self, timeout=None):
        future = Future()
        if self._frame_queue:
            future.set_result(self._frame_queue.popleft())
        else:
            if timeout is not None:
                def on_timeout():
                    future.set_exception(_TimeoutException())

                handle = self._ioloop.add_timeout(
                    self._ioloop.time() + timeout, on_timeout
                )
                future.add_done_callback(lambda _:
                                         self._ioloop.remove_timeout(handle))

            self._frame_future = future

        return future
示例#27
0
    def get_event(self,
                  request,
                  tag='',
                  callback=None,
                  ):
        '''
        Get an event (async of course) return a future that will get it later
        '''
        future = Future()
        if callback is not None:
            def handle_future(future):
                response = future.result()
                self.io_loop.add_callback(callback, response)
            future.add_done_callback(handle_future)
        # add this tag and future to the callbacks
        self.tag_map[tag].append(future)
        self.request_map[request].append((tag, future))

        return future
示例#28
0
    def fetch(self, request, callback=None, **kwargs):
        """Executes a request, calling callback with an `HTTPResponse`.

        The request may be either a string URL or an `HTTPRequest` object.
        If it is a string, we construct an `HTTPRequest` using any additional
        kwargs: ``HTTPRequest(request, **kwargs)``

        If an error occurs during the fetch, the HTTPResponse given to the
        callback has a non-None error attribute that contains the exception
        encountered during the request. You can call response.rethrow() to
        throw the exception (if any) in the callback.
        """
        if not isinstance(request, HTTPRequest):
            request = HTTPRequest(url=request, **kwargs)
        # We may modify this (to add Host, Accept-Encoding, etc),
        # so make sure we don't modify the caller's object.  This is also
        # where normal dicts get converted to HTTPHeaders objects.
        request.headers = httputil.HTTPHeaders(request.headers)
        request = _RequestProxy(request, self.defaults)
        future = Future()
        if callback is not None:
            callback = stack_context.wrap(callback)

            def handle_future(future):
                exc = future.exception()
                if isinstance(exc, HTTPError) and exc.response is not None:
                    response = exc.response
                elif exc is not None:
                    response = HTTPResponse(
                        request, 599, error=exc,
                        request_time=time.time() - request.start_time)
                else:
                    response = future.result()
                self.io_loop.add_callback(callback, response)
            future.add_done_callback(handle_future)

        def handle_response(response):
            if response.error:
                future.set_exception(response.error)
            else:
                future.set_result(response)
        self.fetch_impl(request, handle_response)
        return future
示例#29
0
文件: locks.py 项目: rgbkrk/tornado
    def wait(self, timeout: Union[float, datetime.timedelta] = None) -> "Future[bool]":
        """Wait for `.notify`.

        Returns a `.Future` that resolves ``True`` if the condition is notified,
        or ``False`` after a timeout.
        """
        waiter = Future()  # type: Future[bool]
        self._waiters.append(waiter)
        if timeout:

            def on_timeout() -> None:
                if not waiter.done():
                    future_set_result_unless_cancelled(waiter, False)
                self._garbage_collect()

            io_loop = ioloop.IOLoop.current()
            timeout_handle = io_loop.add_timeout(timeout, on_timeout)
            waiter.add_done_callback(lambda _: io_loop.remove_timeout(timeout_handle))
        return waiter
示例#30
0
    def acquire(self, timeout=None):
        """递减计数器. 返回一个 Future 对象.

        如果计数器(counter)为0将会阻塞, 等待 `.release`. 在超时之后
        Future 对象将会抛出 `.TimeoutError` .
        """
        waiter = Future()
        if self._value > 0:
            self._value -= 1
            waiter.set_result(_ReleasingContextManager(self))
        else:
            self._waiters.append(waiter)
            if timeout:
                def on_timeout():
                    waiter.set_exception(gen.TimeoutError())
                    self._garbage_collect()
                io_loop = ioloop.IOLoop.current()
                timeout_handle = io_loop.add_timeout(timeout, on_timeout)
                waiter.add_done_callback(
                    lambda _: io_loop.remove_timeout(timeout_handle))
        return waiter
示例#31
0
    def execute_code(self,
                     code,
                     result_extractor=None,
                     done_callback=None,
                     timeout=None):
        """
        Asynchronously execute the given code using the underlying managed kernel client
        Note: this method is not synchronized, it is the responsibility of the caller to synchronize using the lock member variable
        e.g.
            with (yield managed_client.lock.acquire()):
                yield managed_client.execute_code( code )

        Parameters
        ----------
        code : String
            Python code to be executed

        result_extractor : function [Optional]
            Called when the code has finished executing to extract the results into the returned Future

        Returns
        -------
        Future
        
        """
        if result_extractor is None:
            result_extractor = self._result_extractor
        code = PixieGatewayApp.instance().prepend_execute_code + "\n" + code
        app_log.debug("Executing Code: %s", code)
        future = Future()
        parent_header = self.kernel_manager.execute(self.kernel_handle, code)
        result_accumulator = []

        def on_reply(msg):
            if 'msg_id' in msg['parent_header'] and msg['parent_header'][
                    'msg_id'] == parent_header:
                if not future.done():
                    if "channel" not in msg:
                        msg["channel"] = "iopub"
                    result_accumulator.append(msg)
                    # Complete the future on idle status
                    if msg['header']['msg_type'] == 'status' and msg[
                            'content']['execution_state'] == 'idle':
                        future.set_result(result_extractor(result_accumulator))
                    elif msg['header']['msg_type'] == 'error':
                        error_name = msg['content']['ename']
                        error_value = msg['content']['evalue']
                        trace = sanitize_traceback(msg['content']['traceback'])
                        future.set_exception(
                            CodeExecutionError(error_name, error_value, trace,
                                               code))
            else:
                app_log.warning("Got an orphan message %s",
                                msg['parent_header'])

        self.current_iopub_handler = on_reply

        if done_callback is not None:
            future.add_done_callback(done_callback)

        #attach the future to the kernel to be notified if it dies
        self.kernel_manager.register_execute_future(self.kernel_handle, future)
        if timeout is None:
            return future
        else:
            return gen.with_timeout(timedelta(seconds=timeout), future)
示例#32
0
文件: s5.py 项目: zxy1013/tornado
 def get(self):
     global future
     future = Future()
     future.add_done_callback(self.done)
     yield future  # 阻塞 挂起当前请求 线程处理其他请求
示例#33
0
    def execute_code(self, code, result_extractor=None, done_callback=None):
        """
        Asynchronously execute the given code using the underlying managed kernel client
        Note: this method is not synchronized, it is the responsibility of the caller to synchronize using the lock member variable
        e.g.
            with (yield managed_client.lock.acquire()):
                yield managed_client.execute_code( code )

        Parameters
        ----------
        code : String
            Python code to be executed

        result_extractor : function [Optional]
            Called when the code has finished executing to extract the results into the returned Future

        Returns
        -------
        Future
        
        """
        if result_extractor is None:
            result_extractor = self._result_extractor
        code = PixieGatewayApp.instance().prepend_execute_code + "\n" + code
        app_log.debug("Executing Code: %s", code)
        future = Future()
        parent_header = self.kernel_client.execute(code)
        result_accumulator = []

        def on_reply(msgList):
            session = type(self.kernel_client.session)(
                config=self.kernel_client.session.config,
                key=self.kernel_client.session.key,
            )
            _, msgList = session.feed_identities(msgList)
            msg = session.deserialize(msgList)
            if 'msg_id' in msg['parent_header'] and msg['parent_header'][
                    'msg_id'] == parent_header:
                if not future.done():
                    if "channel" not in msg:
                        msg["channel"] = "iopub"
                    result_accumulator.append(msg)
                    # Complete the future on idle status
                    if msg['header']['msg_type'] == 'status' and msg[
                            'content']['execution_state'] == 'idle':
                        future.set_result(result_extractor(result_accumulator))
                    elif msg['header']['msg_type'] == 'error':
                        error_name = msg['content']['ename']
                        error_value = msg['content']['evalue']
                        traceback = "\n".join(msg['content']['traceback'])
                        future.set_exception(
                            Exception(
                                'Code execution Error {}: {} \nTraceback: {}\nRunning code: {}'
                                .format(error_name, error_value, traceback,
                                        code)))
            else:
                app_log.warning("Got an orphan message %s", msg)

        self.iopub.on_recv(on_reply)

        if done_callback is not None:
            future.add_done_callback(done_callback)
        return future
示例#34
0
 def put_task(self, datapoint, callback):
     f = Future()
     f.add_done_callback(callback)
     self.input_queue.put((datapoint, f))
     return f
    def get(self, *args, **kwargs):
        global future
        future = Future()
        future.add_done_callback(self.done)

        yield future
示例#36
0
class KernelGatewayWSClient(LoggingConfigurable):
    """Proxy web socket connection to a kernel/enterprise gateway."""
    def __init__(self, **kwargs):
        super(KernelGatewayWSClient, self).__init__(**kwargs)
        self.kernel_id = None
        self.ws = None
        self.ws_future = Future()
        self.ws_future_cancelled = False

    @gen.coroutine
    def _connect(self, kernel_id):
        self.kernel_id = kernel_id
        ws_url = url_path_join(
            os.getenv('KG_WS_URL', KG_URL.replace('http', 'ws')),
            '/api/kernels', url_escape(kernel_id), 'channels')
        self.log.info('Connecting to {}'.format(ws_url))
        parameters = {
            "headers": KG_HEADERS,
            "validate_cert": VALIDATE_KG_CERT,
            "connect_timeout": KG_CONNECT_TIMEOUT,
            "request_timeout": KG_REQUEST_TIMEOUT
        }
        if KG_HTTP_USER:
            parameters["auth_username"] = KG_HTTP_USER
        if KG_HTTP_PASS:
            parameters["auth_password"] = KG_HTTP_PASS
        if KG_CLIENT_KEY:
            parameters["client_key"] = KG_CLIENT_KEY
            parameters["client_cert"] = KG_CLIENT_CERT
            if KG_CLIENT_CA:
                parameters["ca_certs"] = KG_CLIENT_CA
        request = HTTPRequest(ws_url, **parameters)
        self.ws_future = websocket_connect(request)
        self.ws_future.add_done_callback(self._connection_done)

    def _connection_done(self, fut):
        if not self.ws_future_cancelled:  # prevent concurrent.futures._base.CancelledError
            self.ws = fut.result()
            self.log.debug("Connection is ready: ws: {}".format(self.ws))
        else:
            self.log.warning(
                "Websocket connection has been cancelled via client disconnect before its establishment.  "
                "Kernel with ID '{}' may not be terminated on Gateway: {}".
                format(self.kernel_id, KG_URL))

    def _disconnect(self):
        if self.ws is not None:
            # Close connection
            self.ws.close()
        elif not self.ws_future.done():
            # Cancel pending connection.  Since future.cancel() is a noop on tornado, we'll track cancellation locally
            self.ws_future.cancel()
            self.ws_future_cancelled = True
            self.log.debug("_disconnect: ws_future_cancelled: {}".format(
                self.ws_future_cancelled))

    @gen.coroutine
    def _read_messages(self, callback):
        """Read messages from gateway server."""
        while True:
            message = None
            if not self.ws_future_cancelled:
                try:
                    message = yield self.ws.read_message()
                except Exception as e:
                    self.log.error(
                        "Exception reading message from websocket: {}".format(
                            e))  # , exc_info=True)
                if message is None:
                    break
                callback(
                    message
                )  # pass back to notebook client (see self.on_open and WebSocketChannelsHandler.open)
            else:  # ws cancelled - stop reading
                break

    def on_open(self, kernel_id, message_callback, **kwargs):
        """Web socket connection open against gateway server."""
        self._connect(kernel_id)
        loop = IOLoop.current()
        loop.add_future(self.ws_future,
                        lambda future: self._read_messages(message_callback))

    def on_message(self, message):
        """Send message to gateway server."""
        if self.ws is None:
            loop = IOLoop.current()
            loop.add_future(self.ws_future,
                            lambda future: self._write_message(message))
        else:
            self._write_message(message)

    def _write_message(self, message):
        """Send message to gateway server."""
        try:
            if not self.ws_future_cancelled:
                self.ws.write_message(message)
        except Exception as e:
            self.log.error("Exception writing message to websocket: {}".format(
                e))  # , exc_info=True)

    def on_close(self):
        """Web socket closed event."""
        self._disconnect()
示例#37
0
class WebSocketTest(AsyncHTTPTestCase):
    def get_app(self):
        self.close_future = Future()
        return Application([
            ('/echo', EchoHandler, dict(close_future=self.close_future)),
            ('/non_ws', NonWebSocketHandler),
            ('/header', HeaderHandler, dict(close_future=self.close_future)),
            ('/close_reason', CloseReasonHandler,
             dict(close_future=self.close_future)),
        ])

    def test_http_request(self):
        # WS server, HTTP client.
        response = self.fetch('/echo')
        self.assertEqual(response.code, 400)

    @gen_test
    def test_websocket_gen(self):
        ws = yield websocket_connect('ws://localhost:%d/echo' %
                                     self.get_http_port(),
                                     io_loop=self.io_loop)
        ws.write_message('hello')
        response = yield ws.read_message()
        self.assertEqual(response, 'hello')
        ws.close()
        yield self.close_future

    def test_websocket_callbacks(self):
        websocket_connect('ws://localhost:%d/echo' % self.get_http_port(),
                          io_loop=self.io_loop,
                          callback=self.stop)
        ws = self.wait().result()
        ws.write_message('hello')
        ws.read_message(self.stop)
        response = self.wait().result()
        self.assertEqual(response, 'hello')
        self.close_future.add_done_callback(lambda f: self.stop())
        ws.close()
        self.wait()

    @gen_test
    def test_websocket_http_fail(self):
        with self.assertRaises(HTTPError) as cm:
            yield websocket_connect('ws://localhost:%d/notfound' %
                                    self.get_http_port(),
                                    io_loop=self.io_loop)
        self.assertEqual(cm.exception.code, 404)

    @gen_test
    def test_websocket_http_success(self):
        with self.assertRaises(WebSocketError):
            yield websocket_connect('ws://localhost:%d/non_ws' %
                                    self.get_http_port(),
                                    io_loop=self.io_loop)

    @gen_test
    def test_websocket_network_fail(self):
        sock, port = bind_unused_port()
        sock.close()
        with self.assertRaises(IOError):
            with ExpectLog(gen_log, ".*"):
                yield websocket_connect('ws://localhost:%d/' % port,
                                        io_loop=self.io_loop,
                                        connect_timeout=3600)

    @gen_test
    def test_websocket_close_buffered_data(self):
        ws = yield websocket_connect('ws://localhost:%d/echo' %
                                     self.get_http_port())
        ws.write_message('hello')
        ws.write_message('world')
        ws.stream.close()
        yield self.close_future

    @gen_test
    def test_websocket_headers(self):
        # Ensure that arbitrary headers can be passed through websocket_connect.
        ws = yield websocket_connect(
            HTTPRequest('ws://localhost:%d/header' % self.get_http_port(),
                        headers={'X-Test': 'hello'}))
        response = yield ws.read_message()
        self.assertEqual(response, 'hello')
        ws.close()
        yield self.close_future

    @gen_test
    def test_server_close_reason(self):
        ws = yield websocket_connect('ws://localhost:%d/close_reason' %
                                     self.get_http_port())
        msg = yield ws.read_message()
        # A message of None means the other side closed the connection.
        self.assertIs(msg, None)
        self.assertEqual(ws.close_code, 1001)
        self.assertEqual(ws.close_reason, "goodbye")

    @gen_test
    def test_client_close_reason(self):
        ws = yield websocket_connect('ws://localhost:%d/echo' %
                                     self.get_http_port())
        ws.close(1001, 'goodbye')
        code, reason = yield self.close_future
        self.assertEqual(code, 1001)
        self.assertEqual(reason, 'goodbye')

    @gen_test
    def test_check_origin_valid_no_path(self):
        port = self.get_http_port()

        url = 'ws://localhost:%d/echo' % port
        headers = {'Origin': 'http://localhost:%d' % port}

        ws = yield websocket_connect(HTTPRequest(url, headers=headers),
                                     io_loop=self.io_loop)
        ws.write_message('hello')
        response = yield ws.read_message()
        self.assertEqual(response, 'hello')
        ws.close()
        yield self.close_future

    @gen_test
    def test_check_origin_valid_with_path(self):
        port = self.get_http_port()

        url = 'ws://localhost:%d/echo' % port
        headers = {'Origin': 'http://localhost:%d/something' % port}

        ws = yield websocket_connect(HTTPRequest(url, headers=headers),
                                     io_loop=self.io_loop)
        ws.write_message('hello')
        response = yield ws.read_message()
        self.assertEqual(response, 'hello')
        ws.close()
        yield self.close_future

    @gen_test
    def test_check_origin_invalid_partial_url(self):
        port = self.get_http_port()

        url = 'ws://localhost:%d/echo' % port
        headers = {'Origin': 'localhost:%d' % port}

        with self.assertRaises(HTTPError) as cm:
            yield websocket_connect(HTTPRequest(url, headers=headers),
                                    io_loop=self.io_loop)
        self.assertEqual(cm.exception.code, 403)

    @gen_test
    def test_check_origin_invalid(self):
        port = self.get_http_port()

        url = 'ws://localhost:%d/echo' % port
        # Host is localhost, which should not be accessible from some other
        # domain
        headers = {'Origin': 'http://somewhereelse.com'}

        with self.assertRaises(HTTPError) as cm:
            yield websocket_connect(HTTPRequest(url, headers=headers),
                                    io_loop=self.io_loop)

        self.assertEqual(cm.exception.code, 403)

    @gen_test
    def test_check_origin_invalid_subdomains(self):
        port = self.get_http_port()

        url = 'ws://localhost:%d/echo' % port
        # Subdomains should be disallowed by default.  If we could pass a
        # resolver to websocket_connect we could test sibling domains as well.
        headers = {'Origin': 'http://subtenant.localhost'}

        with self.assertRaises(HTTPError) as cm:
            yield websocket_connect(HTTPRequest(url, headers=headers),
                                    io_loop=self.io_loop)

        self.assertEqual(cm.exception.code, 403)
示例#38
0
    def get(self):
        global future
        future = Future()
        future.add_done_callback(self.done)

        yield future
示例#39
0
class WebSocketTest(WebSocketBaseTestCase):
    def get_app(self):
        self.close_future = Future()
        return Application([
            ('/echo', EchoHandler, dict(close_future=self.close_future)),
            ('/non_ws', NonWebSocketHandler),
            ('/header', HeaderHandler, dict(close_future=self.close_future)),
            ('/header_echo', HeaderEchoHandler,
             dict(close_future=self.close_future)),
            ('/close_reason', CloseReasonHandler,
             dict(close_future=self.close_future)),
            ('/error_in_on_message', ErrorInOnMessageHandler,
             dict(close_future=self.close_future)),
            ('/async_prepare', AsyncPrepareHandler,
             dict(close_future=self.close_future)),
            ('/path_args/(.*)', PathArgsHandler,
             dict(close_future=self.close_future)),
            ('/coroutine', CoroutineOnMessageHandler,
             dict(close_future=self.close_future)),
            ('/render', RenderMessageHandler,
             dict(close_future=self.close_future)),
        ],
                           template_loader=DictLoader({
                               'message.html':
                               '<b>{{ message }}</b>',
                           }))

    def tearDown(self):
        super(WebSocketTest, self).tearDown()
        RequestHandler._template_loaders.clear()

    def test_http_request(self):
        # WS server, HTTP client.
        response = self.fetch('/echo')
        self.assertEqual(response.code, 400)

    def test_bad_websocket_version(self):
        response = self.fetch('/echo',
                              headers={
                                  'Connection': 'Upgrade',
                                  'Upgrade': 'WebSocket',
                                  'Sec-WebSocket-Version': '12'
                              })
        self.assertEqual(response.code, 426)

    @gen_test
    def test_websocket_gen(self):
        ws = yield self.ws_connect('/echo')
        yield ws.write_message('hello')
        response = yield ws.read_message()
        self.assertEqual(response, 'hello')
        yield self.close(ws)

    def test_websocket_callbacks(self):
        websocket_connect('ws://127.0.0.1:%d/echo' % self.get_http_port(),
                          callback=self.stop)
        ws = self.wait().result()
        ws.write_message('hello')
        ws.read_message(self.stop)
        response = self.wait().result()
        self.assertEqual(response, 'hello')
        self.close_future.add_done_callback(lambda f: self.stop())
        ws.close()
        self.wait()

    @gen_test
    def test_binary_message(self):
        ws = yield self.ws_connect('/echo')
        ws.write_message(b'hello \xe9', binary=True)
        response = yield ws.read_message()
        self.assertEqual(response, b'hello \xe9')
        yield self.close(ws)

    @gen_test
    def test_unicode_message(self):
        ws = yield self.ws_connect('/echo')
        ws.write_message(u'hello \u00e9')
        response = yield ws.read_message()
        self.assertEqual(response, u'hello \u00e9')
        yield self.close(ws)

    @gen_test
    def test_render_message(self):
        ws = yield self.ws_connect('/render')
        ws.write_message('hello')
        response = yield ws.read_message()
        self.assertEqual(response, '<b>hello</b>')
        yield self.close(ws)

    @gen_test
    def test_error_in_on_message(self):
        ws = yield self.ws_connect('/error_in_on_message')
        ws.write_message('hello')
        with ExpectLog(app_log, "Uncaught exception"):
            response = yield ws.read_message()
        self.assertIs(response, None)
        yield self.close(ws)

    @gen_test
    def test_websocket_http_fail(self):
        with self.assertRaises(HTTPError) as cm:
            yield self.ws_connect('/notfound')
        self.assertEqual(cm.exception.code, 404)

    @gen_test
    def test_websocket_http_success(self):
        with self.assertRaises(WebSocketError):
            yield self.ws_connect('/non_ws')

    @gen_test
    def test_websocket_network_fail(self):
        sock, port = bind_unused_port()
        sock.close()
        with self.assertRaises(IOError):
            with ExpectLog(gen_log, ".*"):
                yield websocket_connect('ws://127.0.0.1:%d/' % port,
                                        connect_timeout=3600)

    @gen_test
    def test_websocket_close_buffered_data(self):
        ws = yield websocket_connect('ws://127.0.0.1:%d/echo' %
                                     self.get_http_port())
        ws.write_message('hello')
        ws.write_message('world')
        # Close the underlying stream.
        ws.stream.close()
        yield self.close_future

    @gen_test
    def test_websocket_headers(self):
        # Ensure that arbitrary headers can be passed through websocket_connect.
        ws = yield websocket_connect(
            HTTPRequest('ws://127.0.0.1:%d/header' % self.get_http_port(),
                        headers={'X-Test': 'hello'}))
        response = yield ws.read_message()
        self.assertEqual(response, 'hello')
        yield self.close(ws)

    @gen_test
    def test_websocket_header_echo(self):
        # Ensure that headers can be returned in the response.
        # Specifically, that arbitrary headers passed through websocket_connect
        # can be returned.
        ws = yield websocket_connect(
            HTTPRequest('ws://127.0.0.1:%d/header_echo' % self.get_http_port(),
                        headers={'X-Test-Hello': 'hello'}))
        self.assertEqual(ws.headers.get('X-Test-Hello'), 'hello')
        self.assertEqual(ws.headers.get('X-Extra-Response-Header'),
                         'Extra-Response-Value')
        yield self.close(ws)

    @gen_test
    def test_server_close_reason(self):
        ws = yield self.ws_connect('/close_reason')
        msg = yield ws.read_message()
        # A message of None means the other side closed the connection.
        self.assertIs(msg, None)
        self.assertEqual(ws.close_code, 1001)
        self.assertEqual(ws.close_reason, "goodbye")
        # The on_close callback is called no matter which side closed.
        code, reason = yield self.close_future
        # The client echoed the close code it received to the server,
        # so the server's close code (returned via close_future) is
        # the same.
        self.assertEqual(code, 1001)

    @gen_test
    def test_client_close_reason(self):
        ws = yield self.ws_connect('/echo')
        ws.close(1001, 'goodbye')
        code, reason = yield self.close_future
        self.assertEqual(code, 1001)
        self.assertEqual(reason, 'goodbye')

    @gen_test
    def test_write_after_close(self):
        ws = yield self.ws_connect('/close_reason')
        msg = yield ws.read_message()
        self.assertIs(msg, None)
        with self.assertRaises(StreamClosedError):
            ws.write_message('hello')

    @gen_test
    def test_async_prepare(self):
        # Previously, an async prepare method triggered a bug that would
        # result in a timeout on test shutdown (and a memory leak).
        ws = yield self.ws_connect('/async_prepare')
        ws.write_message('hello')
        res = yield ws.read_message()
        self.assertEqual(res, 'hello')

    @gen_test
    def test_path_args(self):
        ws = yield self.ws_connect('/path_args/hello')
        res = yield ws.read_message()
        self.assertEqual(res, 'hello')

    @gen_test
    def test_coroutine(self):
        ws = yield self.ws_connect('/coroutine')
        # Send both messages immediately, coroutine must process one at a time.
        yield ws.write_message('hello1')
        yield ws.write_message('hello2')
        res = yield ws.read_message()
        self.assertEqual(res, 'hello1')
        res = yield ws.read_message()
        self.assertEqual(res, 'hello2')

    @gen_test
    def test_check_origin_valid_no_path(self):
        port = self.get_http_port()

        url = 'ws://127.0.0.1:%d/echo' % port
        headers = {'Origin': 'http://127.0.0.1:%d' % port}

        ws = yield websocket_connect(HTTPRequest(url, headers=headers))
        ws.write_message('hello')
        response = yield ws.read_message()
        self.assertEqual(response, 'hello')
        yield self.close(ws)

    @gen_test
    def test_check_origin_valid_with_path(self):
        port = self.get_http_port()

        url = 'ws://127.0.0.1:%d/echo' % port
        headers = {'Origin': 'http://127.0.0.1:%d/something' % port}

        ws = yield websocket_connect(HTTPRequest(url, headers=headers))
        ws.write_message('hello')
        response = yield ws.read_message()
        self.assertEqual(response, 'hello')
        yield self.close(ws)

    @gen_test
    def test_check_origin_invalid_partial_url(self):
        port = self.get_http_port()

        url = 'ws://127.0.0.1:%d/echo' % port
        headers = {'Origin': '127.0.0.1:%d' % port}

        with self.assertRaises(HTTPError) as cm:
            yield websocket_connect(HTTPRequest(url, headers=headers))
        self.assertEqual(cm.exception.code, 403)

    @gen_test
    def test_check_origin_invalid(self):
        port = self.get_http_port()

        url = 'ws://127.0.0.1:%d/echo' % port
        # Host is 127.0.0.1, which should not be accessible from some other
        # domain
        headers = {'Origin': 'http://somewhereelse.com'}

        with self.assertRaises(HTTPError) as cm:
            yield websocket_connect(HTTPRequest(url, headers=headers))

        self.assertEqual(cm.exception.code, 403)

    @gen_test
    def test_check_origin_invalid_subdomains(self):
        port = self.get_http_port()

        url = 'ws://localhost:%d/echo' % port
        # Subdomains should be disallowed by default.  If we could pass a
        # resolver to websocket_connect we could test sibling domains as well.
        headers = {'Origin': 'http://subtenant.localhost'}

        with self.assertRaises(HTTPError) as cm:
            yield websocket_connect(HTTPRequest(url, headers=headers))

        self.assertEqual(cm.exception.code, 403)
示例#40
0
class GatewayWebSocketClient(LoggingConfigurable):
    """Proxy web socket connection to a kernel/enterprise gateway."""

    def __init__(self, **kwargs):
        super(GatewayWebSocketClient, self).__init__(**kwargs)
        self.kernel_id = None
        self.ws = None
        self.ws_future = Future()
        self.disconnected = False

    async def _connect(self, kernel_id):
        # websocket is initialized before connection
        self.ws = None
        self.kernel_id = kernel_id
        ws_url = url_path_join(
            GatewayClient.instance().ws_url,
            GatewayClient.instance().kernels_endpoint, url_escape(kernel_id), 'channels'
        )
        self.log.info('Connecting to {}'.format(ws_url))
        kwargs = {}
        kwargs = GatewayClient.instance().load_connection_args(**kwargs)

        request = HTTPRequest(ws_url, **kwargs)
        self.ws_future = websocket_connect(request)
        self.ws_future.add_done_callback(self._connection_done)

    def _connection_done(self, fut):
        if not self.disconnected and fut.exception() is None:  # prevent concurrent.futures._base.CancelledError
            self.ws = fut.result()
            self.log.debug("Connection is ready: ws: {}".format(self.ws))
        else:
            self.log.warning("Websocket connection has been closed via client disconnect or due to error.  "
                             "Kernel with ID '{}' may not be terminated on GatewayClient: {}".
                             format(self.kernel_id, GatewayClient.instance().url))

    def _disconnect(self):
        self.disconnected = True
        if self.ws is not None:
            # Close connection
            self.ws.close()
        elif not self.ws_future.done():
            # Cancel pending connection.  Since future.cancel() is a noop on tornado, we'll track cancellation locally
            self.ws_future.cancel()
            self.log.debug("_disconnect: future cancelled, disconnected: {}".format(self.disconnected))

    async def _read_messages(self, callback):
        """Read messages from gateway server."""
        while self.ws is not None:
            message = None
            if not self.disconnected:
                try:
                    message = await self.ws.read_message()
                except Exception as e:
                    self.log.error("Exception reading message from websocket: {}".format(e))  # , exc_info=True)
                if message is None:
                    if not self.disconnected:
                        self.log.warning("Lost connection to Gateway: {}".format(self.kernel_id))
                    break
                callback(message)  # pass back to notebook client (see self.on_open and WebSocketChannelsHandler.open)
            else:  # ws cancelled - stop reading
                break

        if not self.disconnected: # if websocket is not disconnected by client, attept to reconnect to Gateway
            self.log.info("Attempting to re-establish the connection to Gateway: {}".format(self.kernel_id))
            self._connect(self.kernel_id)
            loop = IOLoop.current()
            loop.add_future(self.ws_future, lambda future: self._read_messages(callback))

    def on_open(self, kernel_id, message_callback, **kwargs):
        """Web socket connection open against gateway server."""
        self._connect(kernel_id)
        loop = IOLoop.current()
        loop.add_future(
            self.ws_future,
            lambda future: self._read_messages(message_callback)
        )

    def on_message(self, message):
        """Send message to gateway server."""
        if self.ws is None:
            loop = IOLoop.current()
            loop.add_future(
                self.ws_future,
                lambda future: self._write_message(message)
            )
        else:
            self._write_message(message)

    def _write_message(self, message):
        """Send message to gateway server."""
        try:
            if not self.disconnected and self.ws is not None:
                self.ws.write_message(message)
        except Exception as e:
            self.log.error("Exception writing message to websocket: {}".format(e))  # , exc_info=True)

    def on_close(self):
        """Web socket closed event."""
        self._disconnect()
示例#41
0
class WebSocketTest(WebSocketBaseTestCase):
    def get_app(self):
        self.close_future = Future()  # type: Future[None]
        return Application(
            [
                ("/echo", EchoHandler, dict(close_future=self.close_future)),
                ("/redirect", RedirectHandler, dict(url="/echo")),
                ("/double_redirect", RedirectHandler, dict(url="/redirect")),
                ("/non_ws", NonWebSocketHandler),
                ("/header", HeaderHandler,
                 dict(close_future=self.close_future)),
                (
                    "/header_echo",
                    HeaderEchoHandler,
                    dict(close_future=self.close_future),
                ),
                (
                    "/close_reason",
                    CloseReasonHandler,
                    dict(close_future=self.close_future),
                ),
                (
                    "/error_in_on_message",
                    ErrorInOnMessageHandler,
                    dict(close_future=self.close_future),
                ),
                (
                    "/async_prepare",
                    AsyncPrepareHandler,
                    dict(close_future=self.close_future),
                ),
                (
                    "/path_args/(.*)",
                    PathArgsHandler,
                    dict(close_future=self.close_future),
                ),
                (
                    "/coroutine",
                    CoroutineOnMessageHandler,
                    dict(close_future=self.close_future),
                ),
                ("/render", RenderMessageHandler,
                 dict(close_future=self.close_future)),
                (
                    "/subprotocol",
                    SubprotocolHandler,
                    dict(close_future=self.close_future),
                ),
                (
                    "/open_coroutine",
                    OpenCoroutineHandler,
                    dict(close_future=self.close_future, test=self),
                ),
            ],
            template_loader=DictLoader(
                {"message.html": "<b>{{ message }}</b>"}),
        )

    def get_http_client(self):
        # These tests require HTTP/1; force the use of SimpleAsyncHTTPClient.
        return SimpleAsyncHTTPClient()

    def tearDown(self):
        super(WebSocketTest, self).tearDown()
        RequestHandler._template_loaders.clear()

    def test_http_request(self):
        # WS server, HTTP client.
        response = self.fetch("/echo")
        self.assertEqual(response.code, 400)

    def test_missing_websocket_key(self):
        response = self.fetch(
            "/echo",
            headers={
                "Connection": "Upgrade",
                "Upgrade": "WebSocket",
                "Sec-WebSocket-Version": "13",
            },
        )
        self.assertEqual(response.code, 400)

    def test_bad_websocket_version(self):
        response = self.fetch(
            "/echo",
            headers={
                "Connection": "Upgrade",
                "Upgrade": "WebSocket",
                "Sec-WebSocket-Version": "12",
            },
        )
        self.assertEqual(response.code, 426)

    @gen_test
    def test_websocket_gen(self):
        ws = yield self.ws_connect("/echo")
        yield ws.write_message("hello")
        response = yield ws.read_message()
        self.assertEqual(response, "hello")

    def test_websocket_callbacks(self):
        websocket_connect("ws://127.0.0.1:%d/echo" % self.get_http_port(),
                          callback=self.stop)
        ws = self.wait().result()
        ws.write_message("hello")
        ws.read_message(self.stop)
        response = self.wait().result()
        self.assertEqual(response, "hello")
        self.close_future.add_done_callback(lambda f: self.stop())
        ws.close()
        self.wait()

    @gen_test
    def test_binary_message(self):
        ws = yield self.ws_connect("/echo")
        ws.write_message(b"hello \xe9", binary=True)
        response = yield ws.read_message()
        self.assertEqual(response, b"hello \xe9")

    @gen_test
    def test_unicode_message(self):
        ws = yield self.ws_connect("/echo")
        ws.write_message(u"hello \u00e9")
        response = yield ws.read_message()
        self.assertEqual(response, u"hello \u00e9")

    @gen_test
    def test_render_message(self):
        ws = yield self.ws_connect("/render")
        ws.write_message("hello")
        response = yield ws.read_message()
        self.assertEqual(response, "<b>hello</b>")

    @gen_test
    def test_error_in_on_message(self):
        ws = yield self.ws_connect("/error_in_on_message")
        ws.write_message("hello")
        with ExpectLog(app_log, "Uncaught exception"):
            response = yield ws.read_message()
        self.assertIs(response, None)

    @gen_test
    def test_websocket_http_fail(self):
        with self.assertRaises(HTTPError) as cm:
            yield self.ws_connect("/notfound")
        self.assertEqual(cm.exception.code, 404)

    @gen_test
    def test_websocket_http_success(self):
        with self.assertRaises(WebSocketError):
            yield self.ws_connect("/non_ws")

    @gen_test
    def test_websocket_redirect(self):
        conn = yield self.ws_connect("/redirect")
        yield conn.write_message("hello redirect")
        msg = yield conn.read_message()
        self.assertEqual("hello redirect", msg)

    @gen_test
    def test_websocket_double_redirect(self):
        conn = yield self.ws_connect("/double_redirect")
        yield conn.write_message("hello redirect")
        msg = yield conn.read_message()
        self.assertEqual("hello redirect", msg)

    @gen_test
    def test_websocket_network_fail(self):
        sock, port = bind_unused_port()
        sock.close()
        with self.assertRaises(IOError):
            with ExpectLog(gen_log, ".*"):
                yield websocket_connect("ws://127.0.0.1:%d/" % port,
                                        connect_timeout=3600)

    @gen_test
    def test_websocket_close_buffered_data(self):
        ws = yield websocket_connect("ws://127.0.0.1:%d/echo" %
                                     self.get_http_port())
        ws.write_message("hello")
        ws.write_message("world")
        # Close the underlying stream.
        ws.stream.close()

    @gen_test
    def test_websocket_headers(self):
        # Ensure that arbitrary headers can be passed through websocket_connect.
        ws = yield websocket_connect(
            HTTPRequest(
                "ws://127.0.0.1:%d/header" % self.get_http_port(),
                headers={"X-Test": "hello"},
            ))
        response = yield ws.read_message()
        self.assertEqual(response, "hello")

    @gen_test
    def test_websocket_header_echo(self):
        # Ensure that headers can be returned in the response.
        # Specifically, that arbitrary headers passed through websocket_connect
        # can be returned.
        ws = yield websocket_connect(
            HTTPRequest(
                "ws://127.0.0.1:%d/header_echo" % self.get_http_port(),
                headers={"X-Test-Hello": "hello"},
            ))
        self.assertEqual(ws.headers.get("X-Test-Hello"), "hello")
        self.assertEqual(ws.headers.get("X-Extra-Response-Header"),
                         "Extra-Response-Value")

    @gen_test
    def test_server_close_reason(self):
        ws = yield self.ws_connect("/close_reason")
        msg = yield ws.read_message()
        # A message of None means the other side closed the connection.
        self.assertIs(msg, None)
        self.assertEqual(ws.close_code, 1001)
        self.assertEqual(ws.close_reason, "goodbye")
        # The on_close callback is called no matter which side closed.
        code, reason = yield self.close_future
        # The client echoed the close code it received to the server,
        # so the server's close code (returned via close_future) is
        # the same.
        self.assertEqual(code, 1001)

    @gen_test
    def test_client_close_reason(self):
        ws = yield self.ws_connect("/echo")
        ws.close(1001, "goodbye")
        code, reason = yield self.close_future
        self.assertEqual(code, 1001)
        self.assertEqual(reason, "goodbye")

    @gen_test
    def test_write_after_close(self):
        ws = yield self.ws_connect("/close_reason")
        msg = yield ws.read_message()
        self.assertIs(msg, None)
        with self.assertRaises(WebSocketClosedError):
            ws.write_message("hello")

    @gen_test
    def test_async_prepare(self):
        # Previously, an async prepare method triggered a bug that would
        # result in a timeout on test shutdown (and a memory leak).
        ws = yield self.ws_connect("/async_prepare")
        ws.write_message("hello")
        res = yield ws.read_message()
        self.assertEqual(res, "hello")

    @gen_test
    def test_path_args(self):
        ws = yield self.ws_connect("/path_args/hello")
        res = yield ws.read_message()
        self.assertEqual(res, "hello")

    @gen_test
    def test_coroutine(self):
        ws = yield self.ws_connect("/coroutine")
        # Send both messages immediately, coroutine must process one at a time.
        yield ws.write_message("hello1")
        yield ws.write_message("hello2")
        res = yield ws.read_message()
        self.assertEqual(res, "hello1")
        res = yield ws.read_message()
        self.assertEqual(res, "hello2")

    @gen_test
    def test_check_origin_valid_no_path(self):
        port = self.get_http_port()

        url = "ws://127.0.0.1:%d/echo" % port
        headers = {"Origin": "http://127.0.0.1:%d" % port}

        ws = yield websocket_connect(HTTPRequest(url, headers=headers))
        ws.write_message("hello")
        response = yield ws.read_message()
        self.assertEqual(response, "hello")

    @gen_test
    def test_check_origin_valid_with_path(self):
        port = self.get_http_port()

        url = "ws://127.0.0.1:%d/echo" % port
        headers = {"Origin": "http://127.0.0.1:%d/something" % port}

        ws = yield websocket_connect(HTTPRequest(url, headers=headers))
        ws.write_message("hello")
        response = yield ws.read_message()
        self.assertEqual(response, "hello")

    @gen_test
    def test_check_origin_invalid_partial_url(self):
        port = self.get_http_port()

        url = "ws://127.0.0.1:%d/echo" % port
        headers = {"Origin": "127.0.0.1:%d" % port}

        with self.assertRaises(HTTPError) as cm:
            yield websocket_connect(HTTPRequest(url, headers=headers))
        self.assertEqual(cm.exception.code, 403)

    @gen_test
    def test_check_origin_invalid(self):
        port = self.get_http_port()

        url = "ws://127.0.0.1:%d/echo" % port
        # Host is 127.0.0.1, which should not be accessible from some other
        # domain
        headers = {"Origin": "http://somewhereelse.com"}

        with self.assertRaises(HTTPError) as cm:
            yield websocket_connect(HTTPRequest(url, headers=headers))

        self.assertEqual(cm.exception.code, 403)

    @gen_test
    def test_check_origin_invalid_subdomains(self):
        port = self.get_http_port()

        url = "ws://localhost:%d/echo" % port
        # Subdomains should be disallowed by default.  If we could pass a
        # resolver to websocket_connect we could test sibling domains as well.
        headers = {"Origin": "http://subtenant.localhost"}

        with self.assertRaises(HTTPError) as cm:
            yield websocket_connect(HTTPRequest(url, headers=headers))

        self.assertEqual(cm.exception.code, 403)

    @gen_test
    def test_subprotocols(self):
        ws = yield self.ws_connect("/subprotocol",
                                   subprotocols=["badproto", "goodproto"])
        self.assertEqual(ws.selected_subprotocol, "goodproto")
        res = yield ws.read_message()
        self.assertEqual(res, "subprotocol=goodproto")

    @gen_test
    def test_subprotocols_not_offered(self):
        ws = yield self.ws_connect("/subprotocol")
        self.assertIs(ws.selected_subprotocol, None)
        res = yield ws.read_message()
        self.assertEqual(res, "subprotocol=None")

    @gen_test
    def test_open_coroutine(self):
        self.message_sent = Event()
        ws = yield self.ws_connect("/open_coroutine")
        yield ws.write_message("hello")
        self.message_sent.set()
        res = yield ws.read_message()
        self.assertEqual(res, "ok")
示例#42
0
    def fetch(self, request, callback=None, raise_error=True, **kwargs):
        """Executes a request, asynchronously returning an `HTTPResponse`.

        The request may be either a string URL or an `HTTPRequest` object.
        If it is a string, we construct an `HTTPRequest` using any additional
        kwargs: ``HTTPRequest(request, **kwargs)``

        This method returns a `.Future` whose result is an
        `HTTPResponse`. By default, the ``Future`` will raise an
        `HTTPError` if the request returned a non-200 response code
        (other errors may also be raised if the server could not be
        contacted). Instead, if ``raise_error`` is set to False, the
        response will always be returned regardless of the response
        code.

        If a ``callback`` is given, it will be invoked with the `HTTPResponse`.
        In the callback interface, `HTTPError` is not automatically raised.
        Instead, you must check the response's ``error`` attribute or
        call its `~HTTPResponse.rethrow` method.
        """
        if self._closed:
            raise RuntimeError("fetch() called on closed AsyncHTTPClient")
        if not isinstance(request, HTTPRequest):
            request = HTTPRequest(url=request, **kwargs)
        else:
            if kwargs:
                raise ValueError(
                    "kwargs can't be used if request is an HTTPRequest object")
        # We may modify this (to add Host, Accept-Encoding, etc),
        # so make sure we don't modify the caller's object.  This is also
        # where normal dicts get converted to HTTPHeaders objects.
        request.headers = httputil.HTTPHeaders(request.headers)
        request = _RequestProxy(request, self.defaults)
        future = Future()
        if callback is not None:
            callback = stack_context.wrap(callback)

            def handle_future(future):
                exc = future.exception()
                if isinstance(exc, HTTPError) and exc.response is not None:
                    response = exc.response
                elif exc is not None:
                    response = HTTPResponse(request,
                                            599,
                                            error=exc,
                                            request_time=time.time() -
                                            request.start_time)
                else:
                    response = future.result()
                self.io_loop.add_callback(callback, response)

            future.add_done_callback(handle_future)

        def handle_response(response):
            if raise_error and response.error:
                future.set_exception(response.error)
            else:
                future_set_result_unless_cancelled(future, response)

        self.fetch_impl(request, handle_response)
        return future