Ejemplo n.º 1
0
    def on_closed(self, exc_info=False):
        if self._connect_future:
            if exc_info:
                self._connect_future.set_exception(
                    exc_info[1] if isinstance(exc_info, tuple) else exc_info)
            else:
                self._connect_future.set_exception(StreamClosedError(None))
            self._connect_future = None

        if self._connect_ssl_future:
            if exc_info:
                self._connect_ssl_future.set_exception(
                    exc_info[1] if isinstance(exc_info, tuple) else exc_info)
            else:
                self._connect_ssl_future.set_exception(StreamClosedError(None))
            self._connect_ssl_future = None

        if self._read_future:
            if exc_info:
                self._read_future.set_exception(
                    exc_info[1] if isinstance(exc_info, tuple) else exc_info)
            else:
                self._read_future.set_exception(StreamClosedError(None))
            self._read_future = None

        if self._close_callback:
            close_callback, self._close_callback = self._close_callback, None
            self._loop.call_soon(close_callback)

        self._closed = True
Ejemplo n.º 2
0
    def read(self, length):
        '''Read from socket.

        This function reads only from the socket and not the buffer.

        This function uses a blocking fast path consecutively for
        every 100 reads.
        '''
        self.stop_monitor_for_close()

        if self._blocking_counter < 100:
            try:
                data = self._socket.recv(length)
            except ssl.SSLError as error:
                if error.errno != ssl.SSL_ERROR_WANT_READ:
                    raise
            except IOError as error:
                if error.errno not in (errno.EWOULDBLOCK, errno.EINPROGRESS):
                    raise
            else:
                if data:
                    self._blocking_counter += 1
                    raise tornado.gen.Return(data)
                else:
                    self.close()
                    raise StreamClosedError('Stream unexpectedly closed.')
        else:
            self._blocking_counter = 0

        while True:
            events = yield self._wait_event(READ | ERROR,
                                            timeout=self._rw_timeout)

            if events & ERROR:
                self._raise_socket_error()

            try:
                data = self._socket.recv(length)
            except ssl.SSLError as error:
                if error.errno == ssl.SSL_ERROR_WANT_READ:
                    continue
                else:
                    raise
            else:
                break

        if not data:
            self.close()
            raise StreamClosedError('Stream unexpectedly closed.')

        raise tornado.gen.Return(data)
Ejemplo n.º 3
0
    def write(self, data):
        assert isinstance(data, bytes)
        if self._closed:
            raise StreamClosedError(real_error=self.error)

        if not data:
            if self._write_future:
                return self._write_future
            future = Future()
            future.set_result(None)
            return future

        if self._write_buffer_size:
            self._write_buffer += data
        else:
            self._write_buffer = bytearray(data)
        self._write_buffer_size += len(data)
        future = self._write_future = Future()

        self._handle_write()
        if self._write_buffer_size:
            if not self._state & self.io_loop.WRITE:
                self._state = self._state | self.io_loop.WRITE
                self.io_loop.update_handler(self.fileno(), self._state)

        return future
Ejemplo n.º 4
0
    async def connect(self, address, deserialize=True, **connection_args):
        self._check_encryption(address, connection_args)
        ip, port = parse_host_port(address)
        kwargs = self._get_connect_args(**connection_args)

        try:
            stream = await self.client.connect(
                ip, port, max_buffer_size=MAX_BUFFER_SIZE, **kwargs
            )
            # Under certain circumstances tornado will have a closed connnection with an
            # error and not raise a StreamClosedError.
            #
            # This occurs with tornado 5.x and openssl 1.1+
            if stream.closed() and stream.error:
                raise StreamClosedError(stream.error)

        except StreamClosedError as e:
            # The socket connect() call failed
            convert_stream_closed_error(self, e)
        except SSLCertVerificationError as err:
            raise FatalCommClosedError(
                "TLS certificate does not match. Check your security settings. "
                "More info at https://distributed.dask.org/en/latest/tls.html"
            ) from err
        except SSLError as err:
            raise FatalCommClosedError() from err

        local_address = self.prefix + get_stream_address(stream)
        comm = self.comm_class(
            stream, local_address, self.prefix + address, deserialize
        )

        return comm
Ejemplo n.º 5
0
    async def write(self, msg, serializers=None, on_error="message"):
        stream = self.stream
        if stream is None:
            raise CommClosedError()

        frames = await to_frames(
            msg,
            allow_offload=self.allow_offload,
            serializers=serializers,
            on_error=on_error,
            context={
                "sender": self.local_info,
                "recipient": self.remote_info,
                **self.handshake_options,
            },
        )
        frames_nbytes = sum(map(nbytes, frames))

        header = pack_frames_prelude(frames)
        header = struct.pack("Q", nbytes(header) + frames_nbytes) + header

        frames = [header, *frames]
        frames_nbytes += nbytes(header)

        if frames_nbytes < 2 ** 17:  # 128kiB
            # small enough, send in one go
            frames = [b"".join(frames)]

        try:
            # trick to enque all frames for writing beforehand
            for each_frame in frames:
                each_frame_nbytes = nbytes(each_frame)
                if each_frame_nbytes:
                    if stream._write_buffer is None:
                        raise StreamClosedError()

                    if isinstance(each_frame, memoryview):
                        # Make sure that len(data) == data.nbytes`
                        # See <https://github.com/tornadoweb/tornado/pull/2996>
                        each_frame = memoryview(each_frame).cast("B")

                    stream._write_buffer.append(each_frame)
                    stream._total_write_index += each_frame_nbytes

            # start writing frames
            stream.write(b"")
        except StreamClosedError as e:
            self.stream = None
            self._closed = True
            if not shutting_down():
                convert_stream_closed_error(self, e)
        except Exception:
            # Some OSError or a another "low-level" exception. We do not really know
            # what was already written to the underlying socket, so it is not even safe
            # to retry here using the same stream. The only safe thing to do is to
            # abort. (See also GitHub #4133).
            self.abort()
            raise

        return frames_nbytes
Ejemplo n.º 6
0
    async def connect(self, address, deserialize=True, **connection_args):
        self._check_encryption(address, connection_args)
        ip, port = parse_host_port(address)
        kwargs = self._get_connect_args(**connection_args)

        try:
            stream = await self.client.connect(
                ip, port, max_buffer_size=MAX_BUFFER_SIZE, **kwargs
            )

            # Under certain circumstances tornado will have a closed connnection with an error and not raise
            # a StreamClosedError.
            #
            # This occurs with tornado 5.x and openssl 1.1+
            if stream.closed() and stream.error:
                raise StreamClosedError(stream.error)

        except StreamClosedError as e:
            # The socket connect() call failed
            convert_stream_closed_error(self, e)

        local_address = self.prefix + get_stream_address(stream)
        return self.comm_class(
            stream, local_address, self.prefix + address, deserialize
        )
Ejemplo n.º 7
0
 def read_until(end):
     nonlocal lineno
     yield gen.sleep(0.01)
     if lineno == len(stdout):
         raise StreamClosedError()
     lineno += 1
     return '{}\n'.format(stdout[lineno - 1])
Ejemplo n.º 8
0
    def test_handle_stream_closed(self):
        self.layer.socks_conn = Mock()
        self.layer.socks_conn.send = Mock(side_effect=self.collect_send_event)

        socks_request = Request(REQ_COMMAND["CONNECT"], ADDR_TYPE["IPV4"],
                                u"1.2.3.4", self.port)

        addr_not_support_status = RESP_STATUS["ADDRESS_TYPE_NOT_SUPPORTED"]
        network_unreach_status = RESP_STATUS["NETWORK_UNREACHABLE"]
        general_fail_status = RESP_STATUS["GENRAL_FAILURE"]

        error_cases = [(errno.ENOEXEC, addr_not_support_status),
                       (errno.EBADF, addr_not_support_status),
                       (errno.ETIMEDOUT, network_unreach_status),
                       (errno.EADDRINUSE, general_fail_status),
                       (55566, general_fail_status)]

        for code, expect_status in error_cases:
            self.layer.create_dest_stream = Mock(
                side_effect=self.create_raise_exception_function(
                    StreamClosedError((code, ))))
            result_future = self.layer.handle_request_and_create_destination(
                socks_request)
            with self.assertRaises(DestNotConnectedError):
                yield result_future

            self.assertIsNotNone(self.event)
            self.assertIsInstance(self.event, Response)
            self.assertEqual(self.event.status, expect_status)
            self.assertEqual(self.event.atyp, ADDR_TYPE["IPV4"])
            self.assertEqual(str(self.event.addr), "1.2.3.4")
            self.assertEqual(self.event.port, self.port)
Ejemplo n.º 9
0
    def send(self, msg):
        """ Send a message to the other side

        This completes quickly and synchronously
        """
        try:
            if self.stream is None:  # not yet started
                self.buffer.append(msg)
                return

            if self.stream._closed:
                raise StreamClosedError()

            if self.buffer:
                self.buffer.append(msg)
                return

            # If we're new and early,
            now = default_timer()
            if (now < self.last_transmission + self.interval
                    or not self.last_send._done):
                self.buffer.append(msg)
                self.loop.add_callback(self.send_next)
                return

            self.buffer.append(msg)
            self.loop.add_callback(self.send_next, wait=False)
        except Exception as e:
            logger.exception(e)
            raise
Ejemplo n.º 10
0
    def write(self, data):
        '''Write all data to socket.

        This function uses a blocking fast path consecutively for
        every 100 writes.
        '''
        self.stop_monitor_for_close()

        total_bytes_sent = 0

        while total_bytes_sent < len(data):
            if self._blocking_counter < 100:
                try:
                    bytes_sent = self._socket.send(data[total_bytes_sent:])
                except ssl.SSLError as error:
                    if error.errno != ssl.SSL_ERROR_WANT_WRITE:
                        raise
                except IOError as error:
                    if error.errno not in (errno.EWOULDBLOCK,
                                           errno.EINPROGRESS):
                        raise
                else:
                    if not bytes_sent:
                        self.close()
                        raise StreamClosedError('Stream unexpectedly closed.')
                    else:
                        total_bytes_sent += bytes_sent
                        continue
            else:
                self._blocking_counter = 0

            events = yield self._wait_event(WRITE | ERROR,
                                            timeout=self._rw_timeout)

            if events & ERROR:
                self._raise_socket_error()

            bytes_sent = self._socket.send(data[total_bytes_sent:])

            if not bytes_sent:
                self.close()
                raise StreamClosedError('Stream unexpectedly closed.')

            total_bytes_sent += bytes_sent
Ejemplo n.º 11
0
 def consume(self, amount):
     while not self.closed and self.size <= 0:
         yield self.cond.wait()
     if self.closed:
         raise StreamClosedError()
     if self.size < amount:
         amount = self.size
     if self.parent is not None:
         amount = yield self.parent.consume(amount)
     self.size -= amount
     raise gen.Return(amount)
Ejemplo n.º 12
0
 async def connect(self, address, deserialize=True, **connection_args):
     kwargs = self._get_connect_args(**connection_args)
     try:
         request = HTTPRequest(f"{self.prefix}{address}", **kwargs)
         sock = await websocket_connect(request,
                                        max_message_size=10_000_000_000)
         if sock.stream.closed() and sock.stream.error:
             raise StreamClosedError(sock.stream.error)
     except StreamClosedError as e:
         convert_stream_closed_error(self, e)
     except SSLError as err:
         raise FatalCommClosedError() from err
     return self.comm_class(sock, deserialize=deserialize)
Ejemplo n.º 13
0
    def send(self, msg):
        """ Schedule a message for sending to the other side

        This completes quickly and synchronously
        """
        if self.stream is not None and self.stream._closed:
            raise StreamClosedError()

        self.message_count += 1
        self.buffer.append(msg)
        # Avoid spurious wakeups if possible
        if self.next_deadline is None:
            self.waker.set()
Ejemplo n.º 14
0
    def read_bytes(self, num_bytes):
        assert self._read_future is None, "Already reading"
        if self._closed:
            raise StreamClosedError(IOError('Already Closed'))

        future = self._read_future = Future()
        self._read_bytes = num_bytes
        if self._read_buffer_size >= self._read_bytes:
            future, self._read_future = self._read_future, None
            self._read_buffer, data = bytearray(), self._read_buffer
            self._read_buffer_size = 0
            self._read_bytes = 0
            future.set_result(data)
        return future
Ejemplo n.º 15
0
 def close_callback():
     if not future.done():
         # Note that unlike most Futures returned by IOStream,
         # this one passes the underlying error through directly
         # instead of wrapping everything in a StreamClosedError
         # with a real_error attribute. This is because once the
         # connection is established it's more helpful to raise
         # the SSLError directly than to hide it behind a
         # StreamClosedError (and the client is expecting SSL
         # issues rather than network issues since this method is
         # named start_tls).
         future.set_exception(ssl_stream.error or StreamClosedError())
     if orig_close_callback is not None:
         orig_close_callback()
Ejemplo n.º 16
0
    def write(self, data):
        assert isinstance(data, bytes)
        if self._closed:
            raise StreamClosedError(real_error=self.error)

        if data:
            self._write_buffer.append(data)
            self._write_buffer_size += len(data)

        if not self._connecting:
            self._handle_write()
            if self._write_buffer:
                if not self._state & self.io_loop.WRITE:
                    self._state = self._state | self.io_loop.WRITE
                    self.io_loop.update_handler(self.fileno(), self._state)
Ejemplo n.º 17
0
    def read(self, num_bytes):
        assert self._read_future is None, "Already reading"
        if self._closed:
            raise StreamClosedError(real_error=self.error)

        future = self._read_future = Future()
        self._read_bytes = num_bytes
        self._read_partial = False
        if self._read_buffer_size >= self._read_bytes:
            future, self._read_future = self._read_future, None
            self._read_buffer, data = bytearray(), self._read_buffer
            self._read_buffer_size = 0
            self._read_bytes = 0
            future.set_result(data)
        return future
Ejemplo n.º 18
0
    def read(self, num_bytes):
        assert self._read_future is None, "Already reading"
        if self._closed:
            raise StreamClosedError(real_error=self.error)

        future = self._read_future = TracebackFuture()
        self._read_bytes = num_bytes
        self._read_partial = False
        if self._read_buffer_size >= self._read_bytes:
            future, self._read_future = self._read_future, None
            data = b"".join(self._read_buffer)
            self._read_buffer.clear()
            self._read_buffer_size = 0
            self._read_bytes = 0
            future.set_result(data)
        return future
Ejemplo n.º 19
0
 async def connect(self, address, deserialize=True, **connection_args):
     kwargs = self._get_connect_args(**connection_args)
     try:
         request = HTTPRequest(f"{self.prefix}{address}", **kwargs)
         sock = await websocket_connect(request,
                                        max_message_size=10_000_000_000)
         if sock.stream.closed() and sock.stream.error:
             raise StreamClosedError(sock.stream.error)
     except StreamClosedError as e:
         convert_stream_closed_error(self, e)
     except SSLError as err:
         raise FatalCommClosedError(
             "TLS expects a `ssl_context` argument of type "
             "ssl.SSLContext (perhaps check your TLS configuration?)"
         ) from err
     return self.comm_class(sock, deserialize=deserialize)
Ejemplo n.º 20
0
def connect(address, deserialize = True, **connection_args):
   ip, port = parse_host_port(address)
   #kwargs = self._get_connect_args(**connection_args)
   kwargs = {} # The method in Dask just returns {} as far as I can tell
   client = TCPClient()
   try:
      stream = yield client.connect(ip, port, max_buffer_size = MAX_BUFFER_SIZE, **kwargs)
      
      # Under certain circumstances tornado will have a closed connnection with an error and not raise
      # a StreamClosedError.
      #
      # This occurs with tornado 5.x and openssl 1.1+      
      if stream.closed() and stream.error: 
         raise StreamClosedError(stream.error)
   except StreamClosedError as e:
      # The socket connect() call failed
      convert_stream_closed_error("Lambda", e)         
   
   local_address = prefix + get_stream_address(stream)
   raise gen.Return(TCP(stream, local_address, prefix + address, deserialize))
Ejemplo n.º 21
0
 def handle_frame(self, frame):
     if frame.type == constants.FrameType.SETTINGS:
         self._handle_settings_frame(frame)
     elif frame.type == constants.FrameType.WINDOW_UPDATE:
         self._handle_window_update_frame(frame)
     elif frame.type == constants.FrameType.PING:
         self._handle_ping_frame(frame)
     elif frame.type == constants.FrameType.GOAWAY:
         self.stream.close()
         # TODO: shut down all open streams.
         raise StreamClosedError()
     elif frame.type in (constants.FrameType.DATA,
                         constants.FrameType.HEADERS,
                         constants.FrameType.PRIORITY,
                         constants.FrameType.RST_STREAM,
                         constants.FrameType.PUSH_PROMISE,
                         constants.FrameType.CONTINUATION):
         raise ConnectionError(
             constants.ErrorCode.PROTOCOL_ERROR,
             "invalid frame type %s for stream 0" % frame.type)
Ejemplo n.º 22
0
    def read(self, deserializers=None):
        # print("[ {} ] Attempting to read in TCP Comm...".format(datetime.datetime.utcnow()))
        stream = self.stream
        if stream is None:
            raise CommClosedError

        try:
            n_frames = yield stream.read_bytes(8)
            n_frames = struct.unpack("Q", n_frames)[0]
            lengths = yield stream.read_bytes(8 * n_frames)
            lengths = struct.unpack("Q" * n_frames, lengths)

            frames = []
            # print("[ {} ] Reading {} lengths now...".format(datetime.datetime.utcnow(), len(lengths)))
            for length in lengths:
                if length:
                    if PY3 and self._iostream_has_read_into:
                        frame = bytearray(length)
                        n = yield stream.read_into(frame)
                        assert n == length, (n, length)
                    else:
                        frame = yield stream.read_bytes(length)
                else:
                    frame = b""
                frames.append(frame)
        except StreamClosedError as e:
            self.stream = None
            print("StreamClosedError...")
            raise StreamClosedError("Stream closed...")
        else:
            try:
                msg = yield serialization.from_frames(
                    frames,
                    deserialize=self.deserialize,
                    deserializers=deserializers)
            except EOFError:
                # Frames possibly garbled or truncated by communication error
                self.abort()
                # print("aborted stream on truncated data")
                raise CommClosedError("aborted stream on truncated data")
        raise gen.Return(msg)
    def _invoke(self, method_name, *args, **kwargs):
        trace = kwargs.pop("trace", None)
        if trace:
            update_dict_with_trace(kwargs, trace)
        new_trace_id = kwargs.get('trace_id', self.trace_id)
        if new_trace_id != self.trace_id:
            self.log = get_trace_adapter(log, new_trace_id)
            self.trace_id = new_trace_id

        self.log.debug("`%s` Tx method `%s` call: %.300s %.300s",
                       self.service_name, method_name, args, kwargs)

        if self._done:
            raise ChokeEvent()

        if self.pipe is None:
            raise StreamClosedError()

        for method_id, (method, tx_tree) in six.iteritems(self.tx_tree):
            if method == method_name:
                self.log.debug("method `%s` has been found in API map",
                               method_name)
                headers = manage_headers(kwargs, self._header_table)

                packed_data = msgpack_packb(
                    [self.session_id, method_id, args, headers])
                self.log.info(
                    'send message to `%s`: channel id: %s, type: %s, length: %s bytes',
                    self.service_name, self.session_id, method_name,
                    len(packed_data))
                self.pipe.write(packed_data)

                if tx_tree == {}:  # last transition
                    self.done()
                elif tx_tree is not None:  # not a recursive transition
                    self.tx_tree = tx_tree
                raise Return(None)
        raise AttributeError(method_name)
Ejemplo n.º 24
0
    async def connect(self, address, deserialize=True, **connection_args):
        ip, port, sni = parse_gateway_address(address)
        ctx = connection_args.get("ssl_context")
        if not isinstance(ctx, ssl.SSLContext):
            raise TypeError("Gateway expects a `ssl_context` argument of type "
                            "ssl.SSLContext, instead got %s" % ctx)

        try:
            plain_stream = await self.client.connect(
                ip, port, max_buffer_size=MAX_BUFFER_SIZE)
            stream = await plain_stream.start_tls(False,
                                                  ssl_options=ctx,
                                                  server_hostname=sni)
            if stream.closed() and stream.error:
                raise StreamClosedError(stream.error)

        except StreamClosedError as e:
            # The socket connect() call failed
            convert_stream_closed_error(self, e)

        local_address = "tls://" + get_stream_address(stream)
        peer_address = "gateway://" + address
        return TLS(stream, local_address, peer_address, deserialize)
Ejemplo n.º 25
0
    def call(self, clazz, method, request, response):
        if self._stream is None:
            raise StreamClosedError()

        request_header = ProtobufRpcEngine_pb2.RequestHeaderProto()
        request_header.methodName = method
        request_header.declaringClassProtocolName = clazz
        request_header.clientProtocolVersion = 1

        # make call context
        call_id = self._next_call(response)

        # send request
        yield self._write([
            self._rpc_request_header,
            request_header,
            request,
        ])

        # check future
        future = self._callbacks[call_id][1]
        yield future

        raise gen.Return(future.result())
Ejemplo n.º 26
0
    def write(self, data):
        if self._closed:
            raise StreamClosedError(IOError('Already Closed'))

        self._transport.write(data)
Ejemplo n.º 27
0
 def on_connection_close(self):
     if not self.connect_future.done():
         self.connect_future.set_exception(StreamClosedError())
     self.on_message(None)
     self.tcp_client.close()
     super(WebSocketClientConnection, self).on_connection_close()
Ejemplo n.º 28
0
 def on_connection_close(self):
     raise StreamClosedError(
         "Connection is closed without finishing the request")
Ejemplo n.º 29
0
 def send(self, msg):
     if self._broken:
         raise StreamClosedError('Batch Stream is Closed')
     else:
         self.send_q.put_nowait(msg)
Ejemplo n.º 30
0
 def recv(self):
     result = yield self.recv_q.get()
     if result == 'close':
         raise StreamClosedError('Batched Stream is Closed')
     else:
         raise gen.Return(result)