def test_garbage(self): client_transport, server_transport = dummy_dtls_transport_pair() server = RTCSctpTransport(server_transport) server.start(RTCSctpCapabilities(maxMessageSize=65536), 5000) asyncio.ensure_future(client_transport.send(b'garbage')) # check outcome run(asyncio.sleep(0.5)) self.assertEqual(server.state, RTCSctpTransport.State.CLOSED) # shutdown run(server.stop())
def test_receive_sack_discard(self): client_transport, _ = dummy_dtls_transport_pair() client = RTCSctpTransport(client_transport) client._last_received_tsn = 0 # receive sack sack_point = client._last_sacked_tsn chunk = SackChunk() chunk.cumulative_tsn = tsn_minus_one(sack_point) run(client._receive_chunk(chunk)) # sack point must not changed self.assertEqual(client._last_sacked_tsn, sack_point)
def test_mark_received(self): client_transport = DummyDtlsTransport() client = RTCSctpTransport(client_transport) client._last_received_tsn = 0 # receive 1 self.assertFalse(client._mark_received(1)) self.assertEqual(client._last_received_tsn, 1) self.assertEqual(client._sack_misordered, set()) # receive 3 self.assertFalse(client._mark_received(3)) self.assertEqual(client._last_received_tsn, 1) self.assertEqual(client._sack_misordered, set([3])) # receive 4 self.assertFalse(client._mark_received(4)) self.assertEqual(client._last_received_tsn, 1) self.assertEqual(client._sack_misordered, set([3, 4])) # receive 6 self.assertFalse(client._mark_received(6)) self.assertEqual(client._last_received_tsn, 1) self.assertEqual(client._sack_misordered, set([3, 4, 6])) # receive 2 self.assertFalse(client._mark_received(2)) self.assertEqual(client._last_received_tsn, 4) self.assertEqual(client._sack_misordered, set([6]))
def test_bad_verification_tag(self): # verification tag is 12345 instead of 0 data = load('sctp_init_bad_verification.bin') client_transport, server_transport = dummy_dtls_transport_pair() server = RTCSctpTransport(server_transport) server.start(RTCSctpCapabilities(maxMessageSize=65536), 5000) asyncio.ensure_future(client_transport.send(data)) # check outcome run(asyncio.sleep(0.5)) self.assertEqual(server.state, RTCSctpTransport.State.CLOSED) # shutdown run(server.stop())
def test_send_sack(self): sack = None async def mock_send_chunk(c): nonlocal sack sack = c client_transport = DummyDtlsTransport() client = RTCSctpTransport(client_transport) client._last_received_tsn = 123 client._send_chunk = mock_send_chunk run(client._send_sack()) self.assertIsNotNone(sack) self.assertEqual(sack.duplicates, []) self.assertEqual(sack.gaps, []) self.assertEqual(sack.cumulative_tsn, 123)
def test_t3_expired(self): async def mock_send_chunk(chunk): pass async def mock_transmit(): pass client_transport = DummyDtlsTransport() client = RTCSctpTransport(client_transport) client._send_chunk = mock_send_chunk # 1 chunk run(client._send(123, 456, b'M' * USERDATA_MAX_LENGTH)) self.assertIsNotNone(client._t3_handle) self.assertEqual(len(client._outbound_queue), 1) self.assertEqual(client._outbound_queue_pos, 1) # t3 expires client._transmit = mock_transmit client._t3_expired() self.assertIsNone(client._t3_handle) self.assertEqual(len(client._outbound_queue), 1) self.assertEqual(client._outbound_queue_pos, 0) # let async code complete run(asyncio.sleep(0))
def test_send_data(self): async def mock_send_chunk(chunk): pass client_transport = DummyDtlsTransport() client = RTCSctpTransport(client_transport) client._send_chunk = mock_send_chunk # no data run(client._transmit()) self.assertIsNone(client._t3_handle) self.assertEqual(client._outbound_queue_pos, 0) # 1 chunk run(client._send(123, 456, b'M' * USERDATA_MAX_LENGTH)) self.assertIsNotNone(client._t3_handle) self.assertEqual(len(client._outbound_queue), 1) self.assertEqual(client._outbound_queue_pos, 1)
def test_receive_heartbeat(self): client_transport, server_transport = dummy_dtls_transport_pair() client = RTCSctpTransport(client_transport) client._last_received_tsn = 0 client._remote_port = 5000 # receive heartbeat chunk = HeartbeatChunk() chunk.params.append((1, b'\x01\x02\x03\x04')) chunk.tsn = 1 run(client._receive_chunk(chunk)) # check response data = run(server_transport.recv()) packet = Packet.parse(data) self.assertEqual(len(packet.chunks), 1) self.assertTrue(isinstance(packet.chunks[0], HeartbeatAckChunk)) self.assertEqual(packet.chunks[0].params, [(1, b'\x01\x02\x03\x04')])
def test_send_data_over_cwnd(self): async def mock_send_chunk(chunk): pass client_transport = DummyDtlsTransport() client = RTCSctpTransport(client_transport) client._send_chunk = mock_send_chunk client._ssthresh = 131072 # STEP 1 - queue 4 chunks, but cwnd only allows 3 run(client._send(123, 456, b'M' * USERDATA_MAX_LENGTH * 4)) # T3 timer was started self.assertIsNotNone(client._t3_handle) self.assertEqual(len(client._outbound_queue), 4) self.assertEqual(client._outbound_queue_pos, 3) # STEP 2 - sack comes in acknowledging 2 chunks previous_timer = client._t3_handle sack = SackChunk() sack.cumulative_tsn = client._outbound_queue[1].tsn run(client._receive_chunk(sack)) # T3 timer was restarted self.assertIsNotNone(client._t3_handle) self.assertNotEqual(client._t3_handle, previous_timer) self.assertEqual(len(client._outbound_queue), 2) self.assertEqual(client._outbound_queue_pos, 2) # STEP 3 - sack comes in acknowledging 2 more chunks sack = SackChunk() sack.cumulative_tsn = client._outbound_queue[1].tsn run(client._receive_chunk(sack)) # T3 timer was stopped self.assertIsNone(client._t3_handle) self.assertEqual(len(client._outbound_queue), 0) self.assertEqual(client._outbound_queue_pos, 0)
def test_receive_shutdown(self): async def mock_send_chunk(chunk): pass client_transport, _ = dummy_dtls_transport_pair() client = RTCSctpTransport(client_transport) client._last_received_tsn = 0 client._send_chunk = mock_send_chunk client.state = RTCSctpTransport.State.ESTABLISHED # receive shutdown chunk = ShutdownChunk() chunk.cumulative_tsn = tsn_minus_one(client._last_sacked_tsn) run(client._receive_chunk(chunk)) self.assertEqual(client.state, RTCSctpTransport.State.SHUTDOWN_ACK_SENT) # receive shutdown complete chunk = ShutdownCompleteChunk() run(client._receive_chunk(chunk)) self.assertEqual(client.state, RTCSctpTransport.State.CLOSED)
def test_send_sack_with_gaps(self): sack = None async def mock_send_chunk(c): nonlocal sack sack = c client_transport = DummyDtlsTransport() client = RTCSctpTransport(client_transport) client._last_received_tsn = 12 client._sack_misordered = [14, 15, 17] client._send_chunk = mock_send_chunk run(client._send_sack()) self.assertIsNotNone(sack) self.assertEqual(sack.duplicates, []) self.assertEqual(sack.gaps, [(2, 3), (5, 5)]) self.assertEqual(sack.cumulative_tsn, 12)
def test_receive_data(self): client_transport, _ = dummy_dtls_transport_pair() client = RTCSctpTransport(client_transport) client._last_received_tsn = 0 # receive chunk chunk = DataChunk(flags=(SCTP_DATA_FIRST_FRAG | SCTP_DATA_LAST_FRAG)) chunk.user_data = b'foo' chunk.tsn = 1 run(client._receive_chunk(chunk)) self.assertEqual(client._sack_needed, True) self.assertEqual(client._sack_duplicates, []) self.assertEqual(client._last_received_tsn, 1) client._sack_needed = False # receive chunk again run(client._receive_chunk(chunk)) self.assertEqual(client._sack_needed, True) self.assertEqual(client._sack_duplicates, [1]) self.assertEqual(client._last_received_tsn, 1)
def test_stale_cookie(self): def mock_timestamp(): mock_timestamp.calls += 1 if mock_timestamp.calls == 1: return 0 else: return 61 mock_timestamp.calls = 0 client_transport, server_transport = dummy_dtls_transport_pair() client = RTCSctpTransport(client_transport) server = RTCSctpTransport(server_transport) server._get_timestamp = mock_timestamp server.start(client.getCapabilities(), client.port) client.start(server.getCapabilities(), server.port) # check outcome run(asyncio.sleep(0.5)) self.assertEqual(client.state, RTCSctpTransport.State.CLOSED) self.assertEqual(server.state, RTCSctpTransport.State.CLOSED) # shutdown run(client.stop()) run(server.stop()) self.assertEqual(client.state, RTCSctpTransport.State.CLOSED) self.assertEqual(server.state, RTCSctpTransport.State.CLOSED)
def test_bad_cookie(self): client_transport, server_transport = dummy_dtls_transport_pair() client = RTCSctpTransport(client_transport) server = RTCSctpTransport(server_transport) # corrupt cookie real_send_chunk = client._send_chunk async def mock_send_chunk(chunk): if isinstance(chunk, CookieEchoChunk): chunk.body = b'garbage' return await real_send_chunk(chunk) client._send_chunk = mock_send_chunk server.start(client.getCapabilities(), client.port) client.start(server.getCapabilities(), server.port) # check outcome run(asyncio.sleep(0.5)) self.assertEqual(client.state, RTCSctpTransport.State.COOKIE_ECHOED) self.assertEqual(server.state, RTCSctpTransport.State.CLOSED) # shutdown run(client.stop()) run(server.stop()) self.assertEqual(client.state, RTCSctpTransport.State.CLOSED) self.assertEqual(server.state, RTCSctpTransport.State.CLOSED)
def test_abort(self): client_transport, server_transport = dummy_dtls_transport_pair() client = RTCSctpTransport(client_transport) server = RTCSctpTransport(server_transport) # connect server.start(client.getCapabilities(), client.port) client.start(server.getCapabilities(), server.port) # check outcome run(wait_for_outcome(client, server)) self.assertEqual(client.state, RTCSctpTransport.State.ESTABLISHED) self.assertEqual(server.state, RTCSctpTransport.State.ESTABLISHED) # shutdown run(client._abort()) run(asyncio.sleep(0.5)) self.assertEqual(client.state, RTCSctpTransport.State.CLOSED) self.assertEqual(server.state, RTCSctpTransport.State.CLOSED)
def test_construct(self): dtlsTransport, _ = dummy_dtls_transport_pair() sctpTransport = RTCSctpTransport(dtlsTransport) self.assertEqual(sctpTransport.transport, dtlsTransport) self.assertEqual(sctpTransport.port, 5000)
def test_abrupt_disconnect(self): """ Abrupt disconnect causes the __run() loop to exit. """ client_transport, server_transport = dummy_dtls_transport_pair() client = RTCSctpTransport(client_transport) server = RTCSctpTransport(server_transport) # connect server.start(client.getCapabilities(), client.port) client.start(server.getCapabilities(), server.port) # check outcome run(wait_for_outcome(client, server)) self.assertEqual(client.state, RTCSctpTransport.State.ESTABLISHED) self.assertEqual(server.state, RTCSctpTransport.State.ESTABLISHED) # break one connection run(client_transport.close()) run(asyncio.sleep(0.1)) self.assertEqual(client.state, RTCSctpTransport.State.CLOSED) # break other connection run(server_transport.close()) run(asyncio.sleep(0.1)) self.assertEqual(server.state, RTCSctpTransport.State.CLOSED) # try closing again run(client.stop()) run(server.stop())
def test_connect_then_server_creates_data_channel(self): client_transport, server_transport = dummy_dtls_transport_pair() client = RTCSctpTransport(client_transport) self.assertFalse(client.is_server) server = RTCSctpTransport(server_transport) self.assertTrue(server.is_server) client_channels = track_channels(client) server_channels = track_channels(server) # connect server.start(client.getCapabilities(), client.port) client.start(server.getCapabilities(), server.port) # check outcome run(wait_for_outcome(client, server)) self.assertEqual(client.state, RTCSctpTransport.State.ESTABLISHED) self.assertEqual(client._remote_extensions, [130]) self.assertEqual(server.state, RTCSctpTransport.State.ESTABLISHED) self.assertEqual(server._remote_extensions, [130]) # create data channel channel = RTCDataChannel(server, RTCDataChannelParameters(label='chat')) self.assertEqual(channel.id, None) self.assertEqual(channel.label, 'chat') run(asyncio.sleep(0.5)) self.assertEqual(len(client_channels), 1) self.assertEqual(client_channels[0].id, 0) self.assertEqual(client_channels[0].label, 'chat') self.assertEqual(len(server_channels), 0) # shutdown run(client.stop()) run(server.stop()) self.assertEqual(client.state, RTCSctpTransport.State.CLOSED) self.assertEqual(server.state, RTCSctpTransport.State.CLOSED)
def test_connect_server_limits_streams(self): client_transport, server_transport = dummy_dtls_transport_pair() client = RTCSctpTransport(client_transport) self.assertFalse(client.is_server) server = RTCSctpTransport(server_transport) server.inbound_streams_max = 2048 server.outbound_streams = 256 self.assertTrue(server.is_server) # connect server.start(client.getCapabilities(), client.port) client.start(server.getCapabilities(), server.port) # check outcome run(wait_for_outcome(client, server)) self.assertEqual(client.state, RTCSctpTransport.State.ESTABLISHED) self.assertEqual(client.inbound_streams, 256) self.assertEqual(client.outbound_streams, 2048) self.assertEqual(client._remote_extensions, [130]) self.assertEqual(server.state, RTCSctpTransport.State.ESTABLISHED) self.assertEqual(server.inbound_streams, 2048) self.assertEqual(server.outbound_streams, 256) self.assertEqual(server._remote_extensions, [130]) run(asyncio.sleep(0.5)) # shutdown run(client.stop()) run(server.stop()) self.assertEqual(client.state, RTCSctpTransport.State.CLOSED) self.assertEqual(server.state, RTCSctpTransport.State.CLOSED)
def test_connect_lossy_transport(self): """ Transport with 25% loss eventually connects. """ client_transport, server_transport = dummy_dtls_transport_pair( loss=[True, False, False, False]) client = RTCSctpTransport(client_transport) client._rto = 0.1 self.assertFalse(client.is_server) server = RTCSctpTransport(server_transport) server._rto = 0.1 self.assertTrue(server.is_server) # connect server.start(client.getCapabilities(), client.port) client.start(server.getCapabilities(), server.port) # check outcome run(wait_for_outcome(client, server)) self.assertEqual(client.state, RTCSctpTransport.State.ESTABLISHED) self.assertEqual(server.state, RTCSctpTransport.State.ESTABLISHED) # transmit data server_queue = asyncio.Queue() async def server_fake_receive(*args): await server_queue.put(args) server._receive = server_fake_receive for i in range(20): message = (123, i, b'ping') run(client._send(*message)) received = run(server_queue.get()) self.assertEqual(received, message) # shutdown run(client.stop()) run(server.stop()) self.assertEqual(client.state, RTCSctpTransport.State.CLOSED) self.assertEqual(server.state, RTCSctpTransport.State.CLOSED)
def test_connect_broken_transport(self): """ Transport with 100% loss never connects. """ client_transport, server_transport = dummy_dtls_transport_pair( loss=[True]) client = RTCSctpTransport(client_transport) client._rto = 0.1 self.assertFalse(client.is_server) server = RTCSctpTransport(server_transport) server._rto = 0.1 self.assertTrue(server.is_server) # connect server.start(client.getCapabilities(), client.port) client.start(server.getCapabilities(), server.port) # check outcome run(wait_for_outcome(client, server)) self.assertEqual(client.state, RTCSctpTransport.State.CLOSED) self.assertEqual(server.state, RTCSctpTransport.State.CLOSED) # shutdown run(client.stop()) run(server.stop()) self.assertEqual(client.state, RTCSctpTransport.State.CLOSED) self.assertEqual(server.state, RTCSctpTransport.State.CLOSED)
def test_construct_invalid_dtls_transport_state(self): dtlsTransport = DummyDtlsTransport(state='closed') with self.assertRaises(InvalidStateError): RTCSctpTransport(dtlsTransport)
def test_receive_data_out_of_order(self): client_transport, _ = dummy_dtls_transport_pair() client = RTCSctpTransport(client_transport) client._last_received_tsn = 0 # build chunks chunks = [] chunk = DataChunk(flags=SCTP_DATA_FIRST_FRAG) chunk.user_data = b'foo' chunk.tsn = 1 chunks.append(chunk) chunk = DataChunk() chunk.user_data = b'bar' chunk.tsn = 2 chunks.append(chunk) chunk = DataChunk(flags=SCTP_DATA_LAST_FRAG) chunk.user_data = b'baz' chunk.tsn = 3 chunks.append(chunk) # receive first chunk run(client._receive_chunk(chunks[0])) self.assertEqual(client._sack_needed, True) self.assertEqual(client._sack_duplicates, []) self.assertEqual(client._sack_misordered, set()) self.assertEqual(client._last_received_tsn, 1) client._sack_needed = False # receive last chunk run(client._receive_chunk(chunks[2])) self.assertEqual(client._sack_needed, True) self.assertEqual(client._sack_duplicates, []) self.assertEqual(client._sack_misordered, set([3])) self.assertEqual(client._last_received_tsn, 1) client._sack_needed = False # receive middle chunk run(client._receive_chunk(chunks[1])) self.assertEqual(client._sack_needed, True) self.assertEqual(client._sack_duplicates, []) self.assertEqual(client._sack_misordered, set([])) self.assertEqual(client._last_received_tsn, 3) client._sack_needed = False # receive last chunk again run(client._receive_chunk(chunks[2])) self.assertEqual(client._sack_needed, True) self.assertEqual(client._sack_duplicates, [3]) self.assertEqual(client._sack_misordered, set([])) self.assertEqual(client._last_received_tsn, 3) client._sack_needed = False
def test_abrupt_disconnect_2(self): """ Abrupt disconnect causes sending ABORT chunk to fail. """ client_transport, server_transport = dummy_dtls_transport_pair() client = RTCSctpTransport(client_transport) server = RTCSctpTransport(server_transport) # connect server.start(client.getCapabilities(), client.port) client.start(server.getCapabilities(), server.port) # check outcome run(wait_for_outcome(client, server)) self.assertEqual(client.state, RTCSctpTransport.State.ESTABLISHED) self.assertEqual(server.state, RTCSctpTransport.State.ESTABLISHED) # break connection run(client_transport.close()) run(server_transport.close()) # stop run(client.stop()) run(server.stop())
def test_connect_client_limits_streams(self): client_transport, server_transport = dummy_dtls_transport_pair() client = RTCSctpTransport(client_transport) client.inbound_streams_max = 2048 client.outbound_streams = 256 self.assertFalse(client.is_server) server = RTCSctpTransport(server_transport) self.assertTrue(server.is_server) # connect server.start(client.getCapabilities(), client.port) client.start(server.getCapabilities(), server.port) # check outcome run(wait_for_outcome(client, server)) self.assertEqual(client.state, RTCSctpTransport.State.ESTABLISHED) self.assertEqual(client.inbound_streams, 2048) self.assertEqual(client.outbound_streams, 256) self.assertEqual(client._remote_extensions, [130]) self.assertEqual(server.state, RTCSctpTransport.State.ESTABLISHED) self.assertEqual(server.inbound_streams, 256) self.assertEqual(server.outbound_streams, 2048) self.assertEqual(server._remote_extensions, [130]) # client requests additional outbound streams param = StreamAddOutgoingParam( request_sequence=client._reconfig_request_seq, new_streams=16) run(client._send_reconfig_param(param)) run(asyncio.sleep(0.5)) self.assertEqual(server.inbound_streams, 272) self.assertEqual(server.outbound_streams, 2048) # shutdown run(client.stop()) run(server.stop()) self.assertEqual(client.state, RTCSctpTransport.State.CLOSED) self.assertEqual(server.state, RTCSctpTransport.State.CLOSED)
def test_t2_expired_when_shutdown_ack_sent(self): async def mock_send_chunk(chunk): pass client_transport = DummyDtlsTransport() client = RTCSctpTransport(client_transport) client._last_received_tsn = 0 client._send_chunk = mock_send_chunk chunk = ShutdownAckChunk() # fails once client.state = RTCSctpTransport.State.SHUTDOWN_ACK_SENT client._t2_start(chunk) client._t2_expired() self.assertEqual(client._t2_failures, 1) self.assertIsNotNone(client._t2_handle) self.assertEqual(client.state, RTCSctpTransport.State.SHUTDOWN_ACK_SENT) # fails 10 times client._t2_failures = 9 client._t2_expired() self.assertEqual(client._t2_failures, 10) self.assertIsNotNone(client._t2_handle) self.assertEqual(client.state, RTCSctpTransport.State.SHUTDOWN_ACK_SENT) # fails 11 times client._t2_expired() self.assertEqual(client._t2_failures, 11) self.assertIsNone(client._t2_handle) self.assertEqual(client.state, RTCSctpTransport.State.CLOSED) # let async code complete run(asyncio.sleep(0))