示例#1
0
    def test_close(self):
        pc = RTCPeerConnection()
        pc_states = track_states(pc)

        # close once
        run(pc.close())

        # close twice
        run(pc.close())

        self.assertEqual(pc_states['signalingState'], ['stable', 'closed'])
示例#2
0
 def webrtc_handler(self):
     signaling = ApprtcSignaling(self.room)
     pc = RTCPeerConnection()
     loop = asyncio.new_event_loop()
     try:
         loop.run_until_complete(self.run(pc, signaling))
     except KeyboardInterrupt:
         pass
     finally:
         loop.run_until_complete(pc.close())
         loop.run_until_complete(signaling.close())
示例#3
0
    def run(self):
        signaling = CopyAndPasteSignaling()
        pc = RTCPeerConnection()

        if self._conn_type == WebRTCConnection.OFFER:
            func = self._set_offer
        else:
            func = self._run_answer

        try:
            asyncio.run(func(pc, signaling))
        except KeyboardInterrupt:
            pass
        finally:
            loop = asyncio.get_event_loop()
            loop.run_until_complete(pc.close())
            loop.run_until_complete(signaling.close())
示例#4
0
def aio_loop(args):
    pc = RTCPeerConnection()

    if args.role == "caller":
        coro = run_caller_rtc(pc)
    else:
        coro = run_callee_rtc(pc)

    asyncio.set_event_loop_policy(aiogevent.EventLoopPolicy())
    loop = asyncio.get_event_loop()
    try:
        asyncio.ensure_future(coro)
        loop.run_forever()
    except KeyboardInterrupt:
        pass
    finally:
        loop.run_until_complete(pc.close())
示例#5
0
async def setup_test(connection: RTCPeerConnection, offer) -> RTCSessionDescription:
    description = RTCSessionDescription(offer["sdp"], offer["type"])

    session_description = SessionDescription.parse(offer["sdp"])
    if len(session_description.media) != 1:
        raise ValueError("Only one media channel accepted.")

    media = session_description.media[0]
    if len(media.ice_candidates) != 1:
        raise ValueError("Only one ICE candidate accepted.")

    if not media.ice_candidates_complete:
        raise ValueError("ICE candidates must be completed")

    candidate = media.ice_candidates[0]

    await connection.setRemoteDescription(description)
    await connection.setLocalDescription(await connection.createAnswer())

    logger.debug(f"[{candidate}] Beginning test with this candidate…")

    @connection.on("datachannel")
    def on_datachannel(channel: RTCDataChannel):
        logger.debug(f"[{candidate}] Established an RTCDataChannel")

        @channel.on("message")
        def on_message(message):
            logger.debug(f"[{candidate}] Received a message")

            if message == MAGIC_QUESTION:
                channel.send(MAGIC_ANSWER)
            else:
                channel.close()
                asyncio.ensure_future(connection.close())

    # Make the connection time out
    asyncio.get_event_loop().call_later(
        TIMEOUT_TIME, asyncio.ensure_future, connection.close()
    )

    # this is the answer to return
    return connection.localDescription
示例#6
0
def webrtc_worker(offer: RTCSessionDescription, answer_queue: queue.Queue):
    pc = RTCPeerConnection()

    loop = asyncio.new_event_loop()

    task = loop.create_task(process_offer(pc, offer))


    def done_callback(task: asyncio.Task):
        pc: RTCPeerConnection = task.result()
        answer_queue.put(pc.localDescription)


    task.add_done_callback(done_callback)

    try:
        loop.run_forever()
    finally:
        logger.debug("Event loop %s has stopped.", loop)
        loop.run_until_complete(pc.close())
        loop.run_until_complete(loop.shutdown_asyncgens())
        loop.close()
        logger.debug("Event loop %s cleaned up.", loop)
示例#7
0
class RTCConnection(SubscriptionProducerConsumer):
    _log = logging.getLogger("rtcbot.RTCConnection")

    def __init__(
        self,
        defaultChannelOrdered=True,
        loop=None,
        rtcConfiguration=RTCConfiguration(
            [RTCIceServer(urls="stun:stun.l.google.com:19302")]),
    ):
        super().__init__(
            directPutSubscriptionType=asyncio.Queue,
            defaultSubscriptionType=asyncio.Queue,
            logger=self._log,
        )
        self._loop = loop
        if self._loop is None:
            self._loop = asyncio.get_event_loop()

        self._dataChannels = {}

        # These allow us to easily signal when the given events happen
        self._dataChannelSubscriber = SubscriptionProducer(
            logger=self._log.getChild("dataChannelSubscriber"))
        self._rtc = RTCPeerConnection(configuration=rtcConfiguration)
        self._rtc.on("datachannel", self._onDatachannel)
        # self._rtc.on("iceconnectionstatechange", self._onIceConnectionStateChange)
        self._rtc.on("track", self._onTrack)

        self._hasRemoteDescription = False
        self._defaultChannelOrdered = defaultChannelOrdered

        self._videoHandler = ConnectionVideoHandler(self._rtc)
        self._audioHandler = ConnectionAudioHandler(self._rtc)

    async def getLocalDescription(self, description=None):
        """
        Gets the description to send on. Creates an initial description
        if no remote description was passed, and creates a response if
        a remote was given,
        """
        if self._hasRemoteDescription or description is not None:
            # This means that we received an offer - either the remote description
            # was already set, or we passed in a description. In either case,
            # instead of initializing a new connection, we prepare a response
            if not self._hasRemoteDescription:
                await self.setRemoteDescription(description)
            self._log.debug("Creating response to connection offer")
            try:
                answer = await self._rtc.createAnswer()
            except AttributeError:
                self._log.exception(
                    "\n>>> Looks like the offer didn't include the necessary info to set up audio/video. See RTCConnection.video.offerToReceive(). <<<\n\n"
                )
                raise
            await self._rtc.setLocalDescription(answer)
            return {
                "sdp": self._rtc.localDescription.sdp,
                "type": self._rtc.localDescription.type,
            }

        # There was no remote description, which means that we are initializing the
        # connection.

        # Before starting init, we create a default data channel for the connection
        self._log.debug("Setting up default data channel")
        channel = DataChannel(
            self._rtc.createDataChannel("default",
                                        ordered=self._defaultChannelOrdered))
        # Subscribe the default channel directly to our own inputs and outputs.
        # We have it listen to our own self._get, and write to our self._put_nowait
        channel.putSubscription(NoClosedSubscription(self._get))
        channel.subscribe(self._put_nowait)
        channel.onReady(lambda: self._setReady(channel.ready))
        self._dataChannels[channel.name] = channel

        # Make sure we offer to receive video and audio if if isn't set up yet with
        # all the receiving transceivers
        if len(self.video._senders) < self.video._offerToReceive:
            self._log.debug("Offering to receive video")
            for i in range(self.video._offerToReceive -
                           len(self.video._senders)):
                self._rtc.addTransceiver("video", "recvonly")
        if len(self.audio._senders) < self.audio._offerToReceive:
            self._log.debug("Offering to receive audio")
            for i in range(self.audio._offerToReceive -
                           len(self.audio._senders)):
                self._rtc.addTransceiver("audio", "recvonly")

        self._log.debug("Creating new connection offer")
        offer = await self._rtc.createOffer()
        await self._rtc.setLocalDescription(offer)
        return {
            "sdp": self._rtc.localDescription.sdp,
            "type": self._rtc.localDescription.type,
        }

    async def setRemoteDescription(self, description):
        self._log.debug("Setting remote connection description")
        await self._rtc.setRemoteDescription(
            RTCSessionDescription(**description))
        self._hasRemoteDescription = True

    def _onDatachannel(self, channel):
        """
        When a data channel comes in, adds it to the data channels, and sets up its messaging and stuff.

        """
        channel = DataChannel(channel)
        self._log.debug("Got channel: %s", channel.name)
        if channel.name == "default":
            # Subscribe the default channel directly to our own inputs and outputs.
            # We have it listen to our own self._get, and write to our self._put_nowait
            channel.putSubscription(NoClosedSubscription(self._get))
            channel.subscribe(self._put_nowait)
            channel.onReady(lambda: self._setReady(channel.ready))

            # Set the default channel
            self._defaultChannel = channel

        else:
            self._dataChannelSubscriber.put_nowait(channel)
        self._dataChannels[channel.name] = channel

    def _onTrack(self, track):
        self._log.debug("Received %s track from connection", track.kind)
        if track.kind == "audio":
            self._audioHandler._onTrack(track)
        elif track.kind == "video":
            self._videoHandler._onTrack(track)

    def onDataChannel(self, callback=None):
        """
        Acts as a subscriber...
        """
        return self._dataChannelSubscriber.subscribe(callback)

    def addDataChannel(self, name, ordered=True):
        """
        Adds a data channel to the connection. Note that the RTCConnection adds a "default" channel
        automatically, which you can subscribe to directly.
        """
        self._log.debug("Adding data channel to connection")

        if name in self._dataChannels or name == "default":
            raise KeyError("Data channel %s already exists", name)

        dc = DataChannel(self._rtc.createDataChannel(name, ordered=ordered))
        self._dataChannels[name] = dc
        return dc

    def getDataChannel(self, name):
        """
        Returns the data channel with the given name. Please note that the "default" channel is considered special,
        and is not returned.
        """
        if name == "default":
            raise KeyError(
                "Default channel not available for 'get'. Use the RTCConnection's subscribe and put_nowait methods for access to it."
            )
        return self._dataChannels[name]

    @property
    def video(self):
        """
        Convenience function - you can subscribe to it to get video frames once they show up
        """
        return self._videoHandler

    @property
    def audio(self):
        """
        Convenience function - you can subscribe to it to get audio once a stream is received
        """
        return self._audioHandler

    def close(self):
        """
        If the loop is running, returns a future that will close the connection. Otherwise, runs
        the loop temporarily to complete closing.
        """
        super().close()
        # And closes all tracks
        self.video.close()
        self.audio.close()

        for dc in self._dataChannels:
            self._dataChannels[dc].close()

        self._dataChannelSubscriber.close()

        if self._loop.is_running():
            self._log.debug("Loop is running - close will return a future!")
            return asyncio.ensure_future(self._rtc.close())
        else:
            self._loop.run_until_complete(self._rtc.close())
        return None

    def send(self, msg):
        """
        Send is an alias for put_nowait - makes it easier for people new to rtcbot to understand
        what is going on
        """
        self.put_nowait(msg)
示例#8
0
class WebRtcWorker:
    _thread: Union[threading.Thread, None]
    _loop: Union[AbstractEventLoop, None]
    _answer_queue: queue.Queue
    _video_transformer: Optional[VideoTransformerBase]
    _video_receiver: Optional[VideoReceiver]

    @property
    def video_transformer(self) -> Optional[VideoTransformerBase]:
        return self._video_transformer

    @property
    def video_receiver(self) -> Optional[VideoReceiver]:
        return self._video_receiver

    def __init__(
        self,
        mode: WebRtcMode,
        player_factory: Optional[MediaPlayerFactory] = None,
        in_recorder_factory: Optional[MediaRecorderFactory] = None,
        out_recorder_factory: Optional[MediaRecorderFactory] = None,
        video_transformer_factory: Optional[VideoTransformerFactory] = None,
        async_transform: bool = True,
    ) -> None:
        self._thread = None
        self._loop = None
        self.pc = RTCPeerConnection()
        self._answer_queue = queue.Queue()
        self._stop_requested = False

        self.mode = mode
        self.player_factory = player_factory
        self.in_recorder_factory = in_recorder_factory
        self.out_recorder_factory = out_recorder_factory
        self.video_transformer_factory = video_transformer_factory
        self.async_transform = async_transform

        self._video_transformer = None
        self._video_receiver = None

    def _run_webrtc_thread(
        self,
        sdp: str,
        type_: str,
        in_recorder_factory: Optional[MediaRecorderFactory],
        out_recorder_factory: Optional[MediaRecorderFactory],
        player_factory: Optional[MediaPlayerFactory],
        video_transformer_factory: Optional[VideoTransformerFactory],
        video_receiver: Optional[VideoReceiver],
        async_transform: bool,
    ):
        try:
            self._webrtc_thread(
                sdp=sdp,
                type_=type_,
                player_factory=player_factory,
                in_recorder_factory=in_recorder_factory,
                out_recorder_factory=out_recorder_factory,
                video_transformer_factory=video_transformer_factory,
                video_receiver=video_receiver,
                async_transform=async_transform,
            )
        except Exception as e:
            logger.warn("An error occurred in the WebRTC worker thread: %s", e)

            if self._loop:
                logger.warn("An event loop exists. Clean up it.")
                loop = self._loop
                loop.run_until_complete(self.pc.close())
                loop.run_until_complete(loop.shutdown_asyncgens())
                loop.close()
                logger.warn("Event loop %s cleaned up.", loop)

            self._answer_queue.put(
                e)  # Send the error object to the main thread

    def _webrtc_thread(
        self,
        sdp: str,
        type_: str,
        player_factory: Optional[MediaPlayerFactory],
        in_recorder_factory: Optional[MediaRecorderFactory],
        out_recorder_factory: Optional[MediaRecorderFactory],
        video_transformer_factory: Optional[Callable[[],
                                                     VideoTransformerBase]],
        video_receiver: Optional[VideoReceiver],
        async_transform: bool,
    ):
        logger.debug(
            "_webrtc_thread(player_factory=%s, video_transformer_factory=%s)",
            player_factory,
            video_transformer_factory,
        )

        loop = asyncio.new_event_loop()
        self._loop = loop

        offer = RTCSessionDescription(sdp, type_)

        def callback(localDescription):
            self._answer_queue.put(localDescription)

        video_transformer = None
        if video_transformer_factory:
            video_transformer = video_transformer_factory()

        if self.mode == WebRtcMode.SENDRECV:
            if video_transformer is None:
                logger.info("mode is set as sendrecv, "
                            "but video_transformer_factory is not specified. "
                            "A simple loopback transformer is used.")
                video_transformer = NoOpVideoTransformer()

        self._video_transformer = video_transformer

        loop.create_task(
            _process_offer(
                self.mode,
                self.pc,
                offer,
                player_factory=player_factory,
                in_recorder_factory=in_recorder_factory,
                out_recorder_factory=out_recorder_factory,
                video_transformer=video_transformer,
                video_receiver=video_receiver,
                async_transform=async_transform,
                callback=callback,
            ))

        try:
            loop.run_forever()
        finally:
            logger.debug("Event loop %s has stopped.", loop)
            loop.run_until_complete(self.pc.close())
            loop.run_until_complete(loop.shutdown_asyncgens())
            loop.close()
            logger.debug("Event loop %s cleaned up.", loop)

    def process_offer(
            self,
            sdp,
            type_,
            timeout: Union[float, None] = 10.0) -> RTCSessionDescription:
        if self.mode == WebRtcMode.SENDONLY:
            self._video_receiver = VideoReceiver(queue_maxsize=1)

        self._thread = threading.Thread(
            target=self._run_webrtc_thread,
            kwargs={
                "sdp": sdp,
                "type_": type_,
                "player_factory": self.player_factory,
                "in_recorder_factory": self.in_recorder_factory,
                "out_recorder_factory": self.out_recorder_factory,
                "video_transformer_factory": self.video_transformer_factory,
                "video_receiver": self._video_receiver,
                "async_transform": self.async_transform,
            },
            daemon=True,
        )
        self._thread.start()

        try:
            result = self._answer_queue.get(block=True, timeout=timeout)
        except queue.Empty:
            self.stop(timeout=1)
            raise TimeoutError("Processing offer and initializing the worker "
                               f"has not finished in {timeout} seconds")

        if isinstance(result, Exception):
            raise result

        return result

    def stop(self, timeout: Union[float, None] = 1.0):
        if self._loop:
            self._loop.stop()
        if self._thread:
            self._thread.join(timeout=timeout)
示例#9
0
    def test_connect_audio_mid_changes(self):
        pc1 = RTCPeerConnection()
        pc1_states = track_states(pc1)

        pc2 = RTCPeerConnection()
        pc2_states = track_states(pc2)

        self.assertEqual(pc1.iceConnectionState, 'new')
        self.assertEqual(pc1.iceGatheringState, 'new')
        self.assertIsNone(pc1.localDescription)
        self.assertIsNone(pc1.remoteDescription)

        self.assertEqual(pc2.iceConnectionState, 'new')
        self.assertEqual(pc2.iceGatheringState, 'new')
        self.assertIsNone(pc2.localDescription)
        self.assertIsNone(pc2.remoteDescription)

        # add audio tracks immediately
        pc1.addTrack(AudioStreamTrack())
        pc1.getTransceivers()[0].mid = 'sdparta_0'  # pretend we're Firefox!
        self.assertEqual(mids(pc1), ['sdparta_0'])

        pc2.addTrack(AudioStreamTrack())
        self.assertEqual(mids(pc2), ['audio'])

        # create offer
        offer = run(pc1.createOffer())
        self.assertEqual(offer.type, 'offer')
        self.assertTrue('m=audio ' in offer.sdp)
        self.assertFalse('a=candidate:' in offer.sdp)

        run(pc1.setLocalDescription(offer))
        self.assertEqual(pc1.iceConnectionState, 'new')
        self.assertEqual(pc1.iceGatheringState, 'complete')
        self.assertTrue('m=audio ' in pc1.localDescription.sdp)
        self.assertTrue('a=candidate:' in pc1.localDescription.sdp)
        self.assertTrue('a=sendrecv' in pc1.localDescription.sdp)
        self.assertTrue('a=fingerprint:sha-256' in pc1.localDescription.sdp)
        self.assertTrue('a=setup:actpass' in pc1.localDescription.sdp)
        self.assertTrue('a=mid:sdparta_0' in pc1.localDescription.sdp)

        # handle offer
        run(pc2.setRemoteDescription(pc1.localDescription))
        self.assertEqual(pc2.remoteDescription, pc1.localDescription)
        self.assertEqual(len(pc2.getReceivers()), 1)
        self.assertEqual(len(pc2.getSenders()), 1)
        self.assertEqual(mids(pc2), ['sdparta_0'])

        # create answer
        answer = run(pc2.createAnswer())
        self.assertEqual(answer.type, 'answer')
        self.assertTrue('m=audio ' in answer.sdp)
        self.assertFalse('a=candidate:' in answer.sdp)

        run(pc2.setLocalDescription(answer))
        self.assertEqual(pc2.iceConnectionState, 'checking')
        self.assertEqual(pc2.iceGatheringState, 'complete')
        self.assertTrue('m=audio ' in pc2.localDescription.sdp)
        self.assertTrue('a=candidate:' in pc2.localDescription.sdp)
        self.assertTrue('a=sendrecv' in pc1.localDescription.sdp)
        self.assertTrue('a=fingerprint:sha-256' in pc2.localDescription.sdp)
        self.assertTrue('a=setup:active' in pc2.localDescription.sdp)
        self.assertTrue('a=mid:sdparta_0' in pc2.localDescription.sdp)

        # handle answer
        run(pc1.setRemoteDescription(pc2.localDescription))
        self.assertEqual(pc1.remoteDescription, pc2.localDescription)
        self.assertEqual(pc1.iceConnectionState, 'checking')

        # check outcome
        run(asyncio.sleep(1))
        self.assertEqual(pc1.iceConnectionState, 'completed')
        self.assertEqual(pc2.iceConnectionState, 'completed')

        # close
        run(pc1.close())
        run(pc2.close())
        self.assertEqual(pc1.iceConnectionState, 'closed')
        self.assertEqual(pc2.iceConnectionState, 'closed')

        # check state changes
        self.assertEqual(pc1_states['iceConnectionState'],
                         ['new', 'checking', 'completed', 'closed'])
        self.assertEqual(pc1_states['iceGatheringState'],
                         ['new', 'gathering', 'complete'])
        self.assertEqual(pc1_states['signalingState'],
                         ['stable', 'have-local-offer', 'stable', 'closed'])

        self.assertEqual(pc2_states['iceConnectionState'],
                         ['new', 'checking', 'completed', 'closed'])
        self.assertEqual(pc2_states['iceGatheringState'],
                         ['new', 'gathering', 'complete'])
        self.assertEqual(pc2_states['signalingState'],
                         ['stable', 'have-remote-offer', 'stable', 'closed'])
示例#10
0
 def test_createOffer_closed(self):
     pc = RTCPeerConnection()
     run(pc.close())
     with self.assertRaises(InvalidStateError) as cm:
         run(pc.createOffer())
     self.assertEqual(str(cm.exception), 'RTCPeerConnection is closed')
class WebRTCConnection(threading.Thread, BaseWorker):

    OFFER = 1
    ANSWER = 2
    HOST_REQUEST = b"01"
    REMOTE_REQUEST = b"02"

    def __init__(self, grid_descriptor, worker, destination, connections,
                 conn_type):
        threading.Thread.__init__(self)
        BaseWorker.__init__(self, hook=hook, id=destination)
        self._conn_type = conn_type
        self._origin = worker.id
        self._worker = worker
        self._worker.tensor_requests = []
        self._destination = destination
        self._grid = grid_descriptor
        self._msg = ""
        self._request_pool = queue.Queue()
        self._response_pool = queue.Queue()
        self.channel = None
        self.available = True
        self.connections = connections

    # Add a new operation on request_pool
    async def _send_msg(self, message, location=None):
        self._request_pool.put(WebRTCConnection.HOST_REQUEST + message)

        # Wait
        # PySyft is a sync library and should wait for this response.
        while self._response_pool.empty():
            await asyncio.sleep(0)
        return self._response_pool.get()

    # Client side
    # Called when someone call syft function locally eg. tensor.send(node)
    def _recv_msg(self, message):
        """ Quando recebe algo local e quer mandar para o worker remoto.
            Necessário retorno após envio.
        """
        if self.available:
            return asyncio.run(self._send_msg(message))
        else:  # PySyft's GC delete commands
            return self._worker._recv_msg(message)

    # Running async all time
    async def send(self, channel):
        while self.available:
            if not self._request_pool.empty():
                channel.send(self._request_pool.get())
            await asyncio.sleep(0)

    # Running async all time
    def process_msg(self, message, channel):
        if message[:2] == WebRTCConnection.HOST_REQUEST:
            try:
                decoded_response = self._worker._recv_msg(message[2:])
            except GetNotPermittedError as e:
                message = sy.serde.deserialize(message[2:],
                                               worker=self._worker)
                self._worker.tensor_requests.append(message)
                decoded_response = sy.serde.serialize(e)

            channel.send(WebRTCConnection.REMOTE_REQUEST + decoded_response)
        else:
            self._response_pool.put(message[2:])

    def search(self, query):
        message = SearchMessage(query)
        serialized_message = sy.serde.serialize(message)
        response = asyncio.run(self._send_msg(serialized_message))
        return sy.serde.deserialize(response)

    # Main
    def run(self):
        self.signaling = CopyAndPasteSignaling()
        self.pc = RTCPeerConnection()

        if self._conn_type == WebRTCConnection.OFFER:
            func = self._set_offer
        else:
            func = self._run_answer

        self.loop = asyncio.new_event_loop()
        try:
            self.loop.run_until_complete(func(self.pc, self.signaling))
        except Exception:
            self.loop.run_until_complete(self.pc.close())

            # Stop loop:
            self.loop.run_until_complete(self.loop.shutdown_asyncgens())
            self.loop.close()

    # SERVER
    async def _set_offer(self, pc, signaling):
        await signaling.connect()
        channel = pc.createDataChannel("chat")

        self.channel = channel

        @channel.on("open")
        def on_open():
            asyncio.ensure_future(self.send(channel))

        @channel.on("message")
        def on_message(message):
            self.process_msg(message, channel)

        await pc.setLocalDescription(await pc.createOffer())
        local_description = object_to_string(pc.localDescription)

        response = {
            MSG_FIELD.TYPE: NODE_EVENTS.WEBRTC_OFFER,
            MSG_FIELD.PAYLOAD: local_description,
            MSG_FIELD.FROM: self._origin,
        }

        forward_payload = {
            MSG_FIELD.TYPE: GRID_EVENTS.FORWARD,
            MSG_FIELD.DESTINATION: self._destination,
            MSG_FIELD.CONTENT: response,
        }

        self._grid.send(json.dumps(forward_payload))
        await self.consume_signaling(pc, signaling)

    # CLIENT
    async def _run_answer(self, pc, signaling):
        await signaling.connect()

        @pc.on("datachannel")
        def on_datachannel(channel):
            asyncio.ensure_future(self.send(channel))

            self.channel = channel

            @channel.on("message")
            def on_message(message):
                self.process_msg(message, channel)

        await self.consume_signaling(pc, signaling)

    async def consume_signaling(self, pc, signaling):

        # Async keep-alive connection thread
        while self.available:
            sleep_time = 0
            if self._msg == "":
                await asyncio.sleep(sleep_time)
                continue

            obj = object_from_string(self._msg)

            if isinstance(obj, RTCSessionDescription):
                await pc.setRemoteDescription(obj)
                if obj.type == "offer":
                    # send answer
                    await pc.setLocalDescription(await pc.createAnswer())
                    local_description = object_to_string(pc.localDescription)

                    response = {
                        MSG_FIELD.TYPE: NODE_EVENTS.WEBRTC_ANSWER,
                        MSG_FIELD.FROM: self._origin,
                        MSG_FIELD.PAYLOAD: local_description,
                    }

                    forward_payload = {
                        MSG_FIELD.TYPE: GRID_EVENTS.FORWARD,
                        MSG_FIELD.DESTINATION: self._destination,
                        MSG_FIELD.CONTENT: response,
                    }
                    self._grid.send(json.dumps(forward_payload))
                    sleep_time = 10
            self._msg = ""
        raise Exception

    def disconnect(self):
        self.available = False
        del self.connections[self._destination]

    def set_msg(self, content: str):
        self._msg = content
示例#12
0
class WebRTCConnection(threading.Thread, BaseWorker):

    OFFER = 1
    ANSWER = 2
    HOST_REQUEST = b"01"
    REMOTE_REQUEST = b"02"

    def __init__(self, grid_descriptor, worker, destination, connections,
                 conn_type):
        """ Create a new webrtc peer connection.
            
            Args:
                grid_descriptor: Grid network's websocket descriptor to forward webrtc connection request.
                worker: Virtual Worker that represents this peer.
                destination: Destination Peer ID.
                connections: Peer connection descriptors.
                conn_type: Connection responsabilities this peer should provide. (offer, answer)
        """
        threading.Thread.__init__(self)
        BaseWorker.__init__(self, hook=sy.hook, id=destination)
        self._conn_type = conn_type
        self._origin = worker.id
        self._worker = worker
        self._worker.tensor_requests = []
        self._destination = destination
        self._grid = grid_descriptor
        self._msg = ""
        self._request_pool = queue.Queue()
        self._response_pool = queue.Queue()
        self.channel = None
        self.available = True
        self.connections = connections

    def _send_msg(self, message: bin, location=None):
        """ Add a new syft operation on the request_pool to be processed asynchronously.
            
            Args:
                message : Binary Syft message.
                location : peer location (This parameter should be preserved to keep the BaseWorker compatibility, but we do not use it.)
            
            Returns:
                response_message: Binary Syft response message.
        """
        self._request_pool.put(WebRTCConnection.HOST_REQUEST + message)

        # Wait
        # PySyft is a sync library and should wait for this response.
        while self._response_pool.empty():
            time.sleep(0)
        return self._response_pool.get()

    def _recv_msg(self, message: bin):
        """ Called when someone call syft function locally eg. tensor.send(node)
            
            PS: This method should be synchronized to keep the compatibility with Syft internal operations.
            Args:
                message: Binary Syft message.

            Returns:
                response_message : Binary syft response message.
        """
        if self.available:
            return self._send_msg(message)
        else:  # PySyft's GC delete commands
            return self._worker._recv_msg(message)

    # Running async all time
    async def send(self, channel):
        """ Async method that will listen peer remote's requests and put it into the request_pool queue to be processed.
            
            Args:
                channel: Connection channel used by the peers.
        """
        while self.available:
            if not self._request_pool.empty():
                channel.send(self._request_pool.get())
            await asyncio.sleep(0)

    # Running async all time
    def process_msg(self, message, channel):
        """ Process syft messages forwarding them to the peer virtual worker and put the response into the response_pool queue to be delivered async.
            
            Args:
                message: Binary syft message.
                channel: Connection channel used by the peers.
        """
        if message[:2] == WebRTCConnection.HOST_REQUEST:
            try:
                decoded_response = self._worker._recv_msg(message[2:])
            except GetNotPermittedError as e:
                message = sy.serde.deserialize(message[2:],
                                               worker=self._worker)
                self._worker.tensor_requests.append(message)
                decoded_response = sy.serde.serialize(e)

            channel.send(WebRTCConnection.REMOTE_REQUEST + decoded_response)
        else:
            self._response_pool.put(message[2:])

    def search(self, query):
        """ Node's dataset search method overwrite.
            
            Args:
                query: Query used to search by the desired dataset tag.
            Returns:
                query_response: Return the peer's response.
        """
        message = SearchMessage(query)
        serialized_message = sy.serde.serialize(message)
        response = self._send_msg(serialized_message)
        return sy.serde.deserialize(response)

    # Main
    def run(self):
        """ Main thread method used to set up the connection and manage all the process."""
        self.signaling = CopyAndPasteSignaling()
        self.pc = RTCPeerConnection()

        if self._conn_type == WebRTCConnection.OFFER:
            func = self._set_offer
        else:
            func = self._run_answer

        self.loop = asyncio.new_event_loop()
        try:
            self.loop.run_until_complete(func(self.pc, self.signaling))
        except Exception:
            self.loop.run_until_complete(self.pc.close())

            # Stop loop:
            self.loop.run_until_complete(self.loop.shutdown_asyncgens())
            self.loop.close()

    # OFFER
    async def _set_offer(self, pc, signaling):
        """ Private method used to set up an offer to estabilish a new webrtc connection.
            
            Args:
                pc: Peer Connection  descriptor
                signaling: Webrtc signaling instance.
        """
        await signaling.connect()
        channel = pc.createDataChannel("chat")

        self.channel = channel

        @channel.on("open")
        def on_open():
            asyncio.ensure_future(self.send(channel))

        @channel.on("message")
        def on_message(message):
            self.process_msg(message, channel)

        await pc.setLocalDescription(await pc.createOffer())
        local_description = object_to_string(pc.localDescription)

        response = {
            MSG_FIELD.TYPE: NODE_EVENTS.WEBRTC_OFFER,
            MSG_FIELD.PAYLOAD: local_description,
            MSG_FIELD.FROM: self._origin,
        }

        forward_payload = {
            MSG_FIELD.TYPE: GRID_EVENTS.FORWARD,
            MSG_FIELD.DESTINATION: self._destination,
            MSG_FIELD.CONTENT: response,
        }

        self._grid.send(json.dumps(forward_payload))
        await self.consume_signaling(pc, signaling)

    # ANSWER
    async def _run_answer(self, pc, signaling):
        """ Private method used to set up an answer to estabilish a new webrtc connection.
            
            Args:
                pc: Peer connection.
                signaling: Webrtc signaling instance.
        """
        await signaling.connect()

        @pc.on("datachannel")
        def on_datachannel(channel):
            asyncio.ensure_future(self.send(channel))

            self.channel = channel

            @channel.on("message")
            def on_message(message):
                self.process_msg(message, channel)

        await self.consume_signaling(pc, signaling)

    async def consume_signaling(self, pc, signaling):
        """ Consume signaling to go through all the webrtc connection protocol.
            
            Args:
                pc: Peer Connection.
                signaling: Webrtc signaling instance.
            Exception:
                ConnectionClosedException: Exception used to finish this connection and close this thread.
        """
        # Async keep-alive connection thread
        while self.available:
            sleep_time = 0
            if self._msg == "":
                await asyncio.sleep(sleep_time)
                continue

            obj = object_from_string(self._msg)

            if isinstance(obj, RTCSessionDescription):
                await pc.setRemoteDescription(obj)
                if obj.type == "offer":
                    # send answer
                    await pc.setLocalDescription(await pc.createAnswer())
                    local_description = object_to_string(pc.localDescription)

                    response = {
                        MSG_FIELD.TYPE: NODE_EVENTS.WEBRTC_ANSWER,
                        MSG_FIELD.FROM: self._origin,
                        MSG_FIELD.PAYLOAD: local_description,
                    }

                    forward_payload = {
                        MSG_FIELD.TYPE: GRID_EVENTS.FORWARD,
                        MSG_FIELD.DESTINATION: self._destination,
                        MSG_FIELD.CONTENT: response,
                    }
                    self._grid.send(json.dumps(forward_payload))
                    sleep_time = 10
            self._msg = ""
        raise Exception

    def disconnect(self):
        """ Disconnect from the peer and finish this thread. """
        self.available = False
        del self.connections[self._destination]

    def set_msg(self, content: str):
        self._msg = content
示例#13
0
class WebRTCConnection(BidirectionalConnection):
    loop: Any

    def __init__(self, node: AbstractNode) -> None:
        # WebRTC Connection representation

        # As we have a full-duplex connection,
        # it's necessary to use a node instance
        # inside of this connection. In order to
        # be able to process requests sent by
        # the other peer.
        # All the requests messages will be forwarded
        # to this node.
        self.node = node

        # EventLoop that manages async tasks (producer/consumer)
        # This structure is global and needs to be
        # defined beforehand.
        try:
            self.loop = get_running_loop()
            log = "♫♫♫ > ...using a running event loop..."
            logger.debug(log)
            print(log)
        except RuntimeError as e:
            self.loop = None
            log = f"♫♫♫ > ...error getting a running event Loop... {e}"
            logger.error(log)
            print(log)

        if self.loop is None:
            log = "♫♫♫ > ...creating a new event loop..."
            print(log)
            logger.debug(log)
            self.loop = asyncio.new_event_loop()

        # Message pool (High Priority)
        # These queues will be used to manage
        # async  messages.
        try:
            self.producer_pool: asyncio.Queue = asyncio.Queue(
                loop=self.loop, )  # Request Messages / Request Responses
            self.consumer_pool: asyncio.Queue = asyncio.Queue(
                loop=self.loop, )  # Request Responses

            # Initialize a PeerConnection structure
            self.peer_connection = RTCPeerConnection()

            # Set channel descriptor as None
            # This attribute will be used for external classes
            # in order to verify if the connection channel
            # was established.
            self.channel: Optional[RTCDataChannel] = None
            self._client_address: Optional[Address] = None

            # asyncio.ensure_future(self.heartbeat())

        except Exception as e:
            log = f"Got an exception in WebRTCConnection __init__. {e}"
            logger.error(log)
            raise e

    @syft_decorator(typechecking=True)
    async def _set_offer(self) -> str:
        """Initialize a Real Time Communication Data Channel,
        set datachannel callbacks/tasks, and send offer payload
        message.

        :return: returns a signaling offer payload containing local description.
        :rtype: str
        """
        try:
            # Use the Peer Connection structure to
            # set the channel as a RTCDataChannel.
            self.channel = self.peer_connection.createDataChannel(
                "datachannel")

            # This method will be called by as a callback
            # function by the aioRTC lib when the when
            # the connection opens.
            @self.channel.on("open")
            async def on_open() -> None:  # type : ignore
                self.__producer_task = asyncio.ensure_future(self.producer())

            # This method is the aioRTC "consumer" task
            # and will be running as long as connection remains.
            # At this point we're just setting the method behavior
            # It'll start running after the connection opens.
            @self.channel.on("message")
            async def on_message(
                    message: Union[bin, str]) -> None:  # type: ignore
                # Forward all received messages to our own consumer method.
                await self.consumer(msg=message)

            # Set peer_connection to generate an offer message type.
            await self.peer_connection.setLocalDescription(
                await self.peer_connection.createOffer())

            # Generates the local description structure
            # and serialize it to string afterwards.
            local_description = object_to_string(
                self.peer_connection.localDescription)

            # Return the Offer local_description payload.
            return local_description
        except Exception as e:
            log = f"Got an exception in WebRTCConnection _set_offer. {e}"
            logger.error(log)
            raise e

    @syft_decorator(typechecking=True)
    async def _set_answer(self, payload: str) -> str:
        """Receives a signaling offer payload, initialize/set
        datachannel callbacks/tasks, updates remote local description
        using offer's payload message and returns a
        signaling answer payload.

        :return: returns a signaling answer payload containing local description.
        :rtype: str
        """

        try:

            @self.peer_connection.on("datachannel")
            def on_datachannel(channel: RTCDataChannel) -> None:
                self.channel = channel

                self.__producer_task = asyncio.ensure_future(self.producer())

                @self.channel.on("message")
                async def on_message(
                        message: Union[bin, str]) -> None:  # type: ignore
                    await self.consumer(msg=message)

            return await self._process_answer(payload=payload)
        except Exception as e:
            log = f"Got an exception in WebRTCConnection _set_answer. {e}"
            logger.error(log)
            raise e

    @syft_decorator(typechecking=True)
    async def _process_answer(self, payload: str) -> Union[str, None]:
        # Converts payload received by
        # the other peer in aioRTC Object
        # instance.
        try:
            msg = object_from_string(payload)

            # Check if Object instance is a
            # description of RTC Session.
            if isinstance(msg, RTCSessionDescription):

                # Use the target's network address/metadata
                # to set the remote description of this peer.
                # This will basically say to this peer how to find/connect
                # with to other peer.
                await self.peer_connection.setRemoteDescription(msg)

                # If it's an offer message type,
                # generates your own local description
                # and send it back in order to tell
                # to the other peer how to find you.
                if msg.type == "offer":
                    # Set peer_connection to generate an offer message type.
                    await self.peer_connection.setLocalDescription(
                        await self.peer_connection.createAnswer())

                    # Generates the local description structure
                    # and serialize it to string afterwards.
                    local_description = object_to_string(
                        self.peer_connection.localDescription)

                    # Returns the answer peer's local description
                    return local_description
        except Exception as e:
            log = f"Got an exception in WebRTCConnection _process_answer. {e}"
            logger.error(log)
            raise e
        return None

    @syft_decorator(typechecking=True)
    async def producer(self) -> None:
        # Async task to send messages to the other side.
        # These messages will be enqueued by PySyft Node Clients
        # by using PySyft routes and ClientConnection's inheritance.
        try:
            while True:
                # If self.producer_pool is empty
                # give up task queue priority, giving
                # computing time to the next task.
                msg = await self.producer_pool.get()

                await asyncio.sleep(message_cooldown)
                # If self.producer_pool.get() returned a message
                # send it as a binary using the RTCDataChannel.
                # logger.critical(f"> Sending MSG {msg.message} ID: {msg.id}")
                self.channel.send(msg.to_bytes())  # type: ignore
        except Exception as e:
            log = f"Got an exception in WebRTCConnection producer. {e}"
            logger.error(log)
            raise e

    def close(self) -> None:
        try:
            # Build Close Message to warn the other peer
            bye_msg = CloseConnectionMessage(address=Address())

            self.channel.send(bye_msg.to_bytes())  # type: ignore

            # Finish async tasks related with this connection
            self._finish_coroutines()
        except Exception as e:
            log = f"Got an exception in WebRTCConnection close. {e}"
            logger.error(log)
            raise e

    def _finish_coroutines(self) -> None:
        try:
            asyncio.run(self.peer_connection.close())
            self.__producer_task.cancel()
        except Exception as e:
            log = f"Got an exception in WebRTCConnection _finish_coroutines. {e}"
            logger.error(log)
            raise e

    @syft_decorator(typechecking=True)
    async def consumer(self, msg: bin) -> None:  # type: ignore
        try:
            # Async task to receive/process messages sent by the other side.
            # These messages will be sent by the other peer
            # as a service requests or responses for requests made by
            # this connection previously (ImmediateSyftMessageWithReply).

            # Deserialize the received message
            _msg = _deserialize(blob=msg, from_bytes=True)

            # Check if it's NOT  a response generated by a previous request
            # made by the client instance that uses this connection as a route.
            # PS: The "_client_address" attribute will be defined during
            # Node Client initialization.
            if _msg.address != self._client_address:
                # If it's a new service request, route it properly
                # using the node instance owned by this connection.

                # Immediate message with reply
                if isinstance(_msg, SignedImmediateSyftMessageWithReply):
                    reply = self.recv_immediate_msg_with_reply(msg=_msg)
                    await self.producer_pool.put(reply)

                # Immediate message without reply
                elif isinstance(_msg, SignedImmediateSyftMessageWithoutReply):
                    self.recv_immediate_msg_without_reply(msg=_msg)

                elif isinstance(_msg, CloseConnectionMessage):
                    # Just finish async tasks related with this connection
                    self._finish_coroutines()

                # Eventual message without reply
                else:
                    self.recv_eventual_msg_without_reply(msg=_msg)

            # If it's true, the message will have the client's address as destination.
            else:
                await self.consumer_pool.put(_msg)
        except Exception as e:
            log = f"Got an exception in WebRTCConnection consumer. {e}"
            logger.error(log)
            raise e

    @syft_decorator(typechecking=True)
    def recv_immediate_msg_with_reply(
        self, msg: SignedImmediateSyftMessageWithReply
    ) -> SignedImmediateSyftMessageWithoutReply:
        """Executes/Replies requests instantly.

        :return: returns an instance of SignedImmediateSyftMessageWithReply
        :rtype: SignedImmediateSyftMessageWithoutReply
        """
        # Execute node services now
        try:
            r = random.randint(0, 100000)
            logger.debug(
                f"> Before recv_immediate_msg_with_reply {r} {msg.message} {type(msg.message)}"
            )
            reply = self.node.recv_immediate_msg_with_reply(msg=msg)
            logger.debug(
                f"> After recv_immediate_msg_with_reply {r} {msg.message} {type(msg.message)}"
            )
            return reply
        except Exception as e:
            log = f"Got an exception in WebRTCConnection recv_immediate_msg_with_reply. {e}"
            logger.error(log)
            raise e

    @syft_decorator(typechecking=True)
    def recv_immediate_msg_without_reply(
            self, msg: SignedImmediateSyftMessageWithoutReply) -> None:
        """ Executes requests instantly. """
        try:
            r = random.randint(0, 100000)
            logger.debug(
                f"> Before recv_immediate_msg_without_reply {r} {msg.message} {type(msg.message)}"
            )
            self.node.recv_immediate_msg_without_reply(msg=msg)
            logger.debug(
                f"> After recv_immediate_msg_without_reply {r} {msg.message} {type(msg.message)}"
            )
        except Exception as e:
            log = f"Got an exception in WebRTCConnection recv_immediate_msg_without_reply. {e}"
            logger.error(log)
            raise e

    @syft_decorator(typechecking=True)
    def recv_eventual_msg_without_reply(
            self, msg: SignedEventualSyftMessageWithoutReply) -> None:
        """ Executes requests eventually. """
        try:
            self.node.recv_eventual_msg_without_reply(msg=msg)
        except Exception as e:
            log = f"Got an exception in WebRTCConnection recv_eventual_msg_without_reply. {e}"
            logger.error(log)
            raise e

    @syft_decorator(typechecking=False)
    def send_immediate_msg_with_reply(
        self, msg: SignedImmediateSyftMessageWithReply
    ) -> SignedImmediateSyftMessageWithReply:
        """Sends high priority messages and wait for their responses.

        :return: returns an instance of SignedImmediateSyftMessageWithReply.
        :rtype: SignedImmediateSyftMessageWithReply
        """
        try:
            return asyncio.run(self.send_sync_message(msg=msg))
        except Exception as e:
            log = f"Got an exception in WebRTCConnection send_immediate_msg_with_reply. {e}"
            logger.error(log)
            raise e

    @syft_decorator(typechecking=True)
    def send_immediate_msg_without_reply(
            self, msg: SignedImmediateSyftMessageWithoutReply) -> None:
        """" Sends high priority messages without waiting for their reply. """
        try:
            # asyncio.run(self.producer_pool.put_nowait(msg))
            self.producer_pool.put_nowait(msg)
            time.sleep(message_cooldown)
        except Exception as e:
            log = f"Got an exception in WebRTCConnection send_immediate_msg_without_reply. {e}"
            logger.error(log)
            raise e

    @syft_decorator(typechecking=True)
    def send_eventual_msg_without_reply(
            self, msg: SignedEventualSyftMessageWithoutReply) -> None:
        """" Sends low priority messages without waiting for their reply. """
        try:
            asyncio.run(self.producer_pool.put(msg))
            time.sleep(message_cooldown)
        except Exception as e:
            log = f"Got an exception in WebRTCConnection send_eventual_msg_without_reply. {e}"
            logger.error(log)
            raise e

    @syft_decorator(typechecking=True)
    async def send_sync_message(
        self, msg: SignedImmediateSyftMessageWithReply
    ) -> SignedImmediateSyftMessageWithoutReply:
        """Send sync messages generically.

        :return: returns an instance of SignedImmediateSyftMessageWithoutReply.
        :rtype: SignedImmediateSyftMessageWithoutReply
        """
        try:
            # To ensure the sequence of sending / receiving messages
            # it's necessary to keep only a unique reference for reading
            # inputs (producer) and outputs (consumer).
            r = random.randint(0, 100000)
            # To be able to perform this method synchronously (waiting for the reply)
            # without blocking async methods, we need to use queues.

            # Enqueue the message to be sent to the target.
            logger.debug(
                f"> Before send_sync_message producer_pool.put blocking {r}")
            # self.producer_pool.put_nowait(msg)
            await self.producer_pool.put(msg)
            logger.debug(
                f"> After send_sync_message producer_pool.put blocking {r}")

            # Wait for the response checking the consumer queue.
            logger.debug(
                f"> Before send_sync_message consumer_pool.get blocking {r} {msg}"
            )
            logger.debug(
                f"> Before send_sync_message consumer_pool.get blocking {r} {msg.message}"
            )
            # before = time.time()
            # timeout_secs = 15

            response = await self.consumer_pool.get()

            #  asyncio.run()
            # self.async_check(before=before, timeout_secs=timeout_secs, r=r)

            logger.debug(
                f"> After send_sync_message consumer_pool.get blocking {r}")
            return response
        except Exception as e:
            log = f"Got an exception in WebRTCConnection send_eventual_msg_without_reply. {e}"
            logger.error(log)
            raise e

    async def async_check(self, before: float, timeout_secs: int,
                          r: float) -> SignedImmediateSyftMessageWithoutReply:
        while True:
            await asyncio.sleep(message_cooldown)
            try:
                response = self.consumer_pool.get_nowait()
                return response
            except Exception as e:
                now = time.time()
                logger.debug(
                    f"> During send_sync_message consumer_pool.get blocking {r}. {e}"
                )
                if now - before > timeout_secs:
                    log = f"send_sync_message timeout {timeout_secs} {r}"
                    logger.critical(log)
                    raise Exception(log)
示例#14
0
    def test_connect_audio_and_video_and_data_channel(self):
        pc1 = RTCPeerConnection()
        pc1_states = track_states(pc1)

        pc2 = RTCPeerConnection()
        pc2_states = track_states(pc2)

        self.assertEqual(pc1.iceConnectionState, 'new')
        self.assertEqual(pc1.iceGatheringState, 'new')
        self.assertIsNone(pc1.localDescription)
        self.assertIsNone(pc1.remoteDescription)

        self.assertEqual(pc2.iceConnectionState, 'new')
        self.assertEqual(pc2.iceGatheringState, 'new')
        self.assertIsNone(pc2.localDescription)
        self.assertIsNone(pc2.remoteDescription)

        # create offer
        pc1.addTrack(AudioStreamTrack())
        pc1.addTrack(VideoStreamTrack())
        pc1.createDataChannel('chat', protocol='bob')
        offer = run(pc1.createOffer())
        self.assertEqual(offer.type, 'offer')
        self.assertTrue('m=audio ' in offer.sdp)
        self.assertTrue('m=video ' in offer.sdp)
        self.assertTrue('m=application ' in offer.sdp)

        run(pc1.setLocalDescription(offer))
        self.assertEqual(pc1.iceConnectionState, 'new')
        self.assertEqual(pc1.iceGatheringState, 'complete')

        # handle offer
        run(pc2.setRemoteDescription(pc1.localDescription))
        self.assertEqual(pc2.remoteDescription, pc1.localDescription)
        self.assertEqual(len(pc2.getSenders()), 2)
        self.assertEqual(len(pc2.getReceivers()), 2)

        # create answer
        pc2.addTrack(AudioStreamTrack())
        pc2.addTrack(VideoStreamTrack())
        answer = run(pc2.createAnswer())
        self.assertEqual(answer.type, 'answer')
        self.assertTrue('m=audio ' in answer.sdp)
        self.assertTrue('m=video ' in answer.sdp)
        self.assertTrue('m=application ' in answer.sdp)

        run(pc2.setLocalDescription(answer))
        self.assertEqual(pc2.iceConnectionState, 'checking')
        self.assertEqual(pc2.iceGatheringState, 'complete')
        self.assertTrue('m=audio ' in pc2.localDescription.sdp)
        self.assertTrue('m=video ' in pc2.localDescription.sdp)
        self.assertTrue('m=application ' in pc2.localDescription.sdp)

        # handle answer
        run(pc1.setRemoteDescription(pc2.localDescription))
        self.assertEqual(pc1.remoteDescription, pc2.localDescription)
        self.assertEqual(pc1.iceConnectionState, 'checking')

        # check outcome
        run(asyncio.sleep(1))
        self.assertEqual(pc1.iceConnectionState, 'completed')
        self.assertEqual(pc2.iceConnectionState, 'completed')

        # check a single transport is used
        self.assertBundled(pc1)
        self.assertBundled(pc2)

        # close
        run(pc1.close())
        run(pc2.close())
        self.assertEqual(pc1.iceConnectionState, 'closed')
        self.assertEqual(pc2.iceConnectionState, 'closed')

        # check state changes
        self.assertEqual(pc1_states['iceConnectionState'],
                         ['new', 'checking', 'completed', 'closed'])
        self.assertEqual(pc1_states['iceGatheringState'], [
            'new', 'gathering', 'new', 'gathering', 'new', 'gathering',
            'complete'
        ])
        self.assertEqual(pc1_states['signalingState'],
                         ['stable', 'have-local-offer', 'stable', 'closed'])

        self.assertEqual(pc2_states['iceConnectionState'],
                         ['new', 'checking', 'completed', 'closed'])
        self.assertEqual(pc2_states['iceGatheringState'],
                         ['new', 'gathering', 'complete'])
        self.assertEqual(pc2_states['signalingState'],
                         ['stable', 'have-remote-offer', 'stable', 'closed'])
示例#15
0
class WebRTCVPN:
    def __init__(self):
        self.pc =  RTCPeerConnection()
        self.channel = None

    async def create_offer(self):
        channel = self.pc.createDataChannel("chat")
        self.channel = channel

        await self.pc.setLocalDescription(await self.pc.createOffer())

        return object_to_string(self.pc.localDescription)

    async def create_answer(self, offer):
        offer = object_from_string(offer)

        await self.pc.setRemoteDescription(offer)
        await self.pc.setLocalDescription(await self.pc.createAnswer())

        @self.pc.on("datachannel")
        def on_datachannel(channel):
            self.channel = channel

        return object_to_string(self.pc.localDescription)

    async def set_answer(self, answer):
        answer = object_from_string(answer)
        await self.pc.setRemoteDescription(answer)

    def get_channel(self):
        return self.channel

    def get_pc(self):
        return self.pc

    def create_tuntap(self, name, address, mtu, channel):
        self.tap = tuntap.Tun(name=name)
        self.tap.open()

        #channel.on("message")(self.tap.fd.write)
        @channel.on("message")
        def on_message(message):
            self.tap.fd.write(message)

        def tun_reader():
            data = self.tap.fd.read(self.tap.mtu)
            channel_state = self.channel.transport.transport.state

            if data and channel_state == "connected":
                channel.send(data)

        loop = asyncio.get_event_loop()
        loop.add_reader(self.tap.fd, tun_reader)

        self.tap.up()

        ip = IPRoute()
        index = ip.link_lookup(ifname=name)[0]
        ip.addr('add', index=index, address=address, mask=24)
        ip.link("set", index=index, mtu=mtu)

    def set_route(self, dst, gateway):
        ip = IPRoute()
        ip.route('add', dst=dst, gateway=gateway)

    async def input(self):
        loop = asyncio.get_event_loop()

        reader = asyncio.StreamReader(loop=loop)
        read_pipe = sys.stdin
        read_transport, _ = await loop.connect_read_pipe(
            lambda: asyncio.StreamReaderProtocol(reader), read_pipe
        )

        data = await reader.readline()

        return data.decode(read_pipe.encoding)

    async def hold(self):
        loop = asyncio.get_event_loop()
        reader = asyncio.StreamReader(loop=loop)
        data = await reader.readline()

        return data

    async def monitor(self):
        loop = asyncio.get_event_loop()
        while True:
            if not self.channel == None:
                if self.channel.transport.transport.state == "closed":
                    loop.run_until_complete(self.pc.close())
                    self.tap.close()
                    break
            await asyncio.sleep(1)

        return "closed"
示例#16
0
class KivyRTCApp(App):
    """
    A simple RTC client for Kivy use aiortc
    """
    _thread = None
    _update_cam = None
    _device = None
    _device_info = [0, 1920, 1080]
    is_running = False

    def __init__(self, app_name, user_data_dir, **kwargs):
        super(KivyRTCApp, self).__init__(**kwargs)

        self._app_name = app_name
        self._user_data_dir = user_data_dir
        self.title = app_name
        self.icon = Config.get('kivy', 'window_icon')

    def build(self):
        root = Builder.load_file('kivyrtc/main-layout.kv')

        return root

    def on_start(self):
        # Display FPS of app
        # from .tools.show_fps import ShowFPS
        # ShowFPS(self.root)

        for i in range(10):
            try:
                self._create_camera(i)
                status, _ = self._device.read()
                if status:
                    self._device_info[0] = i
                    break
            except:
                pass
        else:
            raise RuntimeError("Can't start camera")

        self.root.ids.room.text = '1234567'
        self.root.ids.server.text = JANUS_URL
        self._update_cam = Clock.schedule_interval(self._update, 1.0 / 30)

    def on_stop(self):
        if self._thread and self._thread.is_alive():
            self.leave_room()
            self._thread.join()

        if self._update_cam is not None:
            self._update_cam.cancel()
            self._update_cam = None

        self._device.release()

    def _create_camera(self, id):
        self._device = cv2.VideoCapture(id)
        self._device.set(cv2.CAP_PROP_FRAME_WIDTH, self._device_info[1])
        self._device.set(cv2.CAP_PROP_FRAME_HEIGHT, self._device_info[2])

    def _update(self, dt):
        if self._device is None:
            self._create_camera(self._device_info[0])

        _, img = self._device.read()
        self.np_img = img

        img_wg = self.root.ids.user_camera
        img_wg.refresh_widget(img, 'bgr')

    def connect_room(self, room):
        if self._thread and self._thread.is_alive(): return

        self.is_running = True
        self._thread = threading.Thread(target=self.run_server, args=[room])
        self._thread.start()

    def leave_room(self):
        if not self._thread: return

        self.is_running = False
        asyncio.run(self.session._queue.put(None))

    def run_server(self, room):
        try:
            loop = asyncio.get_event_loop()
        except RuntimeError:
            loop = asyncio.new_event_loop()
            asyncio.set_event_loop(loop)

        url = self.root.ids.server.text
        # create signaling and peer connection
        self.session = JanusSession(url)
        self.pc = RTCPeerConnection()
        self.pcs = {}

        # create media source
        video = VideoImageTrack(self)

        audio = AudioTrack(loop)

        # create media sink
        self.recorder = MediaStreamer(self.root.ids.network_camera)

        def out_cb(outdata: np.ndarray, frame_count, time_info, status):
            q_out = None
            for context in self.recorder.get_context():
                try:
                    if q_out is None:
                        q_out = context.queue.get_nowait()
                    else:
                        q_out |= context.queue.get_nowait()
                except:
                    pass

            if q_out is None:
                return

            if q_out.shape == outdata.shape:
                outdata[:] = q_out
            elif q_out.shape[0] * q_out.shape[1] == outdata.shape[
                    0] * outdata.shape[1]:
                outdata[:] = q_out.reshape(outdata.shape)
            else:
                outdata[:] = np.zeros(outdata.shape, dtype=outdata.dtype)
                Logger.warning(
                    f'Audio: wrong size, got {q_out.shape}, should {outdata.shape}'
                )

        # run event loop
        try:
            out_stream = OutputStream(
                blocksize=1920,
                callback=out_cb,
                dtype='int16',
                channels=1,
            )

            with out_stream, audio:
                loop.run_until_complete(
                    self.__run(
                        room=int(room),
                        video=video,
                        audio=audio,
                    ))
        finally:
            # cleanup
            loop.run_until_complete(self.recorder.stop())
            loop.run_until_complete(self.pc.close())
            loop.run_until_complete(self.session.destroy())
            for i in self.pcs:
                loop.run_until_complete(self.pcs[i].close())
            self.pcs = {}

            pending = asyncio.Task.all_tasks(loop)
            for task in pending:
                task.cancel()
            # loop.run_until_complete(asyncio.gather(*pending, loop=loop))
            loop.close()

            Logger.info('Loop: closed all')

        self._thread = None

    async def __run(self, room, video, audio):
        await self.session.create()

        # configure media
        media = {
            # "audio": True,
            "video": True,
            "videocodec": "vp8",
            'audiocodec': 'opus',
        }

        self.pc.addTrack(video)
        self.pc.addTrack(audio)

        try:
            self.plugin = await self.session.attach("janus.plugin.videoroom")

            await self.create_roon(room)

            # join video room
            response = await self.plugin.send({
                "body": {
                    "display": "aiortc",
                    "ptype": "publisher",
                    "request": "join",
                    "room": room,
                }
            })
            publishers = response['plugindata']['data']['publishers']

            # send offer
            await self.pc.setLocalDescription(await self.pc.createOffer())
            request = {"request": "publish"}
            request.update(media)
            response = await self.plugin.send({
                "body": request,
                "jsep": {
                    "sdp": self.pc.localDescription.sdp,
                    "trickle": False,
                    "type": self.pc.localDescription.type,
                },
            })

            # apply answer
            answer = RTCSessionDescription(sdp=response["jsep"]["sdp"],
                                           type=response["jsep"]["type"])
            await self.pc.setRemoteDescription(answer)

            if publishers != []:
                for i in publishers:
                    await self.subscribe(i["id"], room)
        except Exception:
            Logger.exception('Join room: fail')
            return

        Logger.info('AppRTC: Start call')
        i = 0
        while True:
            res = await self.session._queue.get()

            if not res:
                if not self.is_running:
                    break
                else:
                    continue

            if res.get('plugindata'):
                await self.new_connect(res, room)

            elif res['janus'] == 'hangup':
                await self.pcs.pop(res['sender']).close()
                await self.recorder.remove_track(res['sender'])
                Logger.info(f'AppRTC: {res["sender"]} leave room')

            elif res['janus'] == 'slowlink':
                Logger.info(f"AppRTC: slowlink {res['uplink']} "
                            f"{res['nacks']}")
                tar_plugin = self.session._plugins.get(res["sender"])
                if res['uplink']:
                    await tar_plugin.send({
                        "body": {
                            "request": "configure",
                            "bitrate": 64000
                        },
                    })
                else:
                    await tar_plugin.send({
                        "body": {
                            "request": "configure",
                            "bitrate": 128000
                        },
                    })

            elif res['janus'] == 'webrtcup':
                if res['sender'] == self.plugin.plugin_id:
                    Logger.info('AppRTC: You are streaming')
                elif res['sender'] in list(self.pcs.keys()):
                    Logger.info(
                        f'AppRTC: You are receiving from {res["sender"]}')
                else:
                    Logger.info(f'AppRTC: Receiving from {res["sender"]} '
                                f'list {list(self.pcs.keys())}')
            elif res['janus'] == 'media':
                if res['sender'] == self.plugin.plugin_id:
                    if res['receiving']:
                        Logger.info(f'AppRTC: {res["type"]} is OK')
                    else:
                        Logger.warning(f'AppRTC: {res["type"]} fail')
            elif res['janus'] == 'keepalive':
                pass
            else:
                print('-' * 40)
                for i, j in res.items():
                    print(i, j)
                print('-' * 40)

        Logger.info('AppRTC: Leave call')

    async def subscribe(self, sub_id, room):
        plugin = await self.session.attach("janus.plugin.videoroom")
        pc = RTCPeerConnection()
        self.pcs[plugin.plugin_id] = pc

        @pc.on("track")
        async def on_track(track):
            await self.recorder.addTrack(track, plugin.plugin_id)

        request = {
            "request": "join",
            "ptype": "subscriber",
            "room": room,
            "feed": sub_id,
            # "private_id" : ''
        }
        response = await plugin.send({
            "body": request,
        })
        if response['plugindata']['data'].get("error"):
            return
        await pc.setRemoteDescription(
            RTCSessionDescription(sdp=response["jsep"]["sdp"],
                                  type=response["jsep"]["type"]))

        answer = await pc.createAnswer()
        await pc.setLocalDescription(answer)

        response = await plugin.send({
            "body": {
                "request": "start"
            },
            "jsep": {
                "sdp": pc.localDescription.sdp,
                "trickle": False,
                "type": pc.localDescription.type,
            }
        })
        # print(response)
        self.recorder.start()

    async def new_connect(self, data, room):
        if data['plugindata']['plugin'] == 'janus.plugin.videoroom' and\
                data['plugindata']['data'].get('publishers'):
            publishers = data['plugindata']['data']['publishers']
            for i in publishers:
                await self.subscribe(i["id"], room)

    async def create_roon(self, room):
        res = await self.plugin.send(
            {'body': {
                "request": "exists",
                "room": room
            }})
        if res['plugindata']['data']["exists"]:
            return
        params = {
            "request": "create",
            "room": room,
            #<unique numeric ID, optional, chosen by plugin if missing>
            "permanent": False,
            #<true|false, whether the room should be saved in the config file, default=false>
            "description": "",
            # This is my awesome room
            'is_private': False,
            # true|false (private rooms don't appear when you do a 'list' request)
            # 'secret' : '',
            # <optional password needed for manipulating (e.g. destroying) the room>
            # 'pin' : '',
            # <optional password needed for joining the room>
            # 'require_pvtid' : True,
            # true|false (whether subscriptions are required to provide a valid
            #  a valid private_id to associate with a publisher, default=false)
            'publishers': 5,
            # <max number of concurrent senders> (e.g., 6 for a video
            #  conference or 1 for a webinar, default=3)
            # 'bitrate' : '',
            # <max video bitrate for senders> (e.g., 128000)
            # 'fir_freq' : '',
            # <send a FIR to publishers every fir_freq seconds> (0=disable)
            'audiocodec': 'opus',
            # opus|g722|pcmu|pcma|isac32|isac16 (audio codec to force on publishers, default=opus
            # can be a comma separated list in order of preference, e.g., opus,pcmu)
            'videocodec': 'vp8',
            # vp8|vp9|h264 (video codec to force on publishers, default=vp8
            # can be a comma separated list in order of preference, e.g., vp9,vp8,h264)
            # 'opus_fec' : True,
            # true|false (whether inband FEC must be negotiated; only works for Opus, default=false)
            # 'video_svc' : True,
            # true|false (whether SVC support must be enabled; only works for VP9, default=false)
            # 'audiolevel_ext' : False,
            # true|false (whether the ssrc-audio-level RTP extension must be
            # negotiated/used or not for new publishers, default=true)
            # 'audiolevel_event' : True,
            # true|false (whether to emit event to other users or not)
            # 'audio_active_packets' : '' ,
            # 100 (number of packets with audio level, default=100, 2 seconds)
            # 'audio_level_average' : '' ,
            # 25 (average value of audio level, 127=muted, 0='too loud', default=25)
            # 'videoorient_ext' : False,
            # true|false (whether the video-orientation RTP extension must be
            # negotiated/used or not for new publishers, default=true)
            # 'playoutdelay_ext' : False,
            # true|false (whether the playout-delay RTP extension must be
            # negotiated/used or not for new publishers, default=true)
            # 'transport_wide_cc_ext' : True,
            # true|false (whether the transport wide CC RTP extension must be
            # negotiated/used or not for new publishers, default=false)
            # 'record' : True,
            # true|false (whether this room should be recorded, default=false)
            # 'rec_dir' : '' ,
            # <folder where recordings should be stored, when enabled>
            # 'notify_joining' : True,
            # true|false (optional, whether to notify all participants when a new
            # participant joins the room. The Videoroom plugin by design only notifies
            # new feeds (publishers), and enabling this may result extra notification
            # traffic. This flag is particularly useful when enabled with \c require_pvtid
            # for admin to manage listening only participants. default=false)
        }

        await self.plugin.send({
            "body": params,
        })
示例#17
0
class WebRtcWorker(Generic[VideoProcessorT, AudioProcessorT]):
    _process_offer_thread: Union[threading.Thread, None]
    _answer_queue: queue.Queue
    _session_shutdown_observer: SessionShutdownObserver
    _video_processor: Optional[VideoProcessorT]
    _audio_processor: Optional[AudioProcessorT]
    _video_receiver: Optional[VideoReceiver]
    _audio_receiver: Optional[AudioReceiver]
    _input_video_track: Optional[MediaStreamTrack]
    _input_audio_track: Optional[MediaStreamTrack]
    _output_video_track: Optional[MediaStreamTrack]
    _output_audio_track: Optional[MediaStreamTrack]
    _player: Optional[MediaPlayer]

    @property
    def video_processor(self) -> Optional[VideoProcessorT]:
        return self._video_processor

    @property
    def audio_processor(self) -> Optional[AudioProcessorT]:
        return self._audio_processor

    @property
    def video_receiver(self) -> Optional[VideoReceiver]:
        return self._video_receiver

    @property
    def audio_receiver(self) -> Optional[AudioReceiver]:
        return self._audio_receiver

    @property
    def input_video_track(self) -> Optional[MediaStreamTrack]:
        return self._input_video_track

    @property
    def input_audio_track(self) -> Optional[MediaStreamTrack]:
        return self._input_audio_track

    @property
    def output_video_track(self) -> Optional[MediaStreamTrack]:
        return self._output_video_track

    @property
    def output_audio_track(self) -> Optional[MediaStreamTrack]:
        return self._output_audio_track

    def __init__(
        self,
        mode: WebRtcMode,
        source_video_track: Optional[MediaStreamTrack] = None,
        source_audio_track: Optional[MediaStreamTrack] = None,
        player_factory: Optional[MediaPlayerFactory] = None,
        in_recorder_factory: Optional[MediaRecorderFactory] = None,
        out_recorder_factory: Optional[MediaRecorderFactory] = None,
        video_processor_factory: Optional[
            VideoProcessorFactory[VideoProcessorT]] = None,
        audio_processor_factory: Optional[
            AudioProcessorFactory[AudioProcessorT]] = None,
        async_processing: bool = True,
        video_receiver_size: int = 4,
        audio_receiver_size: int = 4,
        sendback_video: bool = True,
        sendback_audio: bool = True,
    ) -> None:
        self._process_offer_thread = None
        self.pc = RTCPeerConnection()
        self._answer_queue = queue.Queue()

        self.mode = mode
        self.source_video_track = source_video_track
        self.source_audio_track = source_audio_track
        self.player_factory = player_factory
        self.in_recorder_factory = in_recorder_factory
        self.out_recorder_factory = out_recorder_factory
        self.video_processor_factory = video_processor_factory
        self.audio_processor_factory = audio_processor_factory
        self.async_processing = async_processing
        self.video_receiver_size = video_receiver_size
        self.audio_receiver_size = audio_receiver_size
        self.sendback_video = sendback_video
        self.sendback_audio = sendback_audio

        self._video_processor = None
        self._audio_processor = None
        self._video_receiver = None
        self._audio_receiver = None
        self._input_video_track = None
        self._input_audio_track = None
        self._output_video_track = None
        self._output_audio_track = None
        self._player = None

        self._session_shutdown_observer = SessionShutdownObserver(self.stop)

    def _run_process_offer_thread(
        self,
        sdp: str,
        type_: str,
    ):
        try:
            self._process_offer_thread_impl(
                sdp=sdp,
                type_=type_,
            )
        except Exception as e:
            logger.warn("An error occurred in the WebRTC worker thread: %s", e)
            self._answer_queue.put(
                e)  # Send the error object to the main thread

    def _process_offer_thread_impl(
        self,
        sdp: str,
        type_: str,
    ):
        logger.debug("_process_offer_thread_impl starts", )

        loop = get_server_event_loop()
        asyncio.set_event_loop(loop)

        offer = RTCSessionDescription(sdp, type_)

        def on_track_created(track_type: TrackType, track: MediaStreamTrack):
            if track_type == "input:video":
                self._input_video_track = track
            elif track_type == "input:audio":
                self._input_audio_track = track
            elif track_type == "output:video":
                self._output_video_track = track
            elif track_type == "output:audio":
                self._output_audio_track = track

        video_processor = None
        if self.video_processor_factory:
            video_processor = self.video_processor_factory()

        audio_processor = None
        if self.audio_processor_factory:
            audio_processor = self.audio_processor_factory()

        in_recorder = None
        if self.in_recorder_factory:
            in_recorder = self.in_recorder_factory()

        out_recorder = None
        if self.out_recorder_factory:
            out_recorder = self.out_recorder_factory()

        video_receiver = None
        audio_receiver = None
        if self.mode == WebRtcMode.SENDONLY:
            video_receiver = VideoReceiver(
                queue_maxsize=self.video_receiver_size)
            audio_receiver = AudioReceiver(
                queue_maxsize=self.audio_receiver_size)

        self._video_processor = video_processor
        self._audio_processor = audio_processor
        self._video_receiver = video_receiver
        self._audio_receiver = audio_receiver

        relay = get_global_relay()

        source_audio_track = None
        source_video_track = None
        if self.player_factory:
            player = self.player_factory()
            self._player = player
            if player.audio:
                source_audio_track = relay.subscribe(player.audio)
            if player.video:
                source_video_track = relay.subscribe(player.video)
        else:
            if self.source_video_track:
                source_video_track = relay.subscribe(self.source_video_track)
            if self.source_audio_track:
                source_audio_track = relay.subscribe(self.source_audio_track)

        @self.pc.on("iceconnectionstatechange")
        async def on_iceconnectionstatechange():
            logger.info("ICE connection state is %s",
                        self.pc.iceConnectionState)
            iceConnectionState = self.pc.iceConnectionState
            if iceConnectionState == "closed" or iceConnectionState == "failed":
                self._unset_processors()
            if self.pc.iceConnectionState == "failed":
                await self.pc.close()

        process_offer_task = loop.create_task(
            _process_offer_coro(
                self.mode,
                self.pc,
                offer,
                relay=relay,
                source_video_track=source_video_track,
                source_audio_track=source_audio_track,
                in_recorder=in_recorder,
                out_recorder=out_recorder,
                video_processor=video_processor,
                audio_processor=audio_processor,
                video_receiver=video_receiver,
                audio_receiver=audio_receiver,
                async_processing=self.async_processing,
                sendback_video=self.sendback_video,
                sendback_audio=self.sendback_audio,
                on_track_created=on_track_created,
            ))

        def callback(done_task: asyncio.Task):
            e = done_task.exception()
            if e:
                logger.debug("Error occurred in process_offer")
                logger.debug(e)
                self._answer_queue.put(e)
                return

            localDescription = done_task.result()
            self._answer_queue.put(localDescription)

        process_offer_task.add_done_callback(callback)

    def process_offer(
            self,
            sdp,
            type_,
            timeout: Union[float, None] = 10.0) -> RTCSessionDescription:
        self._process_offer_thread = threading.Thread(
            target=self._run_process_offer_thread,
            kwargs={
                "sdp": sdp,
                "type_": type_,
            },
            daemon=True,
            name=f"process_offer_{next(process_offer_thread_id_generator)}",
        )
        self._process_offer_thread.start()

        try:
            result = self._answer_queue.get(block=True, timeout=timeout)
        except queue.Empty:
            self.stop(timeout=1)
            raise TimeoutError("Processing offer and initializing the worker "
                               f"has not finished in {timeout} seconds")

        if isinstance(result, Exception):
            raise result

        return result

    def _unset_processors(self):
        self._video_processor = None
        self._audio_processor = None

        if self._video_receiver:
            self._video_receiver.stop()
        self._video_receiver = None

        if self._audio_receiver:
            self._audio_receiver.stop()
        self._audio_receiver = None

        # The player tracks are not automatically stopped when the WebRTC session ends
        # because these tracks are connected to the consumer via `MediaRelay` proxies
        # so `stop()` on the consumer is not delegated to the source tracks.
        # So the player is stopped manually here when the worker stops.
        if self._player:
            if self._player.video:
                self._player.video.stop()
            if self._player.audio:
                self._player.audio.stop()
        self._player = None

    def stop(self, timeout: Union[float, None] = 1.0):
        self._unset_processors()
        if self._process_offer_thread:
            self._process_offer_thread.join(timeout=timeout)
            self._process_offer_thread = None

        if self.pc and self.pc.connectionState != "closed":
            loop = get_server_event_loop()
            if loop.is_running():
                loop.create_task(self.pc.close())
            else:
                loop.run_until_complete(self.pc.close())

        self._session_shutdown_observer.stop()
示例#18
0
class RTCConnection(SubscriptionProducerConsumer):
    _log = logging.getLogger("rtcbot.RTCConnection")

    def __init__(self, defaultChannelOrdered=True, loop=None):
        super().__init__(
            directPutSubscriptionType=asyncio.Queue,
            defaultSubscriptionType=asyncio.Queue,
            logger=self._log,
        )
        self._loop = loop
        if self._loop is None:
            self._loop = asyncio.get_event_loop()

        self._dataChannels = {}

        # These allow us to easily signal when the given events happen
        self._dataChannelSubscriber = SubscriptionProducer(
            logger=self._log.getChild("dataChannelSubscriber")
        )

        self._rtc = RTCPeerConnection()
        self._rtc.on("datachannel", self._onDatachannel)
        # self._rtc.on("iceconnectionstatechange", self._onIceConnectionStateChange)
        self._rtc.on("track", self._onTrack)

        self._hasRemoteDescription = False
        self._defaultChannelOrdered = defaultChannelOrdered

        self._videoHandler = ConnectionVideoHandler(self._rtc)
        self._audioHandler = ConnectionAudioHandler(self._rtc)

    async def getLocalDescription(self, description=None):
        """
        Gets the description to send on. Creates an initial description
        if no remote description was passed, and creates a response if
        a remote was given,
        """
        if self._hasRemoteDescription or description is not None:
            # This means that we received an offer - either the remote description
            # was already set, or we passed in a description. In either case,
            # instead of initializing a new connection, we prepare a response
            if not self._hasRemoteDescription:
                await self.setRemoteDescription(description)
            self._log.debug("Creating response to connection offer")
            try:
                answer = await self._rtc.createAnswer()
            except AttributeError:
                self._log.exception(
                    "\n>>> Looks like the offer didn't include the necessary info to set up audio/video. See RTCConnection.video.offerToReceive(). <<<\n\n"
                )
                raise
            await self._rtc.setLocalDescription(answer)
            return {
                "sdp": self._rtc.localDescription.sdp,
                "type": self._rtc.localDescription.type,
            }

        # There was no remote description, which means that we are initializing the
        # connection.

        # Before starting init, we create a default data channel for the connection
        self._log.debug("Setting up default data channel")
        channel = DataChannel(
            self._rtc.createDataChannel("default", ordered=self._defaultChannelOrdered)
        )
        # Subscribe the default channel directly to our own inputs and outputs.
        # We have it listen to our own self._get, and write to our self._put_nowait
        channel.putSubscription(NoClosedSubscription(self._get))
        channel.subscribe(self._put_nowait)
        self._dataChannels[channel.name] = channel

        # Make sure we offer to receive video and audio if if isn't set up yet
        if len(self.video._senders) == 0 and self.video._offerToReceive:
            self._log.debug("Offering to receive video")
            self._rtc.addTransceiver("video", "recvonly")
        if len(self.audio._senders) == 0 and self.audio._offerToReceive:
            self._log.debug("Offering to receive audio")
            self._rtc.addTransceiver("audio", "recvonly")

        self._log.debug("Creating new connection offer")
        offer = await self._rtc.createOffer()
        await self._rtc.setLocalDescription(offer)
        return {
            "sdp": self._rtc.localDescription.sdp,
            "type": self._rtc.localDescription.type,
        }

    async def setRemoteDescription(self, description):
        self._log.debug("Setting remote connection description")
        await self._rtc.setRemoteDescription(RTCSessionDescription(**description))
        self._hasRemoteDescription = True

    def _onDatachannel(self, channel):
        """
        When a data channel comes in, adds it to the data channels, and sets up its messaging and stuff.

        """
        channel = DataChannel(channel)
        self._log.debug("Got channel: %s", channel.name)
        if channel.name == "default":
            # Subscribe the default channel directly to our own inputs and outputs.
            # We have it listen to our own self._get, and write to our self._put_nowait
            channel.putSubscription(NoClosedSubscription(self._get))
            channel.subscribe(self._put_nowait)

            # Set the default channel
            self._defaultChannel = channel

        else:
            self._dataChannelSubscriber.put_nowait(channel)
        self._dataChannels[channel.name] = channel

    def _onTrack(self, track):
        self._log.debug("Received %s track from connection", track.kind)
        if track.kind == "audio":
            self._audioHandler._onTrack(track)
        elif track.kind == "video":
            self._videoHandler._onTrack(track)

    def onDataChannel(self, callback=None):
        """
        Acts as a subscriber...
        """
        return self._dataChannelSubscriber.subscribe(callback)

    def addDataChannel(self, name, ordered=True):
        """
        Adds a data channel to the connection. Note that the RTCConnection adds a "default" channel
        automatically, which you can subscribe to directly.
        """
        self._log.debug("Adding data channel to connection")

        if name in self._dataChannels or name == "default":
            raise KeyError("Data channel %s already exists", name)

        dc = DataChannel(self._rtc.createDataChannel(name, ordered=ordered))
        self._dataChannels[name] = dc
        return dc

    def getDataChannel(self, name):
        """
        Returns the data channel with the given name. Please note that the "default" channel is considered special,
        and is not returned.
        """
        if name == "default":
            raise KeyError(
                "Default channel not available for 'get'. Use the RTCConnection's subscribe and put_nowait methods for access to it."
            )
        return self._dataChannels[name]

    @property
    def video(self):
        """
        Convenience function - you can subscribe to it to get video frames once they show up
        """
        return self._videoHandler

    @property
    def audio(self):
        """
        Convenience function - you can subscribe to it to get video frames once they show up
        """
        return self._audioHandler

    def close(self):
        """
        If the loop is running, returns a future that will close the connection. Otherwise, runs
        the loop temporarily to complete closing.
        """
        super().close()
        # And closes all tracks
        self.video.close()
        self.audio.close()

        for dc in self._dataChannels:
            self._dataChannels[dc].close()

        self._dataChannelSubscriber.close()

        if self._loop.is_running():
            self._log.debug("Loop is running - close will return a future!")
            return asyncio.ensure_future(self._rtc.close())
        else:
            self._loop.run_until_complete(self._rtc.close())
        return None

    def send(self, msg):
        """
        Send is an alias for put_nowait - makes it easier for people new to rtcbot to understand
        what is going on
        """
        self.put_nowait(msg)
示例#19
0
 def test_addTrack_closed(self):
     pc = RTCPeerConnection()
     run(pc.close())
     with self.assertRaises(InvalidStateError) as cm:
         pc.addTrack(AudioStreamTrack())
     self.assertEqual(str(cm.exception), 'RTCPeerConnection is closed')
示例#20
0
class WebRtcWorker:
    _thread: Union[threading.Thread, None]
    _loop: Union[AbstractEventLoop, None]
    _answer_queue: queue.Queue
    _video_transformer: Optional[VideoTransformerBase]
    _video_receiver: Optional[VideoReceiver]

    @property
    def video_transformer(self) -> Optional[VideoTransformerBase]:
        return self._video_transformer

    @property
    def video_receiver(self) -> Optional[VideoReceiver]:
        return self._video_receiver

    def __init__(
        self,
        mode: WebRtcMode,
        player_factory: Optional[MediaPlayerFactory] = None,
        video_transformer_factory: Optional[VideoTransformerFactory] = None,
        async_transform: bool = True,
    ) -> None:
        self._thread = None
        self._loop = None
        self.pc = RTCPeerConnection()
        self._answer_queue = queue.Queue()
        self._stop_requested = False

        self.mode = mode
        self.player_factory = player_factory
        self.video_transformer_factory = video_transformer_factory
        self.async_transform = async_transform

        self._video_transformer = None
        self._video_receiver = None

    def _run_webrtc_thread(
        self,
        sdp: str,
        type_: str,
        player_factory: Optional[MediaPlayerFactory],
        video_transformer_factory: Optional[VideoTransformerFactory],
        video_receiver: Optional[VideoReceiver],
        async_transform: bool,
    ):
        try:
            self._webrtc_thread(
                sdp=sdp,
                type_=type_,
                player_factory=player_factory,
                video_transformer_factory=video_transformer_factory,
                video_receiver=video_receiver,
                async_transform=async_transform,
            )
        except Exception as e:
            logger.error("Error occurred in the WebRTC thread:")

            exc_type, exc_value, exc_traceback = sys.exc_info()
            for tb in traceback.format_exception(exc_type, exc_value,
                                                 exc_traceback):
                for tbline in tb.rstrip().splitlines():
                    logger.error(tbline.rstrip())

            # TODO shutdown this thread!
            raise e

    def _webrtc_thread(
        self,
        sdp: str,
        type_: str,
        player_factory: Optional[MediaPlayerFactory],
        video_transformer_factory: Optional[Callable[[],
                                                     VideoTransformerBase]],
        video_receiver: Optional[VideoReceiver],
        async_transform: bool,
    ):
        logger.debug(
            "_webrtc_thread(player_factory=%s, video_transformer_factory=%s)",
            player_factory,
            video_transformer_factory,
        )

        loop = asyncio.new_event_loop()
        self._loop = loop

        offer = RTCSessionDescription(sdp, type_)

        def callback(localDescription):
            self._answer_queue.put(localDescription)

        video_transformer = None
        if video_transformer_factory:
            video_transformer = video_transformer_factory()

        if self.mode == WebRtcMode.SENDRECV:
            if video_transformer is None:
                logger.info("mode is set as sendrecv, "
                            "but video_transformer_factory is not specified. "
                            "A simple loopback transformer is used.")
                video_transformer = NoOpVideoTransformer()

        self._video_transformer = video_transformer

        loop.create_task(
            _process_offer(
                self.mode,
                self.pc,
                offer,
                player_factory,
                video_transformer=video_transformer,
                video_receiver=video_receiver,
                async_transform=async_transform,
                callback=callback,
            ))

        try:
            loop.run_forever()
        finally:
            logger.debug("Event loop %s has stopped.", loop)
            loop.run_until_complete(self.pc.close())
            loop.run_until_complete(loop.shutdown_asyncgens())
            loop.close()
            logger.debug("Event loop %s cleaned up.", loop)

    def process_offer(self, sdp, type_, timeout=10.0) -> RTCSessionDescription:
        if self.mode == WebRtcMode.SENDONLY:
            self._video_receiver = VideoReceiver(queue_maxsize=1)

        self._thread = threading.Thread(
            target=self._run_webrtc_thread,
            kwargs={
                "sdp": sdp,
                "type_": type_,
                "player_factory": self.player_factory,
                "video_transformer_factory": self.video_transformer_factory,
                "video_receiver": self._video_receiver,
                "async_transform": self.async_transform,
            },
            daemon=True,
        )
        self._thread.start()

        result = self._answer_queue.get(timeout)
        if isinstance(result, Exception):
            raise result

        return result

    def stop(self):
        if self._loop:
            self._loop.stop()
        if self._thread:
            self._thread.join()
示例#21
0
    def test_connect_video_bidirectional(self):
        pc1 = RTCPeerConnection()
        pc1_states = track_states(pc1)

        pc2 = RTCPeerConnection()
        pc2_states = track_states(pc2)

        self.assertEqual(pc1.iceConnectionState, 'new')
        self.assertEqual(pc1.iceGatheringState, 'new')
        self.assertIsNone(pc1.localDescription)
        self.assertIsNone(pc1.remoteDescription)

        self.assertEqual(pc2.iceConnectionState, 'new')
        self.assertEqual(pc2.iceGatheringState, 'new')
        self.assertIsNone(pc2.localDescription)
        self.assertIsNone(pc2.remoteDescription)

        # create offer
        pc1.addTrack(VideoStreamTrack())
        offer = run(pc1.createOffer())
        self.assertEqual(offer.type, 'offer')
        self.assertTrue('m=video ' in offer.sdp)
        self.assertFalse('a=candidate:' in offer.sdp)

        run(pc1.setLocalDescription(offer))
        self.assertEqual(pc1.iceConnectionState, 'new')
        self.assertEqual(pc1.iceGatheringState, 'complete')
        self.assertTrue('m=video ' in pc1.localDescription.sdp)
        self.assertTrue('a=candidate:' in pc1.localDescription.sdp)
        self.assertTrue('a=sendrecv' in pc1.localDescription.sdp)
        self.assertTrue('a=fingerprint:sha-256' in pc1.localDescription.sdp)
        self.assertTrue('a=setup:actpass' in pc1.localDescription.sdp)

        # handle offer
        run(pc2.setRemoteDescription(pc1.localDescription))
        self.assertEqual(pc2.remoteDescription, pc1.localDescription)
        self.assertEqual(len(pc2.getReceivers()), 1)

        # create answer
        pc2.addTrack(VideoStreamTrack())
        answer = run(pc2.createAnswer())
        self.assertEqual(answer.type, 'answer')
        self.assertTrue('m=video ' in answer.sdp)
        self.assertFalse('a=candidate:' in answer.sdp)

        run(pc2.setLocalDescription(answer))
        self.assertEqual(pc2.iceConnectionState, 'checking')
        self.assertEqual(pc2.iceGatheringState, 'complete')
        self.assertTrue('m=video ' in pc2.localDescription.sdp)
        self.assertTrue('a=candidate:' in pc2.localDescription.sdp)
        self.assertTrue('a=sendrecv' in pc1.localDescription.sdp)
        self.assertTrue('a=fingerprint:sha-256' in pc2.localDescription.sdp)
        self.assertTrue('a=setup:active' in pc2.localDescription.sdp)

        # handle answer
        run(pc1.setRemoteDescription(pc2.localDescription))
        self.assertEqual(pc1.remoteDescription, pc2.localDescription)
        self.assertEqual(pc1.iceConnectionState, 'checking')

        # check outcome
        run(asyncio.sleep(1))
        self.assertEqual(pc1.iceConnectionState, 'completed')
        self.assertEqual(pc2.iceConnectionState, 'completed')

        # close
        run(pc1.close())
        run(pc2.close())
        self.assertEqual(pc1.iceConnectionState, 'closed')
        self.assertEqual(pc2.iceConnectionState, 'closed')

        # check state changes
        self.assertEqual(pc1_states['iceConnectionState'],
                         ['new', 'checking', 'completed', 'closed'])
        self.assertEqual(pc1_states['iceGatheringState'],
                         ['new', 'gathering', 'complete'])
        self.assertEqual(pc1_states['signalingState'],
                         ['stable', 'have-local-offer', 'stable', 'closed'])

        self.assertEqual(pc2_states['iceConnectionState'],
                         ['new', 'checking', 'completed', 'closed'])
        self.assertEqual(pc2_states['iceGatheringState'],
                         ['new', 'gathering', 'complete'])
        self.assertEqual(pc2_states['signalingState'],
                         ['stable', 'have-remote-offer', 'stable', 'closed'])
示例#22
0
class WebRtcWorker(Generic[VideoTransformerT]):
    _webrtc_thread: Union[threading.Thread, None]
    _loop: Union[AbstractEventLoop, None]
    _answer_queue: queue.Queue
    _video_transformer: Optional[VideoTransformerT]
    _video_receiver: Optional[VideoReceiver]

    @property
    def video_transformer(self) -> Optional[VideoTransformerT]:
        return self._video_transformer

    @property
    def video_receiver(self) -> Optional[VideoReceiver]:
        return self._video_receiver

    def __init__(
        self,
        mode: WebRtcMode,
        player_factory: Optional[MediaPlayerFactory] = None,
        in_recorder_factory: Optional[MediaRecorderFactory] = None,
        out_recorder_factory: Optional[MediaRecorderFactory] = None,
        video_transformer_factory: Optional[
            VideoTransformerFactory[VideoTransformerT]
        ] = None,
        async_transform: bool = True,
    ) -> None:
        self._webrtc_thread = None
        self._loop = None
        self.pc = RTCPeerConnection()
        self._answer_queue = queue.Queue()

        self.mode = mode
        self.player_factory = player_factory
        self.in_recorder_factory = in_recorder_factory
        self.out_recorder_factory = out_recorder_factory
        self.video_transformer_factory = video_transformer_factory
        self.async_transform = async_transform

        self._video_transformer = None
        self._video_receiver = None

    def _run_webrtc_thread(
        self,
        sdp: str,
        type_: str,
    ):
        try:
            self._webrtc_thread_impl(
                sdp=sdp,
                type_=type_,
            )
        except Exception as e:
            logger.warn("An error occurred in the WebRTC worker thread: %s", e)

            if self._loop:
                logger.warn("An event loop exists. Clean up it.")
                loop = self._loop
                loop.run_until_complete(self.pc.close())
                loop.run_until_complete(loop.shutdown_asyncgens())
                loop.close()
                logger.warn("Event loop %s cleaned up.", loop)

            self._answer_queue.put(e)  # Send the error object to the main thread

    def _webrtc_thread_impl(
        self,
        sdp: str,
        type_: str,
    ):
        logger.debug(
            "_webrtc_thread_impl starts",
        )

        loop = asyncio.new_event_loop()
        self._loop = loop

        offer = RTCSessionDescription(sdp, type_)

        def callback(localDescription):
            self._answer_queue.put(localDescription)

        video_transformer = None
        if self.video_transformer_factory:
            video_transformer = self.video_transformer_factory()

        video_receiver = None
        if self.mode == WebRtcMode.SENDONLY:
            video_receiver = VideoReceiver(queue_maxsize=1)

        self._video_transformer = video_transformer
        self._video_receiver = video_receiver

        @self.pc.on("iceconnectionstatechange")
        async def on_iceconnectionstatechange():
            iceConnectionState = self.pc.iceConnectionState
            if iceConnectionState == "closed" or iceConnectionState == "failed":
                self._unset_transformers()

        loop.create_task(
            _process_offer(
                self.mode,
                self.pc,
                offer,
                player_factory=self.player_factory,
                in_recorder_factory=self.in_recorder_factory,
                out_recorder_factory=self.out_recorder_factory,
                video_transformer=video_transformer,
                video_receiver=video_receiver,
                async_transform=self.async_transform,
                callback=callback,
            )
        )

        try:
            loop.run_forever()
        finally:
            logger.debug("Event loop %s has stopped.", loop)
            loop.run_until_complete(self.pc.close())
            loop.run_until_complete(loop.shutdown_asyncgens())
            loop.close()
            logger.debug("Event loop %s cleaned up.", loop)

    def process_offer(
        self, sdp, type_, timeout: Union[float, None] = 10.0
    ) -> RTCSessionDescription:
        self._webrtc_thread = threading.Thread(
            target=self._run_webrtc_thread,
            kwargs={
                "sdp": sdp,
                "type_": type_,
            },
            daemon=True,
            name=f"webrtc_worker_{next(webrtc_thread_id_generator)}",
        )
        self._webrtc_thread.start()

        try:
            result = self._answer_queue.get(block=True, timeout=timeout)
        except queue.Empty:
            self.stop(timeout=1)
            raise TimeoutError(
                "Processing offer and initializing the worker "
                f"has not finished in {timeout} seconds"
            )

        if isinstance(result, Exception):
            raise result

        return result

    def _unset_transformers(self):
        self._video_transformer = None
        self._video_receiver = None

    def stop(self, timeout: Union[float, None] = 1.0):
        self._unset_transformers()
        if self._loop:
            self._loop.stop()
        if self._webrtc_thread:
            self._webrtc_thread.join(timeout=timeout)
示例#23
0
    def test_connect_datachannel(self):
        pc1 = RTCPeerConnection()
        pc1_data_messages = []
        pc1_states = track_states(pc1)

        pc2 = RTCPeerConnection()
        pc2_data_channels = []
        pc2_data_messages = []
        pc2_states = track_states(pc2)

        @pc2.on('datachannel')
        def on_datachannel(channel):
            self.assertEqual(channel.readyState, 'open')
            pc2_data_channels.append(channel)

            @channel.on('message')
            def on_message(message):
                pc2_data_messages.append(message)
                if isinstance(message, str):
                    channel.send('string-echo: ' + message)
                else:
                    channel.send(b'binary-echo: ' + message)

        # create data channel
        dc = pc1.createDataChannel('chat', protocol='bob')
        self.assertEqual(dc.label, 'chat')
        self.assertEqual(dc.protocol, 'bob')
        self.assertEqual(dc.readyState, 'connecting')

        # send messages
        dc.send('hello')
        dc.send('')
        dc.send(b'\x00\x01\x02\x03')
        dc.send(b'')
        dc.send(LONG_DATA)
        with self.assertRaises(ValueError) as cm:
            dc.send(1234)
        self.assertEqual(str(cm.exception),
                         "Cannot send unsupported data type: <class 'int'>")

        @dc.on('message')
        def on_message(message):
            pc1_data_messages.append(message)

        # create offer
        offer = run(pc1.createOffer())
        self.assertEqual(offer.type, 'offer')
        self.assertTrue('m=application ' in offer.sdp)
        self.assertFalse('a=candidate:' in offer.sdp)

        run(pc1.setLocalDescription(offer))
        self.assertEqual(pc1.iceConnectionState, 'new')
        self.assertEqual(pc1.iceGatheringState, 'complete')
        self.assertTrue('m=application ' in pc1.localDescription.sdp)
        self.assertTrue('a=candidate:' in pc1.localDescription.sdp)
        self.assertTrue('a=sctpmap:5000 webrtc-datachannel 65535' in
                        pc1.localDescription.sdp)
        self.assertTrue('a=fingerprint:sha-256' in pc1.localDescription.sdp)
        self.assertTrue('a=setup:actpass' in pc1.localDescription.sdp)

        # handle offer
        run(pc2.setRemoteDescription(pc1.localDescription))
        self.assertEqual(pc2.remoteDescription, pc1.localDescription)
        self.assertEqual(len(pc2.getReceivers()), 0)
        self.assertEqual(len(pc2.getSenders()), 0)
        self.assertEqual(len(pc2.getSenders()), 0)

        # create answer
        answer = run(pc2.createAnswer())
        self.assertEqual(answer.type, 'answer')
        self.assertTrue('m=application ' in answer.sdp)
        self.assertFalse('a=candidate:' in answer.sdp)

        run(pc2.setLocalDescription(answer))
        self.assertEqual(pc2.iceConnectionState, 'checking')
        self.assertEqual(pc2.iceGatheringState, 'complete')
        self.assertTrue('m=application ' in pc2.localDescription.sdp)
        self.assertTrue('a=candidate:' in pc2.localDescription.sdp)
        self.assertTrue('a=sctpmap:5000 webrtc-datachannel 65535' in
                        pc2.localDescription.sdp)
        self.assertTrue('a=fingerprint:sha-256' in pc2.localDescription.sdp)
        self.assertTrue('a=setup:active' in pc2.localDescription.sdp)

        # handle answer
        run(pc1.setRemoteDescription(pc2.localDescription))
        self.assertEqual(pc1.remoteDescription, pc2.localDescription)
        self.assertEqual(pc1.iceConnectionState, 'checking')

        # check outcome
        run(asyncio.sleep(1))
        self.assertEqual(pc1.iceConnectionState, 'completed')
        self.assertEqual(pc2.iceConnectionState, 'completed')
        self.assertEqual(dc.readyState, 'open')

        # check pc2 got a datachannel
        self.assertEqual(len(pc2_data_channels), 1)
        self.assertEqual(pc2_data_channels[0].label, 'chat')
        self.assertEqual(pc2_data_channels[0].protocol, 'bob')

        # check pc2 got messages
        run(asyncio.sleep(1))
        self.assertEqual(pc2_data_messages, [
            'hello',
            '',
            b'\x00\x01\x02\x03',
            b'',
            LONG_DATA,
        ])

        # check pc1 got replies
        self.assertEqual(pc1_data_messages, [
            'string-echo: hello',
            'string-echo: ',
            b'binary-echo: \x00\x01\x02\x03',
            b'binary-echo: ',
            b'binary-echo: ' + LONG_DATA,
        ])

        # close data channel
        dc.close()
        self.assertEqual(dc.readyState, 'closed')

        # close
        run(pc1.close())
        run(pc2.close())
        self.assertEqual(pc1.iceConnectionState, 'closed')
        self.assertEqual(pc2.iceConnectionState, 'closed')

        # check state changes
        self.assertEqual(pc1_states['iceConnectionState'],
                         ['new', 'checking', 'completed', 'closed'])
        self.assertEqual(pc1_states['iceGatheringState'],
                         ['new', 'gathering', 'complete'])
        self.assertEqual(pc1_states['signalingState'],
                         ['stable', 'have-local-offer', 'stable', 'closed'])

        self.assertEqual(pc2_states['iceConnectionState'],
                         ['new', 'checking', 'completed', 'closed'])
        self.assertEqual(pc2_states['iceGatheringState'],
                         ['new', 'gathering', 'complete'])
        self.assertEqual(pc2_states['signalingState'],
                         ['stable', 'have-remote-offer', 'stable', 'closed'])
示例#24
0
if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='Data channel file transfer')
    parser.add_argument('role', choices=['send', 'receive'])
    parser.add_argument('filename')
    parser.add_argument('--verbose', '-v', action='count')
    add_signaling_arguments(parser)
    args = parser.parse_args()

    if args.verbose:
        logging.basicConfig(level=logging.DEBUG)

    signaling = create_signaling(args)
    pc = RTCPeerConnection()
    if args.role == 'send':
        fp = open(args.filename, 'rb')
        coro = run_offer(pc, signaling, fp)
    else:
        fp = open(args.filename, 'wb')
        coro = run_answer(pc, signaling, fp)

    # run event loop
    loop = asyncio.get_event_loop()
    try:
        loop.run_until_complete(coro)
    except KeyboardInterrupt:
        pass
    finally:
        fp.close()
        loop.run_until_complete(pc.close())
        loop.run_until_complete(signaling.close())
示例#25
0
class WebRTCConnection(BidirectionalConnection):
    loop: Any

    def __init__(self, node: AbstractNode) -> None:
        # WebRTC Connection representation

        # As we have a full-duplex connection,
        # it's necessary to use a node instance
        # inside of this connection. In order to
        # be able to process requests sent by
        # the other peer.
        # All the requests messages will be forwarded
        # to this node.
        self.node = node

        # EventLoop that manages async tasks (producer/consumer)
        # This structure is global and needs to be
        # defined beforehand.

        self.loop = loop
        # Message pool (High Priority)
        # These queues will be used to manage
        # async  messages.
        try:
            self.producer_pool: asyncio.Queue = asyncio.Queue(
                loop=self.loop,
            )  # Request Messages / Request Responses
            self.consumer_pool: asyncio.Queue = asyncio.Queue(
                loop=self.loop,
            )  # Request Responses

            # Initialize a PeerConnection structure
            self.peer_connection = RTCPeerConnection()

            # Set channel descriptor as None
            # This attribute will be used for external classes
            # in order to verify if the connection channel
            # was established.
            self.channel: RTCDataChannel
            self._client_address: Optional[Address] = None

        except Exception as e:
            traceback_and_raise(e)

    async def _set_offer(self) -> str:
        """
        Initialize a Real-Time Communication Data Channel,
        set data channel callbacks/tasks, and send offer payload
        message.

        :return: returns a signaling offer payload containing local description.
        :rtype: str
        """
        try:
            # Use the Peer Connection structure to
            # set the channel as a RTCDataChannel.
            self.channel = self.peer_connection.createDataChannel(
                "datachannel",
            )
            # Keep send buffer busy with chunks
            self.channel.bufferedAmountLowThreshold = 4 * DC_MAX_CHUNK_SIZE

            # This method will be called by aioRTC lib as a callback
            # function when the connection opens.
            @self.channel.on("open")
            async def on_open() -> None:  # type : ignore
                self.__producer_task = asyncio.ensure_future(self.producer())

            chunked_msg = []
            chunks_pending = 0

            # This method is the aioRTC "consumer" task
            # and will be running as long as connection remains.
            # At this point we're just setting the method behavior
            # It'll start running after the connection opens.
            @self.channel.on("message")
            async def on_message(raw: bytes) -> None:
                nonlocal chunked_msg, chunks_pending

                chunk = OrderedChunk.load(raw)
                message = chunk.data

                if message == DC_CHUNK_START_SIGN:
                    chunks_pending = chunk.idx
                    chunked_msg = [b""] * chunks_pending
                elif chunks_pending:
                    if chunked_msg[chunk.idx] == b"":
                        chunks_pending -= 1
                    chunked_msg[chunk.idx] = message
                    if chunks_pending == 0:
                        await self.consumer(msg=b"".join(chunked_msg))
                else:
                    # Forward all received messages to our own consumer method.
                    await self.consumer(msg=message)

            # Set peer_connection to generate an offer message type.
            await self.peer_connection.setLocalDescription(
                await self.peer_connection.createOffer()
            )

            # Generates the local description structure
            # and serialize it to string afterwards.
            local_description = object_to_string(self.peer_connection.localDescription)

            # Return the Offer local_description payload.
            return local_description
        except Exception as e:
            traceback_and_raise(e)

    async def _set_answer(self, payload: str) -> str:
        """
        Receives a signaling offer payload, initialize/set
        Data channel callbacks/tasks, updates remote local description
        using offer's payload message and returns a
        signaling answer payload.

        :return: returns a signaling answer payload containing local description.
        :rtype: str
        """

        try:

            @self.peer_connection.on("datachannel")
            def on_datachannel(channel: RTCDataChannel) -> None:
                self.channel = channel

                self.__producer_task = asyncio.ensure_future(self.producer())

                chunked_msg = []
                chunks_pending = 0

                @self.channel.on("message")
                async def on_message(raw: bytes) -> None:
                    nonlocal chunked_msg, chunks_pending

                    chunk = OrderedChunk.load(raw)
                    message = chunk.data
                    if message == DC_CHUNK_START_SIGN:
                        chunks_pending = chunk.idx
                        chunked_msg = [b""] * chunks_pending
                    elif chunks_pending:
                        if chunked_msg[chunk.idx] == b"":
                            chunks_pending -= 1
                        chunked_msg[chunk.idx] = message
                        if chunks_pending == 0:
                            await self.consumer(msg=b"".join(chunked_msg))
                    else:
                        await self.consumer(msg=message)

            result = await self._process_answer(payload=payload)
            return validate_type(result, str)

        except Exception as e:
            traceback_and_raise(e)
            raise Exception("mypy workaound: should not get here")

    async def _process_answer(self, payload: str) -> Union[str, None]:
        # Converts payload received by
        # the other peer in aioRTC Object
        # instance.
        try:
            msg = object_from_string(payload)

            # Check if Object instance is a
            # description of RTC Session.
            if isinstance(msg, RTCSessionDescription):

                # Use the target's network address/metadata
                # to set the remote description of this peer.
                # This will basically say to this peer how to find/connect
                # with to other peer.
                await self.peer_connection.setRemoteDescription(msg)

                # If it's an offer message type,
                # generates your own local description
                # and send it back in order to tell
                # to the other peer how to find you.
                if msg.type == "offer":
                    # Set peer_connection to generate an offer message type.
                    await self.peer_connection.setLocalDescription(
                        await self.peer_connection.createAnswer()
                    )

                    # Generates the local description structure
                    # and serialize it to string afterwards.
                    local_description = object_to_string(
                        self.peer_connection.localDescription
                    )

                    # Returns the answer peer's local description
                    return local_description
        except Exception as e:
            traceback_and_raise(e)
        return None

    async def producer(self) -> None:
        """
        Async task to send messages to the other side.
        These messages will be enqueued by PySyft Node Clients
        by using PySyft routes and ClientConnection's inheritance.
        """
        try:
            while True:
                # If self.producer_pool is empty, give up task queue priority
                # and give computing time to the next task.
                msg = await self.producer_pool.get()

                # If self.producer_pool.get() returns a message
                # send it as a binary using the RTCDataChannel.
                data = serialize(msg, to_bytes=True)
                data_len = len(data)

                if DC_CHUNKING_ENABLED and data_len > DC_MAX_CHUNK_SIZE:
                    chunk_num = 0
                    done = False
                    sent: asyncio.Future = asyncio.Future(loop=self.loop)

                    def send_data_chunks() -> None:
                        nonlocal chunk_num, data_len, done, sent
                        # Send chunks until buffered amount is big or we're done
                        while (
                            self.channel.bufferedAmount <= DC_MAX_BUFSIZE and not done
                        ):
                            start_offset = chunk_num * DC_MAX_CHUNK_SIZE
                            end_offset = min(
                                (chunk_num + 1) * DC_MAX_CHUNK_SIZE, data_len
                            )
                            chunk = data[start_offset:end_offset]
                            self.channel.send(OrderedChunk(chunk_num, chunk).save())
                            chunk_num += 1
                            if chunk_num * DC_MAX_CHUNK_SIZE >= data_len:
                                done = True
                                sent.set_result(True)

                        if not done:
                            # Set listener for next round of sending when buffer is empty
                            self.channel.once("bufferedamountlow", send_data_chunks)

                    chunk_count = math.ceil(data_len / DC_MAX_CHUNK_SIZE)
                    self.channel.send(
                        OrderedChunk(chunk_count, DC_CHUNK_START_SIGN).save()
                    )
                    send_data_chunks()
                    # Wait until all chunks are dispatched
                    await sent
                else:
                    self.channel.send(OrderedChunk(0, data).save())
        except Exception as e:
            traceback_and_raise(e)

    def close(self) -> None:
        try:
            # Build Close Message to warn the other peer
            bye_msg = CloseConnectionMessage(address=Address())

            self.channel.send(OrderedChunk(0, serialize(bye_msg, to_bytes=True)).save())

            # Finish async tasks related with this connection
            self._finish_coroutines()
        except Exception as e:
            traceback_and_raise(e)

    def _finish_coroutines(self) -> None:
        try:
            asyncio.run(self.peer_connection.close())
            self.__producer_task.cancel()
        except Exception as e:
            traceback_and_raise(e)

    async def consumer(self, msg: bytes) -> None:
        """
        Async task to receive/process messages sent by the other side.
        These messages will be sent by the other peer as a service requests or responses
        for requests made by this connection previously (ImmediateSyftMessageWithReply).
        """
        try:
            # Deserialize the received message
            _msg = _deserialize(blob=msg, from_bytes=True)

            # Check if it's NOT  a response generated by a previous request
            # made by the client instance that uses this connection as a route.
            # PS: The "_client_address" attribute will be defined during
            # Node Client initialization.
            if _msg.address != self._client_address:
                # If it's a new service request, route it properly
                # using the node instance owned by this connection.

                # Immediate message with reply
                if isinstance(_msg, SignedImmediateSyftMessageWithReply):
                    reply = self.recv_immediate_msg_with_reply(msg=_msg)
                    await self.producer_pool.put(reply)

                # Immediate message without reply
                elif isinstance(_msg, SignedImmediateSyftMessageWithoutReply):
                    self.recv_immediate_msg_without_reply(msg=_msg)

                elif isinstance(_msg, CloseConnectionMessage):
                    # Just finish async tasks related with this connection
                    self._finish_coroutines()

                # Eventual message without reply
                else:
                    self.recv_eventual_msg_without_reply(msg=_msg)

            # If it's true, the message will have the client's address as destination.
            else:
                await self.consumer_pool.put(_msg)

        except Exception as e:
            traceback_and_raise(e)

    def recv_immediate_msg_with_reply(
        self, msg: SignedImmediateSyftMessageWithReply
    ) -> SignedImmediateSyftMessageWithoutReply:
        """
        Executes/Replies requests instantly.

        :return: returns an instance of SignedImmediateSyftMessageWithReply
        :rtype: SignedImmediateSyftMessageWithoutReply
        """
        # Execute node services now
        try:
            r = secrets.randbelow(100000)
            debug(
                f"> Before recv_immediate_msg_with_reply {r} {msg.message} {type(msg.message)}"
            )
            reply = self.node.recv_immediate_msg_with_reply(msg=msg)
            debug(
                f"> After recv_immediate_msg_with_reply {r} {msg.message} {type(msg.message)}"
            )
            return reply
        except Exception as e:
            traceback_and_raise(e)

    def recv_immediate_msg_without_reply(
        self, msg: SignedImmediateSyftMessageWithoutReply
    ) -> None:
        """
        Executes requests instantly.
        """
        try:
            r = secrets.randbelow(100000)
            debug(
                f"> Before recv_immediate_msg_without_reply {r} {msg.message} {type(msg.message)}"
            )
            self.node.recv_immediate_msg_without_reply(msg=msg)
            debug(
                f"> After recv_immediate_msg_without_reply {r} {msg.message} {type(msg.message)}"
            )
        except Exception as e:
            traceback_and_raise(e)

    def recv_eventual_msg_without_reply(
        self, msg: SignedEventualSyftMessageWithoutReply
    ) -> None:
        """
        Executes requests eventually.
        """
        try:
            self.node.recv_eventual_msg_without_reply(msg=msg)
        except Exception as e:
            traceback_and_raise(e)
            raise Exception("mypy workaound: should not get here")

    # TODO: fix this mypy madness
    def send_immediate_msg_with_reply(  # type: ignore
        self, msg: SignedImmediateSyftMessageWithReply
    ) -> SignedImmediateSyftMessageWithReply:
        """
        Sends high priority messages and wait for their responses.

        :return: returns an instance of SignedImmediateSyftMessageWithReply.
        :rtype: SignedImmediateSyftMessageWithReply
        """
        try:
            # properly fix this!
            return validate_type(
                asyncio.run(self.send_sync_message(msg=msg)),
                object,
            )
        except Exception as e:
            traceback_and_raise(e)
            raise Exception("mypy workaound: should not get here")

    def send_immediate_msg_without_reply(
        self, msg: SignedImmediateSyftMessageWithoutReply
    ) -> None:
        """
        Sends high priority messages without waiting for their reply.
        """
        try:
            # asyncio.run(self.producer_pool.put_nowait(msg))
            self.producer_pool.put_nowait(msg)
        except Exception as e:
            traceback_and_raise(e)

    def send_eventual_msg_without_reply(
        self, msg: SignedEventualSyftMessageWithoutReply
    ) -> None:
        """
        Sends low priority messages without waiting for their reply.
        """
        try:
            asyncio.run(self.producer_pool.put(msg))
        except Exception as e:
            traceback_and_raise(e)

    async def send_sync_message(
        self, msg: SignedImmediateSyftMessageWithReply
    ) -> SignedImmediateSyftMessageWithoutReply:
        """
        Send sync messages generically.

        :return: returns an instance of SignedImmediateSyftMessageWithoutReply.
        :rtype: SignedImmediateSyftMessageWithoutReply
        """
        try:
            # To ensure the sequence of sending / receiving messages
            # it's necessary to keep only a unique reference for reading
            # inputs (producer) and outputs (consumer).
            r = secrets.randbelow(100000)
            # To be able to perform this method synchronously (waiting for the reply)
            # without blocking async methods, we need to use queues.

            # Enqueue the message to be sent to the target.
            debug(f"> Before send_sync_message producer_pool.put blocking {r}")
            # self.producer_pool.put_nowait(msg)
            await self.producer_pool.put(msg)
            debug(f"> After send_sync_message producer_pool.put blocking {r}")

            # Wait for the response checking the consumer queue.
            debug(f"> Before send_sync_message consumer_pool.get blocking {r} {msg}")
            debug(
                f"> Before send_sync_message consumer_pool.get blocking {r} {msg.message}"
            )
            response = await self.consumer_pool.get()

            debug(f"> After send_sync_message consumer_pool.get blocking {r}")
            return response
        except Exception as e:
            traceback_and_raise(e)

    async def async_check(
        self, before: float, timeout_secs: int, r: float
    ) -> SignedImmediateSyftMessageWithoutReply:
        while True:
            try:
                response = self.consumer_pool.get_nowait()
                return response
            except Exception as e:
                now = time.time()
                debug(f"> During send_sync_message consumer_pool.get blocking {r}. {e}")
                if now - before > timeout_secs:
                    traceback_and_raise(
                        Exception(f"send_sync_message timeout {timeout_secs} {r}")
                    )