def test_recvMessage_with_blocking_socket(self):
        defragmenter = Defragmenter()
        defragmenter.addStaticSize(21, 2)

        sock = MockSocket(bytearray(
            b'\x15' +           # message type
            b'\x03\x03' +       # TLS version
            b'\x00\x02' +       # payload length
            b'\xff\xff'         # message
            ),
            blockEveryOther=True,
            maxRet=1)

        msgSock = MessageSocket(sock, defragmenter)

        gotBlocked = False
        for res in msgSock.recvMessage():
            if res in (0, 1):
                gotBlocked = True
            else:
                break

        self.assertTrue(gotBlocked)
        self.assertIsNotNone(res)

        header, parser = res

        self.assertEqual(header.type, 21)
        self.assertEqual(header.version, (3, 3))
        self.assertEqual(parser.bytes, bytearray(b'\xff\xff'))
    def test_recvMessage_with_unfragmentable_type(self):
        defragmenter = Defragmenter()
        defragmenter.addStaticSize(21, 2)

        sock = MockSocket(bytearray(
            b'\x17' +       # message type
            b'\x03\x03' +   # TLS version
            b'\x00\x06' +   # payload length
            b'\x00\x04' +
            b'\xff'*4
            ))

        msgSock = MessageSocket(sock, defragmenter)

        for res in msgSock.recvMessage():
            if res in (0, 1):
                self.assertTrue(False, "Blocking read")
            else:
                break

        self.assertIsNotNone(res)

        header, parser = res

        self.assertEqual(header.type, 23)
        self.assertEqual(header.version, (3, 3))
        self.assertEqual(header.length, 6)
        self.assertEqual(parser.bytes, bytearray(b'\x00\x04' + b'\xff'*4))
    def test_add_static_size_with_already_defined_type(self):
        d = Defragmenter()

        d.add_static_size(10, 255)

        with self.assertRaises(ValueError):
            d.add_static_size(10, 2)
    def test_add_dynamic_size(self):
        d = Defragmenter()

        d.add_dynamic_size(10, 2, 2)

        ret = d.get_message()
        self.assertIsNone(ret)

        d.add_data(10, bytearray(
            b'\xee\xee' +   # header bytes
            b'\x00\x00' +   # remaining length
            # next message
            b'\xff\xff' +   # header bytes
            b'\x00\x01' +   # remaining length
            b'\xf0'))

        ret = d.get_message()
        self.assertIsNotNone(ret)
        msg_type, data = ret
        self.assertEqual(10, msg_type)
        self.assertEqual(bytearray(b'\xee\xee\x00\x00'), data)

        ret = d.get_message()
        self.assertIsNotNone(ret)
        msg_type, data = ret
        self.assertEqual(10, msg_type)
        self.assertEqual(bytearray(b'\xff\xff\x00\x01\xf0'), data)

        ret = d.get_message()
        self.assertIsNone(ret)
    def test_clear_buffers(self):
        d = Defragmenter()

        d.add_static_size(10, 2)

        d.add_data(10, bytearray(10))

        d.clear_buffers()

        self.assertIsNone(d.get_message())
    def test_clearBuffers(self):
        d = Defragmenter()

        d.addStaticSize(10, 2)

        d.addData(10, bytearray(10))

        d.clearBuffers()

        self.assertIsNone(d.getMessage())
    def test_recvMessage(self):
        defragmenter = Defragmenter()
        defragmenter.addStaticSize(21, 2)

        sock = MockSocket(bytearray(
            b'\x15' +           # message type
            b'\x03\x03' +       # TLS version
            b'\x00\x04' +       # payload length
            b'\xff\xff' +       # first message
            b'\xbb\xbb'         # second message
            ))

        msgSock = MessageSocket(sock, defragmenter)

        for res in msgSock.recvMessage():
            if res in (0, 1):
                self.assertTrue(False, "Blocking read")
            else:
                break

        self.assertIsNotNone(res)

        header, parser = res

        self.assertEqual(header.type, 21)
        self.assertEqual(header.version, (3, 3))
        self.assertEqual(header.length, 0)
        self.assertEqual(parser.bytes, bytearray(b'\xff\xff'))

        res = None

        for res in msgSock.recvMessage():
            if res in (0, 1):
                self.assertTrue(False, "Blocking read")
            else:
                break

        self.assertIsNotNone(res)

        header, parser = res

        self.assertEqual(header.type, 21)
        self.assertEqual(header.version, (3, 3))
        self.assertEqual(header.length, 0)
        self.assertEqual(parser.bytes, bytearray(b'\xbb\xbb'))
    def test_add_dynamic_size_with_incomplete_payload(self):
        d = Defragmenter()

        d.add_dynamic_size(10, 2, 2)

        d.add_data(10, bytearray(b'\xee\xee\x00\x01'))

        self.assertIsNone(d.get_message())

        d.add_data(10, bytearray(b'\x99'))

        msg_type, data = d.get_message()
        self.assertEqual(10, msg_type)
        self.assertEqual(bytearray(b'\xee\xee\x00\x01\x99'), data)
    def test_add_static_size_with_uncomplete_message(self):
        d = Defragmenter()

        d.add_static_size(10, 2)

        d.add_data(10, bytearray(b'\x10'))

        ret = d.get_message()
        self.assertIsNone(ret)

        d.add_data(10, bytearray(b'\x11'))

        ret = d.get_message()
        self.assertIsNotNone(ret)
        msg_type, data = ret
        self.assertEqual(10, msg_type)
        self.assertEqual(bytearray(b'\x10\x11'), data)

        ret = d.get_message()
        self.assertIsNone(ret)
    def test_recvMessageBlocking(self):
        defragmenter = Defragmenter()
        defragmenter.addStaticSize(21, 2)

        sock = MockSocket(bytearray(
            b'\x15' +           # message type
            b'\x03\x03' +       # TLS version
            b'\x00\x02' +       # payload length
            b'\xff\xff'         # message
            ),
            blockEveryOther=True,
            maxRet=1)

        msgSock = MessageSocket(sock, defragmenter)

        res = msgSock.recvMessageBlocking()

        self.assertIsNotNone(res)

        header, parser = res

        self.assertEqual(header.type, 21)
        self.assertEqual(parser.bytes, bytearray(b'\xff\xff'))
Example #11
0
    def process(self, state):
        """Connect to a server"""
        sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
        sock.settimeout(5)
        sock.connect((self.hostname, self.port))

        defragmenter = Defragmenter()
        defragmenter.addStaticSize(ContentType.alert, 2)
        defragmenter.addStaticSize(ContentType.change_cipher_spec, 1)
        defragmenter.addDynamicSize(ContentType.handshake, 1, 3)

        state.msg_sock = MessageSocket(sock, defragmenter)

        state.msg_sock.version = self.version
    def test_add_static_size(self):
        d = Defragmenter()

        d.add_static_size(10, 2)

        d.add_data(10, bytearray(b'\x03'*2))

        ret = d.get_message()
        self.assertIsNotNone(ret)
        msg_type, data = ret
        self.assertEqual(10, msg_type)
        self.assertEqual(bytearray(b'\x03'*2), data)
    def test_addStaticSize_with_zero_size(self):
        d = Defragmenter()

        with self.assertRaises(ValueError):
            d.addStaticSize(10, 0)
    def test_add_data_with_undefined_type(self):
        d = Defragmenter()

        with self.assertRaises(ValueError):
            d.add_data(1, bytearray(10))
    def test_add_dynamic_size_with_invalid_offset(self):
        d = Defragmenter()

        with self.assertRaises(ValueError):
            d.add_dynamic_size(1, -1, 2)
    def test_add_dynamic_size_with_double_type(self):
        d = Defragmenter()

        d.add_dynamic_size(1, 0, 1)
        with self.assertRaises(ValueError):
            d.add_dynamic_size(1, 2, 2)
    def test_add_static_size_with_invalid_size(self):
        d = Defragmenter()

        with self.assertRaises(ValueError):
            d.add_static_size(10, -10)
    def test_add_static_size_with_multiple_uncompleted_messages(self):
        d = Defragmenter()

        d.add_static_size(10, 2)
        d.add_static_size(8, 4)

        d.add_data(8, bytearray(b'\x08'*3))
        d.add_data(10, bytearray(b'\x10'))

        ret = d.get_message()
        self.assertIsNone(ret)

        d.add_data(8, bytearray(b'\x09'))

        ret = d.get_message()
        self.assertIsNotNone(ret)
        msg_type, data = ret
        self.assertEqual(8, msg_type)
        self.assertEqual(bytearray(b'\x08'*3 + b'\x09'), data)

        ret = d.get_message()
        self.assertIsNone(ret)
    def test_add_dynamic_size_with_incomplete_header(self):
        d = Defragmenter()

        d.add_dynamic_size(10, 2, 2)

        d.add_data(10, bytearray(b'\xee'))

        self.assertIsNone(d.get_message())

        d.add_data(10, bytearray(b'\xee'))

        self.assertIsNone(d.get_message())

        d.add_data(10, bytearray(b'\x00'))

        self.assertIsNone(d.get_message())

        d.add_data(10, bytearray(b'\x00'))

        ret = d.get_message()
        self.assertIsNotNone(ret)
        msg_type, data = ret
        self.assertEqual(10, msg_type)
        self.assertEqual(bytearray(b'\xee\xee\x00\x00'), data)
    def test_addDynamicSize_with_invalid_size(self):
        d = Defragmenter()

        with self.assertRaises(ValueError):
            d.addDynamicSize(1, 2, 0)
    def test_add_static_size_with_multiple_types(self):
        d = Defragmenter()

        # types are added in order of priority...
        d.add_static_size(10, 2)
        # so type 8 should be returned later than type 10 if both are in buffer
        d.add_static_size(8, 4)

        d.add_data(8, bytearray(b'\x08'*4))
        d.add_data(10, bytearray(b'\x10'*2))

        ret = d.get_message()
        self.assertIsNotNone(ret)
        msg_type, data = ret
        self.assertEqual(10, msg_type)
        self.assertEqual(bytearray(b'\x10'*2), data)

        ret = d.get_message()
        self.assertIsNotNone(ret)
        msg_type, data = ret
        self.assertEqual(8, msg_type)
        self.assertEqual(bytearray(b'\x08'*4), data)

        ret = d.get_message()
        self.assertIsNone(ret)
    def test_add_dynamic_size_with_two_streams(self):
        d = Defragmenter()

        d.add_dynamic_size(9, 0, 3)
        d.add_dynamic_size(10, 2, 2)

        d.add_data(10, bytearray(b'\x44\x44\x00\x04'))
        d.add_data(9, bytearray(b'\x00\x00\x02'))

        self.assertIsNone(d.get_message())

        d.add_data(9, bytearray(b'\x09'*2))
        d.add_data(10, bytearray(b'\x10'*4))

        msg_type, data = d.get_message()
        self.assertEqual(msg_type, 9)
        self.assertEqual(data, bytearray(b'\x00\x00\x02\x09\x09'))

        msg_type, data = d.get_message()
        self.assertEqual(msg_type, 10)
        self.assertEqual(data, bytearray(b'\x44'*2 + b'\x00\x04' + b'\x10'*4))
    def test_get_message(self):
        a = Defragmenter()

        self.assertIsNone(a.get_message())
        self.assertIsNone(a.get_message())
Example #24
0
    def scan(self):
        """Perform a scan on server."""
        defragger = Defragmenter()
        defragger.addStaticSize(ContentType.change_cipher_spec, 1)
        defragger.addStaticSize(ContentType.alert, 2)
        defragger.addDynamicSize(ContentType.handshake, 1, 3)

        try:
            raw_sock = socket.create_connection((self.host, self.port), 5)
        except socket.error as e:
            return [e]

        sock = MessageSocket(raw_sock, defragger)

        if self.hostname is not None:
            client_hello = self.hello_gen(bytearray(self.hostname,
                                                    'utf-8'))
        else:
            client_hello = self.hello_gen(None)

        # record layer version - TLSv1.x
        # use the version from configuration, if present, or default to the
        # RFC recommended (3, 1) for TLS and (3, 0) for SSLv3
        if hasattr(client_hello, 'record_version'):
            sock.version = client_hello.record_version
        elif hasattr(self.hello_gen, 'record_version'):
            sock.version = self.hello_gen.record_version
        elif client_hello.client_version > (3, 1):  # TLS1.0
            sock.version = (3, 1)
        else:
            sock.version = client_hello.client_version

        # we don't want to send invalid messages (SSLv2 hello in SSL record
        # layer), so set the record layer version to SSLv2 if the hello is
        # of SSLv2 format
        if client_hello.ssl2:
            sock.version = (0, 2)

        # save the record version used in the end for later analysis
        client_hello.record_version = sock.version

        messages = [client_hello]

        handshake_parser = HandshakeParser()

        try:
            sock.sendMessageBlocking(client_hello)
        except socket.error as e:
            messages.append(e)
            return messages
        except TLSAbruptCloseError as e:
            sock.sock.close()
            messages.append(e)
            return messages

        # get all the server messages that affect connection, abort as soon
        # as they've been read
        try:
            while True:
                header, parser = sock.recvMessageBlocking()

                if header.type == ContentType.alert:
                    alert = Alert()
                    alert.parse(parser)
                    alert.record_version = header.version
                    messages += [alert]
                elif header.type == ContentType.handshake:
                    msg = handshake_parser.parse(parser)
                    msg.record_version = header.version
                    messages += [msg]
                    if isinstance(msg, ServerHelloDone):
                        return messages
                else:
                    raise TypeError("Unknown content type: {0}"
                                    .format(header.type))
        except (TLSAbruptCloseError, TLSIllegalParameterException,
                ValueError, TypeError, socket.error, SyntaxError) as e:
            messages += [e]
            return messages
        finally:
            try:
                sock.sock.close()
            except (socket.error, OSError):
                pass