Exemplo n.º 1
0
 def terminate_queue_threads(self):
     log("terminate_queue_threads()")
     #the format thread will exit:
     self._get_packet_cb = None
     self._source_has_more.set()
     #make all the queue based threads exit by adding the empty marker:
     exit_queue = Queue()
     for _ in range(10):  #just 2 should be enough!
         exit_queue.put(None)
     try:
         owq = self._write_queue
         self._write_queue = exit_queue
         #discard all elements in the old queue and push the None marker:
         try:
             while owq.qsize() > 0:
                 owq.read(False)
         except:
             pass
         owq.put_nowait(None)
     except:
         pass
     try:
         orq = self._read_queue
         self._read_queue = exit_queue
         #discard all elements in the old queue and push the None marker:
         try:
             while orq.qsize() > 0:
                 orq.read(False)
         except:
             pass
         orq.put_nowait(None)
     except:
         pass
     #just in case the read thread is waiting again:
     self._source_has_more.set()
Exemplo n.º 2
0
 def __init__(self, src_type=None, src_options={}, codecs=get_codecs(), codec_options={}, volume=1.0):
     if not src_type:
         from xpra.sound.pulseaudio_util import get_pa_device_options
         monitor_devices = get_pa_device_options(True, False)
         log.info("found pulseaudio monitor devices: %s", monitor_devices)
         if len(monitor_devices)==0:
             log.warn("could not detect any pulseaudio monitor devices")
             log.warn(" a test source will be used instead")
             src_type = "audiotestsrc"
             default_src_options = {"wave":2, "freq":100, "volume":0.4}
         else:
             monitor_device = monitor_devices.items()[0][0]
             log.info("using pulseaudio source device:")
             log.info(" '%s'", monitor_device)
             src_type = "pulsesrc"
             default_src_options = {"device" : monitor_device}
         src_options = default_src_options
     if src_type not in get_source_plugins():
         raise InitExit(1, "invalid source plugin '%s', valid options are: %s" % (src_type, ",".join(get_source_plugins())))
     matching = [x for x in CODEC_ORDER if (x in codecs and x in get_codecs())]
     log("SoundSource(..) found matching codecs %s", matching)
     if not matching:
         raise InitExit(1, "no matching codecs between arguments '%s' and supported list '%s'" % (csv(codecs), csv(get_codecs().keys())))
     codec = matching[0]
     encoder, fmt = get_encoder_formatter(codec)
     SoundPipeline.__init__(self, codec)
     self.src_type = src_type
     source_str = plugin_str(src_type, src_options)
     #FIXME: this is ugly and relies on the fact that we don't pass any codec options to work!
     encoder_str = plugin_str(encoder, codec_options or ENCODER_DEFAULT_OPTIONS.get(encoder, {}))
     fmt_str = plugin_str(fmt, MUXER_DEFAULT_OPTIONS.get(fmt, {}))
     pipeline_els = [source_str]
     if encoder in ENCODER_NEEDS_AUDIOCONVERT or src_type in SOURCE_NEEDS_AUDIOCONVERT:
         pipeline_els += ["audioconvert"]
     pipeline_els.append("volume name=volume volume=%s" % volume)
     pipeline_els += [encoder_str,
                     fmt_str,
                     APPSINK]
     self.setup_pipeline_and_bus(pipeline_els)
     self.volume = self.pipeline.get_by_name("volume")
     self.sink = self.pipeline.get_by_name("sink")
     try:
         if get_gst_version()<(1,0):
             self.sink.set_property("enable-last-buffer", False)
         else:
             self.sink.set_property("enable-last-sample", False)
     except Exception as e:
         log("failed to disable last buffer: %s", e)
     self.caps = None
     self.skipped_caps = set()
     if JITTER>0:
         self.jitter_queue = Queue()
     try:
         #Gst 1.0:
         self.sink.connect("new-sample", self.on_new_sample)
         self.sink.connect("new-preroll", self.on_new_preroll1)
     except:
         #Gst 0.10:
         self.sink.connect("new-buffer", self.on_new_buffer)
         self.sink.connect("new-preroll", self.on_new_preroll0)
Exemplo n.º 3
0
    def __init__(self, core_encodings, encodings, default_encoding,
                 scaling_control, default_quality, default_min_quality,
                 default_speed, default_min_speed):
        log("ServerSource%s",
            (core_encodings, encodings, default_encoding, scaling_control,
             default_quality, default_min_quality, default_speed,
             default_min_speed))
        self.server_core_encodings = core_encodings
        self.server_encodings = encodings
        self.default_encoding = default_encoding
        self.scaling_control = scaling_control

        self.default_quality = default_quality  #default encoding quality for lossy encodings
        self.default_min_quality = default_min_quality  #default minimum encoding quality
        self.default_speed = default_speed  #encoding speed (only used by x264)
        self.default_min_speed = default_min_speed  #default minimum encoding speed

        self.default_batch_config = DamageBatchConfig(
        )  #contains default values, some of which may be supplied by the client
        self.global_batch_config = self.default_batch_config.clone(
        )  #global batch config

        self.vrefresh = -1
        self.supports_transparency = False
        self.encoding = None  #the default encoding for all windows
        self.encodings = ()  #all the encodings supported by the client
        self.core_encodings = ()
        self.window_icon_encodings = ["premult_argb32"]
        self.rgb_formats = ("RGB", )
        self.encoding_options = typedict()
        self.icons_encoding_options = typedict()
        self.default_encoding_options = typedict()
        self.auto_refresh_delay = 0

        self.zlib = True
        self.lz4 = use_lz4
        self.lzo = use_lzo

        #for managing the recalculate_delays work:
        self.calculate_window_pixels = {}
        self.calculate_window_ids = set()
        self.calculate_timer = 0
        self.calculate_last_time = 0

        #if we "proxy video", we will modify the video helper to add
        #new encoders, so we must make a deep copy to preserve the original
        #which may be used by other clients (other ServerSource instances)
        self.video_helper = getVideoHelper().clone()

        # the queues of damage requests we work through:
        self.encode_work_queue = Queue(
        )  #holds functions to call to compress data (pixels, clipboard)
        #items placed in this queue are picked off by the "encode" thread,
        #the functions should add the packets they generate to the 'packet_queue'
        self.packet_queue = deque(
        )  #holds actual packets ready for sending (already encoded)
        #these packets are picked off by the "protocol" via 'next_packet()'
        #format: packet, wid, pixels, start_send_cb, end_send_cb
        #(only packet is required - the rest can be 0/None for clipboard packets)
        self.encode_thread = start_thread(self.encode_loop, "encode")
Exemplo n.º 4
0
    def run(self):
        log("ProxyProcess.run() pid=%s, uid=%s, gid=%s", os.getpid(), getuid(), getgid())
        setuidgid(self.uid, self.gid)
        if self.env_options:
            #TODO: whitelist env update?
            os.environ.update(self.env_options)
        self.video_init()

        log.info("new proxy instance started")
        log.info(" for client %s", self.client_conn)
        log.info(" and server %s", self.server_conn)

        signal.signal(signal.SIGTERM, self.signal_quit)
        signal.signal(signal.SIGINT, self.signal_quit)
        log("registered signal handler %s", self.signal_quit)

        start_thread(self.server_message_queue, "server message queue")

        if not self.create_control_socket():
            #TODO: should send a message to the client
            return
        self.control_socket_thread = start_thread(self.control_socket_loop, "control")

        self.main_queue = Queue()
        #setup protocol wrappers:
        self.server_packets = Queue(PROXY_QUEUE_SIZE)
        self.client_packets = Queue(PROXY_QUEUE_SIZE)
        self.client_protocol = Protocol(self, self.client_conn, self.process_client_packet, self.get_client_packet)
        self.client_protocol.restore_state(self.client_state)
        self.server_protocol = Protocol(self, self.server_conn, self.process_server_packet, self.get_server_packet)
        #server connection tweaks:
        self.server_protocol.large_packets.append("draw")
        self.server_protocol.large_packets.append("window-icon")
        self.server_protocol.large_packets.append("keymap-changed")
        self.server_protocol.large_packets.append("server-settings")
        if self.caps.boolget("file-transfer"):
            self.client_protocol.large_packets.append("send-file")
            self.client_protocol.large_packets.append("send-file-chunk")
            self.server_protocol.large_packets.append("send-file")
            self.server_protocol.large_packets.append("send-file-chunk")
        self.server_protocol.set_compression_level(self.session_options.get("compression_level", 0))
        self.server_protocol.enable_default_encoder()

        self.lost_windows = set()
        self.encode_queue = Queue()
        self.encode_thread = start_thread(self.encode_loop, "encode")

        log("starting network threads")
        self.server_protocol.start()
        self.client_protocol.start()

        self.send_hello()
        self.timeout_add(VIDEO_TIMEOUT*1000, self.timeout_video_encoders)

        try:
            self.run_queue()
        except KeyboardInterrupt as e:
            self.stop(str(e))
        finally:
            log("ProxyProcess.run() ending %s", os.getpid())
Exemplo n.º 5
0
 def __init__(self,
              scheduler,
              conn,
              auth,
              process_packet_cb,
              get_rfb_pixelformat,
              session_name="Xpra"):
     """
         You must call this constructor and source_has_more() from the main thread.
     """
     assert scheduler is not None
     assert conn is not None
     self.timeout_add = scheduler.timeout_add
     self.idle_add = scheduler.idle_add
     self._conn = conn
     self._authenticator = auth
     self._process_packet_cb = process_packet_cb
     self._get_rfb_pixelformat = get_rfb_pixelformat
     self.session_name = session_name
     self._write_queue = Queue()
     self._buffer = b""
     self._challenge = None
     self.share = False
     #counters:
     self.input_packetcount = 0
     self.input_raw_packetcount = 0
     self.output_packetcount = 0
     self.output_raw_packetcount = 0
     self._protocol_version = ()
     self._closed = False
     self._packet_parser = self._parse_protocol_handshake
     self._write_thread = None
     self._read_thread = make_thread(self._read_thread_loop,
                                     "read",
                                     daemon=True)
Exemplo n.º 6
0
    def start_tcp_proxy(self, proto, data):
        log("start_tcp_proxy(%s, %s)", proto, data[:10])
        #any buffers read after we steal the connection will be placed in this temporary queue:
        temp_read_buffer = Queue()
        client_connection = proto.steal_connection(temp_read_buffer.put)
        try:
            self._potential_protocols.remove(proto)
        except:
            pass  #might already have been removed by now
        #connect to web server:
        sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
        sock.settimeout(10)
        host, port = self._tcp_proxy.split(":", 1)
        try:
            web_server_connection = _socket_connect(sock, (host, int(port)),
                                                    "web-proxy-for-%s" % proto,
                                                    "tcp")
        except:
            log.warn("failed to connect to proxy: %s:%s", host, port)
            proto.gibberish("invalid packet header", data)
            return
        log("proxy connected to tcp server at %s:%s : %s", host, port,
            web_server_connection)
        sock.settimeout(self._socket_timeout)

        ioe = proto.wait_for_io_threads_exit(0.5 + self._socket_timeout)
        if not ioe:
            log.warn("proxy failed to stop all existing network threads!")
            self.disconnect_protocol(proto, "internal threading error")
            return
        #now that we own it, we can start it again:
        client_connection.set_active(True)
        #and we can use blocking sockets:
        self.set_socket_timeout(client_connection, None)
        #prevent deadlocks on exit:
        sock.settimeout(1)

        log("pushing initial buffer to its new destination: %s",
            repr_ellipsized(data))
        web_server_connection.write(data)
        while not temp_read_buffer.empty():
            buf = temp_read_buffer.get()
            if buf:
                log("pushing read buffer to its new destination: %s",
                    repr_ellipsized(buf))
                web_server_connection.write(buf)
        p = XpraProxy(client_connection, web_server_connection)
        self._tcp_proxy_clients.append(p)

        def run_proxy():
            p.run()
            log("run_proxy() %s ended", p)
            if p in self._tcp_proxy_clients:
                self._tcp_proxy_clients.remove(p)

        t = make_daemon_thread(run_proxy, "web-proxy-for-%s" % proto)
        t.start()
        log.info("client %s forwarded to proxy server %s:%s",
                 client_connection, host, port)
Exemplo n.º 7
0
    def run(self):
        debug("ProxyProcess.run() pid=%s, uid=%s, gid=%s", os.getpid(),
              os.getuid(), os.getgid())
        #change uid and gid:
        if os.getgid() != self.gid:
            os.setgid(self.gid)
        if os.getuid() != self.uid:
            os.setuid(self.uid)
        debug("ProxyProcess.run() new uid=%s, gid=%s", os.getuid(),
              os.getgid())

        if self.env_options:
            #TODO: whitelist env update?
            os.environ.update(self.env_options)

        log.info("new proxy started for client %s and server %s",
                 self.client_conn, self.server_conn)

        if not USE_THREADING:
            signal.signal(signal.SIGTERM, self.signal_quit)
            signal.signal(signal.SIGINT, self.signal_quit)
            debug("registered signal handler %s", self.signal_quit)

        make_daemon_thread(self.server_message_queue,
                           "server message queue").start()

        self.main_queue = Queue()
        #setup protocol wrappers:
        self.server_packets = Queue(PROXY_QUEUE_SIZE)
        self.client_packets = Queue(PROXY_QUEUE_SIZE)
        self.client_protocol = Protocol(self, self.client_conn,
                                        self.process_client_packet,
                                        self.get_client_packet)
        self.client_protocol.restore_state(self.client_state)
        self.server_protocol = Protocol(self, self.server_conn,
                                        self.process_server_packet,
                                        self.get_server_packet)
        #server connection tweaks:
        self.server_protocol.large_packets.append("draw")
        self.server_protocol.large_packets.append("keymap-changed")
        self.server_protocol.large_packets.append("server-settings")
        self.server_protocol.set_compression_level(
            self.session_options.get("compression_level", 0))

        debug("starting network threads")
        self.server_protocol.start()
        self.client_protocol.start()

        #forward the hello packet:
        hello_packet = ("hello", self.filter_client_caps(self.caps))
        self.queue_server_packet(hello_packet)

        try:
            try:
                self.run_queue()
            except KeyboardInterrupt, e:
                self.stop(str(e))
        finally:
            debug("ProxyProcess.run() ending %s", os.getpid())
Exemplo n.º 8
0
 def __init__(self, scheduler, conn, process_packet_cb, get_packet_cb=None):
     """
         You must call this constructor and source_has_more() from the main thread.
     """
     assert scheduler is not None
     assert conn is not None
     self.timeout_add = scheduler.timeout_add
     self.idle_add = scheduler.idle_add
     self._conn = conn
     if FAKE_JITTER > 0:
         from xpra.net.fake_jitter import FakeJitter
         fj = FakeJitter(self.timeout_add, process_packet_cb)
         self._process_packet_cb = fj.process_packet_cb
     else:
         self._process_packet_cb = process_packet_cb
     self._write_queue = Queue(1)
     self._read_queue = Queue(20)
     self._read_queue_put = self.read_queue_put
     # Invariant: if .source is None, then _source_has_more == False
     self._get_packet_cb = get_packet_cb
     #counters:
     self.input_stats = {}
     self.input_packetcount = 0
     self.input_raw_packetcount = 0
     self.output_stats = {}
     self.output_packetcount = 0
     self.output_raw_packetcount = 0
     #initial value which may get increased by client/server after handshake:
     self.max_packet_size = 256 * 1024
     self.abs_max_packet_size = 256 * 1024 * 1024
     self.large_packets = ["hello", "window-metadata", "sound-data"]
     self.send_aliases = {}
     self.receive_aliases = {}
     self._log_stats = None  #None here means auto-detect
     self._closed = False
     self.encoder = "none"
     self._encoder = self.noencode
     self.compressor = "none"
     self._compress = compression.nocompress
     self.compression_level = 0
     self.cipher_in = None
     self.cipher_in_name = None
     self.cipher_in_block_size = 0
     self.cipher_in_padding = INITIAL_PADDING
     self.cipher_out = None
     self.cipher_out_name = None
     self.cipher_out_block_size = 0
     self.cipher_out_padding = INITIAL_PADDING
     self._write_lock = Lock()
     from xpra.make_thread import make_thread
     self._write_thread = make_thread(self._write_thread_loop,
                                      "write",
                                      daemon=True)
     self._read_thread = make_thread(self._read_thread_loop,
                                     "read",
                                     daemon=True)
     self._read_parser_thread = None  #started when needed
     self._write_format_thread = None  #started when needed
     self._source_has_more = Event()
Exemplo n.º 9
0
 def stop_encode_thread(self):
     #empty the encode queue:
     q = self.encode_queue
     if q:
         q.put_nowait(None)
         q = Queue()
         q.put(None)
         self.encode_queue = q
Exemplo n.º 10
0
 def start_queue_encode(self, item):
     #start the encode work queue:
     #holds functions to call to compress data (pixels, clipboard)
     #items placed in this queue are picked off by the "encode" thread,
     #the functions should add the packets they generate to the 'packet_queue'
     self.encode_work_queue = Queue()
     self.queue_encode = self.encode_work_queue.put
     self.queue_encode(item)
     self.encode_thread = start_thread(self.encode_loop, "encode")
Exemplo n.º 11
0
 def stop(self, reason="proxy terminating", skip_proto=None):
     debug("stop(%s, %s)", reason, skip_proto)
     self.main_queue.put(None)
     #empty the main queue:
     q = Queue()
     q.put(None)
     self.main_queue = q
     for proto in (self.client_protocol, self.server_protocol):
         if proto and proto != skip_proto:
             proto.flush_then_close(["disconnect", reason])
Exemplo n.º 12
0
 def stop(self, reason="proxy terminating", skip_proto=None):
     debug("stop(%s, %s)", reason, skip_proto)
     self.main_queue.put(None)
     #empty the main queue:
     q = Queue()
     q.put(None)
     self.main_queue = q
     for proto in (self.client_protocol, self.server_protocol):
         if proto and proto!=skip_proto:
             proto.flush_then_close(["disconnect", reason])
Exemplo n.º 13
0
    def __init__(self, scheduler, conn, process_packet_cb, get_packet_cb=None):
        """
            You must call this constructor and source_has_more() from the main thread.
        """
        assert scheduler is not None
        assert conn is not None
        self.timeout_add = scheduler.timeout_add
        self.idle_add = scheduler.idle_add
        self._conn = conn
        if FAKE_JITTER > 0:
            from xpra.net.fake_jitter import FakeJitter

            fj = FakeJitter(self.timeout_add, process_packet_cb)
            self._process_packet_cb = fj.process_packet_cb
        else:
            self._process_packet_cb = process_packet_cb
        self._write_queue = Queue(1)
        self._read_queue = Queue(20)
        self._read_queue_put = self._read_queue.put
        # Invariant: if .source is None, then _source_has_more == False
        self._get_packet_cb = get_packet_cb
        # counters:
        self.input_stats = {}
        self.input_packetcount = 0
        self.input_raw_packetcount = 0
        self.output_stats = {}
        self.output_packetcount = 0
        self.output_raw_packetcount = 0
        # initial value which may get increased by client/server after handshake:
        self.max_packet_size = 256 * 1024
        self.abs_max_packet_size = 256 * 1024 * 1024
        self.large_packets = ["hello"]
        self.send_aliases = {}
        self.receive_aliases = {}
        self._log_stats = None  # None here means auto-detect
        self._closed = False
        self.encoder = "none"
        self._encoder = self.noencode
        self.compressor = "none"
        self._compress = compression.nocompress
        self.compression_level = 0
        self.cipher_in = None
        self.cipher_in_name = None
        self.cipher_in_block_size = 0
        self.cipher_out = None
        self.cipher_out_name = None
        self.cipher_out_block_size = 0
        self._write_lock = Lock()
        from xpra.daemon_thread import make_daemon_thread

        self._write_thread = make_daemon_thread(self._write_thread_loop, "write")
        self._read_thread = make_daemon_thread(self._read_thread_loop, "read")
        self._read_parser_thread = make_daemon_thread(self._read_parse_thread_loop, "parse")
        self._write_format_thread = make_daemon_thread(self._write_format_thread_loop, "format")
        self._source_has_more = threading.Event()
Exemplo n.º 14
0
 def __init__(self, description="wrapper"):
     self.process = None
     self.protocol = None
     self.command = None
     self.description = description
     self.send_queue = Queue()
     self.signal_callbacks = {}
     self.large_packets = []
     #hook a default packet handlers:
     self.connect(Protocol.CONNECTION_LOST, self.connection_lost)
     self.connect(Protocol.GIBBERISH, self.gibberish)
Exemplo n.º 15
0
 def __init__(self, scheduler, conn, process_packet_cb, get_packet_cb=None):
     """
         You must call this constructor and source_has_more() from the main thread.
     """
     assert scheduler is not None
     assert conn is not None
     self.timeout_add = scheduler.timeout_add
     self.idle_add = scheduler.idle_add
     self._conn = conn
     if FAKE_JITTER > 0:
         fj = FakeJitter(self.timeout_add, process_packet_cb)
         self._process_packet_cb = fj.process_packet_cb
     else:
         self._process_packet_cb = process_packet_cb
     self._write_queue = Queue(1)
     self._read_queue = Queue(20)
     # Invariant: if .source is None, then _source_has_more == False
     self._get_packet_cb = get_packet_cb
     #counters:
     self.input_stats = {}
     self.input_packetcount = 0
     self.input_raw_packetcount = 0
     self.output_stats = {}
     self.output_packetcount = 0
     self.output_raw_packetcount = 0
     #initial value which may get increased by client/server after handshake:
     self.max_packet_size = 32 * 1024
     self.abs_max_packet_size = 256 * 1024 * 1024
     self.large_packets = ["hello"]
     self.send_aliases = {}
     self.receive_aliases = {}
     self._log_stats = None  #None here means auto-detect
     self._closed = False
     self._encoder = self.noencode
     self._compress = zcompress
     self.compression_level = 0
     self.cipher_in = None
     self.cipher_in_name = None
     self.cipher_in_block_size = 0
     self.cipher_out = None
     self.cipher_out_name = None
     self.cipher_out_block_size = 0
     self._write_lock = Lock()
     self._write_thread = make_daemon_thread(self._write_thread_loop,
                                             "write")
     self._read_thread = make_daemon_thread(self._read_thread_loop, "read")
     self._read_parser_thread = make_daemon_thread(
         self._read_parse_thread_loop, "parse")
     self._write_format_thread = make_daemon_thread(
         self._write_format_thread_loop, "format")
     self._source_has_more = threading.Event()
     self.enable_default_encoder()
Exemplo n.º 16
0
    def start_tcp_proxy(self, proto, data):
        proxylog("start_tcp_proxy(%s, '%s')", proto, repr_ellipsized(data))
        try:
            self._potential_protocols.remove(proto)
        except:
            pass  # might already have been removed by now
        proxylog("start_tcp_proxy: protocol state before stealing: %s", proto.get_info(alias_info=False))
        # any buffers read after we steal the connection will be placed in this temporary queue:
        temp_read_buffer = Queue()
        client_connection = proto.steal_connection(temp_read_buffer.put)
        # connect to web server:
        sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
        sock.settimeout(10)
        host, port = self._tcp_proxy.split(":", 1)
        try:
            web_server_connection = _socket_connect(sock, (host, int(port)), "web-proxy-for-%s" % proto, "tcp")
        except:
            proxylog.warn("failed to connect to proxy: %s:%s", host, port)
            proto.gibberish("invalid packet header", data)
            return
        proxylog("proxy connected to tcp server at %s:%s : %s", host, port, web_server_connection)
        sock.settimeout(self._socket_timeout)

        ioe = proto.wait_for_io_threads_exit(0.5 + self._socket_timeout)
        if not ioe:
            proxylog.warn("proxy failed to stop all existing network threads!")
            self.disconnect_protocol(proto, "internal threading error")
            return
        # now that we own it, we can start it again:
        client_connection.set_active(True)
        # and we can use blocking sockets:
        self.set_socket_timeout(client_connection, None)
        # prevent deadlocks on exit:
        sock.settimeout(1)

        proxylog("pushing initial buffer to its new destination: %s", repr_ellipsized(data))
        web_server_connection.write(data)
        while not temp_read_buffer.empty():
            buf = temp_read_buffer.get()
            if buf:
                proxylog("pushing read buffer to its new destination: %s", repr_ellipsized(buf))
                web_server_connection.write(buf)
        p = XpraProxy(client_connection.target, client_connection, web_server_connection)
        self._tcp_proxy_clients.append(p)
        proxylog.info(
            "client connection from %s forwarded to proxy server on %s:%s", client_connection.target, host, port
        )
        p.run()
        proxylog("run_proxy() %s ended", p)
        if p in self._tcp_proxy_clients:
            self._tcp_proxy_clients.remove(p)
Exemplo n.º 17
0
class Invoker(QtCore.QObject):
    def __init__(self):
        super(Invoker, self).__init__()
        self.queue = Queue()

    def invoke(self, func, *args):
        f = lambda: func(*args)
        self.queue.put(f)
        QtCore.QMetaObject.invokeMethod(self, "handler", QtCore.Qt.QueuedConnection)

    @QtCore.pyqtSlot()
    def handler(self):
        f = self.queue.get()
        f()
Exemplo n.º 18
0
 def __init__(self, input_filename="-", output_filename="-", wrapped_object=None, method_whitelist=None):
     self.name = ""
     self._input = None
     self._output = None
     self.input_filename = input_filename
     self.output_filename = output_filename
     self.method_whitelist = method_whitelist
     self.large_packets = []
     #the gobject instance which is wrapped:
     self.wrapped_object = wrapped_object
     self.send_queue = Queue()
     self.protocol = None
     register_os_signals(self.handle_signal)
     self.setup_mainloop()
Exemplo n.º 19
0
 def stop(self, reason="proxy terminating", skip_proto=None):
     log.info("stop(%s, %s)", reason, skip_proto)
     self.exit = True
     if self.control_socket_path:
         try:
             os.unlink(self.control_socket_path)
         except:
             pass
         self.control_socket_path = None
     try:
         self.control_socket.close()
     except:
         pass
     self.main_queue.put(None)
     #empty the main queue:
     q = Queue()
     q.put(None)
     self.main_queue = q
     #empty the encode queue:
     q = Queue()
     q.put(None)
     self.encode_queue = q
     for proto in (self.client_protocol, self.server_protocol):
         if proto and proto != skip_proto:
             log("sending disconnect to %s", proto)
             proto.flush_then_close(["disconnect", SERVER_SHUTDOWN, reason])
Exemplo n.º 20
0
 def stop(self, reason="terminating", skip_proto=None):
     log("stop(%s, %s)", reason, skip_proto)
     log.info("stopping proxy instance: %s", reason)
     self.exit = True
     try:
         self.control_socket.close()
     except:
         pass
     csc = self.control_socket_cleanup
     if csc:
         self.control_socket_cleanup = None
         csc()
     self.main_queue.put(None)
     #empty the main queue:
     q = Queue()
     q.put(None)
     self.main_queue = q
     #empty the encode queue:
     q = Queue()
     q.put(None)
     self.encode_queue = q
     for proto in (self.client_protocol, self.server_protocol):
         if proto and proto != skip_proto:
             log("sending disconnect to %s", proto)
             proto.send_disconnect([SERVER_SHUTDOWN, reason])
Exemplo n.º 21
0
 def stop(self, force=False):
     if self.exit:
         return
     if force:
         if self.items.qsize()>0:
             log.warn("Worker stop: %s items in the queue will not be run!", self.items.qsize())
             self.items.put(None)
             self.items = Queue()
         self.exit = True
     else:
         if self.items.qsize()>0:
             log.info("waiting for %s items in work queue to complete", self.items.qsize())
     debug("Worker_Thread.stop(%s) %s items in work queue", force, self.items)
     self.items.put(None)
Exemplo n.º 22
0
class Invoker(QtCore.QObject):
    def __init__(self):
        super(Invoker, self).__init__()
        self.queue = Queue()

    def invoke(self, func, *args):
        f = lambda: func(*args)
        self.queue.put(f)
        QtCore.QMetaObject.invokeMethod(self, "handler",
                                        QtCore.Qt.QueuedConnection)

    @QtCore.pyqtSlot()
    def handler(self):
        f = self.queue.get()
        f()
Exemplo n.º 23
0
 def __init__(self, input_filename="-", output_filename="-", wrapped_object=None, method_whitelist=None):
     self.name = ""
     self.input_filename = input_filename
     self.output_filename = output_filename
     self.method_whitelist = method_whitelist
     self.large_packets = []
     #the gobject instance which is wrapped:
     self.wrapped_object = wrapped_object
     self.send_queue = Queue()
     self.protocol = None
     if HANDLE_SIGINT:
         #this breaks gobject3!
         signal.signal(signal.SIGINT, self.handle_signal)
     signal.signal(signal.SIGTERM, self.handle_signal)
     self.setup_mainloop()
Exemplo n.º 24
0
 def __init__(self, description="wrapper"):
     self.process = None
     self.protocol = None
     self.command = None
     self.description = description
     self.send_queue = Queue()
     self.signal_callbacks = {}
     self.large_packets = []
     #hook a default packet handlers:
     self.connect(Protocol.CONNECTION_LOST, self.connection_lost)
     self.connect(Protocol.GIBBERISH, self.gibberish)
     glib = import_glib()
     self.idle_add = glib.idle_add
     self.timeout_add = glib.timeout_add
     self.source_remove = glib.source_remove
Exemplo n.º 25
0
 def stop(self, force=False):
     if self.exit:
         return
     items = tuple(x for x in tuple(self.items.queue) if x is not None)
     log("Worker_Thread.stop(%s) %i items still in work queue: %s", force,
         len(items), items)
     if force:
         if items:
             log.warn("Worker stop: %s items in the queue will not be run!",
                      len(items))
             self.items.put(None)
             self.items = Queue()
         self.exit = True
     else:
         if items:
             log.info("waiting for %s items in work queue to complete",
                      len(items))
     self.items.put(None)
Exemplo n.º 26
0
class loopback_connection(Connection):
    """ a fake connection which just writes back whatever is sent to it """
    def __init__(self, *args):
        Connection.__init__(self, *args)
        self.queue = Queue()

    def read(self, n):
        self.may_abort("read")
        #FIXME: we don't handle n...
        return self.queue.get(True)

    def write(self, buf):
        self.may_abort("write")
        self.queue.put(buf)
        return len(buf)

    def may_abort(self, action):
        return False
Exemplo n.º 27
0
class loopback_connection(Connection):
    """ a fake connection which just writes back whatever is sent to it """
    def __init__(self, *args):
        Connection.__init__(self, *args)
        self.queue = Queue()

    def read(self, n):
        self.may_abort("read")
        #FIXME: we don't handle n...
        return self.queue.get(True)

    def write(self, buf):
        self.may_abort("write")
        self.queue.put(buf)
        return len(buf)

    def may_abort(self, action):
        return False
Exemplo n.º 28
0
 def __init__(self, conn, process_packet_cb, get_packet_cb=None):
     """
         You must call this constructor and source_has_more() from the main thread.
     """
     assert conn is not None
     self._conn = conn
     if FAKE_JITTER > 0:
         fj = FakeJitter(process_packet_cb)
         self._process_packet_cb = fj.process_packet_cb
     else:
         self._process_packet_cb = process_packet_cb
     self._write_queue = Queue(1)
     self._read_queue = Queue(20)
     # Invariant: if .source is None, then _source_has_more == False
     self._get_packet_cb = get_packet_cb
     #counters:
     self.input_packetcount = 0
     self.input_raw_packetcount = 0
     self.output_packetcount = 0
     self.output_raw_packetcount = 0
     #initial value which may get increased by client/server after handshake:
     self.max_packet_size = 32 * 1024
     self.large_packets = ["hello"]
     self.aliases = {}
     self.chunked_compression = True
     self._closed = False
     self._encoder = self.bencode
     self._decompressor = decompressobj()
     self._compression_level = 0
     self.cipher_in = None
     self.cipher_in_name = None
     self.cipher_in_block_size = 0
     self.cipher_out = None
     self.cipher_out_name = None
     self.cipher_out_block_size = 0
     self._write_lock = Lock()
     self._write_thread = make_daemon_thread(self._write_thread_loop,
                                             "write")
     self._read_thread = make_daemon_thread(self._read_thread_loop, "read")
     self._read_parser_thread = make_daemon_thread(
         self._read_parse_thread_loop, "parse")
     self._write_format_thread = make_daemon_thread(
         self._write_format_thread_loop, "format")
     self._source_has_more = threading.Event()
Exemplo n.º 29
0
 def terminate_queue_threads(self):
     log("terminate_queue_threads()")
     # the format thread will exit since closed is set too:
     self._source_has_more.set()
     # make the threads exit by adding the empty marker:
     exit_queue = Queue()
     for _ in range(10):  # just 2 should be enough!
         exit_queue.put(None)
     try:
         owq = self._write_queue
         self._write_queue = exit_queue
         owq.put_nowait(None)
     except:
         pass
     try:
         orq = self._read_queue
         self._read_queue = exit_queue
         orq.put_nowait(None)
     except:
         pass
Exemplo n.º 30
0
 def terminate_queue_threads(self):
     log("terminate_queue_threads()")
     #the format thread will exit since closed is set too:
     self._source_has_more.set()
     #make the threads exit by adding the empty marker:
     exit_queue = Queue()
     for _ in range(10):  #just 2 should be enough!
         exit_queue.put(None)
     try:
         owq = self._write_queue
         self._write_queue = exit_queue
         owq.put_nowait(None)
     except:
         pass
     try:
         orq = self._read_queue
         self._read_queue = exit_queue
         orq.put_nowait(None)
     except:
         pass
Exemplo n.º 31
0
 def __init__(self, description="wrapper"):
     self.process = None
     self.protocol = None
     self.command = None
     self.description = description
     self.send_queue = Queue()
     self.signal_callbacks = {}
     self.large_packets = []
     #hook a default packet handlers:
     self.connect(Protocol.CONNECTION_LOST, self.connection_lost)
     self.connect(Protocol.GIBBERISH, self.gibberish)
Exemplo n.º 32
0
    def run(self):
        log.info("started %s", self)
        log.info(" for client %s", self.client_protocol._conn)
        log.info(" and server %s", self.server_protocol._conn)
        self.video_init()

        #setup protocol wrappers:
        self.server_packets = Queue(PROXY_QUEUE_SIZE)
        self.client_packets = Queue(PROXY_QUEUE_SIZE)
        #server connection tweaks:
        for x in (b"input-devices", b"draw", b"window-icon", b"keymap-changed",
                  b"server-settings"):
            self.server_protocol.large_packets.append(x)
        if self.caps.boolget("file-transfer"):
            for x in (b"send-file", b"send-file-chunk"):
                self.server_protocol.large_packets.append(x)
                self.client_protocol.large_packets.append(x)
        self.server_protocol.set_compression_level(
            self.session_options.get("compression_level", 0))
        self.server_protocol.enable_default_encoder()

        self.lost_windows = set()
        self.encode_queue = Queue()
        self.encode_thread = start_thread(self.encode_loop, "encode")

        self.start_network_threads()
        if self.caps.boolget("ping-echo-sourceid"):
            self.schedule_client_ping()

        self.send_hello()
Exemplo n.º 33
0
 def stop(self, reason="proxy terminating", skip_proto=None):
     log.info("stop(%s, %s)", reason, skip_proto)
     self.exit = True
     if self.control_socket_path:
         try:
             os.unlink(self.control_socket_path)
         except:
             pass
         self.control_socket_path = None
     try:
         self.control_socket.close()
     except:
         pass
     self.main_queue.put(None)
     #empty the main queue:
     q = Queue()
     q.put(None)
     self.main_queue = q
     #empty the encode queue:
     q = Queue()
     q.put(None)
     self.encode_queue = q
     for proto in (self.client_protocol, self.server_protocol):
         if proto and proto!=skip_proto:
             log("sending disconnect to %s", proto)
             proto.flush_then_close(["disconnect", SERVER_SHUTDOWN, reason])
Exemplo n.º 34
0
 def __init__(self, input_filename="-", output_filename="-", wrapped_object=None, method_whitelist=None):
     self.mainloop = mainloop()
     self.name = ""
     self.input_filename = input_filename
     self.output_filename = output_filename
     self.method_whitelist = None
     self.large_packets = []
     #the gobject instance which is wrapped:
     self.wrapped_object = wrapped_object
     self.send_queue = Queue()
     self.protocol = None
     signal.signal(signal.SIGINT, self.handle_signal)
     signal.signal(signal.SIGTERM, self.handle_signal)
Exemplo n.º 35
0
class WebSocketConnection(SocketConnection):
    def __init__(self, socket, local, remote, target, socktype, ws_handler):
        SocketConnection.__init__(self, socket, local, remote, target,
                                  socktype)
        self.protocol_type = "websocket"
        self.ws_handler = ws_handler
        self.pending_read = Queue()

    def close(self):
        self.pending_read = Queue()
        SocketConnection.close(self)

    def read(self, n):
        #FIXME: we should try to honour n
        while self.is_active():
            if self.pending_read.qsize():
                buf = self.pending_read.get()
                log("read() returning pending read buffer, len=%i", len(buf))
                self.input_bytecount += len(buf)
                return memoryview_to_bytes(buf)
            bufs, closed_string = self.ws_handler.recv_frames()
            if closed_string:
                log("read() closed_string: %s",
                    memoryview_to_bytes(closed_string))
                self.active = False
            log("read() got %i ws frames", len(bufs))
            if bufs:
                buf = bufs[0]
                if len(bufs) > 1:
                    for v in bufs[1:]:
                        self.pending_read.put(v)
                self.input_bytecount += len(buf)
                return memoryview_to_bytes(buf)

    def write(self, buf):
        self.ws_handler.send_frames([memoryview_to_bytes(buf)])
        self.output_bytecount += len(buf)
        return len(buf)
Exemplo n.º 36
0
 def stop(self, force=False):
     if self.exit:
         return
     if force:
         if self.items.qsize() > 0:
             log.warn("Worker stop: %s items in the queue will not be run!", self.items.qsize())
             self.items.put(None)
             self.items = Queue()
         self.exit = True
     else:
         if self.items.qsize() > 0:
             log.info("waiting for %s items in work queue to complete", self.items.qsize())
     debug("Worker_Thread.stop(%s) %s items in work queue", force, self.items)
     self.items.put(None)
Exemplo n.º 37
0
 def __init__(self, description="wrapper"):
     self.process = None
     self.protocol = None
     self.command = None
     self.description = description
     self.send_queue = Queue()
     self.signal_callbacks = {}
     self.large_packets = []
     #hook a default packet handlers:
     self.connect(Protocol.CONNECTION_LOST, self.connection_lost)
     self.connect(Protocol.GIBBERISH, self.gibberish)
     glib = import_glib()
     self.idle_add = glib.idle_add
     self.timeout_add = glib.timeout_add
     self.source_remove = glib.source_remove
Exemplo n.º 38
0
class Protocol(object):
    CONNECTION_LOST = "connection-lost"
    GIBBERISH = "gibberish"
    INVALID = "invalid"

    def __init__(self, scheduler, conn, process_packet_cb, get_packet_cb=None):
        """
            You must call this constructor and source_has_more() from the main thread.
        """
        assert scheduler is not None
        assert conn is not None
        self.timeout_add = scheduler.timeout_add
        self.idle_add = scheduler.idle_add
        self._conn = conn
        if FAKE_JITTER > 0:
            from xpra.net.fake_jitter import FakeJitter
            fj = FakeJitter(self.timeout_add, process_packet_cb)
            self._process_packet_cb = fj.process_packet_cb
        else:
            self._process_packet_cb = process_packet_cb
        self._write_queue = Queue(1)
        self._read_queue = Queue(20)
        self._read_queue_put = self.read_queue_put
        # Invariant: if .source is None, then _source_has_more == False
        self._get_packet_cb = get_packet_cb
        #counters:
        self.input_stats = {}
        self.input_packetcount = 0
        self.input_raw_packetcount = 0
        self.output_stats = {}
        self.output_packetcount = 0
        self.output_raw_packetcount = 0
        #initial value which may get increased by client/server after handshake:
        self.max_packet_size = 256 * 1024
        self.abs_max_packet_size = 256 * 1024 * 1024
        self.large_packets = ["hello", "window-metadata", "sound-data"]
        self.send_aliases = {}
        self.receive_aliases = {}
        self._log_stats = None  #None here means auto-detect
        self._closed = False
        self.encoder = "none"
        self._encoder = self.noencode
        self.compressor = "none"
        self._compress = compression.nocompress
        self.compression_level = 0
        self.cipher_in = None
        self.cipher_in_name = None
        self.cipher_in_block_size = 0
        self.cipher_in_padding = INITIAL_PADDING
        self.cipher_out = None
        self.cipher_out_name = None
        self.cipher_out_block_size = 0
        self.cipher_out_padding = INITIAL_PADDING
        self._write_lock = Lock()
        from xpra.make_thread import make_thread
        self._write_thread = make_thread(self._write_thread_loop,
                                         "write",
                                         daemon=True)
        self._read_thread = make_thread(self._read_thread_loop,
                                        "read",
                                        daemon=True)
        self._read_parser_thread = None  #started when needed
        self._write_format_thread = None  #started when needed
        self._source_has_more = Event()

    STATE_FIELDS = ("max_packet_size", "large_packets", "send_aliases",
                    "receive_aliases", "cipher_in", "cipher_in_name",
                    "cipher_in_block_size", "cipher_in_padding", "cipher_out",
                    "cipher_out_name", "cipher_out_block_size",
                    "cipher_out_padding", "compression_level", "encoder",
                    "compressor")

    def save_state(self):
        state = {}
        for x in Protocol.STATE_FIELDS:
            state[x] = getattr(self, x)
        return state

    def restore_state(self, state):
        assert state is not None
        for x in Protocol.STATE_FIELDS:
            assert x in state, "field %s is missing" % x
            setattr(self, x, state[x])
        #special handling for compressor / encoder which are named objects:
        self.enable_compressor(self.compressor)
        self.enable_encoder(self.encoder)

    def wait_for_io_threads_exit(self, timeout=None):
        for t in (self._read_thread, self._write_thread):
            if t and t.isAlive():
                t.join(timeout)
        exited = True
        cinfo = self._conn or "cleared connection"
        for t in (self._read_thread, self._write_thread):
            if t and t.isAlive():
                log.warn(
                    "Warning: %s thread of %s is still alive (timeout=%s)",
                    t.name, cinfo, timeout)
                exited = False
        return exited

    def set_packet_source(self, get_packet_cb):
        self._get_packet_cb = get_packet_cb

    def set_cipher_in(self, ciphername, iv, password, key_salt, iterations,
                      padding):
        if self.cipher_in_name != ciphername:
            cryptolog.info("receiving data using %s encryption", ciphername)
            self.cipher_in_name = ciphername
        cryptolog("set_cipher_in%s",
                  (ciphername, iv, password, key_salt, iterations))
        self.cipher_in, self.cipher_in_block_size = get_decryptor(
            ciphername, iv, password, key_salt, iterations)
        self.cipher_in_padding = padding

    def set_cipher_out(self, ciphername, iv, password, key_salt, iterations,
                       padding):
        if self.cipher_out_name != ciphername:
            cryptolog.info("sending data using %s encryption", ciphername)
            self.cipher_out_name = ciphername
        cryptolog("set_cipher_out%s",
                  (ciphername, iv, password, key_salt, iterations, padding))
        self.cipher_out, self.cipher_out_block_size = get_encryptor(
            ciphername, iv, password, key_salt, iterations)
        self.cipher_out_padding = padding

    def __repr__(self):
        return "Protocol(%s)" % self._conn

    def get_threads(self):
        return [
            x for x in [
                self._write_thread, self._read_thread,
                self._read_parser_thread, self._write_format_thread
            ] if x is not None
        ]

    def get_info(self, alias_info=True):
        info = {
            "large_packets": self.large_packets,
            "compression_level": self.compression_level,
            "max_packet_size": self.max_packet_size,
            "aliases": USE_ALIASES,
            "input": {
                "buffer-size": READ_BUFFER_SIZE,
                "packetcount": self.input_packetcount,
                "raw_packetcount": self.input_raw_packetcount,
                "count": self.input_stats,
                "cipher": {
                    "": self.cipher_in_name or "",
                    "padding": self.cipher_in_padding,
                },
            },
            "output": {
                "packet-join-size": PACKET_JOIN_SIZE,
                "large-packet-size": LARGE_PACKET_SIZE,
                "inline-size": INLINE_SIZE,
                "min-compress-size": MIN_COMPRESS_SIZE,
                "packetcount": self.output_packetcount,
                "raw_packetcount": self.output_raw_packetcount,
                "count": self.output_stats,
                "cipher": {
                    "": self.cipher_out_name or "",
                    "padding": self.cipher_out_padding
                },
            },
        }
        c = self._compress
        if c:
            info["compressor"] = compression.get_compressor_name(
                self._compress)
        e = self._encoder
        if e:
            if self._encoder == self.noencode:
                info["encoder"] = "noencode"
            else:
                info["encoder"] = packet_encoding.get_encoder_name(
                    self._encoder)
        if alias_info:
            info["send_alias"] = self.send_aliases
            info["receive_alias"] = self.receive_aliases
        c = self._conn
        if c:
            try:
                info.update(self._conn.get_info())
            except:
                log.error("error collecting connection information on %s",
                          self._conn,
                          exc_info=True)
        info["has_more"] = self._source_has_more.is_set()
        for t in (self._write_thread, self._read_thread,
                  self._read_parser_thread, self._write_format_thread):
            if t:
                info.setdefault("thread", {})[t.name] = t.is_alive()
        return info

    def start(self):
        def do_start():
            if not self._closed:
                self._write_thread.start()
                self._read_thread.start()

        self.idle_add(do_start)

    def send_now(self, packet):
        if self._closed:
            log("send_now(%s ...) connection is closed already, not sending",
                packet[0])
            return
        log("send_now(%s ...)", packet[0])
        assert self._get_packet_cb == None, "cannot use send_now when a packet source exists! (set to %s)" % self._get_packet_cb

        def packet_cb():
            self._get_packet_cb = None
            return (packet, )

        self._get_packet_cb = packet_cb
        self.source_has_more()

    def source_has_more(self):
        self._source_has_more.set()
        #start the format thread:
        if not self._write_format_thread and not self._closed:
            from xpra.make_thread import make_thread
            self._write_format_thread = make_thread(
                self._write_format_thread_loop, "format", daemon=True)
            self._write_format_thread.start()
        INJECT_FAULT(self)

    def _write_format_thread_loop(self):
        log("write_format_thread_loop starting")
        try:
            while not self._closed:
                self._source_has_more.wait()
                gpc = self._get_packet_cb
                if self._closed or not gpc:
                    return
                self._source_has_more.clear()
                self._add_packet_to_queue(*gpc())
        except Exception as e:
            if self._closed:
                return
            self._internal_error("error in network packet write/format",
                                 e,
                                 exc_info=True)

    def _add_packet_to_queue(self,
                             packet,
                             start_send_cb=None,
                             end_send_cb=None,
                             has_more=False):
        if has_more:
            self._source_has_more.set()
        if packet is None:
            return
        log("add_packet_to_queue(%s ...)", packet[0])
        chunks = self.encode(packet)
        with self._write_lock:
            if self._closed:
                return
            self._add_chunks_to_queue(chunks, start_send_cb, end_send_cb)

    def _add_chunks_to_queue(self,
                             chunks,
                             start_send_cb=None,
                             end_send_cb=None):
        """ the write_lock must be held when calling this function """
        counter = 0
        items = []
        for proto_flags, index, level, data in chunks:
            scb, ecb = None, None
            #fire the start_send_callback just before the first packet is processed:
            if counter == 0:
                scb = start_send_cb
            #fire the end_send callback when the last packet (index==0) makes it out:
            if index == 0:
                ecb = end_send_cb
            payload_size = len(data)
            actual_size = payload_size
            if self.cipher_out:
                proto_flags |= FLAGS_CIPHER
                #note: since we are padding: l!=len(data)
                padding_size = self.cipher_out_block_size - (
                    payload_size % self.cipher_out_block_size)
                if padding_size == 0:
                    padded = data
                else:
                    # pad byte value is number of padding bytes added
                    padded = data + pad(self.cipher_out_padding, padding_size)
                    actual_size += padding_size
                assert len(
                    padded
                ) == actual_size, "expected padded size to be %i, but got %i" % (
                    len(padded), actual_size)
                data = self.cipher_out.encrypt(padded)
                assert len(
                    data
                ) == actual_size, "expected encrypted size to be %i, but got %i" % (
                    len(data), actual_size)
                cryptolog("sending %s bytes %s encrypted with %s padding",
                          payload_size, self.cipher_out_name, padding_size)
            if proto_flags & FLAGS_NOHEADER:
                assert not self.cipher_out
                #for plain/text packets (ie: gibberish response)
                log("sending %s bytes without header", payload_size)
                items.append((data, scb, ecb))
            elif actual_size < PACKET_JOIN_SIZE:
                if type(data) not in JOIN_TYPES:
                    data = bytes(data)
                header_and_data = pack_header(proto_flags, level, index,
                                              payload_size) + data
                items.append((header_and_data, scb, ecb))
            else:
                header = pack_header(proto_flags, level, index, payload_size)
                items.append((header, scb, None))
                items.append((strtobytes(data), None, ecb))
            counter += 1
        self._write_queue.put(items)
        self.output_packetcount += 1

    def raw_write(self, contents, start_cb=None, end_cb=None):
        """ Warning: this bypasses the compression and packet encoder! """
        self._write_queue.put(((contents, start_cb, end_cb), ))

    def verify_packet(self, packet):
        """ look for None values which may have caused the packet to fail encoding """
        if type(packet) != list:
            return
        assert len(packet) > 0, "invalid packet: %s" % packet
        tree = ["'%s' packet" % packet[0]]
        self.do_verify_packet(tree, packet)

    def do_verify_packet(self, tree, packet):
        def err(msg):
            log.error("%s in %s", msg, "->".join(tree))

        def new_tree(append):
            nt = tree[:]
            nt.append(append)
            return nt

        if packet is None:
            return err("None value")
        if type(packet) == list:
            for i, x in enumerate(packet):
                self.do_verify_packet(new_tree("[%s]" % i), x)
        elif type(packet) == dict:
            for k, v in packet.items():
                self.do_verify_packet(new_tree("key for value='%s'" % str(v)),
                                      k)
                self.do_verify_packet(new_tree("value for key='%s'" % str(k)),
                                      v)

    def enable_default_encoder(self):
        opts = packet_encoding.get_enabled_encoders()
        assert len(opts) > 0, "no packet encoders available!"
        self.enable_encoder(opts[0])

    def enable_encoder_from_caps(self, caps):
        opts = packet_encoding.get_enabled_encoders(
            order=packet_encoding.PERFORMANCE_ORDER)
        log("enable_encoder_from_caps(..) options=%s", opts)
        for e in opts:
            if caps.boolget(e, e == "bencode"):
                self.enable_encoder(e)
                return True
        log.error("no matching packet encoder found!")
        return False

    def enable_encoder(self, e):
        self._encoder = packet_encoding.get_encoder(e)
        self.encoder = e
        log("enable_encoder(%s): %s", e, self._encoder)

    def enable_default_compressor(self):
        opts = compression.get_enabled_compressors()
        if len(opts) > 0:
            self.enable_compressor(opts[0])
        else:
            self.enable_compressor("none")

    def enable_compressor_from_caps(self, caps):
        if self.compression_level == 0:
            self.enable_compressor("none")
            return
        opts = compression.get_enabled_compressors(
            order=compression.PERFORMANCE_ORDER)
        log("enable_compressor_from_caps(..) options=%s", opts)
        for c in opts:  #ie: [zlib, lz4, lzo]
            if caps.boolget(c):
                self.enable_compressor(c)
                return
        log.warn("compression disabled: no matching compressor found")
        self.enable_compressor("none")

    def enable_compressor(self, compressor):
        self._compress = compression.get_compressor(compressor)
        self.compressor = compressor
        log("enable_compressor(%s): %s", compressor, self._compress)

    def noencode(self, data):
        #just send data as a string for clients that don't understand xpra packet format:
        if sys.version_info[0] >= 3:
            import codecs

            def b(x):
                if type(x) == bytes:
                    return x
                return codecs.latin_1_encode(x)[0]
        else:

            def b(x):  #@DuplicatedSignature
                return x

        return b(": ".join(str(x) for x in data) + "\n"), FLAGS_NOHEADER

    def encode(self, packet_in):
        """
        Given a packet (tuple or list of items), converts it for the wire.
        This method returns all the binary packets to send, as an array of:
        (index, compression_level and compression flags, binary_data)
        The index, if positive indicates the item to populate in the packet
        whose index is zero.
        ie: ["blah", [large binary data], "hello", 200]
        may get converted to:
        [
            (1, compression_level, [large binary data now zlib compressed]),
            (0,                 0, bencoded/rencoded(["blah", '', "hello", 200]))
        ]
        """
        packets = []
        packet = list(packet_in)
        level = self.compression_level
        size_check = LARGE_PACKET_SIZE
        min_comp_size = MIN_COMPRESS_SIZE
        for i in range(1, len(packet)):
            item = packet[i]
            if item is None:
                raise TypeError("invalid None value in %s packet at index %s" %
                                (packet[0], i))
            ti = type(item)
            if ti in (int, long, bool, dict, list, tuple):
                continue
            try:
                l = len(item)
            except TypeError as e:
                raise TypeError(
                    "invalid type %s in %s packet at index %s: %s" %
                    (ti, packet[0], i, e))
            if ti == LargeStructure:
                item = item.data
                packet[i] = item
                ti = type(item)
                continue
            elif ti == Compressible:
                #this is a marker used to tell us we should compress it now
                #(used by the client for clipboard data)
                item = item.compress()
                packet[i] = item
                ti = type(item)
                #(it may now be a "Compressed" item and be processed further)
            if ti in (Compressed, LevelCompressed):
                #already compressed data (usually pixels, cursors, etc)
                if not item.can_inline or l > INLINE_SIZE:
                    il = 0
                    if ti == LevelCompressed:
                        #unlike Compressed (usually pixels, decompressed in the paint thread),
                        #LevelCompressed is decompressed by the network layer
                        #so we must tell it how to do that and pass the level flag
                        il = item.level
                    packets.append((0, i, il, item.data))
                    packet[i] = ''
                else:
                    #data is small enough, inline it:
                    packet[i] = item.data
                    min_comp_size += l
                    size_check += l
            elif ti in (str, bytes) and level > 0 and l > LARGE_PACKET_SIZE:
                log.warn(
                    "found a large uncompressed item in packet '%s' at position %s: %s bytes",
                    packet[0], i, len(item))
                #add new binary packet with large item:
                cl, cdata = self._compress(item, level)
                packets.append((0, i, cl, cdata))
                #replace this item with an empty string placeholder:
                packet[i] = ''
            elif ti not in (str, bytes):
                log.warn("unexpected data type %s in %s packet: %s", ti,
                         packet[0], repr_ellipsized(item))
        #now the main packet (or what is left of it):
        packet_type = packet[0]
        self.output_stats[packet_type] = self.output_stats.get(packet_type,
                                                               0) + 1
        if USE_ALIASES and self.send_aliases and packet_type in self.send_aliases:
            #replace the packet type with the alias:
            packet[0] = self.send_aliases[packet_type]
        try:
            main_packet, proto_flags = self._encoder(packet)
        except Exception:
            if self._closed:
                return [], 0
            log.error("failed to encode packet: %s", packet, exc_info=True)
            #make the error a bit nicer to parse: undo aliases:
            packet[0] = packet_type
            self.verify_packet(packet)
            raise
        if len(main_packet
               ) > size_check and packet_in[0] not in self.large_packets:
            log.warn(
                "found large packet (%s bytes): %s, argument types:%s, sizes: %s, packet head=%s",
                len(main_packet), packet_in[0], [type(x) for x in packet[1:]],
                [len(str(x)) for x in packet[1:]], repr_ellipsized(packet))
        #compress, but don't bother for small packets:
        if level > 0 and len(main_packet) > min_comp_size:
            cl, cdata = self._compress(main_packet, level)
            packets.append((proto_flags, 0, cl, cdata))
        else:
            packets.append((proto_flags, 0, 0, main_packet))
        return packets

    def set_compression_level(self, level):
        #this may be used next time encode() is called
        assert level >= 0 and level <= 10, "invalid compression level: %s (must be between 0 and 10" % level
        self.compression_level = level

    def _io_thread_loop(self, name, callback):
        try:
            log("io_thread_loop(%s, %s) loop starting", name, callback)
            while not self._closed and callback():
                pass
            log("io_thread_loop(%s, %s) loop ended, closed=%s", name, callback,
                self._closed)
        except ConnectionClosedException as e:
            log("%s closed", self._conn, exc_info=True)
            if not self._closed:
                #ConnectionClosedException means the warning has been logged already
                self._connection_lost("%s connection %s closed" %
                                      (name, self._conn))
        except (OSError, IOError, socket_error) as e:
            if not self._closed:
                self._internal_error("%s connection %s reset" %
                                     (name, self._conn),
                                     e,
                                     exc_info=e.args[0] not in ABORT)
        except Exception as e:
            #can happen during close(), in which case we just ignore:
            if not self._closed:
                log.error("Error: %s on %s failed: %s",
                          name,
                          self._conn,
                          type(e),
                          exc_info=True)
                self.close()

    def _write_thread_loop(self):
        self._io_thread_loop("write", self._write)

    def _write(self):
        items = self._write_queue.get()
        # Used to signal that we should exit:
        if items is None:
            log("write thread: empty marker, exiting")
            self.close()
            return False
        for buf, start_cb, end_cb in items:
            con = self._conn
            if not con:
                return False
            if start_cb:
                try:
                    start_cb(con.output_bytecount)
                except:
                    if not self._closed:
                        log.error("error on %s", start_cb, exc_info=True)
            while buf and not self._closed:
                written = con.write(buf)
                if written:
                    buf = buf[written:]
                    self.output_raw_packetcount += 1
            if end_cb:
                try:
                    end_cb(self._conn.output_bytecount)
                except:
                    if not self._closed:
                        log.error("error on %s", end_cb, exc_info=True)
        return True

    def _read_thread_loop(self):
        self._io_thread_loop("read", self._read)

    def _read(self):
        buf = self._conn.read(READ_BUFFER_SIZE)
        #log("read thread: got data of size %s: %s", len(buf), repr_ellipsized(buf))
        #add to the read queue (or whatever takes its place - see steal_connection)
        self._read_queue_put(buf)
        if not buf:
            log("read thread: eof")
            #give time to the parse thread to call close itself
            #so it has time to parse and process the last packet received
            self.timeout_add(1000, self.close)
            return False
        self.input_raw_packetcount += 1
        return True

    def _internal_error(self, message="", exc=None, exc_info=False):
        #log exception info with last log message
        if self._closed:
            return
        ei = exc_info
        if exc:
            ei = None  #log it separately below
        log.error("Error: %s", message, exc_info=ei)
        if exc:
            log.error(" %s", exc, exc_info=exc_info)
        self.idle_add(self._connection_lost, message)

    def _connection_lost(self, message="", exc_info=False):
        log("connection lost: %s", message, exc_info=exc_info)
        self.close()
        return False

    def invalid(self, msg, data):
        self.idle_add(self._process_packet_cb, self,
                      [Protocol.INVALID, msg, data])
        # Then hang up:
        self.timeout_add(1000, self._connection_lost, msg)

    def gibberish(self, msg, data):
        self.idle_add(self._process_packet_cb, self,
                      [Protocol.GIBBERISH, msg, data])
        # Then hang up:
        self.timeout_add(1000, self._connection_lost, msg)

    #delegates to invalid_header()
    #(so this can more easily be intercepted and overriden
    # see tcp-proxy)
    def _invalid_header(self, data):
        self.invalid_header(self, data)

    def invalid_header(self, proto, data):
        err = "invalid packet header: '%s'" % binascii.hexlify(data[:8])
        if len(data) > 1:
            err += " read buffer=%s" % repr_ellipsized(data)
        self.gibberish(err, data)

    def read_queue_put(self, data):
        #start the parse thread if needed:
        if not self._read_parser_thread and not self._closed:
            from xpra.make_thread import make_thread
            self._read_parser_thread = make_thread(
                self._read_parse_thread_loop, "parse", daemon=True)
            self._read_parser_thread.start()
        self._read_queue.put(data)

    def _read_parse_thread_loop(self):
        log("read_parse_thread_loop starting")
        try:
            self.do_read_parse_thread_loop()
        except Exception as e:
            if self._closed:
                return
            self._internal_error("error in network packet reading/parsing",
                                 e,
                                 exc_info=True)

    def do_read_parse_thread_loop(self):
        """
            Process the individual network packets placed in _read_queue.
            Concatenate the raw packet data, then try to parse it.
            Extract the individual packets from the potentially large buffer,
            saving the rest of the buffer for later, and optionally decompress this data
            and re-construct the one python-object-packet from potentially multiple packets (see packet_index).
            The 8 bytes packet header gives us information on the packet index, packet size and compression.
            The actual processing of the packet is done via the callback process_packet_cb,
            this will be called from this parsing thread so any calls that need to be made
            from the UI thread will need to use a callback (usually via 'idle_add')
        """
        read_buffer = None
        payload_size = -1
        padding_size = 0
        packet_index = 0
        compression_level = False
        packet = None
        raw_packets = {}
        while not self._closed:
            buf = self._read_queue.get()
            if not buf:
                log("parse thread: empty marker, exiting")
                self.idle_add(self.close)
                return
            if read_buffer:
                read_buffer = read_buffer + buf
            else:
                read_buffer = buf
            bl = len(read_buffer)
            while not self._closed:
                packet = None
                bl = len(read_buffer)
                if bl <= 0:
                    break
                if payload_size < 0:
                    if read_buffer[0] not in ("P", ord("P")):
                        self._invalid_header(read_buffer)
                        return
                    if bl < 8:
                        break  #packet still too small
                    #packet format: struct.pack('cBBBL', ...) - 8 bytes
                    _, protocol_flags, compression_level, packet_index, data_size = unpack_header(
                        read_buffer[:8])

                    #sanity check size (will often fail if not an xpra client):
                    if data_size > self.abs_max_packet_size:
                        self._invalid_header(read_buffer)
                        return

                    bl = len(read_buffer) - 8
                    if protocol_flags & FLAGS_CIPHER:
                        if self.cipher_in_block_size == 0 or not self.cipher_in_name:
                            cryptolog.warn(
                                "received cipher block but we don't have a cipher to decrypt it with, not an xpra client?"
                            )
                            self._invalid_header(read_buffer)
                            return
                        padding_size = self.cipher_in_block_size - (
                            data_size % self.cipher_in_block_size)
                        payload_size = data_size + padding_size
                    else:
                        #no cipher, no padding:
                        padding_size = 0
                        payload_size = data_size
                    assert payload_size > 0, "invalid payload size: %i" % payload_size
                    read_buffer = read_buffer[8:]

                    if payload_size > self.max_packet_size:
                        #this packet is seemingly too big, but check again from the main UI thread
                        #this gives 'set_max_packet_size' a chance to run from "hello"
                        def check_packet_size(size_to_check, packet_header):
                            if self._closed:
                                return False
                            log("check_packet_size(%s, 0x%s) limit is %s",
                                size_to_check, repr_ellipsized(packet_header),
                                self.max_packet_size)
                            if size_to_check > self.max_packet_size:
                                msg = "packet size requested is %s but maximum allowed is %s" % \
                                              (size_to_check, self.max_packet_size)
                                self.invalid(msg, packet_header)
                            return False

                        self.timeout_add(1000, check_packet_size, payload_size,
                                         read_buffer[:32])

                if bl < payload_size:
                    # incomplete packet, wait for the rest to arrive
                    break

                #chop this packet from the buffer:
                if len(read_buffer) == payload_size:
                    raw_string = read_buffer
                    read_buffer = ''
                else:
                    raw_string = read_buffer[:payload_size]
                    read_buffer = read_buffer[payload_size:]
                #decrypt if needed:
                data = raw_string
                if self.cipher_in and protocol_flags & FLAGS_CIPHER:
                    cryptolog("received %i %s encrypted bytes with %s padding",
                              payload_size, self.cipher_in_name, padding_size)
                    data = self.cipher_in.decrypt(raw_string)
                    if padding_size > 0:

                        def debug_str(s):
                            try:
                                return binascii.hexlify(bytearray(s))
                            except:
                                return csv(list(str(s)))

                        # pad byte value is number of padding bytes added
                        padtext = pad(self.cipher_in_padding, padding_size)
                        if data.endswith(padtext):
                            cryptolog("found %s %s padding",
                                      self.cipher_in_padding,
                                      self.cipher_in_name)
                        else:
                            actual_padding = data[-padding_size:]
                            cryptolog.warn(
                                "Warning: %s decryption failed: invalid padding",
                                self.cipher_in_name)
                            cryptolog(
                                " data does not end with %s padding bytes %s",
                                self.cipher_in_padding, debug_str(padtext))
                            cryptolog(" but with %s (%s)",
                                      debug_str(actual_padding), type(data))
                            cryptolog(" decrypted data: %s",
                                      debug_str(data[:128]))
                            return self._internal_error(
                                "%s encryption padding error - wrong key?" %
                                self.cipher_in_name)
                        data = data[:-padding_size]
                #uncompress if needed:
                if compression_level > 0:
                    try:
                        data = decompress(data, compression_level)
                    except InvalidCompressionException as e:
                        self.invalid("invalid compression: %s" % e, data)
                        return
                    except Exception as e:
                        ctype = compression.get_compression_type(
                            compression_level)
                        log("%s packet decompression failed",
                            ctype,
                            exc_info=True)
                        msg = "%s packet decompression failed" % ctype
                        if self.cipher_in:
                            msg += " (invalid encryption key?)"
                        else:
                            #only include the exception text when not using encryption
                            #as this may leak crypto information:
                            msg += " %s" % e
                        return self.gibberish(msg, data)

                if self.cipher_in and not (protocol_flags & FLAGS_CIPHER):
                    self.invalid("unencrypted packet dropped", data)
                    return

                if self._closed:
                    return
                if packet_index > 0:
                    #raw packet, store it and continue:
                    raw_packets[packet_index] = data
                    payload_size = -1
                    packet_index = 0
                    if len(raw_packets) >= 4:
                        self.invalid(
                            "too many raw packets: %s" % len(raw_packets),
                            data)
                        return
                    continue
                #final packet (packet_index==0), decode it:
                try:
                    packet = decode(data, protocol_flags)
                except InvalidPacketEncodingException as e:
                    self.invalid("invalid packet encoding: %s" % e, data)
                    return
                except ValueError as e:
                    etype = packet_encoding.get_packet_encoding_type(
                        protocol_flags)
                    log.error("Error parsing %s packet:", etype)
                    log.error(" %s", e)
                    if self._closed:
                        return
                    log("failed to parse %s packet: %s", etype,
                        binascii.hexlify(data[:128]))
                    log(" %s", e)
                    log(" data: %s", repr_ellipsized(data))
                    log(" packet index=%i, packet size=%i, buffer size=%s",
                        packet_index, payload_size, bl)
                    self.gibberish("failed to parse %s packet" % etype, data)
                    return

                if self._closed:
                    return
                payload_size = -1
                padding_size = 0
                #add any raw packets back into it:
                if raw_packets:
                    for index, raw_data in raw_packets.items():
                        #replace placeholder with the raw_data packet data:
                        packet[index] = raw_data
                    raw_packets = {}

                packet_type = packet[0]
                if self.receive_aliases and type(
                        packet_type
                ) == int and packet_type in self.receive_aliases:
                    packet_type = self.receive_aliases.get(packet_type)
                    packet[0] = packet_type
                self.input_stats[packet_type] = self.output_stats.get(
                    packet_type, 0) + 1

                self.input_packetcount += 1
                log("processing packet %s", packet_type)
                self._process_packet_cb(self, packet)
                packet = None
                INJECT_FAULT(self)

    def flush_then_close(self, last_packet, done_callback=None):
        """ Note: this is best effort only
            the packet may not get sent.

            We try to get the write lock,
            we try to wait for the write queue to flush
            we queue our last packet,
            we wait again for the queue to flush,
            then no matter what, we close the connection and stop the threads.
        """
        log("flush_then_close(%s, %s) closed=%s", last_packet, done_callback,
            self._closed)

        def done():
            log("flush_then_close: done, callback=%s", done_callback)
            if done_callback:
                done_callback()

        if self._closed:
            log("flush_then_close: already closed")
            return done()

        def wait_for_queue(timeout=10):
            #IMPORTANT: if we are here, we have the write lock held!
            if not self._write_queue.empty():
                #write queue still has stuff in it..
                if timeout <= 0:
                    log("flush_then_close: queue still busy, closing without sending the last packet"
                        )
                    self._write_lock.release()
                    self.close()
                    done()
                else:
                    log("flush_then_close: still waiting for queue to flush")
                    self.timeout_add(100, wait_for_queue, timeout - 1)
            else:
                log("flush_then_close: queue is now empty, sending the last packet and closing"
                    )
                chunks = self.encode(last_packet)

                def close_and_release():
                    log("flush_then_close: wait_for_packet_sent() close_and_release()"
                        )
                    self.close()
                    try:
                        self._write_lock.release()
                    except:
                        pass
                    done()

                def wait_for_packet_sent():
                    log(
                        "flush_then_close: wait_for_packet_sent() queue.empty()=%s, closed=%s",
                        self._write_queue.empty(), self._closed)
                    if self._write_queue.empty() or self._closed:
                        #it got sent, we're done!
                        close_and_release()
                        return False
                    return not self._closed  #run until we manage to close (here or via the timeout)

                def packet_queued(*args):
                    #if we're here, we have the lock and the packet is in the write queue
                    log("flush_then_close: packet_queued() closed=%s",
                        self._closed)
                    if wait_for_packet_sent():
                        #check again every 100ms
                        self.timeout_add(100, wait_for_packet_sent)

                self._add_chunks_to_queue(chunks,
                                          start_send_cb=None,
                                          end_send_cb=packet_queued)
                #just in case wait_for_packet_sent never fires:
                self.timeout_add(5 * 1000, close_and_release)

        def wait_for_write_lock(timeout=100):
            if not self._write_lock.acquire(False):
                if timeout <= 0:
                    log("flush_then_close: timeout waiting for the write lock")
                    self.close()
                    done()
                else:
                    log(
                        "flush_then_close: write lock is busy, will retry %s more times",
                        timeout)
                    self.timeout_add(10, wait_for_write_lock, timeout - 1)
            else:
                log("flush_then_close: acquired the write lock")
                #we have the write lock - we MUST free it!
                wait_for_queue()

        #normal codepath:
        # -> wait_for_write_lock
        # -> wait_for_queue
        # -> _add_chunks_to_queue
        # -> packet_queued
        # -> wait_for_packet_sent
        # -> close_and_release
        log("flush_then_close: wait_for_write_lock()")
        wait_for_write_lock()

    def close(self):
        log("Protocol.close() closed=%s, connection=%s", self._closed,
            self._conn)
        if self._closed:
            return
        self._closed = True
        self.idle_add(self._process_packet_cb, self,
                      [Protocol.CONNECTION_LOST])
        c = self._conn
        if c:
            try:
                log("Protocol.close() calling %s", c.close)
                c.close()
                if self._log_stats is None and self._conn.input_bytecount == 0 and self._conn.output_bytecount == 0:
                    #no data sent or received, skip logging of stats:
                    self._log_stats = False
                if self._log_stats:
                    from xpra.simple_stats import std_unit, std_unit_dec
                    log.info(
                        "connection closed after %s packets received (%s bytes) and %s packets sent (%s bytes)",
                        std_unit(self.input_packetcount),
                        std_unit_dec(self._conn.input_bytecount),
                        std_unit(self.output_packetcount),
                        std_unit_dec(self._conn.output_bytecount))
            except:
                log.error("error closing %s", self._conn, exc_info=True)
            self._conn = None
        self.terminate_queue_threads()
        self.idle_add(self.clean)
        log("Protocol.close() done")

    def steal_connection(self, read_callback=None):
        #so we can re-use this connection somewhere else
        #(frees all protocol threads and resources)
        #Note: this method can only be used with non-blocking sockets,
        #and if more than one packet can arrive, the read_callback should be used
        #to ensure that no packets get lost.
        #The caller must call wait_for_io_threads_exit() to ensure that this
        #class is no longer reading from the connection before it can re-use it
        assert not self._closed, "cannot steal a closed connection"
        if read_callback:
            self._read_queue_put = read_callback
        conn = self._conn
        self._closed = True
        self._conn = None
        if conn:
            #this ensures that we exit the untilConcludes() read/write loop
            conn.set_active(False)
        self.terminate_queue_threads()
        return conn

    def clean(self):
        #clear all references to ensure we can get garbage collected quickly:
        self._get_packet_cb = None
        self._encoder = None
        self._write_thread = None
        self._read_thread = None
        self._read_parser_thread = None
        self._write_format_thread = None
        self._process_packet_cb = None

    def terminate_queue_threads(self):
        log("terminate_queue_threads()")
        #the format thread will exit:
        self._get_packet_cb = None
        self._source_has_more.set()
        #make all the queue based threads exit by adding the empty marker:
        exit_queue = Queue()
        for _ in range(10):  #just 2 should be enough!
            exit_queue.put(None)
        try:
            owq = self._write_queue
            self._write_queue = exit_queue
            #discard all elements in the old queue and push the None marker:
            try:
                while owq.qsize() > 0:
                    owq.read(False)
            except:
                pass
            owq.put_nowait(None)
        except:
            pass
        try:
            orq = self._read_queue
            self._read_queue = exit_queue
            #discard all elements in the old queue and push the None marker:
            try:
                while orq.qsize() > 0:
                    orq.read(False)
            except:
                pass
            orq.put_nowait(None)
        except:
            pass
        #just in case the read thread is waiting again:
        self._source_has_more.set()
Exemplo n.º 39
0
 def __init__(self):
     super(Invoker, self).__init__()
     self.queue = Queue()
Exemplo n.º 40
0
    def run(self):
        log("ProxyProcess.run() pid=%s, uid=%s, gid=%s", os.getpid(), os.getuid(), os.getgid())
        #change uid and gid:
        if os.getgid()!=self.gid:
            os.setgid(self.gid)
        if os.getuid()!=self.uid:
            os.setuid(self.uid)
        log("ProxyProcess.run() new uid=%s, gid=%s", os.getuid(), os.getgid())

        if self.env_options:
            #TODO: whitelist env update?
            os.environ.update(self.env_options)
        self.video_init()

        log.info("new proxy started for client %s and server %s", self.client_conn, self.server_conn)

        signal.signal(signal.SIGTERM, self.signal_quit)
        signal.signal(signal.SIGINT, self.signal_quit)
        log("registered signal handler %s", self.signal_quit)

        make_daemon_thread(self.server_message_queue, "server message queue").start()

        if self.create_control_socket():
            self.control_socket_thread = make_daemon_thread(self.control_socket_loop, "control")
            self.control_socket_thread.start()

        self.main_queue = Queue()
        #setup protocol wrappers:
        self.server_packets = Queue(PROXY_QUEUE_SIZE)
        self.client_packets = Queue(PROXY_QUEUE_SIZE)
        self.client_protocol = Protocol(self, self.client_conn, self.process_client_packet, self.get_client_packet)
        self.client_protocol.restore_state(self.client_state)
        self.server_protocol = Protocol(self, self.server_conn, self.process_server_packet, self.get_server_packet)
        #server connection tweaks:
        self.server_protocol.large_packets.append("draw")
        self.server_protocol.large_packets.append("window-icon")
        self.server_protocol.large_packets.append("keymap-changed")
        self.server_protocol.large_packets.append("server-settings")
        self.server_protocol.set_compression_level(self.session_options.get("compression_level", 0))
        self.server_protocol.enable_default_encoder()

        self.lost_windows = set()
        self.encode_queue = Queue()
        self.encode_thread = make_daemon_thread(self.encode_loop, "encode")
        self.encode_thread.start()

        log("starting network threads")
        self.server_protocol.start()
        self.client_protocol.start()

        #forward the hello packet:
        hello_packet = ("hello", self.filter_client_caps(self.caps))
        self.queue_server_packet(hello_packet)
        self.timeout_add(VIDEO_TIMEOUT*1000, self.timeout_video_encoders)

        try:
            try:
                self.run_queue()
            except KeyboardInterrupt as e:
                self.stop(str(e))
        finally:
            log("ProxyProcess.run() ending %s", os.getpid())
Exemplo n.º 41
0
class subprocess_callee(object):
    """
    This is the callee side, wrapping the gobject we want to interact with.
    All the input received will be converted to method calls on the wrapped object.
    Subclasses should register the signal handlers they want to see exported back to the caller.
    The convenience connect_export(signal-name, *args) can be used to forward signals unmodified.
    You can also call send() to pass packets back to the caller.
    (there is no validation of which signals are valid or not)
    """
    def __init__(self, input_filename="-", output_filename="-", wrapped_object=None, method_whitelist=None):
        self.name = ""
        self._input = None
        self._output = None
        self.input_filename = input_filename
        self.output_filename = output_filename
        self.method_whitelist = method_whitelist
        self.large_packets = []
        #the gobject instance which is wrapped:
        self.wrapped_object = wrapped_object
        self.send_queue = Queue()
        self.protocol = None
        register_os_signals(self.handle_signal)
        self.setup_mainloop()

    def setup_mainloop(self):
        glib = import_glib()
        self.mainloop = glib.MainLoop()
        self.idle_add = glib.idle_add
        self.timeout_add = glib.timeout_add
        self.source_remove = glib.source_remove


    def connect_export(self, signal_name, *user_data):
        """ gobject style signal registration for the wrapped object,
            the signals will automatically be forwarded to the wrapper process
            using send(signal_name, *signal_args, *user_data)
        """
        log("connect_export%s", [signal_name] + list(user_data))
        args = list(user_data) + [signal_name]
        self.wrapped_object.connect(signal_name, self.export, *args)

    def export(self, *args):
        signal_name = args[-1]
        log("export(%s, ...)", signal_name)
        data = args[1:-1]
        self.send(signal_name, *tuple(data))


    def start(self):
        self.protocol = self.make_protocol()
        self.protocol.start()
        try:
            self.run()
            return 0
        except KeyboardInterrupt as e:
            log("start() KeyboardInterrupt %s", e)
            if str(e):
                log.warn("%s", e)
            return 0
        except Exception:
            log.error("error in main loop", exc_info=True)
            return 1
        finally:
            log("run() ended, calling cleanup and protocol close")
            self.cleanup()
            if self.protocol:
                self.protocol.close()
                self.protocol = None
            i = self._input
            if i:
                self._input = None
                try:
                    i.close()
                except (OSError, IOError):
                    log("%s.close()", i, exc_info=True)
            o = self._output
            if o:
                self._output = None
                try:
                    o.close()
                except (OSError, IOError):
                    log("%s.close()", o, exc_info=True)

    def make_protocol(self):
        #figure out where we read from and write to:
        if self.input_filename=="-":
            #disable stdin buffering:
            self._input = os.fdopen(sys.stdin.fileno(), 'rb', 0)
            setbinarymode(self._input.fileno())
        else:
            self._input = open(self.input_filename, 'rb')
        if self.output_filename=="-":
            #disable stdout buffering:
            self._output = os.fdopen(sys.stdout.fileno(), 'wb', 0)
            setbinarymode(self._output.fileno())
        else:
            self._output = open(self.output_filename, 'wb')
        #stdin and stdout wrapper:
        conn = TwoFileConnection(self._output, self._input,
                                 abort_test=None, target=self.name,
                                 socktype=self.name, close_cb=self.net_stop)
        conn.timeout = 0
        protocol = Protocol(self, conn, self.process_packet, get_packet_cb=self.get_packet)
        setup_fastencoder_nocompression(protocol)
        protocol.large_packets = self.large_packets
        return protocol


    def run(self):
        self.mainloop.run()


    def net_stop(self):
        #this is called from the network thread,
        #we use idle add to ensure we clean things up from the main thread
        log("net_stop() will call stop from main thread")
        self.idle_add(self.stop)


    def cleanup(self):
        pass

    def stop(self):
        self.cleanup()
        p = self.protocol
        log("stop() protocol=%s", p)
        if p:
            self.protocol = None
            p.close()
        self.do_stop()

    def do_stop(self):
        log("stop() stopping mainloop %s", self.mainloop)
        self.mainloop.quit()

    def handle_signal(self, sig):
        """ This is for OS signals SIGINT and SIGTERM """
        #next time, just stop:
        register_os_signals(self.signal_stop)
        signame = SIGNAMES.get(sig, sig)
        log("handle_signal(%s) calling stop from main thread", signame)
        self.send("signal", signame)
        self.timeout_add(0, self.cleanup)
        #give time for the network layer to send the signal message
        self.timeout_add(150, self.stop)

    def signal_stop(self, sig):
        """ This time we really want to exit without waiting """
        signame = SIGNAMES.get(sig, sig)
        log("signal_stop(%s) calling stop", signame)
        self.stop()


    def send(self, *args):
        if HEXLIFY_PACKETS:
            args = args[:1]+[hexstr(str(x)[:32]) for x in args[1:]]
        log("send: adding '%s' message (%s items already in queue)", args[0], self.send_queue.qsize())
        self.send_queue.put(args)
        p = self.protocol
        if p:
            p.source_has_more()
        INJECT_FAULT(p)

    def get_packet(self):
        try:
            item = self.send_queue.get(False)
        except Exception:
            item = None
        return (item, None, None, self.send_queue.qsize()>0)

    def process_packet(self, proto, packet):
        command = bytestostr(packet[0])
        if command==Protocol.CONNECTION_LOST:
            log("connection-lost: %s, calling stop", packet[1:])
            self.net_stop()
            return
        if command==Protocol.GIBBERISH:
            log.warn("gibberish received:")
            log.warn(" %s", repr_ellipsized(packet[1], limit=80))
            log.warn(" stopping")
            self.net_stop()
            return
        if command=="stop":
            log("received stop message")
            self.net_stop()
            return
        if command=="exit":
            log("received exit message")
            sys.exit(0)
            return
        #make it easier to hookup signals to methods:
        attr = command.replace("-", "_")
        if self.method_whitelist is not None and attr not in self.method_whitelist:
            log.warn("invalid command: %s (not in whitelist: %s)", attr, self.method_whitelist)
            return
        wo = self.wrapped_object
        if not wo:
            log("wrapped object is no more, ignoring method call '%s'", attr)
            return
        method = getattr(wo, attr, None)
        if not method:
            log.warn("unknown command: '%s'", attr)
            log.warn(" packet: '%s'", repr_ellipsized(str(packet)))
            return
        if DEBUG_WRAPPER:
            log("calling %s.%s%s", wo, attr, str(tuple(packet[1:]))[:128])
        self.idle_add(method, *packet[1:])
        INJECT_FAULT(proto)
Exemplo n.º 42
0
class Protocol(object):
    CONNECTION_LOST = "connection-lost"
    GIBBERISH = "gibberish"
    INVALID = "invalid"

    def __init__(self, scheduler, conn, process_packet_cb, get_packet_cb=None):
        """
            You must call this constructor and source_has_more() from the main thread.
        """
        assert scheduler is not None
        assert conn is not None
        self.timeout_add = scheduler.timeout_add
        self.idle_add = scheduler.idle_add
        self._conn = conn
        if FAKE_JITTER > 0:
            from xpra.net.fake_jitter import FakeJitter

            fj = FakeJitter(self.timeout_add, process_packet_cb)
            self._process_packet_cb = fj.process_packet_cb
        else:
            self._process_packet_cb = process_packet_cb
        self._write_queue = Queue(1)
        self._read_queue = Queue(20)
        self._read_queue_put = self._read_queue.put
        # Invariant: if .source is None, then _source_has_more == False
        self._get_packet_cb = get_packet_cb
        # counters:
        self.input_stats = {}
        self.input_packetcount = 0
        self.input_raw_packetcount = 0
        self.output_stats = {}
        self.output_packetcount = 0
        self.output_raw_packetcount = 0
        # initial value which may get increased by client/server after handshake:
        self.max_packet_size = 256 * 1024
        self.abs_max_packet_size = 256 * 1024 * 1024
        self.large_packets = ["hello"]
        self.send_aliases = {}
        self.receive_aliases = {}
        self._log_stats = None  # None here means auto-detect
        self._closed = False
        self.encoder = "none"
        self._encoder = self.noencode
        self.compressor = "none"
        self._compress = compression.nocompress
        self.compression_level = 0
        self.cipher_in = None
        self.cipher_in_name = None
        self.cipher_in_block_size = 0
        self.cipher_out = None
        self.cipher_out_name = None
        self.cipher_out_block_size = 0
        self._write_lock = Lock()
        from xpra.daemon_thread import make_daemon_thread

        self._write_thread = make_daemon_thread(self._write_thread_loop, "write")
        self._read_thread = make_daemon_thread(self._read_thread_loop, "read")
        self._read_parser_thread = make_daemon_thread(self._read_parse_thread_loop, "parse")
        self._write_format_thread = make_daemon_thread(self._write_format_thread_loop, "format")
        self._source_has_more = threading.Event()

    STATE_FIELDS = (
        "max_packet_size",
        "large_packets",
        "send_aliases",
        "receive_aliases",
        "cipher_in",
        "cipher_in_name",
        "cipher_in_block_size",
        "cipher_out",
        "cipher_out_name",
        "cipher_out_block_size",
        "compression_level",
        "encoder",
        "compressor",
    )

    def save_state(self):
        state = {}
        for x in Protocol.STATE_FIELDS:
            state[x] = getattr(self, x)
        return state

    def restore_state(self, state):
        assert state is not None
        for x in Protocol.STATE_FIELDS:
            assert x in state, "field %s is missing" % x
            setattr(self, x, state[x])
        # special handling for compressor / encoder which are named objects:
        self.enable_compressor(self.compressor)
        self.enable_encoder(self.encoder)

    def wait_for_io_threads_exit(self, timeout=None):
        for t in (self._read_thread, self._write_thread):
            t.join(timeout)
        exited = True
        for t in (self._read_thread, self._write_thread):
            if t.isAlive():
                log.warn("%s thread of %s has not yet exited (timeout=%s)", t.name, self._conn, timeout)
                exited = False
                break
        return exited

    def set_packet_source(self, get_packet_cb):
        self._get_packet_cb = get_packet_cb

    def set_cipher_in(self, ciphername, iv, password, key_salt, iterations):
        if self.cipher_in_name != ciphername:
            log.info("receiving data using %s encryption", ciphername)
            self.cipher_in_name = ciphername
        log("set_cipher_in%s", (ciphername, iv, password, key_salt, iterations))
        self.cipher_in, self.cipher_in_block_size = get_cipher(ciphername, iv, password, key_salt, iterations)

    def set_cipher_out(self, ciphername, iv, password, key_salt, iterations):
        if self.cipher_out_name != ciphername:
            log.info("sending data using %s encryption", ciphername)
            self.cipher_out_name = ciphername
        log("set_cipher_out%s", (ciphername, iv, password, key_salt, iterations))
        self.cipher_out, self.cipher_out_block_size = get_cipher(ciphername, iv, password, key_salt, iterations)

    def __repr__(self):
        return "Protocol(%s)" % self._conn

    def get_threads(self):
        return [
            x
            for x in [self._write_thread, self._read_thread, self._read_parser_thread, self._write_format_thread]
            if x is not None
        ]

    def get_info(self, alias_info=True):
        info = {
            "input.packetcount": self.input_packetcount,
            "input.raw_packetcount": self.input_raw_packetcount,
            "input.cipher": self.cipher_in_name or "",
            "output.packetcount": self.output_packetcount,
            "output.raw_packetcount": self.output_raw_packetcount,
            "output.cipher": self.cipher_out_name or "",
            "large_packets": self.large_packets,
            "compression_level": self.compression_level,
            "max_packet_size": self.max_packet_size,
        }
        updict(info, "input.count", self.input_stats)
        updict(info, "output.count", self.output_stats)
        c = self._compress
        if c:
            info["compressor"] = compression.get_compressor_name(self._compress)
        e = self._encoder
        if e:
            if self._encoder == self.noencode:
                info["encoder"] = "noencode"
            else:
                info["encoder"] = packet_encoding.get_encoder_name(self._encoder)
        if alias_info:
            for k, v in self.send_aliases.items():
                info["send_alias." + str(k)] = v
                info["send_alias." + str(v)] = k
            for k, v in self.receive_aliases.items():
                info["receive_alias." + str(k)] = v
                info["receive_alias." + str(v)] = k
        c = self._conn
        if c:
            try:
                info.update(self._conn.get_info())
            except:
                log.error("error collecting connection information on %s", self._conn, exc_info=True)
        info["has_more"] = self._source_has_more.is_set()
        for t in (self._write_thread, self._read_thread, self._read_parser_thread, self._write_format_thread):
            if t:
                info["thread.%s" % t.name] = t.is_alive()
        return info

    def start(self):
        def do_start():
            if not self._closed:
                self._write_thread.start()
                self._read_thread.start()
                self._read_parser_thread.start()
                self._write_format_thread.start()

        self.idle_add(do_start)

    def send_now(self, packet):
        if self._closed:
            log("send_now(%s ...) connection is closed already, not sending", packet[0])
            return
        log("send_now(%s ...)", packet[0])
        assert self._get_packet_cb == None, (
            "cannot use send_now when a packet source exists! (set to %s)" % self._get_packet_cb
        )

        def packet_cb():
            self._get_packet_cb = None
            return (packet,)

        self._get_packet_cb = packet_cb
        self.source_has_more()

    def source_has_more(self):
        self._source_has_more.set()

    def _write_format_thread_loop(self):
        log("write_format_thread_loop starting")
        try:
            while not self._closed:
                self._source_has_more.wait()
                if self._closed:
                    return
                self._source_has_more.clear()
                self._add_packet_to_queue(*self._get_packet_cb())
        except:
            self._internal_error("error in network packet write/format", True)

    def _add_packet_to_queue(self, packet, start_send_cb=None, end_send_cb=None, has_more=False):
        if has_more:
            self._source_has_more.set()
        if packet is None:
            return
        log("add_packet_to_queue(%s ...)", packet[0])
        chunks, proto_flags = self.encode(packet)
        with self._write_lock:
            if self._closed:
                return
            self._add_chunks_to_queue(chunks, proto_flags, start_send_cb, end_send_cb)

    def _add_chunks_to_queue(self, chunks, proto_flags, start_send_cb=None, end_send_cb=None):
        """ the write_lock must be held when calling this function """
        counter = 0
        items = []
        for index, level, data in chunks:
            scb, ecb = None, None
            # fire the start_send_callback just before the first packet is processed:
            if counter == 0:
                scb = start_send_cb
            # fire the end_send callback when the last packet (index==0) makes it out:
            if index == 0:
                ecb = end_send_cb
            payload_size = len(data)
            actual_size = payload_size
            if self.cipher_out:
                proto_flags |= FLAGS_CIPHER
                # note: since we are padding: l!=len(data)
                padding = (self.cipher_out_block_size - len(data) % self.cipher_out_block_size) * " "
                if len(padding) == 0:
                    padded = data
                else:
                    padded = data + padding
                actual_size = payload_size + len(padding)
                assert len(padded) == actual_size
                data = self.cipher_out.encrypt(padded)
                assert len(data) == actual_size
                log("sending %s bytes encrypted with %s padding", payload_size, len(padding))
            if proto_flags & FLAGS_NOHEADER:
                # for plain/text packets (ie: gibberish response)
                items.append((data, scb, ecb))
            elif actual_size < PACKET_JOIN_SIZE:
                if type(data) not in JOIN_TYPES:
                    data = bytes(data)
                header_and_data = pack_header(proto_flags, level, index, payload_size) + data
                items.append((header_and_data, scb, ecb))
            else:
                header = pack_header(proto_flags, level, index, payload_size)
                items.append((header, scb, None))
                items.append((strtobytes(data), None, ecb))
            counter += 1
        self._write_queue.put(items)
        self.output_packetcount += 1

    def verify_packet(self, packet):
        """ look for None values which may have caused the packet to fail encoding """
        if type(packet) != list:
            return
        assert len(packet) > 0
        tree = ["'%s' packet" % packet[0]]
        self.do_verify_packet(tree, packet)

    def do_verify_packet(self, tree, packet):
        def err(msg):
            log.error("%s in %s", msg, "->".join(tree))

        def new_tree(append):
            nt = tree[:]
            nt.append(append)
            return nt

        if packet is None:
            return err("None value")
        if type(packet) == list:
            for i, x in enumerate(packet):
                self.do_verify_packet(new_tree("[%s]" % i), x)
        elif type(packet) == dict:
            for k, v in packet.items():
                self.do_verify_packet(new_tree("key for value='%s'" % str(v)), k)
                self.do_verify_packet(new_tree("value for key='%s'" % str(k)), v)

    def enable_default_encoder(self):
        opts = packet_encoding.get_enabled_encoders()
        assert len(opts) > 0, "no packet encoders available!"
        self.enable_encoder(opts[0])

    def enable_encoder_from_caps(self, caps):
        opts = packet_encoding.get_enabled_encoders(order=packet_encoding.PERFORMANCE_ORDER)
        log("enable_encoder_from_caps(..) options=%s", opts)
        for e in opts:
            if caps.boolget(e, e == "bencode"):
                self.enable_encoder(e)
                return True
        log.error("no matching packet encoder found!")
        return False

    def enable_encoder(self, e):
        self._encoder = packet_encoding.get_encoder(e)
        self.encoder = e
        log("enable_encoder(%s): %s", e, self._encoder)

    def enable_default_compressor(self):
        opts = compression.get_enabled_compressors()
        if len(opts) > 0:
            self.enable_compressor(opts[0])
        else:
            self.enable_compressor("none")

    def enable_compressor_from_caps(self, caps):
        if self.compression_level == 0:
            self.enable_compressor("none")
            return
        opts = compression.get_enabled_compressors(order=compression.PERFORMANCE_ORDER)
        log("enable_compressor_from_caps(..) options=%s", opts)
        for c in opts:  # ie: [zlib, lz4, lzo]
            if caps.boolget(c):
                self.enable_compressor(c)
                return
        log.warn("compression disabled: no matching compressor found")
        self.enable_compressor("none")

    def enable_compressor(self, compressor):
        self._compress = compression.get_compressor(compressor)
        self.compressor = compressor
        log("enable_compressor(%s): %s", compressor, self._compress)

    def noencode(self, data):
        # just send data as a string for clients that don't understand xpra packet format:
        if sys.version_info[0] >= 3:
            import codecs

            def b(x):
                if type(x) == bytes:
                    return x
                return codecs.latin_1_encode(x)[0]

        else:

            def b(x):  # @DuplicatedSignature
                return x

        return b(": ".join(str(x) for x in data) + "\n"), FLAGS_NOHEADER

    def encode(self, packet_in):
        """
        Given a packet (tuple or list of items), converts it for the wire.
        This method returns all the binary packets to send, as an array of:
        (index, compression_level and compression flags, binary_data)
        The index, if positive indicates the item to populate in the packet
        whose index is zero.
        ie: ["blah", [large binary data], "hello", 200]
        may get converted to:
        [
            (1, compression_level, [large binary data now zlib compressed]),
            (0,                 0, bencoded/rencoded(["blah", '', "hello", 200]))
        ]
        """
        packets = []
        packet = list(packet_in)
        level = self.compression_level
        size_check = LARGE_PACKET_SIZE
        min_comp_size = 378
        for i in range(1, len(packet)):
            item = packet[i]
            ti = type(item)
            if ti in (int, long, bool, dict, list, tuple):
                continue
            l = len(item)
            if ti == Uncompressed:
                # this is a marker used to tell us we should compress it now
                # (used by the client for clipboard data)
                item = item.compress()
                packet[i] = item
                ti = type(item)
                # (it may now be a "Compressed" item and be processed further)
            if ti in (Compressed, LevelCompressed):
                # already compressed data (usually pixels, cursors, etc)
                if not item.can_inline or l > INLINE_SIZE:
                    il = 0
                    if ti == LevelCompressed:
                        # unlike Compressed (usually pixels, decompressed in the paint thread),
                        # LevelCompressed is decompressed by the network layer
                        # so we must tell it how to do that and pass the level flag
                        il = item.level
                    packets.append((i, il, item.data))
                    packet[i] = ""
                else:
                    # data is small enough, inline it:
                    packet[i] = item.data
                    min_comp_size += l
                    size_check += l
            elif ti in (str, bytes) and level > 0 and l > LARGE_PACKET_SIZE:
                log.warn(
                    "found a large uncompressed item in packet '%s' at position %s: %s bytes", packet[0], i, len(item)
                )
                # add new binary packet with large item:
                cl, cdata = self._compress(item, level)
                packets.append((i, cl, cdata))
                # replace this item with an empty string placeholder:
                packet[i] = ""
            elif ti not in (str, bytes):
                log.warn("unexpected data type %s in %s packet: %s", ti, packet[0], repr_ellipsized(item))
        # now the main packet (or what is left of it):
        packet_type = packet[0]
        self.output_stats[packet_type] = self.output_stats.get(packet_type, 0) + 1
        if USE_ALIASES and self.send_aliases and packet_type in self.send_aliases:
            # replace the packet type with the alias:
            packet[0] = self.send_aliases[packet_type]
        try:
            main_packet, proto_version = self._encoder(packet)
        except Exception as e:
            if self._closed:
                return [], 0
            log.error("failed to encode packet: %s", packet, exc_info=True)
            # make the error a bit nicer to parse: undo aliases:
            packet[0] = packet_type
            self.verify_packet(packet)
            raise e
        if len(main_packet) > size_check and packet_in[0] not in self.large_packets:
            log.warn(
                "found large packet (%s bytes): %s, argument types:%s, sizes: %s, packet head=%s",
                len(main_packet),
                packet_in[0],
                [type(x) for x in packet[1:]],
                [len(str(x)) for x in packet[1:]],
                repr_ellipsized(packet),
            )
        # compress, but don't bother for small packets:
        if level > 0 and len(main_packet) > min_comp_size:
            cl, cdata = self._compress(main_packet, level)
            packets.append((0, cl, cdata))
        else:
            packets.append((0, 0, main_packet))
        return packets, proto_version

    def set_compression_level(self, level):
        # this may be used next time encode() is called
        assert level >= 0 and level <= 10, "invalid compression level: %s (must be between 0 and 10" % level
        self.compression_level = level

    def _io_thread_loop(self, name, callback):
        try:
            log("io_thread_loop(%s, %s) loop starting", name, callback)
            while not self._closed:
                callback()
            log("io_thread_loop(%s, %s) loop ended, closed=%s", name, callback, self._closed)
        except ConnectionClosedException as e:
            if not self._closed:
                self._internal_error("%s connection %s closed: %s" % (name, self._conn, e))
        except (OSError, IOError, socket_error) as e:
            if not self._closed:
                self._internal_error(
                    "%s connection %s reset: %s" % (name, self._conn, e), exc_info=e.args[0] not in ABORT
                )
        except:
            # can happen during close(), in which case we just ignore:
            if not self._closed:
                log.error("%s error on %s", name, self._conn, exc_info=True)
                self.close()

    def _write_thread_loop(self):
        self._io_thread_loop("write", self._write)

    def _write(self):
        items = self._write_queue.get()
        # Used to signal that we should exit:
        if items is None:
            log("write thread: empty marker, exiting")
            self.close()
            return
        for buf, start_cb, end_cb in items:
            con = self._conn
            if not con:
                return
            if start_cb:
                try:
                    start_cb(con.output_bytecount)
                except:
                    if not self._closed:
                        log.error("error on %s", start_cb, exc_info=True)
            while buf and not self._closed:
                written = con.write(buf)
                if written:
                    buf = buf[written:]
                    self.output_raw_packetcount += 1
            if end_cb:
                try:
                    end_cb(self._conn.output_bytecount)
                except:
                    if not self._closed:
                        log.error("error on %s", end_cb, exc_info=True)

    def _read_thread_loop(self):
        self._io_thread_loop("read", self._read)

    def _read(self):
        buf = self._conn.read(READ_BUFFER_SIZE)
        # log("read thread: got data of size %s: %s", len(buf), repr_ellipsized(buf))
        # add to the read queue (or whatever takes its place - see steal_connection)
        self._read_queue_put(buf)
        if not buf:
            log("read thread: eof")
            self.close()
            return
        self.input_raw_packetcount += 1

    def _internal_error(self, message="", exc_info=False):
        log.error("internal error: %s", message, exc_info=exc_info)
        self.idle_add(self._connection_lost, message)

    def _connection_lost(self, message="", exc_info=False):
        log("connection lost: %s", message, exc_info=exc_info)
        self.close()
        return False

    def invalid(self, msg, data):
        self.idle_add(self._process_packet_cb, self, [Protocol.INVALID, msg, data])
        # Then hang up:
        self.timeout_add(1000, self._connection_lost, msg)

    def gibberish(self, msg, data):
        self.idle_add(self._process_packet_cb, self, [Protocol.GIBBERISH, msg, data])
        # Then hang up:
        self.timeout_add(1000, self._connection_lost, msg)

    # delegates to invalid_header()
    # (so this can more easily be intercepted and overriden
    # see tcp-proxy)
    def _invalid_header(self, data):
        self.invalid_header(self, data)

    def invalid_header(self, proto, data):
        err = "invalid packet header: '%s'" % binascii.hexlify(data[:8])
        if len(data) > 1:
            err += " read buffer=%s" % repr_ellipsized(data)
        self.gibberish(err, data)

    def _read_parse_thread_loop(self):
        log("read_parse_thread_loop starting")
        try:
            self.do_read_parse_thread_loop()
        except:
            self._internal_error("error in network packet reading/parsing", True)

    def do_read_parse_thread_loop(self):
        """
            Process the individual network packets placed in _read_queue.
            Concatenate the raw packet data, then try to parse it.
            Extract the individual packets from the potentially large buffer,
            saving the rest of the buffer for later, and optionally decompress this data
            and re-construct the one python-object-packet from potentially multiple packets (see packet_index).
            The 8 bytes packet header gives us information on the packet index, packet size and compression.
            The actual processing of the packet is done via the callback process_packet_cb,
            this will be called from this parsing thread so any calls that need to be made
            from the UI thread will need to use a callback (usually via 'idle_add')
        """
        read_buffer = None
        payload_size = -1
        padding = None
        packet_index = 0
        compression_level = False
        packet = None
        raw_packets = {}
        while not self._closed:
            buf = self._read_queue.get()
            if not buf:
                log("read thread: empty marker, exiting")
                self.idle_add(self.close)
                return
            if read_buffer:
                read_buffer = read_buffer + buf
            else:
                read_buffer = buf
            bl = len(read_buffer)
            while not self._closed:
                packet = None
                bl = len(read_buffer)
                if bl <= 0:
                    break
                if payload_size < 0:
                    if read_buffer[0] not in ("P", ord("P")):
                        self._invalid_header(read_buffer)
                        return
                    if bl < 8:
                        break  # packet still too small
                    # packet format: struct.pack('cBBBL', ...) - 8 bytes
                    _, protocol_flags, compression_level, packet_index, data_size = unpack_header(read_buffer[:8])

                    # sanity check size (will often fail if not an xpra client):
                    if data_size > self.abs_max_packet_size:
                        self._invalid_header(read_buffer)
                        return

                    bl = len(read_buffer) - 8
                    if protocol_flags & FLAGS_CIPHER:
                        if self.cipher_in_block_size == 0 or not self.cipher_in_name:
                            log.warn(
                                "received cipher block but we don't have a cipher to decrypt it with, not an xpra client?"
                            )
                            self._invalid_header(read_buffer)
                            return
                        padding = (self.cipher_in_block_size - data_size % self.cipher_in_block_size) * " "
                        payload_size = data_size + len(padding)
                    else:
                        # no cipher, no padding:
                        padding = None
                        payload_size = data_size
                    assert payload_size > 0
                    read_buffer = read_buffer[8:]

                    if payload_size > self.max_packet_size:
                        # this packet is seemingly too big, but check again from the main UI thread
                        # this gives 'set_max_packet_size' a chance to run from "hello"
                        def check_packet_size(size_to_check, packet_header):
                            if self._closed:
                                return False
                            log(
                                "check_packet_size(%s, 0x%s) limit is %s",
                                size_to_check,
                                repr_ellipsized(packet_header),
                                self.max_packet_size,
                            )
                            if size_to_check > self.max_packet_size:
                                msg = "packet size requested is %s but maximum allowed is %s" % (
                                    size_to_check,
                                    self.max_packet_size,
                                )
                                self.invalid(msg, packet_header)
                            return False

                        self.timeout_add(1000, check_packet_size, payload_size, read_buffer[:32])

                if bl < payload_size:
                    # incomplete packet, wait for the rest to arrive
                    break

                # chop this packet from the buffer:
                if len(read_buffer) == payload_size:
                    raw_string = read_buffer
                    read_buffer = ""
                else:
                    raw_string = read_buffer[:payload_size]
                    read_buffer = read_buffer[payload_size:]
                # decrypt if needed:
                data = raw_string
                if self.cipher_in and protocol_flags & FLAGS_CIPHER:
                    log("received %s encrypted bytes with %s padding", payload_size, len(padding))
                    data = self.cipher_in.decrypt(raw_string)
                    if padding:

                        def debug_str(s):
                            try:
                                return list(bytearray(s))
                            except:
                                return list(str(s))

                        if not data.endswith(padding):
                            log(
                                "decryption failed: string does not end with '%s': %s (%s) -> %s (%s)",
                                padding,
                                debug_str(raw_string),
                                type(raw_string),
                                debug_str(data),
                                type(data),
                            )
                            self._internal_error("encryption error (wrong key?)")
                            return
                        data = data[: -len(padding)]
                # uncompress if needed:
                if compression_level > 0:
                    try:
                        data = decompress(data, compression_level)
                    except InvalidCompressionException as e:
                        self.invalid("invalid compression: %s" % e, data)
                        return
                    except Exception as e:
                        ctype = compression.get_compression_type(compression_level)
                        log("%s packet decompression failed", ctype, exc_info=True)
                        msg = "%s packet decompression failed" % ctype
                        if self.cipher_in:
                            msg += " (invalid encryption key?)"
                        else:
                            msg += " %s" % e
                        return self.gibberish(msg, data)

                if self.cipher_in and not (protocol_flags & FLAGS_CIPHER):
                    self.invalid("unencrypted packet dropped", data)
                    return

                if self._closed:
                    return
                if packet_index > 0:
                    # raw packet, store it and continue:
                    raw_packets[packet_index] = data
                    payload_size = -1
                    packet_index = 0
                    if len(raw_packets) >= 4:
                        self.invalid("too many raw packets: %s" % len(raw_packets), data)
                        return
                    continue
                # final packet (packet_index==0), decode it:
                try:
                    packet = decode(data, protocol_flags)
                except InvalidPacketEncodingException as e:
                    self.invalid("invalid packet encoding: %s" % e, data)
                    return
                except ValueError as e:
                    etype = packet_encoding.get_packet_encoding_type(protocol_flags)
                    log.error("failed to parse %s packet: %s", etype, e, exc_info=not self._closed)
                    if self._closed:
                        return
                    log("failed to parse %s packet: %s", etype, binascii.hexlify(data))
                    msg = "packet index=%s, packet size=%s, buffer size=%s, error=%s" % (
                        packet_index,
                        payload_size,
                        bl,
                        e,
                    )
                    self.gibberish("failed to parse %s packet" % etype, data)
                    return

                if self._closed:
                    return
                payload_size = -1
                padding = None
                # add any raw packets back into it:
                if raw_packets:
                    for index, raw_data in raw_packets.items():
                        # replace placeholder with the raw_data packet data:
                        packet[index] = raw_data
                    raw_packets = {}

                packet_type = packet[0]
                if self.receive_aliases and type(packet_type) == int and packet_type in self.receive_aliases:
                    packet_type = self.receive_aliases.get(packet_type)
                    packet[0] = packet_type
                self.input_stats[packet_type] = self.output_stats.get(packet_type, 0) + 1

                self.input_packetcount += 1
                log("processing packet %s", packet_type)
                self._process_packet_cb(self, packet)
                packet = None

    def flush_then_close(self, last_packet, done_callback=None):
        """ Note: this is best effort only
            the packet may not get sent.

            We try to get the write lock,
            we try to wait for the write queue to flush
            we queue our last packet,
            we wait again for the queue to flush,
            then no matter what, we close the connection and stop the threads.
        """

        def done():
            if done_callback:
                done_callback()

        if self._closed:
            log("flush_then_close: already closed")
            return done()

        def wait_for_queue(timeout=10):
            # IMPORTANT: if we are here, we have the write lock held!
            if not self._write_queue.empty():
                # write queue still has stuff in it..
                if timeout <= 0:
                    log("flush_then_close: queue still busy, closing without sending the last packet")
                    self._write_lock.release()
                    self.close()
                    done()
                else:
                    log("flush_then_close: still waiting for queue to flush")
                    self.timeout_add(100, wait_for_queue, timeout - 1)
            else:
                log("flush_then_close: queue is now empty, sending the last packet and closing")
                chunks, proto_flags = self.encode(last_packet)

                def close_and_release():
                    log("flush_then_close: wait_for_packet_sent() close_and_release()")
                    self.close()
                    try:
                        self._write_lock.release()
                    except:
                        pass
                    done()

                def wait_for_packet_sent():
                    log(
                        "flush_then_close: wait_for_packet_sent() queue.empty()=%s, closed=%s",
                        self._write_queue.empty(),
                        self._closed,
                    )
                    if self._write_queue.empty() or self._closed:
                        # it got sent, we're done!
                        close_and_release()
                        return False
                    return not self._closed  # run until we manage to close (here or via the timeout)

                def packet_queued(*args):
                    # if we're here, we have the lock and the packet is in the write queue
                    log("flush_then_close: packet_queued() closed=%s", self._closed)
                    if wait_for_packet_sent():
                        # check every 100ms
                        self.timeout_add(100, wait_for_packet_sent)

                self._add_chunks_to_queue(chunks, proto_flags, start_send_cb=None, end_send_cb=packet_queued)
                # just in case wait_for_packet_sent never fires:
                self.timeout_add(5 * 1000, close_and_release)

        def wait_for_write_lock(timeout=100):
            if not self._write_lock.acquire(False):
                if timeout <= 0:
                    log("flush_then_close: timeout waiting for the write lock")
                    self.close()
                    done()
                else:
                    log("flush_then_close: write lock is busy, will retry %s more times", timeout)
                    self.timeout_add(10, wait_for_write_lock, timeout - 1)
            else:
                log("flush_then_close: acquired the write lock")
                # we have the write lock - we MUST free it!
                wait_for_queue()

        # normal codepath:
        # -> wait_for_write_lock
        # -> wait_for_queue
        # -> _add_chunks_to_queue
        # -> packet_queued
        # -> wait_for_packet_sent
        # -> close_and_release
        log("flush_then_close: wait_for_write_lock()")
        wait_for_write_lock()

    def close(self):
        log("close() closed=%s", self._closed)
        if self._closed:
            return
        self._closed = True
        self.idle_add(self._process_packet_cb, self, [Protocol.CONNECTION_LOST])
        if self._conn:
            try:
                self._conn.close()
                if self._log_stats is None and self._conn.input_bytecount == 0 and self._conn.output_bytecount == 0:
                    # no data sent or received, skip logging of stats:
                    self._log_stats = False
                if self._log_stats:
                    from xpra.simple_stats import std_unit, std_unit_dec

                    log.info(
                        "connection closed after %s packets received (%s bytes) and %s packets sent (%s bytes)",
                        std_unit(self.input_packetcount),
                        std_unit_dec(self._conn.input_bytecount),
                        std_unit(self.output_packetcount),
                        std_unit_dec(self._conn.output_bytecount),
                    )
            except:
                log.error("error closing %s", self._conn, exc_info=True)
            self._conn = None
        self.terminate_queue_threads()
        self.idle_add(self.clean)

    def steal_connection(self, read_callback=None):
        # so we can re-use this connection somewhere else
        # (frees all protocol threads and resources)
        # Note: this method can only be used with non-blocking sockets,
        # and if more than one packet can arrive, the read_callback should be used
        # to ensure that no packets get lost.
        # The caller must call wait_for_io_threads_exit() to ensure that this
        # class is no longer reading from the connection before it can re-use it
        assert not self._closed
        if read_callback:
            self._read_queue_put = read_callback
        conn = self._conn
        self._closed = True
        self._conn = None
        if conn:
            # this ensures that we exit the untilConcludes() read/write loop
            conn.set_active(False)
        self.terminate_queue_threads()
        return conn

    def clean(self):
        # clear all references to ensure we can get garbage collected quickly:
        self._get_packet_cb = None
        self._encoder = None
        self._write_thread = None
        self._read_thread = None
        self._read_parser_thread = None
        self._write_format_thread = None
        self._process_packet_cb = None

    def terminate_queue_threads(self):
        log("terminate_queue_threads()")
        # the format thread will exit since closed is set too:
        self._source_has_more.set()
        # make the threads exit by adding the empty marker:
        exit_queue = Queue()
        for _ in range(10):  # just 2 should be enough!
            exit_queue.put(None)
        try:
            owq = self._write_queue
            self._write_queue = exit_queue
            owq.put_nowait(None)
        except:
            pass
        try:
            orq = self._read_queue
            self._read_queue = exit_queue
            orq.put_nowait(None)
        except:
            pass
Exemplo n.º 43
0
class Worker_Thread(Thread):

    def __init__(self):
        Thread.__init__(self, name="Worker_Thread")
        self.items = Queue()
        self.exit = False
        self.setDaemon(True)

    def stop(self, force=False):
        if force:
            if self.items.qsize()>0:
                log.warn("Worker_Thread.stop(%s) %s items in work queue will not run!", force, self.items.qsize())
            self.exit = True
        else:
            if self.items.qsize()>0:
                log.info("waiting for %s items in work queue to complete", self.items.qsize())
        debug("Worker_Thread.stop(%s) %s items in work queue: ", force, self.items)
        self.items.put(None)

    def add(self, item):
        if self.items.qsize()>10:
            log.warn("Worker_Thread.items queue size is %s", self.items.qsize())
        self.items.put(item)

    def run(self):
        debug("Worker_Thread.run() starting")
        while not self.exit:
            item = self.items.get()
            if item is None:
                break
            try:
                debug("Worker_Thread.run() calling %s (queue size=%s)", item, self.items.qsize())
                item()
            except:
                log.error("Worker_Thread.run() error on %s", item, exc_info=True)
        debug("Worker_Thread.run() ended")
Exemplo n.º 44
0
class Worker_Thread(Thread):
    """
        A background thread which calls the functions we post to it.
        The functions are placed in a queue and only called once,
        when this thread gets around to it.
    """
    def __init__(self):
        Thread.__init__(self, name="Worker_Thread")
        self.items = Queue()
        self.exit = False
        self.setDaemon(True)

    def __repr__(self):
        return "Worker_Thread(items=%s, exit=%s)" % (self.items.qsize(),
                                                     self.exit)

    def stop(self, force=False):
        if self.exit:
            return
        if force:
            if self.items.qsize() > 0:
                log.warn("Worker stop: %s items in the queue will not be run!",
                         self.items.qsize())
                self.items.put(None)
                self.items = Queue()
            self.exit = True
        else:
            if self.items.qsize() > 0:
                log.info("waiting for %s items in work queue to complete",
                         self.items.qsize())
        debug("Worker_Thread.stop(%s) %s items in work queue", force,
              self.items)
        self.items.put(None)

    def add(self, item):
        if self.items.qsize() > 10:
            log.warn("Worker_Thread.items queue size is %s",
                     self.items.qsize())
        self.items.put(item)

    def run(self):
        debug("Worker_Thread.run() starting")
        while not self.exit:
            item = self.items.get()
            if item is None:
                debug("Worker_Thread.run() found end of queue marker")
                self.exit = True
                break
            try:
                debug("Worker_Thread.run() calling %s (queue size=%s)", item,
                      self.items.qsize())
                item()
            except:
                log.error("Error in worker thread processing item %s",
                          item,
                          exc_info=True)
        debug("Worker_Thread.run() ended (queue size=%s)", self.items.qsize())
Exemplo n.º 45
0
class Worker_Thread(Thread):
    """
        A background thread which calls the functions we post to it.
        The functions are placed in a queue and only called once,
        when this thread gets around to it.
    """

    def __init__(self):
        Thread.__init__(self, name="Worker_Thread")
        self.items = Queue()
        self.exit = False
        self.setDaemon(True)

    def stop(self, force=False):
        if self.exit:
            return
        if force:
            if self.items.qsize()>0:
                log.warn("Worker_Thread.stop(%s) %s items in work queue will not run!", force, self.items.qsize())
            self.exit = True
        else:
            if self.items.qsize()>0:
                log.info("waiting for %s items in work queue to complete", self.items.qsize())
        debug("Worker_Thread.stop(%s) %s items in work queue: ", force, self.items)
        self.items.put(None)

    def add(self, item):
        if self.items.qsize()>10:
            log.warn("Worker_Thread.items queue size is %s", self.items.qsize())
        self.items.put(item)

    def run(self):
        debug("Worker_Thread.run() starting")
        while not self.exit:
            item = self.items.get()
            if item is None:
                break
            try:
                debug("Worker_Thread.run() calling %s (queue size=%s)", item, self.items.qsize())
                item()
            except:
                log.error("Worker_Thread.run() error on %s", item, exc_info=True)
        debug("Worker_Thread.run() ended")
        self.exit = True
Exemplo n.º 46
0
class ProxyInstanceProcess(Process):

    def __init__(self, uid, gid, env_options, session_options, socket_dir,
                 video_encoder_modules, csc_modules,
                 client_conn, client_state, cipher, encryption_key, server_conn, caps, message_queue):
        Process.__init__(self, name=str(client_conn))
        self.uid = uid
        self.gid = gid
        self.env_options = env_options
        self.session_options = session_options
        self.socket_dir = socket_dir
        self.video_encoder_modules = video_encoder_modules
        self.csc_modules = csc_modules
        self.client_conn = client_conn
        self.client_state = client_state
        self.cipher = cipher
        self.encryption_key = encryption_key
        self.server_conn = server_conn
        self.caps = caps
        log("ProxyProcess%s", (uid, gid, env_options, session_options, socket_dir,
                               video_encoder_modules, csc_modules,
                               client_conn, repr_ellipsized(str(client_state)), cipher, encryption_key, server_conn,
                               "%s: %s.." % (type(caps), repr_ellipsized(str(caps))), message_queue))
        self.client_protocol = None
        self.server_protocol = None
        self.exit = False
        self.main_queue = None
        self.message_queue = message_queue
        self.encode_queue = None            #holds draw packets to encode
        self.encode_thread = None
        self.video_encoding_defs = None
        self.video_encoders = None
        self.video_encoders_last_used_time = None
        self.video_encoder_types = None
        self.video_helper = None
        self.lost_windows = None
        #for handling the local unix domain socket:
        self.control_socket_cleanup = None
        self.control_socket = None
        self.control_socket_thread = None
        self.control_socket_path = None
        self.potential_protocols = []
        self.max_connections = MAX_CONCURRENT_CONNECTIONS

    def server_message_queue(self):
        while True:
            log("waiting for server message on %s", self.message_queue)
            m = self.message_queue.get()
            log("received proxy server message: %s", m)
            if m=="stop":
                self.stop("proxy server request")
                return
            elif m=="socket-handover-complete":
                log("setting sockets to blocking mode: %s", (self.client_conn, self.server_conn))
                #set sockets to blocking mode:
                set_blocking(self.client_conn)
                set_blocking(self.server_conn)
            else:
                log.error("unexpected proxy server message: %s", m)

    def signal_quit(self, signum, frame):
        log.info("")
        log.info("proxy process pid %s got signal %s, exiting", os.getpid(), SIGNAMES.get(signum, signum))
        self.exit = True
        signal.signal(signal.SIGINT, deadly_signal)
        signal.signal(signal.SIGTERM, deadly_signal)
        self.stop(SIGNAMES.get(signum, signum))

    def idle_add(self, fn, *args, **kwargs):
        #we emulate gobject's idle_add using a simple queue
        self.main_queue.put((fn, args, kwargs))

    def timeout_add(self, timeout, fn, *args, **kwargs):
        #emulate gobject's timeout_add using idle add and a Timer
        #using custom functions to cancel() the timer when needed
        def idle_exec():
            v = fn(*args, **kwargs)
            if bool(v):
                self.timeout_add(timeout, fn, *args, **kwargs)
            return False
        def timer_exec():
            #just run via idle_add:
            self.idle_add(idle_exec)
        Timer(timeout/1000.0, timer_exec).start()

    def run(self):
        log("ProxyProcess.run() pid=%s, uid=%s, gid=%s", os.getpid(), getuid(), getgid())
        setuidgid(self.uid, self.gid)
        if self.env_options:
            #TODO: whitelist env update?
            os.environ.update(self.env_options)
        self.video_init()

        log.info("new proxy instance started")
        log.info(" for client %s", self.client_conn)
        log.info(" and server %s", self.server_conn)

        signal.signal(signal.SIGTERM, self.signal_quit)
        signal.signal(signal.SIGINT, self.signal_quit)
        log("registered signal handler %s", self.signal_quit)

        start_thread(self.server_message_queue, "server message queue")

        if not self.create_control_socket():
            #TODO: should send a message to the client
            return
        self.control_socket_thread = start_thread(self.control_socket_loop, "control")

        self.main_queue = Queue()
        #setup protocol wrappers:
        self.server_packets = Queue(PROXY_QUEUE_SIZE)
        self.client_packets = Queue(PROXY_QUEUE_SIZE)
        self.client_protocol = Protocol(self, self.client_conn, self.process_client_packet, self.get_client_packet)
        self.client_protocol.restore_state(self.client_state)
        self.server_protocol = Protocol(self, self.server_conn, self.process_server_packet, self.get_server_packet)
        #server connection tweaks:
        self.server_protocol.large_packets.append("draw")
        self.server_protocol.large_packets.append("window-icon")
        self.server_protocol.large_packets.append("keymap-changed")
        self.server_protocol.large_packets.append("server-settings")
        if self.caps.boolget("file-transfer"):
            self.client_protocol.large_packets.append("send-file")
            self.client_protocol.large_packets.append("send-file-chunk")
            self.server_protocol.large_packets.append("send-file")
            self.server_protocol.large_packets.append("send-file-chunk")
        self.server_protocol.set_compression_level(self.session_options.get("compression_level", 0))
        self.server_protocol.enable_default_encoder()

        self.lost_windows = set()
        self.encode_queue = Queue()
        self.encode_thread = start_thread(self.encode_loop, "encode")

        log("starting network threads")
        self.server_protocol.start()
        self.client_protocol.start()

        self.send_hello()
        self.timeout_add(VIDEO_TIMEOUT*1000, self.timeout_video_encoders)

        try:
            self.run_queue()
        except KeyboardInterrupt as e:
            self.stop(str(e))
        finally:
            log("ProxyProcess.run() ending %s", os.getpid())

    def video_init(self):
        enclog("video_init() loading codecs")
        load_codecs(decoders=False)
        enclog("video_init() will try video encoders: %s", csv(self.video_encoder_modules) or "none")
        self.video_helper = getVideoHelper()
        #only use video encoders (no CSC supported in proxy)
        self.video_helper.set_modules(video_encoders=self.video_encoder_modules)
        self.video_helper.init()

        self.video_encoding_defs = {}
        self.video_encoders = {}
        self.video_encoders_dst_formats = []
        self.video_encoders_last_used_time = {}
        self.video_encoder_types = []

        #figure out which encoders we want to proxy for (if any):
        encoder_types = set()
        for encoding in self.video_helper.get_encodings():
            colorspace_specs = self.video_helper.get_encoder_specs(encoding)
            for colorspace, especs in colorspace_specs.items():
                if colorspace not in ("BGRX", "BGRA", "RGBX", "RGBA"):
                    #only deal with encoders that can handle plain RGB directly
                    continue

                for spec in especs:                             #ie: video_spec("x264")
                    spec_props = spec.to_dict()
                    del spec_props["codec_class"]               #not serializable!
                    spec_props["score_boost"] = 50              #we want to win scoring so we get used ahead of other encoders
                    spec_props["max_instances"] = 3             #limit to 3 video streams we proxy for (we really want 2,
                                                                # but because of races with garbage collection, we need to allow more)
                    #store it in encoding defs:
                    self.video_encoding_defs.setdefault(encoding, {}).setdefault(colorspace, []).append(spec_props)
                    encoder_types.add(spec.codec_type)

        enclog("encoder types found: %s", tuple(encoder_types))
        #remove duplicates and use preferred order:
        order = PREFERRED_ENCODER_ORDER[:]
        for x in list(encoder_types):
            if x not in order:
                order.append(x)
        self.video_encoder_types = [x for x in order if x in encoder_types]
        enclog.info("proxy video encoders: %s", ", ".join(self.video_encoder_types))


    def create_control_socket(self):
        assert self.socket_dir
        dotxpra = DotXpra(self.socket_dir)
        sockpath = dotxpra.socket_path(":proxy-%s" % os.getpid())
        state = dotxpra.get_server_state(sockpath)
        if state in (DotXpra.LIVE, DotXpra.UNKNOWN):
            log.error("Error: you already have a proxy server running at '%s'", sockpath)
            log.error(" the control socket will not be created")
            return False
        log("create_control_socket: socket path='%s', uid=%i, gid=%i", sockpath, getuid(), getgid())
        try:
            sock, self.control_socket_cleanup = create_unix_domain_socket(sockpath, None, 0o600)
            sock.listen(5)
        except Exception as e:
            log("create_unix_domain_socket failed for '%s'", sockpath, exc_info=True)
            log.error("Error: failed to setup control socket '%s':", sockpath)
            log.error(" %s", e)
            return False
        self.control_socket = sock
        self.control_socket_path = sockpath
        log.info("proxy instance now also available using unix domain socket:")
        log.info(" %s", self.control_socket_path)
        return True

    def control_socket_loop(self):
        while not self.exit:
            log("waiting for connection on %s", self.control_socket_path)
            sock, address = self.control_socket.accept()
            self.new_control_connection(sock, address)

    def new_control_connection(self, sock, address):
        if len(self.potential_protocols)>=self.max_connections:
            log.error("too many connections (%s), ignoring new one", len(self.potential_protocols))
            sock.close()
            return  True
        try:
            peername = sock.getpeername()
        except:
            peername = str(address)
        sockname = sock.getsockname()
        target = peername or sockname
        #sock.settimeout(0)
        log("new_control_connection() sock=%s, sockname=%s, address=%s, peername=%s", sock, sockname, address, peername)
        sc = SocketConnection(sock, sockname, address, target, "unix-domain")
        log.info("New proxy instance control connection received: %s", sc)
        protocol = Protocol(self, sc, self.process_control_packet)
        protocol.large_packets.append("info-response")
        self.potential_protocols.append(protocol)
        protocol.enable_default_encoder()
        protocol.start()
        self.timeout_add(SOCKET_TIMEOUT*1000, self.verify_connection_accepted, protocol)
        return True

    def verify_connection_accepted(self, protocol):
        if not protocol._closed and protocol in self.potential_protocols:
            log.error("connection timedout: %s", protocol)
            self.send_disconnect(protocol, LOGIN_TIMEOUT)

    def process_control_packet(self, proto, packet):
        try:
            self.do_process_control_packet(proto, packet)
        except Exception as e:
            log.error("error processing control packet", exc_info=True)
            self.send_disconnect(proto, CONTROL_COMMAND_ERROR, str(e))

    def do_process_control_packet(self, proto, packet):
        log("process_control_packet(%s, %s)", proto, packet)
        packet_type = packet[0]
        if packet_type==Protocol.CONNECTION_LOST:
            log.info("Connection lost")
            if proto in self.potential_protocols:
                self.potential_protocols.remove(proto)
            return
        if packet_type=="hello":
            caps = typedict(packet[1])
            if caps.boolget("challenge"):
                self.send_disconnect(proto, AUTHENTICATION_ERROR, "this socket does not use authentication")
                return
            if caps.get("info_request", False):
                proto.send_now(("hello", self.get_proxy_info(proto)))
                self.timeout_add(5*1000, self.send_disconnect, proto, CLIENT_EXIT_TIMEOUT, "info sent")
                return
            elif caps.get("stop_request", False):
                self.stop("socket request", None)
                return
            elif caps.get("version_request", False):
                from xpra import __version__
                proto.send_now(("hello", {"version" : __version__}))
                self.timeout_add(5*1000, self.send_disconnect, proto, CLIENT_EXIT_TIMEOUT, "version sent")
                return
        self.send_disconnect(proto, CONTROL_COMMAND_ERROR, "this socket only handles 'info', 'version' and 'stop' requests")

    def send_disconnect(self, proto, reason, *extra):
        log("send_disconnect(%s, %s, %s)", proto, reason, extra)
        if proto._closed:
            return
        proto.send_now(["disconnect", reason]+list(extra))
        self.timeout_add(1000, self.force_disconnect, proto)

    def force_disconnect(self, proto):
        proto.close()


    def get_proxy_info(self, proto):
        sinfo = {}
        sinfo.update(get_server_info())
        sinfo.update(get_thread_info(proto))
        return {"proxy" : {
                           "version"    : local_version,
                           ""           : sinfo,
                           },
                "window" : self.get_window_info(),
                }

    def send_hello(self, challenge_response=None, client_salt=None):
        hello = self.filter_client_caps(self.caps)
        if challenge_response:
            hello.update({
                          "challenge_response"      : challenge_response,
                          "challenge_client_salt"   : client_salt,
                          })
        self.queue_server_packet(("hello", hello))


    def sanitize_session_options(self, options):
        d = {}
        def number(k, v):
            return parse_number(int, k, v)
        OPTION_WHITELIST = {"compression_level" : number,
                            "lz4"               : parse_bool,
                            "lzo"               : parse_bool,
                            "zlib"              : parse_bool,
                            "rencode"           : parse_bool,
                            "bencode"           : parse_bool,
                            "yaml"              : parse_bool}
        for k,v in options.items():
            parser = OPTION_WHITELIST.get(k)
            if parser:
                log("trying to add %s=%s using %s", k, v, parser)
                try:
                    d[k] = parser(k, v)
                except Exception as e:
                    log.warn("failed to parse value %s for %s using %s: %s", v, k, parser, e)
        return d

    def filter_client_caps(self, caps):
        fc = self.filter_caps(caps, ("cipher", "challenge", "digest", "aliases", "compression", "lz4", "lz0", "zlib"))
        #update with options provided via config if any:
        fc.update(self.sanitize_session_options(self.session_options))
        #add video proxies if any:
        fc["encoding.proxy.video"] = len(self.video_encoding_defs)>0
        if self.video_encoding_defs:
            fc["encoding.proxy.video.encodings"] = self.video_encoding_defs
        return fc

    def filter_server_caps(self, caps):
        self.server_protocol.enable_encoder_from_caps(caps)
        return self.filter_caps(caps, ("aliases", ))

    def filter_caps(self, caps, prefixes):
        #removes caps that the proxy overrides / does not use:
        #(not very pythonic!)
        pcaps = {}
        removed = []
        for k in caps.keys():
            skip = len([e for e in prefixes if k.startswith(e)])
            if skip==0:
                pcaps[k] = caps[k]
            else:
                removed.append(k)
        log("filtered out %s matching %s", removed, prefixes)
        #replace the network caps with the proxy's own:
        pcaps.update(flatten_dict(get_network_caps()))
        #then add the proxy info:
        updict(pcaps, "proxy", get_server_info(), flatten_dicts=True)
        pcaps["proxy"] = True
        pcaps["proxy.hostname"] = socket.gethostname()
        return pcaps


    def run_queue(self):
        log("run_queue() queue has %s items already in it", self.main_queue.qsize())
        #process "idle_add"/"timeout_add" events in the main loop:
        while not self.exit:
            log("run_queue() size=%s", self.main_queue.qsize())
            v = self.main_queue.get()
            if v is None:
                log("run_queue() None exit marker")
                break
            fn, args, kwargs = v
            log("run_queue() %s%s%s", fn, args, kwargs)
            try:
                v = fn(*args, **kwargs)
                if bool(v):
                    #re-run it
                    self.main_queue.put(v)
            except:
                log.error("error during main loop callback %s", fn, exc_info=True)
        self.exit = True
        #wait for connections to close down cleanly before we exit
        for i in range(10):
            if self.client_protocol._closed and self.server_protocol._closed:
                break
            if i==0:
                log.info("waiting for network connections to close")
            else:
                log("still waiting %i/10 - client.closed=%s, server.closed=%s", i+1, self.client_protocol._closed, self.server_protocol._closed)
            time.sleep(0.1)
        log.info("proxy instance %s stopped", os.getpid())

    def stop(self, reason="proxy terminating", skip_proto=None):
        log.info("stop(%s, %s)", reason, skip_proto)
        self.exit = True
        try:
            self.control_socket.close()
        except:
            pass
        csc = self.control_socket_cleanup
        if csc:
            self.control_socket_cleanup = None
            csc()
        self.main_queue.put(None)
        #empty the main queue:
        q = Queue()
        q.put(None)
        self.main_queue = q
        #empty the encode queue:
        q = Queue()
        q.put(None)
        self.encode_queue = q
        for proto in (self.client_protocol, self.server_protocol):
            if proto and proto!=skip_proto:
                log("sending disconnect to %s", proto)
                proto.flush_then_close(["disconnect", SERVER_SHUTDOWN, reason])


    def queue_client_packet(self, packet):
        log("queueing client packet: %s", packet[0])
        self.client_packets.put(packet)
        self.client_protocol.source_has_more()

    def get_client_packet(self):
        #server wants a packet
        p = self.client_packets.get()
        log("sending to client: %s", p[0])
        return p, None, None, self.client_packets.qsize()>0

    def process_client_packet(self, proto, packet):
        packet_type = packet[0]
        log("process_client_packet: %s", packet_type)
        if packet_type==Protocol.CONNECTION_LOST:
            self.stop("client connection lost", proto)
            return
        elif packet_type=="disconnect":
            log("got disconnect from client: %s", packet[1])
            if self.exit:
                self.client_protocol.close()
            else:
                self.stop("disconnect from client: %s" % packet[1])
        elif packet_type=="set_deflate":
            #echo it back to the client:
            self.client_packets.put(packet)
            self.client_protocol.source_has_more()
            return
        elif packet_type=="hello":
            log.warn("Warning: invalid hello packet received after initial authentication (dropped)")
            return
        self.queue_server_packet(packet)


    def queue_server_packet(self, packet):
        log("queueing server packet: %s", packet[0])
        self.server_packets.put(packet)
        self.server_protocol.source_has_more()

    def get_server_packet(self):
        #server wants a packet
        p = self.server_packets.get()
        log("sending to server: %s", p[0])
        return p, None, None, self.server_packets.qsize()>0


    def _packet_recompress(self, packet, index, name):
        if len(packet)>index:
            data = packet[index]
            if len(data)<512:
                packet[index] = str(data)
                return
            #FIXME: this is ugly and not generic!
            zlib = compression.use_zlib and self.caps.boolget("zlib", True)
            lz4 = compression.use_lz4 and self.caps.boolget("lz4", False)
            lzo = compression.use_lzo and self.caps.boolget("lzo", False)
            if zlib or lz4 or lzo:
                packet[index] = compressed_wrapper(name, data, zlib=zlib, lz4=lz4, lzo=lzo, can_inline=False)
            else:
                #prevent warnings about large uncompressed data
                packet[index] = Compressed("raw %s" % name, data, can_inline=True)

    def process_server_packet(self, proto, packet):
        packet_type = packet[0]
        log("process_server_packet: %s", packet_type)
        if packet_type==Protocol.CONNECTION_LOST:
            self.stop("server connection lost", proto)
            return
        elif packet_type=="disconnect":
            log("got disconnect from server: %s", packet[1])
            if self.exit:
                self.server_protocol.close()
            else:
                self.stop("disconnect from server: %s" % packet[1])
        elif packet_type=="hello":
            c = typedict(packet[1])
            maxw, maxh = c.intpair("max_desktop_size", (4096, 4096))
            caps = self.filter_server_caps(c)
            #add new encryption caps:
            if self.cipher:
                from xpra.net.crypto import crypto_backend_init, new_cipher_caps, DEFAULT_PADDING
                crypto_backend_init()
                padding_options = self.caps.strlistget("cipher.padding.options", [DEFAULT_PADDING])
                auth_caps = new_cipher_caps(self.client_protocol, self.cipher, self.encryption_key, padding_options)
                caps.update(auth_caps)
            #may need to bump packet size:
            proto.max_packet_size = maxw*maxh*4*4
            file_transfer = self.caps.boolget("file-transfer") and c.boolget("file-transfer")
            file_size_limit = max(self.caps.intget("file-size-limit"), c.intget("file-size-limit"))
            file_max_packet_size = int(file_transfer) * (1024 + file_size_limit*1024*1024)
            self.client_protocol.max_packet_size = max(self.client_protocol.max_packet_size, file_max_packet_size)
            self.server_protocol.max_packet_size = max(self.server_protocol.max_packet_size, file_max_packet_size)
            packet = ("hello", caps)
        elif packet_type=="info-response":
            #adds proxy info:
            #note: this is only seen by the client application
            #"xpra info" is a new connection, which talks to the proxy server...
            info = packet[1]
            info.update(self.get_proxy_info(proto))
        elif packet_type=="lost-window":
            wid = packet[1]
            #mark it as lost so we can drop any current/pending frames
            self.lost_windows.add(wid)
            #queue it so it gets cleaned safely (for video encoders mostly):
            self.encode_queue.put(packet)
            #and fall through so tell the client immediately
        elif packet_type=="draw":
            #use encoder thread:
            self.encode_queue.put(packet)
            #which will queue the packet itself when done:
            return
        #we do want to reformat cursor packets...
        #as they will have been uncompressed by the network layer already:
        elif packet_type=="cursor":
            #packet = ["cursor", x, y, width, height, xhot, yhot, serial, pixels, name]
            #or:
            #packet = ["cursor", "png", x, y, width, height, xhot, yhot, serial, pixels, name]
            #or:
            #packet = ["cursor", ""]
            if len(packet)>=8:
                #hard to distinguish png cursors from normal cursors...
                try:
                    int(packet[1])
                    self._packet_recompress(packet, 8, "cursor")
                except:
                    self._packet_recompress(packet, 9, "cursor")
        elif packet_type=="window-icon":
            self._packet_recompress(packet, 5, "icon")
        elif packet_type=="send-file":
            if packet[6]:
                packet[6] = Compressed("file-data", packet[6])
        elif packet_type=="send-file-chunk":
            if packet[3]:
                packet[3] = Compressed("file-chunk-data", packet[3])
        elif packet_type=="challenge":
            from xpra.net.crypto import get_salt
            #client may have already responded to the challenge,
            #so we have to handle authentication from this end
            salt = packet[1]
            digest = packet[3]
            client_salt = get_salt(len(salt))
            salt = xor_str(salt, client_salt)
            if digest!=b"hmac":
                self.stop("digest mode '%s' not supported", std(digest))
                return
            password = self.session_options.get("password")
            if not password:
                self.stop("authentication requested by the server, but no password available for this session")
                return
            import hmac, hashlib
            password = strtobytes(password)
            salt = strtobytes(salt)
            challenge_response = hmac.HMAC(password, salt, digestmod=hashlib.md5).hexdigest()
            log.info("sending %s challenge response", digest)
            self.send_hello(challenge_response, client_salt)
            return
        self.queue_client_packet(packet)


    def encode_loop(self):
        """ thread for slower encoding related work """
        while not self.exit:
            packet = self.encode_queue.get()
            if packet is None:
                return
            try:
                packet_type = packet[0]
                if packet_type=="lost-window":
                    wid = packet[1]
                    self.lost_windows.remove(wid)
                    ve = self.video_encoders.get(wid)
                    if ve:
                        del self.video_encoders[wid]
                        del self.video_encoders_last_used_time[wid]
                        ve.clean()
                elif packet_type=="draw":
                    #modify the packet with the video encoder:
                    if self.process_draw(packet):
                        #then send it as normal:
                        self.queue_client_packet(packet)
                elif packet_type=="check-video-timeout":
                    #not a real packet, this is added by the timeout check:
                    wid = packet[1]
                    ve = self.video_encoders.get(wid)
                    now = time.time()
                    idle_time = now-self.video_encoders_last_used_time.get(wid)
                    if ve and idle_time>VIDEO_TIMEOUT:
                        enclog("timing out the video encoder context for window %s", wid)
                        #timeout is confirmed, we are in the encoding thread,
                        #so it is now safe to clean it up:
                        ve.clean()
                        del self.video_encoders[wid]
                        del self.video_encoders_last_used_time[wid]
                else:
                    enclog.warn("unexpected encode packet: %s", packet_type)
            except:
                enclog.warn("error encoding packet", exc_info=True)


    def process_draw(self, packet):
        wid, x, y, width, height, encoding, pixels, _, rowstride, client_options = packet[1:11]
        #never modify mmap packets
        if encoding=="mmap":
            return True

        #we have a proxy video packet:
        rgb_format = client_options.get("rgb_format", "")
        enclog("proxy draw: client_options=%s", client_options)

        def send_updated(encoding, compressed_data, updated_client_options):
            #update the packet with actual encoding data used:
            packet[6] = encoding
            packet[7] = compressed_data
            packet[10] = updated_client_options
            enclog("returning %s bytes from %s, options=%s", len(compressed_data), len(pixels), updated_client_options)
            return (wid not in self.lost_windows)

        def passthrough(strip_alpha=True):
            enclog("proxy draw: %s passthrough (rowstride: %s vs %s, strip alpha=%s)", rgb_format, rowstride, client_options.get("rowstride", 0), strip_alpha)
            if strip_alpha:
                #passthrough as plain RGB:
                Xindex = rgb_format.upper().find("X")
                if Xindex>=0 and len(rgb_format)==4:
                    #force clear alpha (which may be garbage):
                    newdata = bytearray(pixels)
                    for i in range(len(pixels)/4):
                        newdata[i*4+Xindex] = chr(255)
                    packet[9] = client_options.get("rowstride", 0)
                    cdata = bytes(newdata)
                else:
                    cdata = pixels
                new_client_options = {"rgb_format" : rgb_format}
            else:
                #preserve
                cdata = pixels
                new_client_options = client_options
            wrapped = Compressed("%s pixels" % encoding, cdata)
            #FIXME: we should not assume that rgb32 is supported here...
            #(we may have to convert to rgb24..)
            return send_updated("rgb32", wrapped, new_client_options)

        proxy_video = client_options.get("proxy", False)
        if PASSTHROUGH and (encoding in ("rgb32", "rgb24") or proxy_video):
            #we are dealing with rgb data, so we can pass it through:
            return passthrough(proxy_video)
        elif not self.video_encoder_types or not client_options or not proxy_video:
            #ensure we don't try to re-compress the pixel data in the network layer:
            #(re-add the "compressed" marker that gets lost when we re-assemble packets)
            packet[7] = Compressed("%s pixels" % encoding, packet[7])
            return True

        #video encoding: find existing encoder
        ve = self.video_encoders.get(wid)
        if ve:
            if ve in self.lost_windows:
                #we cannot clean the video encoder here, there may be more frames queue up
                #"lost-window" in encode_loop will take care of it safely
                return False
            #we must verify that the encoder is still valid
            #and scrap it if not (ie: when window is resized)
            if ve.get_width()!=width or ve.get_height()!=height:
                enclog("closing existing video encoder %s because dimensions have changed from %sx%s to %sx%s", ve, ve.get_width(), ve.get_height(), width, height)
                ve.clean()
                ve = None
            elif ve.get_encoding()!=encoding:
                enclog("closing existing video encoder %s because encoding has changed from %s to %s", ve.get_encoding(), encoding)
                ve.clean()
                ve = None
        #scaling and depth are proxy-encoder attributes:
        scaling = client_options.get("scaling", (1, 1))
        depth   = client_options.get("depth", 24)
        rowstride = client_options.get("rowstride", rowstride)
        quality = client_options.get("quality", -1)
        speed   = client_options.get("speed", -1)
        timestamp = client_options.get("timestamp")

        image = ImageWrapper(x, y, width, height, pixels, rgb_format, depth, rowstride, planes=ImageWrapper.PACKED)
        if timestamp is not None:
            image.set_timestamp(timestamp)

        #the encoder options are passed through:
        encoder_options = client_options.get("options", {})
        if not ve:
            #make a new video encoder:
            spec = self._find_video_encoder(encoding, rgb_format)
            if spec is None:
                #no video encoder!
                enc_pillow = get_codec("enc_pillow")
                if not enc_pillow:
                    from xpra.server.picture_encode import warn_encoding_once
                    warn_encoding_once("no-video-no-PIL", "no video encoder found for rgb format %s, sending as plain RGB!" % rgb_format)
                    return passthrough(True)
                enclog("no video encoder available: sending as jpeg")
                coding, compressed_data, client_options, _, _, _, _ = enc_pillow.encode("jpeg", image, quality, speed, False)
                return send_updated(coding, compressed_data, client_options)

            enclog("creating new video encoder %s for window %s", spec, wid)
            ve = spec.make_instance()
            #dst_formats is specified with first frame only:
            dst_formats = client_options.get("dst_formats")
            if dst_formats is not None:
                #save it in case we timeout the video encoder,
                #so we can instantiate it again, even from a frame no>1
                self.video_encoders_dst_formats = dst_formats
            else:
                assert self.video_encoders_dst_formats, "BUG: dst_formats not specified for proxy and we don't have it either"
                dst_formats = self.video_encoders_dst_formats
            ve.init_context(width, height, rgb_format, dst_formats, encoding, quality, speed, scaling, {})
            self.video_encoders[wid] = ve
            self.video_encoders_last_used_time[wid] = time.time()       #just to make sure this is always set
        #actual video compression:
        enclog("proxy compression using %s with quality=%s, speed=%s", ve, quality, speed)
        data, out_options = ve.compress_image(image, quality, speed, encoder_options)
        #pass through some options if we don't have them from the encoder
        #(maybe we should also use the "pts" from the real server?)
        for k in ("timestamp", "rgb_format", "depth", "csc"):
            if k not in out_options and k in client_options:
                out_options[k] = client_options[k]
        self.video_encoders_last_used_time[wid] = time.time()
        return send_updated(ve.get_encoding(), Compressed(encoding, data), out_options)

    def timeout_video_encoders(self):
        #have to be careful as another thread may come in...
        #so we just ask the encode thread (which deals with encoders already)
        #to do what may need to be done if we find a timeout:
        now = time.time()
        for wid in list(self.video_encoders_last_used_time.keys()):
            idle_time = int(now-self.video_encoders_last_used_time.get(wid))
            if idle_time is None:
                continue
            enclog("timeout_video_encoders() wid=%s, idle_time=%s", wid, idle_time)
            if idle_time and idle_time>VIDEO_TIMEOUT:
                self.encode_queue.put(["check-video-timeout", wid])
        return True     #run again

    def _find_video_encoder(self, encoding, rgb_format):
        #try the one specified first, then all the others:
        try_encodings = [encoding] + [x for x in self.video_helper.get_encodings() if x!=encoding]
        for encoding in try_encodings:
            colorspace_specs = self.video_helper.get_encoder_specs(encoding)
            especs = colorspace_specs.get(rgb_format)
            if len(especs)==0:
                continue
            for etype in self.video_encoder_types:
                for spec in especs:
                    if etype==spec.codec_type:
                        enclog("_find_video_encoder(%s, %s)=%s", encoding, rgb_format, spec)
                        return spec
        enclog("_find_video_encoder(%s, %s) not found", encoding, rgb_format)
        return None

    def get_window_info(self):
        info = {}
        now = time.time()
        for wid, encoder in self.video_encoders.items():
            einfo = encoder.get_info()
            einfo["idle_time"] = int(now-self.video_encoders_last_used_time.get(wid, 0))
            info[wid] = {
                         "proxy"    : {
                                       ""           : encoder.get_type(),
                                       "encoder"    : einfo
                                       },
                         }
        enclog("get_window_info()=%s", info)
        return info
Exemplo n.º 47
0
 def __init__(self, src_type=None, src_options={}, codecs=get_codecs(), codec_options={}, volume=1.0):
     if not src_type:
         try:
             from xpra.sound.pulseaudio.pulseaudio_util import get_pa_device_options
             monitor_devices = get_pa_device_options(True, False)
             log.info("found pulseaudio monitor devices: %s", monitor_devices)
         except ImportError as e:
             log.warn("Warning: pulseaudio is not available!")
             log.warn(" %s", e)
             monitor_devices = []
         if len(monitor_devices)==0:
             log.warn("could not detect any pulseaudio monitor devices")
             log.warn(" a test source will be used instead")
             src_type = "audiotestsrc"
             default_src_options = {"wave":2, "freq":100, "volume":0.4}
         else:
             monitor_device = monitor_devices.items()[0][0]
             log.info("using pulseaudio source device:")
             log.info(" '%s'", monitor_device)
             src_type = "pulsesrc"
             default_src_options = {"device" : monitor_device}
         src_options = default_src_options
     if src_type not in get_source_plugins():
         raise InitExit(1, "invalid source plugin '%s', valid options are: %s" % (src_type, ",".join(get_source_plugins())))
     matching = [x for x in CODEC_ORDER if (x in codecs and x in get_codecs())]
     log("SoundSource(..) found matching codecs %s", matching)
     if not matching:
         raise InitExit(1, "no matching codecs between arguments '%s' and supported list '%s'" % (csv(codecs), csv(get_codecs().keys())))
     codec = matching[0]
     encoder, fmt = get_encoder_formatter(codec)
     self.queue = None
     self.caps = None
     self.volume = None
     self.sink = None
     self.src = None
     self.src_type = src_type
     self.buffer_latency = False
     self.jitter_queue = None
     self.file = None
     SoundPipeline.__init__(self, codec)
     src_options["name"] = "src"
     source_str = plugin_str(src_type, src_options)
     #FIXME: this is ugly and relies on the fact that we don't pass any codec options to work!
     encoder_str = plugin_str(encoder, codec_options or get_encoder_default_options(encoder))
     fmt_str = plugin_str(fmt, MUXER_DEFAULT_OPTIONS.get(fmt, {}))
     pipeline_els = [source_str]
     if SOURCE_QUEUE_TIME>0:
         queue_el = ["queue",
                     "name=queue",
                     "min-threshold-time=0",
                     "max-size-buffers=0",
                     "max-size-bytes=0",
                     "max-size-time=%s" % (SOURCE_QUEUE_TIME*MS_TO_NS),
                     "leaky=%s" % GST_QUEUE_LEAK_DOWNSTREAM]
         pipeline_els += [" ".join(queue_el)]
     if encoder in ENCODER_NEEDS_AUDIOCONVERT or src_type in SOURCE_NEEDS_AUDIOCONVERT:
         pipeline_els += ["audioconvert"]
     pipeline_els.append("volume name=volume volume=%s" % volume)
     pipeline_els += [encoder_str,
                     fmt_str,
                     APPSINK]
     if not self.setup_pipeline_and_bus(pipeline_els):
         return
     self.volume = self.pipeline.get_by_name("volume")
     self.sink = self.pipeline.get_by_name("sink")
     if SOURCE_QUEUE_TIME>0:
         self.queue  = self.pipeline.get_by_name("queue")
     if self.queue:
         try:
             self.queue.set_property("silent", True)
         except Exception as e:
             log("cannot make queue silent: %s", e)
     try:
         if get_gst_version()<(1,0):
             self.sink.set_property("enable-last-buffer", False)
         else:
             self.sink.set_property("enable-last-sample", False)
     except Exception as e:
         log("failed to disable last buffer: %s", e)
     self.skipped_caps = set()
     if JITTER>0:
         self.jitter_queue = Queue()
     try:
         #Gst 1.0:
         self.sink.connect("new-sample", self.on_new_sample)
         self.sink.connect("new-preroll", self.on_new_preroll1)
     except:
         #Gst 0.10:
         self.sink.connect("new-buffer", self.on_new_buffer)
         self.sink.connect("new-preroll", self.on_new_preroll0)
     self.src = self.pipeline.get_by_name("src")
     try:
         for x in ("actual-buffer-time", "actual-latency-time"):
             #don't comment this out, it is used to verify the attributes are present:
             gstlog("initial %s: %s", x, self.src.get_property(x))
         self.buffer_latency = True
     except Exception as e:
         log.info("source %s does not support 'buffer-time' or 'latency-time':", self.src_type)
         log.info(" %s", e)
     else:
         #if the env vars have been set, try to honour the settings:
         global BUFFER_TIME, LATENCY_TIME
         if BUFFER_TIME>0:
             if BUFFER_TIME<LATENCY_TIME:
                 log.warn("Warning: latency (%ims) must be lower than the buffer time (%ims)", LATENCY_TIME, BUFFER_TIME)
             else:
                 log("latency tuning for %s, will try to set buffer-time=%i, latency-time=%i", src_type, BUFFER_TIME, LATENCY_TIME)
                 def settime(attr, v):
                     try:
                         cval = self.src.get_property(attr)
                         gstlog("default: %s=%i", attr, cval//1000)
                         if v>=0:
                             self.src.set_property(attr, v*1000)
                             gstlog("overriding with: %s=%i", attr, v)
                     except Exception as e:
                         log.warn("source %s does not support '%s': %s", self.src_type, attr, e)
                 settime("buffer-time", BUFFER_TIME)
                 settime("latency-time", LATENCY_TIME)
     gen = generation.increase()
     if SAVE_TO_FILE is not None:
         parts = codec.split("+")
         if len(parts)>1:
             filename = SAVE_TO_FILE+str(gen)+"-"+parts[0]+".%s" % parts[1]
         else:
             filename = SAVE_TO_FILE+str(gen)+".%s" % codec
         self.file = open(filename, 'wb')
         log.info("saving %s stream to %s", codec, filename)
Exemplo n.º 48
0
class SoundSource(SoundPipeline):

    __gsignals__ = SoundPipeline.__generic_signals__.copy()
    __gsignals__.update({
        "new-buffer"    : n_arg_signal(2),
        })

    def __init__(self, src_type=None, src_options={}, codecs=get_codecs(), codec_options={}, volume=1.0):
        if not src_type:
            try:
                from xpra.sound.pulseaudio.pulseaudio_util import get_pa_device_options
                monitor_devices = get_pa_device_options(True, False)
                log.info("found pulseaudio monitor devices: %s", monitor_devices)
            except ImportError as e:
                log.warn("Warning: pulseaudio is not available!")
                log.warn(" %s", e)
                monitor_devices = []
            if len(monitor_devices)==0:
                log.warn("could not detect any pulseaudio monitor devices")
                log.warn(" a test source will be used instead")
                src_type = "audiotestsrc"
                default_src_options = {"wave":2, "freq":100, "volume":0.4}
            else:
                monitor_device = monitor_devices.items()[0][0]
                log.info("using pulseaudio source device:")
                log.info(" '%s'", monitor_device)
                src_type = "pulsesrc"
                default_src_options = {"device" : monitor_device}
            src_options = default_src_options
        if src_type not in get_source_plugins():
            raise InitExit(1, "invalid source plugin '%s', valid options are: %s" % (src_type, ",".join(get_source_plugins())))
        matching = [x for x in CODEC_ORDER if (x in codecs and x in get_codecs())]
        log("SoundSource(..) found matching codecs %s", matching)
        if not matching:
            raise InitExit(1, "no matching codecs between arguments '%s' and supported list '%s'" % (csv(codecs), csv(get_codecs().keys())))
        codec = matching[0]
        encoder, fmt = get_encoder_formatter(codec)
        self.queue = None
        self.caps = None
        self.volume = None
        self.sink = None
        self.src = None
        self.src_type = src_type
        self.buffer_latency = False
        self.jitter_queue = None
        self.file = None
        SoundPipeline.__init__(self, codec)
        src_options["name"] = "src"
        source_str = plugin_str(src_type, src_options)
        #FIXME: this is ugly and relies on the fact that we don't pass any codec options to work!
        encoder_str = plugin_str(encoder, codec_options or get_encoder_default_options(encoder))
        fmt_str = plugin_str(fmt, MUXER_DEFAULT_OPTIONS.get(fmt, {}))
        pipeline_els = [source_str]
        if SOURCE_QUEUE_TIME>0:
            queue_el = ["queue",
                        "name=queue",
                        "min-threshold-time=0",
                        "max-size-buffers=0",
                        "max-size-bytes=0",
                        "max-size-time=%s" % (SOURCE_QUEUE_TIME*MS_TO_NS),
                        "leaky=%s" % GST_QUEUE_LEAK_DOWNSTREAM]
            pipeline_els += [" ".join(queue_el)]
        if encoder in ENCODER_NEEDS_AUDIOCONVERT or src_type in SOURCE_NEEDS_AUDIOCONVERT:
            pipeline_els += ["audioconvert"]
        pipeline_els.append("volume name=volume volume=%s" % volume)
        pipeline_els += [encoder_str,
                        fmt_str,
                        APPSINK]
        if not self.setup_pipeline_and_bus(pipeline_els):
            return
        self.volume = self.pipeline.get_by_name("volume")
        self.sink = self.pipeline.get_by_name("sink")
        if SOURCE_QUEUE_TIME>0:
            self.queue  = self.pipeline.get_by_name("queue")
        if self.queue:
            try:
                self.queue.set_property("silent", True)
            except Exception as e:
                log("cannot make queue silent: %s", e)
        try:
            if get_gst_version()<(1,0):
                self.sink.set_property("enable-last-buffer", False)
            else:
                self.sink.set_property("enable-last-sample", False)
        except Exception as e:
            log("failed to disable last buffer: %s", e)
        self.skipped_caps = set()
        if JITTER>0:
            self.jitter_queue = Queue()
        try:
            #Gst 1.0:
            self.sink.connect("new-sample", self.on_new_sample)
            self.sink.connect("new-preroll", self.on_new_preroll1)
        except:
            #Gst 0.10:
            self.sink.connect("new-buffer", self.on_new_buffer)
            self.sink.connect("new-preroll", self.on_new_preroll0)
        self.src = self.pipeline.get_by_name("src")
        try:
            for x in ("actual-buffer-time", "actual-latency-time"):
                #don't comment this out, it is used to verify the attributes are present:
                gstlog("initial %s: %s", x, self.src.get_property(x))
            self.buffer_latency = True
        except Exception as e:
            log.info("source %s does not support 'buffer-time' or 'latency-time':", self.src_type)
            log.info(" %s", e)
        else:
            #if the env vars have been set, try to honour the settings:
            global BUFFER_TIME, LATENCY_TIME
            if BUFFER_TIME>0:
                if BUFFER_TIME<LATENCY_TIME:
                    log.warn("Warning: latency (%ims) must be lower than the buffer time (%ims)", LATENCY_TIME, BUFFER_TIME)
                else:
                    log("latency tuning for %s, will try to set buffer-time=%i, latency-time=%i", src_type, BUFFER_TIME, LATENCY_TIME)
                    def settime(attr, v):
                        try:
                            cval = self.src.get_property(attr)
                            gstlog("default: %s=%i", attr, cval//1000)
                            if v>=0:
                                self.src.set_property(attr, v*1000)
                                gstlog("overriding with: %s=%i", attr, v)
                        except Exception as e:
                            log.warn("source %s does not support '%s': %s", self.src_type, attr, e)
                    settime("buffer-time", BUFFER_TIME)
                    settime("latency-time", LATENCY_TIME)
        gen = generation.increase()
        if SAVE_TO_FILE is not None:
            parts = codec.split("+")
            if len(parts)>1:
                filename = SAVE_TO_FILE+str(gen)+"-"+parts[0]+".%s" % parts[1]
            else:
                filename = SAVE_TO_FILE+str(gen)+".%s" % codec
            self.file = open(filename, 'wb')
            log.info("saving %s stream to %s", codec, filename)


    def __repr__(self):
        return "SoundSource('%s' - %s)" % (self.pipeline_str, self.state)

    def cleanup(self):
        SoundPipeline.cleanup(self)
        self.src_type = ""
        self.sink = None
        self.caps = None
        f = self.file
        if f:
            self.file = None
            f.close()

    def get_info(self):
        info = SoundPipeline.get_info(self)
        if self.queue:
            info["queue"] = {"cur" : self.queue.get_property("current-level-time")//MS_TO_NS}
        if self.buffer_latency:
            for x in ("actual-buffer-time", "actual-latency-time"):
                v = self.src.get_property(x)
                if v>=0:
                    info[x] = v
        return info


    def on_new_preroll1(self, appsink):
        sample = appsink.emit('pull-preroll')
        gstlog('new preroll1: %s', sample)
        return self.emit_buffer1(sample)

    def on_new_sample(self, bus):
        #Gst 1.0
        sample = self.sink.emit("pull-sample")
        return self.emit_buffer1(sample)

    def emit_buffer1(self, sample):
        buf = sample.get_buffer()
        #info = sample.get_info()
        size = buf.get_size()
        extract_dup = getattr(buf, "extract_dup", None)
        if extract_dup:
            data = extract_dup(0, size)
        else:
            #crappy gi bindings detected, using workaround:
            from xpra.sound.gst_hacks import map_gst_buffer
            with map_gst_buffer(buf) as a:
                data = bytes(a[:])
        return self.emit_buffer(data, {"timestamp"  : normv(buf.pts),
                                   "duration"   : normv(buf.duration),
                                   })


    def on_new_preroll0(self, appsink):
        buf = appsink.emit('pull-preroll')
        gstlog('new preroll0: %s bytes', len(buf))
        return self.emit_buffer0(buf)

    def on_new_buffer(self, bus):
        #pygst 0.10
        buf = self.sink.emit("pull-buffer")
        return self.emit_buffer0(buf)


    def caps_to_dict(self, caps):
        if not caps:
            return {}
        d = {}
        try:
            for cap in caps:
                name = cap.get_name()
                capd = {}
                for k in cap.keys():
                    v = cap[k]
                    if type(v) in (str, int):
                        capd[k] = cap[k]
                    elif k not in self.skipped_caps:
                        log("skipping %s cap key %s=%s of type %s", name, k, v, type(v))
                d[name] = capd
        except Exception as e:
            log.error("Error parsing '%s':", caps)
            log.error(" %s", e)
        return d

    def emit_buffer0(self, buf):
        """ convert pygst structure into something more generic for the wire """
        #none of the metadata is really needed at present, but it may be in the future:
        #metadata = {"caps"      : buf.get_caps().to_string(),
        #            "size"      : buf.size,
        #            "timestamp" : buf.timestamp,
        #            "duration"  : buf.duration,
        #            "offset"    : buf.offset,
        #            "offset_end": buf.offset_end}
        log("emit buffer: %s bytes, timestamp=%s", len(buf.data), buf.timestamp//MS_TO_NS)
        metadata = {
                   "timestamp" : normv(buf.timestamp),
                   "duration"  : normv(buf.duration)
                   }
        d = self.caps_to_dict(buf.get_caps())
        if not self.caps or self.caps!=d:
            self.caps = d
            self.info["caps"] = self.caps
            metadata["caps"] = self.caps
        return self.emit_buffer(buf.data, metadata)

    def emit_buffer(self, data, metadata={}):
        f = self.file
        if f and data:
            self.file.write(data)
            self.file.flush()
        if self.state=="stopped":
            #don't bother
            return 0
        if JITTER>0:
            #will actually emit the buffer after a random delay
            if self.jitter_queue.empty():
                #queue was empty, schedule a timer to flush it
                from random import randint
                jitter = randint(1, JITTER)
                self.timeout_add(jitter, self.flush_jitter_queue)
                log("emit_buffer: will flush jitter queue in %ims", jitter)
            self.jitter_queue.put((data, metadata))
            return 0
        log("emit_buffer data=%s, len=%i, metadata=%s", type(data), len(data), metadata)
        return self.do_emit_buffer(data, metadata)

    def flush_jitter_queue(self):
        while not self.jitter_queue.empty():
            d,m = self.jitter_queue.get(False)
            self.do_emit_buffer(d, m)

    def do_emit_buffer(self, data, metadata={}):
        self.inc_buffer_count()
        self.inc_byte_count(len(data))
        metadata["time"] = int(time.time()*1000)
        self.idle_emit("new-buffer", data, metadata)
        self.emit_info()
        return 0
Exemplo n.º 49
0
class subprocess_caller(object):
    """
    This is the caller side, wrapping the subprocess.
    You can call send() to pass packets to it
     which will get converted to method calls on the receiving end,
    You can register for signals, in which case your callbacks will be called
     when those signals are forwarded back.
    (there is no validation of which signals are valid or not)
    """

    def __init__(self, description="wrapper"):
        self.process = None
        self.protocol = None
        self.command = None
        self.description = description
        self.send_queue = Queue()
        self.signal_callbacks = {}
        self.large_packets = []
        #hook a default packet handlers:
        self.connect(Protocol.CONNECTION_LOST, self.connection_lost)
        self.connect(Protocol.GIBBERISH, self.gibberish)


    def connect(self, signal, cb, *args):
        """ gobject style signal registration """
        self.signal_callbacks.setdefault(signal, []).append((cb, list(args)))


    def subprocess_exit(self, *args):
        #beware: this may fire more than once!
        log("subprocess_exit%s command=%s", args, self.command)
        self._fire_callback("exit")

    def start(self):
        self.process = self.exec_subprocess()
        self.protocol = self.make_protocol()
        self.protocol.start()

    def make_protocol(self):
        #make a connection using the process stdin / stdout
        conn = TwoFileConnection(self.process.stdin, self.process.stdout, abort_test=None, target=self.description, info=self.description, close_cb=self.subprocess_exit)
        conn.timeout = 0
        protocol = Protocol(gobject, conn, self.process_packet, get_packet_cb=self.get_packet)
        #we assume the other end has the same encoders (which is reasonable):
        #TODO: fallback to bencoder
        try:
            protocol.enable_encoder("rencode")
        except Exception as e:
            log.warn("failed to enable rencode: %s", e)
            protocol.enable_encoder("bencode")
        #we assume this is local, so no compression:
        protocol.enable_compressor("none")
        protocol.large_packets = self.large_packets
        return protocol


    def exec_subprocess(self):
        kwargs = self.exec_kwargs()
        log("exec_subprocess() command=%s, kwargs=%s", self.command, kwargs)
        proc = subprocess.Popen(self.command, stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=sys.stderr.fileno(), env=self.get_env(), **kwargs)
        getChildReaper().add_process(proc, self.description, self.command, True, True, callback=self.subprocess_exit)
        return proc

    def get_env(self):
        env = os.environ.copy()
        env["XPRA_SKIP_UI"] = "1"
        env["XPRA_LOG_PREFIX"] = "%s " % self.description
        #let's make things more complicated than they should be:
        #on win32, the environment can end up containing unicode, and subprocess chokes on it
        for k,v in env.items():
            try:
                env[k] = bytestostr(v.encode("utf8"))
            except:
                env[k] = bytestostr(v)
        return env

    def exec_kwargs(self):
        if os.name=="posix":
            return {"close_fds" : True}
        elif sys.platform.startswith("win"):
            if not WIN32_SHOWWINDOW:
                startupinfo = subprocess.STARTUPINFO()
                startupinfo.dwFlags |= subprocess.STARTF_USESHOWWINDOW
                return {"startupinfo" : startupinfo}
        return {}


    def cleanup(self):
        self.stop()

    def stop(self):
        self.stop_process()
        #call via idle_add to prevent deadlocks on win32!
        gobject.idle_add(self.stop_protocol)

    def stop_process(self):
        log("stop() sending stop request to %s", self.description)
        proc = self.process
        if proc and proc.poll() is None:
            try:
                proc.terminate()
                self.process = None
            except Exception as e:
                log.warn("failed to stop the wrapped subprocess %s: %s", proc, e)

    def stop_protocol(self):
        p = self.protocol
        if p:
            self.protocol = None
            log("%s.stop() calling %s", self, p.close)
            try:
                p.close()
            except Exception as e:
                log.warn("failed to close the subprocess connection: %s", p, e)


    def connection_lost(self, *args):
        log("connection_lost%s", args)
        self.stop()

    def gibberish(self, *args):
        log("gibberish%s", args)
        self.stop()


    def get_packet(self):
        try:
            item = self.send_queue.get(False)
        except:
            item = None
        return (item, None, None, self.send_queue.qsize()>0)

    def send(self, *packet_data):
        self.send_queue.put(packet_data)
        p = self.protocol
        if p:
            p.source_has_more()

    def process_packet(self, proto, packet):
        if DEBUG_WRAPPER:
            log("process_packet(%s, %s)", proto, [str(x)[:32] for x in packet])
        signal_name = bytestostr(packet[0])
        self._fire_callback(signal_name, packet[1:])

    def _fire_callback(self, signal_name, extra_args=[]):
        callbacks = self.signal_callbacks.get(signal_name)
        log("firing callback for %s: %s", signal_name, callbacks)
        if callbacks:
            for cb, args in callbacks:
                try:
                    all_args = list(args) + extra_args
                    gobject.idle_add(cb, self, *all_args)
                except Exception:
                    log.error("error processing callback %s for %s packet", cb, signal_name, exc_info=True)
Exemplo n.º 50
0
 def __init__(self, *args):
     Connection.__init__(self, *args)
     self.queue = Queue()
Exemplo n.º 51
0
 def __init__(self):
     Thread.__init__(self, name="Worker_Thread")
     self.items = Queue()
     self.exit = False
     self.setDaemon(True)
Exemplo n.º 52
0
class subprocess_caller(object):
    """
    This is the caller side, wrapping the subprocess.
    You can call send() to pass packets to it
     which will get converted to method calls on the receiving end,
    You can register for signals, in which case your callbacks will be called
     when those signals are forwarded back.
    (there is no validation of which signals are valid or not)
    """

    def __init__(self, description="wrapper"):
        self.process = None
        self.protocol = None
        self.command = None
        self.description = description
        self.send_queue = Queue()
        self.signal_callbacks = {}
        self.large_packets = []
        #hook a default packet handlers:
        self.connect(Protocol.CONNECTION_LOST, self.connection_lost)
        self.connect(Protocol.GIBBERISH, self.gibberish)
        glib = import_glib()
        self.idle_add = glib.idle_add
        self.timeout_add = glib.timeout_add
        self.source_remove = glib.source_remove


    def connect(self, signal, cb, *args):
        """ gobject style signal registration """
        self.signal_callbacks.setdefault(signal, []).append((cb, list(args)))


    def subprocess_exit(self, *args):
        #beware: this may fire more than once!
        log("subprocess_exit%s command=%s", args, self.command)
        self._fire_callback("exit")

    def start(self):
        self.start = self.fail_start
        self.process = self.exec_subprocess()
        self.protocol = self.make_protocol()
        self.protocol.start()

    def fail_start(self):
        raise Exception("this wrapper has already been started")

    def abort_test(self, action):
        p = self.process
        if p is None or p.poll():
            raise ConnectionClosedException("cannot %s: subprocess has terminated" % action)

    def make_protocol(self):
        #make a connection using the process stdin / stdout
        conn = TwoFileConnection(self.process.stdin, self.process.stdout, abort_test=self.abort_test, target=self.description, socktype=self.description, close_cb=self.subprocess_exit)
        conn.timeout = 0
        protocol = Protocol(self, conn, self.process_packet, get_packet_cb=self.get_packet)
        setup_fastencoder_nocompression(protocol)
        protocol.large_packets = self.large_packets
        return protocol


    def exec_subprocess(self):
        kwargs = exec_kwargs()
        env = self.get_env()
        log("exec_subprocess() command=%s, env=%s, kwargs=%s", self.command, env, kwargs)
        proc = subprocess.Popen(self.command, stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=sys.stderr.fileno(), env=env, **kwargs)
        getChildReaper().add_process(proc, self.description, self.command, True, True, callback=self.subprocess_exit)
        return proc

    def get_env(self):
        env = exec_env()
        env["XPRA_LOG_PREFIX"] = "%s " % self.description
        return env

    def cleanup(self):
        self.stop()

    def stop(self):
        self.stop_process()
        self.stop_protocol()

    def stop_process(self):
        log("%s.stop_process() sending stop request to %s", self, self.description)
        proc = self.process
        if proc and proc.poll() is None:
            try:
                proc.terminate()
                self.process = None
            except Exception as e:
                log.warn("failed to stop the wrapped subprocess %s: %s", proc, e)

    def stop_protocol(self):
        p = self.protocol
        if p:
            self.protocol = None
            log("%s.stop_protocol() calling %s", self, p.close)
            try:
                p.close()
            except Exception as e:
                log.warn("failed to close the subprocess connection: %s", p, e)


    def connection_lost(self, *args):
        log("connection_lost%s", args)
        self.stop()

    def gibberish(self, *args):
        log.warn("%s stopping on gibberish:", self.description)
        log.warn(" %s", repr_ellipsized(args[1], limit=80))
        self.stop()


    def get_packet(self):
        try:
            item = self.send_queue.get(False)
        except:
            item = None
        return (item, None, None, self.send_queue.qsize()>0)

    def send(self, *packet_data):
        self.send_queue.put(packet_data)
        p = self.protocol
        if p:
            p.source_has_more()
        INJECT_FAULT(p)

    def process_packet(self, proto, packet):
        if DEBUG_WRAPPER:
            log("process_packet(%s, %s)", proto, [str(x)[:32] for x in packet])
        signal_name = bytestostr(packet[0])
        self._fire_callback(signal_name, packet[1:])
        INJECT_FAULT(proto)

    def _fire_callback(self, signal_name, extra_args=[]):
        callbacks = self.signal_callbacks.get(signal_name)
        log("firing callback for '%s': %s", signal_name, callbacks)
        if callbacks:
            for cb, args in callbacks:
                try:
                    all_args = list(args) + extra_args
                    self.idle_add(cb, self, *all_args)
                except Exception:
                    log.error("error processing callback %s for %s packet", cb, signal_name, exc_info=True)
Exemplo n.º 53
0
class subprocess_caller(object):
    """
    This is the caller side, wrapping the subprocess.
    You can call send() to pass packets to it
     which will get converted to method calls on the receiving end,
    You can register for signals, in which case your callbacks will be called
     when those signals are forwarded back.
    (there is no validation of which signals are valid or not)
    """

    def __init__(self, description="wrapper"):
        self.process = None
        self.protocol = None
        self.command = None
        self.description = description
        self.send_queue = Queue()
        self.signal_callbacks = {}
        self.large_packets = []
        #hook a default packet handlers:
        self.connect(Protocol.CONNECTION_LOST, self.connection_lost)
        self.connect(Protocol.GIBBERISH, self.gibberish)
        glib = import_glib()
        self.idle_add = glib.idle_add
        self.timeout_add = glib.timeout_add
        self.source_remove = glib.source_remove


    def connect(self, signal, cb, *args):
        """ gobject style signal registration """
        self.signal_callbacks.setdefault(signal, []).append((cb, list(args)))


    def subprocess_exit(self, *args):
        #beware: this may fire more than once!
        log("subprocess_exit%s command=%s", args, self.command)
        self._fire_callback("exit")

    def start(self):
        assert self.process is None, "already started"
        self.process = self.exec_subprocess()
        self.protocol = self.make_protocol()
        self.protocol.start()

    def abort_test(self, action):
        p = self.process
        if p is None or p.poll():
            raise ConnectionClosedException("cannot %s: subprocess has terminated" % action)

    def make_protocol(self):
        #make a connection using the process stdin / stdout
        conn = TwoFileConnection(self.process.stdin, self.process.stdout,
                                 abort_test=self.abort_test, target=self.description,
                                 socktype=self.description, close_cb=self.subprocess_exit)
        conn.timeout = 0
        protocol = Protocol(self, conn, self.process_packet, get_packet_cb=self.get_packet)
        setup_fastencoder_nocompression(protocol)
        protocol.large_packets = self.large_packets
        return protocol


    def exec_subprocess(self):
        kwargs = exec_kwargs()
        env = self.get_env()
        log("exec_subprocess() command=%s, env=%s, kwargs=%s", self.command, env, kwargs)
        proc = subprocess.Popen(self.command, stdin=subprocess.PIPE, stdout=subprocess.PIPE, env=env, **kwargs)
        getChildReaper().add_process(proc, self.description, self.command, True, True, callback=self.subprocess_exit)
        return proc

    def get_env(self):
        env = exec_env()
        env["XPRA_LOG_PREFIX"] = "%s " % self.description
        env["XPRA_FIX_UNICODE_OUT"] = "0"
        return env

    def cleanup(self):
        self.stop()

    def stop(self):
        self.stop_process()
        self.stop_protocol()

    def stop_process(self):
        log("%s.stop_process() sending stop request to %s", self, self.description)
        proc = self.process
        if proc and proc.poll() is None:
            try:
                proc.terminate()
                self.process = None
            except Exception as e:
                log.warn("failed to stop the wrapped subprocess %s: %s", proc, e)

    def stop_protocol(self):
        p = self.protocol
        if p:
            self.protocol = None
            log("%s.stop_protocol() calling %s", self, p.close)
            try:
                p.close()
            except Exception as e:
                log.warn("failed to close the subprocess connection: %s", p, e)


    def connection_lost(self, *args):
        log("connection_lost%s", args)
        self.stop()

    def gibberish(self, *args):
        log.warn("%s stopping on gibberish:", self.description)
        log.warn(" %s", repr_ellipsized(args[1], limit=80))
        self.stop()


    def get_packet(self):
        try:
            item = self.send_queue.get(False)
        except Exception:
            item = None
        return (item, None, None, None, False, self.send_queue.qsize()>0)

    def send(self, *packet_data):
        self.send_queue.put(packet_data)
        p = self.protocol
        if p:
            p.source_has_more()
        INJECT_FAULT(p)

    def process_packet(self, proto, packet):
        if DEBUG_WRAPPER:
            log("process_packet(%s, %s)", proto, [str(x)[:32] for x in packet])
        signal_name = bytestostr(packet[0])
        self._fire_callback(signal_name, packet[1:])
        INJECT_FAULT(proto)

    def _fire_callback(self, signal_name, extra_args=()):
        callbacks = self.signal_callbacks.get(signal_name)
        log("firing callback for '%s': %s", signal_name, callbacks)
        if callbacks:
            for cb, args in callbacks:
                try:
                    all_args = list(args) + list(extra_args)
                    self.idle_add(cb, self, *all_args)
                except Exception:
                    log.error("error processing callback %s for %s packet", cb, signal_name, exc_info=True)
Exemplo n.º 54
0
class subprocess_callee(object):
    """
    This is the callee side, wrapping the gobject we want to interact with.
    All the input received will be converted to method calls on the wrapped object.
    Subclasses should register the signal handlers they want to see exported back to the caller.
    The convenience connect_export(signal-name, *args) can be used to forward signals unmodified.
    You can also call send() to pass packets back to the caller.
    (there is no validation of which signals are valid or not)
    """
    def __init__(self, input_filename="-", output_filename="-", wrapped_object=None, method_whitelist=None):
        self.name = ""
        self.input_filename = input_filename
        self.output_filename = output_filename
        self.method_whitelist = method_whitelist
        self.large_packets = []
        #the gobject instance which is wrapped:
        self.wrapped_object = wrapped_object
        self.send_queue = Queue()
        self.protocol = None
        if HANDLE_SIGINT:
            #this breaks gobject3!
            signal.signal(signal.SIGINT, self.handle_signal)
        signal.signal(signal.SIGTERM, self.handle_signal)
        self.setup_mainloop()

    def setup_mainloop(self):
        glib = import_glib()
        self.mainloop = glib.MainLoop()
        self.idle_add = glib.idle_add
        self.timeout_add = glib.timeout_add
        self.source_remove = glib.source_remove


    def connect_export(self, signal_name, *user_data):
        """ gobject style signal registration for the wrapped object,
            the signals will automatically be forwarded to the wrapper process
            using send(signal_name, *signal_args, *user_data)
        """
        log("connect_export%s", [signal_name] + list(user_data))
        args = list(user_data) + [signal_name]
        self.wrapped_object.connect(signal_name, self.export, *args)

    def export(self, *args):
        signal_name = args[-1]
        log("export(%s, ...)", signal_name)
        data = args[1:-1]
        self.send(signal_name, *list(data))


    def start(self):
        self.protocol = self.make_protocol()
        self.protocol.start()
        try:
            self.run()
            return 0
        except KeyboardInterrupt as e:
            if str(e):
                log.warn("%s", e)
            return 0
        except Exception:
            log.error("error in main loop", exc_info=True)
            return 1
        finally:
            self.cleanup()
            if self.protocol:
                self.protocol.close()
                self.protocol = None
            if self.input_filename=="-":
                try:
                    self._input.close()
                except:
                    pass
            if self.output_filename=="-":
                try:
                    self._output.close()
                except:
                    pass

    def make_protocol(self):
        #figure out where we read from and write to:
        if self.input_filename=="-":
            #disable stdin buffering:
            self._input = os.fdopen(sys.stdin.fileno(), 'rb', 0)
            setbinarymode(self._input.fileno())
        else:
            self._input = open(self.input_filename, 'rb')
        if self.output_filename=="-":
            #disable stdout buffering:
            self._output = os.fdopen(sys.stdout.fileno(), 'wb', 0)
            setbinarymode(self._output.fileno())
        else:
            self._output = open(self.output_filename, 'wb')
        #stdin and stdout wrapper:
        conn = TwoFileConnection(self._output, self._input, abort_test=None, target=self.name, socktype=self.name, close_cb=self.net_stop)
        conn.timeout = 0
        protocol = Protocol(self, conn, self.process_packet, get_packet_cb=self.get_packet)
        setup_fastencoder_nocompression(protocol)
        protocol.large_packets = self.large_packets
        return protocol


    def run(self):
        self.mainloop.run()


    def net_stop(self):
        #this is called from the network thread,
        #we use idle add to ensure we clean things up from the main thread
        log("net_stop() will call stop from main thread")
        self.idle_add(self.stop)


    def cleanup(self):
        pass

    def stop(self):
        self.cleanup()
        p = self.protocol
        log("stop() protocol=%s", p)
        if p:
            self.protocol = None
            p.close()
        self.do_stop()

    def do_stop(self):
        log("stop() stopping mainloop %s", self.mainloop)
        self.mainloop.quit()

    def handle_signal(self, sig, frame):
        """ This is for OS signals SIGINT and SIGTERM """
        #next time, just stop:
        signal.signal(signal.SIGINT, self.signal_stop)
        signal.signal(signal.SIGTERM, self.signal_stop)
        signame = SIGNAMES.get(sig, sig)
        try:
            log("handle_signal(%s, %s) calling stop from main thread", signame, frame)
        except:
            pass        #may fail if we were doing IO logging when the signal was received
        self.send("signal", signame)
        self.timeout_add(0, self.cleanup)
        #give time for the network layer to send the signal message
        self.timeout_add(150, self.stop)

    def signal_stop(self, sig, frame):
        """ This time we really want to exit without waiting """
        signame = SIGNAMES.get(sig, sig)
        log("signal_stop(%s, %s) calling stop", signame, frame)
        self.stop()


    def send(self, *args):
        if HEXLIFY_PACKETS:
            args = args[:1]+[binascii.hexlify(str(x)[:32]) for x in args[1:]]
        log("send: adding '%s' message (%s items already in queue)", args[0], self.send_queue.qsize())
        self.send_queue.put(args)
        p = self.protocol
        if p:
            p.source_has_more()
        INJECT_FAULT(p)

    def get_packet(self):
        try:
            item = self.send_queue.get(False)
        except:
            item = None
        return (item, None, None, self.send_queue.qsize()>0)

    def process_packet(self, proto, packet):
        command = bytestostr(packet[0])
        if command==Protocol.CONNECTION_LOST:
            log("connection-lost: %s, calling stop", packet[1:])
            self.net_stop()
            return
        elif command==Protocol.GIBBERISH:
            log.warn("gibberish received:")
            log.warn(" %s", repr_ellipsized(packet[1], limit=80))
            log.warn(" stopping")
            self.net_stop()
            return
        elif command=="stop":
            log("received stop message")
            self.net_stop()
            return
        elif command=="exit":
            log("received exit message")
            sys.exit(0)
            return
        #make it easier to hookup signals to methods:
        attr = command.replace("-", "_")
        if self.method_whitelist is not None and attr not in self.method_whitelist:
            log.warn("invalid command: %s (not in whitelist: %s)", attr, self.method_whitelist)
            return
        wo = self.wrapped_object
        if not wo:
            log("wrapped object is no more, ignoring method call '%s'", attr)
            return
        method = getattr(wo, attr, None)
        if not method:
            log.warn("unknown command: '%s'", attr)
            log.warn(" packet: '%s'", repr_ellipsized(str(packet)))
            return
        if DEBUG_WRAPPER:
            log("calling %s.%s%s", wo, attr, str(tuple(packet[1:]))[:128])
        self.idle_add(method, *packet[1:])
        INJECT_FAULT(proto)
Exemplo n.º 55
0
 def __init__(self):
     Thread.__init__(self, name="Worker_Thread")
     self.items = Queue()
     self.exit = False
     self.setDaemon(True)