Exemple #1
0
    def test_rtp_any_ssrc(self):
        # protect RTP
        tx_session = Session(policy=Policy(
            key=KEY,
            ssrc_type=Policy.SSRC_ANY_OUTBOUND))
        protected = tx_session.protect(RTP)
        self.assertEqual(len(protected), 182)

        # bad type
        with self.assertRaises(TypeError) as cm:
            tx_session.protect(4567)
        self.assertEqual(str(cm.exception), 'packet must be bytes')

        # bad length
        with self.assertRaises(ValueError) as cm:
            tx_session.protect(b'0' * 1500)
        self.assertEqual(str(cm.exception), 'packet is too long')

        # unprotect RTP
        rx_session = Session(policy=Policy(
            key=KEY,
            ssrc_type=Policy.SSRC_ANY_INBOUND))
        unprotected = rx_session.unprotect(protected)
        self.assertEqual(len(unprotected), 172)
        self.assertEqual(unprotected, RTP)
Exemple #2
0
    async def connect(self):
        assert self.state == self.State.CLOSED

        self._set_state(self.State.CONNECTING)
        while not self.encrypted:
            result = lib.SSL_do_handshake(self.ssl)
            await self._write_ssl()

            if result > 0:
                self.encrypted = True
                break

            error = lib.SSL_get_error(self.ssl, result)
            if error == lib.SSL_ERROR_WANT_READ:
                await self._recv_next()
            else:
                raise DtlsError('DTLS handshake failed (error %d)' % error)

        # check remote fingerprint
        x509 = lib.SSL_get_peer_certificate(self.ssl)
        remote_fingerprint = certificate_digest(x509)
        if remote_fingerprint != self.remote_fingerprint.upper():
            raise DtlsError('DTLS fingerprint does not match')

        # generate keying material
        buf = ffi.new('unsigned char[]', 2 * (SRTP_KEY_LEN + SRTP_SALT_LEN))
        extractor = b'EXTRACTOR-dtls_srtp'
        _openssl_assert(
            lib.SSL_export_keying_material(self.ssl, buf, len(
                buf), extractor, len(extractor), ffi.NULL, 0, 0) == 1)

        view = ffi.buffer(buf)
        if self.is_server:
            srtp_tx_key = get_srtp_key_salt(view, 1)
            srtp_rx_key = get_srtp_key_salt(view, 0)
        else:
            srtp_tx_key = get_srtp_key_salt(view, 0)
            srtp_rx_key = get_srtp_key_salt(view, 1)

        rx_policy = Policy(key=srtp_rx_key, ssrc_type=Policy.SSRC_ANY_INBOUND)
        self._rx_srtp = Session(rx_policy)
        tx_policy = Policy(key=srtp_tx_key, ssrc_type=Policy.SSRC_ANY_OUTBOUND)
        self._tx_srtp = Session(tx_policy)

        # start data pump
        logger.debug('%s - DTLS handshake complete', self.role)
        self._set_state(self.State.CONNECTED)
        asyncio.ensure_future(self.__run())
Exemple #3
0
    async def connect(self):
        while not self.encrypted:
            result = lib.SSL_do_handshake(self.ssl)
            if result > 0:
                self.encrypted = True
                break

            error = lib.SSL_get_error(self.ssl, result)

            await self._write_ssl()

            if error == lib.SSL_ERROR_WANT_READ:
                data = await self.transport.recv()
                lib.BIO_write(self.read_bio, data, len(data))
            else:
                raise Exception('DTLS handshake failed (error %d)' % error)

        await self._write_ssl()

        # check remote fingerprint
        x509 = lib.SSL_get_peer_certificate(self.ssl)
        remote_fingerprint = certificate_digest(x509)
        if remote_fingerprint != self.remote_fingerprint.upper():
            raise Exception('DTLS fingerprint does not match')

        # generate keying material
        buf = ffi.new("char[]", 2 * (SRTP_KEY_LEN + SRTP_SALT_LEN))
        extractor = b'EXTRACTOR-dtls_srtp'
        if not lib.SSL_export_keying_material(self.ssl, buf,
                                              len(buf), extractor,
                                              len(extractor), ffi.NULL, 0, 0):
            raise Exception('DTLS could not extract SRTP keying material')

        view = ffi.buffer(buf)
        if self.is_server:
            srtp_tx_key = get_srtp_key_salt(view, 1)
            srtp_rx_key = get_srtp_key_salt(view, 0)
        else:
            srtp_tx_key = get_srtp_key_salt(view, 0)
            srtp_rx_key = get_srtp_key_salt(view, 1)

        logger.info('DTLS handshake complete')
        rx_policy = Policy(key=srtp_rx_key, ssrc_type=Policy.SSRC_ANY_INBOUND)
        self._rx_srtp = Session(rx_policy)
        tx_policy = Policy(key=srtp_tx_key, ssrc_type=Policy.SSRC_ANY_OUTBOUND)
        self._tx_srtp = Session(tx_policy)
Exemple #4
0
    def test_rtcp_any_ssrc(self):
        # protect RCTP
        tx_session = Session(
            policy=Policy(key=KEY, ssrc_type=Policy.SSRC_ANY_OUTBOUND))
        protected = tx_session.protect_rtcp(RTCP)
        self.assertEqual(len(protected), 42)

        # bad type
        with self.assertRaises(TypeError) as cm:
            tx_session.protect_rtcp(4567)
        self.assertEqual(str(cm.exception), "packet must be bytes")

        # bad length
        with self.assertRaises(ValueError) as cm:
            tx_session.protect_rtcp(b"0" * 1500)
        self.assertEqual(str(cm.exception), "packet is too long")

        # unprotect RTCP
        rx_session = Session(
            policy=Policy(key=KEY, ssrc_type=Policy.SSRC_ANY_INBOUND))
        unprotected = rx_session.unprotect_rtcp(protected)
        self.assertEqual(len(unprotected), 28)
        self.assertEqual(unprotected, RTCP)
Exemple #5
0
    def test_rtp_specific_ssrc(self):
        # protect RTP
        tx_session = Session(policy=Policy(
            key=KEY, ssrc_type=Policy.SSRC_SPECIFIC, ssrc_value=12345))
        protected = tx_session.protect(RTP)
        self.assertEqual(len(protected), 182)

        # unprotect RTP
        rx_session = Session(policy=Policy(
            key=KEY, ssrc_type=Policy.SSRC_SPECIFIC, ssrc_value=12345))
        unprotected = rx_session.unprotect(protected)
        self.assertEqual(len(unprotected), 172)
        self.assertEqual(unprotected, RTP)
Exemple #6
0
class RTCDtlsTransport(EventEmitter):
    """
    The :class:`RTCDtlsTransport` object includes information relating to
    Datagram Transport Layer Security (DTLS) transport.

    :param: transport: An :class:`RTCIceTransport`.
    :param: certificates: A list of :class:`RTCCertificate` (only one is allowed currently).
    """
    def __init__(self, transport, certificates):
        assert len(certificates) == 1
        certificate = certificates[0]

        super().__init__()
        self.encrypted = False
        self._data_receiver = None
        self._role = 'auto'
        self._rtp_header_extensions_map = rtp.HeaderExtensionsMap()
        self._rtp_router = RtpRouter()
        self._state = State.NEW
        self._stats_id = 'transport_' + str(id(self))
        self._task = None
        self._transport = transport

        # counters
        self.__rx_bytes = 0
        self.__rx_packets = 0
        self.__tx_bytes = 0
        self.__tx_packets = 0

        # SRTP
        self._rx_srtp = None
        self._tx_srtp = None

        # SSL init
        self.__ctx = create_ssl_context(certificate)

        ssl = lib.SSL_new(self.__ctx)
        self.ssl = ffi.gc(ssl, lib.SSL_free)

        self.read_bio = lib.BIO_new(lib.BIO_s_mem())
        self.read_cdata = ffi.new('char[]', 1500)
        self.write_bio = lib.BIO_new(lib.BIO_s_mem())
        self.write_cdata = ffi.new('char[]', 1500)
        lib.SSL_set_bio(self.ssl, self.read_bio, self.write_bio)

        self.__local_certificate = certificate

    @property
    def state(self):
        """
        The current state of the DTLS transport.

        One of `'new'`, `'connecting'`, `'connected'`, `'closed'` or `'failed'`.
        """
        return str(self._state)[6:].lower()

    @property
    def transport(self):
        """
        The associated :class:`RTCIceTransport` instance.
        """
        return self._transport

    def getLocalParameters(self):
        """
        Get the local parameters of the DTLS transport.

        :rtype: :class:`RTCDtlsParameters`
        """
        return RTCDtlsParameters(
            fingerprints=self.__local_certificate.getFingerprints())

    async def start(self, remoteParameters):
        """
        Start DTLS transport negotiation with the parameters of the remote
        DTLS transport.

        :param: remoteParameters: An :class:`RTCDtlsParameters`.
        """
        assert self._state == State.NEW
        assert len(remoteParameters.fingerprints)

        if self.transport.role == 'controlling':
            self._role = 'server'
            lib.SSL_set_accept_state(self.ssl)
        else:
            self._role = 'client'
            lib.SSL_set_connect_state(self.ssl)

        self._set_state(State.CONNECTING)
        try:
            while not self.encrypted:
                result = lib.SSL_do_handshake(self.ssl)
                await self._write_ssl()

                if result > 0:
                    self.encrypted = True
                    break

                error = lib.SSL_get_error(self.ssl, result)
                if error == lib.SSL_ERROR_WANT_READ:
                    await self._recv_next()
                else:
                    self.__log_debug('x DTLS handshake failed (error %d)',
                                     error)
                    for info in get_error_queue():
                        self.__log_debug('x %s', ':'.join(info))
                    self._set_state(State.FAILED)
                    return
        except ConnectionError:
            self.__log_debug('x DTLS handshake failed (connection error)')
            self._set_state(State.FAILED)
            return

        # check remote fingerprint
        x509 = lib.SSL_get_peer_certificate(self.ssl)
        remote_fingerprint = certificate_digest(x509)
        fingerprint_is_valid = False
        for f in remoteParameters.fingerprints:
            if f.algorithm.lower() == 'sha-256' and f.value.lower(
            ) == remote_fingerprint.lower():
                fingerprint_is_valid = True
                break
        if not fingerprint_is_valid:
            self.__log_debug('x DTLS handshake failed (fingerprint mismatch)')
            self._set_state(State.FAILED)
            return

        # generate keying material
        buf = ffi.new('unsigned char[]', 2 * (SRTP_KEY_LEN + SRTP_SALT_LEN))
        extractor = b'EXTRACTOR-dtls_srtp'
        _openssl_assert(
            lib.SSL_export_keying_material(self.ssl, buf, len(
                buf), extractor, len(extractor), ffi.NULL, 0, 0) == 1)

        view = ffi.buffer(buf)
        if self._role == 'server':
            srtp_tx_key = get_srtp_key_salt(view, 1)
            srtp_rx_key = get_srtp_key_salt(view, 0)
        else:
            srtp_tx_key = get_srtp_key_salt(view, 0)
            srtp_rx_key = get_srtp_key_salt(view, 1)

        rx_policy = Policy(key=srtp_rx_key, ssrc_type=Policy.SSRC_ANY_INBOUND)
        rx_policy.allow_repeat_tx = True
        rx_policy.window_size = 1024
        self._rx_srtp = Session(rx_policy)

        tx_policy = Policy(key=srtp_tx_key, ssrc_type=Policy.SSRC_ANY_OUTBOUND)
        tx_policy.allow_repeat_tx = True
        tx_policy.window_size = 1024
        self._tx_srtp = Session(tx_policy)

        # start data pump
        self.__log_debug('- DTLS handshake complete')
        self._set_state(State.CONNECTED)
        self._task = asyncio.ensure_future(self.__run())

    async def stop(self):
        """
        Stop and close the DTLS transport.
        """
        if self._task is not None:
            self._task.cancel()
            self._task = None

        if self._state in [State.CONNECTING, State.CONNECTED]:
            lib.SSL_shutdown(self.ssl)
            try:
                await self._write_ssl()
            except ConnectionError:
                pass
            self.__log_debug('- DTLS shutdown complete')

    async def __run(self):
        try:
            while True:
                await self._recv_next()
        except ConnectionError:
            for receiver in self._rtp_router.receivers:
                receiver._handle_disconnect()
        finally:
            self._set_state(State.CLOSED)

    def _get_stats(self):
        report = RTCStatsReport()
        report.add(
            RTCTransportStats(
                # RTCStats
                timestamp=clock.current_datetime(),
                type='transport',
                id=self._stats_id,
                # RTCTransportStats,
                packetsSent=self.__tx_packets,
                packetsReceived=self.__rx_packets,
                bytesSent=self.__tx_bytes,
                bytesReceived=self.__rx_bytes,
                iceRole=self.transport.role,
                dtlsState=self.state,
            ))
        return report

    async def _handle_rtcp_data(self, data):
        try:
            packets = RtcpPacket.parse(data)
        except ValueError as exc:
            self.__log_debug('x RTCP parsing failed: %s', exc)
            return

        for packet in packets:
            # route RTCP packet
            for recipient in self._rtp_router.route_rtcp(packet):
                await recipient._handle_rtcp_packet(packet)

    async def _handle_rtp_data(self, data, arrival_time_ms):
        try:
            packet = RtpPacket.parse(data, self._rtp_header_extensions_map)
        except ValueError as exc:
            self.__log_debug('x RTP parsing failed: %s', exc)
            return

        # route RTP packet
        receiver = self._rtp_router.route_rtp(packet)
        if receiver is not None:
            await receiver._handle_rtp_packet(packet,
                                              arrival_time_ms=arrival_time_ms)

    async def _recv_next(self):
        # get timeout
        timeout = None
        if not self.encrypted:
            ptv_sec = ffi.new('time_t *')
            ptv_usec = ffi.new('long *')
            if lib.Cryptography_DTLSv1_get_timeout(self.ssl, ptv_sec,
                                                   ptv_usec):
                timeout = ptv_sec[0] + (ptv_usec[0] / 1000000)

        # receive next datagram
        if timeout is not None:
            try:
                data = await asyncio.wait_for(self.transport._recv(),
                                              timeout=timeout)
            except asyncio.TimeoutError:
                self.__log_debug('x DTLS handling timeout')
                lib.DTLSv1_handle_timeout(self.ssl)
                await self._write_ssl()
                return
        else:
            data = await self.transport._recv()

        self.__rx_bytes += len(data)
        self.__rx_packets += 1

        first_byte = data[0]
        if first_byte > 19 and first_byte < 64:
            # DTLS
            lib.BIO_write(self.read_bio, data, len(data))
            result = lib.SSL_read(self.ssl, self.read_cdata,
                                  len(self.read_cdata))
            await self._write_ssl()
            if result == 0:
                self.__log_debug('- DTLS shutdown by remote party')
                raise ConnectionError
            elif result > 0 and self._data_receiver:
                data = ffi.buffer(self.read_cdata)[0:result]
                await self._data_receiver._handle_data(data)
        elif first_byte > 127 and first_byte < 192 and self._rx_srtp:
            # SRTP / SRTCP
            arrival_time_ms = clock.current_ms()
            try:
                if is_rtcp(data):
                    data = self._rx_srtp.unprotect_rtcp(data)
                    await self._handle_rtcp_data(data)
                else:
                    data = self._rx_srtp.unprotect(data)
                    await self._handle_rtp_data(
                        data, arrival_time_ms=arrival_time_ms)
            except pylibsrtp.Error as exc:
                self.__log_debug('x SRTP unprotect failed: %s', exc)

    def _register_data_receiver(self, receiver):
        assert self._data_receiver is None
        self._data_receiver = receiver

    def _register_rtp_receiver(self, receiver,
                               parameters: RTCRtpReceiveParameters):
        ssrcs = set()
        for encoding in parameters.encodings:
            ssrcs.add(encoding.ssrc)

        self._rtp_header_extensions_map.configure(parameters)
        self._rtp_router.register_receiver(
            receiver,
            ssrcs=list(ssrcs),
            payload_types=[codec.payloadType for codec in parameters.codecs],
            mid=parameters.muxId)

    def _register_rtp_sender(self, sender, parameters: RTCRtpSendParameters):
        self._rtp_header_extensions_map.configure(parameters)
        self._rtp_router.register_sender(sender, ssrc=sender._ssrc)

    async def _send_data(self, data):
        if self._state != State.CONNECTED:
            raise ConnectionError('Cannot send encrypted data, not connected')

        lib.SSL_write(self.ssl, data, len(data))
        await self._write_ssl()

    async def _send_rtp(self, data):
        if self._state != State.CONNECTED:
            raise ConnectionError('Cannot send encrypted RTP, not connected')

        if is_rtcp(data):
            data = self._tx_srtp.protect_rtcp(data)
        else:
            data = self._tx_srtp.protect(data)
        await self.transport._send(data)
        self.__tx_bytes += len(data)
        self.__tx_packets += 1

    def _set_state(self, state):
        if state != self._state:
            self.__log_debug('- %s -> %s', self._state, state)
            self._state = state
            self.emit('statechange')

    def _unregister_data_receiver(self, receiver):
        if self._data_receiver == receiver:
            self._data_receiver = None

    def _unregister_rtp_receiver(self, receiver):
        self._rtp_router.unregister_receiver(receiver)

    def _unregister_rtp_sender(self, sender):
        self._rtp_router.unregister_sender(sender)

    async def _write_ssl(self):
        """
        Flush outgoing data which OpenSSL put in our BIO to the transport.
        """
        pending = lib.BIO_ctrl_pending(self.write_bio)
        if pending > 0:
            result = lib.BIO_read(self.write_bio, self.write_cdata,
                                  len(self.write_cdata))
            await self.transport._send(ffi.buffer(self.write_cdata)[0:result])
            self.__tx_bytes += result
            self.__tx_packets += 1

    def __log_debug(self, msg, *args):
        logger.debug(self._role + ' ' + msg, *args)
Exemple #7
0
    async def start(self, remoteParameters: RTCDtlsParameters) -> None:
        """
        Start DTLS transport negotiation with the parameters of the remote
        DTLS transport.

        :param remoteParameters: An :class:`RTCDtlsParameters`.
        """
        assert self._state == State.NEW
        assert len(remoteParameters.fingerprints)

        if self.transport.role == "controlling":
            self._role = "server"
            lib.SSL_set_accept_state(self.ssl)
        else:
            self._role = "client"
            lib.SSL_set_connect_state(self.ssl)

        self._set_state(State.CONNECTING)
        try:
            while not self.encrypted:
                result = lib.SSL_do_handshake(self.ssl)
                await self._write_ssl()

                if result > 0:
                    self.encrypted = True
                    break

                error = lib.SSL_get_error(self.ssl, result)
                if error == lib.SSL_ERROR_WANT_READ:
                    await self._recv_next()
                else:
                    self.__log_debug("x DTLS handshake failed (error %d)",
                                     error)
                    for info in get_error_queue():
                        self.__log_debug("x %s", ":".join(info))
                    self._set_state(State.FAILED)
                    return
        except ConnectionError:
            self.__log_debug("x DTLS handshake failed (connection error)")
            self._set_state(State.FAILED)
            return

        # check remote fingerprint
        x509 = lib.SSL_get_peer_certificate(self.ssl)
        remote_fingerprint = certificate_digest(x509)
        fingerprint_is_valid = False
        for f in remoteParameters.fingerprints:
            if (f.algorithm.lower() == "sha-256"
                    and f.value.lower() == remote_fingerprint.lower()):
                fingerprint_is_valid = True
                break
        if not fingerprint_is_valid:
            self.__log_debug("x DTLS handshake failed (fingerprint mismatch)")
            self._set_state(State.FAILED)
            return

        # generate keying material
        buf = ffi.new("unsigned char[]", 2 * (SRTP_KEY_LEN + SRTP_SALT_LEN))
        extractor = b"EXTRACTOR-dtls_srtp"
        _openssl_assert(
            lib.SSL_export_keying_material(self.ssl, buf, len(
                buf), extractor, len(extractor), ffi.NULL, 0, 0) == 1)

        view = ffi.buffer(buf)
        if self._role == "server":
            srtp_tx_key = get_srtp_key_salt(view, 1)
            srtp_rx_key = get_srtp_key_salt(view, 0)
        else:
            srtp_tx_key = get_srtp_key_salt(view, 0)
            srtp_rx_key = get_srtp_key_salt(view, 1)

        rx_policy = Policy(key=srtp_rx_key, ssrc_type=Policy.SSRC_ANY_INBOUND)
        rx_policy.allow_repeat_tx = True
        rx_policy.window_size = 1024
        self._rx_srtp = Session(rx_policy)

        tx_policy = Policy(key=srtp_tx_key, ssrc_type=Policy.SSRC_ANY_OUTBOUND)
        tx_policy.allow_repeat_tx = True
        tx_policy.window_size = 1024
        self._tx_srtp = Session(tx_policy)

        # start data pump
        self.__log_debug("- DTLS handshake complete")
        self._set_state(State.CONNECTED)
        self._task = asyncio.ensure_future(self.__run())
class srtp_hex2hex_coder(hex_coder):
    '''
    Allow record srtp/rtp stream
    '''
    def __init__(self,
                 override_payload_offset=None,
                 rtp_offset=0,
                 srtpkey=None,
                 resrtpkey=None,
                 verbose=False,
                 payload_type=111,
                 filter_ssrc=None):
        #super(ogg_opus_coder, self).__init__(verbose=verbose)
        hex_coder.__init__(self, verbose=verbose)

        # setup rtp parameters
        self.override_payload_offset = override_payload_offset
        self.rtp_offset = rtp_offset
        self.payload_type = payload_type

        # setup srtp parameters
        self.srtpkey = srtpkey
        self.resrtpkey = resrtpkey
        self.srtp_session = None
        self.ssrc = None
        self.filter_ssrc = filter_ssrc

    def record_rtp_packet(self, packet, is_last=False):
        if not packet:
            print("No packet")
            return False

        if self.verbose:
            print("Dumping packet ", packet)

        rtp_raw_full = None
        try:
            rtp_raw_full = bytes.fromhex(''.join(packet[self.rtp_offset:]))
        except:
            print("Failed to get hex", packet)
            return False

        if len(rtp_raw_full) < RTP_FIXED_HEADER:
            print("udp payload is too small")
            return False

        if self.verbose:
            print(self.rtp_offset, packet[self.rtp_offset - 1],
                  packet[self.rtp_offset], packet[self.rtp_offset + 1])

        # decode RTP Header
        rtp_raw_header = rtp_raw_full[:RTP_FIXED_HEADER]

        if self.verbose:
            print(rtp_raw_header.hex())
        rtp = RTP_HEADER._make(struct.unpack(RTP_HEADER_FMT, rtp_raw_header))
        if self.verbose:
            print(rtp)
        if rtp.PAYLOAD_TYPE & 0b10000000:
            # strip marker bit from payload
            rtp = RTP_HEADER(FIRST=rtp.FIRST,
                             PAYLOAD_TYPE=rtp.PAYLOAD_TYPE & 0b01111111,
                             SEQUENCE_NUMBER=rtp.SEQUENCE_NUMBER,
                             TIMESTAMP=rtp.TIMESTAMP,
                             SSRC=rtp.SSRC)
            print('stripped marker bit', rtp)

        # Filter the RTP with v=2
        if (rtp.FIRST & 0b11000000) != 0b10000000:
            print("Not an RTP")
            return False

        rtp_exten = rtp.FIRST & 0b10000
        rtp_exten_length = 0
        rtp_csrc = rtp.FIRST & 0b1111

        # calculate rtp header length
        calc_rtp_header_len = RTP_FIXED_HEADER + rtp_csrc * 4
        if rtp_exten:
            exten_start = RTP_FIXED_HEADER + rtp_csrc * 4
            exten_raw = rtp_raw_full[exten_start:exten_start + 4]
            if len(exten_raw) != 4:
                print("Skipping malformed RTP")
                return False
            rtp_exten_profile, rtp_exten_length = struct.unpack(
                '>HH', exten_raw)
            calc_rtp_header_len += 4 + rtp_exten_length * 4

        if self.verbose:
            print("calc_rtp_header_len", calc_rtp_header_len)

        if self.override_payload_offset:
            calc_rtp_header_len = self.override_payload_offset - self.rtp_offset

        # Filter opus
        if self.payload_type and rtp.PAYLOAD_TYPE != self.payload_type:
            print(
                "Skipping payload {rtp_payload_type} while {opus_payload_type} expected."
                .format(rtp_payload_type=rtp.PAYLOAD_TYPE,
                        opus_payload_type=self.payload_type))
            return False

        if self.filter_ssrc and self.filter_ssrc != rtp.SSRC:
            print(
                "Skipping ssrc={rtp_ssrc} while {filter_ssrc} expected".format(
                    rtp_ssrc=rtp.SSRC, filter_ssrc=self.filter_ssrc))
            return False

        if len(rtp_raw_full) < (calc_rtp_header_len + 1):
            print("Empty payload")
            return False

        if self.srtpkey:
            if self.ssrc != rtp.SSRC:
                #if not self.srtp_session:
                #    self.srtp_session = Session()
                self.srtp_session = Session()
                print("using key [%s]" % self.srtpkey)
                srtpkey = base64.b64decode(self.srtpkey)
                plc = Policy(key=srtpkey,
                             ssrc_value=rtp.SSRC,
                             ssrc_type=Policy.SSRC_ANY_INBOUND)
                print(plc)
                self.srtp_session.add_stream(plc)
                self.ssrc = rtp.SSRC
            try:
                rtp_raw_full = self.srtp_session.unprotect(rtp_raw_full)
            except:
                print("decrypt fail seq={sequence}, ssrc={ssrc}".format(
                    sequence=rtp.SEQUENCE_NUMBER, ssrc=rtp.SSRC))
                '''
                if self.resrtpkey:
                    srtpkey = base64.b64decode(self.resrtpkey)
                    plc = Policy(key=srtpkey,ssrc_value=rtp.SSRC,ssrc_type=Policy.SSRC_ANY_INBOUND)
                    self.srtp_session.add_stream(plc)
                    print("Using restrpkey here from next packet")
                '''
                return False

        self.write_udp(
            bytes.fromhex(''.join(packet[:self.rtp_offset])) + rtp_raw_full)
        return True

    def record_rtp_file(self, infile, outfile, hexstring=False):
        '''
        Convert an RTP hexdump into a recorded file
        '''
        self.start_file(outfile)

        if hexstring:
            with open(infile, 'r') as rtp_fd:
                packet_counter = 0
                success_counter = 0
                for xline in rtp_fd:
                    xline = xline.strip()
                    packet = [
                        xline[idx:idx + 2] for idx in range(0, len(xline), +2)
                    ]
                    packet_counter += 1
                    if self.record_rtp_packet(packet):
                        success_counter += 1
        else:
            re_valid_hex = re.compile(
                r'^[0-9A-Fa-f]{1,8}\s{1,4}[0-9A-Fa-f]{2}')
            with open(infile, 'r') as rtp_fd:
                packet_counter = 0
                success_counter = 0
                packet = []
                for xline in rtp_fd:
                    if not xline or not re_valid_hex.match(xline):
                        if packet:
                            packet_counter += 1
                            if self.record_rtp_packet(packet):
                                success_counter += 1
                            packet = []
                    else:
                        content = xline.split()
                        #print(len(content))
                        content.pop(0)  # skip the segment column
                        packet.extend(content)

                if packet:
                    packet_counter += 1
                    if self.record_rtp_packet(packet, is_last=True):
                        success_counter += 1

                print("Written %d out of %d packets" %
                      (success_counter, packet_counter))
        self.end_file()
    def record_rtp_packet(self, packet, is_last=False):
        if not packet:
            print("No packet")
            return False

        if self.verbose:
            print("Dumping packet ", packet)

        rtp_raw_full = None
        try:
            rtp_raw_full = bytes.fromhex(''.join(packet[self.rtp_offset:]))
        except:
            print("Failed to get hex", packet)
            return False

        if len(rtp_raw_full) < RTP_FIXED_HEADER:
            print("udp payload is too small")
            return False

        if self.verbose:
            print(self.rtp_offset, packet[self.rtp_offset - 1],
                  packet[self.rtp_offset], packet[self.rtp_offset + 1])

        # decode RTP Header
        rtp_raw_header = rtp_raw_full[:RTP_FIXED_HEADER]

        if self.verbose:
            print(rtp_raw_header.hex())
        rtp = RTP_HEADER._make(struct.unpack(RTP_HEADER_FMT, rtp_raw_header))
        if self.verbose:
            print(rtp)
        if rtp.PAYLOAD_TYPE & 0b10000000:
            # strip marker bit from payload
            rtp = RTP_HEADER(FIRST=rtp.FIRST,
                             PAYLOAD_TYPE=rtp.PAYLOAD_TYPE & 0b01111111,
                             SEQUENCE_NUMBER=rtp.SEQUENCE_NUMBER,
                             TIMESTAMP=rtp.TIMESTAMP,
                             SSRC=rtp.SSRC)
            print('stripped marker bit', rtp)

        # Filter the RTP with v=2
        if (rtp.FIRST & 0b11000000) != 0b10000000:
            print("Not an RTP")
            return False

        rtp_exten = rtp.FIRST & 0b10000
        rtp_exten_length = 0
        rtp_csrc = rtp.FIRST & 0b1111

        # calculate rtp header length
        calc_rtp_header_len = RTP_FIXED_HEADER + rtp_csrc * 4
        if rtp_exten:
            exten_start = RTP_FIXED_HEADER + rtp_csrc * 4
            exten_raw = rtp_raw_full[exten_start:exten_start + 4]
            if len(exten_raw) != 4:
                print("Skipping malformed RTP")
                return False
            rtp_exten_profile, rtp_exten_length = struct.unpack(
                '>HH', exten_raw)
            calc_rtp_header_len += 4 + rtp_exten_length * 4

        if self.verbose:
            print("calc_rtp_header_len", calc_rtp_header_len)

        if self.override_payload_offset:
            calc_rtp_header_len = self.override_payload_offset - self.rtp_offset

        # Filter opus
        if self.payload_type and rtp.PAYLOAD_TYPE != self.payload_type:
            print(
                "Skipping payload {rtp_payload_type} while {opus_payload_type} expected."
                .format(rtp_payload_type=rtp.PAYLOAD_TYPE,
                        opus_payload_type=self.payload_type))
            return False

        if self.filter_ssrc and self.filter_ssrc != rtp.SSRC:
            print(
                "Skipping ssrc={rtp_ssrc} while {filter_ssrc} expected".format(
                    rtp_ssrc=rtp.SSRC, filter_ssrc=self.filter_ssrc))
            return False

        if len(rtp_raw_full) < (calc_rtp_header_len + 1):
            print("Empty payload")
            return False

        if self.srtpkey:
            if self.ssrc != rtp.SSRC:
                #if not self.srtp_session:
                #    self.srtp_session = Session()
                self.srtp_session = Session()
                print("using key [%s]" % self.srtpkey)
                srtpkey = base64.b64decode(self.srtpkey)
                plc = Policy(key=srtpkey,
                             ssrc_value=rtp.SSRC,
                             ssrc_type=Policy.SSRC_ANY_INBOUND)
                print(plc)
                self.srtp_session.add_stream(plc)
                self.ssrc = rtp.SSRC
            try:
                rtp_raw_full = self.srtp_session.unprotect(rtp_raw_full)
            except:
                print("decrypt fail seq={sequence}, ssrc={ssrc}".format(
                    sequence=rtp.SEQUENCE_NUMBER, ssrc=rtp.SSRC))
                '''
                if self.resrtpkey:
                    srtpkey = base64.b64decode(self.resrtpkey)
                    plc = Policy(key=srtpkey,ssrc_value=rtp.SSRC,ssrc_type=Policy.SSRC_ANY_INBOUND)
                    self.srtp_session.add_stream(plc)
                    print("Using restrpkey here from next packet")
                '''
                return False

        self.write_udp(
            bytes.fromhex(''.join(packet[:self.rtp_offset])) + rtp_raw_full)
        return True
    def playback_as_srtp_stream(self, infile, outfile, starting_sequence=0):
        '''
        Read an ogg file and create rtp/srtp in hex form
        '''
        self.ogg.reset(open(infile, 'rb'))

        with open(outfile, 'w') as hex_fd:
            page_counter = 0
            sequence_counter = starting_sequence
            for header, content in self.ogg:
                print(header)
                if 0 == page_counter:
                    opus_identity_head = OPUS_IDENTITY_HEADER._make(
                        struct.unpack(OPUS_IDENTITY_HEADER_FMT, content))
                    print(opus_identity_head)
                #elif 1 == page_counter:
                #    opus_comment_head = OPUS_COMMENT_HEADER._make(struct.unpack(OPUS_COMMENT_HEADER_FMT, content))
                #    print(opus_comment_head)
                else:
                    print(' '.join([hex(x) for x in content]))
                    if page_counter > 1:
                        # show the TOC byte
                        #toc = content[0]
                        #config = (toc >> 3)
                        #s = toc & 0b100
                        #s = s >> 2
                        #num_frames = toc & 0b11
                        #print("config %d, s %d, c/frames %d" % (config, s, num_frames))

                        # make an RTP packet
                        rtp = RTP_HEADER(0x80, self.payload_type,
                                         sequence_counter, header.GRANULE_POS,
                                         header.BITSTREAM)
                        rtp_raw_full = struct.pack(RTP_HEADER_FMT, *
                                                   rtp) + content

                        if self.srtpkey:
                            if self.ssrc != rtp.SSRC:
                                print("using key [%s]" % self.srtpkey)
                                srtpkey = base64.b64decode(self.srtpkey)
                                plc = Policy(
                                    key=srtpkey,
                                    ssrc_value=rtp.SSRC,
                                    ssrc_type=Policy.SSRC_ANY_OUTBOUND)
                                print(plc)
                                self.srtp_session = Session(policy=plc)
                                self.ssrc = rtp.SSRC
                            try:
                                rtp_raw_full = self.srtp_session.protect(
                                    rtp_raw_full)
                            except:
                                print(
                                    "encrypt fail seq={sequence}, ssrc={ssrc}".
                                    format(sequence=rtp.SEQUENCE_NUMBER,
                                           ssrc=rtp.SSRC))

                        hex_fd.write(self.hexdump(rtp_raw_full))
                        if sequence_counter == MAX_RTP_SEQUENCE_NUM:
                            sequence_counter = 0
                        else:
                            sequence_counter += 1

                page_counter += 1
                print("Read %d pages" % page_counter)
    def record_rtp_packet(self, packet, is_last=False):
        assert self.ogg is not None

        if not packet:
            print("No packet")
            return False

        if self.verbose:
            print("Dumping packet ", packet)

        rtp_raw_full = None
        try:
            rtp_raw_full = bytes.fromhex(''.join(packet[self.rtp_offset:]))
        except:
            print("Failed to get hex", packet)
            return False

        if len(rtp_raw_full) < RTP_FIXED_HEADER:
            print("udp payload is too small")
            return False

        if self.verbose:
            print(self.rtp_offset, packet[self.rtp_offset - 1],
                  packet[self.rtp_offset], packet[self.rtp_offset + 1])

        # decode RTP Header
        rtp_raw_header = rtp_raw_full[:RTP_FIXED_HEADER]

        if self.verbose:
            print(rtp_raw_header.hex())
        rtp = RTP_HEADER._make(struct.unpack(RTP_HEADER_FMT, rtp_raw_header))
        if self.verbose:
            print(rtp)

        # Filter the RTP with v=2
        if (rtp.FIRST & 0b11000000) != 0b10000000:
            print("Not an RTP")
            return False

        rtp_exten = rtp.FIRST & 0b10000
        rtp_exten_length = 0
        rtp_csrc = rtp.FIRST & 0b1111

        # calculate rtp header length
        calc_rtp_header_len = RTP_FIXED_HEADER + rtp_csrc * 4
        if rtp_exten:
            exten_start = RTP_FIXED_HEADER + rtp_csrc * 4
            exten_raw = rtp_raw_full[exten_start:exten_start + 4]
            if len(exten_raw) != 4:
                print("Skipping malformed RTP")
                return False
            rtp_exten_profile, rtp_exten_length = struct.unpack(
                '>HH', exten_raw)
            calc_rtp_header_len += 4 + rtp_exten_length * 4

        if self.verbose:
            print("calc_rtp_header_len", calc_rtp_header_len)

        if self.override_payload_offset:
            calc_rtp_header_len = self.override_payload_offset - self.rtp_offset

        # Filter opus
        if self.payload_type and rtp.PAYLOAD_TYPE != self.payload_type:
            print(
                "Skipping payload {rtp_payload_type} while {opus_payload_type} expected."
                .format(rtp_payload_type=rtp.PAYLOAD_TYPE,
                        opus_payload_type=self.payload_type))
            return False

        if self.filter_ssrc and self.filter_ssrc != rtp.SSRC:
            print(
                "Skipping ssrc={rtp_ssrc} while {filter_ssrc} expected".format(
                    rtp_ssrc=rtp.SSRC, filter_ssrc=self.filter_ssrc))
            return False

        if len(rtp_raw_full) < (calc_rtp_header_len + 1):
            print("Empty payload")
            return False

        if self.srtpkey:
            if self.ssrc != rtp.SSRC:
                if not self.srtp_session:
                    self.srtp_session = Session()
                print("using key [%s]" % self.srtpkey)
                srtpkey = base64.b64decode(self.srtpkey)
                plc = Policy(key=srtpkey,
                             ssrc_value=rtp.SSRC,
                             ssrc_type=Policy.SSRC_ANY_INBOUND)
                print(plc)
                self.srtp_session.add_stream(plc)
                self.ssrc = rtp.SSRC
            try:
                rtp_raw_full = self.srtp_session.unprotect(rtp_raw_full)
            except:
                print("decrypt fail seq={sequence}, ssrc={ssrc}".format(
                    sequence=rtp.SEQUENCE_NUMBER, ssrc=rtp.SSRC))
                '''
                if self.resrtpkey:
                    srtpkey = base64.b64decode(self.resrtpkey)
                    plc = Policy(key=srtpkey,ssrc_value=rtp.SSRC,ssrc_type=Policy.SSRC_ANY_INBOUND)
                    self.srtp_session.add_stream(plc)
                    print("Using restrpkey here from next packet")
                '''
                return False

        # Add bitstream header
        if self.ogg.get_curr_bitstream() != rtp.SSRC:
            self.write_stream_header(rtp.SSRC)
            self.write_stream_comment('hex_to_opus', [str(rtp)])

        # rtp_payload = rtp_raw_full[RTP_FIXED_HEADER:]
        rtp_payload = rtp_raw_full[calc_rtp_header_len:]
        self.ogg.write_page(
            rtp_payload,
            is_data=True,
            is_last=is_last,
            ptime=20,
            pageno=rtp.SEQUENCE_NUMBER)  # By default the ptime=20
        return True
Exemple #12
0
    def test_add_remove_stream(self):
        # protect RTP
        tx_session = Session(policy=Policy(
            key=KEY,
            ssrc_type=Policy.SSRC_SPECIFIC,
            ssrc_value=12345))
        protected = tx_session.protect(RTP)
        self.assertEqual(len(protected), 182)

        # add stream and unprotect RTP
        rx_session = Session()
        rx_session.add_stream(Policy(
            key=KEY,
            ssrc_type=Policy.SSRC_SPECIFIC,
            ssrc_value=12345))
        unprotected = rx_session.unprotect(protected)
        self.assertEqual(len(unprotected), 172)
        self.assertEqual(unprotected, RTP)

        # remove stream
        rx_session.remove_stream(12345)

        # try removing stream again
        with self.assertRaises(Error) as cm:
            rx_session.remove_stream(12345)
        self.assertEqual(str(cm.exception), 'no appropriate context found')
Exemple #13
0
    def test_no_key(self):
        policy = Policy(ssrc_type=Policy.SSRC_ANY_OUTBOUND)

        with self.assertRaises(Error) as cm:
            Session(policy=policy)
        self.assertEqual(str(cm.exception), 'unsupported parameter')
Exemple #14
0
class DtlsSrtpSession:
    def __init__(self, context, is_server, transport):
        self.encrypted = False
        self.is_server = is_server
        self.remote_fingerprint = None
        self.transport = transport

        self.data_queue = asyncio.Queue()
        self.data = Channel(recv=self.data_queue.get, send=self._send_data)

        self.rtp_queue = asyncio.Queue()
        self.rtp = Channel(recv=self.rtp_queue.get, send=self._send_rtp)

        ssl = lib.SSL_new(context.ctx)
        self.ssl = ffi.gc(ssl, lib.SSL_free)

        self.read_bio = lib.BIO_new(lib.BIO_s_mem())
        self.write_bio = lib.BIO_new(lib.BIO_s_mem())
        lib.SSL_set_bio(self.ssl, self.read_bio, self.write_bio)

        if self.is_server:
            lib.SSL_set_accept_state(self.ssl)
        else:
            lib.SSL_set_connect_state(self.ssl)

    async def connect(self):
        while not self.encrypted:
            result = lib.SSL_do_handshake(self.ssl)
            if result > 0:
                self.encrypted = True
                break

            error = lib.SSL_get_error(self.ssl, result)

            await self._write_ssl()

            if error == lib.SSL_ERROR_WANT_READ:
                data = await self.transport.recv()
                lib.BIO_write(self.read_bio, data, len(data))
            else:
                raise Exception('DTLS handshake failed (error %d)' % error)

        await self._write_ssl()

        # check remote fingerprint
        x509 = lib.SSL_get_peer_certificate(self.ssl)
        remote_fingerprint = certificate_digest(x509)
        if remote_fingerprint != self.remote_fingerprint.upper():
            raise Exception('DTLS fingerprint does not match')

        # generate keying material
        buf = ffi.new("char[]", 2 * (SRTP_KEY_LEN + SRTP_SALT_LEN))
        extractor = b'EXTRACTOR-dtls_srtp'
        if not lib.SSL_export_keying_material(self.ssl, buf,
                                              len(buf), extractor,
                                              len(extractor), ffi.NULL, 0, 0):
            raise Exception('DTLS could not extract SRTP keying material')

        view = ffi.buffer(buf)
        if self.is_server:
            srtp_tx_key = get_srtp_key_salt(view, 1)
            srtp_rx_key = get_srtp_key_salt(view, 0)
        else:
            srtp_tx_key = get_srtp_key_salt(view, 0)
            srtp_rx_key = get_srtp_key_salt(view, 1)

        logger.info('DTLS handshake complete')
        rx_policy = Policy(key=srtp_rx_key, ssrc_type=Policy.SSRC_ANY_INBOUND)
        self._rx_srtp = Session(rx_policy)
        tx_policy = Policy(key=srtp_tx_key, ssrc_type=Policy.SSRC_ANY_OUTBOUND)
        self._tx_srtp = Session(tx_policy)

    async def run(self):
        while True:
            data = await self.transport.recv()
            first_byte = data[0]
            if first_byte > 19 and first_byte < 64:
                # DTLS
                lib.BIO_write(self.read_bio, data, len(data))
                buf = ffi.new("char[]", 1500)
                result = lib.SSL_read(self.ssl, buf, len(buf))
                await self.data_queue.put(ffi.buffer(buf)[0:result])
            elif first_byte > 127 and first_byte < 192:
                # SRTP / SRTCP
                if is_rtcp(data):
                    data = self._rx_srtp.unprotect_rtcp(data)
                else:
                    data = self._rx_srtp.unprotect(data)
                await self.rtp_queue.put(data)

    async def _send_data(self, data):
        lib.SSL_write(self.ssl, data, len(data))
        await self._write_ssl()

    async def _send_rtp(self, data):
        if is_rtcp(data):
            data = self._tx_srtp.protect_rtcp(data)
        else:
            data = self._tx_srtp.protect(data)
        await self.transport.send(data)

    async def _write_ssl(self):
        pending = lib.BIO_ctrl_pending(self.write_bio)
        if pending > 0:
            buf = ffi.new("char[]", pending)
            lib.BIO_read(self.write_bio, buf, len(buf))
            data = b''.join(buf)
            await self.transport.send(data)
Exemple #15
0
class RTCDtlsTransport(EventEmitter):
    """
    The :class:`RTCDtlsTransport` object includes information relating to
    Datagram Transport Layer Security (DTLS) transport.

    :param: transport: An :class:`RTCIceTransport`.
    :param: certificates: A list of :class:`RTCCertificate` (only one is allowed currently).
    """
    def __init__(self, transport, certificates):
        assert len(certificates) == 1
        certificate = certificates[0]

        super().__init__()
        self.closed = asyncio.Event()
        self.encrypted = False
        self._role = 'auto'
        self._rtp_mid_header_id = None
        self._rtp_router = RtpRouter()
        self._start = None
        self._state = State.NEW
        self._transport = transport

        self.data_queue = asyncio.Queue()
        self.data = Channel(closed=self.closed, queue=self.data_queue)

        # SSL init
        self.__ctx = create_ssl_context(certificate)

        ssl = lib.SSL_new(self.__ctx)
        self.ssl = ffi.gc(ssl, lib.SSL_free)

        self.read_bio = lib.BIO_new(lib.BIO_s_mem())
        self.read_cdata = ffi.new('char[]', 1500)
        self.write_bio = lib.BIO_new(lib.BIO_s_mem())
        self.write_cdata = ffi.new('char[]', 1500)
        lib.SSL_set_bio(self.ssl, self.read_bio, self.write_bio)

        self.__local_parameters = RTCDtlsParameters(
            fingerprints=certificate.getFingerprints())

    @property
    def state(self):
        """
        The current state of the DTLS transport.
        """
        return str(self._state)[6:].lower()

    @property
    def transport(self):
        """
        The associated :class:`RTCIceTransport` instance.
        """
        return self._transport

    def getLocalParameters(self):
        """
        Get the local parameters of the DTLS transport.

        :rtype: :class:`RTCDtlsParameters`
        """
        return self.__local_parameters

    async def start(self, remoteParameters):
        """
        Start DTLS transport negotiation with the parameters of the remote
        DTLS transport.

        :param: remoteParameters: An :class:`RTCDtlsParameters`.
        """
        assert self._state not in [State.CLOSED, State.FAILED]
        assert len(remoteParameters.fingerprints)

        # handle the case where start is already in progress
        if self._start is not None:
            return await self._start.wait()
        self._start = asyncio.Event()

        if self.transport.role == 'controlling':
            self._role = 'server'
            lib.SSL_set_accept_state(self.ssl)
        else:
            self._role = 'client'
            lib.SSL_set_connect_state(self.ssl)

        self._set_state(State.CONNECTING)
        while not self.encrypted:
            result = lib.SSL_do_handshake(self.ssl)
            await self._write_ssl()

            if result > 0:
                self.encrypted = True
                break

            error = lib.SSL_get_error(self.ssl, result)
            if error == lib.SSL_ERROR_WANT_READ:
                await self._recv_next()
            else:
                self._set_state(State.FAILED)
                raise DtlsError('DTLS handshake failed (error %d)' % error)

        # check remote fingerprint
        x509 = lib.SSL_get_peer_certificate(self.ssl)
        remote_fingerprint = certificate_digest(x509)
        fingerprint_is_valid = False
        for f in remoteParameters.fingerprints:
            if f.algorithm == 'sha-256' and f.value.lower(
            ) == remote_fingerprint.lower():
                fingerprint_is_valid = True
                break
        if not fingerprint_is_valid:
            self._set_state(State.FAILED)
            raise DtlsError('DTLS fingerprint does not match')

        # generate keying material
        buf = ffi.new('unsigned char[]', 2 * (SRTP_KEY_LEN + SRTP_SALT_LEN))
        extractor = b'EXTRACTOR-dtls_srtp'
        _openssl_assert(
            lib.SSL_export_keying_material(self.ssl, buf, len(
                buf), extractor, len(extractor), ffi.NULL, 0, 0) == 1)

        view = ffi.buffer(buf)
        if self._role == 'server':
            srtp_tx_key = get_srtp_key_salt(view, 1)
            srtp_rx_key = get_srtp_key_salt(view, 0)
        else:
            srtp_tx_key = get_srtp_key_salt(view, 0)
            srtp_rx_key = get_srtp_key_salt(view, 1)

        rx_policy = Policy(key=srtp_rx_key, ssrc_type=Policy.SSRC_ANY_INBOUND)
        rx_policy.allow_repeat_tx = True
        rx_policy.window_size = 1024
        self._rx_srtp = Session(rx_policy)

        tx_policy = Policy(key=srtp_tx_key, ssrc_type=Policy.SSRC_ANY_OUTBOUND)
        tx_policy.allow_repeat_tx = True
        tx_policy.window_size = 1024
        self._tx_srtp = Session(tx_policy)

        # start data pump
        self.__log_debug('- DTLS handshake complete')
        self._set_state(State.CONNECTED)
        asyncio.ensure_future(self.__run())
        self._start.set()

    async def stop(self):
        """
        Stop and close the DTLS transport.
        """
        if self._state in [State.CONNECTING, State.CONNECTED]:
            lib.SSL_shutdown(self.ssl)
            try:
                await self._write_ssl()
            except ConnectionError:
                pass
            self.__log_debug('- DTLS shutdown complete')
            self.closed.set()

    async def __run(self):
        try:
            while True:
                await self._recv_next()
        except ConnectionError:
            pass
        finally:
            self._set_state(State.CLOSED)
            self.closed.set()

    async def _handle_rtcp_data(self, data):
        packets = RtcpPacket.parse(data)
        for packet in packets:
            receiver = None
            if hasattr(packet, 'ssrc'):
                # SR and RR
                receiver = self._rtp_router.route(packet.ssrc)
            elif getattr(packet, 'chunks', None):
                # SDES
                receiver = self._rtp_router.route(packet.chunks[0].ssrc)
            elif getattr(packet, 'sources', None):
                # BYE
                receiver = self._rtp_router.route(packet.sources[0])
            if receiver is not None:
                await receiver._handle_rtcp_packet(packet)

    async def _handle_rtp_data(self, data):
        packet = RtpPacket.parse(data)

        # get muxId from RTP header extensions
        mid = None
        for x_id, x_value in get_header_extensions(packet):
            if x_id == self._rtp_mid_header_id:
                mid = x_value.decode('utf8')
                break

        # route RTP packet
        receiver = self._rtp_router.route(packet.ssrc, mid=mid)
        if receiver is not None:
            await receiver._handle_rtp_packet(packet)

    async def _recv_next(self):
        # get timeout
        ptv_sec = ffi.new('time_t *')
        ptv_usec = ffi.new('long *')
        if lib.Cryptography_DTLSv1_get_timeout(self.ssl, ptv_sec, ptv_usec):
            timeout = ptv_sec[0] + (ptv_usec[0] / 1000000)
        else:
            timeout = None

        try:
            data = await first_completed(self.transport._connection.recv(),
                                         self.closed.wait(),
                                         timeout=timeout)
        except TimeoutError:
            self.__log_debug('x DTLS handling timeout')
            lib.DTLSv1_handle_timeout(self.ssl)
            await self._write_ssl()
            return

        if data is True:
            # session was closed
            raise ConnectionError

        first_byte = data[0]
        if first_byte > 19 and first_byte < 64:
            # DTLS
            lib.BIO_write(self.read_bio, data, len(data))
            result = lib.SSL_read(self.ssl, self.read_cdata,
                                  len(self.read_cdata))
            await self._write_ssl()
            if result == 0:
                self.__log_debug('- DTLS shutdown by remote party')
                raise ConnectionError
            elif result > 0:
                await self.data_queue.put(
                    ffi.buffer(self.read_cdata)[0:result])
        elif first_byte > 127 and first_byte < 192:
            # SRTP / SRTCP
            try:
                if is_rtcp(data):
                    data = self._rx_srtp.unprotect_rtcp(data)
                    await self._handle_rtcp_data(data)
                else:
                    data = self._rx_srtp.unprotect(data)
                    await self._handle_rtp_data(data)
            except pylibsrtp.Error as exc:
                self.__log_debug('x SRTP unprotect failed: %s', exc)

    def _register_rtp_receiver(self, receiver, parameters):
        # make note of the RTP header extension used for muxId
        for ext in parameters.headerExtensions:
            if ext.uri == 'urn:ietf:params:rtp-hdrext:sdes:mid':
                self._rtp_mid_header_id = ext.id

        self._rtp_router.register(receiver, parameters)

    async def _send_data(self, data):
        if self._state != State.CONNECTED:
            raise ConnectionError('Cannot send encrypted data, not connected')

        lib.SSL_write(self.ssl, data, len(data))
        await self._write_ssl()

    async def _send_rtp(self, data):
        if self._state != State.CONNECTED:
            raise ConnectionError('Cannot send encrypted RTP, not connected')

        if is_rtcp(data):
            data = self._tx_srtp.protect_rtcp(data)
        else:
            data = self._tx_srtp.protect(data)
        await self.transport._connection.send(data)

    def _set_state(self, state):
        if state != self._state:
            self.__log_debug('- %s -> %s', self._state, state)
            self._state = state
            self.emit('statechange')

    async def _write_ssl(self):
        """
        Flush outgoing data which OpenSSL put in our BIO to the transport.
        """
        pending = lib.BIO_ctrl_pending(self.write_bio)
        if pending > 0:
            result = lib.BIO_read(self.write_bio, self.write_cdata,
                                  len(self.write_cdata))
            await self.transport._connection.send(
                ffi.buffer(self.write_cdata)[0:result])

    def __log_debug(self, msg, *args):
        logger.debug(self._role + ' ' + msg, *args)
class srtp_ogg_opus_coder(ogg_opus_coder):
    '''
    Allow record srtp/rtp stream
    '''
    def __init__(self,
                 override_payload_offset=None,
                 rtp_offset=0,
                 srtpkey=None,
                 resrtpkey=None,
                 verbose=False,
                 payload_type=111,
                 filter_ssrc=None):
        #super(ogg_opus_coder, self).__init__(verbose=verbose)
        ogg_opus_coder.__init__(self, verbose=verbose)

        # setup rtp parameters
        self.override_payload_offset = override_payload_offset
        self.rtp_offset = rtp_offset
        self.payload_type = payload_type

        # setup srtp parameters
        self.srtpkey = srtpkey
        self.resrtpkey = resrtpkey
        self.srtp_session = None
        self.ssrc = None
        self.filter_ssrc = filter_ssrc

    def record_rtp_packet(self, packet, is_last=False):
        assert self.ogg is not None

        if not packet:
            print("No packet")
            return False

        if self.verbose:
            print("Dumping packet ", packet)

        rtp_raw_full = None
        try:
            rtp_raw_full = bytes.fromhex(''.join(packet[self.rtp_offset:]))
        except:
            print("Failed to get hex", packet)
            return False

        if len(rtp_raw_full) < RTP_FIXED_HEADER:
            print("udp payload is too small")
            return False

        if self.verbose:
            print(self.rtp_offset, packet[self.rtp_offset - 1],
                  packet[self.rtp_offset], packet[self.rtp_offset + 1])

        # decode RTP Header
        rtp_raw_header = rtp_raw_full[:RTP_FIXED_HEADER]

        if self.verbose:
            print(rtp_raw_header.hex())
        rtp = RTP_HEADER._make(struct.unpack(RTP_HEADER_FMT, rtp_raw_header))
        if self.verbose:
            print(rtp)

        # Filter the RTP with v=2
        if (rtp.FIRST & 0b11000000) != 0b10000000:
            print("Not an RTP")
            return False

        rtp_exten = rtp.FIRST & 0b10000
        rtp_exten_length = 0
        rtp_csrc = rtp.FIRST & 0b1111

        # calculate rtp header length
        calc_rtp_header_len = RTP_FIXED_HEADER + rtp_csrc * 4
        if rtp_exten:
            exten_start = RTP_FIXED_HEADER + rtp_csrc * 4
            exten_raw = rtp_raw_full[exten_start:exten_start + 4]
            if len(exten_raw) != 4:
                print("Skipping malformed RTP")
                return False
            rtp_exten_profile, rtp_exten_length = struct.unpack(
                '>HH', exten_raw)
            calc_rtp_header_len += 4 + rtp_exten_length * 4

        if self.verbose:
            print("calc_rtp_header_len", calc_rtp_header_len)

        if self.override_payload_offset:
            calc_rtp_header_len = self.override_payload_offset - self.rtp_offset

        # Filter opus
        if self.payload_type and rtp.PAYLOAD_TYPE != self.payload_type:
            print(
                "Skipping payload {rtp_payload_type} while {opus_payload_type} expected."
                .format(rtp_payload_type=rtp.PAYLOAD_TYPE,
                        opus_payload_type=self.payload_type))
            return False

        if self.filter_ssrc and self.filter_ssrc != rtp.SSRC:
            print(
                "Skipping ssrc={rtp_ssrc} while {filter_ssrc} expected".format(
                    rtp_ssrc=rtp.SSRC, filter_ssrc=self.filter_ssrc))
            return False

        if len(rtp_raw_full) < (calc_rtp_header_len + 1):
            print("Empty payload")
            return False

        if self.srtpkey:
            if self.ssrc != rtp.SSRC:
                if not self.srtp_session:
                    self.srtp_session = Session()
                print("using key [%s]" % self.srtpkey)
                srtpkey = base64.b64decode(self.srtpkey)
                plc = Policy(key=srtpkey,
                             ssrc_value=rtp.SSRC,
                             ssrc_type=Policy.SSRC_ANY_INBOUND)
                print(plc)
                self.srtp_session.add_stream(plc)
                self.ssrc = rtp.SSRC
            try:
                rtp_raw_full = self.srtp_session.unprotect(rtp_raw_full)
            except:
                print("decrypt fail seq={sequence}, ssrc={ssrc}".format(
                    sequence=rtp.SEQUENCE_NUMBER, ssrc=rtp.SSRC))
                '''
                if self.resrtpkey:
                    srtpkey = base64.b64decode(self.resrtpkey)
                    plc = Policy(key=srtpkey,ssrc_value=rtp.SSRC,ssrc_type=Policy.SSRC_ANY_INBOUND)
                    self.srtp_session.add_stream(plc)
                    print("Using restrpkey here from next packet")
                '''
                return False

        # Add bitstream header
        if self.ogg.get_curr_bitstream() != rtp.SSRC:
            self.write_stream_header(rtp.SSRC)
            self.write_stream_comment('hex_to_opus', [str(rtp)])

        # rtp_payload = rtp_raw_full[RTP_FIXED_HEADER:]
        rtp_payload = rtp_raw_full[calc_rtp_header_len:]
        self.ogg.write_page(
            rtp_payload,
            is_data=True,
            is_last=is_last,
            ptime=20,
            pageno=rtp.SEQUENCE_NUMBER)  # By default the ptime=20
        return True

    def record_rtp_file(self, infile, outfile):
        '''
        Convert an RTP hexdump into a recorded file
        '''
        self.start_file(outfile)
        re_valid_hex = re.compile(r'^[0-9A-Fa-f]{1,8}\s{1,4}[0-9A-Fa-f]{2}')
        with open(infile, 'r') as rtp_fd:
            packet_counter = 0
            success_counter = 0
            packet = []
            for xline in rtp_fd:
                if not xline or not re_valid_hex.match(xline):
                    if packet:
                        packet_counter += 1
                        if self.record_rtp_packet(packet):
                            success_counter += 1
                        packet = []
                else:
                    content = xline.split()
                    #print(len(content))
                    content.pop(0)  # skip the segment column
                    packet.extend(content)

            if packet:
                self.record_rtp_packet(packet, is_last=True)

            print("Written %d out of %d packets" %
                  (success_counter, packet_counter))
        self.end_file()

    def hexdump(self, content):
        counter = 0
        output = []
        for segment in range((len(content) >> 4) + 1):
            segment_out = []

            segment_out.append('%06x' % counter)
            for offset in range(0, 16):
                pos = (segment << 4) + offset
                if pos >= len(content):
                    break  # avoid overflow
                segment_out.append('%02x' % content[pos])
                counter += 1
            output.append(' '.join(segment_out))

        output.append('\n')
        return '\n'.join(output)

    def playback_as_srtp_stream(self, infile, outfile, starting_sequence=0):
        '''
        Read an ogg file and create rtp/srtp in hex form
        '''
        self.ogg.reset(open(infile, 'rb'))

        with open(outfile, 'w') as hex_fd:
            page_counter = 0
            sequence_counter = starting_sequence
            for header, content in self.ogg:
                print(header)
                if 0 == page_counter:
                    opus_identity_head = OPUS_IDENTITY_HEADER._make(
                        struct.unpack(OPUS_IDENTITY_HEADER_FMT, content))
                    print(opus_identity_head)
                #elif 1 == page_counter:
                #    opus_comment_head = OPUS_COMMENT_HEADER._make(struct.unpack(OPUS_COMMENT_HEADER_FMT, content))
                #    print(opus_comment_head)
                else:
                    print(' '.join([hex(x) for x in content]))
                    if page_counter > 1:
                        # show the TOC byte
                        #toc = content[0]
                        #config = (toc >> 3)
                        #s = toc & 0b100
                        #s = s >> 2
                        #num_frames = toc & 0b11
                        #print("config %d, s %d, c/frames %d" % (config, s, num_frames))

                        # make an RTP packet
                        rtp = RTP_HEADER(0x80, self.payload_type,
                                         sequence_counter, header.GRANULE_POS,
                                         header.BITSTREAM)
                        rtp_raw_full = struct.pack(RTP_HEADER_FMT, *
                                                   rtp) + content

                        if self.srtpkey:
                            if self.ssrc != rtp.SSRC:
                                print("using key [%s]" % self.srtpkey)
                                srtpkey = base64.b64decode(self.srtpkey)
                                plc = Policy(
                                    key=srtpkey,
                                    ssrc_value=rtp.SSRC,
                                    ssrc_type=Policy.SSRC_ANY_OUTBOUND)
                                print(plc)
                                self.srtp_session = Session(policy=plc)
                                self.ssrc = rtp.SSRC
                            try:
                                rtp_raw_full = self.srtp_session.protect(
                                    rtp_raw_full)
                            except:
                                print(
                                    "encrypt fail seq={sequence}, ssrc={ssrc}".
                                    format(sequence=rtp.SEQUENCE_NUMBER,
                                           ssrc=rtp.SSRC))

                        hex_fd.write(self.hexdump(rtp_raw_full))
                        if sequence_counter == MAX_RTP_SEQUENCE_NUM:
                            sequence_counter = 0
                        else:
                            sequence_counter += 1

                page_counter += 1
                print("Read %d pages" % page_counter)
Exemple #17
0
    async def start(self, remoteParameters):
        """
        Start DTLS transport negotiation with the parameters of the remote
        DTLS transport.

        :param: remoteParameters: An :class:`RTCDtlsParameters`.
        """
        assert self._state not in [State.CLOSED, State.FAILED]
        assert len(remoteParameters.fingerprints)

        # handle the case where start is already in progress
        if self._start is not None:
            return await self._start.wait()
        self._start = asyncio.Event()

        if self.transport.role == 'controlling':
            self._role = 'server'
            lib.SSL_set_accept_state(self.ssl)
        else:
            self._role = 'client'
            lib.SSL_set_connect_state(self.ssl)

        self._set_state(State.CONNECTING)
        while not self.encrypted:
            result = lib.SSL_do_handshake(self.ssl)
            await self._write_ssl()

            if result > 0:
                self.encrypted = True
                break

            error = lib.SSL_get_error(self.ssl, result)
            if error == lib.SSL_ERROR_WANT_READ:
                await self._recv_next()
            else:
                self._set_state(State.FAILED)
                raise DtlsError('DTLS handshake failed (error %d)' % error)

        # check remote fingerprint
        x509 = lib.SSL_get_peer_certificate(self.ssl)
        remote_fingerprint = certificate_digest(x509)
        fingerprint_is_valid = False
        for f in remoteParameters.fingerprints:
            if f.algorithm == 'sha-256' and f.value.lower(
            ) == remote_fingerprint.lower():
                fingerprint_is_valid = True
                break
        if not fingerprint_is_valid:
            self._set_state(State.FAILED)
            raise DtlsError('DTLS fingerprint does not match')

        # generate keying material
        buf = ffi.new('unsigned char[]', 2 * (SRTP_KEY_LEN + SRTP_SALT_LEN))
        extractor = b'EXTRACTOR-dtls_srtp'
        _openssl_assert(
            lib.SSL_export_keying_material(self.ssl, buf, len(
                buf), extractor, len(extractor), ffi.NULL, 0, 0) == 1)

        view = ffi.buffer(buf)
        if self._role == 'server':
            srtp_tx_key = get_srtp_key_salt(view, 1)
            srtp_rx_key = get_srtp_key_salt(view, 0)
        else:
            srtp_tx_key = get_srtp_key_salt(view, 0)
            srtp_rx_key = get_srtp_key_salt(view, 1)

        rx_policy = Policy(key=srtp_rx_key, ssrc_type=Policy.SSRC_ANY_INBOUND)
        rx_policy.allow_repeat_tx = True
        rx_policy.window_size = 1024
        self._rx_srtp = Session(rx_policy)

        tx_policy = Policy(key=srtp_tx_key, ssrc_type=Policy.SSRC_ANY_OUTBOUND)
        tx_policy.allow_repeat_tx = True
        tx_policy.window_size = 1024
        self._tx_srtp = Session(tx_policy)

        # start data pump
        self.__log_debug('- DTLS handshake complete')
        self._set_state(State.CONNECTED)
        asyncio.ensure_future(self.__run())
        self._start.set()
Exemple #18
0
class DtlsSrtpSession:
    def __init__(self, context, is_server, transport):
        self.closed = asyncio.Event()
        self.encrypted = False
        self.is_server = is_server
        self.remote_fingerprint = None
        self.role = self.is_server and 'server' or 'client'
        self.state = self.State.CLOSED
        self.transport = transport

        self.data_queue = asyncio.Queue()
        self.data = Channel(closed=self.closed,
                            queue=self.data_queue,
                            send=self._send_data)

        self.rtp_queue = asyncio.Queue()
        self.rtp = Channel(closed=self.closed,
                           queue=self.rtp_queue,
                           send=self._send_rtp)

        ssl = lib.SSL_new(context.ctx)
        self.ssl = ffi.gc(ssl, lib.SSL_free)

        self.read_bio = lib.BIO_new(lib.BIO_s_mem())
        self.read_cdata = ffi.new('char[]', 1500)
        self.write_bio = lib.BIO_new(lib.BIO_s_mem())
        self.write_cdata = ffi.new('char[]', 1500)
        lib.SSL_set_bio(self.ssl, self.read_bio, self.write_bio)

        if self.is_server:
            lib.SSL_set_accept_state(self.ssl)
        else:
            lib.SSL_set_connect_state(self.ssl)

        # local fingerprint
        x509 = lib.SSL_get_certificate(self.ssl)
        self.local_fingerprint = certificate_digest(x509)

    async def close(self):
        if self.state != self.State.CLOSED:
            lib.SSL_shutdown(self.ssl)
            await self._write_ssl()
            logger.debug('%s - DTLS shutdown complete', self.role)
            self.closed.set()

    async def connect(self):
        assert self.state == self.State.CLOSED

        self._set_state(self.State.CONNECTING)
        while not self.encrypted:
            result = lib.SSL_do_handshake(self.ssl)
            await self._write_ssl()

            if result > 0:
                self.encrypted = True
                break

            error = lib.SSL_get_error(self.ssl, result)
            if error == lib.SSL_ERROR_WANT_READ:
                await self._recv_next()
            else:
                raise DtlsError('DTLS handshake failed (error %d)' % error)

        # check remote fingerprint
        x509 = lib.SSL_get_peer_certificate(self.ssl)
        remote_fingerprint = certificate_digest(x509)
        if remote_fingerprint != self.remote_fingerprint.upper():
            raise DtlsError('DTLS fingerprint does not match')

        # generate keying material
        buf = ffi.new('unsigned char[]', 2 * (SRTP_KEY_LEN + SRTP_SALT_LEN))
        extractor = b'EXTRACTOR-dtls_srtp'
        _openssl_assert(
            lib.SSL_export_keying_material(self.ssl, buf, len(
                buf), extractor, len(extractor), ffi.NULL, 0, 0) == 1)

        view = ffi.buffer(buf)
        if self.is_server:
            srtp_tx_key = get_srtp_key_salt(view, 1)
            srtp_rx_key = get_srtp_key_salt(view, 0)
        else:
            srtp_tx_key = get_srtp_key_salt(view, 0)
            srtp_rx_key = get_srtp_key_salt(view, 1)

        rx_policy = Policy(key=srtp_rx_key, ssrc_type=Policy.SSRC_ANY_INBOUND)
        self._rx_srtp = Session(rx_policy)
        tx_policy = Policy(key=srtp_tx_key, ssrc_type=Policy.SSRC_ANY_OUTBOUND)
        self._tx_srtp = Session(tx_policy)

        # start data pump
        logger.debug('%s - DTLS handshake complete', self.role)
        self._set_state(self.State.CONNECTED)
        asyncio.ensure_future(self.__run())

    async def __run(self):
        try:
            while True:
                await self._recv_next()
        except ConnectionError:
            pass
        finally:
            self._set_state(self.State.CLOSED)
            self.closed.set()

    async def _recv_next(self):
        data = await first_completed(self.transport.recv(), self.closed.wait())
        if data is True:
            # session was closed
            raise ConnectionError

        first_byte = data[0]
        if first_byte > 19 and first_byte < 64:
            # DTLS
            lib.BIO_write(self.read_bio, data, len(data))
            result = lib.SSL_read(self.ssl, self.read_cdata,
                                  len(self.read_cdata))
            if result == 0:
                logger.debug('%s - DTLS shutdown by remote party' % self.role)
                raise ConnectionError
            elif result > 0:
                await self.data_queue.put(
                    ffi.buffer(self.read_cdata)[0:result])
        elif first_byte > 127 and first_byte < 192:
            # SRTP / SRTCP
            if is_rtcp(data):
                data = self._rx_srtp.unprotect_rtcp(data)
            else:
                data = self._rx_srtp.unprotect(data)
            await self.rtp_queue.put(data)

    async def _send_data(self, data):
        if self.state != self.State.CONNECTED:
            raise ConnectionError('Cannot send encrypted data, not connected')

        lib.SSL_write(self.ssl, data, len(data))
        await self._write_ssl()

    async def _send_rtp(self, data):
        if self.state != self.State.CONNECTED:
            raise ConnectionError('Cannot send encrypted RTP, not connected')

        if is_rtcp(data):
            data = self._tx_srtp.protect_rtcp(data)
        else:
            data = self._tx_srtp.protect(data)
        await self.transport.send(data)

    def _set_state(self, state):
        if state != self.state:
            logger.debug('%s - %s -> %s', self.role, self.state, state)
            self.state = state

    async def _write_ssl(self):
        """
        Flush outgoing data which OpenSSL put in our BIO to the transport.
        """
        pending = lib.BIO_ctrl_pending(self.write_bio)
        if pending > 0:
            result = lib.BIO_read(self.write_bio, self.write_cdata,
                                  len(self.write_cdata))
            await self.transport.send(ffi.buffer(self.write_cdata)[0:result])

    class State(enum.Enum):
        CLOSED = 0
        CONNECTING = 1
        CONNECTED = 2