예제 #1
0
    def start_tcp_proxy(self, proto, data):
        log("start_tcp_proxy(%s, %s)", proto, data[:10])
        #any buffers read after we steal the connection will be placed in this temporary queue:
        temp_read_buffer = Queue()
        client_connection = proto.steal_connection(temp_read_buffer.put)
        try:
            self._potential_protocols.remove(proto)
        except:
            pass  #might already have been removed by now
        #connect to web server:
        sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
        sock.settimeout(10)
        host, port = self._tcp_proxy.split(":", 1)
        try:
            web_server_connection = _socket_connect(sock, (host, int(port)),
                                                    "web-proxy-for-%s" % proto,
                                                    "tcp")
        except:
            log.warn("failed to connect to proxy: %s:%s", host, port)
            proto.gibberish("invalid packet header", data)
            return
        log("proxy connected to tcp server at %s:%s : %s", host, port,
            web_server_connection)
        sock.settimeout(self._socket_timeout)

        ioe = proto.wait_for_io_threads_exit(0.5 + self._socket_timeout)
        if not ioe:
            log.warn("proxy failed to stop all existing network threads!")
            self.disconnect_protocol(proto, "internal threading error")
            return
        #now that we own it, we can start it again:
        client_connection.set_active(True)
        #and we can use blocking sockets:
        self.set_socket_timeout(client_connection, None)
        #prevent deadlocks on exit:
        sock.settimeout(1)

        log("pushing initial buffer to its new destination: %s",
            repr_ellipsized(data))
        web_server_connection.write(data)
        while not temp_read_buffer.empty():
            buf = temp_read_buffer.get()
            if buf:
                log("pushing read buffer to its new destination: %s",
                    repr_ellipsized(buf))
                web_server_connection.write(buf)
        p = XpraProxy(client_connection, web_server_connection)
        self._tcp_proxy_clients.append(p)

        def run_proxy():
            p.run()
            log("run_proxy() %s ended", p)
            if p in self._tcp_proxy_clients:
                self._tcp_proxy_clients.remove(p)

        t = make_daemon_thread(run_proxy, "web-proxy-for-%s" % proto)
        t.start()
        log.info("client %s forwarded to proxy server %s:%s",
                 client_connection, host, port)
예제 #2
0
    def start_tcp_proxy(self, proto, data):
        proxylog("start_tcp_proxy(%s, '%s')", proto, repr_ellipsized(data))
        try:
            self._potential_protocols.remove(proto)
        except:
            pass  # might already have been removed by now
        proxylog("start_tcp_proxy: protocol state before stealing: %s", proto.get_info(alias_info=False))
        # any buffers read after we steal the connection will be placed in this temporary queue:
        temp_read_buffer = Queue()
        client_connection = proto.steal_connection(temp_read_buffer.put)
        # connect to web server:
        sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
        sock.settimeout(10)
        host, port = self._tcp_proxy.split(":", 1)
        try:
            web_server_connection = _socket_connect(sock, (host, int(port)), "web-proxy-for-%s" % proto, "tcp")
        except:
            proxylog.warn("failed to connect to proxy: %s:%s", host, port)
            proto.gibberish("invalid packet header", data)
            return
        proxylog("proxy connected to tcp server at %s:%s : %s", host, port, web_server_connection)
        sock.settimeout(self._socket_timeout)

        ioe = proto.wait_for_io_threads_exit(0.5 + self._socket_timeout)
        if not ioe:
            proxylog.warn("proxy failed to stop all existing network threads!")
            self.disconnect_protocol(proto, "internal threading error")
            return
        # now that we own it, we can start it again:
        client_connection.set_active(True)
        # and we can use blocking sockets:
        self.set_socket_timeout(client_connection, None)
        # prevent deadlocks on exit:
        sock.settimeout(1)

        proxylog("pushing initial buffer to its new destination: %s", repr_ellipsized(data))
        web_server_connection.write(data)
        while not temp_read_buffer.empty():
            buf = temp_read_buffer.get()
            if buf:
                proxylog("pushing read buffer to its new destination: %s", repr_ellipsized(buf))
                web_server_connection.write(buf)
        p = XpraProxy(client_connection.target, client_connection, web_server_connection)
        self._tcp_proxy_clients.append(p)
        proxylog.info(
            "client connection from %s forwarded to proxy server on %s:%s", client_connection.target, host, port
        )
        p.run()
        proxylog("run_proxy() %s ended", p)
        if p in self._tcp_proxy_clients:
            self._tcp_proxy_clients.remove(p)
예제 #3
0
class Protocol(object):
    CONNECTION_LOST = "connection-lost"
    GIBBERISH = "gibberish"
    INVALID = "invalid"

    def __init__(self, scheduler, conn, process_packet_cb, get_packet_cb=None):
        """
            You must call this constructor and source_has_more() from the main thread.
        """
        assert scheduler is not None
        assert conn is not None
        self.timeout_add = scheduler.timeout_add
        self.idle_add = scheduler.idle_add
        self._conn = conn
        if FAKE_JITTER > 0:
            from xpra.net.fake_jitter import FakeJitter
            fj = FakeJitter(self.timeout_add, process_packet_cb)
            self._process_packet_cb = fj.process_packet_cb
        else:
            self._process_packet_cb = process_packet_cb
        self._write_queue = Queue(1)
        self._read_queue = Queue(20)
        self._read_queue_put = self.read_queue_put
        # Invariant: if .source is None, then _source_has_more == False
        self._get_packet_cb = get_packet_cb
        #counters:
        self.input_stats = {}
        self.input_packetcount = 0
        self.input_raw_packetcount = 0
        self.output_stats = {}
        self.output_packetcount = 0
        self.output_raw_packetcount = 0
        #initial value which may get increased by client/server after handshake:
        self.max_packet_size = 256 * 1024
        self.abs_max_packet_size = 256 * 1024 * 1024
        self.large_packets = ["hello", "window-metadata", "sound-data"]
        self.send_aliases = {}
        self.receive_aliases = {}
        self._log_stats = None  #None here means auto-detect
        self._closed = False
        self.encoder = "none"
        self._encoder = self.noencode
        self.compressor = "none"
        self._compress = compression.nocompress
        self.compression_level = 0
        self.cipher_in = None
        self.cipher_in_name = None
        self.cipher_in_block_size = 0
        self.cipher_in_padding = INITIAL_PADDING
        self.cipher_out = None
        self.cipher_out_name = None
        self.cipher_out_block_size = 0
        self.cipher_out_padding = INITIAL_PADDING
        self._write_lock = Lock()
        from xpra.make_thread import make_thread
        self._write_thread = make_thread(self._write_thread_loop,
                                         "write",
                                         daemon=True)
        self._read_thread = make_thread(self._read_thread_loop,
                                        "read",
                                        daemon=True)
        self._read_parser_thread = None  #started when needed
        self._write_format_thread = None  #started when needed
        self._source_has_more = Event()

    STATE_FIELDS = ("max_packet_size", "large_packets", "send_aliases",
                    "receive_aliases", "cipher_in", "cipher_in_name",
                    "cipher_in_block_size", "cipher_in_padding", "cipher_out",
                    "cipher_out_name", "cipher_out_block_size",
                    "cipher_out_padding", "compression_level", "encoder",
                    "compressor")

    def save_state(self):
        state = {}
        for x in Protocol.STATE_FIELDS:
            state[x] = getattr(self, x)
        return state

    def restore_state(self, state):
        assert state is not None
        for x in Protocol.STATE_FIELDS:
            assert x in state, "field %s is missing" % x
            setattr(self, x, state[x])
        #special handling for compressor / encoder which are named objects:
        self.enable_compressor(self.compressor)
        self.enable_encoder(self.encoder)

    def wait_for_io_threads_exit(self, timeout=None):
        for t in (self._read_thread, self._write_thread):
            if t and t.isAlive():
                t.join(timeout)
        exited = True
        cinfo = self._conn or "cleared connection"
        for t in (self._read_thread, self._write_thread):
            if t and t.isAlive():
                log.warn(
                    "Warning: %s thread of %s is still alive (timeout=%s)",
                    t.name, cinfo, timeout)
                exited = False
        return exited

    def set_packet_source(self, get_packet_cb):
        self._get_packet_cb = get_packet_cb

    def set_cipher_in(self, ciphername, iv, password, key_salt, iterations,
                      padding):
        if self.cipher_in_name != ciphername:
            cryptolog.info("receiving data using %s encryption", ciphername)
            self.cipher_in_name = ciphername
        cryptolog("set_cipher_in%s",
                  (ciphername, iv, password, key_salt, iterations))
        self.cipher_in, self.cipher_in_block_size = get_decryptor(
            ciphername, iv, password, key_salt, iterations)
        self.cipher_in_padding = padding

    def set_cipher_out(self, ciphername, iv, password, key_salt, iterations,
                       padding):
        if self.cipher_out_name != ciphername:
            cryptolog.info("sending data using %s encryption", ciphername)
            self.cipher_out_name = ciphername
        cryptolog("set_cipher_out%s",
                  (ciphername, iv, password, key_salt, iterations, padding))
        self.cipher_out, self.cipher_out_block_size = get_encryptor(
            ciphername, iv, password, key_salt, iterations)
        self.cipher_out_padding = padding

    def __repr__(self):
        return "Protocol(%s)" % self._conn

    def get_threads(self):
        return [
            x for x in [
                self._write_thread, self._read_thread,
                self._read_parser_thread, self._write_format_thread
            ] if x is not None
        ]

    def get_info(self, alias_info=True):
        info = {
            "large_packets": self.large_packets,
            "compression_level": self.compression_level,
            "max_packet_size": self.max_packet_size,
            "aliases": USE_ALIASES,
            "input": {
                "buffer-size": READ_BUFFER_SIZE,
                "packetcount": self.input_packetcount,
                "raw_packetcount": self.input_raw_packetcount,
                "count": self.input_stats,
                "cipher": {
                    "": self.cipher_in_name or "",
                    "padding": self.cipher_in_padding,
                },
            },
            "output": {
                "packet-join-size": PACKET_JOIN_SIZE,
                "large-packet-size": LARGE_PACKET_SIZE,
                "inline-size": INLINE_SIZE,
                "min-compress-size": MIN_COMPRESS_SIZE,
                "packetcount": self.output_packetcount,
                "raw_packetcount": self.output_raw_packetcount,
                "count": self.output_stats,
                "cipher": {
                    "": self.cipher_out_name or "",
                    "padding": self.cipher_out_padding
                },
            },
        }
        c = self._compress
        if c:
            info["compressor"] = compression.get_compressor_name(
                self._compress)
        e = self._encoder
        if e:
            if self._encoder == self.noencode:
                info["encoder"] = "noencode"
            else:
                info["encoder"] = packet_encoding.get_encoder_name(
                    self._encoder)
        if alias_info:
            info["send_alias"] = self.send_aliases
            info["receive_alias"] = self.receive_aliases
        c = self._conn
        if c:
            try:
                info.update(self._conn.get_info())
            except:
                log.error("error collecting connection information on %s",
                          self._conn,
                          exc_info=True)
        info["has_more"] = self._source_has_more.is_set()
        for t in (self._write_thread, self._read_thread,
                  self._read_parser_thread, self._write_format_thread):
            if t:
                info.setdefault("thread", {})[t.name] = t.is_alive()
        return info

    def start(self):
        def do_start():
            if not self._closed:
                self._write_thread.start()
                self._read_thread.start()

        self.idle_add(do_start)

    def send_now(self, packet):
        if self._closed:
            log("send_now(%s ...) connection is closed already, not sending",
                packet[0])
            return
        log("send_now(%s ...)", packet[0])
        assert self._get_packet_cb == None, "cannot use send_now when a packet source exists! (set to %s)" % self._get_packet_cb

        def packet_cb():
            self._get_packet_cb = None
            return (packet, )

        self._get_packet_cb = packet_cb
        self.source_has_more()

    def source_has_more(self):
        self._source_has_more.set()
        #start the format thread:
        if not self._write_format_thread and not self._closed:
            from xpra.make_thread import make_thread
            self._write_format_thread = make_thread(
                self._write_format_thread_loop, "format", daemon=True)
            self._write_format_thread.start()
        INJECT_FAULT(self)

    def _write_format_thread_loop(self):
        log("write_format_thread_loop starting")
        try:
            while not self._closed:
                self._source_has_more.wait()
                gpc = self._get_packet_cb
                if self._closed or not gpc:
                    return
                self._source_has_more.clear()
                self._add_packet_to_queue(*gpc())
        except Exception as e:
            if self._closed:
                return
            self._internal_error("error in network packet write/format",
                                 e,
                                 exc_info=True)

    def _add_packet_to_queue(self,
                             packet,
                             start_send_cb=None,
                             end_send_cb=None,
                             has_more=False):
        if has_more:
            self._source_has_more.set()
        if packet is None:
            return
        log("add_packet_to_queue(%s ...)", packet[0])
        chunks = self.encode(packet)
        with self._write_lock:
            if self._closed:
                return
            self._add_chunks_to_queue(chunks, start_send_cb, end_send_cb)

    def _add_chunks_to_queue(self,
                             chunks,
                             start_send_cb=None,
                             end_send_cb=None):
        """ the write_lock must be held when calling this function """
        counter = 0
        items = []
        for proto_flags, index, level, data in chunks:
            scb, ecb = None, None
            #fire the start_send_callback just before the first packet is processed:
            if counter == 0:
                scb = start_send_cb
            #fire the end_send callback when the last packet (index==0) makes it out:
            if index == 0:
                ecb = end_send_cb
            payload_size = len(data)
            actual_size = payload_size
            if self.cipher_out:
                proto_flags |= FLAGS_CIPHER
                #note: since we are padding: l!=len(data)
                padding_size = self.cipher_out_block_size - (
                    payload_size % self.cipher_out_block_size)
                if padding_size == 0:
                    padded = data
                else:
                    # pad byte value is number of padding bytes added
                    padded = data + pad(self.cipher_out_padding, padding_size)
                    actual_size += padding_size
                assert len(
                    padded
                ) == actual_size, "expected padded size to be %i, but got %i" % (
                    len(padded), actual_size)
                data = self.cipher_out.encrypt(padded)
                assert len(
                    data
                ) == actual_size, "expected encrypted size to be %i, but got %i" % (
                    len(data), actual_size)
                cryptolog("sending %s bytes %s encrypted with %s padding",
                          payload_size, self.cipher_out_name, padding_size)
            if proto_flags & FLAGS_NOHEADER:
                assert not self.cipher_out
                #for plain/text packets (ie: gibberish response)
                log("sending %s bytes without header", payload_size)
                items.append((data, scb, ecb))
            elif actual_size < PACKET_JOIN_SIZE:
                if type(data) not in JOIN_TYPES:
                    data = bytes(data)
                header_and_data = pack_header(proto_flags, level, index,
                                              payload_size) + data
                items.append((header_and_data, scb, ecb))
            else:
                header = pack_header(proto_flags, level, index, payload_size)
                items.append((header, scb, None))
                items.append((strtobytes(data), None, ecb))
            counter += 1
        self._write_queue.put(items)
        self.output_packetcount += 1

    def raw_write(self, contents, start_cb=None, end_cb=None):
        """ Warning: this bypasses the compression and packet encoder! """
        self._write_queue.put(((contents, start_cb, end_cb), ))

    def verify_packet(self, packet):
        """ look for None values which may have caused the packet to fail encoding """
        if type(packet) != list:
            return
        assert len(packet) > 0, "invalid packet: %s" % packet
        tree = ["'%s' packet" % packet[0]]
        self.do_verify_packet(tree, packet)

    def do_verify_packet(self, tree, packet):
        def err(msg):
            log.error("%s in %s", msg, "->".join(tree))

        def new_tree(append):
            nt = tree[:]
            nt.append(append)
            return nt

        if packet is None:
            return err("None value")
        if type(packet) == list:
            for i, x in enumerate(packet):
                self.do_verify_packet(new_tree("[%s]" % i), x)
        elif type(packet) == dict:
            for k, v in packet.items():
                self.do_verify_packet(new_tree("key for value='%s'" % str(v)),
                                      k)
                self.do_verify_packet(new_tree("value for key='%s'" % str(k)),
                                      v)

    def enable_default_encoder(self):
        opts = packet_encoding.get_enabled_encoders()
        assert len(opts) > 0, "no packet encoders available!"
        self.enable_encoder(opts[0])

    def enable_encoder_from_caps(self, caps):
        opts = packet_encoding.get_enabled_encoders(
            order=packet_encoding.PERFORMANCE_ORDER)
        log("enable_encoder_from_caps(..) options=%s", opts)
        for e in opts:
            if caps.boolget(e, e == "bencode"):
                self.enable_encoder(e)
                return True
        log.error("no matching packet encoder found!")
        return False

    def enable_encoder(self, e):
        self._encoder = packet_encoding.get_encoder(e)
        self.encoder = e
        log("enable_encoder(%s): %s", e, self._encoder)

    def enable_default_compressor(self):
        opts = compression.get_enabled_compressors()
        if len(opts) > 0:
            self.enable_compressor(opts[0])
        else:
            self.enable_compressor("none")

    def enable_compressor_from_caps(self, caps):
        if self.compression_level == 0:
            self.enable_compressor("none")
            return
        opts = compression.get_enabled_compressors(
            order=compression.PERFORMANCE_ORDER)
        log("enable_compressor_from_caps(..) options=%s", opts)
        for c in opts:  #ie: [zlib, lz4, lzo]
            if caps.boolget(c):
                self.enable_compressor(c)
                return
        log.warn("compression disabled: no matching compressor found")
        self.enable_compressor("none")

    def enable_compressor(self, compressor):
        self._compress = compression.get_compressor(compressor)
        self.compressor = compressor
        log("enable_compressor(%s): %s", compressor, self._compress)

    def noencode(self, data):
        #just send data as a string for clients that don't understand xpra packet format:
        if sys.version_info[0] >= 3:
            import codecs

            def b(x):
                if type(x) == bytes:
                    return x
                return codecs.latin_1_encode(x)[0]
        else:

            def b(x):  #@DuplicatedSignature
                return x

        return b(": ".join(str(x) for x in data) + "\n"), FLAGS_NOHEADER

    def encode(self, packet_in):
        """
        Given a packet (tuple or list of items), converts it for the wire.
        This method returns all the binary packets to send, as an array of:
        (index, compression_level and compression flags, binary_data)
        The index, if positive indicates the item to populate in the packet
        whose index is zero.
        ie: ["blah", [large binary data], "hello", 200]
        may get converted to:
        [
            (1, compression_level, [large binary data now zlib compressed]),
            (0,                 0, bencoded/rencoded(["blah", '', "hello", 200]))
        ]
        """
        packets = []
        packet = list(packet_in)
        level = self.compression_level
        size_check = LARGE_PACKET_SIZE
        min_comp_size = MIN_COMPRESS_SIZE
        for i in range(1, len(packet)):
            item = packet[i]
            if item is None:
                raise TypeError("invalid None value in %s packet at index %s" %
                                (packet[0], i))
            ti = type(item)
            if ti in (int, long, bool, dict, list, tuple):
                continue
            try:
                l = len(item)
            except TypeError as e:
                raise TypeError(
                    "invalid type %s in %s packet at index %s: %s" %
                    (ti, packet[0], i, e))
            if ti == LargeStructure:
                item = item.data
                packet[i] = item
                ti = type(item)
                continue
            elif ti == Compressible:
                #this is a marker used to tell us we should compress it now
                #(used by the client for clipboard data)
                item = item.compress()
                packet[i] = item
                ti = type(item)
                #(it may now be a "Compressed" item and be processed further)
            if ti in (Compressed, LevelCompressed):
                #already compressed data (usually pixels, cursors, etc)
                if not item.can_inline or l > INLINE_SIZE:
                    il = 0
                    if ti == LevelCompressed:
                        #unlike Compressed (usually pixels, decompressed in the paint thread),
                        #LevelCompressed is decompressed by the network layer
                        #so we must tell it how to do that and pass the level flag
                        il = item.level
                    packets.append((0, i, il, item.data))
                    packet[i] = ''
                else:
                    #data is small enough, inline it:
                    packet[i] = item.data
                    min_comp_size += l
                    size_check += l
            elif ti in (str, bytes) and level > 0 and l > LARGE_PACKET_SIZE:
                log.warn(
                    "found a large uncompressed item in packet '%s' at position %s: %s bytes",
                    packet[0], i, len(item))
                #add new binary packet with large item:
                cl, cdata = self._compress(item, level)
                packets.append((0, i, cl, cdata))
                #replace this item with an empty string placeholder:
                packet[i] = ''
            elif ti not in (str, bytes):
                log.warn("unexpected data type %s in %s packet: %s", ti,
                         packet[0], repr_ellipsized(item))
        #now the main packet (or what is left of it):
        packet_type = packet[0]
        self.output_stats[packet_type] = self.output_stats.get(packet_type,
                                                               0) + 1
        if USE_ALIASES and self.send_aliases and packet_type in self.send_aliases:
            #replace the packet type with the alias:
            packet[0] = self.send_aliases[packet_type]
        try:
            main_packet, proto_flags = self._encoder(packet)
        except Exception:
            if self._closed:
                return [], 0
            log.error("failed to encode packet: %s", packet, exc_info=True)
            #make the error a bit nicer to parse: undo aliases:
            packet[0] = packet_type
            self.verify_packet(packet)
            raise
        if len(main_packet
               ) > size_check and packet_in[0] not in self.large_packets:
            log.warn(
                "found large packet (%s bytes): %s, argument types:%s, sizes: %s, packet head=%s",
                len(main_packet), packet_in[0], [type(x) for x in packet[1:]],
                [len(str(x)) for x in packet[1:]], repr_ellipsized(packet))
        #compress, but don't bother for small packets:
        if level > 0 and len(main_packet) > min_comp_size:
            cl, cdata = self._compress(main_packet, level)
            packets.append((proto_flags, 0, cl, cdata))
        else:
            packets.append((proto_flags, 0, 0, main_packet))
        return packets

    def set_compression_level(self, level):
        #this may be used next time encode() is called
        assert level >= 0 and level <= 10, "invalid compression level: %s (must be between 0 and 10" % level
        self.compression_level = level

    def _io_thread_loop(self, name, callback):
        try:
            log("io_thread_loop(%s, %s) loop starting", name, callback)
            while not self._closed and callback():
                pass
            log("io_thread_loop(%s, %s) loop ended, closed=%s", name, callback,
                self._closed)
        except ConnectionClosedException as e:
            log("%s closed", self._conn, exc_info=True)
            if not self._closed:
                #ConnectionClosedException means the warning has been logged already
                self._connection_lost("%s connection %s closed" %
                                      (name, self._conn))
        except (OSError, IOError, socket_error) as e:
            if not self._closed:
                self._internal_error("%s connection %s reset" %
                                     (name, self._conn),
                                     e,
                                     exc_info=e.args[0] not in ABORT)
        except Exception as e:
            #can happen during close(), in which case we just ignore:
            if not self._closed:
                log.error("Error: %s on %s failed: %s",
                          name,
                          self._conn,
                          type(e),
                          exc_info=True)
                self.close()

    def _write_thread_loop(self):
        self._io_thread_loop("write", self._write)

    def _write(self):
        items = self._write_queue.get()
        # Used to signal that we should exit:
        if items is None:
            log("write thread: empty marker, exiting")
            self.close()
            return False
        for buf, start_cb, end_cb in items:
            con = self._conn
            if not con:
                return False
            if start_cb:
                try:
                    start_cb(con.output_bytecount)
                except:
                    if not self._closed:
                        log.error("error on %s", start_cb, exc_info=True)
            while buf and not self._closed:
                written = con.write(buf)
                if written:
                    buf = buf[written:]
                    self.output_raw_packetcount += 1
            if end_cb:
                try:
                    end_cb(self._conn.output_bytecount)
                except:
                    if not self._closed:
                        log.error("error on %s", end_cb, exc_info=True)
        return True

    def _read_thread_loop(self):
        self._io_thread_loop("read", self._read)

    def _read(self):
        buf = self._conn.read(READ_BUFFER_SIZE)
        #log("read thread: got data of size %s: %s", len(buf), repr_ellipsized(buf))
        #add to the read queue (or whatever takes its place - see steal_connection)
        self._read_queue_put(buf)
        if not buf:
            log("read thread: eof")
            #give time to the parse thread to call close itself
            #so it has time to parse and process the last packet received
            self.timeout_add(1000, self.close)
            return False
        self.input_raw_packetcount += 1
        return True

    def _internal_error(self, message="", exc=None, exc_info=False):
        #log exception info with last log message
        if self._closed:
            return
        ei = exc_info
        if exc:
            ei = None  #log it separately below
        log.error("Error: %s", message, exc_info=ei)
        if exc:
            log.error(" %s", exc, exc_info=exc_info)
        self.idle_add(self._connection_lost, message)

    def _connection_lost(self, message="", exc_info=False):
        log("connection lost: %s", message, exc_info=exc_info)
        self.close()
        return False

    def invalid(self, msg, data):
        self.idle_add(self._process_packet_cb, self,
                      [Protocol.INVALID, msg, data])
        # Then hang up:
        self.timeout_add(1000, self._connection_lost, msg)

    def gibberish(self, msg, data):
        self.idle_add(self._process_packet_cb, self,
                      [Protocol.GIBBERISH, msg, data])
        # Then hang up:
        self.timeout_add(1000, self._connection_lost, msg)

    #delegates to invalid_header()
    #(so this can more easily be intercepted and overriden
    # see tcp-proxy)
    def _invalid_header(self, data):
        self.invalid_header(self, data)

    def invalid_header(self, proto, data):
        err = "invalid packet header: '%s'" % binascii.hexlify(data[:8])
        if len(data) > 1:
            err += " read buffer=%s" % repr_ellipsized(data)
        self.gibberish(err, data)

    def read_queue_put(self, data):
        #start the parse thread if needed:
        if not self._read_parser_thread and not self._closed:
            from xpra.make_thread import make_thread
            self._read_parser_thread = make_thread(
                self._read_parse_thread_loop, "parse", daemon=True)
            self._read_parser_thread.start()
        self._read_queue.put(data)

    def _read_parse_thread_loop(self):
        log("read_parse_thread_loop starting")
        try:
            self.do_read_parse_thread_loop()
        except Exception as e:
            if self._closed:
                return
            self._internal_error("error in network packet reading/parsing",
                                 e,
                                 exc_info=True)

    def do_read_parse_thread_loop(self):
        """
            Process the individual network packets placed in _read_queue.
            Concatenate the raw packet data, then try to parse it.
            Extract the individual packets from the potentially large buffer,
            saving the rest of the buffer for later, and optionally decompress this data
            and re-construct the one python-object-packet from potentially multiple packets (see packet_index).
            The 8 bytes packet header gives us information on the packet index, packet size and compression.
            The actual processing of the packet is done via the callback process_packet_cb,
            this will be called from this parsing thread so any calls that need to be made
            from the UI thread will need to use a callback (usually via 'idle_add')
        """
        read_buffer = None
        payload_size = -1
        padding_size = 0
        packet_index = 0
        compression_level = False
        packet = None
        raw_packets = {}
        while not self._closed:
            buf = self._read_queue.get()
            if not buf:
                log("parse thread: empty marker, exiting")
                self.idle_add(self.close)
                return
            if read_buffer:
                read_buffer = read_buffer + buf
            else:
                read_buffer = buf
            bl = len(read_buffer)
            while not self._closed:
                packet = None
                bl = len(read_buffer)
                if bl <= 0:
                    break
                if payload_size < 0:
                    if read_buffer[0] not in ("P", ord("P")):
                        self._invalid_header(read_buffer)
                        return
                    if bl < 8:
                        break  #packet still too small
                    #packet format: struct.pack('cBBBL', ...) - 8 bytes
                    _, protocol_flags, compression_level, packet_index, data_size = unpack_header(
                        read_buffer[:8])

                    #sanity check size (will often fail if not an xpra client):
                    if data_size > self.abs_max_packet_size:
                        self._invalid_header(read_buffer)
                        return

                    bl = len(read_buffer) - 8
                    if protocol_flags & FLAGS_CIPHER:
                        if self.cipher_in_block_size == 0 or not self.cipher_in_name:
                            cryptolog.warn(
                                "received cipher block but we don't have a cipher to decrypt it with, not an xpra client?"
                            )
                            self._invalid_header(read_buffer)
                            return
                        padding_size = self.cipher_in_block_size - (
                            data_size % self.cipher_in_block_size)
                        payload_size = data_size + padding_size
                    else:
                        #no cipher, no padding:
                        padding_size = 0
                        payload_size = data_size
                    assert payload_size > 0, "invalid payload size: %i" % payload_size
                    read_buffer = read_buffer[8:]

                    if payload_size > self.max_packet_size:
                        #this packet is seemingly too big, but check again from the main UI thread
                        #this gives 'set_max_packet_size' a chance to run from "hello"
                        def check_packet_size(size_to_check, packet_header):
                            if self._closed:
                                return False
                            log("check_packet_size(%s, 0x%s) limit is %s",
                                size_to_check, repr_ellipsized(packet_header),
                                self.max_packet_size)
                            if size_to_check > self.max_packet_size:
                                msg = "packet size requested is %s but maximum allowed is %s" % \
                                              (size_to_check, self.max_packet_size)
                                self.invalid(msg, packet_header)
                            return False

                        self.timeout_add(1000, check_packet_size, payload_size,
                                         read_buffer[:32])

                if bl < payload_size:
                    # incomplete packet, wait for the rest to arrive
                    break

                #chop this packet from the buffer:
                if len(read_buffer) == payload_size:
                    raw_string = read_buffer
                    read_buffer = ''
                else:
                    raw_string = read_buffer[:payload_size]
                    read_buffer = read_buffer[payload_size:]
                #decrypt if needed:
                data = raw_string
                if self.cipher_in and protocol_flags & FLAGS_CIPHER:
                    cryptolog("received %i %s encrypted bytes with %s padding",
                              payload_size, self.cipher_in_name, padding_size)
                    data = self.cipher_in.decrypt(raw_string)
                    if padding_size > 0:

                        def debug_str(s):
                            try:
                                return binascii.hexlify(bytearray(s))
                            except:
                                return csv(list(str(s)))

                        # pad byte value is number of padding bytes added
                        padtext = pad(self.cipher_in_padding, padding_size)
                        if data.endswith(padtext):
                            cryptolog("found %s %s padding",
                                      self.cipher_in_padding,
                                      self.cipher_in_name)
                        else:
                            actual_padding = data[-padding_size:]
                            cryptolog.warn(
                                "Warning: %s decryption failed: invalid padding",
                                self.cipher_in_name)
                            cryptolog(
                                " data does not end with %s padding bytes %s",
                                self.cipher_in_padding, debug_str(padtext))
                            cryptolog(" but with %s (%s)",
                                      debug_str(actual_padding), type(data))
                            cryptolog(" decrypted data: %s",
                                      debug_str(data[:128]))
                            return self._internal_error(
                                "%s encryption padding error - wrong key?" %
                                self.cipher_in_name)
                        data = data[:-padding_size]
                #uncompress if needed:
                if compression_level > 0:
                    try:
                        data = decompress(data, compression_level)
                    except InvalidCompressionException as e:
                        self.invalid("invalid compression: %s" % e, data)
                        return
                    except Exception as e:
                        ctype = compression.get_compression_type(
                            compression_level)
                        log("%s packet decompression failed",
                            ctype,
                            exc_info=True)
                        msg = "%s packet decompression failed" % ctype
                        if self.cipher_in:
                            msg += " (invalid encryption key?)"
                        else:
                            #only include the exception text when not using encryption
                            #as this may leak crypto information:
                            msg += " %s" % e
                        return self.gibberish(msg, data)

                if self.cipher_in and not (protocol_flags & FLAGS_CIPHER):
                    self.invalid("unencrypted packet dropped", data)
                    return

                if self._closed:
                    return
                if packet_index > 0:
                    #raw packet, store it and continue:
                    raw_packets[packet_index] = data
                    payload_size = -1
                    packet_index = 0
                    if len(raw_packets) >= 4:
                        self.invalid(
                            "too many raw packets: %s" % len(raw_packets),
                            data)
                        return
                    continue
                #final packet (packet_index==0), decode it:
                try:
                    packet = decode(data, protocol_flags)
                except InvalidPacketEncodingException as e:
                    self.invalid("invalid packet encoding: %s" % e, data)
                    return
                except ValueError as e:
                    etype = packet_encoding.get_packet_encoding_type(
                        protocol_flags)
                    log.error("Error parsing %s packet:", etype)
                    log.error(" %s", e)
                    if self._closed:
                        return
                    log("failed to parse %s packet: %s", etype,
                        binascii.hexlify(data[:128]))
                    log(" %s", e)
                    log(" data: %s", repr_ellipsized(data))
                    log(" packet index=%i, packet size=%i, buffer size=%s",
                        packet_index, payload_size, bl)
                    self.gibberish("failed to parse %s packet" % etype, data)
                    return

                if self._closed:
                    return
                payload_size = -1
                padding_size = 0
                #add any raw packets back into it:
                if raw_packets:
                    for index, raw_data in raw_packets.items():
                        #replace placeholder with the raw_data packet data:
                        packet[index] = raw_data
                    raw_packets = {}

                packet_type = packet[0]
                if self.receive_aliases and type(
                        packet_type
                ) == int and packet_type in self.receive_aliases:
                    packet_type = self.receive_aliases.get(packet_type)
                    packet[0] = packet_type
                self.input_stats[packet_type] = self.output_stats.get(
                    packet_type, 0) + 1

                self.input_packetcount += 1
                log("processing packet %s", packet_type)
                self._process_packet_cb(self, packet)
                packet = None
                INJECT_FAULT(self)

    def flush_then_close(self, last_packet, done_callback=None):
        """ Note: this is best effort only
            the packet may not get sent.

            We try to get the write lock,
            we try to wait for the write queue to flush
            we queue our last packet,
            we wait again for the queue to flush,
            then no matter what, we close the connection and stop the threads.
        """
        log("flush_then_close(%s, %s) closed=%s", last_packet, done_callback,
            self._closed)

        def done():
            log("flush_then_close: done, callback=%s", done_callback)
            if done_callback:
                done_callback()

        if self._closed:
            log("flush_then_close: already closed")
            return done()

        def wait_for_queue(timeout=10):
            #IMPORTANT: if we are here, we have the write lock held!
            if not self._write_queue.empty():
                #write queue still has stuff in it..
                if timeout <= 0:
                    log("flush_then_close: queue still busy, closing without sending the last packet"
                        )
                    self._write_lock.release()
                    self.close()
                    done()
                else:
                    log("flush_then_close: still waiting for queue to flush")
                    self.timeout_add(100, wait_for_queue, timeout - 1)
            else:
                log("flush_then_close: queue is now empty, sending the last packet and closing"
                    )
                chunks = self.encode(last_packet)

                def close_and_release():
                    log("flush_then_close: wait_for_packet_sent() close_and_release()"
                        )
                    self.close()
                    try:
                        self._write_lock.release()
                    except:
                        pass
                    done()

                def wait_for_packet_sent():
                    log(
                        "flush_then_close: wait_for_packet_sent() queue.empty()=%s, closed=%s",
                        self._write_queue.empty(), self._closed)
                    if self._write_queue.empty() or self._closed:
                        #it got sent, we're done!
                        close_and_release()
                        return False
                    return not self._closed  #run until we manage to close (here or via the timeout)

                def packet_queued(*args):
                    #if we're here, we have the lock and the packet is in the write queue
                    log("flush_then_close: packet_queued() closed=%s",
                        self._closed)
                    if wait_for_packet_sent():
                        #check again every 100ms
                        self.timeout_add(100, wait_for_packet_sent)

                self._add_chunks_to_queue(chunks,
                                          start_send_cb=None,
                                          end_send_cb=packet_queued)
                #just in case wait_for_packet_sent never fires:
                self.timeout_add(5 * 1000, close_and_release)

        def wait_for_write_lock(timeout=100):
            if not self._write_lock.acquire(False):
                if timeout <= 0:
                    log("flush_then_close: timeout waiting for the write lock")
                    self.close()
                    done()
                else:
                    log(
                        "flush_then_close: write lock is busy, will retry %s more times",
                        timeout)
                    self.timeout_add(10, wait_for_write_lock, timeout - 1)
            else:
                log("flush_then_close: acquired the write lock")
                #we have the write lock - we MUST free it!
                wait_for_queue()

        #normal codepath:
        # -> wait_for_write_lock
        # -> wait_for_queue
        # -> _add_chunks_to_queue
        # -> packet_queued
        # -> wait_for_packet_sent
        # -> close_and_release
        log("flush_then_close: wait_for_write_lock()")
        wait_for_write_lock()

    def close(self):
        log("Protocol.close() closed=%s, connection=%s", self._closed,
            self._conn)
        if self._closed:
            return
        self._closed = True
        self.idle_add(self._process_packet_cb, self,
                      [Protocol.CONNECTION_LOST])
        c = self._conn
        if c:
            try:
                log("Protocol.close() calling %s", c.close)
                c.close()
                if self._log_stats is None and self._conn.input_bytecount == 0 and self._conn.output_bytecount == 0:
                    #no data sent or received, skip logging of stats:
                    self._log_stats = False
                if self._log_stats:
                    from xpra.simple_stats import std_unit, std_unit_dec
                    log.info(
                        "connection closed after %s packets received (%s bytes) and %s packets sent (%s bytes)",
                        std_unit(self.input_packetcount),
                        std_unit_dec(self._conn.input_bytecount),
                        std_unit(self.output_packetcount),
                        std_unit_dec(self._conn.output_bytecount))
            except:
                log.error("error closing %s", self._conn, exc_info=True)
            self._conn = None
        self.terminate_queue_threads()
        self.idle_add(self.clean)
        log("Protocol.close() done")

    def steal_connection(self, read_callback=None):
        #so we can re-use this connection somewhere else
        #(frees all protocol threads and resources)
        #Note: this method can only be used with non-blocking sockets,
        #and if more than one packet can arrive, the read_callback should be used
        #to ensure that no packets get lost.
        #The caller must call wait_for_io_threads_exit() to ensure that this
        #class is no longer reading from the connection before it can re-use it
        assert not self._closed, "cannot steal a closed connection"
        if read_callback:
            self._read_queue_put = read_callback
        conn = self._conn
        self._closed = True
        self._conn = None
        if conn:
            #this ensures that we exit the untilConcludes() read/write loop
            conn.set_active(False)
        self.terminate_queue_threads()
        return conn

    def clean(self):
        #clear all references to ensure we can get garbage collected quickly:
        self._get_packet_cb = None
        self._encoder = None
        self._write_thread = None
        self._read_thread = None
        self._read_parser_thread = None
        self._write_format_thread = None
        self._process_packet_cb = None

    def terminate_queue_threads(self):
        log("terminate_queue_threads()")
        #the format thread will exit:
        self._get_packet_cb = None
        self._source_has_more.set()
        #make all the queue based threads exit by adding the empty marker:
        exit_queue = Queue()
        for _ in range(10):  #just 2 should be enough!
            exit_queue.put(None)
        try:
            owq = self._write_queue
            self._write_queue = exit_queue
            #discard all elements in the old queue and push the None marker:
            try:
                while owq.qsize() > 0:
                    owq.read(False)
            except:
                pass
            owq.put_nowait(None)
        except:
            pass
        try:
            orq = self._read_queue
            self._read_queue = exit_queue
            #discard all elements in the old queue and push the None marker:
            try:
                while orq.qsize() > 0:
                    orq.read(False)
            except:
                pass
            orq.put_nowait(None)
        except:
            pass
        #just in case the read thread is waiting again:
        self._source_has_more.set()
예제 #4
0
파일: src.py 프로젝트: svn2github/Xpra
class SoundSource(SoundPipeline):

    __gsignals__ = SoundPipeline.__generic_signals__.copy()
    __gsignals__.update({
        "new-buffer"    : n_arg_signal(2),
        })

    def __init__(self, src_type=None, src_options={}, codecs=get_codecs(), codec_options={}, volume=1.0):
        if not src_type:
            try:
                from xpra.sound.pulseaudio.pulseaudio_util import get_pa_device_options
                monitor_devices = get_pa_device_options(True, False)
                log.info("found pulseaudio monitor devices: %s", monitor_devices)
            except ImportError as e:
                log.warn("Warning: pulseaudio is not available!")
                log.warn(" %s", e)
                monitor_devices = []
            if len(monitor_devices)==0:
                log.warn("could not detect any pulseaudio monitor devices")
                log.warn(" a test source will be used instead")
                src_type = "audiotestsrc"
                default_src_options = {"wave":2, "freq":100, "volume":0.4}
            else:
                monitor_device = monitor_devices.items()[0][0]
                log.info("using pulseaudio source device:")
                log.info(" '%s'", monitor_device)
                src_type = "pulsesrc"
                default_src_options = {"device" : monitor_device}
            src_options = default_src_options
        if src_type not in get_source_plugins():
            raise InitExit(1, "invalid source plugin '%s', valid options are: %s" % (src_type, ",".join(get_source_plugins())))
        matching = [x for x in CODEC_ORDER if (x in codecs and x in get_codecs())]
        log("SoundSource(..) found matching codecs %s", matching)
        if not matching:
            raise InitExit(1, "no matching codecs between arguments '%s' and supported list '%s'" % (csv(codecs), csv(get_codecs().keys())))
        codec = matching[0]
        encoder, fmt = get_encoder_formatter(codec)
        self.queue = None
        self.caps = None
        self.volume = None
        self.sink = None
        self.src = None
        self.src_type = src_type
        self.buffer_latency = False
        self.jitter_queue = None
        self.file = None
        SoundPipeline.__init__(self, codec)
        src_options["name"] = "src"
        source_str = plugin_str(src_type, src_options)
        #FIXME: this is ugly and relies on the fact that we don't pass any codec options to work!
        encoder_str = plugin_str(encoder, codec_options or get_encoder_default_options(encoder))
        fmt_str = plugin_str(fmt, MUXER_DEFAULT_OPTIONS.get(fmt, {}))
        pipeline_els = [source_str]
        if SOURCE_QUEUE_TIME>0:
            queue_el = ["queue",
                        "name=queue",
                        "min-threshold-time=0",
                        "max-size-buffers=0",
                        "max-size-bytes=0",
                        "max-size-time=%s" % (SOURCE_QUEUE_TIME*MS_TO_NS),
                        "leaky=%s" % GST_QUEUE_LEAK_DOWNSTREAM]
            pipeline_els += [" ".join(queue_el)]
        if encoder in ENCODER_NEEDS_AUDIOCONVERT or src_type in SOURCE_NEEDS_AUDIOCONVERT:
            pipeline_els += ["audioconvert"]
        pipeline_els.append("volume name=volume volume=%s" % volume)
        pipeline_els += [encoder_str,
                        fmt_str,
                        APPSINK]
        if not self.setup_pipeline_and_bus(pipeline_els):
            return
        self.volume = self.pipeline.get_by_name("volume")
        self.sink = self.pipeline.get_by_name("sink")
        if SOURCE_QUEUE_TIME>0:
            self.queue  = self.pipeline.get_by_name("queue")
        if self.queue:
            try:
                self.queue.set_property("silent", True)
            except Exception as e:
                log("cannot make queue silent: %s", e)
        try:
            if get_gst_version()<(1,0):
                self.sink.set_property("enable-last-buffer", False)
            else:
                self.sink.set_property("enable-last-sample", False)
        except Exception as e:
            log("failed to disable last buffer: %s", e)
        self.skipped_caps = set()
        if JITTER>0:
            self.jitter_queue = Queue()
        try:
            #Gst 1.0:
            self.sink.connect("new-sample", self.on_new_sample)
            self.sink.connect("new-preroll", self.on_new_preroll1)
        except:
            #Gst 0.10:
            self.sink.connect("new-buffer", self.on_new_buffer)
            self.sink.connect("new-preroll", self.on_new_preroll0)
        self.src = self.pipeline.get_by_name("src")
        try:
            for x in ("actual-buffer-time", "actual-latency-time"):
                #don't comment this out, it is used to verify the attributes are present:
                gstlog("initial %s: %s", x, self.src.get_property(x))
            self.buffer_latency = True
        except Exception as e:
            log.info("source %s does not support 'buffer-time' or 'latency-time':", self.src_type)
            log.info(" %s", e)
        else:
            #if the env vars have been set, try to honour the settings:
            global BUFFER_TIME, LATENCY_TIME
            if BUFFER_TIME>0:
                if BUFFER_TIME<LATENCY_TIME:
                    log.warn("Warning: latency (%ims) must be lower than the buffer time (%ims)", LATENCY_TIME, BUFFER_TIME)
                else:
                    log("latency tuning for %s, will try to set buffer-time=%i, latency-time=%i", src_type, BUFFER_TIME, LATENCY_TIME)
                    def settime(attr, v):
                        try:
                            cval = self.src.get_property(attr)
                            gstlog("default: %s=%i", attr, cval//1000)
                            if v>=0:
                                self.src.set_property(attr, v*1000)
                                gstlog("overriding with: %s=%i", attr, v)
                        except Exception as e:
                            log.warn("source %s does not support '%s': %s", self.src_type, attr, e)
                    settime("buffer-time", BUFFER_TIME)
                    settime("latency-time", LATENCY_TIME)
        gen = generation.increase()
        if SAVE_TO_FILE is not None:
            parts = codec.split("+")
            if len(parts)>1:
                filename = SAVE_TO_FILE+str(gen)+"-"+parts[0]+".%s" % parts[1]
            else:
                filename = SAVE_TO_FILE+str(gen)+".%s" % codec
            self.file = open(filename, 'wb')
            log.info("saving %s stream to %s", codec, filename)


    def __repr__(self):
        return "SoundSource('%s' - %s)" % (self.pipeline_str, self.state)

    def cleanup(self):
        SoundPipeline.cleanup(self)
        self.src_type = ""
        self.sink = None
        self.caps = None
        f = self.file
        if f:
            self.file = None
            f.close()

    def get_info(self):
        info = SoundPipeline.get_info(self)
        if self.queue:
            info["queue"] = {"cur" : self.queue.get_property("current-level-time")//MS_TO_NS}
        if self.buffer_latency:
            for x in ("actual-buffer-time", "actual-latency-time"):
                v = self.src.get_property(x)
                if v>=0:
                    info[x] = v
        return info


    def on_new_preroll1(self, appsink):
        sample = appsink.emit('pull-preroll')
        gstlog('new preroll1: %s', sample)
        return self.emit_buffer1(sample)

    def on_new_sample(self, bus):
        #Gst 1.0
        sample = self.sink.emit("pull-sample")
        return self.emit_buffer1(sample)

    def emit_buffer1(self, sample):
        buf = sample.get_buffer()
        #info = sample.get_info()
        size = buf.get_size()
        extract_dup = getattr(buf, "extract_dup", None)
        if extract_dup:
            data = extract_dup(0, size)
        else:
            #crappy gi bindings detected, using workaround:
            from xpra.sound.gst_hacks import map_gst_buffer
            with map_gst_buffer(buf) as a:
                data = bytes(a[:])
        return self.emit_buffer(data, {"timestamp"  : normv(buf.pts),
                                   "duration"   : normv(buf.duration),
                                   })


    def on_new_preroll0(self, appsink):
        buf = appsink.emit('pull-preroll')
        gstlog('new preroll0: %s bytes', len(buf))
        return self.emit_buffer0(buf)

    def on_new_buffer(self, bus):
        #pygst 0.10
        buf = self.sink.emit("pull-buffer")
        return self.emit_buffer0(buf)


    def caps_to_dict(self, caps):
        if not caps:
            return {}
        d = {}
        try:
            for cap in caps:
                name = cap.get_name()
                capd = {}
                for k in cap.keys():
                    v = cap[k]
                    if type(v) in (str, int):
                        capd[k] = cap[k]
                    elif k not in self.skipped_caps:
                        log("skipping %s cap key %s=%s of type %s", name, k, v, type(v))
                d[name] = capd
        except Exception as e:
            log.error("Error parsing '%s':", caps)
            log.error(" %s", e)
        return d

    def emit_buffer0(self, buf):
        """ convert pygst structure into something more generic for the wire """
        #none of the metadata is really needed at present, but it may be in the future:
        #metadata = {"caps"      : buf.get_caps().to_string(),
        #            "size"      : buf.size,
        #            "timestamp" : buf.timestamp,
        #            "duration"  : buf.duration,
        #            "offset"    : buf.offset,
        #            "offset_end": buf.offset_end}
        log("emit buffer: %s bytes, timestamp=%s", len(buf.data), buf.timestamp//MS_TO_NS)
        metadata = {
                   "timestamp" : normv(buf.timestamp),
                   "duration"  : normv(buf.duration)
                   }
        d = self.caps_to_dict(buf.get_caps())
        if not self.caps or self.caps!=d:
            self.caps = d
            self.info["caps"] = self.caps
            metadata["caps"] = self.caps
        return self.emit_buffer(buf.data, metadata)

    def emit_buffer(self, data, metadata={}):
        f = self.file
        if f and data:
            self.file.write(data)
            self.file.flush()
        if self.state=="stopped":
            #don't bother
            return 0
        if JITTER>0:
            #will actually emit the buffer after a random delay
            if self.jitter_queue.empty():
                #queue was empty, schedule a timer to flush it
                from random import randint
                jitter = randint(1, JITTER)
                self.timeout_add(jitter, self.flush_jitter_queue)
                log("emit_buffer: will flush jitter queue in %ims", jitter)
            self.jitter_queue.put((data, metadata))
            return 0
        log("emit_buffer data=%s, len=%i, metadata=%s", type(data), len(data), metadata)
        return self.do_emit_buffer(data, metadata)

    def flush_jitter_queue(self):
        while not self.jitter_queue.empty():
            d,m = self.jitter_queue.get(False)
            self.do_emit_buffer(d, m)

    def do_emit_buffer(self, data, metadata={}):
        self.inc_buffer_count()
        self.inc_byte_count(len(data))
        metadata["time"] = int(time.time()*1000)
        self.idle_emit("new-buffer", data, metadata)
        self.emit_info()
        return 0
예제 #5
0
class SoundSource(SoundPipeline):

    __gsignals__ = SoundPipeline.__generic_signals__.copy()
    __gsignals__.update({
        "new-buffer": n_arg_signal(3),
    })

    def __init__(self,
                 src_type=None,
                 src_options={},
                 codecs=get_encoders(),
                 codec_options={},
                 volume=1.0):
        if not src_type:
            try:
                from xpra.sound.pulseaudio.pulseaudio_util import get_pa_device_options
                monitor_devices = get_pa_device_options(True, False)
                log.info("found pulseaudio monitor devices: %s",
                         monitor_devices)
            except ImportError as e:
                log.warn("Warning: pulseaudio is not available!")
                log.warn(" %s", e)
                monitor_devices = []
            if len(monitor_devices) == 0:
                log.warn("could not detect any pulseaudio monitor devices")
                log.warn(" a test source will be used instead")
                src_type = "audiotestsrc"
                default_src_options = {"wave": 2, "freq": 100, "volume": 0.4}
            else:
                monitor_device = monitor_devices.items()[0][0]
                log.info("using pulseaudio source device:")
                log.info(" '%s'", monitor_device)
                src_type = "pulsesrc"
                default_src_options = {"device": monitor_device}
            src_options = default_src_options
        if src_type not in get_source_plugins():
            raise InitExit(
                1, "invalid source plugin '%s', valid options are: %s" %
                (src_type, ",".join(get_source_plugins())))
        matching = [
            x for x in CODEC_ORDER if (x in codecs and x in get_encoders())
        ]
        log("SoundSource(..) found matching codecs %s", matching)
        if not matching:
            raise InitExit(
                1,
                "no matching codecs between arguments '%s' and supported list '%s'"
                % (csv(codecs), csv(get_encoders().keys())))
        codec = matching[0]
        encoder, fmt, stream_compressor = get_encoder_elements(codec)
        SoundPipeline.__init__(self, codec)
        self.queue = None
        self.caps = None
        self.volume = None
        self.sink = None
        self.src = None
        self.src_type = src_type
        self.pending_metadata = []
        self.buffer_latency = True
        self.jitter_queue = None
        self.file = None
        self.container_format = (fmt or "").replace("mux",
                                                    "").replace("pay", "")
        self.stream_compressor = stream_compressor
        src_options["name"] = "src"
        source_str = plugin_str(src_type, src_options)
        #FIXME: this is ugly and relies on the fact that we don't pass any codec options to work!
        pipeline_els = [source_str]
        if SOURCE_QUEUE_TIME > 0:
            queue_el = [
                "queue", "name=queue", "min-threshold-time=0",
                "max-size-buffers=0", "max-size-bytes=0",
                "max-size-time=%s" % (SOURCE_QUEUE_TIME * MS_TO_NS),
                "leaky=%s" % GST_QUEUE_LEAK_DOWNSTREAM
            ]
            pipeline_els += [" ".join(queue_el)]
        if encoder in ENCODER_NEEDS_AUDIOCONVERT or src_type in SOURCE_NEEDS_AUDIOCONVERT:
            pipeline_els += ["audioconvert"]
        pipeline_els.append("volume name=volume volume=%s" % volume)
        if encoder:
            encoder_str = plugin_str(
                encoder, codec_options or get_encoder_default_options(encoder))
            pipeline_els.append(encoder_str)
        if fmt:
            fmt_str = plugin_str(fmt, MUXER_DEFAULT_OPTIONS.get(fmt, {}))
            pipeline_els.append(fmt_str)
        pipeline_els.append(APPSINK)
        if not self.setup_pipeline_and_bus(pipeline_els):
            return
        self.volume = self.pipeline.get_by_name("volume")
        self.sink = self.pipeline.get_by_name("sink")
        if SOURCE_QUEUE_TIME > 0:
            self.queue = self.pipeline.get_by_name("queue")
        if self.queue:
            try:
                self.queue.set_property("silent", True)
            except Exception as e:
                log("cannot make queue silent: %s", e)
        try:
            if get_gst_version() < (1, 0):
                self.sink.set_property("enable-last-buffer", False)
            else:
                self.sink.set_property("enable-last-sample", False)
        except Exception as e:
            log("failed to disable last buffer: %s", e)
        self.skipped_caps = set()
        if JITTER > 0:
            self.jitter_queue = Queue()
        try:
            #Gst 1.0:
            self.sink.connect("new-sample", self.on_new_sample)
            self.sink.connect("new-preroll", self.on_new_preroll1)
        except:
            #Gst 0.10:
            self.sink.connect("new-buffer", self.on_new_buffer)
            self.sink.connect("new-preroll", self.on_new_preroll0)
        self.src = self.pipeline.get_by_name("src")
        try:
            for x in ("actual-buffer-time", "actual-latency-time"):
                #don't comment this out, it is used to verify the attributes are present:
                try:
                    gstlog("initial %s: %s", x, self.src.get_property(x))
                except Exception as e:
                    self.buffer_latency = False
        except Exception as e:
            log.info(
                "source %s does not support 'buffer-time' or 'latency-time':",
                self.src_type)
            log.info(" %s", e)
        else:
            #if the env vars have been set, try to honour the settings:
            global BUFFER_TIME, LATENCY_TIME
            if BUFFER_TIME > 0:
                if BUFFER_TIME < LATENCY_TIME:
                    log.warn(
                        "Warning: latency (%ims) must be lower than the buffer time (%ims)",
                        LATENCY_TIME, BUFFER_TIME)
                else:
                    log(
                        "latency tuning for %s, will try to set buffer-time=%i, latency-time=%i",
                        src_type, BUFFER_TIME, LATENCY_TIME)

                    def settime(attr, v):
                        try:
                            cval = self.src.get_property(attr)
                            gstlog("default: %s=%i", attr, cval // 1000)
                            if v >= 0:
                                self.src.set_property(attr, v * 1000)
                                gstlog("overriding with: %s=%i", attr, v)
                        except Exception as e:
                            log.warn("source %s does not support '%s': %s",
                                     self.src_type, attr, e)

                    settime("buffer-time", BUFFER_TIME)
                    settime("latency-time", LATENCY_TIME)
        gen = generation.increase()
        if SAVE_TO_FILE is not None:
            parts = codec.split("+")
            if len(parts) > 1:
                filename = SAVE_TO_FILE + str(
                    gen) + "-" + parts[0] + ".%s" % parts[1]
            else:
                filename = SAVE_TO_FILE + str(gen) + ".%s" % codec
            self.file = open(filename, 'wb')
            log.info("saving %s stream to %s", codec, filename)

    def __repr__(self):
        return "SoundSource('%s' - %s)" % (self.pipeline_str, self.state)

    def cleanup(self):
        SoundPipeline.cleanup(self)
        self.src_type = ""
        self.sink = None
        self.caps = None
        f = self.file
        if f:
            self.file = None
            f.close()

    def get_info(self):
        info = SoundPipeline.get_info(self)
        if self.queue:
            info["queue"] = {
                "cur":
                self.queue.get_property("current-level-time") // MS_TO_NS
            }
        if self.buffer_latency:
            for x in ("actual-buffer-time", "actual-latency-time"):
                v = self.src.get_property(x)
                if v >= 0:
                    info[x] = v
        return info

    def on_new_preroll1(self, appsink):
        gstlog('new preroll')
        return 0

    def on_new_sample(self, bus):
        #Gst 1.0
        sample = self.sink.emit("pull-sample")
        return self.emit_buffer1(sample)

    def emit_buffer1(self, sample):
        buf = sample.get_buffer()
        #info = sample.get_info()
        size = buf.get_size()
        extract_dup = getattr(buf, "extract_dup", None)
        if extract_dup:
            data = extract_dup(0, size)
        else:
            #crappy gi bindings detected, using workaround:
            from xpra.sound.gst_hacks import map_gst_buffer
            with map_gst_buffer(buf) as a:
                data = bytes(a[:])
        pts = normv(buf.pts)
        duration = normv(buf.duration)
        if pts == -1 and duration == -1 and BUNDLE_METADATA and len(
                self.pending_metadata) < 10:
            self.pending_metadata.append(data)
            return 0
        return self.emit_buffer(data, {
            "timestamp": pts,
            "duration": duration,
        })

    def on_new_preroll0(self, appsink):
        gstlog('new preroll')
        return 0

    def on_new_buffer(self, bus):
        #pygst 0.10
        buf = self.sink.emit("pull-buffer")
        return self.emit_buffer0(buf)

    def caps_to_dict(self, caps):
        if not caps:
            return {}
        d = {}
        try:
            for cap in caps:
                name = cap.get_name()
                capd = {}
                for k in cap.keys():
                    v = cap[k]
                    if type(v) in (str, int):
                        capd[k] = cap[k]
                    elif k not in self.skipped_caps:
                        log("skipping %s cap key %s=%s of type %s", name, k, v,
                            type(v))
                d[name] = capd
        except Exception as e:
            log.error("Error parsing '%s':", caps)
            log.error(" %s", e)
        return d

    def emit_buffer0(self, buf):
        """ convert pygst structure into something more generic for the wire """
        #none of the metadata is really needed at present, but it may be in the future:
        #metadata = {"caps"      : buf.get_caps().to_string(),
        #            "size"      : buf.size,
        #            "timestamp" : buf.timestamp,
        #            "duration"  : buf.duration,
        #            "offset"    : buf.offset,
        #            "offset_end": buf.offset_end}
        log("emit buffer: %s bytes, timestamp=%s", len(buf.data),
            buf.timestamp // MS_TO_NS)
        metadata = {
            "timestamp": normv(buf.timestamp),
            "duration": normv(buf.duration)
        }
        d = self.caps_to_dict(buf.get_caps())
        if not self.caps or self.caps != d:
            self.caps = d
            self.info["caps"] = self.caps
            metadata["caps"] = self.caps
        return self.emit_buffer(buf.data, metadata)

    def emit_buffer(self, data, metadata={}):
        if self.stream_compressor and data:
            data = compressed_wrapper("sound",
                                      data,
                                      level=9,
                                      zlib=False,
                                      lz4=(self.stream_compressor == "lz4"),
                                      lzo=(self.stream_compressor == "lzo"),
                                      can_inline=True)
            #log("compressed using %s from %i bytes down to %i bytes", self.stream_compressor, len(odata), len(data))
            metadata["compress"] = self.stream_compressor
        f = self.file
        if f:
            for x in self.pending_metadata:
                self.file.write(x)
            if data:
                self.file.write(data)
            self.file.flush()
        if self.state == "stopped":
            #don't bother
            return 0
        if JITTER > 0:
            #will actually emit the buffer after a random delay
            if self.jitter_queue.empty():
                #queue was empty, schedule a timer to flush it
                from random import randint
                jitter = randint(1, JITTER)
                self.timeout_add(jitter, self.flush_jitter_queue)
                log("emit_buffer: will flush jitter queue in %ims", jitter)
            for x in self.pending_metadata:
                self.jitter_queue.put((x, {}))
            self.pending_metadata = []
            self.jitter_queue.put((data, metadata))
            return 0
        log("emit_buffer data=%s, len=%i, metadata=%s", type(data), len(data),
            metadata)
        return self.do_emit_buffer(data, metadata)

    def flush_jitter_queue(self):
        while not self.jitter_queue.empty():
            d, m = self.jitter_queue.get(False)
            self.do_emit_buffer(d, m)

    def do_emit_buffer(self, data, metadata={}):
        self.inc_buffer_count()
        self.inc_byte_count(len(data))
        for x in self.pending_metadata:
            self.inc_buffer_count()
            self.inc_byte_count(len(x))
        metadata["time"] = int(time.time() * 1000)
        self.idle_emit("new-buffer", data, metadata, self.pending_metadata)
        self.pending_metadata = []
        self.emit_info()
        return 0
예제 #6
0
파일: src.py 프로젝트: rudresh2319/Xpra
class SoundSource(SoundPipeline):

    __gsignals__ = SoundPipeline.__generic_signals__.copy()
    __gsignals__.update({
        "new-buffer"    : n_arg_signal(2),
        })

    def __init__(self, src_type=None, src_options={}, codecs=get_codecs(), codec_options={}, volume=1.0):
        if not src_type:
            from xpra.sound.pulseaudio_util import get_pa_device_options
            monitor_devices = get_pa_device_options(True, False)
            log.info("found pulseaudio monitor devices: %s", monitor_devices)
            if len(monitor_devices)==0:
                log.warn("could not detect any pulseaudio monitor devices")
                log.warn(" a test source will be used instead")
                src_type = "audiotestsrc"
                default_src_options = {"wave":2, "freq":100, "volume":0.4}
            else:
                monitor_device = monitor_devices.items()[0][0]
                log.info("using pulseaudio source device:")
                log.info(" '%s'", monitor_device)
                src_type = "pulsesrc"
                default_src_options = {"device" : monitor_device}
            src_options = default_src_options
        if src_type not in get_source_plugins():
            raise InitExit(1, "invalid source plugin '%s', valid options are: %s" % (src_type, ",".join(get_source_plugins())))
        matching = [x for x in CODEC_ORDER if (x in codecs and x in get_codecs())]
        log("SoundSource(..) found matching codecs %s", matching)
        if not matching:
            raise InitExit(1, "no matching codecs between arguments '%s' and supported list '%s'" % (csv(codecs), csv(get_codecs().keys())))
        codec = matching[0]
        encoder, fmt = get_encoder_formatter(codec)
        SoundPipeline.__init__(self, codec)
        self.src_type = src_type
        source_str = plugin_str(src_type, src_options)
        #FIXME: this is ugly and relies on the fact that we don't pass any codec options to work!
        encoder_str = plugin_str(encoder, codec_options or ENCODER_DEFAULT_OPTIONS.get(encoder, {}))
        fmt_str = plugin_str(fmt, MUXER_DEFAULT_OPTIONS.get(fmt, {}))
        pipeline_els = [source_str]
        if encoder in ENCODER_NEEDS_AUDIOCONVERT or src_type in SOURCE_NEEDS_AUDIOCONVERT:
            pipeline_els += ["audioconvert"]
        pipeline_els.append("volume name=volume volume=%s" % volume)
        pipeline_els += [encoder_str,
                        fmt_str,
                        APPSINK]
        self.setup_pipeline_and_bus(pipeline_els)
        self.volume = self.pipeline.get_by_name("volume")
        self.sink = self.pipeline.get_by_name("sink")
        try:
            if get_gst_version()<(1,0):
                self.sink.set_property("enable-last-buffer", False)
            else:
                self.sink.set_property("enable-last-sample", False)
        except Exception as e:
            log("failed to disable last buffer: %s", e)
        self.caps = None
        self.skipped_caps = set()
        if JITTER>0:
            self.jitter_queue = Queue()
        try:
            #Gst 1.0:
            self.sink.connect("new-sample", self.on_new_sample)
            self.sink.connect("new-preroll", self.on_new_preroll1)
        except:
            #Gst 0.10:
            self.sink.connect("new-buffer", self.on_new_buffer)
            self.sink.connect("new-preroll", self.on_new_preroll0)

    def __repr__(self):
        return "SoundSource('%s' - %s)" % (self.pipeline_str, self.state)

    def cleanup(self):
        SoundPipeline.cleanup(self)
        self.src_type = ""
        self.sink = None
        self.caps = None

    def get_info(self):
        info = SoundPipeline.get_info(self)
        if self.caps:
            info["caps"] = self.caps
        return info


    def on_new_preroll1(self, appsink):
        sample = appsink.emit('pull-preroll')
        log('new preroll1: %s', sample)
        return self.emit_buffer1(sample)

    def on_new_sample(self, bus):
        #Gst 1.0
        sample = self.sink.emit("pull-sample")
        return self.emit_buffer1(sample)

    def emit_buffer1(self, sample):
        buf = sample.get_buffer()
        #info = sample.get_info()
        size = buf.get_size()
        extract_dup = getattr(buf, "extract_dup", None)
        if extract_dup:
            data = extract_dup(0, size)
        else:
            #crappy gi bindings detected, using workaround:
            from xpra.sound.gst_hacks import map_gst_buffer
            with map_gst_buffer(buf) as a:
                data = bytes(a[:])
        return self.emit_buffer(data, {"timestamp"  : normv(buf.pts),
                                   "duration"   : normv(buf.duration),
                                   })


    def on_new_preroll0(self, appsink):
        buf = appsink.emit('pull-preroll')
        log('new preroll0: %s bytes', len(buf))
        return self.emit_buffer0(buf)

    def on_new_buffer(self, bus):
        #pygst 0.10
        buf = self.sink.emit("pull-buffer")
        return self.emit_buffer0(buf)


    def caps_to_dict(self, caps):
        if not caps:
            return {}
        d = {}
        try:
            for cap in caps:
                name = cap.get_name()
                capd = {}
                for k in cap.keys():
                    v = cap[k]
                    if type(v) in (str, int):
                        capd[k] = cap[k]
                    elif k not in self.skipped_caps:
                        log("skipping %s cap key %s=%s of type %s", name, k, v, type(v))
                d[name] = capd
        except Exception as e:
            log.error("Error parsing '%s':", caps)
            log.error(" %s", e)
        return d

    def emit_buffer0(self, buf):
        """ convert pygst structure into something more generic for the wire """
        #none of the metadata is really needed at present, but it may be in the future:
        #metadata = {"caps"      : buf.get_caps().to_string(),
        #            "size"      : buf.size,
        #            "timestamp" : buf.timestamp,
        #            "duration"  : buf.duration,
        #            "offset"    : buf.offset,
        #            "offset_end": buf.offset_end}
        log("emit buffer: %s bytes, timestamp=%s", len(buf.data), buf.timestamp//MS_TO_NS)
        metadata = {
                   "timestamp" : normv(buf.timestamp),
                   "duration"  : normv(buf.duration)
                   }
        d = self.caps_to_dict(buf.get_caps())
        if not self.caps or self.caps!=d:
            self.caps = d
            metadata["caps"] = self.caps
        return self.emit_buffer(buf.data, metadata)

    def emit_buffer(self, data, metadata={}):
        if JITTER>0:
            #will actually emit the buffer after a random delay
            if self.jitter_queue.empty():
                #queue was empty, schedule a timer to flush it
                from random import randint
                jitter = randint(1, JITTER)
                self.timeout_add(jitter, self.flush_jitter_queue)
                log("emit_buffer: will flush jitter queue in %ims", jitter)
            self.jitter_queue.put((data, metadata))
            return 0
        log("emit_buffer data=%s, len=%i, metadata=%s", type(data), len(data), metadata)
        return self.do_emit_buffer(data, metadata)

    def flush_jitter_queue(self):
        while not self.jitter_queue.empty():
            d,m = self.jitter_queue.get(False)
            self.do_emit_buffer(d, m)

    def do_emit_buffer(self, data, metadata={}):
        self.buffer_count += 1
        self.byte_count += len(data)
        metadata["time"] = int(time.time()*1000)
        self.idle_emit("new-buffer", data, metadata)
        self.emit_info()
        return 0
예제 #7
0
class SoundSource(SoundPipeline):

    __gsignals__ = SoundPipeline.__generic_signals__.copy()
    __gsignals__.update({
        "new-buffer": n_arg_signal(3),
    })

    def __init__(self,
                 src_type=None,
                 src_options={},
                 codecs=get_encoders(),
                 codec_options={},
                 volume=1.0):
        if not src_type:
            try:
                from xpra.sound.pulseaudio.pulseaudio_util import get_pa_device_options
                monitor_devices = get_pa_device_options(True, False)
                log.info("found pulseaudio monitor devices: %s",
                         monitor_devices)
            except ImportError as e:
                log.warn("Warning: pulseaudio is not available!")
                log.warn(" %s", e)
                monitor_devices = []
            if len(monitor_devices) == 0:
                log.warn("could not detect any pulseaudio monitor devices")
                log.warn(" a test source will be used instead")
                src_type = "audiotestsrc"
                default_src_options = {"wave": 2, "freq": 100, "volume": 0.4}
            else:
                monitor_device = monitor_devices.items()[0][0]
                log.info("using pulseaudio source device:")
                log.info(" '%s'", monitor_device)
                src_type = "pulsesrc"
                default_src_options = {"device": monitor_device}
            src_options = default_src_options
        if src_type not in get_source_plugins():
            raise InitExit(
                1, "invalid source plugin '%s', valid options are: %s" %
                (src_type, ",".join(get_source_plugins())))
        matching = [
            x for x in CODEC_ORDER if (x in codecs and x in get_encoders())
        ]
        log("SoundSource(..) found matching codecs %s", matching)
        if not matching:
            raise InitExit(
                1,
                "no matching codecs between arguments '%s' and supported list '%s'"
                % (csv(codecs), csv(get_encoders().keys())))
        codec = matching[0]
        encoder, fmt, stream_compressor = get_encoder_elements(codec)
        SoundPipeline.__init__(self, codec)
        self.queue = None
        self.caps = None
        self.volume = None
        self.sink = None
        self.src = None
        self.src_type = src_type
        self.timestamp = None
        self.min_timestamp = 0
        self.max_timestamp = 0
        self.pending_metadata = []
        self.buffer_latency = True
        self.jitter_queue = None
        self.file = None
        self.container_format = (fmt or "").replace("mux",
                                                    "").replace("pay", "")
        self.stream_compressor = stream_compressor
        src_options["name"] = "src"
        source_str = plugin_str(src_type, src_options)
        #FIXME: this is ugly and relies on the fact that we don't pass any codec options to work!
        pipeline_els = [source_str]
        log("has plugin(timestamp)=%s", has_plugins("timestamp"))
        if has_plugins("timestamp"):
            pipeline_els.append("timestamp name=timestamp")
        if SOURCE_QUEUE_TIME > 0:
            queue_el = [
                "queue", "name=queue", "min-threshold-time=0",
                "max-size-buffers=0", "max-size-bytes=0",
                "max-size-time=%s" % (SOURCE_QUEUE_TIME * MS_TO_NS),
                "leaky=%s" % GST_QUEUE_LEAK_DOWNSTREAM
            ]
            pipeline_els += [" ".join(queue_el)]
        if encoder in ENCODER_NEEDS_AUDIOCONVERT or src_type in SOURCE_NEEDS_AUDIOCONVERT:
            pipeline_els += ["audioconvert"]
        if CUTTER_THRESHOLD > 0 and encoder not in ENCODER_CANNOT_USE_CUTTER and not fmt:
            pipeline_els.append(
                "cutter threshold=%.4f run-length=%i pre-length=%i leaky=false name=cutter"
                % (CUTTER_THRESHOLD, CUTTER_RUN_LENGTH * MS_TO_NS,
                   CUTTER_PRE_LENGTH * MS_TO_NS))
            if encoder in CUTTER_NEEDS_CONVERT:
                pipeline_els.append("audioconvert")
            if encoder in CUTTER_NEEDS_RESAMPLE:
                pipeline_els.append("audioresample")
        pipeline_els.append("volume name=volume volume=%s" % volume)
        if encoder:
            encoder_str = plugin_str(
                encoder, codec_options or get_encoder_default_options(encoder))
            pipeline_els.append(encoder_str)
        if fmt:
            fmt_str = plugin_str(fmt, MUXER_DEFAULT_OPTIONS.get(fmt, {}))
            pipeline_els.append(fmt_str)
        pipeline_els.append(APPSINK)
        if not self.setup_pipeline_and_bus(pipeline_els):
            return
        self.timestamp = self.pipeline.get_by_name("timestamp")
        self.volume = self.pipeline.get_by_name("volume")
        self.sink = self.pipeline.get_by_name("sink")
        if SOURCE_QUEUE_TIME > 0:
            self.queue = self.pipeline.get_by_name("queue")
        if self.queue:
            try:
                self.queue.set_property("silent", True)
            except Exception as e:
                log("cannot make queue silent: %s", e)
        self.sink.set_property("enable-last-sample", False)
        self.skipped_caps = set()
        if JITTER > 0:
            self.jitter_queue = Queue()
        #Gst 1.0:
        self.sink.connect("new-sample", self.on_new_sample)
        self.sink.connect("new-preroll", self.on_new_preroll)
        self.src = self.pipeline.get_by_name("src")
        for x in ("actual-buffer-time", "actual-latency-time"):
            try:
                gstlog("initial %s: %s", x, self.src.get_property(x))
            except Exception as e:
                gstlog("no %s property on %s: %s", x, self.src, e)
                self.buffer_latency = False
        #if the env vars have been set, try to honour the settings:
        global BUFFER_TIME, LATENCY_TIME
        if BUFFER_TIME > 0:
            if BUFFER_TIME < LATENCY_TIME:
                log.warn(
                    "Warning: latency (%ims) must be lower than the buffer time (%ims)",
                    LATENCY_TIME, BUFFER_TIME)
            else:
                log(
                    "latency tuning for %s, will try to set buffer-time=%i, latency-time=%i",
                    src_type, BUFFER_TIME, LATENCY_TIME)

                def settime(attr, v):
                    try:
                        cval = self.src.get_property(attr)
                        gstlog("default: %s=%i", attr, cval // 1000)
                        if v >= 0:
                            self.src.set_property(attr, v * 1000)
                            gstlog("overriding with: %s=%i", attr, v)
                    except Exception as e:
                        log.warn("source %s does not support '%s': %s",
                                 self.src_type, attr, e)

                settime("buffer-time", BUFFER_TIME)
                settime("latency-time", LATENCY_TIME)
        gen = generation.increase()
        if SAVE_TO_FILE is not None:
            parts = codec.split("+")
            if len(parts) > 1:
                filename = SAVE_TO_FILE + str(
                    gen) + "-" + parts[0] + ".%s" % parts[1]
            else:
                filename = SAVE_TO_FILE + str(gen) + ".%s" % codec
            self.file = open(filename, 'wb')
            log.info("saving %s stream to %s", codec, filename)

    def __repr__(self):
        return "SoundSource('%s' - %s)" % (self.pipeline_str, self.state)

    def cleanup(self):
        SoundPipeline.cleanup(self)
        self.src_type = ""
        self.sink = None
        self.caps = None
        f = self.file
        if f:
            self.file = None
            f.close()

    def get_info(self):
        info = SoundPipeline.get_info(self)
        if self.queue:
            info["queue"] = {
                "cur":
                self.queue.get_property("current-level-time") // MS_TO_NS
            }
        if CUTTER_THRESHOLD > 0 and (self.min_timestamp or self.max_timestamp):
            info["cutter.min-timestamp"] = self.min_timestamp
            info["cutter.max-timestamp"] = self.max_timestamp
        if self.buffer_latency:
            for x in ("actual-buffer-time", "actual-latency-time"):
                v = self.src.get_property(x)
                if v >= 0:
                    info[x] = v
        return info

    def do_parse_element_message(self, _message, name, props={}):
        if name == "cutter":
            above = props.get("above")
            ts = props.get("timestamp", 0)
            if above is False:
                self.max_timestamp = ts
                self.min_timestamp = 0
            elif above is True:
                self.max_timestamp = 0
                self.min_timestamp = ts
            if LOG_CUTTER:
                l = gstlog.info
            else:
                l = gstlog
            l("cutter message, above=%s, min-timestamp=%s, max-timestamp=%s",
              above, self.min_timestamp, self.max_timestamp)

    def on_new_preroll(self, _appsink):
        gstlog('new preroll')
        return 0

    def on_new_sample(self, _bus):
        #Gst 1.0
        sample = self.sink.emit("pull-sample")
        return self.emit_buffer(sample)

    def emit_buffer(self, sample):
        buf = sample.get_buffer()
        pts = normv(buf.pts)
        if self.min_timestamp > 0 and pts < self.min_timestamp:
            gstlog("cutter: skipping buffer with pts=%s (min-timestamp=%s)",
                   pts, self.min_timestamp)
            return 0
        elif self.max_timestamp > 0 and pts > self.max_timestamp:
            gstlog("cutter: skipping buffer with pts=%s (max-timestamp=%s)",
                   pts, self.max_timestamp)
            return 0
        size = buf.get_size()
        data = buf.extract_dup(0, size)
        duration = normv(buf.duration)
        metadata = {
            "timestamp": pts,
            "duration": duration,
        }
        if self.timestamp:
            delta = self.timestamp.get_property("delta")
            ts = (pts + delta) // 1000000  #ns to ms
            now = monotonic_time()
            latency = int(1000 * now) - ts
            #log.info("emit_buffer: delta=%i, pts=%i, ts=%s, time=%s, latency=%ims", delta, pts, ts, now, (latency//1000000))
            ts_info = {
                "ts": ts,
                "latency": latency,
            }
            metadata.update(ts_info)
            self.info.update(ts_info)
        if pts == -1 and duration == -1 and BUNDLE_METADATA and len(
                self.pending_metadata) < 10:
            self.pending_metadata.append(data)
            return 0
        return self._emit_buffer(data, metadata)

    def _emit_buffer(self, data, metadata={}):
        if self.stream_compressor and data:
            cdata = compressed_wrapper("sound",
                                       data,
                                       level=9,
                                       zlib=False,
                                       lz4=(self.stream_compressor == "lz4"),
                                       lzo=(self.stream_compressor == "lzo"),
                                       can_inline=True)
            if len(cdata) < len(data) * 90 // 100:
                log("compressed using %s from %i bytes down to %i bytes",
                    self.stream_compressor, len(data), len(cdata))
                metadata["compress"] = self.stream_compressor
                data = cdata
            else:
                log(
                    "skipped inefficient %s stream compression: %i bytes down to %i bytes",
                    self.stream_compressor, len(data), len(cdata))
        f = self.file
        if f:
            for x in self.pending_metadata:
                self.file.write(x)
            if data:
                self.file.write(data)
            self.file.flush()
        if self.state == "stopped":
            #don't bother
            return 0
        if JITTER > 0:
            #will actually emit the buffer after a random delay
            if self.jitter_queue.empty():
                #queue was empty, schedule a timer to flush it
                from random import randint
                jitter = randint(1, JITTER)
                self.timeout_add(jitter, self.flush_jitter_queue)
                log("emit_buffer: will flush jitter queue in %ims", jitter)
            for x in self.pending_metadata:
                self.jitter_queue.put((x, {}))
            self.pending_metadata = []
            self.jitter_queue.put((data, metadata))
            return 0
        log("emit_buffer data=%s, len=%i, metadata=%s", type(data), len(data),
            metadata)
        return self.do_emit_buffer(data, metadata)

    def caps_to_dict(self, caps):
        if not caps:
            return {}
        d = {}
        try:
            for cap in caps:
                name = cap.get_name()
                capd = {}
                for k in cap.keys():
                    v = cap[k]
                    if type(v) in (str, int):
                        capd[k] = cap[k]
                    elif k not in self.skipped_caps:
                        log("skipping %s cap key %s=%s of type %s", name, k, v,
                            type(v))
                d[name] = capd
        except Exception as e:
            log.error("Error parsing '%s':", caps)
            log.error(" %s", e)
        return d

    def flush_jitter_queue(self):
        while not self.jitter_queue.empty():
            d, m = self.jitter_queue.get(False)
            self.do_emit_buffer(d, m)

    def do_emit_buffer(self, data, metadata={}):
        self.inc_buffer_count()
        self.inc_byte_count(len(data))
        for x in self.pending_metadata:
            self.inc_buffer_count()
            self.inc_byte_count(len(x))
        metadata["time"] = int(monotonic_time() * 1000)
        self.idle_emit("new-buffer", data, metadata, self.pending_metadata)
        self.pending_metadata = []
        self.emit_info()
        return 0
예제 #8
0
파일: protocol.py 프로젝트: svn2github/Xpra
class Protocol(object):
    CONNECTION_LOST = "connection-lost"
    GIBBERISH = "gibberish"
    INVALID = "invalid"

    def __init__(self, scheduler, conn, process_packet_cb, get_packet_cb=None):
        """
            You must call this constructor and source_has_more() from the main thread.
        """
        assert scheduler is not None
        assert conn is not None
        self.timeout_add = scheduler.timeout_add
        self.idle_add = scheduler.idle_add
        self._conn = conn
        if FAKE_JITTER > 0:
            from xpra.net.fake_jitter import FakeJitter

            fj = FakeJitter(self.timeout_add, process_packet_cb)
            self._process_packet_cb = fj.process_packet_cb
        else:
            self._process_packet_cb = process_packet_cb
        self._write_queue = Queue(1)
        self._read_queue = Queue(20)
        self._read_queue_put = self._read_queue.put
        # Invariant: if .source is None, then _source_has_more == False
        self._get_packet_cb = get_packet_cb
        # counters:
        self.input_stats = {}
        self.input_packetcount = 0
        self.input_raw_packetcount = 0
        self.output_stats = {}
        self.output_packetcount = 0
        self.output_raw_packetcount = 0
        # initial value which may get increased by client/server after handshake:
        self.max_packet_size = 256 * 1024
        self.abs_max_packet_size = 256 * 1024 * 1024
        self.large_packets = ["hello"]
        self.send_aliases = {}
        self.receive_aliases = {}
        self._log_stats = None  # None here means auto-detect
        self._closed = False
        self.encoder = "none"
        self._encoder = self.noencode
        self.compressor = "none"
        self._compress = compression.nocompress
        self.compression_level = 0
        self.cipher_in = None
        self.cipher_in_name = None
        self.cipher_in_block_size = 0
        self.cipher_out = None
        self.cipher_out_name = None
        self.cipher_out_block_size = 0
        self._write_lock = Lock()
        from xpra.daemon_thread import make_daemon_thread

        self._write_thread = make_daemon_thread(self._write_thread_loop, "write")
        self._read_thread = make_daemon_thread(self._read_thread_loop, "read")
        self._read_parser_thread = make_daemon_thread(self._read_parse_thread_loop, "parse")
        self._write_format_thread = make_daemon_thread(self._write_format_thread_loop, "format")
        self._source_has_more = threading.Event()

    STATE_FIELDS = (
        "max_packet_size",
        "large_packets",
        "send_aliases",
        "receive_aliases",
        "cipher_in",
        "cipher_in_name",
        "cipher_in_block_size",
        "cipher_out",
        "cipher_out_name",
        "cipher_out_block_size",
        "compression_level",
        "encoder",
        "compressor",
    )

    def save_state(self):
        state = {}
        for x in Protocol.STATE_FIELDS:
            state[x] = getattr(self, x)
        return state

    def restore_state(self, state):
        assert state is not None
        for x in Protocol.STATE_FIELDS:
            assert x in state, "field %s is missing" % x
            setattr(self, x, state[x])
        # special handling for compressor / encoder which are named objects:
        self.enable_compressor(self.compressor)
        self.enable_encoder(self.encoder)

    def wait_for_io_threads_exit(self, timeout=None):
        for t in (self._read_thread, self._write_thread):
            t.join(timeout)
        exited = True
        for t in (self._read_thread, self._write_thread):
            if t.isAlive():
                log.warn("%s thread of %s has not yet exited (timeout=%s)", t.name, self._conn, timeout)
                exited = False
                break
        return exited

    def set_packet_source(self, get_packet_cb):
        self._get_packet_cb = get_packet_cb

    def set_cipher_in(self, ciphername, iv, password, key_salt, iterations):
        if self.cipher_in_name != ciphername:
            log.info("receiving data using %s encryption", ciphername)
            self.cipher_in_name = ciphername
        log("set_cipher_in%s", (ciphername, iv, password, key_salt, iterations))
        self.cipher_in, self.cipher_in_block_size = get_cipher(ciphername, iv, password, key_salt, iterations)

    def set_cipher_out(self, ciphername, iv, password, key_salt, iterations):
        if self.cipher_out_name != ciphername:
            log.info("sending data using %s encryption", ciphername)
            self.cipher_out_name = ciphername
        log("set_cipher_out%s", (ciphername, iv, password, key_salt, iterations))
        self.cipher_out, self.cipher_out_block_size = get_cipher(ciphername, iv, password, key_salt, iterations)

    def __repr__(self):
        return "Protocol(%s)" % self._conn

    def get_threads(self):
        return [
            x
            for x in [self._write_thread, self._read_thread, self._read_parser_thread, self._write_format_thread]
            if x is not None
        ]

    def get_info(self, alias_info=True):
        info = {
            "input.packetcount": self.input_packetcount,
            "input.raw_packetcount": self.input_raw_packetcount,
            "input.cipher": self.cipher_in_name or "",
            "output.packetcount": self.output_packetcount,
            "output.raw_packetcount": self.output_raw_packetcount,
            "output.cipher": self.cipher_out_name or "",
            "large_packets": self.large_packets,
            "compression_level": self.compression_level,
            "max_packet_size": self.max_packet_size,
        }
        updict(info, "input.count", self.input_stats)
        updict(info, "output.count", self.output_stats)
        c = self._compress
        if c:
            info["compressor"] = compression.get_compressor_name(self._compress)
        e = self._encoder
        if e:
            if self._encoder == self.noencode:
                info["encoder"] = "noencode"
            else:
                info["encoder"] = packet_encoding.get_encoder_name(self._encoder)
        if alias_info:
            for k, v in self.send_aliases.items():
                info["send_alias." + str(k)] = v
                info["send_alias." + str(v)] = k
            for k, v in self.receive_aliases.items():
                info["receive_alias." + str(k)] = v
                info["receive_alias." + str(v)] = k
        c = self._conn
        if c:
            try:
                info.update(self._conn.get_info())
            except:
                log.error("error collecting connection information on %s", self._conn, exc_info=True)
        info["has_more"] = self._source_has_more.is_set()
        for t in (self._write_thread, self._read_thread, self._read_parser_thread, self._write_format_thread):
            if t:
                info["thread.%s" % t.name] = t.is_alive()
        return info

    def start(self):
        def do_start():
            if not self._closed:
                self._write_thread.start()
                self._read_thread.start()
                self._read_parser_thread.start()
                self._write_format_thread.start()

        self.idle_add(do_start)

    def send_now(self, packet):
        if self._closed:
            log("send_now(%s ...) connection is closed already, not sending", packet[0])
            return
        log("send_now(%s ...)", packet[0])
        assert self._get_packet_cb == None, (
            "cannot use send_now when a packet source exists! (set to %s)" % self._get_packet_cb
        )

        def packet_cb():
            self._get_packet_cb = None
            return (packet,)

        self._get_packet_cb = packet_cb
        self.source_has_more()

    def source_has_more(self):
        self._source_has_more.set()

    def _write_format_thread_loop(self):
        log("write_format_thread_loop starting")
        try:
            while not self._closed:
                self._source_has_more.wait()
                if self._closed:
                    return
                self._source_has_more.clear()
                self._add_packet_to_queue(*self._get_packet_cb())
        except:
            self._internal_error("error in network packet write/format", True)

    def _add_packet_to_queue(self, packet, start_send_cb=None, end_send_cb=None, has_more=False):
        if has_more:
            self._source_has_more.set()
        if packet is None:
            return
        log("add_packet_to_queue(%s ...)", packet[0])
        chunks, proto_flags = self.encode(packet)
        with self._write_lock:
            if self._closed:
                return
            self._add_chunks_to_queue(chunks, proto_flags, start_send_cb, end_send_cb)

    def _add_chunks_to_queue(self, chunks, proto_flags, start_send_cb=None, end_send_cb=None):
        """ the write_lock must be held when calling this function """
        counter = 0
        items = []
        for index, level, data in chunks:
            scb, ecb = None, None
            # fire the start_send_callback just before the first packet is processed:
            if counter == 0:
                scb = start_send_cb
            # fire the end_send callback when the last packet (index==0) makes it out:
            if index == 0:
                ecb = end_send_cb
            payload_size = len(data)
            actual_size = payload_size
            if self.cipher_out:
                proto_flags |= FLAGS_CIPHER
                # note: since we are padding: l!=len(data)
                padding = (self.cipher_out_block_size - len(data) % self.cipher_out_block_size) * " "
                if len(padding) == 0:
                    padded = data
                else:
                    padded = data + padding
                actual_size = payload_size + len(padding)
                assert len(padded) == actual_size
                data = self.cipher_out.encrypt(padded)
                assert len(data) == actual_size
                log("sending %s bytes encrypted with %s padding", payload_size, len(padding))
            if proto_flags & FLAGS_NOHEADER:
                # for plain/text packets (ie: gibberish response)
                items.append((data, scb, ecb))
            elif actual_size < PACKET_JOIN_SIZE:
                if type(data) not in JOIN_TYPES:
                    data = bytes(data)
                header_and_data = pack_header(proto_flags, level, index, payload_size) + data
                items.append((header_and_data, scb, ecb))
            else:
                header = pack_header(proto_flags, level, index, payload_size)
                items.append((header, scb, None))
                items.append((strtobytes(data), None, ecb))
            counter += 1
        self._write_queue.put(items)
        self.output_packetcount += 1

    def verify_packet(self, packet):
        """ look for None values which may have caused the packet to fail encoding """
        if type(packet) != list:
            return
        assert len(packet) > 0
        tree = ["'%s' packet" % packet[0]]
        self.do_verify_packet(tree, packet)

    def do_verify_packet(self, tree, packet):
        def err(msg):
            log.error("%s in %s", msg, "->".join(tree))

        def new_tree(append):
            nt = tree[:]
            nt.append(append)
            return nt

        if packet is None:
            return err("None value")
        if type(packet) == list:
            for i, x in enumerate(packet):
                self.do_verify_packet(new_tree("[%s]" % i), x)
        elif type(packet) == dict:
            for k, v in packet.items():
                self.do_verify_packet(new_tree("key for value='%s'" % str(v)), k)
                self.do_verify_packet(new_tree("value for key='%s'" % str(k)), v)

    def enable_default_encoder(self):
        opts = packet_encoding.get_enabled_encoders()
        assert len(opts) > 0, "no packet encoders available!"
        self.enable_encoder(opts[0])

    def enable_encoder_from_caps(self, caps):
        opts = packet_encoding.get_enabled_encoders(order=packet_encoding.PERFORMANCE_ORDER)
        log("enable_encoder_from_caps(..) options=%s", opts)
        for e in opts:
            if caps.boolget(e, e == "bencode"):
                self.enable_encoder(e)
                return True
        log.error("no matching packet encoder found!")
        return False

    def enable_encoder(self, e):
        self._encoder = packet_encoding.get_encoder(e)
        self.encoder = e
        log("enable_encoder(%s): %s", e, self._encoder)

    def enable_default_compressor(self):
        opts = compression.get_enabled_compressors()
        if len(opts) > 0:
            self.enable_compressor(opts[0])
        else:
            self.enable_compressor("none")

    def enable_compressor_from_caps(self, caps):
        if self.compression_level == 0:
            self.enable_compressor("none")
            return
        opts = compression.get_enabled_compressors(order=compression.PERFORMANCE_ORDER)
        log("enable_compressor_from_caps(..) options=%s", opts)
        for c in opts:  # ie: [zlib, lz4, lzo]
            if caps.boolget(c):
                self.enable_compressor(c)
                return
        log.warn("compression disabled: no matching compressor found")
        self.enable_compressor("none")

    def enable_compressor(self, compressor):
        self._compress = compression.get_compressor(compressor)
        self.compressor = compressor
        log("enable_compressor(%s): %s", compressor, self._compress)

    def noencode(self, data):
        # just send data as a string for clients that don't understand xpra packet format:
        if sys.version_info[0] >= 3:
            import codecs

            def b(x):
                if type(x) == bytes:
                    return x
                return codecs.latin_1_encode(x)[0]

        else:

            def b(x):  # @DuplicatedSignature
                return x

        return b(": ".join(str(x) for x in data) + "\n"), FLAGS_NOHEADER

    def encode(self, packet_in):
        """
        Given a packet (tuple or list of items), converts it for the wire.
        This method returns all the binary packets to send, as an array of:
        (index, compression_level and compression flags, binary_data)
        The index, if positive indicates the item to populate in the packet
        whose index is zero.
        ie: ["blah", [large binary data], "hello", 200]
        may get converted to:
        [
            (1, compression_level, [large binary data now zlib compressed]),
            (0,                 0, bencoded/rencoded(["blah", '', "hello", 200]))
        ]
        """
        packets = []
        packet = list(packet_in)
        level = self.compression_level
        size_check = LARGE_PACKET_SIZE
        min_comp_size = 378
        for i in range(1, len(packet)):
            item = packet[i]
            ti = type(item)
            if ti in (int, long, bool, dict, list, tuple):
                continue
            l = len(item)
            if ti == Uncompressed:
                # this is a marker used to tell us we should compress it now
                # (used by the client for clipboard data)
                item = item.compress()
                packet[i] = item
                ti = type(item)
                # (it may now be a "Compressed" item and be processed further)
            if ti in (Compressed, LevelCompressed):
                # already compressed data (usually pixels, cursors, etc)
                if not item.can_inline or l > INLINE_SIZE:
                    il = 0
                    if ti == LevelCompressed:
                        # unlike Compressed (usually pixels, decompressed in the paint thread),
                        # LevelCompressed is decompressed by the network layer
                        # so we must tell it how to do that and pass the level flag
                        il = item.level
                    packets.append((i, il, item.data))
                    packet[i] = ""
                else:
                    # data is small enough, inline it:
                    packet[i] = item.data
                    min_comp_size += l
                    size_check += l
            elif ti in (str, bytes) and level > 0 and l > LARGE_PACKET_SIZE:
                log.warn(
                    "found a large uncompressed item in packet '%s' at position %s: %s bytes", packet[0], i, len(item)
                )
                # add new binary packet with large item:
                cl, cdata = self._compress(item, level)
                packets.append((i, cl, cdata))
                # replace this item with an empty string placeholder:
                packet[i] = ""
            elif ti not in (str, bytes):
                log.warn("unexpected data type %s in %s packet: %s", ti, packet[0], repr_ellipsized(item))
        # now the main packet (or what is left of it):
        packet_type = packet[0]
        self.output_stats[packet_type] = self.output_stats.get(packet_type, 0) + 1
        if USE_ALIASES and self.send_aliases and packet_type in self.send_aliases:
            # replace the packet type with the alias:
            packet[0] = self.send_aliases[packet_type]
        try:
            main_packet, proto_version = self._encoder(packet)
        except Exception as e:
            if self._closed:
                return [], 0
            log.error("failed to encode packet: %s", packet, exc_info=True)
            # make the error a bit nicer to parse: undo aliases:
            packet[0] = packet_type
            self.verify_packet(packet)
            raise e
        if len(main_packet) > size_check and packet_in[0] not in self.large_packets:
            log.warn(
                "found large packet (%s bytes): %s, argument types:%s, sizes: %s, packet head=%s",
                len(main_packet),
                packet_in[0],
                [type(x) for x in packet[1:]],
                [len(str(x)) for x in packet[1:]],
                repr_ellipsized(packet),
            )
        # compress, but don't bother for small packets:
        if level > 0 and len(main_packet) > min_comp_size:
            cl, cdata = self._compress(main_packet, level)
            packets.append((0, cl, cdata))
        else:
            packets.append((0, 0, main_packet))
        return packets, proto_version

    def set_compression_level(self, level):
        # this may be used next time encode() is called
        assert level >= 0 and level <= 10, "invalid compression level: %s (must be between 0 and 10" % level
        self.compression_level = level

    def _io_thread_loop(self, name, callback):
        try:
            log("io_thread_loop(%s, %s) loop starting", name, callback)
            while not self._closed:
                callback()
            log("io_thread_loop(%s, %s) loop ended, closed=%s", name, callback, self._closed)
        except ConnectionClosedException as e:
            if not self._closed:
                self._internal_error("%s connection %s closed: %s" % (name, self._conn, e))
        except (OSError, IOError, socket_error) as e:
            if not self._closed:
                self._internal_error(
                    "%s connection %s reset: %s" % (name, self._conn, e), exc_info=e.args[0] not in ABORT
                )
        except:
            # can happen during close(), in which case we just ignore:
            if not self._closed:
                log.error("%s error on %s", name, self._conn, exc_info=True)
                self.close()

    def _write_thread_loop(self):
        self._io_thread_loop("write", self._write)

    def _write(self):
        items = self._write_queue.get()
        # Used to signal that we should exit:
        if items is None:
            log("write thread: empty marker, exiting")
            self.close()
            return
        for buf, start_cb, end_cb in items:
            con = self._conn
            if not con:
                return
            if start_cb:
                try:
                    start_cb(con.output_bytecount)
                except:
                    if not self._closed:
                        log.error("error on %s", start_cb, exc_info=True)
            while buf and not self._closed:
                written = con.write(buf)
                if written:
                    buf = buf[written:]
                    self.output_raw_packetcount += 1
            if end_cb:
                try:
                    end_cb(self._conn.output_bytecount)
                except:
                    if not self._closed:
                        log.error("error on %s", end_cb, exc_info=True)

    def _read_thread_loop(self):
        self._io_thread_loop("read", self._read)

    def _read(self):
        buf = self._conn.read(READ_BUFFER_SIZE)
        # log("read thread: got data of size %s: %s", len(buf), repr_ellipsized(buf))
        # add to the read queue (or whatever takes its place - see steal_connection)
        self._read_queue_put(buf)
        if not buf:
            log("read thread: eof")
            self.close()
            return
        self.input_raw_packetcount += 1

    def _internal_error(self, message="", exc_info=False):
        log.error("internal error: %s", message, exc_info=exc_info)
        self.idle_add(self._connection_lost, message)

    def _connection_lost(self, message="", exc_info=False):
        log("connection lost: %s", message, exc_info=exc_info)
        self.close()
        return False

    def invalid(self, msg, data):
        self.idle_add(self._process_packet_cb, self, [Protocol.INVALID, msg, data])
        # Then hang up:
        self.timeout_add(1000, self._connection_lost, msg)

    def gibberish(self, msg, data):
        self.idle_add(self._process_packet_cb, self, [Protocol.GIBBERISH, msg, data])
        # Then hang up:
        self.timeout_add(1000, self._connection_lost, msg)

    # delegates to invalid_header()
    # (so this can more easily be intercepted and overriden
    # see tcp-proxy)
    def _invalid_header(self, data):
        self.invalid_header(self, data)

    def invalid_header(self, proto, data):
        err = "invalid packet header: '%s'" % binascii.hexlify(data[:8])
        if len(data) > 1:
            err += " read buffer=%s" % repr_ellipsized(data)
        self.gibberish(err, data)

    def _read_parse_thread_loop(self):
        log("read_parse_thread_loop starting")
        try:
            self.do_read_parse_thread_loop()
        except:
            self._internal_error("error in network packet reading/parsing", True)

    def do_read_parse_thread_loop(self):
        """
            Process the individual network packets placed in _read_queue.
            Concatenate the raw packet data, then try to parse it.
            Extract the individual packets from the potentially large buffer,
            saving the rest of the buffer for later, and optionally decompress this data
            and re-construct the one python-object-packet from potentially multiple packets (see packet_index).
            The 8 bytes packet header gives us information on the packet index, packet size and compression.
            The actual processing of the packet is done via the callback process_packet_cb,
            this will be called from this parsing thread so any calls that need to be made
            from the UI thread will need to use a callback (usually via 'idle_add')
        """
        read_buffer = None
        payload_size = -1
        padding = None
        packet_index = 0
        compression_level = False
        packet = None
        raw_packets = {}
        while not self._closed:
            buf = self._read_queue.get()
            if not buf:
                log("read thread: empty marker, exiting")
                self.idle_add(self.close)
                return
            if read_buffer:
                read_buffer = read_buffer + buf
            else:
                read_buffer = buf
            bl = len(read_buffer)
            while not self._closed:
                packet = None
                bl = len(read_buffer)
                if bl <= 0:
                    break
                if payload_size < 0:
                    if read_buffer[0] not in ("P", ord("P")):
                        self._invalid_header(read_buffer)
                        return
                    if bl < 8:
                        break  # packet still too small
                    # packet format: struct.pack('cBBBL', ...) - 8 bytes
                    _, protocol_flags, compression_level, packet_index, data_size = unpack_header(read_buffer[:8])

                    # sanity check size (will often fail if not an xpra client):
                    if data_size > self.abs_max_packet_size:
                        self._invalid_header(read_buffer)
                        return

                    bl = len(read_buffer) - 8
                    if protocol_flags & FLAGS_CIPHER:
                        if self.cipher_in_block_size == 0 or not self.cipher_in_name:
                            log.warn(
                                "received cipher block but we don't have a cipher to decrypt it with, not an xpra client?"
                            )
                            self._invalid_header(read_buffer)
                            return
                        padding = (self.cipher_in_block_size - data_size % self.cipher_in_block_size) * " "
                        payload_size = data_size + len(padding)
                    else:
                        # no cipher, no padding:
                        padding = None
                        payload_size = data_size
                    assert payload_size > 0
                    read_buffer = read_buffer[8:]

                    if payload_size > self.max_packet_size:
                        # this packet is seemingly too big, but check again from the main UI thread
                        # this gives 'set_max_packet_size' a chance to run from "hello"
                        def check_packet_size(size_to_check, packet_header):
                            if self._closed:
                                return False
                            log(
                                "check_packet_size(%s, 0x%s) limit is %s",
                                size_to_check,
                                repr_ellipsized(packet_header),
                                self.max_packet_size,
                            )
                            if size_to_check > self.max_packet_size:
                                msg = "packet size requested is %s but maximum allowed is %s" % (
                                    size_to_check,
                                    self.max_packet_size,
                                )
                                self.invalid(msg, packet_header)
                            return False

                        self.timeout_add(1000, check_packet_size, payload_size, read_buffer[:32])

                if bl < payload_size:
                    # incomplete packet, wait for the rest to arrive
                    break

                # chop this packet from the buffer:
                if len(read_buffer) == payload_size:
                    raw_string = read_buffer
                    read_buffer = ""
                else:
                    raw_string = read_buffer[:payload_size]
                    read_buffer = read_buffer[payload_size:]
                # decrypt if needed:
                data = raw_string
                if self.cipher_in and protocol_flags & FLAGS_CIPHER:
                    log("received %s encrypted bytes with %s padding", payload_size, len(padding))
                    data = self.cipher_in.decrypt(raw_string)
                    if padding:

                        def debug_str(s):
                            try:
                                return list(bytearray(s))
                            except:
                                return list(str(s))

                        if not data.endswith(padding):
                            log(
                                "decryption failed: string does not end with '%s': %s (%s) -> %s (%s)",
                                padding,
                                debug_str(raw_string),
                                type(raw_string),
                                debug_str(data),
                                type(data),
                            )
                            self._internal_error("encryption error (wrong key?)")
                            return
                        data = data[: -len(padding)]
                # uncompress if needed:
                if compression_level > 0:
                    try:
                        data = decompress(data, compression_level)
                    except InvalidCompressionException as e:
                        self.invalid("invalid compression: %s" % e, data)
                        return
                    except Exception as e:
                        ctype = compression.get_compression_type(compression_level)
                        log("%s packet decompression failed", ctype, exc_info=True)
                        msg = "%s packet decompression failed" % ctype
                        if self.cipher_in:
                            msg += " (invalid encryption key?)"
                        else:
                            msg += " %s" % e
                        return self.gibberish(msg, data)

                if self.cipher_in and not (protocol_flags & FLAGS_CIPHER):
                    self.invalid("unencrypted packet dropped", data)
                    return

                if self._closed:
                    return
                if packet_index > 0:
                    # raw packet, store it and continue:
                    raw_packets[packet_index] = data
                    payload_size = -1
                    packet_index = 0
                    if len(raw_packets) >= 4:
                        self.invalid("too many raw packets: %s" % len(raw_packets), data)
                        return
                    continue
                # final packet (packet_index==0), decode it:
                try:
                    packet = decode(data, protocol_flags)
                except InvalidPacketEncodingException as e:
                    self.invalid("invalid packet encoding: %s" % e, data)
                    return
                except ValueError as e:
                    etype = packet_encoding.get_packet_encoding_type(protocol_flags)
                    log.error("failed to parse %s packet: %s", etype, e, exc_info=not self._closed)
                    if self._closed:
                        return
                    log("failed to parse %s packet: %s", etype, binascii.hexlify(data))
                    msg = "packet index=%s, packet size=%s, buffer size=%s, error=%s" % (
                        packet_index,
                        payload_size,
                        bl,
                        e,
                    )
                    self.gibberish("failed to parse %s packet" % etype, data)
                    return

                if self._closed:
                    return
                payload_size = -1
                padding = None
                # add any raw packets back into it:
                if raw_packets:
                    for index, raw_data in raw_packets.items():
                        # replace placeholder with the raw_data packet data:
                        packet[index] = raw_data
                    raw_packets = {}

                packet_type = packet[0]
                if self.receive_aliases and type(packet_type) == int and packet_type in self.receive_aliases:
                    packet_type = self.receive_aliases.get(packet_type)
                    packet[0] = packet_type
                self.input_stats[packet_type] = self.output_stats.get(packet_type, 0) + 1

                self.input_packetcount += 1
                log("processing packet %s", packet_type)
                self._process_packet_cb(self, packet)
                packet = None

    def flush_then_close(self, last_packet, done_callback=None):
        """ Note: this is best effort only
            the packet may not get sent.

            We try to get the write lock,
            we try to wait for the write queue to flush
            we queue our last packet,
            we wait again for the queue to flush,
            then no matter what, we close the connection and stop the threads.
        """

        def done():
            if done_callback:
                done_callback()

        if self._closed:
            log("flush_then_close: already closed")
            return done()

        def wait_for_queue(timeout=10):
            # IMPORTANT: if we are here, we have the write lock held!
            if not self._write_queue.empty():
                # write queue still has stuff in it..
                if timeout <= 0:
                    log("flush_then_close: queue still busy, closing without sending the last packet")
                    self._write_lock.release()
                    self.close()
                    done()
                else:
                    log("flush_then_close: still waiting for queue to flush")
                    self.timeout_add(100, wait_for_queue, timeout - 1)
            else:
                log("flush_then_close: queue is now empty, sending the last packet and closing")
                chunks, proto_flags = self.encode(last_packet)

                def close_and_release():
                    log("flush_then_close: wait_for_packet_sent() close_and_release()")
                    self.close()
                    try:
                        self._write_lock.release()
                    except:
                        pass
                    done()

                def wait_for_packet_sent():
                    log(
                        "flush_then_close: wait_for_packet_sent() queue.empty()=%s, closed=%s",
                        self._write_queue.empty(),
                        self._closed,
                    )
                    if self._write_queue.empty() or self._closed:
                        # it got sent, we're done!
                        close_and_release()
                        return False
                    return not self._closed  # run until we manage to close (here or via the timeout)

                def packet_queued(*args):
                    # if we're here, we have the lock and the packet is in the write queue
                    log("flush_then_close: packet_queued() closed=%s", self._closed)
                    if wait_for_packet_sent():
                        # check every 100ms
                        self.timeout_add(100, wait_for_packet_sent)

                self._add_chunks_to_queue(chunks, proto_flags, start_send_cb=None, end_send_cb=packet_queued)
                # just in case wait_for_packet_sent never fires:
                self.timeout_add(5 * 1000, close_and_release)

        def wait_for_write_lock(timeout=100):
            if not self._write_lock.acquire(False):
                if timeout <= 0:
                    log("flush_then_close: timeout waiting for the write lock")
                    self.close()
                    done()
                else:
                    log("flush_then_close: write lock is busy, will retry %s more times", timeout)
                    self.timeout_add(10, wait_for_write_lock, timeout - 1)
            else:
                log("flush_then_close: acquired the write lock")
                # we have the write lock - we MUST free it!
                wait_for_queue()

        # normal codepath:
        # -> wait_for_write_lock
        # -> wait_for_queue
        # -> _add_chunks_to_queue
        # -> packet_queued
        # -> wait_for_packet_sent
        # -> close_and_release
        log("flush_then_close: wait_for_write_lock()")
        wait_for_write_lock()

    def close(self):
        log("close() closed=%s", self._closed)
        if self._closed:
            return
        self._closed = True
        self.idle_add(self._process_packet_cb, self, [Protocol.CONNECTION_LOST])
        if self._conn:
            try:
                self._conn.close()
                if self._log_stats is None and self._conn.input_bytecount == 0 and self._conn.output_bytecount == 0:
                    # no data sent or received, skip logging of stats:
                    self._log_stats = False
                if self._log_stats:
                    from xpra.simple_stats import std_unit, std_unit_dec

                    log.info(
                        "connection closed after %s packets received (%s bytes) and %s packets sent (%s bytes)",
                        std_unit(self.input_packetcount),
                        std_unit_dec(self._conn.input_bytecount),
                        std_unit(self.output_packetcount),
                        std_unit_dec(self._conn.output_bytecount),
                    )
            except:
                log.error("error closing %s", self._conn, exc_info=True)
            self._conn = None
        self.terminate_queue_threads()
        self.idle_add(self.clean)

    def steal_connection(self, read_callback=None):
        # so we can re-use this connection somewhere else
        # (frees all protocol threads and resources)
        # Note: this method can only be used with non-blocking sockets,
        # and if more than one packet can arrive, the read_callback should be used
        # to ensure that no packets get lost.
        # The caller must call wait_for_io_threads_exit() to ensure that this
        # class is no longer reading from the connection before it can re-use it
        assert not self._closed
        if read_callback:
            self._read_queue_put = read_callback
        conn = self._conn
        self._closed = True
        self._conn = None
        if conn:
            # this ensures that we exit the untilConcludes() read/write loop
            conn.set_active(False)
        self.terminate_queue_threads()
        return conn

    def clean(self):
        # clear all references to ensure we can get garbage collected quickly:
        self._get_packet_cb = None
        self._encoder = None
        self._write_thread = None
        self._read_thread = None
        self._read_parser_thread = None
        self._write_format_thread = None
        self._process_packet_cb = None

    def terminate_queue_threads(self):
        log("terminate_queue_threads()")
        # the format thread will exit since closed is set too:
        self._source_has_more.set()
        # make the threads exit by adding the empty marker:
        exit_queue = Queue()
        for _ in range(10):  # just 2 should be enough!
            exit_queue.put(None)
        try:
            owq = self._write_queue
            self._write_queue = exit_queue
            owq.put_nowait(None)
        except:
            pass
        try:
            orq = self._read_queue
            self._read_queue = exit_queue
            orq.put_nowait(None)
        except:
            pass