コード例 #1
0
ファイル: test_protocol_base.py プロジェクト: svn2github/Xpra
 def new_connection(self, *args):
     log.info("new_connection(%s)", args)
     sock, address = self.listener.accept()
     log.info("new_connection(%s) sock=%s, address=%s", args, sock, address)
     sock.settimeout(None)
     sock.setblocking(1)
     sc = makeSocketConnection(sock, str(address)+"server")
     protocol = Protocol(gobject, sc, self.process_packet)
     protocol.salt = None
     protocol.set_compression_level(1)
     protocol.start()
     return True
コード例 #2
0
 def new_connection(self, *args):
     log.info("new_connection(%s)", args)
     sock, address = self.listener.accept()
     log.info("new_connection(%s) sock=%s, address=%s", args, sock, address)
     sock.settimeout(None)
     sock.setblocking(1)
     sc = makeSocketConnection(sock, str(address) + "server")
     protocol = Protocol(sc, self.process_packet)
     protocol.salt = None
     protocol.set_compression_level(1)
     protocol.start()
     return True
コード例 #3
0
class ProxyInstanceProcess(Process):

    def __init__(self, uid, gid, env_options, session_options, socket_dir,
                 video_encoder_modules, csc_modules,
                 client_conn, disp_desc, client_state, cipher, encryption_key, server_conn, caps, message_queue):
        Process.__init__(self, name=str(client_conn))
        self.uid = uid
        self.gid = gid
        self.env_options = env_options
        self.session_options = session_options
        self.socket_dir = socket_dir
        self.video_encoder_modules = video_encoder_modules
        self.csc_modules = csc_modules
        self.client_conn = client_conn
        self.disp_desc = disp_desc
        self.client_state = client_state
        self.cipher = cipher
        self.encryption_key = encryption_key
        self.server_conn = server_conn
        self.caps = caps
        log("ProxyProcess%s", (uid, gid, env_options, session_options, socket_dir,
                               video_encoder_modules, csc_modules,
                               client_conn, disp_desc, repr_ellipsized(str(client_state)), cipher, encryption_key, server_conn,
                               "%s: %s.." % (type(caps), repr_ellipsized(str(caps))), message_queue))
        self.client_protocol = None
        self.server_protocol = None
        self.exit = False
        self.main_queue = None
        self.message_queue = message_queue
        self.encode_queue = None            #holds draw packets to encode
        self.encode_thread = None
        self.video_encoding_defs = None
        self.video_encoders = None
        self.video_encoders_last_used_time = None
        self.video_encoder_types = None
        self.video_helper = None
        self.lost_windows = None
        #for handling the local unix domain socket:
        self.control_socket_cleanup = None
        self.control_socket = None
        self.control_socket_thread = None
        self.control_socket_path = None
        self.potential_protocols = []
        self.max_connections = MAX_CONCURRENT_CONNECTIONS

    def server_message_queue(self):
        while True:
            log("waiting for server message on %s", self.message_queue)
            m = self.message_queue.get()
            log("received proxy server message: %s", m)
            if m=="stop":
                self.stop("proxy server request")
                return
            elif m=="socket-handover-complete":
                log("setting sockets to blocking mode: %s", (self.client_conn, self.server_conn))
                #set sockets to blocking mode:
                set_blocking(self.client_conn)
                set_blocking(self.server_conn)
            else:
                log.error("unexpected proxy server message: %s", m)

    def signal_quit(self, signum, frame):
        log.info("")
        log.info("proxy process pid %s got signal %s, exiting", os.getpid(), SIGNAMES.get(signum, signum))
        self.exit = True
        signal.signal(signal.SIGINT, deadly_signal)
        signal.signal(signal.SIGTERM, deadly_signal)
        self.stop(SIGNAMES.get(signum, signum))

    def idle_add(self, fn, *args, **kwargs):
        #we emulate gobject's idle_add using a simple queue
        self.main_queue.put((fn, args, kwargs))

    def timeout_add(self, timeout, fn, *args, **kwargs):
        #emulate gobject's timeout_add using idle add and a Timer
        #using custom functions to cancel() the timer when needed
        def idle_exec():
            v = fn(*args, **kwargs)
            if bool(v):
                self.timeout_add(timeout, fn, *args, **kwargs)
            return False
        def timer_exec():
            #just run via idle_add:
            self.idle_add(idle_exec)
        Timer(timeout/1000.0, timer_exec).start()

    def run(self):
        log("ProxyProcess.run() pid=%s, uid=%s, gid=%s", os.getpid(), getuid(), getgid())
        setuidgid(self.uid, self.gid)
        if self.env_options:
            #TODO: whitelist env update?
            os.environ.update(self.env_options)
        self.video_init()

        log.info("new proxy instance started")
        log.info(" for client %s", self.client_conn)
        log.info(" and server %s", self.server_conn)

        signal.signal(signal.SIGTERM, self.signal_quit)
        signal.signal(signal.SIGINT, self.signal_quit)
        log("registered signal handler %s", self.signal_quit)

        start_thread(self.server_message_queue, "server message queue")

        if not self.create_control_socket():
            #TODO: should send a message to the client
            return
        self.control_socket_thread = start_thread(self.control_socket_loop, "control")

        self.main_queue = Queue()
        #setup protocol wrappers:
        self.server_packets = Queue(PROXY_QUEUE_SIZE)
        self.client_packets = Queue(PROXY_QUEUE_SIZE)
        self.client_protocol = Protocol(self, self.client_conn, self.process_client_packet, self.get_client_packet)
        self.client_protocol.restore_state(self.client_state)
        self.server_protocol = Protocol(self, self.server_conn, self.process_server_packet, self.get_server_packet)
        #server connection tweaks:
        self.server_protocol.large_packets.append("input-devices")
        self.server_protocol.large_packets.append("draw")
        self.server_protocol.large_packets.append("window-icon")
        self.server_protocol.large_packets.append("keymap-changed")
        self.server_protocol.large_packets.append("server-settings")
        if self.caps.boolget("file-transfer"):
            self.client_protocol.large_packets.append("send-file")
            self.client_protocol.large_packets.append("send-file-chunk")
            self.server_protocol.large_packets.append("send-file")
            self.server_protocol.large_packets.append("send-file-chunk")
        self.server_protocol.set_compression_level(self.session_options.get("compression_level", 0))
        self.server_protocol.enable_default_encoder()

        self.lost_windows = set()
        self.encode_queue = Queue()
        self.encode_thread = start_thread(self.encode_loop, "encode")

        log("starting network threads")
        self.server_protocol.start()
        self.client_protocol.start()

        self.send_hello()
        self.timeout_add(VIDEO_TIMEOUT*1000, self.timeout_video_encoders)

        try:
            self.run_queue()
        except KeyboardInterrupt as e:
            self.stop(str(e))
        finally:
            log("ProxyProcess.run() ending %s", os.getpid())

    def video_init(self):
        enclog("video_init() loading codecs")
        load_codecs(decoders=False)
        enclog("video_init() will try video encoders: %s", csv(self.video_encoder_modules) or "none")
        self.video_helper = getVideoHelper()
        #only use video encoders (no CSC supported in proxy)
        self.video_helper.set_modules(video_encoders=self.video_encoder_modules)
        self.video_helper.init()

        self.video_encoding_defs = {}
        self.video_encoders = {}
        self.video_encoders_dst_formats = []
        self.video_encoders_last_used_time = {}
        self.video_encoder_types = []

        #figure out which encoders we want to proxy for (if any):
        encoder_types = set()
        for encoding in self.video_helper.get_encodings():
            colorspace_specs = self.video_helper.get_encoder_specs(encoding)
            for colorspace, especs in colorspace_specs.items():
                if colorspace not in ("BGRX", "BGRA", "RGBX", "RGBA"):
                    #only deal with encoders that can handle plain RGB directly
                    continue

                for spec in especs:                             #ie: video_spec("x264")
                    spec_props = spec.to_dict()
                    del spec_props["codec_class"]               #not serializable!
                    spec_props["score_boost"] = 50              #we want to win scoring so we get used ahead of other encoders
                    spec_props["max_instances"] = 3             #limit to 3 video streams we proxy for (we really want 2,
                                                                # but because of races with garbage collection, we need to allow more)
                    #store it in encoding defs:
                    self.video_encoding_defs.setdefault(encoding, {}).setdefault(colorspace, []).append(spec_props)
                    encoder_types.add(spec.codec_type)

        enclog("encoder types found: %s", tuple(encoder_types))
        #remove duplicates and use preferred order:
        order = PREFERRED_ENCODER_ORDER[:]
        for x in list(encoder_types):
            if x not in order:
                order.append(x)
        self.video_encoder_types = [x for x in order if x in encoder_types]
        enclog.info("proxy video encoders: %s", ", ".join(self.video_encoder_types))


    def create_control_socket(self):
        assert self.socket_dir
        username = get_username_for_uid(self.uid)
        dotxpra = DotXpra(self.socket_dir, actual_username=username, uid=self.uid, gid=self.gid)
        sockpath = dotxpra.socket_path(":proxy-%s" % os.getpid())
        state = dotxpra.get_server_state(sockpath)
        log("create_control_socket: socket path='%s', uid=%i, gid=%i, state=%s", sockpath, getuid(), getgid(), state)
        if state in (DotXpra.LIVE, DotXpra.UNKNOWN):
            log.error("Error: you already have a proxy server running at '%s'", sockpath)
            log.error(" the control socket will not be created")
            return False
        d = os.path.dirname(sockpath)
        try:
            dotxpra.mksockdir(d)
        except Exception as e:
            log.warn("Warning: failed to create socket directory '%s'", d)
            log.warn(" %s", e)
        try:
            sock, self.control_socket_cleanup = create_unix_domain_socket(sockpath, None, 0o600)
            sock.listen(5)
        except Exception as e:
            log("create_unix_domain_socket failed for '%s'", sockpath, exc_info=True)
            log.error("Error: failed to setup control socket '%s':", sockpath)
            log.error(" %s", e)
            return False
        self.control_socket = sock
        self.control_socket_path = sockpath
        log.info("proxy instance now also available using unix domain socket:")
        log.info(" %s", self.control_socket_path)
        return True

    def control_socket_loop(self):
        while not self.exit:
            log("waiting for connection on %s", self.control_socket_path)
            sock, address = self.control_socket.accept()
            self.new_control_connection(sock, address)

    def new_control_connection(self, sock, address):
        if len(self.potential_protocols)>=self.max_connections:
            log.error("too many connections (%s), ignoring new one", len(self.potential_protocols))
            sock.close()
            return  True
        try:
            peername = sock.getpeername()
        except:
            peername = str(address)
        sockname = sock.getsockname()
        target = peername or sockname
        #sock.settimeout(0)
        log("new_control_connection() sock=%s, sockname=%s, address=%s, peername=%s", sock, sockname, address, peername)
        sc = SocketConnection(sock, sockname, address, target, "unix-domain")
        log.info("New proxy instance control connection received: %s", sc)
        protocol = Protocol(self, sc, self.process_control_packet)
        protocol.large_packets.append("info-response")
        self.potential_protocols.append(protocol)
        protocol.enable_default_encoder()
        protocol.start()
        self.timeout_add(SOCKET_TIMEOUT*1000, self.verify_connection_accepted, protocol)
        return True

    def verify_connection_accepted(self, protocol):
        if not protocol._closed and protocol in self.potential_protocols:
            log.error("connection timedout: %s", protocol)
            self.send_disconnect(protocol, LOGIN_TIMEOUT)

    def process_control_packet(self, proto, packet):
        try:
            self.do_process_control_packet(proto, packet)
        except Exception as e:
            log.error("error processing control packet", exc_info=True)
            self.send_disconnect(proto, CONTROL_COMMAND_ERROR, str(e))

    def do_process_control_packet(self, proto, packet):
        log("process_control_packet(%s, %s)", proto, packet)
        packet_type = packet[0]
        if packet_type==Protocol.CONNECTION_LOST:
            log.info("Connection lost")
            if proto in self.potential_protocols:
                self.potential_protocols.remove(proto)
            return
        if packet_type=="hello":
            caps = typedict(packet[1])
            if caps.boolget("challenge"):
                self.send_disconnect(proto, AUTHENTICATION_ERROR, "this socket does not use authentication")
                return
            if caps.get("info_request", False):
                proto.send_now(("hello", self.get_proxy_info(proto)))
                self.timeout_add(5*1000, self.send_disconnect, proto, CLIENT_EXIT_TIMEOUT, "info sent")
                return
            elif caps.get("stop_request", False):
                self.stop("socket request", None)
                return
            elif caps.get("version_request", False):
                from xpra import __version__
                proto.send_now(("hello", {"version" : __version__}))
                self.timeout_add(5*1000, self.send_disconnect, proto, CLIENT_EXIT_TIMEOUT, "version sent")
                return
        self.send_disconnect(proto, CONTROL_COMMAND_ERROR, "this socket only handles 'info', 'version' and 'stop' requests")

    def send_disconnect(self, proto, reason, *extra):
        log("send_disconnect(%s, %s, %s)", proto, reason, extra)
        if proto._closed:
            return
        proto.send_now(["disconnect", reason]+list(extra))
        self.timeout_add(1000, self.force_disconnect, proto)

    def force_disconnect(self, proto):
        proto.close()


    def get_proxy_info(self, proto):
        sinfo = {}
        sinfo.update(get_server_info())
        sinfo.update(get_thread_info(proto))
        return {"proxy" : {
                           "version"    : local_version,
                           ""           : sinfo,
                           },
                "window" : self.get_window_info(),
                }

    def send_hello(self, challenge_response=None, client_salt=None):
        hello = self.filter_client_caps(self.caps)
        if challenge_response:
            hello.update({
                          "challenge_response"      : challenge_response,
                          "challenge_client_salt"   : client_salt,
                          })
        self.queue_server_packet(("hello", hello))


    def sanitize_session_options(self, options):
        d = {}
        def number(k, v):
            return parse_number(int, k, v)
        OPTION_WHITELIST = {"compression_level" : number,
                            "lz4"               : parse_bool,
                            "lzo"               : parse_bool,
                            "zlib"              : parse_bool,
                            "rencode"           : parse_bool,
                            "bencode"           : parse_bool,
                            "yaml"              : parse_bool}
        for k,v in options.items():
            parser = OPTION_WHITELIST.get(k)
            if parser:
                log("trying to add %s=%s using %s", k, v, parser)
                try:
                    d[k] = parser(k, v)
                except Exception as e:
                    log.warn("failed to parse value %s for %s using %s: %s", v, k, parser, e)
        return d

    def filter_client_caps(self, caps):
        fc = self.filter_caps(caps, ("cipher", "challenge", "digest", "aliases", "compression", "lz4", "lz0", "zlib"))
        #the display string may override the username:
        username = self.disp_desc.get("username")
        if username:
            fc["username"] = username
        #update with options provided via config if any:
        fc.update(self.sanitize_session_options(self.session_options))
        #add video proxies if any:
        fc["encoding.proxy.video"] = len(self.video_encoding_defs)>0
        if self.video_encoding_defs:
            fc["encoding.proxy.video.encodings"] = self.video_encoding_defs
        return fc

    def filter_server_caps(self, caps):
        self.server_protocol.enable_encoder_from_caps(caps)
        return self.filter_caps(caps, ("aliases", ))

    def filter_caps(self, caps, prefixes):
        #removes caps that the proxy overrides / does not use:
        #(not very pythonic!)
        pcaps = {}
        removed = []
        for k in caps.keys():
            skip = len([e for e in prefixes if k.startswith(e)])
            if skip==0:
                pcaps[k] = caps[k]
            else:
                removed.append(k)
        log("filtered out %s matching %s", removed, prefixes)
        #replace the network caps with the proxy's own:
        pcaps.update(flatten_dict(get_network_caps()))
        #then add the proxy info:
        updict(pcaps, "proxy", get_server_info(), flatten_dicts=True)
        pcaps["proxy"] = True
        pcaps["proxy.hostname"] = socket.gethostname()
        return pcaps


    def run_queue(self):
        log("run_queue() queue has %s items already in it", self.main_queue.qsize())
        #process "idle_add"/"timeout_add" events in the main loop:
        while not self.exit:
            log("run_queue() size=%s", self.main_queue.qsize())
            v = self.main_queue.get()
            if v is None:
                log("run_queue() None exit marker")
                break
            fn, args, kwargs = v
            log("run_queue() %s%s%s", fn, args, kwargs)
            try:
                v = fn(*args, **kwargs)
                if bool(v):
                    #re-run it
                    self.main_queue.put(v)
            except:
                log.error("error during main loop callback %s", fn, exc_info=True)
        self.exit = True
        #wait for connections to close down cleanly before we exit
        for i in range(10):
            if self.client_protocol._closed and self.server_protocol._closed:
                break
            if i==0:
                log.info("waiting for network connections to close")
            else:
                log("still waiting %i/10 - client.closed=%s, server.closed=%s", i+1, self.client_protocol._closed, self.server_protocol._closed)
            time.sleep(0.1)
        log.info("proxy instance %s stopped", os.getpid())

    def stop(self, reason="proxy terminating", skip_proto=None):
        log.info("stop(%s, %s)", reason, skip_proto)
        self.exit = True
        try:
            self.control_socket.close()
        except:
            pass
        csc = self.control_socket_cleanup
        if csc:
            self.control_socket_cleanup = None
            csc()
        self.main_queue.put(None)
        #empty the main queue:
        q = Queue()
        q.put(None)
        self.main_queue = q
        #empty the encode queue:
        q = Queue()
        q.put(None)
        self.encode_queue = q
        for proto in (self.client_protocol, self.server_protocol):
            if proto and proto!=skip_proto:
                log("sending disconnect to %s", proto)
                proto.flush_then_close(["disconnect", SERVER_SHUTDOWN, reason])


    def queue_client_packet(self, packet):
        log("queueing client packet: %s", packet[0])
        self.client_packets.put(packet)
        self.client_protocol.source_has_more()

    def get_client_packet(self):
        #server wants a packet
        p = self.client_packets.get()
        log("sending to client: %s", p[0])
        return p, None, None, self.client_packets.qsize()>0

    def process_client_packet(self, proto, packet):
        packet_type = packet[0]
        log("process_client_packet: %s", packet_type)
        if packet_type==Protocol.CONNECTION_LOST:
            self.stop("client connection lost", proto)
            return
        elif packet_type=="disconnect":
            log("got disconnect from client: %s", packet[1])
            if self.exit:
                self.client_protocol.close()
            else:
                self.stop("disconnect from client: %s" % packet[1])
        elif packet_type=="set_deflate":
            #echo it back to the client:
            self.client_packets.put(packet)
            self.client_protocol.source_has_more()
            return
        elif packet_type=="hello":
            log.warn("Warning: invalid hello packet received after initial authentication (dropped)")
            return
        self.queue_server_packet(packet)


    def queue_server_packet(self, packet):
        log("queueing server packet: %s", packet[0])
        self.server_packets.put(packet)
        self.server_protocol.source_has_more()

    def get_server_packet(self):
        #server wants a packet
        p = self.server_packets.get()
        log("sending to server: %s", p[0])
        return p, None, None, self.server_packets.qsize()>0


    def _packet_recompress(self, packet, index, name):
        if len(packet)>index:
            data = packet[index]
            if len(data)<512:
                packet[index] = str(data)
                return
            #FIXME: this is ugly and not generic!
            zlib = compression.use_zlib and self.caps.boolget("zlib", True)
            lz4 = compression.use_lz4 and self.caps.boolget("lz4", False)
            lzo = compression.use_lzo and self.caps.boolget("lzo", False)
            if zlib or lz4 or lzo:
                packet[index] = compressed_wrapper(name, data, zlib=zlib, lz4=lz4, lzo=lzo, can_inline=False)
            else:
                #prevent warnings about large uncompressed data
                packet[index] = Compressed("raw %s" % name, data, can_inline=True)

    def process_server_packet(self, proto, packet):
        packet_type = packet[0]
        log("process_server_packet: %s", packet_type)
        if packet_type==Protocol.CONNECTION_LOST:
            self.stop("server connection lost", proto)
            return
        elif packet_type=="disconnect":
            log("got disconnect from server: %s", packet[1])
            if self.exit:
                self.server_protocol.close()
            else:
                self.stop("disconnect from server: %s" % packet[1])
        elif packet_type=="hello":
            c = typedict(packet[1])
            maxw, maxh = c.intpair("max_desktop_size", (4096, 4096))
            caps = self.filter_server_caps(c)
            #add new encryption caps:
            if self.cipher:
                from xpra.net.crypto import crypto_backend_init, new_cipher_caps, DEFAULT_PADDING
                crypto_backend_init()
                padding_options = self.caps.strlistget("cipher.padding.options", [DEFAULT_PADDING])
                auth_caps = new_cipher_caps(self.client_protocol, self.cipher, self.encryption_key, padding_options)
                caps.update(auth_caps)
            #may need to bump packet size:
            proto.max_packet_size = maxw*maxh*4*4
            file_transfer = self.caps.boolget("file-transfer") and c.boolget("file-transfer")
            file_size_limit = max(self.caps.intget("file-size-limit"), c.intget("file-size-limit"))
            file_max_packet_size = int(file_transfer) * (1024 + file_size_limit*1024*1024)
            self.client_protocol.max_packet_size = max(self.client_protocol.max_packet_size, file_max_packet_size)
            self.server_protocol.max_packet_size = max(self.server_protocol.max_packet_size, file_max_packet_size)
            packet = ("hello", caps)
        elif packet_type=="info-response":
            #adds proxy info:
            #note: this is only seen by the client application
            #"xpra info" is a new connection, which talks to the proxy server...
            info = packet[1]
            info.update(self.get_proxy_info(proto))
        elif packet_type=="lost-window":
            wid = packet[1]
            #mark it as lost so we can drop any current/pending frames
            self.lost_windows.add(wid)
            #queue it so it gets cleaned safely (for video encoders mostly):
            self.encode_queue.put(packet)
            #and fall through so tell the client immediately
        elif packet_type=="draw":
            #use encoder thread:
            self.encode_queue.put(packet)
            #which will queue the packet itself when done:
            return
        #we do want to reformat cursor packets...
        #as they will have been uncompressed by the network layer already:
        elif packet_type=="cursor":
            #packet = ["cursor", x, y, width, height, xhot, yhot, serial, pixels, name]
            #or:
            #packet = ["cursor", "png", x, y, width, height, xhot, yhot, serial, pixels, name]
            #or:
            #packet = ["cursor", ""]
            if len(packet)>=8:
                #hard to distinguish png cursors from normal cursors...
                try:
                    int(packet[1])
                    self._packet_recompress(packet, 8, "cursor")
                except:
                    self._packet_recompress(packet, 9, "cursor")
        elif packet_type=="window-icon":
            self._packet_recompress(packet, 5, "icon")
        elif packet_type=="send-file":
            if packet[6]:
                packet[6] = Compressed("file-data", packet[6])
        elif packet_type=="send-file-chunk":
            if packet[3]:
                packet[3] = Compressed("file-chunk-data", packet[3])
        elif packet_type=="challenge":
            from xpra.net.crypto import get_salt
            #client may have already responded to the challenge,
            #so we have to handle authentication from this end
            salt = packet[1]
            digest = packet[3]
            client_salt = get_salt(len(salt))
            salt = xor_str(salt, client_salt)
            if digest!=b"hmac":
                self.stop("digest mode '%s' not supported", std(digest))
                return
            password = self.disp_desc.get("password", self.session_options.get("password"))
            log("password from %s / %s = %s", self.disp_desc, self.session_options, password)
            if not password:
                self.stop("authentication requested by the server, but no password available for this session")
                return
            import hmac, hashlib
            password = strtobytes(password)
            salt = strtobytes(salt)
            challenge_response = hmac.HMAC(password, salt, digestmod=hashlib.md5).hexdigest()
            log.info("sending %s challenge response", digest)
            self.send_hello(challenge_response, client_salt)
            return
        self.queue_client_packet(packet)


    def encode_loop(self):
        """ thread for slower encoding related work """
        while not self.exit:
            packet = self.encode_queue.get()
            if packet is None:
                return
            try:
                packet_type = packet[0]
                if packet_type=="lost-window":
                    wid = packet[1]
                    self.lost_windows.remove(wid)
                    ve = self.video_encoders.get(wid)
                    if ve:
                        del self.video_encoders[wid]
                        del self.video_encoders_last_used_time[wid]
                        ve.clean()
                elif packet_type=="draw":
                    #modify the packet with the video encoder:
                    if self.process_draw(packet):
                        #then send it as normal:
                        self.queue_client_packet(packet)
                elif packet_type=="check-video-timeout":
                    #not a real packet, this is added by the timeout check:
                    wid = packet[1]
                    ve = self.video_encoders.get(wid)
                    now = time.time()
                    idle_time = now-self.video_encoders_last_used_time.get(wid)
                    if ve and idle_time>VIDEO_TIMEOUT:
                        enclog("timing out the video encoder context for window %s", wid)
                        #timeout is confirmed, we are in the encoding thread,
                        #so it is now safe to clean it up:
                        ve.clean()
                        del self.video_encoders[wid]
                        del self.video_encoders_last_used_time[wid]
                else:
                    enclog.warn("unexpected encode packet: %s", packet_type)
            except:
                enclog.warn("error encoding packet", exc_info=True)


    def process_draw(self, packet):
        wid, x, y, width, height, encoding, pixels, _, rowstride, client_options = packet[1:11]
        #never modify mmap packets
        if encoding in ("mmap", "scroll"):
            return True

        #we have a proxy video packet:
        rgb_format = client_options.get("rgb_format", "")
        enclog("proxy draw: client_options=%s", client_options)

        def send_updated(encoding, compressed_data, updated_client_options):
            #update the packet with actual encoding data used:
            packet[6] = encoding
            packet[7] = compressed_data
            packet[10] = updated_client_options
            enclog("returning %s bytes from %s, options=%s", len(compressed_data), len(pixels), updated_client_options)
            return (wid not in self.lost_windows)

        def passthrough(strip_alpha=True):
            enclog("proxy draw: %s passthrough (rowstride: %s vs %s, strip alpha=%s)", rgb_format, rowstride, client_options.get("rowstride", 0), strip_alpha)
            if strip_alpha:
                #passthrough as plain RGB:
                Xindex = rgb_format.upper().find("X")
                if Xindex>=0 and len(rgb_format)==4:
                    #force clear alpha (which may be garbage):
                    newdata = bytearray(pixels)
                    for i in range(len(pixels)/4):
                        newdata[i*4+Xindex] = chr(255)
                    packet[9] = client_options.get("rowstride", 0)
                    cdata = bytes(newdata)
                else:
                    cdata = pixels
                new_client_options = {"rgb_format" : rgb_format}
            else:
                #preserve
                cdata = pixels
                new_client_options = client_options
            wrapped = Compressed("%s pixels" % encoding, cdata)
            #FIXME: we should not assume that rgb32 is supported here...
            #(we may have to convert to rgb24..)
            return send_updated("rgb32", wrapped, new_client_options)

        proxy_video = client_options.get("proxy", False)
        if PASSTHROUGH and (encoding in ("rgb32", "rgb24") or proxy_video):
            #we are dealing with rgb data, so we can pass it through:
            return passthrough(proxy_video)
        elif not self.video_encoder_types or not client_options or not proxy_video:
            #ensure we don't try to re-compress the pixel data in the network layer:
            #(re-add the "compressed" marker that gets lost when we re-assemble packets)
            packet[7] = Compressed("%s pixels" % encoding, packet[7])
            return True

        #video encoding: find existing encoder
        ve = self.video_encoders.get(wid)
        if ve:
            if ve in self.lost_windows:
                #we cannot clean the video encoder here, there may be more frames queue up
                #"lost-window" in encode_loop will take care of it safely
                return False
            #we must verify that the encoder is still valid
            #and scrap it if not (ie: when window is resized)
            if ve.get_width()!=width or ve.get_height()!=height:
                enclog("closing existing video encoder %s because dimensions have changed from %sx%s to %sx%s", ve, ve.get_width(), ve.get_height(), width, height)
                ve.clean()
                ve = None
            elif ve.get_encoding()!=encoding:
                enclog("closing existing video encoder %s because encoding has changed from %s to %s", ve.get_encoding(), encoding)
                ve.clean()
                ve = None
        #scaling and depth are proxy-encoder attributes:
        scaling = client_options.get("scaling", (1, 1))
        depth   = client_options.get("depth", 24)
        rowstride = client_options.get("rowstride", rowstride)
        quality = client_options.get("quality", -1)
        speed   = client_options.get("speed", -1)
        timestamp = client_options.get("timestamp")

        image = ImageWrapper(x, y, width, height, pixels, rgb_format, depth, rowstride, planes=ImageWrapper.PACKED)
        if timestamp is not None:
            image.set_timestamp(timestamp)

        #the encoder options are passed through:
        encoder_options = client_options.get("options", {})
        if not ve:
            #make a new video encoder:
            spec = self._find_video_encoder(encoding, rgb_format)
            if spec is None:
                #no video encoder!
                enc_pillow = get_codec("enc_pillow")
                if not enc_pillow:
                    from xpra.server.picture_encode import warn_encoding_once
                    warn_encoding_once("no-video-no-PIL", "no video encoder found for rgb format %s, sending as plain RGB!" % rgb_format)
                    return passthrough(True)
                enclog("no video encoder available: sending as jpeg")
                coding, compressed_data, client_options, _, _, _, _ = enc_pillow.encode("jpeg", image, quality, speed, False)
                return send_updated(coding, compressed_data, client_options)

            enclog("creating new video encoder %s for window %s", spec, wid)
            ve = spec.make_instance()
            #dst_formats is specified with first frame only:
            dst_formats = client_options.get("dst_formats")
            if dst_formats is not None:
                #save it in case we timeout the video encoder,
                #so we can instantiate it again, even from a frame no>1
                self.video_encoders_dst_formats = dst_formats
            else:
                assert self.video_encoders_dst_formats, "BUG: dst_formats not specified for proxy and we don't have it either"
                dst_formats = self.video_encoders_dst_formats
            ve.init_context(width, height, rgb_format, dst_formats, encoding, quality, speed, scaling, {})
            self.video_encoders[wid] = ve
            self.video_encoders_last_used_time[wid] = time.time()       #just to make sure this is always set
        #actual video compression:
        enclog("proxy compression using %s with quality=%s, speed=%s", ve, quality, speed)
        data, out_options = ve.compress_image(image, quality, speed, encoder_options)
        #pass through some options if we don't have them from the encoder
        #(maybe we should also use the "pts" from the real server?)
        for k in ("timestamp", "rgb_format", "depth", "csc"):
            if k not in out_options and k in client_options:
                out_options[k] = client_options[k]
        self.video_encoders_last_used_time[wid] = time.time()
        return send_updated(ve.get_encoding(), Compressed(encoding, data), out_options)

    def timeout_video_encoders(self):
        #have to be careful as another thread may come in...
        #so we just ask the encode thread (which deals with encoders already)
        #to do what may need to be done if we find a timeout:
        now = time.time()
        for wid in list(self.video_encoders_last_used_time.keys()):
            idle_time = int(now-self.video_encoders_last_used_time.get(wid))
            if idle_time is None:
                continue
            enclog("timeout_video_encoders() wid=%s, idle_time=%s", wid, idle_time)
            if idle_time and idle_time>VIDEO_TIMEOUT:
                self.encode_queue.put(["check-video-timeout", wid])
        return True     #run again

    def _find_video_encoder(self, encoding, rgb_format):
        #try the one specified first, then all the others:
        try_encodings = [encoding] + [x for x in self.video_helper.get_encodings() if x!=encoding]
        for encoding in try_encodings:
            colorspace_specs = self.video_helper.get_encoder_specs(encoding)
            especs = colorspace_specs.get(rgb_format)
            if len(especs)==0:
                continue
            for etype in self.video_encoder_types:
                for spec in especs:
                    if etype==spec.codec_type:
                        enclog("_find_video_encoder(%s, %s)=%s", encoding, rgb_format, spec)
                        return spec
        enclog("_find_video_encoder(%s, %s) not found", encoding, rgb_format)
        return None

    def get_window_info(self):
        info = {}
        now = time.time()
        for wid, encoder in self.video_encoders.items():
            einfo = encoder.get_info()
            einfo["idle_time"] = int(now-self.video_encoders_last_used_time.get(wid, 0))
            info[wid] = {
                         "proxy"    : {
                                       ""           : encoder.get_type(),
                                       "encoder"    : einfo
                                       },
                         }
        enclog("get_window_info()=%s", info)
        return info
コード例 #4
0
class ProxyInstanceProcess(Process):

    def __init__(self, uid, gid, env_options, session_options, socket_dir,
                 video_encoder_modules, csc_modules,
                 client_conn, client_state, cipher, encryption_key, server_conn, caps, message_queue):
        Process.__init__(self, name=str(client_conn))
        self.uid = uid
        self.gid = gid
        self.env_options = env_options
        self.session_options = self.sanitize_session_options(session_options)
        self.socket_dir = socket_dir
        self.video_encoder_modules = video_encoder_modules
        self.csc_modules = csc_modules
        self.client_conn = client_conn
        self.client_state = client_state
        self.cipher = cipher
        self.encryption_key = encryption_key
        self.server_conn = server_conn
        self.caps = caps
        log("ProxyProcess%s", (uid, gid, client_conn, client_state, cipher, encryption_key, server_conn, "{..}"))
        self.client_protocol = None
        self.server_protocol = None
        self.exit = False
        self.main_queue = None
        self.message_queue = message_queue
        self.encode_queue = None            #holds draw packets to encode
        self.encode_thread = None
        self.video_encoding_defs = None
        self.video_encoders = None
        self.video_encoders_last_used_time = None
        self.video_encoder_types = None
        self.video_helper = None
        self.lost_windows = None
        #for handling the local unix domain socket:
        self.control_socket = None
        self.control_socket_thread = None
        self.control_socket_path = None
        self.potential_protocols = []
        self.max_connections = MAX_CONCURRENT_CONNECTIONS

    def server_message_queue(self):
        while True:
            log("waiting for server message on %s", self.message_queue)
            m = self.message_queue.get()
            log.info("proxy server message: %s", m)
            if m=="stop":
                self.stop("proxy server request")
                return

    def signal_quit(self, signum, frame):
        log.info("")
        log.info("proxy process pid %s got signal %s, exiting", os.getpid(), SIGNAMES.get(signum, signum))
        self.exit = True
        signal.signal(signal.SIGINT, deadly_signal)
        signal.signal(signal.SIGTERM, deadly_signal)
        self.stop(SIGNAMES.get(signum, signum))

    def idle_add(self, fn, *args, **kwargs):
        #we emulate gobject's idle_add using a simple queue
        self.main_queue.put((fn, args, kwargs))

    def timeout_add(self, timeout, fn, *args, **kwargs):
        #emulate gobject's timeout_add using idle add and a Timer
        #using custom functions to cancel() the timer when needed
        def idle_exec():
            v = fn(*args, **kwargs)
            if bool(v):
                self.timeout_add(timeout, fn, *args, **kwargs)
            return False
        def timer_exec():
            #just run via idle_add:
            self.idle_add(idle_exec)
        Timer(timeout/1000.0, timer_exec).start()

    def run(self):
        log("ProxyProcess.run() pid=%s, uid=%s, gid=%s", os.getpid(), os.getuid(), os.getgid())
        #change uid and gid:
        if os.getgid()!=self.gid:
            os.setgid(self.gid)
        if os.getuid()!=self.uid:
            os.setuid(self.uid)
        log("ProxyProcess.run() new uid=%s, gid=%s", os.getuid(), os.getgid())

        if self.env_options:
            #TODO: whitelist env update?
            os.environ.update(self.env_options)
        self.video_init()

        log.info("new proxy started for client %s and server %s", self.client_conn, self.server_conn)

        signal.signal(signal.SIGTERM, self.signal_quit)
        signal.signal(signal.SIGINT, self.signal_quit)
        log("registered signal handler %s", self.signal_quit)

        make_daemon_thread(self.server_message_queue, "server message queue").start()

        if self.create_control_socket():
            self.control_socket_thread = make_daemon_thread(self.control_socket_loop, "control")
            self.control_socket_thread.start()

        self.main_queue = Queue()
        #setup protocol wrappers:
        self.server_packets = Queue(PROXY_QUEUE_SIZE)
        self.client_packets = Queue(PROXY_QUEUE_SIZE)
        self.client_protocol = Protocol(self, self.client_conn, self.process_client_packet, self.get_client_packet)
        self.client_protocol.restore_state(self.client_state)
        self.server_protocol = Protocol(self, self.server_conn, self.process_server_packet, self.get_server_packet)
        #server connection tweaks:
        self.server_protocol.large_packets.append("draw")
        self.server_protocol.large_packets.append("window-icon")
        self.server_protocol.large_packets.append("keymap-changed")
        self.server_protocol.large_packets.append("server-settings")
        self.server_protocol.set_compression_level(self.session_options.get("compression_level", 0))
        self.server_protocol.enable_default_encoder()

        self.lost_windows = set()
        self.encode_queue = Queue()
        self.encode_thread = make_daemon_thread(self.encode_loop, "encode")
        self.encode_thread.start()

        log("starting network threads")
        self.server_protocol.start()
        self.client_protocol.start()

        #forward the hello packet:
        hello_packet = ("hello", self.filter_client_caps(self.caps))
        self.queue_server_packet(hello_packet)
        self.timeout_add(VIDEO_TIMEOUT*1000, self.timeout_video_encoders)

        try:
            try:
                self.run_queue()
            except KeyboardInterrupt, e:
                self.stop(str(e))
        finally:
            log("ProxyProcess.run() ending %s", os.getpid())

    def video_init(self):
        log("video_init() will try video encoders: %s", self.video_encoder_modules)
        self.video_helper = getVideoHelper()
        #only use video encoders (no CSC supported in proxy)
        self.video_helper.set_modules(video_encoders=self.video_encoder_modules)
        self.video_helper.init()

        self.video_encoding_defs = {}
        self.video_encoders = {}
        self.video_encoders_dst_formats = []
        self.video_encoders_last_used_time = {}
        self.video_encoder_types = []

        #figure out which encoders we want to proxy for (if any):
        encoder_types = set()
        for encoding in self.video_helper.get_encodings():
            colorspace_specs = self.video_helper.get_encoder_specs(encoding)
            for colorspace, especs in colorspace_specs.items():
                if colorspace not in ("BGRX", "BGRA", "RGBX", "RGBA"):
                    #only deal with encoders that can handle plain RGB directly
                    continue

                for spec in especs:                             #ie: codec_spec("x264")
                    spec_props = spec.to_dict()
                    del spec_props["codec_class"]               #not serializable!
                    spec_props["score_boost"] = 50              #we want to win scoring so we get used ahead of other encoders
                    spec_props["max_instances"] = 3             #limit to 3 video streams we proxy for (we really want 2,
                                                                # but because of races with garbage collection, we need to allow more)
                    #store it in encoding defs:
                    self.video_encoding_defs.setdefault(encoding, {}).setdefault(colorspace, []).append(spec_props)
                    encoder_types.add(spec.codec_type)

        log("encoder types found: %s", tuple(encoder_types))
        #remove duplicates and use preferred order:
        order = PREFERRED_ENCODER_ORDER[:]
        for x in list(encoder_types):
            if x not in order:
                order.append(x)
        self.video_encoder_types = [x for x in order if x in encoder_types]
        log.info("proxy video encoders: %s", self.video_encoder_types)


    def create_control_socket(self):
        dotxpra = DotXpra(self.socket_dir)
        name = "proxy-%s" % os.getpid()
        sockpath = dotxpra.norm_make_path(name, dotxpra.sockdir())
        state = dotxpra.get_server_state(sockpath)
        if state in (DotXpra.LIVE, DotXpra.UNKNOWN):
            log.warn("You already have a proxy server running at %s, the control socket will not be created!", sockpath)
            return False
        try:
            sock = create_unix_domain_socket(sockpath, None)
            sock.listen(5)
        except Exception, e:
            log.warn("failed to setup control socket %s: %s", sockpath, e)
            return False
        self.control_socket = sock
        self.control_socket_path = sockpath
        log.info("proxy instance now also available using unix domain socket: %s", self.control_socket_path)
        return True
コード例 #5
0
class ProxyInstanceProcess(Process):

    def __init__(self, uid, gid, env_options, session_options, client_conn, client_state, cipher, encryption_key, server_conn, caps, message_queue):
        Process.__init__(self, name=str(client_conn))
        self.uid = uid
        self.gid = gid
        self.env_options = env_options
        self.session_options = self.sanitize_session_options(session_options)
        self.client_conn = client_conn
        self.client_state = client_state
        self.cipher = cipher
        self.encryption_key = encryption_key
        self.server_conn = server_conn
        self.caps = caps
        debug("ProxyProcess%s", (uid, gid, client_conn, client_state, cipher, encryption_key, server_conn, "{..}"))
        self.client_protocol = None
        self.server_protocol = None
        self.main_queue = None
        self.message_queue = message_queue
        self.video_encoder_types = ["nvenc", "x264"]
        self.video_encoders = {}
        self.video_helper = None

    def server_message_queue(self):
        while True:
            log.info("waiting for server message on %s", self.message_queue)
            m = self.message_queue.get()
            log.info("proxy server message: %s", m)
            if m=="stop":
                self.stop("proxy server request")
                return

    def signal_quit(self, signum, frame):
        log.info("")
        log.info("proxy process pid %s got signal %s, exiting", os.getpid(), SIGNAMES.get(signum, signum))
        signal.signal(signal.SIGINT, deadly_signal)
        signal.signal(signal.SIGTERM, deadly_signal)
        self.stop(SIGNAMES.get(signum, signum))

    def idle_add(self, fn, *args, **kwargs):
        #we emulate gobject's idle_add using a simple queue
        self.main_queue.put((fn, args, kwargs))

    def timeout_add(self, timeout, fn, *args, **kwargs):
        #emulate gobject's timeout_add using idle add and a Timer
        #using custom functions to cancel() the timer when needed
        timer = None
        def idle_exec():
            v = fn(*args, **kwargs)
            if not bool(v):
                timer.cancel()
            return False
        def timer_exec():
            #just run via idle_add:
            self.idle_add(idle_exec)
        timer = Timer(timeout*1000.0, timer_exec)
        timer.start()

    def run(self):
        debug("ProxyProcess.run() pid=%s, uid=%s, gid=%s", os.getpid(), os.getuid(), os.getgid())
        #change uid and gid:
        if os.getgid()!=self.gid:
            os.setgid(self.gid)
        if os.getuid()!=self.uid:
            os.setuid(self.uid)
        debug("ProxyProcess.run() new uid=%s, gid=%s", os.getuid(), os.getgid())

        if self.env_options:
            #TODO: whitelist env update?
            os.environ.update(self.env_options)

        log.info("new proxy started for client %s and server %s", self.client_conn, self.server_conn)

        signal.signal(signal.SIGTERM, self.signal_quit)
        signal.signal(signal.SIGINT, self.signal_quit)
        debug("registered signal handler %s", self.signal_quit)

        make_daemon_thread(self.server_message_queue, "server message queue").start()

        self.main_queue = Queue()
        #setup protocol wrappers:
        self.server_packets = Queue(PROXY_QUEUE_SIZE)
        self.client_packets = Queue(PROXY_QUEUE_SIZE)
        self.client_protocol = Protocol(self, self.client_conn, self.process_client_packet, self.get_client_packet)
        self.client_protocol.restore_state(self.client_state)
        self.server_protocol = Protocol(self, self.server_conn, self.process_server_packet, self.get_server_packet)
        #server connection tweaks:
        self.server_protocol.large_packets.append("draw")
        self.server_protocol.large_packets.append("keymap-changed")
        self.server_protocol.large_packets.append("server-settings")
        self.server_protocol.set_compression_level(self.session_options.get("compression_level", 0))

        debug("starting network threads")
        self.server_protocol.start()
        self.client_protocol.start()

        #forward the hello packet:
        hello_packet = ("hello", self.filter_client_caps(self.caps))
        self.queue_server_packet(hello_packet)

        try:
            try:
                self.run_queue()
            except KeyboardInterrupt, e:
                self.stop(str(e))
        finally:
            debug("ProxyProcess.run() ending %s", os.getpid())


    def sanitize_session_options(self, options):
        d = {}
        def number(k, v):
            return parse_number(int, k, v)
        OPTION_WHITELIST = {"compression_level" : number,
                            "lz4"               : parse_bool}
        for k,v in options.items():
            parser = OPTION_WHITELIST.get(k)
            if parser:
                log("trying to add %s=%s using %s", k, v, parser)
                try:
                    d[k] = parser(k, v)
                except Exception, e:
                    log.warn("failed to parse value %s for %s using %s: %s", v, k, parser, e)
        return d
コード例 #6
0
ファイル: proxy_server.py プロジェクト: rudresh2319/Xpra
class ProxyProcess(Process):
    def __init__(self, uid, gid, env_options, session_options, client_conn,
                 client_state, cipher, encryption_key, server_conn, caps,
                 message_queue):
        Process.__init__(self, name=str(client_conn))
        self.uid = uid
        self.gid = gid
        self.env_options = env_options
        self.session_options = self.sanitize_session_options(session_options)
        self.client_conn = client_conn
        self.client_state = client_state
        self.cipher = cipher
        self.encryption_key = encryption_key
        self.server_conn = server_conn
        self.caps = caps
        debug("ProxyProcess%s", (uid, gid, client_conn, client_state, cipher,
                                 encryption_key, server_conn, "{..}"))
        self.client_protocol = None
        self.server_protocol = None
        self.main_queue = None
        self.message_queue = message_queue

    def server_message_queue(self):
        while True:
            log.info("waiting for server message on %s", self.message_queue)
            m = self.message_queue.get()
            log.info("proxy server message: %s", m)
            if m == "stop":
                self.stop("proxy server request")
                return

    def signal_quit(self, signum, frame):
        log.info("")
        log.info("proxy process pid %s got signal %s, exiting", os.getpid(),
                 SIGNAMES.get(signum, signum))
        signal.signal(signal.SIGINT, deadly_signal)
        signal.signal(signal.SIGTERM, deadly_signal)
        self.stop(SIGNAMES.get(signum, signum))

    def idle_add(self, fn, *args, **kwargs):
        #we emulate gobject's idle_add using a simple queue
        self.main_queue.put((fn, args, kwargs))

    def timeout_add(self, timeout, fn, *args, **kwargs):
        #emulate gobject's timeout_add using idle add and a Timer
        #using custom functions to cancel() the timer when needed
        timer = None

        def idle_exec():
            v = fn(*args, **kwargs)
            if not bool(v):
                timer.cancel()
            return False

        def timer_exec():
            #just run via idle_add:
            self.idle_add(idle_exec)

        timer = Timer(timeout * 1000.0, timer_exec)
        timer.start()

    def run(self):
        debug("ProxyProcess.run() pid=%s, uid=%s, gid=%s", os.getpid(),
              os.getuid(), os.getgid())
        #change uid and gid:
        if os.getgid() != self.gid:
            os.setgid(self.gid)
        if os.getuid() != self.uid:
            os.setuid(self.uid)
        debug("ProxyProcess.run() new uid=%s, gid=%s", os.getuid(),
              os.getgid())

        if self.env_options:
            #TODO: whitelist env update?
            os.environ.update(self.env_options)

        log.info("new proxy started for client %s and server %s",
                 self.client_conn, self.server_conn)

        if not USE_THREADING:
            signal.signal(signal.SIGTERM, self.signal_quit)
            signal.signal(signal.SIGINT, self.signal_quit)
            debug("registered signal handler %s", self.signal_quit)

        make_daemon_thread(self.server_message_queue,
                           "server message queue").start()

        self.main_queue = Queue()
        #setup protocol wrappers:
        self.server_packets = Queue(PROXY_QUEUE_SIZE)
        self.client_packets = Queue(PROXY_QUEUE_SIZE)
        self.client_protocol = Protocol(self, self.client_conn,
                                        self.process_client_packet,
                                        self.get_client_packet)
        self.client_protocol.restore_state(self.client_state)
        self.server_protocol = Protocol(self, self.server_conn,
                                        self.process_server_packet,
                                        self.get_server_packet)
        #server connection tweaks:
        self.server_protocol.large_packets.append("draw")
        self.server_protocol.large_packets.append("keymap-changed")
        self.server_protocol.large_packets.append("server-settings")
        self.server_protocol.set_compression_level(
            self.session_options.get("compression_level", 0))

        debug("starting network threads")
        self.server_protocol.start()
        self.client_protocol.start()

        #forward the hello packet:
        hello_packet = ("hello", self.filter_client_caps(self.caps))
        self.queue_server_packet(hello_packet)

        try:
            try:
                self.run_queue()
            except KeyboardInterrupt, e:
                self.stop(str(e))
        finally:
            debug("ProxyProcess.run() ending %s", os.getpid())

    def sanitize_session_options(self, options):
        d = {}

        def number(k, v):
            return parse_number(int, k, v)

        OPTION_WHITELIST = {"compression_level": number, "lz4": parse_bool}
        for k, v in options.items():
            parser = OPTION_WHITELIST.get(k)
            if parser:
                log("trying to add %s=%s using %s", k, v, parser)
                try:
                    d[k] = parser(k, v)
                except Exception, e:
                    log.warn("failed to parse value %s for %s using %s: %s", v,
                             k, parser, e)
        return d
コード例 #7
0
ファイル: client_base.py プロジェクト: svn2github/Xpra
class XpraClientBase(object):
    """ Base class for Xpra clients.
        Provides the glue code for:
        * sending packets via Protocol
        * handling packets received via _process_packet
        For an actual implementation, look at:
        * GObjectXpraClient
        * xpra.client.gtk2.client
        * xpra.client.gtk3.client
    """

    def __init__(self):
        self.exit_code = None
        self.compression_level = 0
        self.display = None
        self.username = None
        self.password_file = None
        self.password_sent = False
        self.encryption = None
        self.encryption_keyfile = None
        self.quality = -1
        self.min_quality = 0
        self.speed = 0
        self.min_speed = -1
        #protocol stuff:
        self._protocol = None
        self._priority_packets = []
        self._ordinary_packets = []
        self._mouse_position = None
        self._aliases = {}
        self._reverse_aliases = {}
        #server state and caps:
        self.server_capabilities = None
        self._remote_machine_id = None
        self._remote_uuid = None
        self._remote_version = None
        self._remote_revision = None
        self._remote_platform = None
        self._remote_platform_release = None
        self._remote_platform_platform = None
        self._remote_platform_linux_distribution = None
        self.uuid = get_user_uuid()
        self.init_packet_handlers()

    def init(self, opts):
        self.compression_level = opts.compression_level
        self.display = opts.display
        self.username = opts.username
        self.password_file = opts.password_file
        self.encryption = opts.encryption
        self.encryption_keyfile = opts.encryption_keyfile
        self.quality = opts.quality
        self.min_quality = opts.min_quality
        self.speed = opts.speed
        self.min_speed = opts.min_speed

    def timeout_add(self, *args):
        raise Exception("override me!")

    def idle_add(self, *args):
        raise Exception("override me!")

    def source_remove(self, *args):
        raise Exception("override me!")


    def install_signal_handlers(self):
        def deadly_signal(signum, frame):
            sys.stderr.write("\ngot deadly signal %s, exiting\n" % SIGNAMES.get(signum, signum))
            sys.stderr.flush()
            self.cleanup()
            os._exit(128 + signum)
        def app_signal(signum, frame):
            sys.stderr.write("\ngot signal %s, exiting\n" % SIGNAMES.get(signum, signum))
            sys.stderr.flush()
            signal.signal(signal.SIGINT, deadly_signal)
            signal.signal(signal.SIGTERM, deadly_signal)
            self.timeout_add(0, self.quit, 128 + signum)
        signal.signal(signal.SIGINT, app_signal)
        signal.signal(signal.SIGTERM, app_signal)


    def client_type(self):
        #overriden in subclasses!
        return "Python"

    def get_scheduler(self):
        raise NotImplementedError()

    def setup_connection(self, conn):
        log.debug("setup_connection(%s)", conn)
        self._protocol = Protocol(self.get_scheduler(), conn, self.process_packet, self.next_packet)
        self._protocol.large_packets.append("keymap-changed")
        self._protocol.large_packets.append("server-settings")
        self._protocol.set_compression_level(self.compression_level)
        self.have_more = self._protocol.source_has_more

    def init_packet_handlers(self):
        self._packet_handlers = {
            "hello": self._process_hello,
            }
        self._ui_packet_handlers = {
            "challenge": self._process_challenge,
            "disconnect": self._process_disconnect,
            "set_deflate": self._process_set_deflate,
            Protocol.CONNECTION_LOST: self._process_connection_lost,
            Protocol.GIBBERISH: self._process_gibberish,
            }

    def init_aliases(self):
        packet_types = list(self._packet_handlers.keys())
        packet_types += list(self._ui_packet_handlers.keys())
        i = 1
        for key in packet_types:
            self._aliases[i] = key
            self._reverse_aliases[key] = i
            i += 1

    def send_hello(self, challenge_response=None, client_salt=None):
        hello = self.make_hello_base()
        if self.password_file and not challenge_response:
            #avoid sending the full hello: tell the server we want
            #a packet challenge first
            hello["challenge"] = True
        else:
            hello.update(self.make_hello())
        if challenge_response:
            assert self.password_file
            hello["challenge_response"] = challenge_response
            if client_salt:
                hello["challenge_client_salt"] = client_salt
        log.debug("send_hello(%s) packet=%s", binascii.hexlify(challenge_response or ""), hello)
        self.send("hello", hello)
        self.timeout_add(DEFAULT_TIMEOUT, self.verify_connected)

    def verify_connected(self):
        if self.server_capabilities is None:
            #server has not said hello yet
            self.warn_and_quit(EXIT_TIMEOUT, "connection timed out")


    def make_hello_base(self):
        capabilities = get_network_caps()
        capabilities.update({
                "encoding.generic"      : True,
                "namespace"             : True,
                "hostname"              : socket.gethostname(),
                "uuid"                  : self.uuid,
                "username"              : self.username,
                "name"                  : get_name(),
                "client_type"           : self.client_type(),
                "python.version"        : sys.version_info[:3],
                "compression_level"     : self.compression_level,
                })
        if self.display:
            capabilities["display"] = self.display
        capabilities.update(get_platform_info())
        add_version_info(capabilities)
        mid = get_machine_id()
        if mid:
            capabilities["machine_id"] = mid

        if self.encryption:
            assert self.encryption in ENCRYPTION_CIPHERS
            iv = get_hex_uuid()[:16]
            key_salt = get_hex_uuid()+get_hex_uuid()
            iterations = 1000
            capabilities.update({
                        "cipher"                       : self.encryption,
                        "cipher.iv"                    : iv,
                        "cipher.key_salt"              : key_salt,
                        "cipher.key_stretch_iterations": iterations,
                        })
            key = self.get_encryption_key()
            if key is None:
                self.warn_and_quit(EXIT_ENCRYPTION, "encryption key is missing")
                return
            self._protocol.set_cipher_in(self.encryption, iv, key, key_salt, iterations)
            log("encryption capabilities: %s", [(k,v) for k,v in capabilities.items() if k.startswith("cipher")])
        return capabilities

    def make_hello(self):
        capabilities = {
                        "randr_notify"        : False,        #only client.py cares about this
                        "windows"            : False,        #only client.py cares about this
                       }
        if self._reverse_aliases:
            capabilities["aliases"] = self._reverse_aliases
        return capabilities

    def send(self, *parts):
        self._ordinary_packets.append(parts)
        self.have_more()

    def send_now(self, *parts):
        self._priority_packets.append(parts)
        self.have_more()

    def send_positional(self, packet):
        self._ordinary_packets.append(packet)
        self._mouse_position = None
        self.have_more()

    def send_mouse_position(self, packet):
        self._mouse_position = packet
        self.have_more()

    def have_more(self):
        #this function is overridden in setup_protocol()
        p = self._protocol
        if p and p.source:
            p.source_has_more()

    def next_packet(self):
        if self._priority_packets:
            packet = self._priority_packets.pop(0)
        elif self._ordinary_packets:
            packet = self._ordinary_packets.pop(0)
        elif self._mouse_position is not None:
            packet = self._mouse_position
            self._mouse_position = None
        else:
            packet = None
        has_more = packet is not None and \
                (bool(self._priority_packets) or bool(self._ordinary_packets) \
                 or self._mouse_position is not None)
        return packet, None, None, has_more


    def cleanup(self):
        log("XpraClientBase.cleanup() protocol=%s", self._protocol)
        if self._protocol:
            self._protocol.close()
            self._protocol = None

    def glib_init(self):
        try:
            glib = import_glib()
            try:
                glib.threads_init()
            except AttributeError:
                #old versions of glib may not have this method
                pass
        except ImportError:
            pass

    def run(self):
        self._protocol.start()

    def quit(self, exit_code=0):
        raise Exception("override me!")

    def warn_and_quit(self, exit_code, warning):
        log.warn(warning)
        self.quit(exit_code)

    def _process_disconnect(self, packet):
        if len(packet)==2:
            info = packet[1]
        else:
            info = packet[1:]
        e = EXIT_OK
        if self.server_capabilities is None or len(self.server_capabilities)==0:
            #server never sent hello to us - so disconnect is an error
            #(but we don't know which one - the info message may help)
            e = EXIT_FAILURE
        self.warn_and_quit(e, "server requested disconnect: %s" % info)

    def _process_connection_lost(self, packet):
        self.warn_and_quit(EXIT_CONNECTION_LOST, "Connection lost")

    def _process_challenge(self, packet):
        log("processing challenge: %s", packet[1:])
        if not self.password_file:
            self.warn_and_quit(EXIT_PASSWORD_REQUIRED, "server requires authentication, please provide a password")
            return
        password = self.load_password()
        if not password:
            self.warn_and_quit(EXIT_PASSWORD_FILE_ERROR, "failed to load password from file %s" % self.password_file)
            return
        salt = packet[1]
        if self.encryption:
            assert len(packet)>=3, "challenge does not contain encryption details to use for the response"
            server_cipher = packet[2]
            key = self.get_encryption_key()
            if key is None:
                self.warn_and_quit(EXIT_ENCRYPTION, "encryption key is missing")
                return
            if not self.set_server_encryption(server_cipher, key):
                return
        digest = "hmac"
        client_can_salt = len(packet)>=4
        client_salt = None
        if client_can_salt:
            #server supports client salt, and tells us which digest to use:
            digest = packet[3]
            client_salt = get_hex_uuid()+get_hex_uuid()
            #TODO: use some key stretching algorigthm? (meh)
            salt = xor(salt, client_salt)
        if digest=="hmac":
            import hmac
            challenge_response = hmac.HMAC(password, salt).hexdigest()
        elif digest=="xor":
            #don't send XORed password unencrypted:
            if not self._protocol.cipher_out and not ALLOW_UNENCRYPTED_PASSWORDS:
                self.warn_and_quit(EXIT_ENCRYPTION, "server requested digest %s, cowardly refusing to use it without encryption" % digest)
                return
            challenge_response = xor(password, salt)
        else:
            self.warn_and_quit(EXIT_PASSWORD_REQUIRED, "server requested an unsupported digest: %s" % digest)
            return
        if digest:
            log("%s(%s, %s)=%s", digest, password, salt, challenge_response)
        self.password_sent = True
        self.send_hello(challenge_response, client_salt)

    def set_server_encryption(self, capabilities, key):
        def get(key, default=None):
            return capabilities.get(strtobytes(key), default)
        cipher = get("cipher")
        cipher_iv = get("cipher.iv")
        key_salt = get("cipher.key_salt")
        iterations = get("cipher.key_stretch_iterations")
        if not cipher or not cipher_iv:
            self.warn_and_quit(EXIT_ENCRYPTION, "the server does not use or support encryption/password, cannot continue with %s cipher" % self.encryption)
            return False
        if cipher not in ENCRYPTION_CIPHERS:
            self.warn_and_quit(EXIT_ENCRYPTION, "unsupported server cipher: %s, allowed ciphers: %s" % (cipher, ", ".join(ENCRYPTION_CIPHERS)))
            return False
        self._protocol.set_cipher_out(cipher, cipher_iv, key, key_salt, iterations)
        return True


    def get_encryption_key(self):
        key = load_binary_file(self.encryption_keyfile)
        if key is None and self.password_file:
            key = load_binary_file(self.password_file)
            if key:
                log("used password file as encryption key")
        if key is None:
            raise Exception("failed to load encryption keyfile %s" % self.encryption_keyfile)
        return key.strip("\n\r")

    def load_password(self):
        filename = os.path.expanduser(self.password_file)
        password = load_binary_file(filename)
        if password is None:
            return None
        password = password.strip("\n\r")
        log("password read from file %s is %s", self.password_file, "".join(["*" for _ in password]))
        return password

    def _process_hello(self, packet):
        if not self.password_sent and self.password_file:
            self.warn_and_quit(EXIT_NO_AUTHENTICATION, "the server did not request our password")
            return
        try:
            self.server_capabilities = packet[1]
            log("processing hello from server: %s", self.server_capabilities)
            c = typedict(self.server_capabilities)
            self.parse_server_capabilities(c)
        except Exception, e:
            self.warn_and_quit(EXIT_FAILURE, "error processing hello packet from server: %s" % e)
コード例 #8
0
ファイル: client_base.py プロジェクト: rudresh2319/Xpra
class XpraClientBase(object):
    """ Base class for Xpra clients.
        Provides the glue code for:
        * sending packets via Protocol
        * handling packets received via _process_packet
        For an actual implementation, look at:
        * GObjectXpraClient
        * xpra.client.gtk2.client
        * xpra.client.gtk3.client
    """

    def __init__(self):
        #this may be called more than once,
        #skip doing internal init again:
        if not hasattr(self, "exit_code"):
            self.defaults_init()

    def defaults_init(self):
        self.exit_code = None
        self.compression_level = 0
        self.display = None
        self.username = None
        self.password_file = None
        self.password_sent = False
        self.encryption = None
        self.encryption_keyfile = None
        self.quality = -1
        self.min_quality = 0
        self.speed = 0
        self.min_speed = -1
        #protocol stuff:
        self._protocol = None
        self._priority_packets = []
        self._ordinary_packets = []
        self._mouse_position = None
        self._aliases = {}
        self._reverse_aliases = {}
        #server state and caps:
        self.server_capabilities = None
        self._remote_machine_id = None
        self._remote_uuid = None
        self._remote_version = None
        self._remote_revision = None
        self._remote_platform = None
        self._remote_platform_release = None
        self._remote_platform_platform = None
        self._remote_platform_linux_distribution = None
        self.uuid = get_user_uuid()
        self.init_packet_handlers()

    def init(self, opts):
        self.compression_level = opts.compression_level
        self.display = opts.display
        self.username = opts.username
        self.password_file = opts.password_file
        self.encryption = opts.encryption
        self.encryption_keyfile = opts.encryption_keyfile
        self.quality = opts.quality
        self.min_quality = opts.min_quality
        self.speed = opts.speed
        self.min_speed = opts.min_speed

        if DETECT_LEAKS:
            from xpra.util import detect_leaks
            detailed = []
            #example: warning, uses ugly direct import:
            #try:
            #    from xpra.x11.bindings.ximage import XShmImageWrapper       #@UnresolvedImport
            #    detailed.append(XShmImageWrapper)
            #except:
            #    pass
            print_leaks = detect_leaks(log, detailed)
            self.timeout_add(10*1000, print_leaks)


    def timeout_add(self, *args):
        raise Exception("override me!")

    def idle_add(self, *args):
        raise Exception("override me!")

    def source_remove(self, *args):
        raise Exception("override me!")


    def install_signal_handlers(self):
        def deadly_signal(signum, frame):
            sys.stderr.write("\ngot deadly signal %s, exiting\n" % SIGNAMES.get(signum, signum))
            sys.stderr.flush()
            self.cleanup()
            os._exit(128 + signum)
        def app_signal(signum, frame):
            sys.stderr.write("\ngot signal %s, exiting\n" % SIGNAMES.get(signum, signum))
            sys.stderr.flush()
            signal.signal(signal.SIGINT, deadly_signal)
            signal.signal(signal.SIGTERM, deadly_signal)
            self.timeout_add(0, self.disconnect_and_quit, 128 + signum, "exit on signal %s" % SIGNAMES.get(signum, signum))
        signal.signal(signal.SIGINT, app_signal)
        signal.signal(signal.SIGTERM, app_signal)

    def disconnect_and_quit(self, exit_code, reason):
        #make sure that we set the exit code early,
        #so the protocol shutdown won't set a different one:
        if self.exit_code is None:
            self.exit_code = exit_code
        #try to tell the server we're going, then quit
        log("disconnect_and_quit(%s, %s)", exit_code, reason)
        p = self._protocol
        if p is None or p._closed:
            self.quit(exit_code)
            return
        def do_quit():
            log("disconnect_and_quit: do_quit()")
            self.idle_add(self.quit, exit_code)
        if p:
            p.flush_then_close(["disconnect", reason], done_callback=do_quit)
        self.timeout_add(1000, do_quit)


    def client_type(self):
        #overriden in subclasses!
        return "Python"

    def get_scheduler(self):
        raise NotImplementedError()

    def setup_connection(self, conn):
        log.debug("setup_connection(%s)", conn)
        self._protocol = Protocol(self.get_scheduler(), conn, self.process_packet, self.next_packet)
        self._protocol.large_packets.append("keymap-changed")
        self._protocol.large_packets.append("server-settings")
        self._protocol.set_compression_level(self.compression_level)
        self._protocol.receive_aliases.update(self._aliases)
        self._protocol.enable_default_encoder()
        self._protocol.enable_default_compressor()
        self.have_more = self._protocol.source_has_more
        if conn.timeout>0:
            self.timeout_add((conn.timeout + EXTRA_TIMEOUT) * 1000, self.verify_connected)

    def init_packet_handlers(self):
        self._packet_handlers = {
            "hello": self._process_hello,
            }
        self._ui_packet_handlers = {
            "challenge":                self._process_challenge,
            "disconnect":               self._process_disconnect,
            "set_deflate":              self._process_set_deflate,
            "startup-complete":         self._process_startup_complete,
            Protocol.CONNECTION_LOST:   self._process_connection_lost,
            Protocol.GIBBERISH:         self._process_gibberish,
            Protocol.INVALID:           self._process_invalid,
            }

    def init_authenticated_packet_handlers(self):
        #for subclasses to override
        pass


    def init_aliases(self):
        packet_types = list(self._packet_handlers.keys())
        packet_types += list(self._ui_packet_handlers.keys())
        i = 1
        for key in packet_types:
            self._aliases[i] = key
            self._reverse_aliases[key] = i
            i += 1

    def send_hello(self, challenge_response=None, client_salt=None):
        try:
            hello = self.make_hello_base()
            if self.password_file and not challenge_response:
                #avoid sending the full hello: tell the server we want
                #a packet challenge first
                hello["challenge"] = True
            else:
                hello.update(self.make_hello())
        except Exception, e:
            log.error("error preparing connection: %s", e)
            self.quit(EXIT_INTERNAL_ERROR)
            return
        if challenge_response:
            assert self.password_file
            hello["challenge_response"] = challenge_response
            if client_salt:
                hello["challenge_client_salt"] = client_salt
        log.debug("send_hello(%s) packet=%s", binascii.hexlify(b(challenge_response or "")), hello)
        self.send("hello", hello)
コード例 #9
0
ファイル: client_base.py プロジェクト: ljmljz/xpra
class XpraClientBase(FileTransferHandler):
    """ Base class for Xpra clients.
        Provides the glue code for:
        * sending packets via Protocol
        * handling packets received via _process_packet
        For an actual implementation, look at:
        * GObjectXpraClient
        * xpra.client.gtk2.client
        * xpra.client.gtk3.client
    """

    def __init__(self):
        FileTransferHandler.__init__(self)
        #this may be called more than once,
        #skip doing internal init again:
        if not hasattr(self, "exit_code"):
            self.defaults_init()

    def defaults_init(self):
        #skip warning when running the client
        from xpra import child_reaper
        child_reaper.POLL_WARNING = False
        getChildReaper()
        crypto_backend_init()
        log("XpraClientBase.defaults_init() os.environ:")
        for k,v in os.environ.items():
            log(" %s=%s", k, nonl(v))
        #client state:
        self.exit_code = None
        self.exit_on_signal = False
        #connection attributes:
        self.compression_level = 0
        self.display = None
        self.username = None
        self.password_file = None
        self.password_sent = False
        self.encryption = None
        self.encryption_keyfile = None
        self.server_padding_options = [DEFAULT_PADDING]
        self.quality = -1
        self.min_quality = 0
        self.speed = 0
        self.min_speed = -1
        self.printer_attributes = []
        self.send_printers_pending = False
        self.exported_printers = None
        #protocol stuff:
        self._protocol = None
        self._priority_packets = []
        self._ordinary_packets = []
        self._mouse_position = None
        self._aliases = {}
        self._reverse_aliases = {}
        #server state and caps:
        self.server_capabilities = None
        self._remote_machine_id = None
        self._remote_uuid = None
        self._remote_version = None
        self._remote_revision = None
        self._remote_platform = None
        self._remote_platform_release = None
        self._remote_platform_platform = None
        self._remote_platform_linux_distribution = None
        self.uuid = get_user_uuid()
        self.init_packet_handlers()
        sanity_checks()

    def init(self, opts):
        self.compression_level = opts.compression_level
        self.display = opts.display
        self.username = opts.username
        self.password_file = opts.password_file
        self.encryption = opts.encryption or opts.tcp_encryption
        self.encryption_keyfile = opts.encryption_keyfile or opts.tcp_encryption_keyfile
        self.quality = opts.quality
        self.min_quality = opts.min_quality
        self.speed = opts.speed
        self.min_speed = opts.min_speed
        #printing and file transfer:
        FileTransferHandler.init(self, opts)

        if DETECT_LEAKS:
            from xpra.util import detect_leaks
            detailed = []
            #example: warning, uses ugly direct import:
            #try:
            #    from xpra.x11.bindings.ximage import XShmImageWrapper       #@UnresolvedImport
            #    detailed.append(XShmImageWrapper)
            #except:
            #    pass
            print_leaks = detect_leaks(log, detailed)
            self.timeout_add(10*1000, print_leaks)


    def timeout_add(self, *args):
        raise Exception("override me!")

    def idle_add(self, *args):
        raise Exception("override me!")

    def source_remove(self, *args):
        raise Exception("override me!")


    def install_signal_handlers(self):
        def deadly_signal(signum, frame):
            sys.stderr.write("\ngot deadly signal %s, exiting\n" % SIGNAMES.get(signum, signum))
            sys.stderr.flush()
            self.cleanup()
            os._exit(128 + signum)
        def app_signal(signum, frame):
            sys.stderr.write("\ngot signal %s, exiting\n" % SIGNAMES.get(signum, signum))
            sys.stderr.flush()
            signal.signal(signal.SIGINT, deadly_signal)
            signal.signal(signal.SIGTERM, deadly_signal)
            self.signal_cleanup()
            self.timeout_add(0, self.signal_disconnect_and_quit, 128 + signum, "exit on signal %s" % SIGNAMES.get(signum, signum))
        if sys.version_info[0]<3:
            #breaks GTK3..
            signal.signal(signal.SIGINT, app_signal)
        signal.signal(signal.SIGTERM, app_signal)

    def signal_disconnect_and_quit(self, exit_code, reason):
        log("signal_disconnect_and_quit(%s, %s) exit_on_signal=%s", exit_code, reason, self.exit_on_signal)
        if not self.exit_on_signal:
            #if we get another signal, we'll try to exit without idle_add...
            self.exit_on_signal = True
            self.idle_add(self.disconnect_and_quit, exit_code, reason)
            self.idle_add(self.quit, exit_code)
            self.idle_add(self.exit)
            return
        #warning: this will run cleanup code from the signal handler
        self.disconnect_and_quit(exit_code, reason)
        self.quit(exit_code)
        self.exit()
        os._exit(exit_code)

    def signal_cleanup(self):
        #placeholder for stuff that can be cleaned up from the signal handler
        #(non UI thread stuff)
        pass

    def disconnect_and_quit(self, exit_code, reason):
        #make sure that we set the exit code early,
        #so the protocol shutdown won't set a different one:
        if self.exit_code is None:
            self.exit_code = exit_code
        #try to tell the server we're going, then quit
        log("disconnect_and_quit(%s, %s)", exit_code, reason)
        p = self._protocol
        if p is None or p._closed:
            self.quit(exit_code)
            return
        def protocol_closed():
            log("disconnect_and_quit: protocol_closed()")
            self.idle_add(self.quit, exit_code)
        if p:
            p.flush_then_close(["disconnect", reason], done_callback=protocol_closed)
        self.timeout_add(1000, self.quit, exit_code)

    def exit(self):
        sys.exit()


    def client_type(self):
        #overriden in subclasses!
        return "Python"

    def get_scheduler(self):
        raise NotImplementedError()

    def setup_connection(self, conn):
        netlog("setup_connection(%s) timeout=%s", conn, conn.timeout)
        self._protocol = Protocol(self.get_scheduler(), conn, self.process_packet, self.next_packet)
        self._protocol.large_packets.append("keymap-changed")
        self._protocol.large_packets.append("server-settings")
        self._protocol.large_packets.append("logging")
        self._protocol.set_compression_level(self.compression_level)
        self._protocol.receive_aliases.update(self._aliases)
        self._protocol.enable_default_encoder()
        self._protocol.enable_default_compressor()
        if self.encryption and ENCRYPT_FIRST_PACKET:
            key = self.get_encryption_key()
            self._protocol.set_cipher_out(self.encryption, DEFAULT_IV, key, DEFAULT_SALT, DEFAULT_ITERATIONS, INITIAL_PADDING)
        self.have_more = self._protocol.source_has_more
        if conn.timeout>0:
            self.timeout_add((conn.timeout + EXTRA_TIMEOUT) * 1000, self.verify_connected)
        netlog("setup_connection(%s) protocol=%s", conn, self._protocol)


    def remove_packet_handlers(self, *keys):
        for k in keys:
            for d in (self._packet_handlers, self._ui_packet_handlers):
                try:
                    del d[k]
                except:
                    pass

    def set_packet_handlers(self, to, defs):
        """ configures the given packet handlers,
            and make sure we remove any existing ones with the same key
            (which can be useful for subclasses, not here)
        """
        log("set_packet_handlers(%s, %s)", to, defs)
        self.remove_packet_handlers(*defs.keys())
        for k,v in defs.items():
            to[k] = v

    def init_packet_handlers(self):
        self._packet_handlers = {}
        self._ui_packet_handlers = {}
        self.set_packet_handlers(self._packet_handlers, {"hello" : self._process_hello})
        self.set_packet_handlers(self._ui_packet_handlers, {
            "challenge":                self._process_challenge,
            "disconnect":               self._process_disconnect,
            "set_deflate":              self._process_set_deflate,
            "startup-complete":         self._process_startup_complete,
            Protocol.CONNECTION_LOST:   self._process_connection_lost,
            Protocol.GIBBERISH:         self._process_gibberish,
            Protocol.INVALID:           self._process_invalid,
            })

    def init_authenticated_packet_handlers(self):
        self.set_packet_handlers(self._packet_handlers, {"send-file" : self._process_send_file})


    def init_aliases(self):
        packet_types = list(self._packet_handlers.keys())
        packet_types += list(self._ui_packet_handlers.keys())
        i = 1
        for key in packet_types:
            self._aliases[i] = key
            self._reverse_aliases[key] = i
            i += 1

    def has_password(self):
        return self.password_file or os.environ.get('XPRA_PASSWORD')

    def send_hello(self, challenge_response=None, client_salt=None):
        try:
            hello = self.make_hello_base()
            if self.has_password() and not challenge_response:
                #avoid sending the full hello: tell the server we want
                #a packet challenge first
                hello["challenge"] = True
            else:
                hello.update(self.make_hello())
        except InitExit as e:
            log.error("error preparing connection:")
            log.error(" %s", e)
            self.quit(EXIT_INTERNAL_ERROR)
            return
        except Exception as e:
            log.error("error preparing connection: %s", e, exc_info=True)
            self.quit(EXIT_INTERNAL_ERROR)
            return
        if challenge_response:
            assert self.has_password(), "got a password challenge response but we don't have a password! (malicious or broken server?)"
            hello["challenge_response"] = challenge_response
            if client_salt:
                hello["challenge_client_salt"] = client_salt
        log("send_hello(%s) packet=%s", binascii.hexlify(strtobytes(challenge_response or "")), hello)
        self.send("hello", hello)

    def verify_connected(self):
        if self.server_capabilities is None:
            #server has not said hello yet
            self.warn_and_quit(EXIT_TIMEOUT, "connection timed out")


    def make_hello_base(self):
        capabilities = flatten_dict(get_network_caps())
        capabilities.update({
                "version"               : local_version,
                "encoding.generic"      : True,
                "namespace"             : True,
                "hostname"              : socket.gethostname(),
                "uuid"                  : self.uuid,
                "username"              : self.username,
                "name"                  : get_name(),
                "client_type"           : self.client_type(),
                "python.version"        : sys.version_info[:3],
                "compression_level"     : self.compression_level,
                "argv"                  : sys.argv,
                })
        capabilities.update(self.get_file_transfer_features())
        if self.display:
            capabilities["display"] = self.display
        def up(prefix, d):
            updict(capabilities, prefix, d)
        up("build",     self.get_version_info())
        mid = get_machine_id()
        if mid:
            capabilities["machine_id"] = mid

        if self.encryption:
            assert self.encryption in ENCRYPTION_CIPHERS
            iv = get_iv()
            key_salt = get_salt()
            iterations = get_iterations()
            padding = choose_padding(self.server_padding_options)
            up("cipher", {
                    ""                      : self.encryption,
                    "iv"                    : iv,
                    "key_salt"              : key_salt,
                    "key_stretch_iterations": iterations,
                    "padding"               : padding,
                    "padding.options"       : PADDING_OPTIONS,
                    })
            key = self.get_encryption_key()
            if key is None:
                self.warn_and_quit(EXIT_ENCRYPTION, "encryption key is missing")
                return
            self._protocol.set_cipher_in(self.encryption, iv, key, key_salt, iterations, padding)
            netlog("encryption capabilities: %s", dict((k,v) for k,v in capabilities.items() if k.startswith("cipher")))
        return capabilities

    def get_version_info(self):
        return get_version_info()

    def make_hello(self):
        capabilities = {
                        "randr_notify"        : False,        #only client.py cares about this
                        "windows"            : False,        #only client.py cares about this
                       }
        if self._reverse_aliases:
            capabilities["aliases"] = self._reverse_aliases
        return capabilities

    def compressed_wrapper(self, datatype, data, level=5):
        #FIXME: ugly assumptions here, should pass by name!
        zlib = "zlib" in self.server_compressors and compression.use_zlib
        lz4 = "lz4" in self.server_compressors and compression.use_lz4
        lzo = "lzo" in self.server_compressors and compression.use_lzo
        if level>0 and len(data)>=256 and (zlib or lz4 or lzo):
            cw = compression.compressed_wrapper(datatype, data, level=level, zlib=zlib, lz4=lz4, lzo=lzo, can_inline=False)
            if len(cw)<len(data):
                #the compressed version is smaller, use it:
                return cw
        #we can't compress, so at least avoid warnings in the protocol layer:
        return compression.Compressed("raw %s" % datatype, data, can_inline=True)


    def send(self, *parts):
        self._ordinary_packets.append(parts)
        self.have_more()

    def send_now(self, *parts):
        self._priority_packets.append(parts)
        self.have_more()

    def send_positional(self, packet):
        self._ordinary_packets.append(packet)
        self._mouse_position = None
        self.have_more()

    def send_mouse_position(self, packet):
        self._mouse_position = packet
        self.have_more()

    def have_more(self):
        #this function is overridden in setup_protocol()
        p = self._protocol
        if p and p.source:
            p.source_has_more()

    def next_packet(self):
        if self._priority_packets:
            packet = self._priority_packets.pop(0)
        elif self._ordinary_packets:
            packet = self._ordinary_packets.pop(0)
        elif self._mouse_position is not None:
            packet = self._mouse_position
            self._mouse_position = None
        else:
            packet = None
        has_more = packet is not None and \
                (bool(self._priority_packets) or bool(self._ordinary_packets) \
                 or self._mouse_position is not None)
        return packet, None, None, has_more


    def cleanup(self):
        reaper_cleanup()
        p = self._protocol
        log("XpraClientBase.cleanup() protocol=%s", p)
        if p:
            log("calling %s", p.close)
            p.close()
            self._protocol = None
        self.cleanup_printing()
        log("cleanup done")
        dump_all_frames()


    def glib_init(self):
        try:
            glib = import_glib()
            try:
                glib.threads_init()
            except AttributeError:
                #old versions of glib may not have this method
                pass
        except ImportError:
            pass

    def run(self):
        self._protocol.start()

    def quit(self, exit_code=0):
        raise Exception("override me!")

    def warn_and_quit(self, exit_code, message):
        log.warn(message)
        self.quit(exit_code)


    def _process_disconnect(self, packet):
        #ie: ("disconnect", "version error", "incompatible version")
        reason = bytestostr(packet[1])
        info = packet[2:]
        s = nonl(reason)
        if len(info):
            s += " (%s)" % (", ".join([nonl(bytestostr(x)) for x in info]))
        if self.server_capabilities is None or len(self.server_capabilities)==0:
            #server never sent hello to us - so disconnect is an error
            #(but we don't know which one - the info message may help)
            log.warn("server failure: disconnected before the session could be established")
            e = EXIT_FAILURE
        elif disconnect_is_an_error(reason):
            log.warn("server failure: %s", reason)
            e = EXIT_FAILURE
        else:
            if self.exit_code is None:
                #we're not in the process of exiting already,
                #tell the user why the server is disconnecting us
                log.info("server requested disconnect: %s", s)
            self.quit(EXIT_OK)
            return
        self.warn_and_quit(e, "server requested disconnect: %s" % s)

    def _process_connection_lost(self, packet):
        p = self._protocol
        if p and p.input_raw_packetcount==0:
            props = p.get_info()
            c = props.get("compression", "unknown")
            e = props.get("encoder", "unknown")
            netlog.warn("failed to receive anything, not an xpra server?")
            netlog.warn("  could also be the wrong username, password or port")
            if c!="unknown" or e!="unknown":
                netlog.warn("  or maybe this server does not support '%s' compression or '%s' packet encoding?", c, e)
        if self.exit_code!=0:
            self.warn_and_quit(EXIT_CONNECTION_LOST, "Connection lost")

    def _process_challenge(self, packet):
        authlog("processing challenge: %s", packet[1:])
        def warn_server_and_exit(code, message, server_message="authentication failed"):
            authlog.error("Error: authentication failed:")
            authlog.error(" %s", message)
            self.disconnect_and_quit(code, server_message)
        if not self.has_password():
            warn_server_and_exit(EXIT_PASSWORD_REQUIRED, "this server requires authentication, please provide a password", "no password available")
            return
        password = self.load_password()
        if not password:
            warn_server_and_exit(EXIT_PASSWORD_FILE_ERROR, "failed to load password from file %s" % self.password_file, "no password available")
            return
        salt = packet[1]
        if self.encryption:
            assert len(packet)>=3, "challenge does not contain encryption details to use for the response"
            server_cipher = typedict(packet[2])
            key = self.get_encryption_key()
            if key is None:
                warn_server_and_exit(EXIT_ENCRYPTION, "the server does not use any encryption", "client requires encryption")
                return
            if not self.set_server_encryption(server_cipher, key):
                return
        #all server versions support a client salt,
        #they also tell us which digest to use:
        digest = packet[3]
        client_salt = get_hex_uuid()+get_hex_uuid()
        #TODO: use some key stretching algorigthm? (meh)
        try:
            from xpra.codecs.xor.cyxor import xor_str           #@UnresolvedImport
            salt = xor_str(salt, client_salt)
        except:
            salt = xor(salt, client_salt)
        if digest==b"hmac":
            import hmac, hashlib
            password = strtobytes(password)
            salt = strtobytes(salt)
            challenge_response = hmac.HMAC(password, salt, digestmod=hashlib.md5).hexdigest()
        elif digest==b"xor":
            #don't send XORed password unencrypted:
            if not self._protocol.cipher_out and not ALLOW_UNENCRYPTED_PASSWORDS:
                warn_server_and_exit(EXIT_ENCRYPTION, "server requested digest %s, cowardly refusing to use it without encryption" % digest, "invalid digest")
                return
            challenge_response = xor(password, salt)
        else:
            warn_server_and_exit(EXIT_PASSWORD_REQUIRED, "server requested an unsupported digest: %s" % digest, "invalid digest")
            return
        if digest:
            authlog("%s(%s, %s)=%s", digest, binascii.hexlify(password), binascii.hexlify(salt), challenge_response)
        self.password_sent = True
        self.remove_packet_handlers("challenge")
        self.send_hello(challenge_response, client_salt)

    def set_server_encryption(self, caps, key):
        cipher = caps.strget("cipher")
        cipher_iv = caps.strget("cipher.iv")
        key_salt = caps.strget("cipher.key_salt")
        iterations = caps.intget("cipher.key_stretch_iterations")
        padding = caps.strget("cipher.padding", DEFAULT_PADDING)
        #server may tell us what it supports,
        #either from hello response or from challenge packet:
        self.server_padding_options = caps.strlistget("cipher.padding.options", [DEFAULT_PADDING])
        if not cipher or not cipher_iv:
            self.warn_and_quit(EXIT_ENCRYPTION, "the server does not use or support encryption/password, cannot continue with %s cipher" % self.encryption)
            return False
        if cipher not in ENCRYPTION_CIPHERS:
            self.warn_and_quit(EXIT_ENCRYPTION, "unsupported server cipher: %s, allowed ciphers: %s" % (cipher, ", ".join(ENCRYPTION_CIPHERS)))
            return False
        if padding not in ALL_PADDING_OPTIONS:
            self.warn_and_quit(EXIT_ENCRYPTION, "unsupported server cipher padding: %s, allowed ciphers: %s" % (padding, ", ".join(ALL_PADDING_OPTIONS)))
            return False
        p = self._protocol
        if not p:
            return False
        p.set_cipher_out(cipher, cipher_iv, key, key_salt, iterations, padding)
        return True


    def get_encryption_key(self):
        key = load_binary_file(self.encryption_keyfile)
        if not key:
            key = os.environ.get('XPRA_ENCRYPTION_KEY')
        if not key:
            raise InitExit(1, "no encryption key")
        return key.strip("\n\r")

    def load_password(self):
        if not self.password_file:
            return os.environ.get('XPRA_PASSWORD')
        filename = os.path.expanduser(self.password_file)
        password = load_binary_file(filename)
        netlog("password read from file %s is %s", self.password_file, "".join(["*" for _ in (password or "")]))
        return password

    def _process_hello(self, packet):
        if not self.password_sent and self.has_password():
            self.warn_and_quit(EXIT_NO_AUTHENTICATION, "the server did not request our password")
            return
        try:
            self.server_capabilities = typedict(packet[1])
            netlog("processing hello from server: %s", self.server_capabilities)
            self.server_connection_established()
        except Exception as e:
            netlog.info("error in hello packet", exc_info=True)
            self.warn_and_quit(EXIT_FAILURE, "error processing hello packet from server: %s" % e)

    def capsget(self, capabilities, key, default):
        v = capabilities.get(strtobytes(key), default)
        if sys.version >= '3' and type(v)==bytes:
            v = bytestostr(v)
        return v


    def server_connection_established(self):
        netlog("server_connection_established()")
        if not self.parse_version_capabilities():
            netlog("server_connection_established() failed version capabilities")
            return False
        if not self.parse_server_capabilities():
            netlog("server_connection_established() failed server capabilities")
            return False
        if not self.parse_network_capabilities():
            netlog("server_connection_established() failed network capabilities")
            return False
        if not self.parse_encryption_capabilities():
            netlog("server_connection_established() failed encryption capabilities")
            return False
        self.parse_printing_capabilities()
        self.parse_logging_capabilities()
        netlog("server_connection_established() adding authenticated packet handlers")
        self.init_authenticated_packet_handlers()
        return True

    def parse_logging_capabilities(self):
        pass

    def parse_printing_capabilities(self):
        if self.printing:
            if self.server_capabilities.boolget("printing"):
                self.printer_attributes = self.server_capabilities.strlistget("printer.attributes", ["printer-info", "device-uri"])
                self.timeout_add(1000, self.init_printing)

    def init_printing(self):
        try:
            from xpra.platform.printing import init_printing
            printlog("init_printing=%s", init_printing)
            init_printing(self.send_printers)
        except Exception as e:
            log.error("Error initializing printing support:")
            log.error(" %s", e)
            self.printing = False
        try:
            self.do_send_printers()
        except Exception:
            log.error("Error sending the list of printers:", exc_info=True)
            self.printing = False

    def cleanup_printing(self):
        if not self.printing:
            return
        try:
            from xpra.platform.printing import cleanup_printing
            printlog("cleanup_printing=%s", cleanup_printing)
            cleanup_printing()
        except Exception:
            log.warn("failed to cleanup printing subsystem", exc_info=True)

    def send_printers(self, *args):
        #dbus can fire dozens of times for a single printer change
        #so we wait a bit and fire via a timer to try to batch things together:
        if self.send_printers_pending:
            return
        self.send_printers_pending = True
        self.timeout_add(500, self.do_send_printers)

    def do_send_printers(self):
        try:
            self.send_printers_pending = False
            from xpra.platform.printing import get_printers, get_mimetypes
            printers = get_printers()
            printlog("do_send_printers() found printers=%s", printers)
            #remove xpra-forwarded printers to avoid loops and multi-forwards,
            #also ignore stopped printers
            #and only keep the attributes that the server cares about (self.printer_attributes)
            exported_printers = {}
            def used_attrs(d):
                #filter attributes so that we only compare things that are actually used
                if not d:
                    return d
                return dict((k,v) for k,v in d.items() if k in self.printer_attributes)
            for k,v in printers.items():
                device_uri = v.get("device-uri", "")
                if device_uri:
                    #this is cups specific.. oh well
                    printlog("do_send_printers() device-uri(%s)=%s", k, device_uri)
                    if device_uri.startswith("xpraforwarder"):
                        printlog("do_send_printers() skipping xpra forwarded printer=%s", k)
                        continue
                state = v.get("printer-state")
                #"3" if the destination is idle,
                #"4" if the destination is printing a job,
                #"5" if the destination is stopped.
                if state==5:
                    printlog("do_send_printers() skipping stopped printer=%s", k)
                    continue
                attrs = used_attrs(v)
                #add mimetypes:
                attrs["mimetypes"] = get_mimetypes()
                exported_printers[k.encode("utf8")] = attrs
            if self.exported_printers is None:
                #not been sent yet, ensure we can use the dict below:
                self.exported_printers = {}
            elif exported_printers==self.exported_printers:
                printlog("do_send_printers() exported printers unchanged: %s", self.exported_printers)
                return
            #show summary of what has changed:
            added = [k for k in exported_printers.keys() if k not in self.exported_printers]
            if added:
                printlog("do_send_printers() new printers: %s", added)
            removed = [k for k in self.exported_printers.keys() if k not in exported_printers]
            if removed:
                printlog("do_send_printers() printers removed: %s", removed)
            modified = [k for k,v in exported_printers.items() if self.exported_printers.get(k)!=v and k not in added]
            if modified:
                printlog("do_send_printers() printers modified: %s", modified)
            printlog("do_send_printers() printers=%s", exported_printers.keys())
            printlog("do_send_printers() exported printers=%s", ", ".join(str(x) for x in exported_printers.keys()))
            self.exported_printers = exported_printers
            self.send("printers", self.exported_printers)
        except:
            log.error("do_send_printers()", exc_info=True)


    def parse_version_capabilities(self):
        c = self.server_capabilities
        self._remote_machine_id = c.strget("machine_id")
        self._remote_uuid = c.strget("uuid")
        self._remote_version = c.strget("build.version", c.strget("version"))
        self._remote_revision = c.strget("build.revision", c.strget("revision"))
        self._remote_platform = c.strget("platform")
        self._remote_platform_release = c.strget("platform.release")
        self._remote_platform_platform = c.strget("platform.platform")
        #linux distribution is a tuple of different types, ie: ('Linux Fedora' , 20, 'Heisenbug')
        pld = c.listget("platform.linux_distribution")
        if pld and len(pld)==3:
            def san(v):
                if type(v)==int:
                    return v
                return bytestostr(v)
            self._remote_platform_linux_distribution = [san(x) for x in pld]
        verr = version_compat_check(self._remote_version)
        if verr is not None:
            self.warn_and_quit(EXIT_INCOMPATIBLE_VERSION, "incompatible remote version '%s': %s" % (self._remote_version, verr))
            return False
        return True

    def parse_server_capabilities(self):
        return True

    def parse_network_capabilities(self):
        c = self.server_capabilities
        p = self._protocol
        if not p or not p.enable_encoder_from_caps(c):
            return False
        p.enable_compressor_from_caps(c)
        return True

    def parse_encryption_capabilities(self):
        c = self.server_capabilities
        p = self._protocol
        if not p:
            return False
        p.send_aliases = c.dictget("aliases", {})
        if self.encryption:
            #server uses a new cipher after second hello:
            key = self.get_encryption_key()
            assert key, "encryption key is missing"
            if not self.set_server_encryption(c, key):
                return False
        return True

    def _process_set_deflate(self, packet):
        #legacy, should not be used for anything
        pass

    def _process_startup_complete(self, packet):
        #can be received if we connect with "xpra stop" or other command line client
        #as the server is starting up
        pass


    def _process_gibberish(self, packet):
        (_, message, data) = packet
        p = self._protocol
        show_as_text = p and p.input_packetcount==0 and all(c in string.printable for c in bytestostr(data))
        if show_as_text:
            #looks like the first packet back is just text, print it:
            data = bytestostr(data)
            if data.find("Traceback "):
                for x in data.split("\n"):
                    netlog.warn(x.strip("\r"))
            else:
                netlog.warn("Failed to connect, received: %s", repr_ellipsized(data.strip("\n").strip("\r")))
        else:
            netlog.warn("Received uninterpretable nonsense: %s", message)
            netlog.warn(" packet no %i data: %s", p.input_packetcount, repr_ellipsized(data))
        self.quit(EXIT_PACKET_FAILURE)

    def _process_invalid(self, packet):
        (_, message, data) = packet
        netlog.info("Received invalid packet: %s", message)
        netlog(" data: %s", repr_ellipsized(data))
        self.quit(EXIT_PACKET_FAILURE)


    def process_packet(self, proto, packet):
        try:
            handler = None
            packet_type = packet[0]
            if packet_type!=int:
                packet_type = bytestostr(packet_type)
            handler = self._packet_handlers.get(packet_type)
            if handler:
                handler(packet)
                return
            handler = self._ui_packet_handlers.get(packet_type)
            if not handler:
                netlog.error("unknown packet type: %s", packet_type)
                return
            self.idle_add(handler, packet)
        except KeyboardInterrupt:
            raise
        except:
            netlog.error("Unhandled error while processing a '%s' packet from peer using %s", packet_type, handler, exc_info=True)
コード例 #10
0
class XpraClientBase(ServerInfoMixin, FilePrintMixin):
    def __init__(self):
        #this may be called more than once,
        #skip doing internal init again:
        if not hasattr(self, "exit_code"):
            self.defaults_init()
        FilePrintMixin.__init__(self)
        self._init_done = False
        #insert in order:
        dcm = OrderedDict()
        dcm["uri"] = self.process_challenge_uri
        dcm["file"] = self.process_challenge_file
        dcm["env"] = self.process_challenge_env
        dcm["kerberos"] = self.process_challenge_kerberos
        dcm["gss"] = self.process_challenge_gss
        dcm["u2f"] = self.process_challenge_u2f
        dcm["prompt"] = self.process_challenge_prompt
        self.default_challenge_methods = dcm

    def defaults_init(self):
        #skip warning when running the client
        from xpra import child_reaper
        child_reaper.POLL_WARNING = False
        getChildReaper()
        log("XpraClientBase.defaults_init() os.environ:")
        for k, v in os.environ.items():
            log(" %s=%s", k, nonl(v))
        #client state:
        self.exit_code = None
        self.exit_on_signal = False
        self.display_desc = {}
        #connection attributes:
        self.hello_extra = {}
        self.compression_level = 0
        self.display = None
        self.challenge_handlers = OrderedDict()
        self.username = None
        self.password = None
        self.password_file = ()
        self.password_index = 0
        self.password_sent = False
        self.encryption = None
        self.encryption_keyfile = None
        self.server_padding_options = [DEFAULT_PADDING]
        self.server_client_shutdown = True
        self.server_compressors = []
        #protocol stuff:
        self._protocol = None
        self._priority_packets = []
        self._ordinary_packets = []
        self._mouse_position = None
        self._mouse_position_pending = None
        self._mouse_position_send_time = 0
        self._mouse_position_delay = MOUSE_DELAY
        self._mouse_position_timer = 0
        self._aliases = {}
        self._reverse_aliases = {}
        #server state and caps:
        self.server_capabilities = None
        self.completed_startup = False
        self.uuid = get_user_uuid()
        self.init_packet_handlers()
        sanity_checks()

    def init(self, opts, _extra_args=[]):
        if self._init_done:
            #the gtk client classes can inherit this method
            #from multiple parents, skip initializing twice
            return
        self._init_done = True
        for c in XpraClientBase.__bases__:
            c.init(self, opts)
        self.compression_level = opts.compression_level
        self.display = opts.display
        self.username = opts.username
        self.password = opts.password
        self.password_file = opts.password_file
        self.encryption = opts.encryption or opts.tcp_encryption
        if self.encryption:
            crypto_backend_init()
        self.encryption_keyfile = opts.encryption_keyfile or opts.tcp_encryption_keyfile
        #register the authentication challenge handlers:
        ch = tuple(x.strip().lower()
                   for x in (opts.challenge_handlers or "").split(","))

        def has_h(name):
            return "all" in ch or name in ch

        for ch_name in ch:
            if ch_name == "all":
                self.challenge_handlers.update(self.default_challenge_methods)
                break
            method = self.default_challenge_methods.get(ch_name)
            if method:
                self.challenge_handlers[ch_name] = method
                continue
            log.warn("Warning: unknown challenge handler '%s'", ch_name)
        if DETECT_LEAKS:
            from xpra.util import detect_leaks
            print_leaks = detect_leaks()
            self.timeout_add(10 * 1000, print_leaks)

    def timeout_add(self, *args):
        raise Exception("override me!")

    def idle_add(self, *args):
        raise Exception("override me!")

    def source_remove(self, *args):
        raise Exception("override me!")

    def may_notify(self, nid, summary, body, *args, **kwargs):
        notifylog = Logger("notify")
        notifylog("may_notify(%s, %s, %s, %s, %s)", nid, summary, body, args,
                  kwargs)
        notifylog.info("%s", summary)
        if body:
            for x in body.splitlines():
                notifylog.info(" %s", x)

    def handle_deadly_signal(self, signum, _frame=None):
        sys.stderr.write("\ngot deadly signal %s, exiting\n" %
                         SIGNAMES.get(signum, signum))
        sys.stderr.flush()
        self.cleanup()
        os._exit(128 + signum)

    def handle_app_signal(self, signum, _frame=None):
        sys.stderr.write("\ngot signal %s, exiting\n" %
                         SIGNAMES.get(signum, signum))
        sys.stderr.flush()
        signal.signal(signal.SIGINT, self.handle_deadly_signal)
        signal.signal(signal.SIGTERM, self.handle_deadly_signal)
        self.signal_cleanup()
        self.timeout_add(0, self.signal_disconnect_and_quit, 128 + signum,
                         "exit on signal %s" % SIGNAMES.get(signum, signum))

    def install_signal_handlers(self):
        signal.signal(signal.SIGINT, self.handle_app_signal)
        signal.signal(signal.SIGTERM, self.handle_app_signal)

    def signal_disconnect_and_quit(self, exit_code, reason):
        log("signal_disconnect_and_quit(%s, %s) exit_on_signal=%s", exit_code,
            reason, self.exit_on_signal)
        if not self.exit_on_signal:
            #if we get another signal, we'll try to exit without idle_add...
            self.exit_on_signal = True
            self.idle_add(self.disconnect_and_quit, exit_code, reason)
            self.idle_add(self.quit, exit_code)
            self.idle_add(self.exit)
            return
        #warning: this will run cleanup code from the signal handler
        self.disconnect_and_quit(exit_code, reason)
        self.quit(exit_code)
        self.exit()
        os._exit(exit_code)

    def signal_cleanup(self):
        #placeholder for stuff that can be cleaned up from the signal handler
        #(non UI thread stuff)
        pass

    def disconnect_and_quit(self, exit_code, reason):
        #make sure that we set the exit code early,
        #so the protocol shutdown won't set a different one:
        if self.exit_code is None:
            self.exit_code = exit_code
        #try to tell the server we're going, then quit
        log("disconnect_and_quit(%s, %s)", exit_code, reason)
        p = self._protocol
        if p is None or p._closed:
            self.quit(exit_code)
            return

        def protocol_closed():
            log("disconnect_and_quit: protocol_closed()")
            self.idle_add(self.quit, exit_code)

        if p:
            p.send_disconnect([reason], done_callback=protocol_closed)
        self.timeout_add(1000, self.quit, exit_code)

    def exit(self):
        log("XpraClientBase.exit() calling %s", sys.exit)
        sys.exit()

    def client_type(self):
        #overriden in subclasses!
        return "Python"

    def get_scheduler(self):
        raise NotImplementedError()

    def setup_connection(self, conn):
        netlog("setup_connection(%s) timeout=%s, socktype=%s", conn,
               conn.timeout, conn.socktype)
        if conn.socktype == "udp":
            from xpra.net.udp_protocol import UDPClientProtocol
            self._protocol = UDPClientProtocol(self.get_scheduler(), conn,
                                               self.process_packet,
                                               self.next_packet)
            #use a random uuid:
            import random
            self._protocol.uuid = random.randint(0, 2**64 - 1)
            self.set_packet_handlers(self._packet_handlers, {
                "udp-control": self._process_udp_control,
            })
        else:
            self._protocol = Protocol(self.get_scheduler(), conn,
                                      self.process_packet, self.next_packet)
        for x in (b"keymap-changed", b"server-settings", b"logging",
                  b"input-devices"):
            self._protocol.large_packets.append(x)
        self._protocol.set_compression_level(self.compression_level)
        self._protocol.receive_aliases.update(self._aliases)
        self._protocol.enable_default_encoder()
        self._protocol.enable_default_compressor()
        if self.encryption and ENCRYPT_FIRST_PACKET:
            key = self.get_encryption_key()
            self._protocol.set_cipher_out(self.encryption, DEFAULT_IV, key,
                                          DEFAULT_SALT, DEFAULT_ITERATIONS,
                                          INITIAL_PADDING)
        self.have_more = self._protocol.source_has_more
        if conn.timeout > 0:
            self.timeout_add((conn.timeout + EXTRA_TIMEOUT) * 1000,
                             self.verify_connected)
        process = getattr(conn, "process",
                          None)  #ie: ssh is handled by anotherprocess
        if process:
            proc, name, command = process
            if proc:
                getChildReaper().add_process(proc,
                                             name,
                                             command,
                                             ignore=True,
                                             forget=False)
        netlog("setup_connection(%s) protocol=%s", conn, self._protocol)

    def _process_udp_control(self, packet):
        #send it back to the protocol object:
        self._protocol.process_control(*packet[1:])

    def remove_packet_handlers(self, *keys):
        for k in keys:
            for d in (self._packet_handlers, self._ui_packet_handlers):
                try:
                    del d[k]
                except:
                    pass

    def set_packet_handlers(self, to, defs):
        """ configures the given packet handlers,
            and make sure we remove any existing ones with the same key
            (which can be useful for subclasses, not here)
        """
        log("set_packet_handlers(%s, %s)", to, defs)
        self.remove_packet_handlers(*defs.keys())
        for k, v in defs.items():
            to[k] = v

    def init_packet_handlers(self):
        self._packet_handlers = {}
        self._ui_packet_handlers = {}
        self.set_packet_handlers(self._packet_handlers,
                                 {"hello": self._process_hello})
        self.set_packet_handlers(
            self._ui_packet_handlers, {
                "challenge": self._process_challenge,
                "disconnect": self._process_disconnect,
                "set_deflate": self._process_set_deflate,
                "startup-complete": self._process_startup_complete,
                Protocol.CONNECTION_LOST: self._process_connection_lost,
                Protocol.GIBBERISH: self._process_gibberish,
                Protocol.INVALID: self._process_invalid,
            })

    def init_authenticated_packet_handlers(self):
        FilePrintMixin.init_authenticated_packet_handlers(self)

    def init_aliases(self):
        packet_types = list(self._packet_handlers.keys())
        packet_types += list(self._ui_packet_handlers.keys())
        i = 1
        for key in packet_types:
            self._aliases[i] = key
            self._reverse_aliases[key] = i
            i += 1

    def has_password(self):
        return self.password or self.password_file or os.environ.get(
            'XPRA_PASSWORD')

    def send_hello(self, challenge_response=None, client_salt=None):
        try:
            hello = self.make_hello_base()
            if self.has_password() and not challenge_response:
                #avoid sending the full hello: tell the server we want
                #a packet challenge first
                hello["challenge"] = True
            else:
                hello.update(self.make_hello())
        except InitExit as e:
            log.error("error preparing connection:")
            log.error(" %s", e)
            self.quit(EXIT_INTERNAL_ERROR)
            return
        except Exception as e:
            log.error("error preparing connection: %s", e, exc_info=True)
            self.quit(EXIT_INTERNAL_ERROR)
            return
        if challenge_response:
            hello["challenge_response"] = challenge_response
            #make it harder for a passive attacker to guess the password length
            #by observing packet sizes (only relevant for wss and ssl)
            hello["challenge_padding"] = get_salt(
                max(32, 512 - len(challenge_response)))
            if client_salt:
                hello["challenge_client_salt"] = client_salt
        log("send_hello(%s) packet=%s", hexstr(challenge_response or ""),
            hello)
        self.send("hello", hello)

    def verify_connected(self):
        if self.server_capabilities is None:
            #server has not said hello yet
            self.warn_and_quit(EXIT_TIMEOUT, "connection timed out")

    def make_hello_base(self):
        capabilities = flatten_dict(get_network_caps())
        #add "kerberos" and "gss" if enabled:
        default_on = "all" in self.challenge_handlers or "auto" in self.challenge_handlers
        for auth in ("kerberos", "gss", "u2f"):
            if default_on or auth in self.challenge_handlers:
                capabilities["digest"].append(auth)
        capabilities.update(FilePrintMixin.get_caps(self))
        capabilities.update({
            "version": XPRA_VERSION,
            "encoding.generic": True,
            "namespace": True,
            "hostname": socket.gethostname(),
            "uuid": self.uuid,
            "username": self.username,
            "name": get_name(),
            "client_type": self.client_type(),
            "python.version": sys.version_info[:3],
            "python.bits": BITS,
            "compression_level": self.compression_level,
            "argv": sys.argv,
        })
        capabilities.update(self.get_file_transfer_features())
        if self.display:
            capabilities["display"] = self.display

        def up(prefix, d):
            updict(capabilities, prefix, d)

        up("build", self.get_version_info())
        mid = get_machine_id()
        if mid:
            capabilities["machine_id"] = mid
        if self.encryption:
            assert self.encryption in ENCRYPTION_CIPHERS
            iv = get_iv()
            key_salt = get_salt()
            iterations = get_iterations()
            padding = choose_padding(self.server_padding_options)
            up(
                "cipher", {
                    "": self.encryption,
                    "iv": iv,
                    "key_salt": key_salt,
                    "key_stretch_iterations": iterations,
                    "padding": padding,
                    "padding.options": PADDING_OPTIONS,
                })
            key = self.get_encryption_key()
            if key is None:
                self.warn_and_quit(EXIT_ENCRYPTION,
                                   "encryption key is missing")
                return
            self._protocol.set_cipher_in(self.encryption, iv, key, key_salt,
                                         iterations, padding)
            netlog(
                "encryption capabilities: %s",
                dict((k, v) for k, v in capabilities.items()
                     if k.startswith("cipher")))
        capabilities.update(self.hello_extra)
        return capabilities

    def get_version_info(self):
        return get_version_info()

    def make_hello(self):
        capabilities = {
            "randr_notify": False,  #only client.py cares about this
            "windows": False,  #only client.py cares about this
        }
        if self._reverse_aliases:
            capabilities["aliases"] = self._reverse_aliases
        return capabilities

    def compressed_wrapper(self, datatype, data, level=5):
        #FIXME: ugly assumptions here, should pass by name!
        zlib = "zlib" in self.server_compressors and compression.use_zlib
        lz4 = "lz4" in self.server_compressors and compression.use_lz4
        lzo = "lzo" in self.server_compressors and compression.use_lzo
        if level > 0 and len(data) >= 256 and (zlib or lz4 or lzo):
            cw = compression.compressed_wrapper(datatype,
                                                data,
                                                level=level,
                                                zlib=zlib,
                                                lz4=lz4,
                                                lzo=lzo,
                                                can_inline=False)
            if len(cw) < len(data):
                #the compressed version is smaller, use it:
                return cw
        #we can't compress, so at least avoid warnings in the protocol layer:
        return compression.Compressed("raw %s" % datatype,
                                      data,
                                      can_inline=True)

    def send(self, *parts):
        self._ordinary_packets.append(parts)
        self.have_more()

    def send_now(self, *parts):
        self._priority_packets.append(parts)
        self.have_more()

    def send_positional(self, packet):
        #packets that include the mouse position in them
        #we can cancel the pending position packets
        self._ordinary_packets.append(packet)
        self._mouse_position = None
        self._mouse_position_pending = None
        self.cancel_send_mouse_position_timer()
        self.have_more()

    def send_mouse_position(self, packet):
        if self._mouse_position_timer:
            self._mouse_position_pending = packet
            return
        self._mouse_position_pending = packet
        now = monotonic_time()
        elapsed = int(1000 * (now - self._mouse_position_send_time))
        mouselog("send_mouse_position(%s) elapsed=%i, delay=%i", packet,
                 elapsed, self._mouse_position_delay)
        if elapsed < self._mouse_position_delay:
            self._mouse_position_timer = self.timeout_add(
                self._mouse_position_delay - elapsed,
                self.do_send_mouse_position)
        else:
            self.do_send_mouse_position()

    def do_send_mouse_position(self):
        self._mouse_position_timer = 0
        self._mouse_position_send_time = monotonic_time()
        self._mouse_position = self._mouse_position_pending
        mouselog("do_send_mouse_position() position=%s", self._mouse_position)
        self.have_more()

    def cancel_send_mouse_position_timer(self):
        mpt = self._mouse_position_timer
        if mpt:
            self._mouse_position_timer = 0
            self.source_remove(mpt)

    def have_more(self):
        #this function is overridden in setup_protocol()
        p = self._protocol
        if p and p.source:
            p.source_has_more()

    def next_packet(self):
        netlog(
            "next_packet() packets in queues: priority=%i, ordinary=%i, mouse=%s",
            len(self._priority_packets), len(self._ordinary_packets),
            bool(self._mouse_position))
        synchronous = True
        if self._priority_packets:
            packet = self._priority_packets.pop(0)
        elif self._ordinary_packets:
            packet = self._ordinary_packets.pop(0)
        elif self._mouse_position is not None:
            packet = self._mouse_position
            synchronous = False
            self._mouse_position = None
        else:
            packet = None
        has_more = packet is not None and \
                (bool(self._priority_packets) or bool(self._ordinary_packets) \
                 or self._mouse_position is not None)
        return packet, None, None, None, synchronous, has_more

    def cleanup(self):
        reaper_cleanup()
        FilePrintMixin.cleanup(self)
        p = self._protocol
        log("XpraClientBase.cleanup() protocol=%s", p)
        if p:
            log("calling %s", p.close)
            p.close()
            self._protocol = None
        log("cleanup done")
        self.cancel_send_mouse_position_timer()
        dump_all_frames()

    def glib_init(self):
        if PYTHON3:
            import gi
            if gi.version_info >= (3, 11):
                #no longer need to call threads_init
                return
        from xpra.gtk_common.gobject_compat import import_glib
        glib = import_glib()
        glib.threads_init()

    def run(self):
        self._protocol.start()

    def quit(self, exit_code=0):
        raise Exception("override me!")

    def warn_and_quit(self, exit_code, message):
        log.warn(message)
        self.quit(exit_code)

    def send_shutdown_server(self):
        assert self.server_client_shutdown
        self.send("shutdown-server")

    def _process_disconnect(self, packet):
        #ie: ("disconnect", "version error", "incompatible version")
        reason = bytestostr(packet[1])
        info = packet[2:]
        s = nonl(reason)
        if len(info):
            s += " (%s)" % csv(nonl(bytestostr(x)) for x in info)
        if self.server_capabilities is None or len(
                self.server_capabilities) == 0:
            #server never sent hello to us - so disconnect is an error
            #(but we don't know which one - the info message may help)
            log.warn(
                "server failure: disconnected before the session could be established"
            )
            e = EXIT_FAILURE
        elif disconnect_is_an_error(reason):
            log.warn("server failure: %s", reason)
            e = EXIT_FAILURE
        else:
            if self.exit_code is None:
                #we're not in the process of exiting already,
                #tell the user why the server is disconnecting us
                log.info("server requested disconnect:")
                log.info(" %s", s)
            self.quit(EXIT_OK)
            return
        self.warn_and_quit(e, "server requested disconnect: %s" % s)

    def _process_connection_lost(self, _packet):
        p = self._protocol
        if p and p.input_raw_packetcount == 0:
            props = p.get_info()
            c = props.get("compression", "unknown")
            e = props.get("encoder", "unknown")
            netlog.error(
                "Error: failed to receive anything, not an xpra server?")
            netlog.error(
                "  could also be the wrong protocol, username, password or port"
            )
            netlog.error("  or the session was not found")
            if c != "unknown" or e != "unknown":
                netlog.error(
                    "  or maybe this server does not support '%s' compression or '%s' packet encoding?",
                    c, e)
        if self.exit_code != 0:
            self.warn_and_quit(EXIT_CONNECTION_LOST, "Connection lost")

    ########################################
    # Authentication
    def _process_challenge(self, packet):
        authlog("processing challenge: %s", packet[1:])
        if not self.validate_challenge_packet(packet):
            return
        authlog("challenge handlers: %s", self.challenge_handlers)
        for name, method in self.challenge_handlers.items():
            try:
                authlog("calling challenge handler %s", name)
                if method(packet):
                    return
            except Exception as e:
                authlog("%s(%s)", method, packet, exc_info=True)
                authlog.error("Error in %s challenge handler:", name)
                authlog.error(" %s", str(e) or type(e))
                continue
        self.quit(EXIT_PASSWORD_REQUIRED)

    def process_challenge_uri(self, packet):
        if self.password:
            self.send_challenge_reply(packet, self.password)
            #clearing it to allow other modules to process further challenges:
            self.password = None
            return True
        return False

    def process_challenge_env(self, packet):
        k = "XPRA_PASSWORD"
        password = os.environ.get(k)
        authlog("process_challenge_env() %s=%s", k, obsc(password))
        if password:
            self.send_challenge_reply(packet, password)
            return True
        return False

    def process_challenge_file(self, packet):
        if self.password_index < len(self.password_file):
            password_file = self.password_file[self.password_index]
            self.password_index += 1
            filename = os.path.expanduser(password_file)
            password = load_binary_file(filename)
            authlog("password read from file %i '%s': %s", self.password_index,
                    password_file, obsc(password))
            self.send_challenge_reply(packet, password)
            return True
        return False

    def process_challenge_prompt(self, packet):
        prompt = "password"
        digest = packet[3]
        if digest.startswith(b"gss:") or digest.startswith(b"kerberos:"):
            prompt = "%s token" % (digest.split(b":", 1)[0])
        if len(packet) >= 6:
            prompt = std(packet[5])
        return self.do_process_challenge_prompt(packet, prompt)

    def do_process_challenge_prompt(self, packet, prompt="password"):
        authlog("do_process_challenge_prompt() use_tty=%s", use_tty())
        if use_tty():
            import getpass
            authlog("stdin isatty, using password prompt")
            password = getpass.getpass("%s :" %
                                       self.get_challenge_prompt(prompt))
            authlog("password read from tty via getpass: %s", obsc(password))
            self.send_challenge_reply(packet, password)
            return True
        return False

    def process_challenge_kerberos(self, packet):
        digest = packet[3]
        if not digest.startswith(b"kerberos:"):
            authlog("%s is not a kerberos challenge", digest)
            #not a kerberos challenge
            return False
        try:
            if WIN32:
                import winkerberos as kerberos
            else:
                import kerberos
        except ImportError as e:
            authlog("import (win)kerberos", exc_info=True)
            if first_time("no-kerberos"):
                authlog.warn(
                    "Warning: kerberos challenge handler is not supported:")
                authlog.warn(" %s", e)
            return False
        service = bytestostr(digest.split(b":", 1)[1])
        if service not in KERBEROS_SERVICES and "*" not in KERBEROS_SERVICES:
            authlog.warn("Warning: invalid kerberos request for service '%s'",
                         service)
            authlog.warn(" services supported: %s", csv(KERBEROS_SERVICES))
            return False
        authlog("kerberos service=%s", service)

        def log_kerberos_exception(e):
            try:
                for x in e.args:
                    if isinstance(x, (list, tuple)):
                        try:
                            log.error(" %s", csv(x))
                            continue
                        except:
                            pass
                    authlog.error(" %s", x)
            except Exception:
                authlog.error(" %s", e)

        try:
            r, ctx = kerberos.authGSSClientInit(service)
            assert r == 1, "return code %s" % r
        except Exception as e:
            authlog("kerberos.authGSSClientInit(%s)", service, exc_info=True)
            authlog.error("Error: cannot initialize kerberos client:")
            log_kerberos_exception(e)
            return False
        try:
            kerberos.authGSSClientStep(ctx, "")
        except Exception as e:
            authlog("kerberos.authGSSClientStep", exc_info=True)
            authlog.error("Error: kerberos client authentication failure:")
            log_kerberos_exception(e)
            return False
        token = kerberos.authGSSClientResponse(ctx)
        authlog("kerberos token=%s", token)
        self.send_challenge_reply(packet, token)
        return True

    def process_challenge_gss(self, packet):
        digest = packet[3]
        if not digest.startswith(b"gss:"):
            #not a gss challenge
            authlog("%s is not a gss challenge", digest)
            return False
        try:
            import gssapi
            if OSX and False:
                from gssapi.raw import (cython_converters, cython_types, oids)
                assert cython_converters and cython_types and oids
        except ImportError as e:
            authlog("import gssapi", exc_info=True)
            if first_time("no-kerberos"):
                authlog.warn("Warning: gss authentication not supported:")
                authlog.warn(" %s", e)
            return False
        service = bytestostr(digest.split(b":", 1)[1])
        if service not in GSS_SERVICES and "*" not in GSS_SERVICES:
            authlog.warn("Warning: invalid GSS request for service '%s'",
                         service)
            authlog.warn(" services supported: %s", csv(GSS_SERVICES))
            return False
        authlog("gss service=%s", service)
        service_name = gssapi.Name(service)
        try:
            ctx = gssapi.SecurityContext(name=service_name, usage="initiate")
            token = ctx.step()
        except Exception as e:
            authlog("gssapi failure", exc_info=True)
            authlog.error("Error: gssapi client authentication failure:")
            try:
                #split on colon
                for x in str(e).split(":", 2):
                    authlog.error(" %s", x.lstrip(" "))
            except:
                authlog.error(" %s", e)
            return False
        authlog("gss token=%s", repr(token))
        self.send_challenge_reply(packet, token)
        return True

    def process_challenge_u2f(self, packet):
        digest = packet[3]
        if not digest.startswith(b"u2f:"):
            authlog("%s is not a u2f challenge", digest)
            return False
        import binascii
        import logging
        if not is_debug_enabled("auth"):
            logging.getLogger("pyu2f.hardware").setLevel(logging.INFO)
            logging.getLogger("pyu2f.hidtransport").setLevel(logging.INFO)
        from pyu2f import model
        from pyu2f.u2f import GetLocalU2FInterface
        dev = GetLocalU2FInterface()
        APP_ID = os.environ.get("XPRA_U2F_APP_ID", "Xpra")
        key_handle_str = os.environ.get("XPRA_U2F_KEY_HANDLE")
        authlog("process_challenge_u2f XPRA_U2F_KEY_HANDLE=%s", key_handle_str)
        if not key_handle_str:
            #try to load the key handle from the user conf dir(s):
            from xpra.platform.paths import get_user_conf_dirs
            info = self._protocol.get_info(False)
            key_handle_filenames = []
            for hostinfo in ("-%s" % info.get("host", ""), ""):
                key_handle_filenames += [
                    os.path.join(d, "u2f-keyhandle%s.hex" % hostinfo)
                    for d in get_user_conf_dirs()
                ]
            for filename in key_handle_filenames:
                p = osexpand(filename)
                key_handle_str = load_binary_file(p)
                authlog("key_handle_str(%s)=%s", p, key_handle_str)
                if key_handle_str:
                    key_handle_str = key_handle_str.rstrip(b" \n\r")
                    break
            if not key_handle_str:
                authlog.warn("Warning: no U2F key handle found")
                return False
        authlog("process_challenge_u2f key_handle=%s", key_handle_str)
        key_handle = binascii.unhexlify(key_handle_str)
        key = model.RegisteredKey(key_handle)
        #use server salt as challenge directly
        challenge = packet[1]
        authlog.info("activate your U2F device for authentication")
        response = dev.Authenticate(APP_ID, challenge, [key])
        sig = response.signature_data
        client_data = response.client_data
        authlog("process_challenge_u2f client data=%s, signature=%s",
                client_data, binascii.hexlify(sig))
        self.do_send_challenge_reply(bytes(sig), client_data.origin)
        return True

    def auth_error(self,
                   code,
                   message,
                   server_message="authentication failed"):
        authlog.error("Error: authentication failed:")
        authlog.error(" %s", message)
        self.disconnect_and_quit(code, server_message)

    def validate_challenge_packet(self, packet):
        digest = bytestostr(packet[3])
        #don't send XORed password unencrypted:
        if digest == "xor":
            encrypted = self._protocol.cipher_out or self._protocol.get_info(
            ).get("type") in ("ssl", "wss")
            local = self.display_desc.get("local", False)
            authlog("xor challenge, encrypted=%s, local=%s", encrypted, local)
            if local and ALLOW_LOCALHOST_PASSWORDS:
                return True
            elif not encrypted and not ALLOW_UNENCRYPTED_PASSWORDS:
                self.auth_error(
                    EXIT_ENCRYPTION,
                    "server requested '%s' digest, cowardly refusing to use it without encryption"
                    % digest, "invalid digest")
                return False
        salt_digest = "xor"
        if len(packet) >= 5:
            salt_digest = bytestostr(packet[4])
        if salt_digest in ("xor", "des"):
            if not LEGACY_SALT_DIGEST:
                self.auth_error(
                    EXIT_INCOMPATIBLE_VERSION,
                    "server uses legacy salt digest '%s'" % salt_digest)
                return False
            log.warn(
                "Warning: server using legacy support for '%s' salt digest",
                salt_digest)
        return True

    def get_challenge_prompt(self, prompt="password"):
        text = "Please enter the %s" % (prompt, )
        try:
            from xpra.net.bytestreams import pretty_socket
            conn = self._protocol._conn
            text += " for user '%s',\n connecting to %s server %s" % (
                self.username, conn.socktype, pretty_socket(conn.remote))
        except:
            pass
        return text

    def send_challenge_reply(self, packet, password):
        if not password:
            if self.password_file:
                self.auth_error(
                    EXIT_PASSWORD_FILE_ERROR,
                    "failed to load password from file%s %s" %
                    (engs(self.password_file), csv(self.password_file)),
                    "no password available")
            else:
                self.auth_error(
                    EXIT_PASSWORD_REQUIRED,
                    "this server requires authentication and no password is available"
                )
            return
        server_salt = bytestostr(packet[1])
        if self.encryption:
            assert len(
                packet
            ) >= 3, "challenge does not contain encryption details to use for the response"
            server_cipher = typedict(packet[2])
            key = self.get_encryption_key()
            if key is None:
                self.auth_error(EXIT_ENCRYPTION,
                                "the server does not use any encryption",
                                "client requires encryption")
                return
            if not self.set_server_encryption(server_cipher, key):
                return
        #all server versions support a client salt,
        #they also tell us which digest to use:
        digest = bytestostr(packet[3])
        actual_digest = digest.split(":", 1)[0]
        l = len(server_salt)
        salt_digest = "xor"
        if len(packet) >= 5:
            salt_digest = bytestostr(packet[4])
        if salt_digest == "xor":
            #with xor, we have to match the size
            assert l >= 16, "server salt is too short: only %i bytes, minimum is 16" % l
            assert l <= 256, "server salt is too long: %i bytes, maximum is 256" % l
        else:
            #other digest, 32 random bytes is enough:
            l = 32
        client_salt = get_salt(l)
        salt = gendigest(salt_digest, client_salt, server_salt)
        authlog("combined %s salt(%s, %s)=%s", salt_digest,
                hexstr(server_salt), hexstr(client_salt), hexstr(salt))

        challenge_response = gendigest(actual_digest, password, salt)
        if not challenge_response:
            log("invalid digest module '%s': %s", actual_digest)
            self.auth_error(
                EXIT_UNSUPPORTED,
                "server requested '%s' digest but it is not supported" %
                actual_digest, "invalid digest")
            return
        authlog("%s(%s, %s)=%s", actual_digest, repr(password), repr(salt),
                repr(challenge_response))
        self.do_send_challenge_reply(challenge_response, client_salt)

    def do_send_challenge_reply(self, challenge_response, client_salt):
        self.password_sent = True
        self.send_hello(challenge_response, client_salt)

    ########################################
    # Encryption
    def set_server_encryption(self, caps, key):
        cipher = caps.strget("cipher")
        cipher_iv = caps.strget("cipher.iv")
        key_salt = caps.strget("cipher.key_salt")
        iterations = caps.intget("cipher.key_stretch_iterations")
        padding = caps.strget("cipher.padding", DEFAULT_PADDING)
        #server may tell us what it supports,
        #either from hello response or from challenge packet:
        self.server_padding_options = caps.strlistget("cipher.padding.options",
                                                      [DEFAULT_PADDING])
        if not cipher or not cipher_iv:
            self.warn_and_quit(
                EXIT_ENCRYPTION,
                "the server does not use or support encryption/password, cannot continue with %s cipher"
                % self.encryption)
            return False
        if cipher not in ENCRYPTION_CIPHERS:
            self.warn_and_quit(
                EXIT_ENCRYPTION,
                "unsupported server cipher: %s, allowed ciphers: %s" %
                (cipher, csv(ENCRYPTION_CIPHERS)))
            return False
        if padding not in ALL_PADDING_OPTIONS:
            self.warn_and_quit(
                EXIT_ENCRYPTION,
                "unsupported server cipher padding: %s, allowed ciphers: %s" %
                (padding, csv(ALL_PADDING_OPTIONS)))
            return False
        p = self._protocol
        if not p:
            return False
        p.set_cipher_out(cipher, cipher_iv, key, key_salt, iterations, padding)
        return True

    def get_encryption_key(self):
        key = None
        if self.encryption_keyfile and os.path.exists(self.encryption_keyfile):
            key = load_binary_file(self.encryption_keyfile)
            cryptolog("get_encryption_key() loaded %i bytes from '%s'",
                      len(key or ""), self.encryption_keyfile)
        else:
            cryptolog("get_encryption_key() file '%s' does not exist",
                      self.encryption_keyfile)
        if not key:
            XPRA_ENCRYPTION_KEY = "XPRA_ENCRYPTION_KEY"
            key = strtobytes(os.environ.get(XPRA_ENCRYPTION_KEY, ''))
            cryptolog(
                "get_encryption_key() got %i bytes from '%s' environment variable",
                len(key or ""), XPRA_ENCRYPTION_KEY)
        if not key:
            raise InitExit(1, "no encryption key")
        return key.strip(b"\n\r")

    def _process_hello(self, packet):
        self.remove_packet_handlers("challenge")
        if not self.password_sent and self.has_password():
            self.warn_and_quit(EXIT_NO_AUTHENTICATION,
                               "the server did not request our password")
            return
        try:
            self.server_capabilities = typedict(packet[1])
            netlog("processing hello from server: %s",
                   self.server_capabilities)
            if not self.server_connection_established():
                self.warn_and_quit(EXIT_FAILURE,
                                   "failed to establish connection")
        except Exception as e:
            netlog.info("error in hello packet", exc_info=True)
            self.warn_and_quit(
                EXIT_FAILURE,
                "error processing hello packet from server: %s" % e)

    def capsget(self, capabilities, key, default):
        v = capabilities.get(strtobytes(key), default)
        if PYTHON3 and type(v) == bytes:
            v = bytestostr(v)
        return v

    def server_connection_established(self):
        netlog("server_connection_established()")
        if not self.parse_encryption_capabilities():
            netlog(
                "server_connection_established() failed encryption capabilities"
            )
            return False
        if not self.parse_server_capabilities():
            netlog(
                "server_connection_established() failed server capabilities")
            return False
        if not self.parse_network_capabilities():
            netlog(
                "server_connection_established() failed network capabilities")
            return False
        #raise packet size if required:
        if self.file_transfer:
            self._protocol.max_packet_size = max(
                self._protocol.max_packet_size,
                self.file_size_limit * 1024 * 1024)
        netlog(
            "server_connection_established() adding authenticated packet handlers"
        )
        self.init_authenticated_packet_handlers()
        self.init_aliases()
        return True

    def parse_server_capabilities(self):
        for c in XpraClientBase.__bases__:
            if not c.parse_server_capabilities(self):
                return False
        self.server_client_shutdown = self.server_capabilities.boolget(
            "client-shutdown", True)
        self.server_compressors = self.server_capabilities.strlistget(
            "compressors", ["zlib"])
        return True

    def parse_network_capabilities(self):
        c = self.server_capabilities
        p = self._protocol
        if not p or not p.enable_encoder_from_caps(c):
            return False
        p.enable_compressor_from_caps(c)
        p.accept()
        p.send_aliases = c.dictget("aliases", {})
        return True

    def parse_encryption_capabilities(self):
        c = self.server_capabilities
        p = self._protocol
        if not p:
            return False
        if self.encryption:
            #server uses a new cipher after second hello:
            key = self.get_encryption_key()
            assert key, "encryption key is missing"
            if not self.set_server_encryption(c, key):
                return False
        return True

    def _process_set_deflate(self, packet):
        #legacy, should not be used for anything
        pass

    def _process_startup_complete(self, packet):
        #can be received if we connect with "xpra stop" or other command line client
        #as the server is starting up
        self.completed_startup = packet

    def _process_gibberish(self, packet):
        log("process_gibberish(%s)", repr_ellipsized(packet))
        (_, message, data) = packet
        p = self._protocol
        show_as_text = p and p.input_packetcount == 0 and all(
            c in string.printable for c in bytestostr(data))
        if show_as_text:
            #looks like the first packet back is just text, print it:
            data = bytestostr(data)
            if data.find("\n") >= 0:
                for x in data.splitlines():
                    netlog.warn(x)
            else:
                netlog.error("Error: failed to connect, received")
                netlog.error(" %s", repr_ellipsized(data))
        else:
            netlog.error("Error: received uninterpretable nonsense: %s",
                         message)
            if p:
                netlog.error(" packet no %i data: %s", p.input_packetcount,
                             repr_ellipsized(data))
            else:
                netlog.error(" data: %s", repr_ellipsized(data))
        self.quit(EXIT_PACKET_FAILURE)

    def _process_invalid(self, packet):
        (_, message, data) = packet
        netlog.info("Received invalid packet: %s", message)
        netlog(" data: %s", repr_ellipsized(data))
        self.quit(EXIT_PACKET_FAILURE)

    def process_packet(self, _proto, packet):
        try:
            handler = None
            packet_type = packet[0]
            if packet_type != int:
                packet_type = bytestostr(packet_type)
            handler = self._packet_handlers.get(packet_type)
            if handler:
                handler(packet)
                return
            handler = self._ui_packet_handlers.get(packet_type)
            if not handler:
                netlog.error("unknown packet type: %s", packet_type)
                return
            self.idle_add(handler, packet)
        except KeyboardInterrupt:
            raise
        except:
            netlog.error(
                "Unhandled error while processing a '%s' packet from peer using %s",
                packet_type,
                handler,
                exc_info=True)
コード例 #11
0
class ProxyInstanceProcess(Process):

    def __init__(self, uid, gid, env_options, session_options, socket_dir,
                 video_encoder_modules, csc_modules,
                 client_conn, client_state, cipher, encryption_key, server_conn, caps, message_queue):
        Process.__init__(self, name=str(client_conn))
        self.uid = uid
        self.gid = gid
        self.env_options = env_options
        self.session_options = session_options
        self.socket_dir = socket_dir
        self.video_encoder_modules = video_encoder_modules
        self.csc_modules = csc_modules
        self.client_conn = client_conn
        self.client_state = client_state
        self.cipher = cipher
        self.encryption_key = encryption_key
        self.server_conn = server_conn
        self.caps = caps
        log("ProxyProcess%s", (uid, gid, env_options, session_options, socket_dir,
                               video_encoder_modules, csc_modules,
                               client_conn, repr_ellipsized(str(client_state)), cipher, encryption_key, server_conn,
                               "%s: %s.." % (type(caps), repr_ellipsized(str(caps))), message_queue))
        self.client_protocol = None
        self.server_protocol = None
        self.exit = False
        self.main_queue = None
        self.message_queue = message_queue
        self.encode_queue = None            #holds draw packets to encode
        self.encode_thread = None
        self.video_encoding_defs = None
        self.video_encoders = None
        self.video_encoders_last_used_time = None
        self.video_encoder_types = None
        self.video_helper = None
        self.lost_windows = None
        #for handling the local unix domain socket:
        self.control_socket_cleanup = None
        self.control_socket = None
        self.control_socket_thread = None
        self.control_socket_path = None
        self.potential_protocols = []
        self.max_connections = MAX_CONCURRENT_CONNECTIONS

    def server_message_queue(self):
        while True:
            log("waiting for server message on %s", self.message_queue)
            m = self.message_queue.get()
            log("received proxy server message: %s", m)
            if m=="stop":
                self.stop("proxy server request")
                return
            elif m=="socket-handover-complete":
                log("setting sockets to blocking mode: %s", (self.client_conn, self.server_conn))
                #set sockets to blocking mode:
                set_blocking(self.client_conn)
                set_blocking(self.server_conn)
            else:
                log.error("unexpected proxy server message: %s", m)

    def signal_quit(self, signum, frame):
        log.info("")
        log.info("proxy process pid %s got signal %s, exiting", os.getpid(), SIGNAMES.get(signum, signum))
        self.exit = True
        signal.signal(signal.SIGINT, deadly_signal)
        signal.signal(signal.SIGTERM, deadly_signal)
        self.stop(SIGNAMES.get(signum, signum))

    def idle_add(self, fn, *args, **kwargs):
        #we emulate gobject's idle_add using a simple queue
        self.main_queue.put((fn, args, kwargs))

    def timeout_add(self, timeout, fn, *args, **kwargs):
        #emulate gobject's timeout_add using idle add and a Timer
        #using custom functions to cancel() the timer when needed
        def idle_exec():
            v = fn(*args, **kwargs)
            if bool(v):
                self.timeout_add(timeout, fn, *args, **kwargs)
            return False
        def timer_exec():
            #just run via idle_add:
            self.idle_add(idle_exec)
        Timer(timeout/1000.0, timer_exec).start()

    def run(self):
        log("ProxyProcess.run() pid=%s, uid=%s, gid=%s", os.getpid(), getuid(), getgid())
        setuidgid(self.uid, self.gid)
        if self.env_options:
            #TODO: whitelist env update?
            os.environ.update(self.env_options)
        self.video_init()

        log.info("new proxy instance started")
        log.info(" for client %s", self.client_conn)
        log.info(" and server %s", self.server_conn)

        signal.signal(signal.SIGTERM, self.signal_quit)
        signal.signal(signal.SIGINT, self.signal_quit)
        log("registered signal handler %s", self.signal_quit)

        start_thread(self.server_message_queue, "server message queue")

        if not self.create_control_socket():
            #TODO: should send a message to the client
            return
        self.control_socket_thread = start_thread(self.control_socket_loop, "control")

        self.main_queue = Queue()
        #setup protocol wrappers:
        self.server_packets = Queue(PROXY_QUEUE_SIZE)
        self.client_packets = Queue(PROXY_QUEUE_SIZE)
        self.client_protocol = Protocol(self, self.client_conn, self.process_client_packet, self.get_client_packet)
        self.client_protocol.restore_state(self.client_state)
        self.server_protocol = Protocol(self, self.server_conn, self.process_server_packet, self.get_server_packet)
        #server connection tweaks:
        self.server_protocol.large_packets.append("draw")
        self.server_protocol.large_packets.append("window-icon")
        self.server_protocol.large_packets.append("keymap-changed")
        self.server_protocol.large_packets.append("server-settings")
        if self.caps.boolget("file-transfer"):
            self.client_protocol.large_packets.append("send-file")
            self.client_protocol.large_packets.append("send-file-chunk")
            self.server_protocol.large_packets.append("send-file")
            self.server_protocol.large_packets.append("send-file-chunk")
        self.server_protocol.set_compression_level(self.session_options.get("compression_level", 0))
        self.server_protocol.enable_default_encoder()

        self.lost_windows = set()
        self.encode_queue = Queue()
        self.encode_thread = start_thread(self.encode_loop, "encode")

        log("starting network threads")
        self.server_protocol.start()
        self.client_protocol.start()

        self.send_hello()
        self.timeout_add(VIDEO_TIMEOUT*1000, self.timeout_video_encoders)

        try:
            self.run_queue()
        except KeyboardInterrupt as e:
            self.stop(str(e))
        finally:
            log("ProxyProcess.run() ending %s", os.getpid())

    def video_init(self):
        enclog("video_init() loading codecs")
        load_codecs(decoders=False)
        enclog("video_init() will try video encoders: %s", csv(self.video_encoder_modules) or "none")
        self.video_helper = getVideoHelper()
        #only use video encoders (no CSC supported in proxy)
        self.video_helper.set_modules(video_encoders=self.video_encoder_modules)
        self.video_helper.init()

        self.video_encoding_defs = {}
        self.video_encoders = {}
        self.video_encoders_dst_formats = []
        self.video_encoders_last_used_time = {}
        self.video_encoder_types = []

        #figure out which encoders we want to proxy for (if any):
        encoder_types = set()
        for encoding in self.video_helper.get_encodings():
            colorspace_specs = self.video_helper.get_encoder_specs(encoding)
            for colorspace, especs in colorspace_specs.items():
                if colorspace not in ("BGRX", "BGRA", "RGBX", "RGBA"):
                    #only deal with encoders that can handle plain RGB directly
                    continue

                for spec in especs:                             #ie: video_spec("x264")
                    spec_props = spec.to_dict()
                    del spec_props["codec_class"]               #not serializable!
                    spec_props["score_boost"] = 50              #we want to win scoring so we get used ahead of other encoders
                    spec_props["max_instances"] = 3             #limit to 3 video streams we proxy for (we really want 2,
                                                                # but because of races with garbage collection, we need to allow more)
                    #store it in encoding defs:
                    self.video_encoding_defs.setdefault(encoding, {}).setdefault(colorspace, []).append(spec_props)
                    encoder_types.add(spec.codec_type)

        enclog("encoder types found: %s", tuple(encoder_types))
        #remove duplicates and use preferred order:
        order = PREFERRED_ENCODER_ORDER[:]
        for x in list(encoder_types):
            if x not in order:
                order.append(x)
        self.video_encoder_types = [x for x in order if x in encoder_types]
        enclog.info("proxy video encoders: %s", ", ".join(self.video_encoder_types))


    def create_control_socket(self):
        assert self.socket_dir
        dotxpra = DotXpra(self.socket_dir)
        sockpath = dotxpra.socket_path(":proxy-%s" % os.getpid())
        state = dotxpra.get_server_state(sockpath)
        if state in (DotXpra.LIVE, DotXpra.UNKNOWN):
            log.error("Error: you already have a proxy server running at '%s'", sockpath)
            log.error(" the control socket will not be created")
            return False
        log("create_control_socket: socket path='%s', uid=%i, gid=%i", sockpath, getuid(), getgid())
        try:
            sock, self.control_socket_cleanup = create_unix_domain_socket(sockpath, None, 0o600)
            sock.listen(5)
        except Exception as e:
            log("create_unix_domain_socket failed for '%s'", sockpath, exc_info=True)
            log.error("Error: failed to setup control socket '%s':", sockpath)
            log.error(" %s", e)
            return False
        self.control_socket = sock
        self.control_socket_path = sockpath
        log.info("proxy instance now also available using unix domain socket:")
        log.info(" %s", self.control_socket_path)
        return True

    def control_socket_loop(self):
        while not self.exit:
            log("waiting for connection on %s", self.control_socket_path)
            sock, address = self.control_socket.accept()
            self.new_control_connection(sock, address)

    def new_control_connection(self, sock, address):
        if len(self.potential_protocols)>=self.max_connections:
            log.error("too many connections (%s), ignoring new one", len(self.potential_protocols))
            sock.close()
            return  True
        try:
            peername = sock.getpeername()
        except:
            peername = str(address)
        sockname = sock.getsockname()
        target = peername or sockname
        #sock.settimeout(0)
        log("new_control_connection() sock=%s, sockname=%s, address=%s, peername=%s", sock, sockname, address, peername)
        sc = SocketConnection(sock, sockname, address, target, "unix-domain")
        log.info("New proxy instance control connection received: %s", sc)
        protocol = Protocol(self, sc, self.process_control_packet)
        protocol.large_packets.append("info-response")
        self.potential_protocols.append(protocol)
        protocol.enable_default_encoder()
        protocol.start()
        self.timeout_add(SOCKET_TIMEOUT*1000, self.verify_connection_accepted, protocol)
        return True

    def verify_connection_accepted(self, protocol):
        if not protocol._closed and protocol in self.potential_protocols:
            log.error("connection timedout: %s", protocol)
            self.send_disconnect(protocol, LOGIN_TIMEOUT)

    def process_control_packet(self, proto, packet):
        try:
            self.do_process_control_packet(proto, packet)
        except Exception as e:
            log.error("error processing control packet", exc_info=True)
            self.send_disconnect(proto, CONTROL_COMMAND_ERROR, str(e))

    def do_process_control_packet(self, proto, packet):
        log("process_control_packet(%s, %s)", proto, packet)
        packet_type = packet[0]
        if packet_type==Protocol.CONNECTION_LOST:
            log.info("Connection lost")
            if proto in self.potential_protocols:
                self.potential_protocols.remove(proto)
            return
        if packet_type=="hello":
            caps = typedict(packet[1])
            if caps.boolget("challenge"):
                self.send_disconnect(proto, AUTHENTICATION_ERROR, "this socket does not use authentication")
                return
            if caps.get("info_request", False):
                proto.send_now(("hello", self.get_proxy_info(proto)))
                self.timeout_add(5*1000, self.send_disconnect, proto, CLIENT_EXIT_TIMEOUT, "info sent")
                return
            elif caps.get("stop_request", False):
                self.stop("socket request", None)
                return
            elif caps.get("version_request", False):
                from xpra import __version__
                proto.send_now(("hello", {"version" : __version__}))
                self.timeout_add(5*1000, self.send_disconnect, proto, CLIENT_EXIT_TIMEOUT, "version sent")
                return
        self.send_disconnect(proto, CONTROL_COMMAND_ERROR, "this socket only handles 'info', 'version' and 'stop' requests")

    def send_disconnect(self, proto, reason, *extra):
        log("send_disconnect(%s, %s, %s)", proto, reason, extra)
        if proto._closed:
            return
        proto.send_now(["disconnect", reason]+list(extra))
        self.timeout_add(1000, self.force_disconnect, proto)

    def force_disconnect(self, proto):
        proto.close()


    def get_proxy_info(self, proto):
        sinfo = {}
        sinfo.update(get_server_info())
        sinfo.update(get_thread_info(proto))
        return {"proxy" : {
                           "version"    : local_version,
                           ""           : sinfo,
                           },
                "window" : self.get_window_info(),
                }

    def send_hello(self, challenge_response=None, client_salt=None):
        hello = self.filter_client_caps(self.caps)
        if challenge_response:
            hello.update({
                          "challenge_response"      : challenge_response,
                          "challenge_client_salt"   : client_salt,
                          })
        self.queue_server_packet(("hello", hello))


    def sanitize_session_options(self, options):
        d = {}
        def number(k, v):
            return parse_number(int, k, v)
        OPTION_WHITELIST = {"compression_level" : number,
                            "lz4"               : parse_bool,
                            "lzo"               : parse_bool,
                            "zlib"              : parse_bool,
                            "rencode"           : parse_bool,
                            "bencode"           : parse_bool,
                            "yaml"              : parse_bool}
        for k,v in options.items():
            parser = OPTION_WHITELIST.get(k)
            if parser:
                log("trying to add %s=%s using %s", k, v, parser)
                try:
                    d[k] = parser(k, v)
                except Exception as e:
                    log.warn("failed to parse value %s for %s using %s: %s", v, k, parser, e)
        return d

    def filter_client_caps(self, caps):
        fc = self.filter_caps(caps, ("cipher", "challenge", "digest", "aliases", "compression", "lz4", "lz0", "zlib"))
        #update with options provided via config if any:
        fc.update(self.sanitize_session_options(self.session_options))
        #add video proxies if any:
        fc["encoding.proxy.video"] = len(self.video_encoding_defs)>0
        if self.video_encoding_defs:
            fc["encoding.proxy.video.encodings"] = self.video_encoding_defs
        return fc

    def filter_server_caps(self, caps):
        self.server_protocol.enable_encoder_from_caps(caps)
        return self.filter_caps(caps, ("aliases", ))

    def filter_caps(self, caps, prefixes):
        #removes caps that the proxy overrides / does not use:
        #(not very pythonic!)
        pcaps = {}
        removed = []
        for k in caps.keys():
            skip = len([e for e in prefixes if k.startswith(e)])
            if skip==0:
                pcaps[k] = caps[k]
            else:
                removed.append(k)
        log("filtered out %s matching %s", removed, prefixes)
        #replace the network caps with the proxy's own:
        pcaps.update(flatten_dict(get_network_caps()))
        #then add the proxy info:
        updict(pcaps, "proxy", get_server_info(), flatten_dicts=True)
        pcaps["proxy"] = True
        pcaps["proxy.hostname"] = socket.gethostname()
        return pcaps


    def run_queue(self):
        log("run_queue() queue has %s items already in it", self.main_queue.qsize())
        #process "idle_add"/"timeout_add" events in the main loop:
        while not self.exit:
            log("run_queue() size=%s", self.main_queue.qsize())
            v = self.main_queue.get()
            if v is None:
                log("run_queue() None exit marker")
                break
            fn, args, kwargs = v
            log("run_queue() %s%s%s", fn, args, kwargs)
            try:
                v = fn(*args, **kwargs)
                if bool(v):
                    #re-run it
                    self.main_queue.put(v)
            except:
                log.error("error during main loop callback %s", fn, exc_info=True)
        self.exit = True
        #wait for connections to close down cleanly before we exit
        for i in range(10):
            if self.client_protocol._closed and self.server_protocol._closed:
                break
            if i==0:
                log.info("waiting for network connections to close")
            else:
                log("still waiting %i/10 - client.closed=%s, server.closed=%s", i+1, self.client_protocol._closed, self.server_protocol._closed)
            time.sleep(0.1)
        log.info("proxy instance %s stopped", os.getpid())

    def stop(self, reason="proxy terminating", skip_proto=None):
        log.info("stop(%s, %s)", reason, skip_proto)
        self.exit = True
        try:
            self.control_socket.close()
        except:
            pass
        csc = self.control_socket_cleanup
        if csc:
            self.control_socket_cleanup = None
            csc()
        self.main_queue.put(None)
        #empty the main queue:
        q = Queue()
        q.put(None)
        self.main_queue = q
        #empty the encode queue:
        q = Queue()
        q.put(None)
        self.encode_queue = q
        for proto in (self.client_protocol, self.server_protocol):
            if proto and proto!=skip_proto:
                log("sending disconnect to %s", proto)
                proto.flush_then_close(["disconnect", SERVER_SHUTDOWN, reason])


    def queue_client_packet(self, packet):
        log("queueing client packet: %s", packet[0])
        self.client_packets.put(packet)
        self.client_protocol.source_has_more()

    def get_client_packet(self):
        #server wants a packet
        p = self.client_packets.get()
        log("sending to client: %s", p[0])
        return p, None, None, self.client_packets.qsize()>0

    def process_client_packet(self, proto, packet):
        packet_type = packet[0]
        log("process_client_packet: %s", packet_type)
        if packet_type==Protocol.CONNECTION_LOST:
            self.stop("client connection lost", proto)
            return
        elif packet_type=="disconnect":
            log("got disconnect from client: %s", packet[1])
            if self.exit:
                self.client_protocol.close()
            else:
                self.stop("disconnect from client: %s" % packet[1])
        elif packet_type=="set_deflate":
            #echo it back to the client:
            self.client_packets.put(packet)
            self.client_protocol.source_has_more()
            return
        elif packet_type=="hello":
            log.warn("Warning: invalid hello packet received after initial authentication (dropped)")
            return
        self.queue_server_packet(packet)


    def queue_server_packet(self, packet):
        log("queueing server packet: %s", packet[0])
        self.server_packets.put(packet)
        self.server_protocol.source_has_more()

    def get_server_packet(self):
        #server wants a packet
        p = self.server_packets.get()
        log("sending to server: %s", p[0])
        return p, None, None, self.server_packets.qsize()>0


    def _packet_recompress(self, packet, index, name):
        if len(packet)>index:
            data = packet[index]
            if len(data)<512:
                packet[index] = str(data)
                return
            #FIXME: this is ugly and not generic!
            zlib = compression.use_zlib and self.caps.boolget("zlib", True)
            lz4 = compression.use_lz4 and self.caps.boolget("lz4", False)
            lzo = compression.use_lzo and self.caps.boolget("lzo", False)
            if zlib or lz4 or lzo:
                packet[index] = compressed_wrapper(name, data, zlib=zlib, lz4=lz4, lzo=lzo, can_inline=False)
            else:
                #prevent warnings about large uncompressed data
                packet[index] = Compressed("raw %s" % name, data, can_inline=True)

    def process_server_packet(self, proto, packet):
        packet_type = packet[0]
        log("process_server_packet: %s", packet_type)
        if packet_type==Protocol.CONNECTION_LOST:
            self.stop("server connection lost", proto)
            return
        elif packet_type=="disconnect":
            log("got disconnect from server: %s", packet[1])
            if self.exit:
                self.server_protocol.close()
            else:
                self.stop("disconnect from server: %s" % packet[1])
        elif packet_type=="hello":
            c = typedict(packet[1])
            maxw, maxh = c.intpair("max_desktop_size", (4096, 4096))
            caps = self.filter_server_caps(c)
            #add new encryption caps:
            if self.cipher:
                from xpra.net.crypto import crypto_backend_init, new_cipher_caps, DEFAULT_PADDING
                crypto_backend_init()
                padding_options = self.caps.strlistget("cipher.padding.options", [DEFAULT_PADDING])
                auth_caps = new_cipher_caps(self.client_protocol, self.cipher, self.encryption_key, padding_options)
                caps.update(auth_caps)
            #may need to bump packet size:
            proto.max_packet_size = maxw*maxh*4*4
            file_transfer = self.caps.boolget("file-transfer") and c.boolget("file-transfer")
            file_size_limit = max(self.caps.intget("file-size-limit"), c.intget("file-size-limit"))
            file_max_packet_size = int(file_transfer) * (1024 + file_size_limit*1024*1024)
            self.client_protocol.max_packet_size = max(self.client_protocol.max_packet_size, file_max_packet_size)
            self.server_protocol.max_packet_size = max(self.server_protocol.max_packet_size, file_max_packet_size)
            packet = ("hello", caps)
        elif packet_type=="info-response":
            #adds proxy info:
            #note: this is only seen by the client application
            #"xpra info" is a new connection, which talks to the proxy server...
            info = packet[1]
            info.update(self.get_proxy_info(proto))
        elif packet_type=="lost-window":
            wid = packet[1]
            #mark it as lost so we can drop any current/pending frames
            self.lost_windows.add(wid)
            #queue it so it gets cleaned safely (for video encoders mostly):
            self.encode_queue.put(packet)
            #and fall through so tell the client immediately
        elif packet_type=="draw":
            #use encoder thread:
            self.encode_queue.put(packet)
            #which will queue the packet itself when done:
            return
        #we do want to reformat cursor packets...
        #as they will have been uncompressed by the network layer already:
        elif packet_type=="cursor":
            #packet = ["cursor", x, y, width, height, xhot, yhot, serial, pixels, name]
            #or:
            #packet = ["cursor", "png", x, y, width, height, xhot, yhot, serial, pixels, name]
            #or:
            #packet = ["cursor", ""]
            if len(packet)>=8:
                #hard to distinguish png cursors from normal cursors...
                try:
                    int(packet[1])
                    self._packet_recompress(packet, 8, "cursor")
                except:
                    self._packet_recompress(packet, 9, "cursor")
        elif packet_type=="window-icon":
            self._packet_recompress(packet, 5, "icon")
        elif packet_type=="send-file":
            if packet[6]:
                packet[6] = Compressed("file-data", packet[6])
        elif packet_type=="send-file-chunk":
            if packet[3]:
                packet[3] = Compressed("file-chunk-data", packet[3])
        elif packet_type=="challenge":
            from xpra.net.crypto import get_salt
            #client may have already responded to the challenge,
            #so we have to handle authentication from this end
            salt = packet[1]
            digest = packet[3]
            client_salt = get_salt(len(salt))
            salt = xor_str(salt, client_salt)
            if digest!=b"hmac":
                self.stop("digest mode '%s' not supported", std(digest))
                return
            password = self.session_options.get("password")
            if not password:
                self.stop("authentication requested by the server, but no password available for this session")
                return
            import hmac, hashlib
            password = strtobytes(password)
            salt = strtobytes(salt)
            challenge_response = hmac.HMAC(password, salt, digestmod=hashlib.md5).hexdigest()
            log.info("sending %s challenge response", digest)
            self.send_hello(challenge_response, client_salt)
            return
        self.queue_client_packet(packet)


    def encode_loop(self):
        """ thread for slower encoding related work """
        while not self.exit:
            packet = self.encode_queue.get()
            if packet is None:
                return
            try:
                packet_type = packet[0]
                if packet_type=="lost-window":
                    wid = packet[1]
                    self.lost_windows.remove(wid)
                    ve = self.video_encoders.get(wid)
                    if ve:
                        del self.video_encoders[wid]
                        del self.video_encoders_last_used_time[wid]
                        ve.clean()
                elif packet_type=="draw":
                    #modify the packet with the video encoder:
                    if self.process_draw(packet):
                        #then send it as normal:
                        self.queue_client_packet(packet)
                elif packet_type=="check-video-timeout":
                    #not a real packet, this is added by the timeout check:
                    wid = packet[1]
                    ve = self.video_encoders.get(wid)
                    now = time.time()
                    idle_time = now-self.video_encoders_last_used_time.get(wid)
                    if ve and idle_time>VIDEO_TIMEOUT:
                        enclog("timing out the video encoder context for window %s", wid)
                        #timeout is confirmed, we are in the encoding thread,
                        #so it is now safe to clean it up:
                        ve.clean()
                        del self.video_encoders[wid]
                        del self.video_encoders_last_used_time[wid]
                else:
                    enclog.warn("unexpected encode packet: %s", packet_type)
            except:
                enclog.warn("error encoding packet", exc_info=True)


    def process_draw(self, packet):
        wid, x, y, width, height, encoding, pixels, _, rowstride, client_options = packet[1:11]
        #never modify mmap packets
        if encoding=="mmap":
            return True

        #we have a proxy video packet:
        rgb_format = client_options.get("rgb_format", "")
        enclog("proxy draw: client_options=%s", client_options)

        def send_updated(encoding, compressed_data, updated_client_options):
            #update the packet with actual encoding data used:
            packet[6] = encoding
            packet[7] = compressed_data
            packet[10] = updated_client_options
            enclog("returning %s bytes from %s, options=%s", len(compressed_data), len(pixels), updated_client_options)
            return (wid not in self.lost_windows)

        def passthrough(strip_alpha=True):
            enclog("proxy draw: %s passthrough (rowstride: %s vs %s, strip alpha=%s)", rgb_format, rowstride, client_options.get("rowstride", 0), strip_alpha)
            if strip_alpha:
                #passthrough as plain RGB:
                Xindex = rgb_format.upper().find("X")
                if Xindex>=0 and len(rgb_format)==4:
                    #force clear alpha (which may be garbage):
                    newdata = bytearray(pixels)
                    for i in range(len(pixels)/4):
                        newdata[i*4+Xindex] = chr(255)
                    packet[9] = client_options.get("rowstride", 0)
                    cdata = bytes(newdata)
                else:
                    cdata = pixels
                new_client_options = {"rgb_format" : rgb_format}
            else:
                #preserve
                cdata = pixels
                new_client_options = client_options
            wrapped = Compressed("%s pixels" % encoding, cdata)
            #FIXME: we should not assume that rgb32 is supported here...
            #(we may have to convert to rgb24..)
            return send_updated("rgb32", wrapped, new_client_options)

        proxy_video = client_options.get("proxy", False)
        if PASSTHROUGH and (encoding in ("rgb32", "rgb24") or proxy_video):
            #we are dealing with rgb data, so we can pass it through:
            return passthrough(proxy_video)
        elif not self.video_encoder_types or not client_options or not proxy_video:
            #ensure we don't try to re-compress the pixel data in the network layer:
            #(re-add the "compressed" marker that gets lost when we re-assemble packets)
            packet[7] = Compressed("%s pixels" % encoding, packet[7])
            return True

        #video encoding: find existing encoder
        ve = self.video_encoders.get(wid)
        if ve:
            if ve in self.lost_windows:
                #we cannot clean the video encoder here, there may be more frames queue up
                #"lost-window" in encode_loop will take care of it safely
                return False
            #we must verify that the encoder is still valid
            #and scrap it if not (ie: when window is resized)
            if ve.get_width()!=width or ve.get_height()!=height:
                enclog("closing existing video encoder %s because dimensions have changed from %sx%s to %sx%s", ve, ve.get_width(), ve.get_height(), width, height)
                ve.clean()
                ve = None
            elif ve.get_encoding()!=encoding:
                enclog("closing existing video encoder %s because encoding has changed from %s to %s", ve.get_encoding(), encoding)
                ve.clean()
                ve = None
        #scaling and depth are proxy-encoder attributes:
        scaling = client_options.get("scaling", (1, 1))
        depth   = client_options.get("depth", 24)
        rowstride = client_options.get("rowstride", rowstride)
        quality = client_options.get("quality", -1)
        speed   = client_options.get("speed", -1)
        timestamp = client_options.get("timestamp")

        image = ImageWrapper(x, y, width, height, pixels, rgb_format, depth, rowstride, planes=ImageWrapper.PACKED)
        if timestamp is not None:
            image.set_timestamp(timestamp)

        #the encoder options are passed through:
        encoder_options = client_options.get("options", {})
        if not ve:
            #make a new video encoder:
            spec = self._find_video_encoder(encoding, rgb_format)
            if spec is None:
                #no video encoder!
                enc_pillow = get_codec("enc_pillow")
                if not enc_pillow:
                    from xpra.server.picture_encode import warn_encoding_once
                    warn_encoding_once("no-video-no-PIL", "no video encoder found for rgb format %s, sending as plain RGB!" % rgb_format)
                    return passthrough(True)
                enclog("no video encoder available: sending as jpeg")
                coding, compressed_data, client_options, _, _, _, _ = enc_pillow.encode("jpeg", image, quality, speed, False)
                return send_updated(coding, compressed_data, client_options)

            enclog("creating new video encoder %s for window %s", spec, wid)
            ve = spec.make_instance()
            #dst_formats is specified with first frame only:
            dst_formats = client_options.get("dst_formats")
            if dst_formats is not None:
                #save it in case we timeout the video encoder,
                #so we can instantiate it again, even from a frame no>1
                self.video_encoders_dst_formats = dst_formats
            else:
                assert self.video_encoders_dst_formats, "BUG: dst_formats not specified for proxy and we don't have it either"
                dst_formats = self.video_encoders_dst_formats
            ve.init_context(width, height, rgb_format, dst_formats, encoding, quality, speed, scaling, {})
            self.video_encoders[wid] = ve
            self.video_encoders_last_used_time[wid] = time.time()       #just to make sure this is always set
        #actual video compression:
        enclog("proxy compression using %s with quality=%s, speed=%s", ve, quality, speed)
        data, out_options = ve.compress_image(image, quality, speed, encoder_options)
        #pass through some options if we don't have them from the encoder
        #(maybe we should also use the "pts" from the real server?)
        for k in ("timestamp", "rgb_format", "depth", "csc"):
            if k not in out_options and k in client_options:
                out_options[k] = client_options[k]
        self.video_encoders_last_used_time[wid] = time.time()
        return send_updated(ve.get_encoding(), Compressed(encoding, data), out_options)

    def timeout_video_encoders(self):
        #have to be careful as another thread may come in...
        #so we just ask the encode thread (which deals with encoders already)
        #to do what may need to be done if we find a timeout:
        now = time.time()
        for wid in list(self.video_encoders_last_used_time.keys()):
            idle_time = int(now-self.video_encoders_last_used_time.get(wid))
            if idle_time is None:
                continue
            enclog("timeout_video_encoders() wid=%s, idle_time=%s", wid, idle_time)
            if idle_time and idle_time>VIDEO_TIMEOUT:
                self.encode_queue.put(["check-video-timeout", wid])
        return True     #run again

    def _find_video_encoder(self, encoding, rgb_format):
        #try the one specified first, then all the others:
        try_encodings = [encoding] + [x for x in self.video_helper.get_encodings() if x!=encoding]
        for encoding in try_encodings:
            colorspace_specs = self.video_helper.get_encoder_specs(encoding)
            especs = colorspace_specs.get(rgb_format)
            if len(especs)==0:
                continue
            for etype in self.video_encoder_types:
                for spec in especs:
                    if etype==spec.codec_type:
                        enclog("_find_video_encoder(%s, %s)=%s", encoding, rgb_format, spec)
                        return spec
        enclog("_find_video_encoder(%s, %s) not found", encoding, rgb_format)
        return None

    def get_window_info(self):
        info = {}
        now = time.time()
        for wid, encoder in self.video_encoders.items():
            einfo = encoder.get_info()
            einfo["idle_time"] = int(now-self.video_encoders_last_used_time.get(wid, 0))
            info[wid] = {
                         "proxy"    : {
                                       ""           : encoder.get_type(),
                                       "encoder"    : einfo
                                       },
                         }
        enclog("get_window_info()=%s", info)
        return info
コード例 #12
0
class XpraClientBase(object):
    """ Base class for Xpra clients.
        Provides the glue code for:
        * sending packets via Protocol
        * handling packets received via _process_packet
        For an actual implementation, look at:
        * GObjectXpraClient
        * xpra.client.gtk2.client
        * xpra.client.gtk3.client
    """
    def __init__(self):
        #this may be called more than once,
        #skip doing internal init again:
        if not hasattr(self, "exit_code"):
            self.defaults_init()

    def defaults_init(self):
        self.exit_code = None
        self.compression_level = 0
        self.display = None
        self.username = None
        self.password_file = None
        self.password_sent = False
        self.encryption = None
        self.encryption_keyfile = None
        self.quality = -1
        self.min_quality = 0
        self.speed = 0
        self.min_speed = -1
        #protocol stuff:
        self._protocol = None
        self._priority_packets = []
        self._ordinary_packets = []
        self._mouse_position = None
        self._aliases = {}
        self._reverse_aliases = {}
        #server state and caps:
        self.server_capabilities = None
        self._remote_machine_id = None
        self._remote_uuid = None
        self._remote_version = None
        self._remote_revision = None
        self._remote_platform = None
        self._remote_platform_release = None
        self._remote_platform_platform = None
        self._remote_platform_linux_distribution = None
        self.uuid = get_user_uuid()
        self.init_packet_handlers()

    def init(self, opts):
        self.compression_level = opts.compression_level
        self.display = opts.display
        self.username = opts.username
        self.password_file = opts.password_file
        self.encryption = opts.encryption
        self.encryption_keyfile = opts.encryption_keyfile
        self.quality = opts.quality
        self.min_quality = opts.min_quality
        self.speed = opts.speed
        self.min_speed = opts.min_speed

    def timeout_add(self, *args):
        raise Exception("override me!")

    def idle_add(self, *args):
        raise Exception("override me!")

    def source_remove(self, *args):
        raise Exception("override me!")

    def install_signal_handlers(self):
        def deadly_signal(signum, frame):
            sys.stderr.write("\ngot deadly signal %s, exiting\n" %
                             SIGNAMES.get(signum, signum))
            sys.stderr.flush()
            self.cleanup()
            os._exit(128 + signum)

        def app_signal(signum, frame):
            sys.stderr.write("\ngot signal %s, exiting\n" %
                             SIGNAMES.get(signum, signum))
            sys.stderr.flush()
            signal.signal(signal.SIGINT, deadly_signal)
            signal.signal(signal.SIGTERM, deadly_signal)
            self.timeout_add(0, self.quit, 128 + signum)

        signal.signal(signal.SIGINT, app_signal)
        signal.signal(signal.SIGTERM, app_signal)

    def client_type(self):
        #overriden in subclasses!
        return "Python"

    def get_scheduler(self):
        raise NotImplementedError()

    def setup_connection(self, conn):
        log.debug("setup_connection(%s)", conn)
        self._protocol = Protocol(self.get_scheduler(), conn,
                                  self.process_packet, self.next_packet)
        self._protocol.large_packets.append("keymap-changed")
        self._protocol.large_packets.append("server-settings")
        self._protocol.set_compression_level(self.compression_level)
        self._protocol.receive_aliases.update(self._aliases)
        self.have_more = self._protocol.source_has_more

    def init_packet_handlers(self):
        self._packet_handlers = {
            "hello": self._process_hello,
        }
        self._ui_packet_handlers = {
            "challenge": self._process_challenge,
            "disconnect": self._process_disconnect,
            "set_deflate": self._process_set_deflate,
            Protocol.CONNECTION_LOST: self._process_connection_lost,
            Protocol.GIBBERISH: self._process_gibberish,
        }

    def init_aliases(self):
        packet_types = list(self._packet_handlers.keys())
        packet_types += list(self._ui_packet_handlers.keys())
        i = 1
        for key in packet_types:
            self._aliases[i] = key
            self._reverse_aliases[key] = i
            i += 1

    def send_hello(self, challenge_response=None, client_salt=None):
        hello = self.make_hello_base()
        if self.password_file and not challenge_response:
            #avoid sending the full hello: tell the server we want
            #a packet challenge first
            hello["challenge"] = True
        else:
            hello.update(self.make_hello())
        if challenge_response:
            assert self.password_file
            hello["challenge_response"] = challenge_response
            if client_salt:
                hello["challenge_client_salt"] = client_salt
        log.debug("send_hello(%s) packet=%s",
                  binascii.hexlify(challenge_response or ""), hello)
        self.send("hello", hello)
        self.timeout_add(DEFAULT_TIMEOUT, self.verify_connected)

    def verify_connected(self):
        if self.server_capabilities is None:
            #server has not said hello yet
            self.warn_and_quit(EXIT_TIMEOUT, "connection timed out")

    def make_hello_base(self):
        capabilities = get_network_caps()
        capabilities.update({
            "encoding.generic": True,
            "namespace": True,
            "hostname": socket.gethostname(),
            "uuid": self.uuid,
            "username": self.username,
            "name": get_name(),
            "client_type": self.client_type(),
            "python.version": sys.version_info[:3],
            "compression_level": self.compression_level,
        })
        if self.display:
            capabilities["display"] = self.display
        capabilities.update(get_platform_info())
        add_version_info(capabilities)
        mid = get_machine_id()
        if mid:
            capabilities["machine_id"] = mid

        if self.encryption:
            assert self.encryption in ENCRYPTION_CIPHERS
            iv = get_hex_uuid()[:16]
            key_salt = get_hex_uuid() + get_hex_uuid()
            iterations = 1000
            capabilities.update({
                "cipher": self.encryption,
                "cipher.iv": iv,
                "cipher.key_salt": key_salt,
                "cipher.key_stretch_iterations": iterations,
            })
            key = self.get_encryption_key()
            if key is None:
                self.warn_and_quit(EXIT_ENCRYPTION,
                                   "encryption key is missing")
                return
            self._protocol.set_cipher_in(self.encryption, iv, key, key_salt,
                                         iterations)
            log("encryption capabilities: %s",
                [(k, v)
                 for k, v in capabilities.items() if k.startswith("cipher")])
        return capabilities

    def make_hello(self):
        capabilities = {
            "randr_notify": False,  #only client.py cares about this
            "windows": False,  #only client.py cares about this
        }
        if self._reverse_aliases:
            capabilities["aliases"] = self._reverse_aliases
        return capabilities

    def send(self, *parts):
        self._ordinary_packets.append(parts)
        self.have_more()

    def send_now(self, *parts):
        self._priority_packets.append(parts)
        self.have_more()

    def send_positional(self, packet):
        self._ordinary_packets.append(packet)
        self._mouse_position = None
        self.have_more()

    def send_mouse_position(self, packet):
        self._mouse_position = packet
        self.have_more()

    def have_more(self):
        #this function is overridden in setup_protocol()
        p = self._protocol
        if p and p.source:
            p.source_has_more()

    def next_packet(self):
        if self._priority_packets:
            packet = self._priority_packets.pop(0)
        elif self._ordinary_packets:
            packet = self._ordinary_packets.pop(0)
        elif self._mouse_position is not None:
            packet = self._mouse_position
            self._mouse_position = None
        else:
            packet = None
        has_more = packet is not None and \
                (bool(self._priority_packets) or bool(self._ordinary_packets) \
                 or self._mouse_position is not None)
        return packet, None, None, has_more

    def cleanup(self):
        log("XpraClientBase.cleanup() protocol=%s", self._protocol)
        if self._protocol:
            self._protocol.close()
            self._protocol = None

    def glib_init(self):
        try:
            glib = import_glib()
            try:
                glib.threads_init()
            except AttributeError:
                #old versions of glib may not have this method
                pass
        except ImportError:
            pass

    def run(self):
        self._protocol.start()

    def quit(self, exit_code=0):
        raise Exception("override me!")

    def warn_and_quit(self, exit_code, warning):
        log.warn(warning)
        self.quit(exit_code)

    def _process_disconnect(self, packet):
        if len(packet) == 2:
            info = packet[1]
        else:
            info = packet[1:]
        e = EXIT_OK
        if self.server_capabilities is None or len(
                self.server_capabilities) == 0:
            #server never sent hello to us - so disconnect is an error
            #(but we don't know which one - the info message may help)
            e = EXIT_FAILURE
        self.warn_and_quit(e, "server requested disconnect: %s" % info)

    def _process_connection_lost(self, packet):
        self.warn_and_quit(EXIT_CONNECTION_LOST, "Connection lost")

    def _process_challenge(self, packet):
        log("processing challenge: %s", packet[1:])
        if not self.password_file:
            self.warn_and_quit(
                EXIT_PASSWORD_REQUIRED,
                "server requires authentication, please provide a password")
            return
        password = self.load_password()
        if not password:
            self.warn_and_quit(
                EXIT_PASSWORD_FILE_ERROR,
                "failed to load password from file %s" % self.password_file)
            return
        salt = packet[1]
        if self.encryption:
            assert len(
                packet
            ) >= 3, "challenge does not contain encryption details to use for the response"
            server_cipher = packet[2]
            key = self.get_encryption_key()
            if key is None:
                self.warn_and_quit(EXIT_ENCRYPTION,
                                   "encryption key is missing")
                return
            if not self.set_server_encryption(server_cipher, key):
                return
        digest = "hmac"
        client_can_salt = len(packet) >= 4
        client_salt = None
        if client_can_salt:
            #server supports client salt, and tells us which digest to use:
            digest = packet[3]
            client_salt = get_hex_uuid() + get_hex_uuid()
            #TODO: use some key stretching algorigthm? (meh)
            salt = xor(salt, client_salt)
        if digest == "hmac":
            import hmac
            challenge_response = hmac.HMAC(password, salt).hexdigest()
        elif digest == "xor":
            #don't send XORed password unencrypted:
            if not self._protocol.cipher_out and not ALLOW_UNENCRYPTED_PASSWORDS:
                self.warn_and_quit(
                    EXIT_ENCRYPTION,
                    "server requested digest %s, cowardly refusing to use it without encryption"
                    % digest)
                return
            challenge_response = xor(password, salt)
        else:
            self.warn_and_quit(
                EXIT_PASSWORD_REQUIRED,
                "server requested an unsupported digest: %s" % digest)
            return
        if digest:
            log("%s(%s, %s)=%s", digest, password, salt, challenge_response)
        self.password_sent = True
        self.send_hello(challenge_response, client_salt)

    def set_server_encryption(self, capabilities, key):
        def get(key, default=None):
            return capabilities.get(strtobytes(key), default)

        cipher = get("cipher")
        cipher_iv = get("cipher.iv")
        key_salt = get("cipher.key_salt")
        iterations = get("cipher.key_stretch_iterations")
        if not cipher or not cipher_iv:
            self.warn_and_quit(
                EXIT_ENCRYPTION,
                "the server does not use or support encryption/password, cannot continue with %s cipher"
                % self.encryption)
            return False
        if cipher not in ENCRYPTION_CIPHERS:
            self.warn_and_quit(
                EXIT_ENCRYPTION,
                "unsupported server cipher: %s, allowed ciphers: %s" %
                (cipher, ", ".join(ENCRYPTION_CIPHERS)))
            return False
        self._protocol.set_cipher_out(cipher, cipher_iv, key, key_salt,
                                      iterations)
        return True

    def get_encryption_key(self):
        key = load_binary_file(self.encryption_keyfile)
        if key is None and self.password_file:
            key = load_binary_file(self.password_file)
            if key:
                log("used password file as encryption key")
        if key is None:
            raise Exception("failed to load encryption keyfile %s" %
                            self.encryption_keyfile)
        return key.strip("\n\r")

    def load_password(self):
        filename = os.path.expanduser(self.password_file)
        password = load_binary_file(filename)
        if password is None:
            return None
        password = password.strip("\n\r")
        log("password read from file %s is %s", self.password_file,
            "".join(["*" for _ in password]))
        return password

    def _process_hello(self, packet):
        if not self.password_sent and self.password_file:
            self.warn_and_quit(EXIT_NO_AUTHENTICATION,
                               "the server did not request our password")
            return
        try:
            self.server_capabilities = packet[1]
            log("processing hello from server: %s", self.server_capabilities)
            c = typedict(self.server_capabilities)
            self.parse_server_capabilities(c)
        except Exception, e:
            self.warn_and_quit(
                EXIT_FAILURE,
                "error processing hello packet from server: %s" % e)
コード例 #13
0
ファイル: client_base.py プロジェクト: rudresh2319/Xpra
class XpraClientBase(object):
    """ Base class for Xpra clients.
        Provides the glue code for:
        * sending packets via Protocol
        * handling packets received via _process_packet
        For an actual implementation, look at:
        * GObjectXpraClient
        * xpra.client.gtk2.client
        * xpra.client.gtk3.client
    """
    def __init__(self):
        self.exit_code = None
        self.compression_level = 0
        self.password = None
        self.password_file = None
        self.password_sent = False
        self.encryption = None
        self.quality = -1
        self.min_quality = 0
        self.speed = 0
        self.min_speed = -1
        #protocol stuff:
        self._protocol = None
        self._priority_packets = []
        self._ordinary_packets = []
        self._mouse_position = None
        self._aliases = {}
        self._reverse_aliases = {}
        #server state and caps:
        self.server_capabilities = None
        self._remote_version = None
        self._remote_revision = None
        self._remote_platform = None
        self._remote_platform_release = None
        self._remote_platform_platform = None
        self._remote_platform_linux_distribution = None
        self.make_uuid()
        self.init_packet_handlers()

    def init(self, opts):
        self.compression_level = opts.compression_level
        self.password_file = opts.password_file
        self.encryption = opts.encryption
        self.quality = opts.quality
        self.min_quality = opts.min_quality
        self.speed = opts.speed
        self.min_speed = opts.min_speed

    def timeout_add(self, *args):
        raise Exception("override me!")

    def idle_add(self, *args):
        raise Exception("override me!")

    def source_remove(self, *args):
        raise Exception("override me!")

    def install_signal_handlers(self):
        def deadly_signal(signum, frame):
            sys.stderr.write("\ngot deadly signal %s, exiting\n" %
                             SIGNAMES.get(signum, signum))
            sys.stderr.flush()
            self.cleanup()
            os._exit(128 + signum)

        def app_signal(signum, frame):
            sys.stderr.write("\ngot signal %s, exiting\n" %
                             SIGNAMES.get(signum, signum))
            sys.stderr.flush()
            signal.signal(signal.SIGINT, deadly_signal)
            signal.signal(signal.SIGTERM, deadly_signal)
            self.timeout_add(0, self.quit, 128 + signum)

        signal.signal(signal.SIGINT, app_signal)
        signal.signal(signal.SIGTERM, app_signal)

    def client_type(self):
        #overriden in subclasses!
        return "Python"

    def setup_connection(self, conn):
        log.debug("setup_connection(%s)", conn)
        self._protocol = Protocol(conn, self.process_packet, self.next_packet)
        self._protocol.large_packets.append("keymap-changed")
        self._protocol.large_packets.append("server-settings")
        self._protocol.set_compression_level(self.compression_level)
        self.have_more = self._protocol.source_has_more

    def init_packet_handlers(self):
        self._packet_handlers = {
            "hello": self._process_hello,
        }
        self._ui_packet_handlers = {
            "challenge": self._process_challenge,
            "disconnect": self._process_disconnect,
            "set_deflate": self._process_set_deflate,
            Protocol.CONNECTION_LOST: self._process_connection_lost,
            Protocol.GIBBERISH: self._process_gibberish,
        }

    def init_aliases(self):
        packet_types = list(self._packet_handlers.keys())
        packet_types += list(self._ui_packet_handlers.keys())
        i = 1
        for key in packet_types:
            self._aliases[i] = key
            self._reverse_aliases[key] = i
            i += 1

    def send_hello(self, challenge_response=None):
        hello = self.make_hello(challenge_response)
        log.debug("send_hello(%s) packet=%s", challenge_response, hello)
        self.send("hello", hello)
        self.timeout_add(DEFAULT_TIMEOUT, self.verify_connected)

    def verify_connected(self):
        if self.server_capabilities is None:
            #server has not said hello yet
            self.warn_and_quit(EXIT_TIMEOUT, "connection timed out")

    def make_hello(self, challenge_response=None):
        capabilities = {}
        add_version_info(capabilities)
        capabilities["python.version"] = sys.version_info[:3]
        if challenge_response:
            assert self.password
            capabilities["challenge_response"] = challenge_response
        if self.encryption:
            assert self.encryption in ENCRYPTION_CIPHERS
            capabilities["cipher"] = self.encryption
            iv = get_hex_uuid()[:16]
            capabilities["cipher.iv"] = iv
            key_salt = get_hex_uuid()
            capabilities["cipher.key_salt"] = key_salt
            iterations = 1000
            capabilities["cipher.key_stretch_iterations"] = iterations
            self._protocol.set_cipher_in(self.encryption, iv,
                                         self.get_password(), key_salt,
                                         iterations)
            log("encryption capabilities: %s",
                [(k, v)
                 for k, v in capabilities.items() if k.startswith("cipher")])
        capabilities["platform"] = sys.platform
        capabilities["platform.release"] = python_platform.release()
        capabilities["platform.machine"] = python_platform.machine()
        capabilities["platform.processor"] = python_platform.processor()
        capabilities["client_type"] = self.client_type()
        capabilities["namespace"] = True
        capabilities["raw_packets"] = True
        capabilities["chunked_compression"] = True
        capabilities["bencode"] = True
        capabilities["rencode"] = has_rencode
        if has_rencode:
            capabilities["rencode.version"] = rencode_version
        capabilities["hostname"] = socket.gethostname()
        capabilities["uuid"] = self.uuid
        try:
            from xpra.platform.info import get_username, get_name
            capabilities["username"] = get_username()
            capabilities["name"] = get_name()
        except:
            log.error("failed to get username/name", exc_info=True)
        capabilities["randr_notify"] = False  #only client.py cares about this
        capabilities["windows"] = False  #only client.py cares about this
        if self._reverse_aliases:
            capabilities["aliases"] = self._reverse_aliases
        return capabilities

    def make_uuid(self):
        try:
            import hashlib
            u = hashlib.sha1()
        except:
            #try python2.4 variant:
            import sha
            u = sha.new()

        def uupdate(ustr):
            u.update(ustr.encode("utf-8"))

        uupdate(get_machine_id())
        if os.name == "posix":
            uupdate(u"/")
            uupdate(str(os.getuid()))
            uupdate(u"/")
            uupdate(str(os.getgid()))
        self.uuid = u.hexdigest()

    def send(self, *parts):
        self._ordinary_packets.append(parts)
        self.have_more()

    def send_now(self, *parts):
        self._priority_packets.append(parts)
        self.have_more()

    def send_positional(self, packet):
        self._ordinary_packets.append(packet)
        self._mouse_position = None
        self.have_more()

    def send_mouse_position(self, packet):
        self._mouse_position = packet
        self.have_more()

    def have_more(self):
        #this function is overridden in setup_protocol()
        p = self._protocol
        if p and p.source:
            p.source_has_more()

    def next_packet(self):
        if self._priority_packets:
            packet = self._priority_packets.pop(0)
        elif self._ordinary_packets:
            packet = self._ordinary_packets.pop(0)
        elif self._mouse_position is not None:
            packet = self._mouse_position
            self._mouse_position = None
        else:
            packet = None
        has_more = packet is not None and \
                (bool(self._priority_packets) or bool(self._ordinary_packets) \
                 or self._mouse_position is not None)
        return packet, None, None, has_more

    def cleanup(self):
        log("XpraClientBase.cleanup() protocol=%s", self._protocol)
        if self._protocol:
            self._protocol.close()
            self._protocol = None

    def glib_init(self):
        try:
            glib = import_glib()
            try:
                glib.threads_init()
            except AttributeError:
                #old versions of glib may not have this method
                pass
        except ImportError:
            pass

    def run(self):
        self._protocol.start()

    def quit(self, exit_code=0):
        raise Exception("override me!")

    def warn_and_quit(self, exit_code, warning):
        log.warn(warning)
        self.quit(exit_code)

    def _process_disconnect(self, packet):
        if len(packet) == 2:
            info = packet[1]
        else:
            info = packet[1:]
        e = EXIT_OK
        if self.server_capabilities is None or len(
                self.server_capabilities) == 0:
            #server never sent hello to us - so disconnect is an error
            #(but we don't know which one - the info message may help)
            e = EXIT_FAILURE
        self.warn_and_quit(e, "server requested disconnect: %s" % info)

    def _process_connection_lost(self, packet):
        self.warn_and_quit(EXIT_CONNECTION_LOST, "Connection lost")

    def _process_challenge(self, packet):
        if not self.password_file and not self.password:
            self.warn_and_quit(EXIT_PASSWORD_REQUIRED,
                               "password is required by the server")
            return
        if not self.password:
            if not self.load_password():
                return
            assert self.password
        salt = packet[1]
        if self.encryption:
            assert len(
                packet
            ) >= 3, "challenge does not contain encryption details to use for the response"
            server_cipher = packet[2]
            self.set_server_encryption(server_cipher)
        import hmac
        challenge_response = hmac.HMAC(self.password, salt)
        password_hash = challenge_response.hexdigest()
        self.password_sent = True
        self.send_hello(password_hash)

    def set_server_encryption(self, capabilities):
        def get(key, default=None):
            return capabilities.get(strtobytes(key), default)

        cipher = get("cipher")
        cipher_iv = get("cipher.iv")
        key_salt = get("cipher.key_salt")
        iterations = get("cipher.key_stretch_iterations")
        if not cipher or not cipher_iv:
            self.warn_and_quit(
                EXIT_ENCRYPTION,
                "the server does not use or support encryption/password, cannot continue with %s cipher"
                % self.encryption)
            return False
        if cipher not in ENCRYPTION_CIPHERS:
            self.warn_and_quit(
                EXIT_ENCRYPTION,
                "unsupported server cipher: %s, allowed ciphers: %s" %
                (cipher, ", ".join(ENCRYPTION_CIPHERS)))
            return False
        self._protocol.set_cipher_out(cipher, cipher_iv, self.get_password(),
                                      key_salt, iterations)

    def get_password(self):
        if self.password is None:
            self.load_password()
        return self.password

    def load_password(self):
        try:
            filename = os.path.expanduser(self.password_file)
            passwordFile = open(filename, "rU")
            self.password = passwordFile.read()
            passwordFile.close()
            while self.password.endswith("\n") or self.password.endswith("\r"):
                self.password = self.password[:-1]
        except IOError, e:
            self.warn_and_quit(
                EXIT_PASSWORD_FILE_ERROR,
                "failed to open password file %s: %s" %
                (self.password_file, e))
            return False
        log("password read from file %s is %s", self.password_file,
            self.password)
        return True
コード例 #14
0
ファイル: client_base.py プロジェクト: svn2github/Xpra
class XpraClientBase(object):
    """ Base class for Xpra clients.
        Provides the glue code for:
        * sending packets via Protocol
        * handling packets received via _process_packet
        For an actual implementation, look at:
        * GObjectXpraClient
        * xpra.client.gtk2.client
        * xpra.client.gtk3.client
    """

    def __init__(self):
        self.exit_code = None
        self.compression_level = 0
        self.password = None
        self.password_file = None
        self.password_sent = False
        self.encryption = None
        self.quality = -1
        self.min_quality = 0
        self.speed = 0
        self.min_speed = -1
        #protocol stuff:
        self._protocol = None
        self._priority_packets = []
        self._ordinary_packets = []
        self._mouse_position = None
        self._aliases = {}
        self._reverse_aliases = {}
        #server state and caps:
        self.server_capabilities = None
        self._remote_version = None
        self._remote_revision = None
        self._remote_platform = None
        self._remote_platform_release = None
        self._remote_platform_platform = None
        self._remote_platform_linux_distribution = None
        self.make_uuid()
        self.init_packet_handlers()

    def init(self, opts):
        self.compression_level = opts.compression_level
        self.password_file = opts.password_file
        self.encryption = opts.encryption
        self.quality = opts.quality
        self.min_quality = opts.min_quality
        self.speed = opts.speed
        self.min_speed = opts.min_speed

    def timeout_add(self, *args):
        raise Exception("override me!")

    def idle_add(self, *args):
        raise Exception("override me!")

    def source_remove(self, *args):
        raise Exception("override me!")


    def install_signal_handlers(self):
        def deadly_signal(signum, frame):
            sys.stderr.write("\ngot deadly signal %s, exiting\n" % SIGNAMES.get(signum, signum))
            sys.stderr.flush()
            self.cleanup()
            os._exit(128 + signum)
        def app_signal(signum, frame):
            sys.stderr.write("\ngot signal %s, exiting\n" % SIGNAMES.get(signum, signum))
            sys.stderr.flush()
            signal.signal(signal.SIGINT, deadly_signal)
            signal.signal(signal.SIGTERM, deadly_signal)
            self.timeout_add(0, self.quit, 128 + signum)
        signal.signal(signal.SIGINT, app_signal)
        signal.signal(signal.SIGTERM, app_signal)


    def client_type(self):
        #overriden in subclasses!
        return "Python"

    def setup_connection(self, conn):
        log.debug("setup_connection(%s)", conn)
        self._protocol = Protocol(conn, self.process_packet, self.next_packet)
        self._protocol.large_packets.append("keymap-changed")
        self._protocol.large_packets.append("server-settings")
        self._protocol.set_compression_level(self.compression_level)
        self.have_more = self._protocol.source_has_more

    def init_packet_handlers(self):
        self._packet_handlers = {
            "hello": self._process_hello,
            }
        self._ui_packet_handlers = {
            "challenge": self._process_challenge,
            "disconnect": self._process_disconnect,
            "set_deflate": self._process_set_deflate,
            Protocol.CONNECTION_LOST: self._process_connection_lost,
            Protocol.GIBBERISH: self._process_gibberish,
            }

    def init_aliases(self):
        packet_types = list(self._packet_handlers.keys())
        packet_types += list(self._ui_packet_handlers.keys())
        i = 1
        for key in packet_types:
            self._aliases[i] = key
            self._reverse_aliases[key] = i
            i += 1

    def send_hello(self, challenge_response=None):
        hello = self.make_hello(challenge_response)
        log.debug("send_hello(%s) packet=%s", challenge_response, hello)
        self.send("hello", hello)
        self.timeout_add(DEFAULT_TIMEOUT, self.verify_connected)

    def verify_connected(self):
        if self.server_capabilities is None:
            #server has not said hello yet
            self.warn_and_quit(EXIT_TIMEOUT, "connection timed out")


    def make_hello(self, challenge_response=None):
        capabilities = {}
        add_version_info(capabilities)
        capabilities["python.version"] = sys.version_info[:3]
        if challenge_response:
            assert self.password
            capabilities["challenge_response"] = challenge_response
        if self.encryption:
            assert self.encryption in ENCRYPTION_CIPHERS
            capabilities["cipher"] = self.encryption
            iv = get_hex_uuid()[:16]
            capabilities["cipher.iv"] = iv
            key_salt = get_hex_uuid()
            capabilities["cipher.key_salt"] = key_salt
            iterations = 1000
            capabilities["cipher.key_stretch_iterations"] = iterations
            self._protocol.set_cipher_in(self.encryption, iv, self.get_password(), key_salt, iterations)
            log("encryption capabilities: %s", [(k,v) for k,v in capabilities.items() if k.startswith("cipher")])
        capabilities["platform"] = sys.platform
        capabilities["platform.release"] = python_platform.release()
        capabilities["platform.machine"] = python_platform.machine()
        capabilities["platform.processor"] = python_platform.processor()
        capabilities["client_type"] = self.client_type()
        capabilities["namespace"] = True
        capabilities["raw_packets"] = True
        capabilities["chunked_compression"] = True
        capabilities["bencode"] = True
        capabilities["rencode"] = has_rencode
        if has_rencode:
            capabilities["rencode.version"] = rencode_version
        capabilities["hostname"] = socket.gethostname()
        capabilities["uuid"] = self.uuid
        try:
            from xpra.platform.info import get_username, get_name
            capabilities["username"] = get_username()
            capabilities["name"] = get_name()
        except:
            log.error("failed to get username/name", exc_info=True)
        capabilities["randr_notify"] = False    #only client.py cares about this
        capabilities["windows"] = False         #only client.py cares about this
        if self._reverse_aliases:
            capabilities["aliases"] = self._reverse_aliases
        return capabilities

    def make_uuid(self):
        try:
            import hashlib
            u = hashlib.sha1()
        except:
            #try python2.4 variant:
            import sha
            u = sha.new()
        def uupdate(ustr):
            u.update(ustr.encode("utf-8"))
        uupdate(get_machine_id())
        if os.name=="posix":
            uupdate(u"/")
            uupdate(str(os.getuid()))
            uupdate(u"/")
            uupdate(str(os.getgid()))
        self.uuid = u.hexdigest()

    def send(self, *parts):
        self._ordinary_packets.append(parts)
        self.have_more()

    def send_now(self, *parts):
        self._priority_packets.append(parts)
        self.have_more()

    def send_positional(self, packet):
        self._ordinary_packets.append(packet)
        self._mouse_position = None
        self.have_more()

    def send_mouse_position(self, packet):
        self._mouse_position = packet
        self.have_more()

    def have_more(self):
        #this function is overridden in setup_protocol()
        p = self._protocol
        if p and p.source:
            p.source_has_more()

    def next_packet(self):
        if self._priority_packets:
            packet = self._priority_packets.pop(0)
        elif self._ordinary_packets:
            packet = self._ordinary_packets.pop(0)
        elif self._mouse_position is not None:
            packet = self._mouse_position
            self._mouse_position = None
        else:
            packet = None
        has_more = packet is not None and \
                (bool(self._priority_packets) or bool(self._ordinary_packets) \
                 or self._mouse_position is not None)
        return packet, None, None, has_more


    def cleanup(self):
        log("XpraClientBase.cleanup() protocol=%s", self._protocol)
        if self._protocol:
            self._protocol.close()
            self._protocol = None

    def glib_init(self):
        try:
            glib = import_glib()
            try:
                glib.threads_init()
            except AttributeError:
                #old versions of glib may not have this method
                pass
        except ImportError:
            pass

    def run(self):
        self._protocol.start()

    def quit(self, exit_code=0):
        raise Exception("override me!")

    def warn_and_quit(self, exit_code, warning):
        log.warn(warning)
        self.quit(exit_code)

    def _process_disconnect(self, packet):
        if len(packet)==2:
            info = packet[1]
        else:
            info = packet[1:]
        e = EXIT_OK
        if self.server_capabilities is None or len(self.server_capabilities)==0:
            #server never sent hello to us - so disconnect is an error
            #(but we don't know which one - the info message may help)
            e = EXIT_FAILURE
        self.warn_and_quit(e, "server requested disconnect: %s" % info)

    def _process_connection_lost(self, packet):
        self.warn_and_quit(EXIT_CONNECTION_LOST, "Connection lost")

    def _process_challenge(self, packet):
        if not self.password_file and not self.password:
            self.warn_and_quit(EXIT_PASSWORD_REQUIRED, "password is required by the server")
            return
        if not self.password:
            if not self.load_password():
                return
            assert self.password
        salt = packet[1]
        if self.encryption:
            assert len(packet)>=3, "challenge does not contain encryption details to use for the response"
            server_cipher = packet[2]
            self.set_server_encryption(server_cipher)
        import hmac
        challenge_response = hmac.HMAC(self.password, salt)
        password_hash = challenge_response.hexdigest()
        self.password_sent = True
        self.send_hello(password_hash)

    def set_server_encryption(self, capabilities):
        def get(key, default=None):
            return capabilities.get(strtobytes(key), default)
        cipher = get("cipher")
        cipher_iv = get("cipher.iv")
        key_salt = get("cipher.key_salt")
        iterations = get("cipher.key_stretch_iterations")
        if not cipher or not cipher_iv:
            self.warn_and_quit(EXIT_ENCRYPTION, "the server does not use or support encryption/password, cannot continue with %s cipher" % self.encryption)
            return False
        if cipher not in ENCRYPTION_CIPHERS:
            self.warn_and_quit(EXIT_ENCRYPTION, "unsupported server cipher: %s, allowed ciphers: %s" % (cipher, ", ".join(ENCRYPTION_CIPHERS)))
            return False
        self._protocol.set_cipher_out(cipher, cipher_iv, self.get_password(), key_salt, iterations)


    def get_password(self):
        if self.password is None:
            self.load_password()
        return self.password

    def load_password(self):
        try:
            filename = os.path.expanduser(self.password_file)
            passwordFile = open(filename, "rU")
            self.password = passwordFile.read()
            passwordFile.close()
            while self.password.endswith("\n") or self.password.endswith("\r"):
                self.password = self.password[:-1]
        except IOError, e:
            self.warn_and_quit(EXIT_PASSWORD_FILE_ERROR, "failed to open password file %s: %s" % (self.password_file, e))
            return False
        log("password read from file %s is %s", self.password_file, self.password)
        return True
コード例 #15
0
ファイル: client_base.py プロジェクト: svn2github/Xpra
class XpraClientBase(object):
    """ Base class for Xpra clients.
        Provides the glue code for:
        * sending packets via Protocol
        * handling packets received via _process_packet
        For an actual implementation, look at:
        * GObjectXpraClient
        * xpra.client.gtk2.client
        * xpra.client.gtk3.client
    """

    def __init__(self):
        #this may be called more than once,
        #skip doing internal init again:
        if not hasattr(self, "exit_code"):
            self.defaults_init()

    def defaults_init(self):
        self.exit_code = None
        self.compression_level = 0
        self.display = None
        self.username = None
        self.password_file = None
        self.password_sent = False
        self.encryption = None
        self.encryption_keyfile = None
        self.quality = -1
        self.min_quality = 0
        self.speed = 0
        self.min_speed = -1
        self.file_transfer = False
        self.file_size_limit = 10
        self.printing = False
        self.printer_attributes = []
        self.send_printers_pending = False
        self.exported_printers = None
        self.open_files = False
        self.open_command = None
        #protocol stuff:
        self._protocol = None
        self._priority_packets = []
        self._ordinary_packets = []
        self._mouse_position = None
        self._aliases = {}
        self._reverse_aliases = {}
        #server state and caps:
        self.server_capabilities = None
        self._remote_machine_id = None
        self._remote_uuid = None
        self._remote_version = None
        self._remote_revision = None
        self._remote_platform = None
        self._remote_platform_release = None
        self._remote_platform_platform = None
        self._remote_platform_linux_distribution = None
        self.uuid = get_user_uuid()
        self.init_packet_handlers()
        sanity_checks()

    def init(self, opts):
        self.compression_level = opts.compression_level
        self.display = opts.display
        self.username = opts.username
        self.password_file = opts.password_file
        self.encryption = opts.encryption
        self.encryption_keyfile = opts.encryption_keyfile
        self.quality = opts.quality
        self.min_quality = opts.min_quality
        self.speed = opts.speed
        self.min_speed = opts.min_speed
        self.file_transfer = opts.file_transfer
        self.file_size_limit = opts.file_size_limit
        self.printing = opts.printing
        self.open_command = opts.open_command
        self.open_files = opts.open_files

        if DETECT_LEAKS:
            from xpra.util import detect_leaks
            detailed = []
            #example: warning, uses ugly direct import:
            #try:
            #    from xpra.x11.bindings.ximage import XShmImageWrapper       #@UnresolvedImport
            #    detailed.append(XShmImageWrapper)
            #except:
            #    pass
            print_leaks = detect_leaks(log, detailed)
            self.timeout_add(10*1000, print_leaks)


    def timeout_add(self, *args):
        raise Exception("override me!")

    def idle_add(self, *args):
        raise Exception("override me!")

    def source_remove(self, *args):
        raise Exception("override me!")


    def install_signal_handlers(self):
        def deadly_signal(signum, frame):
            sys.stderr.write("\ngot deadly signal %s, exiting\n" % SIGNAMES.get(signum, signum))
            sys.stderr.flush()
            self.cleanup()
            os._exit(128 + signum)
        def app_signal(signum, frame):
            sys.stderr.write("\ngot signal %s, exiting\n" % SIGNAMES.get(signum, signum))
            sys.stderr.flush()
            signal.signal(signal.SIGINT, deadly_signal)
            signal.signal(signal.SIGTERM, deadly_signal)
            self.timeout_add(0, self.disconnect_and_quit, 128 + signum, "exit on signal %s" % SIGNAMES.get(signum, signum))
        signal.signal(signal.SIGINT, app_signal)
        signal.signal(signal.SIGTERM, app_signal)

    def signal_disconnect_and_quit(self, exit_code, reason):
        self.idle_add(self.disconnect_and_quit, exit_code, reason)
        self.idle_add(self.quit, exit_code)
        self.idle_add(self.exit)

    def disconnect_and_quit(self, exit_code, reason):
        #make sure that we set the exit code early,
        #so the protocol shutdown won't set a different one:
        if self.exit_code is None:
            self.exit_code = exit_code
        #try to tell the server we're going, then quit
        log("disconnect_and_quit(%s, %s)", exit_code, reason)
        p = self._protocol
        if p is None or p._closed:
            self.quit(exit_code)
            return
        def protocol_closed():
            log("disconnect_and_quit: protocol_closed()")
            self.idle_add(self.quit, exit_code)
        if p:
            p.flush_then_close(["disconnect", reason], done_callback=protocol_closed)
        self.timeout_add(1000, self.quit, exit_code)

    def exit(self):
        sys.exit()


    def client_type(self):
        #overriden in subclasses!
        return "Python"

    def get_scheduler(self):
        raise NotImplementedError()

    def setup_connection(self, conn):
        log("setup_connection(%s)", conn)
        self._protocol = Protocol(self.get_scheduler(), conn, self.process_packet, self.next_packet)
        self._protocol.large_packets.append("keymap-changed")
        self._protocol.large_packets.append("server-settings")
        self._protocol.large_packets.append("logging")
        self._protocol.set_compression_level(self.compression_level)
        self._protocol.receive_aliases.update(self._aliases)
        self._protocol.enable_default_encoder()
        self._protocol.enable_default_compressor()
        self.have_more = self._protocol.source_has_more
        if conn.timeout>0:
            self.timeout_add((conn.timeout + EXTRA_TIMEOUT) * 1000, self.verify_connected)

    def init_packet_handlers(self):
        self._packet_handlers = {
            "hello"             : self._process_hello,
            }
        self._ui_packet_handlers = {
            "challenge":                self._process_challenge,
            "disconnect":               self._process_disconnect,
            "set_deflate":              self._process_set_deflate,
            "startup-complete":         self._process_startup_complete,
            Protocol.CONNECTION_LOST:   self._process_connection_lost,
            Protocol.GIBBERISH:         self._process_gibberish,
            Protocol.INVALID:           self._process_invalid,
            }

    def init_authenticated_packet_handlers(self):
        self._packet_handlers["send-file"] = self._process_send_file


    def init_aliases(self):
        packet_types = list(self._packet_handlers.keys())
        packet_types += list(self._ui_packet_handlers.keys())
        i = 1
        for key in packet_types:
            self._aliases[i] = key
            self._reverse_aliases[key] = i
            i += 1

    def send_hello(self, challenge_response=None, client_salt=None):
        try:
            hello = self.make_hello_base()
            if (self.password_file or os.environ.get('XPRA_PASSWORD')) and not challenge_response:
                #avoid sending the full hello: tell the server we want
                #a packet challenge first
                hello["challenge"] = True
            else:
                hello.update(self.make_hello())
        except Exception as e:
            log.error("error preparing connection: %s", e, exc_info=True)
            self.quit(EXIT_INTERNAL_ERROR)
            return
        if challenge_response:
            assert self.password_file or os.environ.get('XPRA_PASSWORD')
            hello["challenge_response"] = challenge_response
            if client_salt:
                hello["challenge_client_salt"] = client_salt
        log("send_hello(%s) packet=%s", binascii.hexlify(strtobytes(challenge_response or "")), hello)
        self.send("hello", hello)

    def verify_connected(self):
        if self.server_capabilities is None:
            #server has not said hello yet
            self.warn_and_quit(EXIT_TIMEOUT, "connection timed out")


    def make_hello_base(self):
        capabilities = get_network_caps()
        capabilities.update({
                "version"               : local_version,
                "encoding.generic"      : True,
                "namespace"             : True,
                "file-transfer"         : self.file_transfer,
                "file-size-limit"       : self.file_size_limit,
                "printing"              : self.printing,
                "hostname"              : socket.gethostname(),
                "uuid"                  : self.uuid,
                "username"              : self.username,
                "name"                  : get_name(),
                "client_type"           : self.client_type(),
                "python.version"        : sys.version_info[:3],
                "compression_level"     : self.compression_level,
                })
        if self.display:
            capabilities["display"] = self.display
        def up(prefix, d):
            updict(capabilities, prefix, d)
        up("platform",  get_platform_info())
        up("build",     get_version_info())
        mid = get_machine_id()
        if mid:
            capabilities["machine_id"] = mid

        if self.encryption:
            assert self.encryption in ENCRYPTION_CIPHERS
            iv = get_hex_uuid()[:16]
            key_salt = get_hex_uuid()+get_hex_uuid()
            iterations = 1000
            capabilities.update({
                        "cipher"                       : self.encryption,
                        "cipher.iv"                    : iv,
                        "cipher.key_salt"              : key_salt,
                        "cipher.key_stretch_iterations": iterations,
                        })
            key = self.get_encryption_key()
            if key is None:
                self.warn_and_quit(EXIT_ENCRYPTION, "encryption key is missing")
                return
            self._protocol.set_cipher_in(self.encryption, iv, key, key_salt, iterations)
            log("encryption capabilities: %s", [(k,v) for k,v in capabilities.items() if k.startswith("cipher")])
        return capabilities

    def make_hello(self):
        capabilities = {
                        "randr_notify"        : False,        #only client.py cares about this
                        "windows"            : False,        #only client.py cares about this
                       }
        if self._reverse_aliases:
            capabilities["aliases"] = self._reverse_aliases
        return capabilities

    def send(self, *parts):
        self._ordinary_packets.append(parts)
        self.have_more()

    def send_now(self, *parts):
        self._priority_packets.append(parts)
        self.have_more()

    def send_positional(self, packet):
        self._ordinary_packets.append(packet)
        self._mouse_position = None
        self.have_more()

    def send_mouse_position(self, packet):
        self._mouse_position = packet
        self.have_more()

    def have_more(self):
        #this function is overridden in setup_protocol()
        p = self._protocol
        if p and p.source:
            p.source_has_more()

    def next_packet(self):
        if self._priority_packets:
            packet = self._priority_packets.pop(0)
        elif self._ordinary_packets:
            packet = self._ordinary_packets.pop(0)
        elif self._mouse_position is not None:
            packet = self._mouse_position
            self._mouse_position = None
        else:
            packet = None
        has_more = packet is not None and \
                (bool(self._priority_packets) or bool(self._ordinary_packets) \
                 or self._mouse_position is not None)
        return packet, None, None, has_more


    def cleanup(self):
        log("XpraClientBase.cleanup() protocol=%s", self._protocol)
        if self._protocol:
            self._protocol.close()
            self._protocol = None
        self.cleanup_printing()

    def glib_init(self):
        try:
            glib = import_glib()
            try:
                glib.threads_init()
            except AttributeError:
                #old versions of glib may not have this method
                pass
        except ImportError:
            pass

    def run(self):
        self._protocol.start()

    def quit(self, exit_code=0):
        raise Exception("override me!")

    def warn_and_quit(self, exit_code, warning):
        log.warn(warning)
        self.quit(exit_code)

    def _process_disconnect(self, packet):
        #ie: ("disconnect", "version error", "incompatible version")
        reason = bytestostr(packet[1])
        info = packet[2:]
        s = nonl(reason)
        if len(info):
            s += " (%s)" % (", ".join([nonl(bytestostr(x)) for x in info]))
        if self.server_capabilities is None or len(self.server_capabilities)==0:
            #server never sent hello to us - so disconnect is an error
            #(but we don't know which one - the info message may help)
            log.warn("server failure: disconnected before the session could be established")
            e = EXIT_FAILURE
        elif disconnect_is_an_error(reason):
            log.warn("server failure: %s", reason)
            e = EXIT_FAILURE
        else:
            e = EXIT_OK
        self.warn_and_quit(e, "server requested disconnect: %s" % s)

    def _process_connection_lost(self, packet):
        p = self._protocol
        if p and p.input_raw_packetcount==0:
            props = p.get_info()
            c = props.get("compression", "unknown")
            e = props.get("encoder", "unknown")
            log.warn("failed to receive anything, not an xpra server?")
            log.warn("  could also be the wrong username, password or port")
            if c!="unknown" or e!="unknown":
                log.warn("  or maybe this server does not support '%s' compression or '%s' packet encoding?", c, e)
        self.warn_and_quit(EXIT_CONNECTION_LOST, "Connection lost")

    def _process_challenge(self, packet):
        log("processing challenge: %s", packet[1:])
        if not self.password_file and not os.environ.get('XPRA_PASSWORD'):
            self.warn_and_quit(EXIT_PASSWORD_REQUIRED, "server requires authentication, please provide a password")
            return
        password = self.load_password()
        if not password:
            self.warn_and_quit(EXIT_PASSWORD_FILE_ERROR, "failed to load password from file %s" % self.password_file)
            return
        salt = packet[1]
        if self.encryption:
            assert len(packet)>=3, "challenge does not contain encryption details to use for the response"
            server_cipher = packet[2]
            key = self.get_encryption_key()
            if key is None:
                self.warn_and_quit(EXIT_ENCRYPTION, "encryption key is missing")
                return
            if not self.set_server_encryption(server_cipher, key):
                return
        digest = "hmac"
        client_can_salt = len(packet)>=4
        client_salt = None
        if client_can_salt:
            #server supports client salt, and tells us which digest to use:
            digest = packet[3]
            client_salt = get_hex_uuid()+get_hex_uuid()
            #TODO: use some key stretching algorigthm? (meh)
            salt = xor(salt, client_salt)
        if digest=="hmac":
            import hmac
            challenge_response = hmac.HMAC(password, salt).hexdigest()
        elif digest=="xor":
            #don't send XORed password unencrypted:
            if not self._protocol.cipher_out and not ALLOW_UNENCRYPTED_PASSWORDS:
                self.warn_and_quit(EXIT_ENCRYPTION, "server requested digest %s, cowardly refusing to use it without encryption" % digest)
                return
            challenge_response = xor(password, salt)
        else:
            self.warn_and_quit(EXIT_PASSWORD_REQUIRED, "server requested an unsupported digest: %s" % digest)
            return
        if digest:
            log("%s(%s, %s)=%s", digest, password, salt, challenge_response)
        self.password_sent = True
        for d in (self._packet_handlers, self._ui_packet_handlers):
            try:
                del d["challenge"]
            except:
                pass
        self.send_hello(challenge_response, client_salt)

    def set_server_encryption(self, capabilities, key):
        def get(key, default=None):
            return capabilities.get(strtobytes(key), default)
        cipher = get("cipher")
        cipher_iv = get("cipher.iv")
        key_salt = get("cipher.key_salt")
        iterations = get("cipher.key_stretch_iterations")
        if not cipher or not cipher_iv:
            self.warn_and_quit(EXIT_ENCRYPTION, "the server does not use or support encryption/password, cannot continue with %s cipher" % self.encryption)
            return False
        if cipher not in ENCRYPTION_CIPHERS:
            self.warn_and_quit(EXIT_ENCRYPTION, "unsupported server cipher: %s, allowed ciphers: %s" % (cipher, ", ".join(ENCRYPTION_CIPHERS)))
            return False
        p = self._protocol
        if not p:
            return False
        p.set_cipher_out(cipher, cipher_iv, key, key_salt, iterations)
        return True


    def get_encryption_key(self):
        key = load_binary_file(self.encryption_keyfile)
        if not key:
            key = os.environ.get('XPRA_ENCRYPTION_KEY')
        if not key and self.password_file:
            key = load_binary_file(self.password_file)
            if key:
                log("used password file as encryption key")
        if not key:
            key = os.environ.get('XPRA_PASSWORD')
            if key:
                log("used XPRA_PASSWORD as encryption key")
        if key is None:
            raise Exception("no encryption key")
        return key.strip("\n\r")

    def load_password(self):
        if not self.password_file:
            return os.environ.get('XPRA_PASSWORD')
        filename = os.path.expanduser(self.password_file)
        password = load_binary_file(filename)
        if password is None:
            return None
        password = password.strip("\n\r")
        log("password read from file %s is %s", self.password_file, "".join(["*" for _ in password]))
        return password

    def _process_hello(self, packet):
        if not self.password_sent and (self.password_file or os.environ.get('XPRA_PASSWORD')):
            self.warn_and_quit(EXIT_NO_AUTHENTICATION, "the server did not request our password")
            return
        try:
            self.server_capabilities = typedict(packet[1])
            log("processing hello from server: %s", self.server_capabilities)
            self.server_connection_established()
        except Exception as e:
            log.info("error in hello packet", exc_info=True)
            self.warn_and_quit(EXIT_FAILURE, "error processing hello packet from server: %s" % e)

    def capsget(self, capabilities, key, default):
        v = capabilities.get(strtobytes(key), default)
        if sys.version >= '3' and type(v)==bytes:
            v = bytestostr(v)
        return v


    def server_connection_established(self):
        log("server_connection_established()")
        if not self.parse_version_capabilities():
            log("server_connection_established() failed version capabilities")
            return False
        if not self.parse_server_capabilities():
            log("server_connection_established() failed server capabilities")
            return False
        if not self.parse_network_capabilities():
            log("server_connection_established() failed network capabilities")
            return False
        if not self.parse_encryption_capabilities():
            log("server_connection_established() failed encryption capabilities")
            return False
        self.parse_printing_capabilities()
        log("server_connection_established() adding authenticated packet handlers")
        self.init_authenticated_packet_handlers()
        return True


    def parse_printing_capabilities(self):
        if self.printing:
            if self.server_capabilities.boolget("printing"):
                self.printer_attributes = self.server_capabilities.strlistget("printer.attributes", ["printer-info", "device-uri"])
                self.timeout_add(1000, self.init_printing)

    def init_printing(self):
        if not HAS_PRINTING:
            return
        try:
            from xpra.platform.printing import init_printing
            printlog("init_printing=%s", init_printing)
            init_printing(self.send_printers)
            self.do_send_printers()
        except Exception:
            log.warn("failed to send printers", exc_info=True)

    def cleanup_printing(self):
        if not HAS_PRINTING:
            return
        try:
            from xpra.platform.printing import cleanup_printing
            printlog("cleanup_printing=%s", cleanup_printing)
            cleanup_printing()
        except Exception:
            log.warn("failed to cleanup printing subsystem", exc_info=True)

    def send_printers(self, *args):
        #dbus can fire dozens of times for a single printer change
        #so we wait a bit and fire via a timer to try to batch things together:
        if self.send_printers_pending:
            return
        self.send_printers_pending = True
        self.timeout_add(500, self.do_send_printers)

    def do_send_printers(self):
        try:
            self.send_printers_pending = False
            from xpra.platform.printing import get_printers
            printers = get_printers()
            printlog("do_send_printers() found printers=%s", printers)
            #remove xpra-forwarded printers to avoid loops and multi-forwards,
            #also ignore stopped printers
            #and only keep the attributes that the server cares about (self.printer_attributes)
            exported_printers = {}
            def used_attrs(d):
                #filter attributes so that we only compare things that are actually used
                if not d:
                    return d
                return dict((k,v) for k,v in d.items() if k in self.printer_attributes)
            for k,v in printers.items():
                device_uri = v.get("device-uri", "")
                if device_uri:
                    #this is cups specific.. oh well
                    printlog("do_send_printers() device-uri(%s)=%s", k, device_uri)
                    if device_uri.startswith("xpraforwarder"):
                        printlog("do_send_printers() skipping xpra forwarded printer=%s", k)
                        continue
                state = v.get("printer-state")
                #"3" if the destination is idle,
                #"4" if the destination is printing a job,
                #"5" if the destination is stopped.
                if state==5:
                    printlog("do_send_printers() skipping stopped printer=%s", k)
                    continue
                exported_printers[k.encode("utf8")] = used_attrs(v)
            if self.exported_printers is None:
                #not been sent yet, ensure we can use the dict below:
                self.exported_printers = {}
            elif exported_printers==self.exported_printers:
                printlog("do_send_printers() exported printers unchanged: %s", self.exported_printers)
                return
            #show summary of what has changed:
            added = [k for k in exported_printers.keys() if k not in self.exported_printers]
            if added:
                printlog("do_send_printers() new printers: %s", added)
            removed = [k for k in self.exported_printers.keys() if k not in exported_printers]
            if removed:
                printlog("do_send_printers() printers removed: %s", removed)
            modified = [k for k,v in exported_printers.items() if self.exported_printers.get(k)!=v and k not in added]
            if modified:
                printlog("do_send_printers() printers modified: %s", modified)
            printlog("do_send_printers() printers=%s", exported_printers.keys())
            printlog("do_send_printers() exported printers=%s", ", ".join(str(x) for x in exported_printers.keys()))
            self.exported_printers = exported_printers
            self.send("printers", self.exported_printers)
        except:
            log.error("do_send_printers()", exc_info=True)


    def parse_version_capabilities(self):
        c = self.server_capabilities
        self._remote_machine_id = c.strget("machine_id")
        self._remote_uuid = c.strget("uuid")
        self._remote_version = c.strget("version")
        self._remote_version = c.strget("build.version", self._remote_version)
        self._remote_revision = c.strget("revision")
        self._remote_revision = c.strget("build.revision", self._remote_revision)
        self._remote_platform = c.strget("platform")
        self._remote_platform_release = c.strget("platform.release")
        self._remote_platform_platform = c.strget("platform.platform")
        #linux distribution is a tuple of different types, ie: ('Linux Fedora' , 20, 'Heisenbug')
        pld = c.listget("platform.linux_distribution")
        if pld and len(pld)==3:
            def san(v):
                if type(v)==int:
                    return v
                return bytestostr(v)
            self._remote_platform_linux_distribution = [san(x) for x in pld]
        verr = version_compat_check(self._remote_version)
        if verr is not None:
            self.warn_and_quit(EXIT_INCOMPATIBLE_VERSION, "incompatible remote version '%s': %s" % (self._remote_version, verr))
            return False
        return True

    def parse_server_capabilities(self):
        return True

    def parse_network_capabilities(self):
        c = self.server_capabilities
        p = self._protocol
        if not p or not p.enable_encoder_from_caps(c):
            return False
        p.enable_compressor_from_caps(c)
        return True

    def parse_encryption_capabilities(self):
        c = self.server_capabilities
        p = self._protocol
        if not p:
            return False
        p.send_aliases = c.dictget("aliases", {})
        if self.encryption:
            #server uses a new cipher after second hello:
            key = self.get_encryption_key()
            assert key, "encryption key is missing"
            if not self.set_server_encryption(c, key):
                return False
        return True

    def _process_set_deflate(self, packet):
        #legacy, should not be used for anything
        pass

    def _process_startup_complete(self, packet):
        #can be received if we connect with "xpra stop" or other command line client
        #as the server is starting up
        pass


    def _process_send_file(self, packet):
        #send-file basefilename, printit, openit, filesize, 0, data)
        from xpra.platform.features import DOWNLOAD_PATH
        basefilename, mimetype, printit, openit, filesize, file_data, options = packet[1:11]
        filelog("received file: %s", [basefilename, mimetype, printit, openit, filesize, "%s bytes" % len(file_data), options])
        options = typedict(options)
        if printit:
            assert self.printing
        else:
            assert self.file_transfer
        assert filesize>0 and file_data
        if len(file_data)!=filesize:
            log.error("Error: invalid data size for file '%s'", basefilename)
            log.error(" received %i bytes, expected %i bytes", len(file_data), filesize)
            return
        if filesize>self.file_size_limit*1024*1024:
            log.error("Error: file '%s' is too large:", basefilename)
            log.error(" %iMB, the file size limit is %iMB", filesize//1024//1024, self.file_size_limit)
            return
        #check digest if present:
        digest = options.get("sha1")
        if digest:
            import hashlib
            u = hashlib.sha1()
            u.update(file_data)
            filelog("sha1 digest: %s - expected: %s", u.hexdigest(), digest)
            assert digest==u.hexdigest(), "invalid file digest %s (expected %s)" % (u.hexdigest(), digest)

        #make sure we use a filename that does not exist already:
        wanted_filename = os.path.abspath(os.path.join(os.path.expanduser(DOWNLOAD_PATH), os.path.basename(basefilename)))
        EXTS = {"application/postscript"    : "ps",
                "application/pdf"           : "pdf",
                }
        ext = EXTS.get(mimetype)
        if ext:
            #on some platforms (win32),
            #we want to force an extension
            #so that the file manager can display them properly when you double-click on them
            if not wanted_filename.endswith("."+ext):
                wanted_filename += "."+ext
        filename = wanted_filename
        base = 0
        while os.path.exists(filename):
            filelog("cannot save file as %s: file already exists", filename)
            root, ext = os.path.splitext(wanted_filename)
            base += 1
            filename = root+("-%s" % base)+ext
        flags = os.O_CREAT | os.O_RDWR | os.O_EXCL
        try:
            flags |= os.O_BINARY                #@UndefinedVariable (win32 only)
        except:
            pass
        fd = os.open(filename, flags)
        try:
            os.write(fd, file_data)
        finally:
            os.close(fd)
        filelog.info("downloaded %s bytes to %s file %s", filesize, mimetype, filename)
        if printit:
            printer = options.strget("printer")
            title   = options.strget("title")
            print_options = options.dictget("options")
            #TODO: how do we print multiple copies?
            #copies = options.intget("copies")
            #whitelist of options we can forward:
            safe_print_options = dict((k,v) for k,v in print_options.items() if k in ("PageSize", "Resolution"))
            printlog("safe print options(%s) = %s", options, safe_print_options)
            self._print_file(filename, printer, title, safe_print_options)
            return
        elif openit:
            #run the command in a new thread
            #so we can block waiting for the subprocess to exit
            #(ensures that we do reap the process)
            import thread
            thread.start_new_thread(self._open_file, (filename, ))

    def _print_file(self, filename, printer, title, options):
        import time
        from xpra.platform.printing import print_files, printing_finished, get_printers
        printers = get_printers()
        if printer not in printers:
            printlog.error("Error: printer '%s' does not exist!", printer)
            printlog.error(" printers available: %s", printers.keys() or "none")
            return
        def delfile():
            if not DELETE_PRINTER_FILE:
                return
            try:
                os.unlink(filename)
            except:
                printlog("failed to delete print job file '%s'", filename)
            return False
        job = print_files(printer, [filename], title, options)
        printlog("printing %s, job=%s", filename, job)
        if job<=0:
            printlog("printing failed and returned %i", job)
            delfile()
            return
        start = time.time()
        def check_printing_finished():
            done = printing_finished(job)
            printlog("printing_finished(%s)=%s", job, done)
            if done:
                delfile()
                return False
            if time.time()-start>10*60:
                printlog.warn("print job %s timed out", job)
                delfile()
                return False
            return True #try again..
        if check_printing_finished():
            self.timeout_add(10000, check_printing_finished)

    def _open_file(self, filename):
        if not self.open_files:
            log.warn("Warning: opening files automatically is disabled,")
            log.warn(" ignoring uploaded file:")
            log.warn(" '%s'", filename)
            return
        import subprocess
        PIPE = subprocess.PIPE
        process = subprocess.Popen([self.open_command, filename], stdin=PIPE, stdout=PIPE, stderr=PIPE)
        out, err = process.communicate()
        r = process.wait()
        filelog.info("opened file %s with %s, exit code: %s", filename, self.open_command, r)
        if r!=0:
            l = filelog.warn
        else:
            l = filelog
        if out:
            l("stdout=%s", nonl(out)[:512])
        if err:
            l("stderr=%s", nonl(err)[:512])


    def _process_gibberish(self, packet):
        (_, message, data) = packet
        p = self._protocol
        show_as_text = p and p.input_packetcount==0 and all(c in string.printable for c in bytestostr(data))
        if show_as_text:
            #looks like the first packet back is just text, print it:
            data = bytestostr(data)
            if data.find("Traceback "):
                for x in data.split("\n"):
                    log.warn(x.strip("\r"))
            else:
                log.warn("Failed to connect, received: %s", repr_ellipsized(data.strip("\n").strip("\r")))
        else:
            log.warn("Received uninterpretable nonsense: %s", message)
            log.warn(" packet no %i data: %s", p.input_packetcount, repr_ellipsized(data))
        if str(data).find("assword")>0:
            self.warn_and_quit(EXIT_SSH_FAILURE,
                              "Your ssh program appears to be asking for a password."
                             + GOT_PASSWORD_PROMPT_SUGGESTION)
        elif str(data).find("login")>=0:
            self.warn_and_quit(EXIT_SSH_FAILURE,
                             "Your ssh program appears to be asking for a username.\n"
                             "Perhaps try using something like 'ssh:USER@host:display'?")
        else:
            self.quit(EXIT_PACKET_FAILURE)

    def _process_invalid(self, packet):
        (_, message, data) = packet
        log.info("Received invalid packet: %s", message)
        log(" data: %s", repr_ellipsized(data))
        self.quit(EXIT_PACKET_FAILURE)


    def process_packet(self, proto, packet):
        try:
            handler = None
            packet_type = packet[0]
            if packet_type!=int:
                packet_type = bytestostr(packet_type)
            handler = self._packet_handlers.get(packet_type)
            if handler:
                handler(packet)
                return
            handler = self._ui_packet_handlers.get(packet_type)
            if not handler:
                log.error("unknown packet type: %s", packet_type)
                return
            self.idle_add(handler, packet)
        except KeyboardInterrupt:
            raise
        except:
            log.error("Unhandled error while processing a '%s' packet from peer using %s", packet_type, handler, exc_info=True)
コード例 #16
0
class XpraClientBase(object):
    """ Base class for Xpra clients.
        Provides the glue code for:
        * sending packets via Protocol
        * handling packets received via _process_packet
        For an actual implementation, look at:
        * GObjectXpraClient
        * xpra.client.gtk2.client
        * xpra.client.gtk3.client
    """
    def __init__(self):
        #this may be called more than once,
        #skip doing internal init again:
        if not hasattr(self, "exit_code"):
            self.defaults_init()

    def defaults_init(self):
        self.exit_code = None
        self.compression_level = 0
        self.display = None
        self.username = None
        self.password_file = None
        self.password_sent = False
        self.encryption = None
        self.encryption_keyfile = None
        self.quality = -1
        self.min_quality = 0
        self.speed = 0
        self.min_speed = -1
        self.file_transfer = False
        self.file_size_limit = 10
        self.printing = False
        self.printer_attributes = []
        self.send_printers_pending = False
        self.exported_printers = None
        self.open_files = False
        self.open_command = None
        #protocol stuff:
        self._protocol = None
        self._priority_packets = []
        self._ordinary_packets = []
        self._mouse_position = None
        self._aliases = {}
        self._reverse_aliases = {}
        #server state and caps:
        self.server_capabilities = None
        self._remote_machine_id = None
        self._remote_uuid = None
        self._remote_version = None
        self._remote_revision = None
        self._remote_platform = None
        self._remote_platform_release = None
        self._remote_platform_platform = None
        self._remote_platform_linux_distribution = None
        self.uuid = get_user_uuid()
        self.init_packet_handlers()
        sanity_checks()

    def init(self, opts):
        self.compression_level = opts.compression_level
        self.display = opts.display
        self.username = opts.username
        self.password_file = opts.password_file
        self.encryption = opts.encryption
        self.encryption_keyfile = opts.encryption_keyfile
        self.quality = opts.quality
        self.min_quality = opts.min_quality
        self.speed = opts.speed
        self.min_speed = opts.min_speed
        self.file_transfer = opts.file_transfer
        self.file_size_limit = opts.file_size_limit
        self.printing = opts.printing
        self.open_command = opts.open_command
        self.open_files = opts.open_files

        if DETECT_LEAKS:
            from xpra.util import detect_leaks
            detailed = []
            #example: warning, uses ugly direct import:
            #try:
            #    from xpra.x11.bindings.ximage import XShmImageWrapper       #@UnresolvedImport
            #    detailed.append(XShmImageWrapper)
            #except:
            #    pass
            print_leaks = detect_leaks(log, detailed)
            self.timeout_add(10 * 1000, print_leaks)

    def timeout_add(self, *args):
        raise Exception("override me!")

    def idle_add(self, *args):
        raise Exception("override me!")

    def source_remove(self, *args):
        raise Exception("override me!")

    def install_signal_handlers(self):
        def deadly_signal(signum, frame):
            sys.stderr.write("\ngot deadly signal %s, exiting\n" %
                             SIGNAMES.get(signum, signum))
            sys.stderr.flush()
            self.cleanup()
            os._exit(128 + signum)

        def app_signal(signum, frame):
            sys.stderr.write("\ngot signal %s, exiting\n" %
                             SIGNAMES.get(signum, signum))
            sys.stderr.flush()
            signal.signal(signal.SIGINT, deadly_signal)
            signal.signal(signal.SIGTERM, deadly_signal)
            self.timeout_add(
                0, self.disconnect_and_quit, 128 + signum,
                "exit on signal %s" % SIGNAMES.get(signum, signum))

        signal.signal(signal.SIGINT, app_signal)
        signal.signal(signal.SIGTERM, app_signal)

    def signal_disconnect_and_quit(self, exit_code, reason):
        self.idle_add(self.disconnect_and_quit, exit_code, reason)
        self.idle_add(self.quit, exit_code)
        self.idle_add(self.exit)

    def disconnect_and_quit(self, exit_code, reason):
        #make sure that we set the exit code early,
        #so the protocol shutdown won't set a different one:
        if self.exit_code is None:
            self.exit_code = exit_code
        #try to tell the server we're going, then quit
        log("disconnect_and_quit(%s, %s)", exit_code, reason)
        p = self._protocol
        if p is None or p._closed:
            self.quit(exit_code)
            return

        def protocol_closed():
            log("disconnect_and_quit: protocol_closed()")
            self.idle_add(self.quit, exit_code)

        if p:
            p.flush_then_close(["disconnect", reason],
                               done_callback=protocol_closed)
        self.timeout_add(1000, self.quit, exit_code)

    def exit(self):
        sys.exit()

    def client_type(self):
        #overriden in subclasses!
        return "Python"

    def get_scheduler(self):
        raise NotImplementedError()

    def setup_connection(self, conn):
        log("setup_connection(%s)", conn)
        self._protocol = Protocol(self.get_scheduler(), conn,
                                  self.process_packet, self.next_packet)
        self._protocol.large_packets.append("keymap-changed")
        self._protocol.large_packets.append("server-settings")
        self._protocol.large_packets.append("logging")
        self._protocol.set_compression_level(self.compression_level)
        self._protocol.receive_aliases.update(self._aliases)
        self._protocol.enable_default_encoder()
        self._protocol.enable_default_compressor()
        self.have_more = self._protocol.source_has_more
        if conn.timeout > 0:
            self.timeout_add((conn.timeout + EXTRA_TIMEOUT) * 1000,
                             self.verify_connected)

    def init_packet_handlers(self):
        self._packet_handlers = {
            "hello": self._process_hello,
        }
        self._ui_packet_handlers = {
            "challenge": self._process_challenge,
            "disconnect": self._process_disconnect,
            "set_deflate": self._process_set_deflate,
            "startup-complete": self._process_startup_complete,
            Protocol.CONNECTION_LOST: self._process_connection_lost,
            Protocol.GIBBERISH: self._process_gibberish,
            Protocol.INVALID: self._process_invalid,
        }

    def init_authenticated_packet_handlers(self):
        self._packet_handlers["send-file"] = self._process_send_file

    def init_aliases(self):
        packet_types = list(self._packet_handlers.keys())
        packet_types += list(self._ui_packet_handlers.keys())
        i = 1
        for key in packet_types:
            self._aliases[i] = key
            self._reverse_aliases[key] = i
            i += 1

    def send_hello(self, challenge_response=None, client_salt=None):
        try:
            hello = self.make_hello_base()
            if (self.password_file or os.environ.get('XPRA_PASSWORD')
                ) and not challenge_response:
                #avoid sending the full hello: tell the server we want
                #a packet challenge first
                hello["challenge"] = True
            else:
                hello.update(self.make_hello())
        except Exception as e:
            log.error("error preparing connection: %s", e, exc_info=True)
            self.quit(EXIT_INTERNAL_ERROR)
            return
        if challenge_response:
            assert self.password_file or os.environ.get('XPRA_PASSWORD')
            hello["challenge_response"] = challenge_response
            if client_salt:
                hello["challenge_client_salt"] = client_salt
        log("send_hello(%s) packet=%s",
            binascii.hexlify(strtobytes(challenge_response or "")), hello)
        self.send("hello", hello)

    def verify_connected(self):
        if self.server_capabilities is None:
            #server has not said hello yet
            self.warn_and_quit(EXIT_TIMEOUT, "connection timed out")

    def make_hello_base(self):
        capabilities = get_network_caps()
        capabilities.update({
            "version": local_version,
            "encoding.generic": True,
            "namespace": True,
            "file-transfer": self.file_transfer,
            "file-size-limit": self.file_size_limit,
            "printing": self.printing,
            "hostname": socket.gethostname(),
            "uuid": self.uuid,
            "username": self.username,
            "name": get_name(),
            "client_type": self.client_type(),
            "python.version": sys.version_info[:3],
            "compression_level": self.compression_level,
        })
        if self.display:
            capabilities["display"] = self.display

        def up(prefix, d):
            updict(capabilities, prefix, d)

        up("platform", get_platform_info())
        up("build", get_version_info())
        mid = get_machine_id()
        if mid:
            capabilities["machine_id"] = mid

        if self.encryption:
            assert self.encryption in ENCRYPTION_CIPHERS
            iv = get_hex_uuid()[:16]
            key_salt = get_hex_uuid() + get_hex_uuid()
            iterations = 1000
            capabilities.update({
                "cipher": self.encryption,
                "cipher.iv": iv,
                "cipher.key_salt": key_salt,
                "cipher.key_stretch_iterations": iterations,
            })
            key = self.get_encryption_key()
            if key is None:
                self.warn_and_quit(EXIT_ENCRYPTION,
                                   "encryption key is missing")
                return
            self._protocol.set_cipher_in(self.encryption, iv, key, key_salt,
                                         iterations)
            log("encryption capabilities: %s",
                [(k, v)
                 for k, v in capabilities.items() if k.startswith("cipher")])
        return capabilities

    def make_hello(self):
        capabilities = {
            "randr_notify": False,  #only client.py cares about this
            "windows": False,  #only client.py cares about this
        }
        if self._reverse_aliases:
            capabilities["aliases"] = self._reverse_aliases
        return capabilities

    def send(self, *parts):
        self._ordinary_packets.append(parts)
        self.have_more()

    def send_now(self, *parts):
        self._priority_packets.append(parts)
        self.have_more()

    def send_positional(self, packet):
        self._ordinary_packets.append(packet)
        self._mouse_position = None
        self.have_more()

    def send_mouse_position(self, packet):
        self._mouse_position = packet
        self.have_more()

    def have_more(self):
        #this function is overridden in setup_protocol()
        p = self._protocol
        if p and p.source:
            p.source_has_more()

    def next_packet(self):
        if self._priority_packets:
            packet = self._priority_packets.pop(0)
        elif self._ordinary_packets:
            packet = self._ordinary_packets.pop(0)
        elif self._mouse_position is not None:
            packet = self._mouse_position
            self._mouse_position = None
        else:
            packet = None
        has_more = packet is not None and \
                (bool(self._priority_packets) or bool(self._ordinary_packets) \
                 or self._mouse_position is not None)
        return packet, None, None, has_more

    def cleanup(self):
        log("XpraClientBase.cleanup() protocol=%s", self._protocol)
        if self._protocol:
            self._protocol.close()
            self._protocol = None
        self.cleanup_printing()

    def glib_init(self):
        try:
            glib = import_glib()
            try:
                glib.threads_init()
            except AttributeError:
                #old versions of glib may not have this method
                pass
        except ImportError:
            pass

    def run(self):
        self._protocol.start()

    def quit(self, exit_code=0):
        raise Exception("override me!")

    def warn_and_quit(self, exit_code, warning):
        log.warn(warning)
        self.quit(exit_code)

    def _process_disconnect(self, packet):
        #ie: ("disconnect", "version error", "incompatible version")
        reason = bytestostr(packet[1])
        info = packet[2:]
        s = nonl(reason)
        if len(info):
            s += " (%s)" % (", ".join([nonl(bytestostr(x)) for x in info]))
        if self.server_capabilities is None or len(
                self.server_capabilities) == 0:
            #server never sent hello to us - so disconnect is an error
            #(but we don't know which one - the info message may help)
            log.warn(
                "server failure: disconnected before the session could be established"
            )
            e = EXIT_FAILURE
        elif disconnect_is_an_error(reason):
            log.warn("server failure: %s", reason)
            e = EXIT_FAILURE
        else:
            e = EXIT_OK
        self.warn_and_quit(e, "server requested disconnect: %s" % s)

    def _process_connection_lost(self, packet):
        p = self._protocol
        if p and p.input_raw_packetcount == 0:
            props = p.get_info()
            c = props.get("compression", "unknown")
            e = props.get("encoder", "unknown")
            log.warn("failed to receive anything, not an xpra server?")
            log.warn("  could also be the wrong username, password or port")
            if c != "unknown" or e != "unknown":
                log.warn(
                    "  or maybe this server does not support '%s' compression or '%s' packet encoding?",
                    c, e)
        self.warn_and_quit(EXIT_CONNECTION_LOST, "Connection lost")

    def _process_challenge(self, packet):
        log("processing challenge: %s", packet[1:])
        if not self.password_file and not os.environ.get('XPRA_PASSWORD'):
            self.warn_and_quit(
                EXIT_PASSWORD_REQUIRED,
                "server requires authentication, please provide a password")
            return
        password = self.load_password()
        if not password:
            self.warn_and_quit(
                EXIT_PASSWORD_FILE_ERROR,
                "failed to load password from file %s" % self.password_file)
            return
        salt = packet[1]
        if self.encryption:
            assert len(
                packet
            ) >= 3, "challenge does not contain encryption details to use for the response"
            server_cipher = packet[2]
            key = self.get_encryption_key()
            if key is None:
                self.warn_and_quit(EXIT_ENCRYPTION,
                                   "encryption key is missing")
                return
            if not self.set_server_encryption(server_cipher, key):
                return
        digest = "hmac"
        client_can_salt = len(packet) >= 4
        client_salt = None
        if client_can_salt:
            #server supports client salt, and tells us which digest to use:
            digest = packet[3]
            client_salt = get_hex_uuid() + get_hex_uuid()
            #TODO: use some key stretching algorigthm? (meh)
            salt = xor(salt, client_salt)
        if digest == "hmac":
            import hmac
            challenge_response = hmac.HMAC(password, salt).hexdigest()
        elif digest == "xor":
            #don't send XORed password unencrypted:
            if not self._protocol.cipher_out and not ALLOW_UNENCRYPTED_PASSWORDS:
                self.warn_and_quit(
                    EXIT_ENCRYPTION,
                    "server requested digest %s, cowardly refusing to use it without encryption"
                    % digest)
                return
            challenge_response = xor(password, salt)
        else:
            self.warn_and_quit(
                EXIT_PASSWORD_REQUIRED,
                "server requested an unsupported digest: %s" % digest)
            return
        if digest:
            log("%s(%s, %s)=%s", digest, password, salt, challenge_response)
        self.password_sent = True
        for d in (self._packet_handlers, self._ui_packet_handlers):
            try:
                del d["challenge"]
            except:
                pass
        self.send_hello(challenge_response, client_salt)

    def set_server_encryption(self, capabilities, key):
        def get(key, default=None):
            return capabilities.get(strtobytes(key), default)

        cipher = get("cipher")
        cipher_iv = get("cipher.iv")
        key_salt = get("cipher.key_salt")
        iterations = get("cipher.key_stretch_iterations")
        if not cipher or not cipher_iv:
            self.warn_and_quit(
                EXIT_ENCRYPTION,
                "the server does not use or support encryption/password, cannot continue with %s cipher"
                % self.encryption)
            return False
        if cipher not in ENCRYPTION_CIPHERS:
            self.warn_and_quit(
                EXIT_ENCRYPTION,
                "unsupported server cipher: %s, allowed ciphers: %s" %
                (cipher, ", ".join(ENCRYPTION_CIPHERS)))
            return False
        p = self._protocol
        if not p:
            return False
        p.set_cipher_out(cipher, cipher_iv, key, key_salt, iterations)
        return True

    def get_encryption_key(self):
        key = load_binary_file(self.encryption_keyfile)
        if not key:
            key = os.environ.get('XPRA_ENCRYPTION_KEY')
        if not key and self.password_file:
            key = load_binary_file(self.password_file)
            if key:
                log("used password file as encryption key")
        if not key:
            key = os.environ.get('XPRA_PASSWORD')
            if key:
                log("used XPRA_PASSWORD as encryption key")
        if key is None:
            raise Exception("no encryption key")
        return key.strip("\n\r")

    def load_password(self):
        if not self.password_file:
            return os.environ.get('XPRA_PASSWORD')
        filename = os.path.expanduser(self.password_file)
        password = load_binary_file(filename)
        if password is None:
            return None
        password = password.strip("\n\r")
        log("password read from file %s is %s", self.password_file,
            "".join(["*" for _ in password]))
        return password

    def _process_hello(self, packet):
        if not self.password_sent and (self.password_file
                                       or os.environ.get('XPRA_PASSWORD')):
            self.warn_and_quit(EXIT_NO_AUTHENTICATION,
                               "the server did not request our password")
            return
        try:
            self.server_capabilities = typedict(packet[1])
            log("processing hello from server: %s", self.server_capabilities)
            self.server_connection_established()
        except Exception as e:
            log.info("error in hello packet", exc_info=True)
            self.warn_and_quit(
                EXIT_FAILURE,
                "error processing hello packet from server: %s" % e)

    def capsget(self, capabilities, key, default):
        v = capabilities.get(strtobytes(key), default)
        if sys.version >= '3' and type(v) == bytes:
            v = bytestostr(v)
        return v

    def server_connection_established(self):
        log("server_connection_established()")
        if not self.parse_version_capabilities():
            log("server_connection_established() failed version capabilities")
            return False
        if not self.parse_server_capabilities():
            log("server_connection_established() failed server capabilities")
            return False
        if not self.parse_network_capabilities():
            log("server_connection_established() failed network capabilities")
            return False
        if not self.parse_encryption_capabilities():
            log("server_connection_established() failed encryption capabilities"
                )
            return False
        self.parse_printing_capabilities()
        log("server_connection_established() adding authenticated packet handlers"
            )
        self.init_authenticated_packet_handlers()
        return True

    def parse_printing_capabilities(self):
        if self.printing:
            if self.server_capabilities.boolget("printing"):
                self.printer_attributes = self.server_capabilities.strlistget(
                    "printer.attributes", ["printer-info", "device-uri"])
                self.timeout_add(1000, self.init_printing)

    def init_printing(self):
        if not HAS_PRINTING:
            return
        try:
            from xpra.platform.printing import init_printing
            printlog("init_printing=%s", init_printing)
            init_printing(self.send_printers)
            self.do_send_printers()
        except Exception:
            log.warn("failed to send printers", exc_info=True)

    def cleanup_printing(self):
        if not HAS_PRINTING:
            return
        try:
            from xpra.platform.printing import cleanup_printing
            printlog("cleanup_printing=%s", cleanup_printing)
            cleanup_printing()
        except Exception:
            log.warn("failed to cleanup printing subsystem", exc_info=True)

    def send_printers(self, *args):
        #dbus can fire dozens of times for a single printer change
        #so we wait a bit and fire via a timer to try to batch things together:
        if self.send_printers_pending:
            return
        self.send_printers_pending = True
        self.timeout_add(500, self.do_send_printers)

    def do_send_printers(self):
        try:
            self.send_printers_pending = False
            from xpra.platform.printing import get_printers
            printers = get_printers()
            printlog("do_send_printers() found printers=%s", printers)
            #remove xpra-forwarded printers to avoid loops and multi-forwards,
            #also ignore stopped printers
            #and only keep the attributes that the server cares about (self.printer_attributes)
            exported_printers = {}

            def used_attrs(d):
                #filter attributes so that we only compare things that are actually used
                if not d:
                    return d
                return dict((k, v) for k, v in d.items()
                            if k in self.printer_attributes)

            for k, v in printers.items():
                device_uri = v.get("device-uri", "")
                if device_uri:
                    #this is cups specific.. oh well
                    printlog("do_send_printers() device-uri(%s)=%s", k,
                             device_uri)
                    if device_uri.startswith("xpraforwarder"):
                        printlog(
                            "do_send_printers() skipping xpra forwarded printer=%s",
                            k)
                        continue
                state = v.get("printer-state")
                #"3" if the destination is idle,
                #"4" if the destination is printing a job,
                #"5" if the destination is stopped.
                if state == 5:
                    printlog("do_send_printers() skipping stopped printer=%s",
                             k)
                    continue
                exported_printers[k.encode("utf8")] = used_attrs(v)
            if self.exported_printers is None:
                #not been sent yet, ensure we can use the dict below:
                self.exported_printers = {}
            elif exported_printers == self.exported_printers:
                printlog("do_send_printers() exported printers unchanged: %s",
                         self.exported_printers)
                return
            #show summary of what has changed:
            added = [
                k for k in exported_printers.keys()
                if k not in self.exported_printers
            ]
            if added:
                printlog("do_send_printers() new printers: %s", added)
            removed = [
                k for k in self.exported_printers.keys()
                if k not in exported_printers
            ]
            if removed:
                printlog("do_send_printers() printers removed: %s", removed)
            modified = [
                k for k, v in exported_printers.items()
                if self.exported_printers.get(k) != v and k not in added
            ]
            if modified:
                printlog("do_send_printers() printers modified: %s", modified)
            printlog("do_send_printers() printers=%s",
                     exported_printers.keys())
            printlog("do_send_printers() exported printers=%s",
                     ", ".join(str(x) for x in exported_printers.keys()))
            self.exported_printers = exported_printers
            self.send("printers", self.exported_printers)
        except:
            log.error("do_send_printers()", exc_info=True)

    def parse_version_capabilities(self):
        c = self.server_capabilities
        self._remote_machine_id = c.strget("machine_id")
        self._remote_uuid = c.strget("uuid")
        self._remote_version = c.strget("version")
        self._remote_version = c.strget("build.version", self._remote_version)
        self._remote_revision = c.strget("revision")
        self._remote_revision = c.strget("build.revision",
                                         self._remote_revision)
        self._remote_platform = c.strget("platform")
        self._remote_platform_release = c.strget("platform.release")
        self._remote_platform_platform = c.strget("platform.platform")
        #linux distribution is a tuple of different types, ie: ('Linux Fedora' , 20, 'Heisenbug')
        pld = c.listget("platform.linux_distribution")
        if pld and len(pld) == 3:

            def san(v):
                if type(v) == int:
                    return v
                return bytestostr(v)

            self._remote_platform_linux_distribution = [san(x) for x in pld]
        verr = version_compat_check(self._remote_version)
        if verr is not None:
            self.warn_and_quit(
                EXIT_INCOMPATIBLE_VERSION,
                "incompatible remote version '%s': %s" %
                (self._remote_version, verr))
            return False
        return True

    def parse_server_capabilities(self):
        return True

    def parse_network_capabilities(self):
        c = self.server_capabilities
        p = self._protocol
        if not p or not p.enable_encoder_from_caps(c):
            return False
        p.enable_compressor_from_caps(c)
        return True

    def parse_encryption_capabilities(self):
        c = self.server_capabilities
        p = self._protocol
        if not p:
            return False
        p.send_aliases = c.dictget("aliases", {})
        if self.encryption:
            #server uses a new cipher after second hello:
            key = self.get_encryption_key()
            assert key, "encryption key is missing"
            if not self.set_server_encryption(c, key):
                return False
        return True

    def _process_set_deflate(self, packet):
        #legacy, should not be used for anything
        pass

    def _process_startup_complete(self, packet):
        #can be received if we connect with "xpra stop" or other command line client
        #as the server is starting up
        pass

    def _process_send_file(self, packet):
        #send-file basefilename, printit, openit, filesize, 0, data)
        from xpra.platform.features import DOWNLOAD_PATH
        basefilename, mimetype, printit, openit, filesize, file_data, options = packet[
            1:11]
        filelog("received file: %s", [
            basefilename, mimetype, printit, openit, filesize,
            "%s bytes" % len(file_data), options
        ])
        options = typedict(options)
        if printit:
            assert self.printing
        else:
            assert self.file_transfer
        assert filesize > 0 and file_data
        if len(file_data) != filesize:
            log.error("Error: invalid data size for file '%s'", basefilename)
            log.error(" received %i bytes, expected %i bytes", len(file_data),
                      filesize)
            return
        if filesize > self.file_size_limit * 1024 * 1024:
            log.error("Error: file '%s' is too large:", basefilename)
            log.error(" %iMB, the file size limit is %iMB",
                      filesize // 1024 // 1024, self.file_size_limit)
            return
        #check digest if present:
        digest = options.get("sha1")
        if digest:
            import hashlib
            u = hashlib.sha1()
            u.update(file_data)
            filelog("sha1 digest: %s - expected: %s", u.hexdigest(), digest)
            assert digest == u.hexdigest(
            ), "invalid file digest %s (expected %s)" % (u.hexdigest(), digest)

        #make sure we use a filename that does not exist already:
        wanted_filename = os.path.abspath(
            os.path.join(os.path.expanduser(DOWNLOAD_PATH),
                         os.path.basename(basefilename)))
        EXTS = {
            "application/postscript": "ps",
            "application/pdf": "pdf",
        }
        ext = EXTS.get(mimetype)
        if ext:
            #on some platforms (win32),
            #we want to force an extension
            #so that the file manager can display them properly when you double-click on them
            if not wanted_filename.endswith("." + ext):
                wanted_filename += "." + ext
        filename = wanted_filename
        base = 0
        while os.path.exists(filename):
            filelog("cannot save file as %s: file already exists", filename)
            root, ext = os.path.splitext(wanted_filename)
            base += 1
            filename = root + ("-%s" % base) + ext
        flags = os.O_CREAT | os.O_RDWR | os.O_EXCL
        try:
            flags |= os.O_BINARY  #@UndefinedVariable (win32 only)
        except:
            pass
        fd = os.open(filename, flags)
        try:
            os.write(fd, file_data)
        finally:
            os.close(fd)
        filelog.info("downloaded %s bytes to %s file %s", filesize, mimetype,
                     filename)
        if printit:
            printer = options.strget("printer")
            title = options.strget("title")
            print_options = options.dictget("options")
            #TODO: how do we print multiple copies?
            #copies = options.intget("copies")
            #whitelist of options we can forward:
            safe_print_options = dict((k, v) for k, v in print_options.items()
                                      if k in ("PageSize", "Resolution"))
            printlog("safe print options(%s) = %s", options,
                     safe_print_options)
            self._print_file(filename, printer, title, safe_print_options)
            return
        elif openit:
            #run the command in a new thread
            #so we can block waiting for the subprocess to exit
            #(ensures that we do reap the process)
            import thread
            thread.start_new_thread(self._open_file, (filename, ))

    def _print_file(self, filename, printer, title, options):
        import time
        from xpra.platform.printing import print_files, printing_finished, get_printers
        printers = get_printers()
        if printer not in printers:
            printlog.error("Error: printer '%s' does not exist!", printer)
            printlog.error(" printers available: %s",
                           printers.keys() or "none")
            return

        def delfile():
            if not DELETE_PRINTER_FILE:
                return
            try:
                os.unlink(filename)
            except:
                printlog("failed to delete print job file '%s'", filename)
            return False

        job = print_files(printer, [filename], title, options)
        printlog("printing %s, job=%s", filename, job)
        if job <= 0:
            printlog("printing failed and returned %i", job)
            delfile()
            return
        start = time.time()

        def check_printing_finished():
            done = printing_finished(job)
            printlog("printing_finished(%s)=%s", job, done)
            if done:
                delfile()
                return False
            if time.time() - start > 10 * 60:
                printlog.warn("print job %s timed out", job)
                delfile()
                return False
            return True  #try again..

        if check_printing_finished():
            self.timeout_add(10000, check_printing_finished)

    def _open_file(self, filename):
        if not self.open_files:
            log.warn("Warning: opening files automatically is disabled,")
            log.warn(" ignoring uploaded file:")
            log.warn(" '%s'", filename)
            return
        import subprocess
        PIPE = subprocess.PIPE
        process = subprocess.Popen([self.open_command, filename],
                                   stdin=PIPE,
                                   stdout=PIPE,
                                   stderr=PIPE)
        out, err = process.communicate()
        r = process.wait()
        filelog.info("opened file %s with %s, exit code: %s", filename,
                     self.open_command, r)
        if r != 0:
            l = filelog.warn
        else:
            l = filelog
        if out:
            l("stdout=%s", nonl(out)[:512])
        if err:
            l("stderr=%s", nonl(err)[:512])

    def _process_gibberish(self, packet):
        (_, message, data) = packet
        p = self._protocol
        show_as_text = p and p.input_packetcount == 0 and all(
            c in string.printable for c in bytestostr(data))
        if show_as_text:
            #looks like the first packet back is just text, print it:
            data = bytestostr(data)
            if data.find("Traceback "):
                for x in data.split("\n"):
                    log.warn(x.strip("\r"))
            else:
                log.warn("Failed to connect, received: %s",
                         repr_ellipsized(data.strip("\n").strip("\r")))
        else:
            log.warn("Received uninterpretable nonsense: %s", message)
            log.warn(" packet no %i data: %s", p.input_packetcount,
                     repr_ellipsized(data))
        if str(data).find("assword") > 0:
            self.warn_and_quit(
                EXIT_SSH_FAILURE,
                "Your ssh program appears to be asking for a password." +
                GOT_PASSWORD_PROMPT_SUGGESTION)
        elif str(data).find("login") >= 0:
            self.warn_and_quit(
                EXIT_SSH_FAILURE,
                "Your ssh program appears to be asking for a username.\n"
                "Perhaps try using something like 'ssh:USER@host:display'?")
        else:
            self.quit(EXIT_PACKET_FAILURE)

    def _process_invalid(self, packet):
        (_, message, data) = packet
        log.info("Received invalid packet: %s", message)
        log(" data: %s", repr_ellipsized(data))
        self.quit(EXIT_PACKET_FAILURE)

    def process_packet(self, proto, packet):
        try:
            handler = None
            packet_type = packet[0]
            if packet_type != int:
                packet_type = bytestostr(packet_type)
            handler = self._packet_handlers.get(packet_type)
            if handler:
                handler(packet)
                return
            handler = self._ui_packet_handlers.get(packet_type)
            if not handler:
                log.error("unknown packet type: %s", packet_type)
                return
            self.idle_add(handler, packet)
        except KeyboardInterrupt:
            raise
        except:
            log.error(
                "Unhandled error while processing a '%s' packet from peer using %s",
                packet_type,
                handler,
                exc_info=True)
コード例 #17
0
class XpraClientBase(FileTransferHandler):
    """ Base class for Xpra clients.
        Provides the glue code for:
        * sending packets via Protocol
        * handling packets received via _process_packet
        For an actual implementation, look at:
        * GObjectXpraClient
        * xpra.client.gtk2.client
        * xpra.client.gtk3.client
    """
    def __init__(self):
        FileTransferHandler.__init__(self)
        #this may be called more than once,
        #skip doing internal init again:
        if not hasattr(self, "exit_code"):
            self.defaults_init()

    def defaults_init(self):
        #skip warning when running the client
        from xpra import child_reaper
        child_reaper.POLL_WARNING = False
        getChildReaper()
        log("XpraClientBase.defaults_init() os.environ:")
        for k, v in os.environ.items():
            log(" %s=%s", k, nonl(v))
        #client state:
        self.exit_code = None
        self.exit_on_signal = False
        self.display_desc = {}
        #connection attributes:
        self.hello_extra = {}
        self.compression_level = 0
        self.display = None
        self.username = None
        self.password = None
        self.password_file = None
        self.password_sent = False
        self.encryption = None
        self.encryption_keyfile = None
        self.server_padding_options = [DEFAULT_PADDING]
        self.quality = -1
        self.min_quality = 0
        self.speed = 0
        self.min_speed = -1
        self.printer_attributes = []
        self.send_printers_timer = None
        self.exported_printers = None
        self.can_shutdown_server = True
        #protocol stuff:
        self._protocol = None
        self._priority_packets = []
        self._ordinary_packets = []
        self._mouse_position = None
        self._aliases = {}
        self._reverse_aliases = {}
        #server state and caps:
        self.server_capabilities = None
        self.completed_startup = False
        self._remote_machine_id = None
        self._remote_uuid = None
        self._remote_version = None
        self._remote_revision = None
        self._remote_platform = None
        self._remote_platform_release = None
        self._remote_platform_platform = None
        self._remote_platform_linux_distribution = None
        self.uuid = get_user_uuid()
        self.init_packet_handlers()
        sanity_checks()

    def init(self, opts):
        self.compression_level = opts.compression_level
        self.display = opts.display
        self.username = opts.username
        self.password = opts.password
        self.password_file = opts.password_file
        self.encryption = opts.encryption or opts.tcp_encryption
        if self.encryption:
            crypto_backend_init()
        self.encryption_keyfile = opts.encryption_keyfile or opts.tcp_encryption_keyfile
        self.quality = opts.quality
        self.min_quality = opts.min_quality
        self.speed = opts.speed
        self.min_speed = opts.min_speed
        #printing and file transfer:
        FileTransferHandler.init_opts(self, opts)

        if DETECT_LEAKS:
            from xpra.util import detect_leaks
            detailed = []
            #example: warning, uses ugly direct import:
            #try:
            #    from xpra.x11.bindings.ximage import XShmImageWrapper       #@UnresolvedImport
            #    detailed.append(XShmImageWrapper)
            #except:
            #    pass
            print_leaks = detect_leaks(log, detailed)
            self.timeout_add(10 * 1000, print_leaks)

    def timeout_add(self, *args):
        raise Exception("override me!")

    def idle_add(self, *args):
        raise Exception("override me!")

    def source_remove(self, *args):
        raise Exception("override me!")

    def install_signal_handlers(self):
        def deadly_signal(signum, frame):
            sys.stderr.write("\ngot deadly signal %s, exiting\n" %
                             SIGNAMES.get(signum, signum))
            sys.stderr.flush()
            self.cleanup()
            os._exit(128 + signum)

        def app_signal(signum, frame):
            sys.stderr.write("\ngot signal %s, exiting\n" %
                             SIGNAMES.get(signum, signum))
            sys.stderr.flush()
            signal.signal(signal.SIGINT, deadly_signal)
            signal.signal(signal.SIGTERM, deadly_signal)
            self.signal_cleanup()
            self.timeout_add(
                0, self.signal_disconnect_and_quit, 128 + signum,
                "exit on signal %s" % SIGNAMES.get(signum, signum))

        if PYTHON2:
            #breaks GTK3..
            signal.signal(signal.SIGINT, app_signal)
        signal.signal(signal.SIGTERM, app_signal)

    def signal_disconnect_and_quit(self, exit_code, reason):
        log("signal_disconnect_and_quit(%s, %s) exit_on_signal=%s", exit_code,
            reason, self.exit_on_signal)
        if not self.exit_on_signal:
            #if we get another signal, we'll try to exit without idle_add...
            self.exit_on_signal = True
            self.idle_add(self.disconnect_and_quit, exit_code, reason)
            self.idle_add(self.quit, exit_code)
            self.idle_add(self.exit)
            return
        #warning: this will run cleanup code from the signal handler
        self.disconnect_and_quit(exit_code, reason)
        self.quit(exit_code)
        self.exit()
        os._exit(exit_code)

    def signal_cleanup(self):
        #placeholder for stuff that can be cleaned up from the signal handler
        #(non UI thread stuff)
        pass

    def disconnect_and_quit(self, exit_code, reason):
        #make sure that we set the exit code early,
        #so the protocol shutdown won't set a different one:
        if self.exit_code is None:
            self.exit_code = exit_code
        #try to tell the server we're going, then quit
        log("disconnect_and_quit(%s, %s)", exit_code, reason)
        p = self._protocol
        if p is None or p._closed:
            self.quit(exit_code)
            return

        def protocol_closed():
            log("disconnect_and_quit: protocol_closed()")
            self.idle_add(self.quit, exit_code)

        if p:
            p.flush_then_close(["disconnect", reason],
                               done_callback=protocol_closed)
        self.timeout_add(1000, self.quit, exit_code)

    def exit(self):
        sys.exit()

    def client_type(self):
        #overriden in subclasses!
        return "Python"

    def get_scheduler(self):
        raise NotImplementedError()

    def setup_connection(self, conn):
        netlog("setup_connection(%s) timeout=%s", conn, conn.timeout)
        self._protocol = Protocol(self.get_scheduler(), conn,
                                  self.process_packet, self.next_packet)
        self._protocol.large_packets.append("keymap-changed")
        self._protocol.large_packets.append("server-settings")
        self._protocol.large_packets.append("logging")
        self._protocol.large_packets.append("input-devices")
        self._protocol.set_compression_level(self.compression_level)
        self._protocol.receive_aliases.update(self._aliases)
        self._protocol.enable_default_encoder()
        self._protocol.enable_default_compressor()
        if self.encryption and ENCRYPT_FIRST_PACKET:
            key = self.get_encryption_key()
            self._protocol.set_cipher_out(self.encryption, DEFAULT_IV, key,
                                          DEFAULT_SALT, DEFAULT_ITERATIONS,
                                          INITIAL_PADDING)
        self.have_more = self._protocol.source_has_more
        if conn.timeout > 0:
            self.timeout_add((conn.timeout + EXTRA_TIMEOUT) * 1000,
                             self.verify_connected)
        process = getattr(conn, "process",
                          None)  #ie: ssh is handled by anotherprocess
        if process:
            proc, name, command = process
            getChildReaper().add_process(proc,
                                         name,
                                         command,
                                         ignore=True,
                                         forget=False)
        netlog("setup_connection(%s) protocol=%s", conn, self._protocol)

    def remove_packet_handlers(self, *keys):
        for k in keys:
            for d in (self._packet_handlers, self._ui_packet_handlers):
                try:
                    del d[k]
                except:
                    pass

    def set_packet_handlers(self, to, defs):
        """ configures the given packet handlers,
            and make sure we remove any existing ones with the same key
            (which can be useful for subclasses, not here)
        """
        log("set_packet_handlers(%s, %s)", to, defs)
        self.remove_packet_handlers(*defs.keys())
        for k, v in defs.items():
            to[k] = v

    def init_packet_handlers(self):
        self._packet_handlers = {}
        self._ui_packet_handlers = {}
        self.set_packet_handlers(self._packet_handlers,
                                 {"hello": self._process_hello})
        self.set_packet_handlers(
            self._ui_packet_handlers, {
                "challenge": self._process_challenge,
                "disconnect": self._process_disconnect,
                "set_deflate": self._process_set_deflate,
                "startup-complete": self._process_startup_complete,
                Protocol.CONNECTION_LOST: self._process_connection_lost,
                Protocol.GIBBERISH: self._process_gibberish,
                Protocol.INVALID: self._process_invalid,
            })

    def init_authenticated_packet_handlers(self):
        self.set_packet_handlers(
            self._packet_handlers, {
                "send-file": self._process_send_file,
                "ack-file-chunk": self._process_ack_file_chunk,
                "send-file-chunk": self._process_send_file_chunk,
            })

    def init_aliases(self):
        packet_types = list(self._packet_handlers.keys())
        packet_types += list(self._ui_packet_handlers.keys())
        i = 1
        for key in packet_types:
            self._aliases[i] = key
            self._reverse_aliases[key] = i
            i += 1

    def has_password(self):
        return self.password or self.password_file or os.environ.get(
            'XPRA_PASSWORD')

    def send_hello(self, challenge_response=None, client_salt=None):
        try:
            hello = self.make_hello_base()
            if self.has_password() and not challenge_response:
                #avoid sending the full hello: tell the server we want
                #a packet challenge first
                hello["challenge"] = True
            else:
                hello.update(self.make_hello())
        except InitExit as e:
            log.error("error preparing connection:")
            log.error(" %s", e)
            self.quit(EXIT_INTERNAL_ERROR)
            return
        except Exception as e:
            log.error("error preparing connection: %s", e, exc_info=True)
            self.quit(EXIT_INTERNAL_ERROR)
            return
        if challenge_response:
            assert self.has_password(
            ), "got a password challenge response but we don't have a password! (malicious or broken server?)"
            hello["challenge_response"] = challenge_response
            if client_salt:
                hello["challenge_client_salt"] = client_salt
        log("send_hello(%s) packet=%s",
            binascii.hexlify(strtobytes(challenge_response or "")), hello)
        self.send("hello", hello)

    def verify_connected(self):
        if self.server_capabilities is None:
            #server has not said hello yet
            self.warn_and_quit(EXIT_TIMEOUT, "connection timed out")

    def make_hello_base(self):
        capabilities = flatten_dict(get_network_caps())
        import struct
        bits = struct.calcsize("P") * 8
        capabilities.update({
            "version": XPRA_VERSION,
            "encoding.generic": True,
            "namespace": True,
            "hostname": socket.gethostname(),
            "uuid": self.uuid,
            "username": self.username,
            "name": get_name(),
            "client_type": self.client_type(),
            "python.version": sys.version_info[:3],
            "python.bits": bits,
            "compression_level": self.compression_level,
            "argv": sys.argv,
        })
        capabilities.update(self.get_file_transfer_features())
        if self.display:
            capabilities["display"] = self.display

        def up(prefix, d):
            updict(capabilities, prefix, d)

        up("build", self.get_version_info())
        mid = get_machine_id()
        if mid:
            capabilities["machine_id"] = mid

        if self.encryption:
            assert self.encryption in ENCRYPTION_CIPHERS
            iv = get_iv()
            key_salt = get_salt()
            iterations = get_iterations()
            padding = choose_padding(self.server_padding_options)
            up(
                "cipher", {
                    "": self.encryption,
                    "iv": iv,
                    "key_salt": key_salt,
                    "key_stretch_iterations": iterations,
                    "padding": padding,
                    "padding.options": PADDING_OPTIONS,
                })
            key = self.get_encryption_key()
            if key is None:
                self.warn_and_quit(EXIT_ENCRYPTION,
                                   "encryption key is missing")
                return
            self._protocol.set_cipher_in(self.encryption, iv, key, key_salt,
                                         iterations, padding)
            netlog(
                "encryption capabilities: %s",
                dict((k, v) for k, v in capabilities.items()
                     if k.startswith("cipher")))
        capabilities.update(self.hello_extra)
        return capabilities

    def get_version_info(self):
        return get_version_info()

    def make_hello(self):
        capabilities = {
            "randr_notify": False,  #only client.py cares about this
            "windows": False,  #only client.py cares about this
        }
        if self._reverse_aliases:
            capabilities["aliases"] = self._reverse_aliases
        return capabilities

    def compressed_wrapper(self, datatype, data, level=5):
        #FIXME: ugly assumptions here, should pass by name!
        zlib = "zlib" in self.server_compressors and compression.use_zlib
        lz4 = "lz4" in self.server_compressors and compression.use_lz4
        lzo = "lzo" in self.server_compressors and compression.use_lzo
        if level > 0 and len(data) >= 256 and (zlib or lz4 or lzo):
            cw = compression.compressed_wrapper(datatype,
                                                data,
                                                level=level,
                                                zlib=zlib,
                                                lz4=lz4,
                                                lzo=lzo,
                                                can_inline=False)
            if len(cw) < len(data):
                #the compressed version is smaller, use it:
                return cw
        #we can't compress, so at least avoid warnings in the protocol layer:
        return compression.Compressed("raw %s" % datatype,
                                      data,
                                      can_inline=True)

    def send(self, *parts):
        self._ordinary_packets.append(parts)
        self.have_more()

    def send_now(self, *parts):
        self._priority_packets.append(parts)
        self.have_more()

    def send_positional(self, packet):
        self._ordinary_packets.append(packet)
        self._mouse_position = None
        self.have_more()

    def send_mouse_position(self, packet):
        self._mouse_position = packet
        self.have_more()

    def have_more(self):
        #this function is overridden in setup_protocol()
        p = self._protocol
        if p and p.source:
            p.source_has_more()

    def next_packet(self):
        if self._priority_packets:
            packet = self._priority_packets.pop(0)
        elif self._ordinary_packets:
            packet = self._ordinary_packets.pop(0)
        elif self._mouse_position is not None:
            packet = self._mouse_position
            self._mouse_position = None
        else:
            packet = None
        has_more = packet is not None and \
                (bool(self._priority_packets) or bool(self._ordinary_packets) \
                 or self._mouse_position is not None)
        return packet, None, None, has_more

    def cleanup(self):
        reaper_cleanup()
        #we must clean printing before FileTransferHandler, which turns the printing flag off!
        self.cleanup_printing()
        FileTransferHandler.cleanup(self)
        p = self._protocol
        log("XpraClientBase.cleanup() protocol=%s", p)
        if p:
            log("calling %s", p.close)
            p.close()
            self._protocol = None
        log("cleanup done")
        dump_all_frames()

    def glib_init(self):
        glib = import_glib()
        glib.threads_init()

    def run(self):
        self._protocol.start()

    def quit(self, exit_code=0):
        raise Exception("override me!")

    def warn_and_quit(self, exit_code, message):
        log.warn(message)
        self.quit(exit_code)

    def send_shutdown_server(self):
        assert self.can_shutdown_server
        self.send("shutdown-server")

    def _process_disconnect(self, packet):
        #ie: ("disconnect", "version error", "incompatible version")
        reason = bytestostr(packet[1])
        info = packet[2:]
        s = nonl(reason)
        if len(info):
            s += " (%s)" % (", ".join([nonl(bytestostr(x)) for x in info]))
        if self.server_capabilities is None or len(
                self.server_capabilities) == 0:
            #server never sent hello to us - so disconnect is an error
            #(but we don't know which one - the info message may help)
            log.warn(
                "server failure: disconnected before the session could be established"
            )
            e = EXIT_FAILURE
        elif disconnect_is_an_error(reason):
            log.warn("server failure: %s", reason)
            e = EXIT_FAILURE
        else:
            if self.exit_code is None:
                #we're not in the process of exiting already,
                #tell the user why the server is disconnecting us
                log.info("server requested disconnect: %s", s)
            self.quit(EXIT_OK)
            return
        self.warn_and_quit(e, "server requested disconnect: %s" % s)

    def _process_connection_lost(self, packet):
        p = self._protocol
        if p and p.input_raw_packetcount == 0:
            props = p.get_info()
            c = props.get("compression", "unknown")
            e = props.get("encoder", "unknown")
            netlog.error(
                "Error: failed to receive anything, not an xpra server?")
            netlog.error(
                "  could also be the wrong protocol, username, password or port"
            )
            if c != "unknown" or e != "unknown":
                netlog.error(
                    "  or maybe this server does not support '%s' compression or '%s' packet encoding?",
                    c, e)
        if self.exit_code != 0:
            self.warn_and_quit(EXIT_CONNECTION_LOST, "Connection lost")

    def _process_challenge(self, packet):
        authlog("processing challenge: %s", packet[1:])

        def warn_server_and_exit(code,
                                 message,
                                 server_message="authentication failed"):
            authlog.error("Error: authentication failed:")
            authlog.error(" %s", message)
            self.disconnect_and_quit(code, server_message)

        if not self.has_password():
            warn_server_and_exit(
                EXIT_PASSWORD_REQUIRED,
                "this server requires authentication, please provide a password",
                "no password available")
            return
        password = self.load_password()
        if not password:
            warn_server_and_exit(
                EXIT_PASSWORD_FILE_ERROR,
                "failed to load password from file %s" % self.password_file,
                "no password available")
            return
        server_salt = packet[1]
        if self.encryption:
            assert len(
                packet
            ) >= 3, "challenge does not contain encryption details to use for the response"
            server_cipher = typedict(packet[2])
            key = self.get_encryption_key()
            if key is None:
                warn_server_and_exit(EXIT_ENCRYPTION,
                                     "the server does not use any encryption",
                                     "client requires encryption")
                return
            if not self.set_server_encryption(server_cipher, key):
                return
        #all server versions support a client salt,
        #they also tell us which digest to use:
        digest = bytestostr(packet[3])
        l = len(server_salt)
        salt_digest = "xor"
        if len(packet) >= 5:
            salt_digest = bytestostr(packet[4])
        if salt_digest == "xor":
            #with xor, we have to match the size
            assert l >= 16, "server salt is too short: only %i bytes, minimum is 16" % l
            assert l <= 256, "server salt is too long: %i bytes, maximum is 256" % l
        else:
            #other digest, 32 random bytes is enough:
            l = 32
        client_salt = get_salt(l)
        if salt_digest in ("xor", "des"):
            if not LEGACY_SALT_DIGEST:
                warn_server_and_exit(
                    EXIT_INCOMPATIBLE_VERSION,
                    "server uses legacy salt digest '%s'" % salt_digest,
                    "unsupported digest %s" % salt_digest)
                return
            log.warn(
                "Warning: server using legacy support for '%s' salt digest",
                salt_digest)
        salt = gendigest(salt_digest, client_salt, server_salt)
        authlog("combined %s salt(%s, %s)=%s", salt_digest,
                binascii.hexlify(server_salt), binascii.hexlify(client_salt),
                binascii.hexlify(salt))
        if digest.startswith(b"hmac"):
            import hmac
            digestmod = get_digest_module(digest)
            if not digestmod:
                log("invalid digest module '%s': %s", digest)
                warn_server_and_exit(
                    EXIT_UNSUPPORTED,
                    "server requested digest '%s' but it is not supported" %
                    digest, "invalid digest")
                return
            password = strtobytes(password)
            salt = memoryview_to_bytes(salt)
            challenge_response = hmac.HMAC(password, salt,
                                           digestmod=digestmod).hexdigest()
            authlog("hmac.HMAC(%s, %s)=%s", binascii.hexlify(password),
                    binascii.hexlify(salt), challenge_response)
        elif digest == b"xor":
            #don't send XORed password unencrypted:
            encrypted = self._protocol.cipher_out or self._protocol.get_info(
            ).get("type") == "ssl"
            local = self.display_desc.get("local", False)
            authlog("xor challenge, encrypted=%s, local=%s", encrypted, local)
            if local and ALLOW_LOCALHOST_PASSWORDS:
                pass
            elif not encrypted and not ALLOW_UNENCRYPTED_PASSWORDS:
                warn_server_and_exit(
                    EXIT_ENCRYPTION,
                    "server requested digest %s, cowardly refusing to use it without encryption"
                    % digest, "invalid digest")
                return
            salt = salt[:len(password)]
            challenge_response = memoryview_to_bytes(xor(password, salt))
        else:
            warn_server_and_exit(
                EXIT_PASSWORD_REQUIRED,
                "server requested an unsupported digest: %s" % digest,
                "invalid digest")
            return
        if digest:
            authlog("%s(%s, %s)=%s", digest, binascii.hexlify(password),
                    binascii.hexlify(salt),
                    binascii.hexlify(challenge_response))
        self.password_sent = True
        self.remove_packet_handlers("challenge")
        self.send_hello(challenge_response, client_salt)

    def set_server_encryption(self, caps, key):
        cipher = caps.strget("cipher")
        cipher_iv = caps.strget("cipher.iv")
        key_salt = caps.strget("cipher.key_salt")
        iterations = caps.intget("cipher.key_stretch_iterations")
        padding = caps.strget("cipher.padding", DEFAULT_PADDING)
        #server may tell us what it supports,
        #either from hello response or from challenge packet:
        self.server_padding_options = caps.strlistget("cipher.padding.options",
                                                      [DEFAULT_PADDING])
        if not cipher or not cipher_iv:
            self.warn_and_quit(
                EXIT_ENCRYPTION,
                "the server does not use or support encryption/password, cannot continue with %s cipher"
                % self.encryption)
            return False
        if cipher not in ENCRYPTION_CIPHERS:
            self.warn_and_quit(
                EXIT_ENCRYPTION,
                "unsupported server cipher: %s, allowed ciphers: %s" %
                (cipher, ", ".join(ENCRYPTION_CIPHERS)))
            return False
        if padding not in ALL_PADDING_OPTIONS:
            self.warn_and_quit(
                EXIT_ENCRYPTION,
                "unsupported server cipher padding: %s, allowed ciphers: %s" %
                (padding, ", ".join(ALL_PADDING_OPTIONS)))
            return False
        p = self._protocol
        if not p:
            return False
        p.set_cipher_out(cipher, cipher_iv, key, key_salt, iterations, padding)
        return True

    def get_encryption_key(self):
        key = load_binary_file(self.encryption_keyfile)
        if not key:
            key = os.environ.get('XPRA_ENCRYPTION_KEY')
        if not key:
            raise InitExit(1, "no encryption key")
        return key.strip("\n\r")

    def load_password(self):
        if self.password:
            return self.password
        if not self.password_file:
            return os.environ.get('XPRA_PASSWORD')
        filename = os.path.expanduser(self.password_file)
        password = load_binary_file(filename)
        netlog("password read from file %s is %s", self.password_file,
               "".join(["*" for _ in (password or "")]))
        return password

    def _process_hello(self, packet):
        if not self.password_sent and self.has_password():
            self.warn_and_quit(EXIT_NO_AUTHENTICATION,
                               "the server did not request our password")
            return
        try:
            self.server_capabilities = typedict(packet[1])
            netlog("processing hello from server: %s",
                   self.server_capabilities)
            if not self.server_connection_established():
                self.warn_and_quit(EXIT_FAILURE,
                                   "failed to establish connection")
        except Exception as e:
            netlog.info("error in hello packet", exc_info=True)
            self.warn_and_quit(
                EXIT_FAILURE,
                "error processing hello packet from server: %s" % e)

    def capsget(self, capabilities, key, default):
        v = capabilities.get(strtobytes(key), default)
        if PYTHON3 and type(v) == bytes:
            v = bytestostr(v)
        return v

    def server_connection_established(self):
        netlog("server_connection_established()")
        if not self.parse_version_capabilities():
            netlog(
                "server_connection_established() failed version capabilities")
            return False
        if not self.parse_server_capabilities():
            netlog(
                "server_connection_established() failed server capabilities")
            return False
        if not self.parse_network_capabilities():
            netlog(
                "server_connection_established() failed network capabilities")
            return False
        if not self.parse_encryption_capabilities():
            netlog(
                "server_connection_established() failed encryption capabilities"
            )
            return False
        self.parse_printing_capabilities()
        self.parse_logging_capabilities()
        self.parse_file_transfer_caps(self.server_capabilities)
        #raise packet size if required:
        if self.file_transfer:
            self._protocol.max_packet_size = max(
                self._protocol.max_packet_size,
                self.file_size_limit * 1024 * 1024)
        netlog(
            "server_connection_established() adding authenticated packet handlers"
        )
        self.init_authenticated_packet_handlers()
        return True

    def parse_logging_capabilities(self):
        pass

    def parse_printing_capabilities(self):
        printlog("parse_printing_capabilities() client printing support=%s",
                 self.printing)
        if self.printing:
            server_printing = self.server_capabilities.boolget("printing")
            printlog(
                "parse_printing_capabilities() server printing support=%s",
                server_printing)
            if server_printing:
                self.printer_attributes = self.server_capabilities.strlistget(
                    "printer.attributes", ["printer-info", "device-uri"])
                self.timeout_add(1000, self.init_printing)

    def init_printing(self):
        try:
            from xpra.platform.printing import init_printing
            printlog("init_printing=%s", init_printing)
            init_printing(self.send_printers)
        except Exception as e:
            printlog.error("Error initializing printing support:")
            printlog.error(" %s", e)
            self.printing = False
        try:
            self.do_send_printers()
        except Exception:
            printlog.error("Error sending the list of printers:",
                           exc_info=True)
            self.printing = False
        printlog("init_printing() enabled=%s", self.printing)

    def cleanup_printing(self):
        printlog("cleanup_printing() printing=%s", self.printing)
        if not self.printing:
            return
        self.cancel_send_printers_timer()
        try:
            from xpra.platform.printing import cleanup_printing
            printlog("cleanup_printing=%s", cleanup_printing)
            cleanup_printing()
        except Exception:
            log.warn("failed to cleanup printing subsystem", exc_info=True)

    def send_printers(self, *args):
        printlog("send_printers%s timer=%s", args, self.send_printers_timer)
        #dbus can fire dozens of times for a single printer change
        #so we wait a bit and fire via a timer to try to batch things together:
        if self.send_printers_timer:
            return
        self.send_printers_timer = self.timeout_add(500, self.do_send_printers)

    def cancel_send_printers_timer(self):
        spt = self.send_printers_timer
        printlog("cancel_send_printers_timer() send_printers_timer=%s", spt)
        if spt:
            self.send_printers_timer = None
            self.source_remove(spt)

    def do_send_printers(self):
        try:
            self.send_printers_timer = None
            from xpra.platform.printing import get_printers, get_mimetypes
            try:
                printers = get_printers()
            except Exception as e:
                printlog("%s", get_printers, exc_info=True)
                printlog.error("Error: cannot access the list of printers")
                printlog.error(" %s", e)
                return
            printlog("do_send_printers() found printers=%s", printers)
            #remove xpra-forwarded printers to avoid loops and multi-forwards,
            #also ignore stopped printers
            #and only keep the attributes that the server cares about (self.printer_attributes)
            exported_printers = {}

            def used_attrs(d):
                #filter attributes so that we only compare things that are actually used
                if not d:
                    return d
                return dict((k, v) for k, v in d.items()
                            if k in self.printer_attributes)

            for k, v in printers.items():
                device_uri = v.get("device-uri", "")
                if device_uri:
                    #this is cups specific.. oh well
                    printlog("do_send_printers() device-uri(%s)=%s", k,
                             device_uri)
                    if device_uri.startswith("xpraforwarder"):
                        printlog(
                            "do_send_printers() skipping xpra forwarded printer=%s",
                            k)
                        continue
                state = v.get("printer-state")
                #"3" if the destination is idle,
                #"4" if the destination is printing a job,
                #"5" if the destination is stopped.
                if state == 5 and SKIP_STOPPED_PRINTERS:
                    printlog("do_send_printers() skipping stopped printer=%s",
                             k)
                    continue
                attrs = used_attrs(v)
                #add mimetypes:
                attrs["mimetypes"] = get_mimetypes()
                exported_printers[k.encode("utf8")] = attrs
            if self.exported_printers is None:
                #not been sent yet, ensure we can use the dict below:
                self.exported_printers = {}
            elif exported_printers == self.exported_printers:
                printlog("do_send_printers() exported printers unchanged: %s",
                         self.exported_printers)
                return
            #show summary of what has changed:
            added = [
                k for k in exported_printers.keys()
                if k not in self.exported_printers
            ]
            if added:
                printlog("do_send_printers() new printers: %s", added)
            removed = [
                k for k in self.exported_printers.keys()
                if k not in exported_printers
            ]
            if removed:
                printlog("do_send_printers() printers removed: %s", removed)
            modified = [
                k for k, v in exported_printers.items()
                if self.exported_printers.get(k) != v and k not in added
            ]
            if modified:
                printlog("do_send_printers() printers modified: %s", modified)
            printlog("do_send_printers() printers=%s",
                     exported_printers.keys())
            printlog("do_send_printers() exported printers=%s",
                     ", ".join(str(x) for x in exported_printers.keys()))
            self.exported_printers = exported_printers
            self.send("printers", self.exported_printers)
        except:
            printlog.error("do_send_printers()", exc_info=True)

    def parse_version_capabilities(self):
        c = self.server_capabilities
        self._remote_machine_id = c.strget("machine_id")
        self._remote_uuid = c.strget("uuid")
        self._remote_version = c.strget("build.version", c.strget("version"))
        self._remote_revision = c.strget("build.revision",
                                         c.strget("revision"))
        self._remote_platform = c.strget("platform")
        self._remote_platform_release = c.strget("platform.release")
        self._remote_platform_platform = c.strget("platform.platform")
        #linux distribution is a tuple of different types, ie: ('Linux Fedora' , 20, 'Heisenbug')
        pld = c.listget("platform.linux_distribution")
        if pld and len(pld) == 3:

            def san(v):
                if type(v) == int:
                    return v
                return bytestostr(v)

            self._remote_platform_linux_distribution = [san(x) for x in pld]
        verr = version_compat_check(self._remote_version)
        if verr is not None:
            self.warn_and_quit(
                EXIT_INCOMPATIBLE_VERSION,
                "incompatible remote version '%s': %s" %
                (self._remote_version, verr))
            return False
        return True

    def parse_server_capabilities(self):
        c = self.server_capabilities
        self.can_shutdown_server = c.boolget("client-shutdown", True)
        return True

    def parse_network_capabilities(self):
        c = self.server_capabilities
        p = self._protocol
        if not p or not p.enable_encoder_from_caps(c):
            return False
        p.enable_compressor_from_caps(c)
        return True

    def parse_encryption_capabilities(self):
        c = self.server_capabilities
        p = self._protocol
        if not p:
            return False
        p.send_aliases = c.dictget("aliases", {})
        if self.encryption:
            #server uses a new cipher after second hello:
            key = self.get_encryption_key()
            assert key, "encryption key is missing"
            if not self.set_server_encryption(c, key):
                return False
        return True

    def _process_set_deflate(self, packet):
        #legacy, should not be used for anything
        pass

    def _process_startup_complete(self, packet):
        #can be received if we connect with "xpra stop" or other command line client
        #as the server is starting up
        self.completed_startup = packet

    def _process_gibberish(self, packet):
        (_, message, data) = packet
        p = self._protocol
        show_as_text = p and p.input_packetcount == 0 and all(
            c in string.printable for c in bytestostr(data))
        if show_as_text:
            #looks like the first packet back is just text, print it:
            data = bytestostr(data)
            if data.find("Traceback ") >= 0:
                for x in data.split("\n"):
                    netlog.warn(x.strip("\r"))
            else:
                netlog.warn("Failed to connect, received: %s",
                            repr_ellipsized(data.strip("\n").strip("\r")))
        else:
            netlog.warn("Received uninterpretable nonsense: %s", message)
            netlog.warn(" packet no %i data: %s", p.input_packetcount,
                        repr_ellipsized(data))
        self.quit(EXIT_PACKET_FAILURE)

    def _process_invalid(self, packet):
        (_, message, data) = packet
        netlog.info("Received invalid packet: %s", message)
        netlog(" data: %s", repr_ellipsized(data))
        self.quit(EXIT_PACKET_FAILURE)

    def process_packet(self, proto, packet):
        try:
            handler = None
            packet_type = packet[0]
            if packet_type != int:
                packet_type = bytestostr(packet_type)
            handler = self._packet_handlers.get(packet_type)
            if handler:
                handler(packet)
                return
            handler = self._ui_packet_handlers.get(packet_type)
            if not handler:
                netlog.error("unknown packet type: %s", packet_type)
                return
            self.idle_add(handler, packet)
        except KeyboardInterrupt:
            raise
        except:
            netlog.error(
                "Unhandled error while processing a '%s' packet from peer using %s",
                packet_type,
                handler,
                exc_info=True)
コード例 #18
0
class ProxyInstanceProcess(Process):
    def __init__(self, uid, gid, env_options, session_options, socket_dir,
                 video_encoder_modules, csc_modules, client_conn, client_state,
                 cipher, encryption_key, server_conn, caps, message_queue):
        Process.__init__(self, name=str(client_conn))
        self.uid = uid
        self.gid = gid
        self.env_options = env_options
        self.session_options = self.sanitize_session_options(session_options)
        self.socket_dir = socket_dir
        self.video_encoder_modules = video_encoder_modules
        self.csc_modules = csc_modules
        self.client_conn = client_conn
        self.client_state = client_state
        self.cipher = cipher
        self.encryption_key = encryption_key
        self.server_conn = server_conn
        self.caps = caps
        log("ProxyProcess%s", (uid, gid, client_conn, client_state, cipher,
                               encryption_key, server_conn, "{..}"))
        self.client_protocol = None
        self.server_protocol = None
        self.exit = False
        self.main_queue = None
        self.message_queue = message_queue
        self.encode_queue = None  #holds draw packets to encode
        self.encode_thread = None
        self.video_encoding_defs = None
        self.video_encoders = None
        self.video_encoders_last_used_time = None
        self.video_encoder_types = None
        self.video_helper = None
        self.lost_windows = None
        #for handling the local unix domain socket:
        self.control_socket = None
        self.control_socket_thread = None
        self.control_socket_path = None
        self.potential_protocols = []
        self.max_connections = MAX_CONCURRENT_CONNECTIONS

    def server_message_queue(self):
        while True:
            log("waiting for server message on %s", self.message_queue)
            m = self.message_queue.get()
            log.info("proxy server message: %s", m)
            if m == "stop":
                self.stop("proxy server request")
                return

    def signal_quit(self, signum, frame):
        log.info("")
        log.info("proxy process pid %s got signal %s, exiting", os.getpid(),
                 SIGNAMES.get(signum, signum))
        self.exit = True
        signal.signal(signal.SIGINT, deadly_signal)
        signal.signal(signal.SIGTERM, deadly_signal)
        self.stop(SIGNAMES.get(signum, signum))

    def idle_add(self, fn, *args, **kwargs):
        #we emulate gobject's idle_add using a simple queue
        self.main_queue.put((fn, args, kwargs))

    def timeout_add(self, timeout, fn, *args, **kwargs):
        #emulate gobject's timeout_add using idle add and a Timer
        #using custom functions to cancel() the timer when needed
        def idle_exec():
            v = fn(*args, **kwargs)
            if bool(v):
                self.timeout_add(timeout, fn, *args, **kwargs)
            return False

        def timer_exec():
            #just run via idle_add:
            self.idle_add(idle_exec)

        Timer(timeout / 1000.0, timer_exec).start()

    def run(self):
        log("ProxyProcess.run() pid=%s, uid=%s, gid=%s", os.getpid(),
            os.getuid(), os.getgid())
        #change uid and gid:
        if os.getgid() != self.gid:
            os.setgid(self.gid)
        if os.getuid() != self.uid:
            os.setuid(self.uid)
        log("ProxyProcess.run() new uid=%s, gid=%s", os.getuid(), os.getgid())

        if self.env_options:
            #TODO: whitelist env update?
            os.environ.update(self.env_options)
        self.video_init()

        log.info("new proxy started for client %s and server %s",
                 self.client_conn, self.server_conn)

        signal.signal(signal.SIGTERM, self.signal_quit)
        signal.signal(signal.SIGINT, self.signal_quit)
        log("registered signal handler %s", self.signal_quit)

        make_daemon_thread(self.server_message_queue,
                           "server message queue").start()

        if self.create_control_socket():
            self.control_socket_thread = make_daemon_thread(
                self.control_socket_loop, "control")
            self.control_socket_thread.start()

        self.main_queue = Queue()
        #setup protocol wrappers:
        self.server_packets = Queue(PROXY_QUEUE_SIZE)
        self.client_packets = Queue(PROXY_QUEUE_SIZE)
        self.client_protocol = Protocol(self, self.client_conn,
                                        self.process_client_packet,
                                        self.get_client_packet)
        self.client_protocol.restore_state(self.client_state)
        self.server_protocol = Protocol(self, self.server_conn,
                                        self.process_server_packet,
                                        self.get_server_packet)
        #server connection tweaks:
        self.server_protocol.large_packets.append("draw")
        self.server_protocol.large_packets.append("window-icon")
        self.server_protocol.large_packets.append("keymap-changed")
        self.server_protocol.large_packets.append("server-settings")
        self.server_protocol.set_compression_level(
            self.session_options.get("compression_level", 0))

        self.lost_windows = set()
        self.encode_queue = Queue()
        self.encode_thread = make_daemon_thread(self.encode_loop, "encode")
        self.encode_thread.start()

        log("starting network threads")
        self.server_protocol.start()
        self.client_protocol.start()

        #forward the hello packet:
        hello_packet = ("hello", self.filter_client_caps(self.caps))
        self.queue_server_packet(hello_packet)
        self.timeout_add(VIDEO_TIMEOUT * 1000, self.timeout_video_encoders)

        try:
            try:
                self.run_queue()
            except KeyboardInterrupt, e:
                self.stop(str(e))
        finally:
            log("ProxyProcess.run() ending %s", os.getpid())

    def video_init(self):
        log("video_init() will try video encoders: %s",
            self.video_encoder_modules)
        self.video_helper = getVideoHelper()
        #only use video encoders (no CSC supported in proxy)
        self.video_helper.set_modules(
            video_encoders=self.video_encoder_modules)
        self.video_helper.init()

        self.video_encoding_defs = {}
        self.video_encoders = {}
        self.video_encoders_dst_formats = []
        self.video_encoders_last_used_time = {}
        self.video_encoder_types = []

        #figure out which encoders we want to proxy for (if any):
        encoder_types = set()
        for encoding in self.video_helper.get_encodings():
            colorspace_specs = self.video_helper.get_encoder_specs(encoding)
            for colorspace, especs in colorspace_specs.items():
                if colorspace not in ("BGRX", "BGRA", "RGBX", "RGBA"):
                    #only deal with encoders that can handle plain RGB directly
                    continue

                for spec in especs:  #ie: codec_spec("x264")
                    spec_props = spec.to_dict()
                    del spec_props["codec_class"]  #not serializable!
                    spec_props[
                        "score_boost"] = 50  #we want to win scoring so we get used ahead of other encoders
                    spec_props[
                        "max_instances"] = 3  #limit to 3 video streams we proxy for (we really want 2,
                    # but because of races with garbage collection, we need to allow more)
                    #store it in encoding defs:
                    self.video_encoding_defs.setdefault(
                        encoding, {}).setdefault(colorspace,
                                                 []).append(spec_props)
                    encoder_types.add(spec.codec_type)

        log("encoder types found: %s", encoder_types)
        #remove duplicates and use preferred order:
        order = PREFERRED_ENCODER_ORDER[:]
        for x in list(encoder_types):
            if x not in order:
                order.append(x)
        self.video_encoder_types = [x for x in order if x in encoder_types]
        log.info("proxy video encoders: %s", self.video_encoder_types)

    def create_control_socket(self):
        dotxpra = DotXpra(self.socket_dir)
        name = "proxy-%s" % os.getpid()
        sockpath = dotxpra.norm_make_path(name, dotxpra.sockdir())
        state = dotxpra.get_server_state(sockpath)
        if state in (DotXpra.LIVE, DotXpra.UNKNOWN):
            log.warn(
                "You already have a proxy server running at %s, the control socket will not be created!",
                sockpath)
            return False
        try:
            sock = create_unix_domain_socket(sockpath, None)
            sock.listen(5)
        except Exception, e:
            log.warn("failed to setup control socket %s: %s", sockpath, e)
            return False
        self.control_socket = sock
        self.control_socket_path = sockpath
        log.info(
            "proxy instance now also available using unix domain socket: %s",
            self.control_socket_path)
        return True