Example #1
0
    def test_add_static_size_with_already_defined_type(self):
        d = Defragmenter()

        d.add_static_size(10, 255)

        with self.assertRaises(ValueError):
            d.add_static_size(10, 2)
    def test_add_static_size_with_multiple_types(self):
        d = Defragmenter()

        # types are added in order of priority...
        d.add_static_size(10, 2)
        # so type 8 should be returned later than type 10 if both are in buffer
        d.add_static_size(8, 4)

        d.add_data(8, bytearray(b'\x08'*4))
        d.add_data(10, bytearray(b'\x10'*2))

        ret = d.get_message()
        self.assertIsNotNone(ret)
        msg_type, data = ret
        self.assertEqual(10, msg_type)
        self.assertEqual(bytearray(b'\x10'*2), data)

        ret = d.get_message()
        self.assertIsNotNone(ret)
        msg_type, data = ret
        self.assertEqual(8, msg_type)
        self.assertEqual(bytearray(b'\x08'*4), data)

        ret = d.get_message()
        self.assertIsNone(ret)
    def test_recvMessage_with_blocking_socket(self):
        defragmenter = Defragmenter()
        defragmenter.add_static_size(21, 2)

        sock = MockSocket(
            bytearray(b'\x15' +  # message type
                      b'\x03\x03' +  # TLS version
                      b'\x00\x02' +  # payload length
                      b'\xff\xff'  # message
                      ),
            blockEveryOther=True,
            maxRet=1)

        msgSock = MessageSocket(sock, defragmenter)

        gotBlocked = False
        for res in msgSock.recvMessage():
            if res in (0, 1):
                gotBlocked = True
            else:
                break

        self.assertTrue(gotBlocked)
        self.assertIsNotNone(res)

        header, parser = res

        self.assertEqual(header.type, 21)
        self.assertEqual(header.version, (3, 3))
        self.assertEqual(parser.bytes, bytearray(b'\xff\xff'))
    def test_recvMessage_with_unfragmentable_type(self):
        defragmenter = Defragmenter()
        defragmenter.add_static_size(21, 2)

        sock = MockSocket(
            bytearray(b'\x17' +  # message type
                      b'\x03\x03' +  # TLS version
                      b'\x00\x06' +  # payload length
                      b'\x00\x04' + b'\xff' * 4))

        msgSock = MessageSocket(sock, defragmenter)

        for res in msgSock.recvMessage():
            if res in (0, 1):
                self.assertTrue(False, "Blocking read")
            else:
                break

        self.assertIsNotNone(res)

        header, parser = res

        self.assertEqual(header.type, 23)
        self.assertEqual(header.version, (3, 3))
        self.assertEqual(header.length, 6)
        self.assertEqual(parser.bytes, bytearray(b'\x00\x04' + b'\xff' * 4))
    def test_add_static_size_with_already_defined_type(self):
        d = Defragmenter()

        d.add_static_size(10, 255)

        with self.assertRaises(ValueError):
            d.add_static_size(10, 2)
Example #6
0
    def test_add_static_size_with_multiple_types(self):
        d = Defragmenter()

        # types are added in order of priority...
        d.add_static_size(10, 2)
        # so type 8 should be returned later than type 10 if both are in buffer
        d.add_static_size(8, 4)

        d.add_data(8, bytearray(b'\x08' * 4))
        d.add_data(10, bytearray(b'\x10' * 2))

        ret = d.get_message()
        self.assertIsNotNone(ret)
        msg_type, data = ret
        self.assertEqual(10, msg_type)
        self.assertEqual(bytearray(b'\x10' * 2), data)

        ret = d.get_message()
        self.assertIsNotNone(ret)
        msg_type, data = ret
        self.assertEqual(8, msg_type)
        self.assertEqual(bytearray(b'\x08' * 4), data)

        ret = d.get_message()
        self.assertIsNone(ret)
    def test_clear_buffers(self):
        d = Defragmenter()

        d.add_static_size(10, 2)

        d.add_data(10, bytearray(10))

        d.clear_buffers()

        self.assertIsNone(d.get_message())
Example #8
0
    def test_clear_buffers(self):
        d = Defragmenter()

        d.add_static_size(10, 2)

        d.add_data(10, bytearray(10))

        d.clear_buffers()

        self.assertIsNone(d.get_message())
    def test_add_static_size(self):
        d = Defragmenter()

        d.add_static_size(10, 2)

        d.add_data(10, bytearray(b'\x03'*2))

        ret = d.get_message()
        self.assertIsNotNone(ret)
        msg_type, data = ret
        self.assertEqual(10, msg_type)
        self.assertEqual(bytearray(b'\x03'*2), data)
Example #10
0
    def test_add_static_size(self):
        d = Defragmenter()

        d.add_static_size(10, 2)

        d.add_data(10, bytearray(b'\x03' * 2))

        ret = d.get_message()
        self.assertIsNotNone(ret)
        msg_type, data = ret
        self.assertEqual(10, msg_type)
        self.assertEqual(bytearray(b'\x03' * 2), data)
    def test_recvMessage(self):
        defragmenter = Defragmenter()
        defragmenter.add_static_size(21, 2)

        sock = MockSocket(
            bytearray(b'\x15' +  # message type
                      b'\x03\x03' +  # TLS version
                      b'\x00\x04' +  # payload length
                      b'\xff\xff' +  # first message
                      b'\xbb\xbb'  # second message
                      ))

        msgSock = MessageSocket(sock, defragmenter)

        for res in msgSock.recvMessage():
            if res in (0, 1):
                self.assertTrue(False, "Blocking read")
            else:
                break

        self.assertIsNotNone(res)

        header, parser = res

        self.assertEqual(header.type, 21)
        self.assertEqual(header.version, (3, 3))
        self.assertEqual(header.length, 0)
        self.assertEqual(parser.bytes, bytearray(b'\xff\xff'))

        res = None

        for res in msgSock.recvMessage():
            if res in (0, 1):
                self.assertTrue(False, "Blocking read")
            else:
                break

        self.assertIsNotNone(res)

        header, parser = res

        self.assertEqual(header.type, 21)
        self.assertEqual(header.version, (3, 3))
        self.assertEqual(header.length, 0)
        self.assertEqual(parser.bytes, bytearray(b'\xbb\xbb'))
Example #12
0
    def process(self, state):
        """Connect to a server."""
        sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
        sock.settimeout(5)
        sock.connect((self.hostname, self.port))
        # disable Nagle - we handle buffering and flushing ourselves
        sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)

        # allow for later buffering of writes to the socket
        sock = BufferedSocket(sock)

        defragmenter = Defragmenter()
        defragmenter.add_static_size(ContentType.alert, 2)
        defragmenter.add_static_size(ContentType.change_cipher_spec, 1)
        defragmenter.add_dynamic_size(ContentType.handshake, 1, 3)

        state.msg_sock = MessageSocket(sock, defragmenter)

        state.msg_sock.version = self.version
    def test_add_static_size_with_uncomplete_message(self):
        d = Defragmenter()

        d.add_static_size(10, 2)

        d.add_data(10, bytearray(b'\x10'))

        ret = d.get_message()
        self.assertIsNone(ret)

        d.add_data(10, bytearray(b'\x11'))

        ret = d.get_message()
        self.assertIsNotNone(ret)
        msg_type, data = ret
        self.assertEqual(10, msg_type)
        self.assertEqual(bytearray(b'\x10\x11'), data)

        ret = d.get_message()
        self.assertIsNone(ret)
Example #14
0
    def test_add_static_size_with_uncomplete_message(self):
        d = Defragmenter()

        d.add_static_size(10, 2)

        d.add_data(10, bytearray(b'\x10'))

        ret = d.get_message()
        self.assertIsNone(ret)

        d.add_data(10, bytearray(b'\x11'))

        ret = d.get_message()
        self.assertIsNotNone(ret)
        msg_type, data = ret
        self.assertEqual(10, msg_type)
        self.assertEqual(bytearray(b'\x10\x11'), data)

        ret = d.get_message()
        self.assertIsNone(ret)
Example #15
0
    def test_add_static_size_with_multiple_uncompleted_messages(self):
        d = Defragmenter()

        d.add_static_size(10, 2)
        d.add_static_size(8, 4)

        d.add_data(8, bytearray(b'\x08' * 3))
        d.add_data(10, bytearray(b'\x10'))

        ret = d.get_message()
        self.assertIsNone(ret)

        d.add_data(8, bytearray(b'\x09'))

        ret = d.get_message()
        self.assertIsNotNone(ret)
        msg_type, data = ret
        self.assertEqual(8, msg_type)
        self.assertEqual(bytearray(b'\x08' * 3 + b'\x09'), data)

        ret = d.get_message()
        self.assertIsNone(ret)
    def test_add_static_size_with_multiple_uncompleted_messages(self):
        d = Defragmenter()

        d.add_static_size(10, 2)
        d.add_static_size(8, 4)

        d.add_data(8, bytearray(b'\x08'*3))
        d.add_data(10, bytearray(b'\x10'))

        ret = d.get_message()
        self.assertIsNone(ret)

        d.add_data(8, bytearray(b'\x09'))

        ret = d.get_message()
        self.assertIsNotNone(ret)
        msg_type, data = ret
        self.assertEqual(8, msg_type)
        self.assertEqual(bytearray(b'\x08'*3 + b'\x09'), data)

        ret = d.get_message()
        self.assertIsNone(ret)
    def test_recvMessageBlocking(self):
        defragmenter = Defragmenter()
        defragmenter.add_static_size(21, 2)

        sock = MockSocket(
            bytearray(b'\x15' +  # message type
                      b'\x03\x03' +  # TLS version
                      b'\x00\x02' +  # payload length
                      b'\xff\xff'  # message
                      ),
            blockEveryOther=True,
            maxRet=1)

        msgSock = MessageSocket(sock, defragmenter)

        res = msgSock.recvMessageBlocking()

        self.assertIsNotNone(res)

        header, parser = res

        self.assertEqual(header.type, 21)
        self.assertEqual(parser.bytes, bytearray(b'\xff\xff'))
Example #18
0
    def test_add_static_size_with_invalid_size(self):
        d = Defragmenter()

        with self.assertRaises(ValueError):
            d.add_static_size(10, -10)
    def test_add_static_size_with_invalid_size(self):
        d = Defragmenter()

        with self.assertRaises(ValueError):
            d.add_static_size(10, -10)