def test_write_with_next_protos(self):
        server_hello = ServerHello().create(
                (1,1),                          # server version
                bytearray(b'\x00'*31+b'\x02'),  # random
                bytearray(0),                   # session id
                4,                              # cipher suite
                0,                              # certificate type
                None,                           # TACK ext
                [b'spdy/3', b'http/1.1'])       # next protos advertised

        self.assertEqual(list(bytearray(
            b'\x02' +               # type of message - server_hello
            b'\x00\x00\x3c' +       # length
            b'\x01\x01' +           # proto version
            b'\x00'*31 + b'\x02' +  # random
            b'\x00' +               # session id length
            b'\x00\x04' +           # cipher suite
            b'\x00' +               # compression method
            b'\x00\x14' +           # extensions length
            b'\x33\x74' +           # ext type - NPN (13172)
            b'\x00\x10' +           # ext length - 16 bytes
            b'\x06' +               # first entry length - 6 bytes
            # utf-8 encoding of 'spdy/3'
            b'\x73\x70\x64\x79\x2f\x33'
            b'\x08' +               # second entry length - 8 bytes
            # utf-8 endoding of 'http/1.1'
            b'\x68\x74\x74\x70\x2f\x31\x2e\x31'
            )), list(server_hello.write()))
예제 #2
0
def filter(packetNo, data, source, target):
    bytes = stringToBytes(data)
    if packetNo == 0 and 'Client2Server' in str(source):
        p = Parser(bytes[5:])
        p.get(1)
        clientHello = ClientHello()
        clientHello.parse(p)
        print bcolors.OKGREEN + "Client supports TLS version: %s" % \
            str(clientHello.client_version)
        print "Client supports ciphersuites: %s" % \
            str([CIPHER_MAP.get(i,i) for i in clientHello.cipher_suites]) \
            + bcolors.ENDC
    elif packetNo == 0 and 'Client2Server' not in str(source):
        p = Parser(bytes[5:])
        p.get(1)
        serverHello = ServerHello()
        serverHello.parse(p)
        print bcolors.OKGREEN + "Server selected TLS version: %s" % \
            str(serverHello.server_version)
        print "Server selected ciphersuite: %s" % \
            str(CIPHER_MAP.get(serverHello.cipher_suite,
                               serverHello.cipher_suite)) + bcolors.ENDC

    target.write(data)        
    return data
    def test_parse_with_extensions_length_long_by_one(self):
        p = Parser(bytearray(
            # don't include type of message as it is handled by the hello
            # protocol layer
            # b'\x02' +                     # type of message - server_hello
            b'\x00\x00\x36' +               # length - 54 bytes
            b'\x03\x03' +                   # version - TLS 1.2
            b'\x01'*31 + b'\x02' +          # random
            b'\x00' +                       # session id length
            b'\x00\x9d' +                   # cipher suite
            b'\x01' +                       # compression method (zlib)
            b'\x00\x0f' +                   # extensions length - 15 bytes (!)
            b'\xff\x01' +                   # ext type - renegotiation_info
            b'\x00\x01' +                   # ext length - 1 byte
            b'\x00' +                       # value - supported (0)
            b'\x00\x23' +                   # ext type - session ticket (35)
            b'\x00\x00' +                   # ext length - 0 bytes
            b'\x00\x0f' +                   # ext type - heartbeat (15)
            b'\x00\x01' +                   # ext length - 1 byte
            b'\x01'))                       # peer allowed to send requests (1)
        server_hello = ServerHello()

        with self.assertRaises(SyntaxError) as context:
            server_hello.parse(p)

        # TODO the message could be more descriptive...
        self.assertIsNone(context.exception.msg)
    def test_parse(self):
        p = Parser(bytearray(
            # don't include type of message as it is handled by the hello
            # protocol layer
            # b'\x02' +                     # type of message - server_hello
            b'\x00\x00\x36' +               # length - 54 bytes
            b'\x03\x03' +                   # version - TLS 1.2
            b'\x01'*31 + b'\x02' +          # random
            b'\x00' +                       # session id length
            b'\x00\x9d' +                   # cipher suite
            b'\x01' +                       # compression method (zlib)
            b'\x00\x0e' +                   # extensions length - 14 bytes
            b'\xff\x01' +                   # ext type - renegotiation_info
            b'\x00\x01' +                   # ext length - 1 byte
            b'\x00' +                       # value - supported (0)
            b'\x00\x23' +                   # ext type - session ticket (35)
            b'\x00\x00' +                   # ext length - 0 bytes
            b'\x00\x0f' +                   # ext type - heartbeat (15)
            b'\x00\x01' +                   # ext length - 1 byte
            b'\x01'))                       # peer allowed to send requests (1)
        server_hello = ServerHello()
        server_hello = server_hello.parse(p)

        self.assertEqual((3,3), server_hello.server_version)
        self.assertEqual(bytearray(b'\x01'*31 + b'\x02'), server_hello.random)
        self.assertEqual(bytearray(0), server_hello.session_id)
        self.assertEqual(157, server_hello.cipher_suite)
        # XXX not sent by server!
        self.assertEqual(CertificateType.x509, server_hello.certificate_type)
        self.assertEqual(1, server_hello.compression_method)
        self.assertEqual(None, server_hello.tackExt)
        self.assertEqual(None, server_hello.next_protos_advertised)
    def test_process_with_not_matching_signature_algorithms(self):
        exp = ExpectServerKeyExchange(valid_sig_algs=[(HashAlgorithm.sha256,
                                                       SignatureAlgorithm.rsa)])

        state = ConnectionState()
        state.cipher = CipherSuite.TLS_DHE_RSA_WITH_AES_128_CBC_SHA

        cert = Certificate(CertificateType.x509).\
                create(X509CertChain([X509().parse(srv_raw_certificate)]))

        private_key = parsePEMKey(srv_raw_key, private=True)
        client_hello = ClientHello()
        client_hello.client_version = (3, 3)
        client_hello.random = bytearray(32)
        state.client_random = client_hello.random
        state.handshake_messages.append(client_hello)
        server_hello = ServerHello()
        server_hello.server_version = (3, 3)
        server_hello.random = bytearray(32)
        state.server_random = server_hello.random
        # server hello is not necessary for the test to work
        #state.handshake_messages.append(server_hello)
        state.handshake_messages.append(cert)
        srv_key_exchange = DHE_RSAKeyExchange(\
                CipherSuite.TLS_DHE_RSA_WITH_AES_128_CBC_SHA,
                client_hello,
                server_hello,
                private_key)

        msg = srv_key_exchange.makeServerKeyExchange('sha1')

        with self.assertRaises(TLSIllegalParameterException):
            exp.process(state, msg)
예제 #6
0
    def process(self, state, msg):
        """
        Process the message and update state accordingly

        @type state: ConnectionState
        @param state: overall state of TLS connection

        @type msg: Message
        @param msg: TLS Message read from socket
        """
        assert msg.contentType == ContentType.handshake

        parser = Parser(msg.write())
        hs_type = parser.get(1)
        assert hs_type == HandshakeType.server_hello

        srv_hello = ServerHello()
        srv_hello.parse(parser)

        # extract important info
        state.server_random = srv_hello.random

        # check for session_id based session resumption
        if self.resume:
            assert state.session_id == srv_hello.session_id
        if (state.session_id == srv_hello.session_id and
                srv_hello.session_id != bytearray(0)):
            state.resuming = True
            assert state.cipher == srv_hello.cipher_suite
            assert state.version == srv_hello.server_version
        state.session_id = srv_hello.session_id

        if self.version is not None:
            assert self.version == srv_hello.server_version

        state.cipher = srv_hello.cipher_suite
        state.version = srv_hello.server_version

        # update the state of connection
        state.msg_sock.version = srv_hello.server_version

        state.handshake_messages.append(srv_hello)
        state.handshake_hashes.update(msg.write())

        # check if the message has expected values
        if self.extensions is not None:
            for ext_id in self.extensions:
                ext = srv_hello.getExtension(ext_id)
                assert ext is not None
                # run extension-specific checker if present
                if self.extensions[ext_id] is not None:
                    self.extensions[ext_id](state, ext)
            # not supporting any extensions is valid
            if srv_hello.extensions is not None:
                for ext_id in (ext.extType for ext in srv_hello.extensions):
                    assert ext_id in self.extensions
    def test___str__(self):
        server_hello = ServerHello()
        server_hello = server_hello.create(
                (3,0),
                bytearray(b'\x00'*32),
                bytearray(b'\x01\x20'),
                34500,
                0,
                None,
                None)

        self.assertEqual("server_hello,length(40),version(3.0),random(...),"\
                "session ID(bytearray(b'\\x01 ')),cipher(0x86c4),"\
                "compression method(0)",
                str(server_hello))
    def test_parse_with_bad_cert_type_extension(self):
        p = Parser(bytearray(
            b'\x00\x00\x2e' +               # length - 46 bytes
            b'\x03\x03' +                   # version - TLS 1.2
            b'\x01'*31 + b'\x02' +          # random
            b'\x00' +                       # session id length
            b'\x00\x9d' +                   # cipher suite
            b'\x00' +                       # compression method (none)
            b'\x00\x06' +                   # extensions length - 5 bytes
            b'\x00\x09' +                   # ext type - cert_type (9)
            b'\x00\x02' +                   # ext length - 2 bytes
            b'\x00\x01'                     # value - X.509 (0), OpenPGP (1)
            ))

        server_hello = ServerHello()
        with self.assertRaises(SyntaxError) as context:
            server_hello.parse(p)
    def test_parse_with_length_long_by_one(self):
        p = Parser(bytearray(
            # don't include type of message as it is handled by the hello
            # protocol layer
            # b'\x02' +                     # type of message - server_hello
            b'\x00\x00\x27' +               # length - 39 bytes (one long)
            b'\x03\x03' +                   # version - TLS 1.2
            b'\x01'*31 + b'\x02' +          # random
            b'\x00' +                       # session id length
            b'\x00\x9d' +                   # cipher suite
            b'\x01'                         # compression method (zlib)
            ))
        server_hello = ServerHello()
        with self.assertRaises(SyntaxError) as context:
            server_hello.parse(p)

        # TODO the message probably could be more descriptive...
        self.assertIsNone(context.exception.msg)
    def test_write(self):
        server_hello = ServerHello().create(
                (1,1),                          # server version
                bytearray(b'\x00'*31+b'\x02'),  # random
                bytearray(0),                   # session id
                4,                              # cipher suite
                None,                           # certificate type
                None,                           # TACK ext
                None)                           # next protos advertised

        self.assertEqual(list(bytearray(
            b'\x02' +               # type of message - server_hello
            b'\x00\x00\x26' +       # length
            b'\x01\x01' +           # proto version
            b'\x00'*31 + b'\x02' +  # random
            b'\x00' +               # session id length
            b'\x00\x04' +           # cipher suite
            b'\x00'                 # compression method
            )), list(server_hello.write()))
 def test___repr__(self):
     server_hello = ServerHello()
     server_hello = server_hello.create(
             (3,0),
             bytearray(b'\x00'*32),
             bytearray(0),
             34500,
             0,
             None,
             None,
             extensions=[])
     self.maxDiff = None
     self.assertEqual("ServerHello(server_version=(3.0), "\
             "random=bytearray(b'\\x00\\x00"\
             "\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00"\
             "\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00"\
             "\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00'), "\
             "session_id=bytearray(b''), "\
             "cipher_suite=34500, compression_method=0, _tack_ext=None, "\
             "extensions=[])", repr(server_hello))
    def test_process_with_mandatory_resumption_but_wrong_id(self):
        exp = ExpectServerHello(resume=True)

        state = ConnectionState()
        state.msg_sock = mock.MagicMock()
        state.session_id = bytearray(b'\xaa\xaa\xaa')
        state.cipher = 4

        self.assertFalse(state.resuming)

        msg = ServerHello()
        msg.create(version=(3, 3),
                   random=bytearray(32),
                   session_id=bytearray(b'\xbb\xbb\xbb'),
                   cipher_suite=4)

        self.assertTrue(exp.is_match(msg))

        with self.assertRaises(AssertionError):
            exp.process(state, msg)
    def test_process_with_resumption(self):
        exp = ExpectServerHello()

        state = ConnectionState()
        state.msg_sock = mock.MagicMock()
        state.session_id = bytearray(b'\xaa\xaa\xaa')
        state.cipher = 4

        self.assertFalse(state.resuming)

        msg = ServerHello()
        msg.create(version=(3, 3),
                   random=bytearray(32),
                   session_id=bytearray(b'\xaa\xaa\xaa'),
                   cipher_suite=4)

        self.assertTrue(exp.is_match(msg))

        exp.process(state, msg)

        self.assertTrue(state.resuming)
예제 #14
0
    def test_process_with_incorrect_cipher(self):
        exp = ExpectServerHello(cipher=5)

        state = ConnectionState()
        state.msg_sock = mock.MagicMock()

        ext = RenegotiationInfoExtension().create(None)

        msg = ServerHello().create(version=(3, 3),
                                   random=bytearray(32),
                                   session_id=bytearray(0),
                                   cipher_suite=4)

        self.assertTrue(exp.is_match(msg))

        with self.assertRaises(AssertionError):
            exp.process(state, msg)
예제 #15
0
    def setUp(self):
        self.srv_private_key = parsePEMKey(srv_raw_key, private=True)
        srv_chain = X509CertChain([X509().parse(srv_raw_certificate)])
        self.srv_pub_key = srv_chain.getEndEntityPublicKey()
        self.cipher_suite = CipherSuite.TLS_DHE_RSA_WITH_AES_128_CBC_SHA
        self.client_hello = ClientHello().create((3, 3),
                                                 bytearray(32),
                                                 bytearray(0),
                                                 [])
        self.server_hello = ServerHello().create((3, 3),
                                                 bytearray(32),
                                                 bytearray(0),
                                                 self.cipher_suite)

        self.keyExchange = DHE_RSAKeyExchange(self.cipher_suite,
                                              self.client_hello,
                                              self.server_hello,
                                              self.srv_private_key)
예제 #16
0
    def test_signServerKeyExchange_in_TLS1_1(self):
        srv_private_key = parsePEMKey(srv_raw_key, private=True)
        client_hello = ClientHello()
        cipher_suite = CipherSuite.TLS_DHE_RSA_WITH_AES_128_CBC_SHA
        server_hello = ServerHello().create((3, 2),
                                            bytearray(32),
                                            bytearray(0),
                                            cipher_suite)
        keyExchange = KeyExchange(cipher_suite,
                                  client_hello,
                                  server_hello,
                                  srv_private_key)
        server_key_exchange = ServerKeyExchange(cipher_suite, (3, 2))\
                              .createDH(5, 2, 3)

        keyExchange.signServerKeyExchange(server_key_exchange)

        self.assertEqual(server_key_exchange.write(), self.expected_tls1_1_SKE)
예제 #17
0
    def test_process_with_udefined_cipher(self):
        exp = ExpectServerHello()

        state = ConnectionState()
        client_hello = ClientHello()
        client_hello.cipher_suites = [4]
        state.handshake_messages.append(client_hello)
        state.msg_sock = mock.MagicMock()

        ext = RenegotiationInfoExtension().create(None)

        msg = ServerHello().create(version=(3, 3),
                                   random=bytearray(32),
                                   session_id=bytearray(0),
                                   cipher_suite=0xfff0)

        self.assertTrue(exp.is_match(msg))

        with self.assertRaises(AssertionError):
            exp.process(state, msg)
예제 #18
0
    def test_process_with_unexpected_extensions(self):
        exp = ExpectServerHello(
            extensions={ExtensionType.renegotiation_info: None})

        state = ConnectionState()
        state.msg_sock = mock.MagicMock()

        exts = []
        exts.append(RenegotiationInfoExtension().create())
        exts.append(SNIExtension().create())
        msg = ServerHello().create(version=(3, 3),
                                   random=bytearray(32),
                                   session_id=bytearray(0),
                                   cipher_suite=4,
                                   extensions=exts)

        self.assertTrue(exp.is_match(msg))

        with self.assertRaises(AssertionError):
            exp.process(state, msg)
예제 #19
0
    def test_process_with_extended_master_secret(self):
        exp = ExpectServerHello(
            extensions={ExtensionType.extended_master_secret: None})

        state = ConnectionState()
        state.msg_sock = mock.MagicMock()
        self.assertFalse(state.extended_master_secret)

        ext = TLSExtension(extType=ExtensionType.extended_master_secret)
        msg = ServerHello().create(version=(3, 3),
                                   random=bytearray(32),
                                   session_id=bytearray(0),
                                   cipher_suite=4,
                                   extensions=[ext])

        self.assertTrue(exp.is_match(msg))

        exp.process(state, msg)

        self.assertTrue(state.extended_master_secret)
예제 #20
0
    def test_parse_with_NPN_extension(self):
        p = Parser(
            bytearray(b'\x00\x00\x3c' +  # length - 60 bytes
                      b'\x03\x03' +  # version - TLS 1.2
                      b'\x01' * 31 + b'\x02' +  # random
                      b'\x00' +  # session id length
                      b'\x00\x9d' +  # cipher suite
                      b'\x00' +  # compression method (none)
                      b'\x00\x14' +  # extensions length - 20 bytes
                      b'\x33\x74' +  # ext type - npn
                      b'\x00\x10' +  # ext length - 16 bytes
                      b'\x08' +  # length of first name - 8 bytes
                      b'http/1.1' +
                      b'\x06' +  # length of second name - 6 bytes
                      b'spdy/3'))

        server_hello = ServerHello().parse(p)

        self.assertEqual([bytearray(b'http/1.1'),
                          bytearray(b'spdy/3')], server_hello.next_protos)
예제 #21
0
    def test_create(self):
        server_hello = ServerHello().create(
            (1, 1),  # server version
            bytearray(b'\x00' * 31 + b'\x01'),  # random
            bytearray(0),  # session id
            4,  # cipher suite
            1,  # certificate type
            None,  # TACK ext
            None)  # next protos advertised

        self.assertEqual((1, 1), server_hello.server_version)
        self.assertEqual(bytearray(b'\x00' * 31 + b'\x01'),
                         server_hello.random)
        self.assertEqual(bytearray(0), server_hello.session_id)
        self.assertEqual(4, server_hello.cipher_suite)
        self.assertEqual(CertificateType.openpgp,
                         server_hello.certificate_type)
        self.assertEqual(0, server_hello.compression_method)
        self.assertEqual(None, server_hello.tackExt)
        self.assertEqual(None, server_hello.next_protos_advertised)
예제 #22
0
    def setUp(self):
        self.srv_private_key = parsePEMKey(srv_raw_key, private=True)
        srv_chain = X509CertChain([X509().parse(srv_raw_certificate)])
        self.srv_pub_key = srv_chain.getEndEntityPublicKey()
        self.cipher_suite = CipherSuite.TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA
        ext = [SupportedGroupsExtension().create([GroupName.secp256r1])]
        self.client_hello = ClientHello().create((3, 3),
                                                 bytearray(32),
                                                 bytearray(0),
                                                 [],
                                                 extensions=ext)
        self.server_hello = ServerHello().create((3, 3),
                                                 bytearray(32),
                                                 bytearray(0),
                                                 self.cipher_suite)

        self.keyExchange = ECDHE_RSAKeyExchange(self.cipher_suite,
                                                self.client_hello,
                                                self.server_hello,
                                                self.srv_private_key,
                                                [GroupName.secp256r1])
예제 #23
0
    def test_process_with_extensions(self):
        extension_process = mock.MagicMock()
        exp = ExpectServerHello(
            extensions={ExtensionType.renegotiation_info: extension_process})

        state = ConnectionState()
        state.msg_sock = mock.MagicMock()

        ext = RenegotiationInfoExtension().create()

        msg = ServerHello().create(version=(3, 3),
                                   random=bytearray(32),
                                   session_id=bytearray(0),
                                   cipher_suite=4,
                                   extensions=[ext])

        self.assertTrue(exp.is_match(msg))

        exp.process(state, msg)

        extension_process.assert_called_once_with(state, ext)
예제 #24
0
    def test_client_with_server_responing_with_wrong_session_id_in_TLS1_3(
            self):
        # socket to generate the faux response
        gen_sock = MockSocket(bytearray(0))

        gen_record_layer = RecordLayer(gen_sock)
        gen_record_layer.version = (3, 3)

        srv_ext = []
        srv_ext.append(SrvSupportedVersionsExtension().create((3, 4)))
        srv_ext.append(ServerKeyShareExtension().create(KeyShareEntry().create(
            GroupName.secp256r1, bytearray(b'\x03' + b'\x01' * 32))))

        server_hello = ServerHello().create(
            version=(3, 3),
            random=bytearray(32),
            session_id=bytearray(b"test"),
            cipher_suite=CipherSuite.TLS_AES_128_GCM_SHA256,
            certificate_type=None,
            tackExt=None,
            next_protos_advertised=None,
            extensions=srv_ext)

        for res in gen_record_layer.sendRecord(server_hello):
            if res in (0, 1):
                self.assertTrue(False, "Blocking socket")
            else:
                break

        # test proper
        sock = MockSocket(gen_sock.sent[0])

        conn = TLSConnection(sock)

        with self.assertRaises(TLSLocalAlert) as err:
            conn.handshakeClientCert()

        self.assertEqual(err.exception.description,
                         AlertDescription.illegal_parameter)
예제 #25
0
    def test_process_with_bad_extension(self):
        exps = {
            ExtensionType.renegotiation_info: None,
            ExtensionType.alpn: 'BAD_EXTENSION'
        }
        exp = ExpectServerHello(extensions=exps)

        state = ConnectionState()
        state.msg_sock = mock.MagicMock()

        exts = []
        exts.append(RenegotiationInfoExtension().create(None))
        exts.append(ALPNExtension().create([bytearray(b'http/1.1')]))
        msg = ServerHello().create(version=(3, 3),
                                   random=bytearray(32),
                                   session_id=bytearray(0),
                                   cipher_suite=4,
                                   extensions=exts)

        self.assertTrue(exp.is_match(msg))

        with self.assertRaises(ValueError):
            exp.process(state, msg)
예제 #26
0
    def test_process_with_matching_extension(self):
        exps = {
            ExtensionType.renegotiation_info: None,
            ExtensionType.alpn:
            ALPNExtension().create([bytearray(b'http/1.1')])
        }
        exp = ExpectServerHello(extensions=exps)

        state = ConnectionState()
        state.msg_sock = mock.MagicMock()

        exts = []
        exts.append(RenegotiationInfoExtension().create(None))
        exts.append(ALPNExtension().create([bytearray(b'http/1.1')]))
        msg = ServerHello().create(version=(3, 3),
                                   random=bytearray(32),
                                   session_id=bytearray(0),
                                   cipher_suite=4,
                                   extensions=exts)

        self.assertTrue(exp.is_match(msg))

        exp.process(state, msg)
        self.assertIsInstance(state.handshake_messages[0], ServerHello)
예제 #27
0
파일: expect.py 프로젝트: miradam/tlsfuzzer
    def process(self, state, msg):
        """
        Process the message and update state accordingly

        @type state: ConnectionState
        @param state: overall state of TLS connection

        @type msg: Message
        @param msg: TLS Message read from socket
        """
        assert msg.contentType == ContentType.handshake

        parser = Parser(msg.write())
        hs_type = parser.get(1)
        assert hs_type == HandshakeType.server_hello

        srv_hello = ServerHello()
        srv_hello.parse(parser)

        # extract important info
        state.server_random = srv_hello.random

        # check for session_id based session resumption
        if self.resume:
            assert state.session_id == srv_hello.session_id
        if (state.session_id == srv_hello.session_id
                and srv_hello.session_id != bytearray(0)):
            state.resuming = True
            assert state.cipher == srv_hello.cipher_suite
            assert state.version == srv_hello.server_version
        state.session_id = srv_hello.session_id

        if self.version is not None:
            assert self.version == srv_hello.server_version

        if self.cipher is not None:
            assert self.cipher == srv_hello.cipher_suite

        state.cipher = srv_hello.cipher_suite
        state.version = srv_hello.server_version

        # update the state of connection
        state.msg_sock.version = srv_hello.server_version

        state.handshake_messages.append(srv_hello)
        state.handshake_hashes.update(msg.write())

        # Reset value of the session-wide settings
        state.extended_master_secret = False
        state.encrypt_then_mac = False

        # check if the message has expected values
        if self.extensions is not None:
            for ext_id in self.extensions:
                ext = srv_hello.getExtension(ext_id)
                if ext is None:
                    raise AssertionError(
                        "Required extension {0} missing".format(
                            ExtensionType.toStr(ext_id)))
                # run extension-specific checker if present
                if self.extensions[ext_id] is not None:
                    if callable(self.extensions[ext_id]):
                        self.extensions[ext_id](state, ext)
                    elif isinstance(self.extensions[ext_id], TLSExtension):
                        if not self.extensions[ext_id] == ext:
                            raise AssertionError(
                                "Expected extension "
                                "not matched, received: {0}".format(ext))
                    else:
                        raise ValueError(
                            "Bad extension, id: {0}".format(ext_id))
                    continue
                if ext_id == ExtensionType.extended_master_secret:
                    state.extended_master_secret = True
                if ext_id == ExtensionType.encrypt_then_mac:
                    state.encrypt_then_mac = True
            # not supporting any extensions is valid
            if srv_hello.extensions is not None:
                for ext_id in (ext.extType for ext in srv_hello.extensions):
                    if ext_id not in self.extensions:
                        raise AssertionError(
                            "unexpected extension: {0}".format(
                                ExtensionType.toStr(ext_id)))
예제 #28
0
    def test(self):

        sock = MockSocket(server_hello_ciphertext)

        record_layer = RecordLayer(sock)

        ext = [
            SNIExtension().create(bytearray(b'server')),
            TLSExtension(extType=ExtensionType.renegotiation_info).create(
                bytearray(b'\x00')),
            SupportedGroupsExtension().create([
                GroupName.x25519, GroupName.secp256r1, GroupName.secp384r1,
                GroupName.secp521r1, GroupName.ffdhe2048, GroupName.ffdhe3072,
                GroupName.ffdhe4096, GroupName.ffdhe6144, GroupName.ffdhe8192
            ]),
            TLSExtension(extType=35),
            ClientKeyShareExtension().create([
                KeyShareEntry().create(GroupName.x25519, client_key_public,
                                       client_key_private)
            ]),
            SupportedVersionsExtension().create([(3, 4)]),
            SignatureAlgorithmsExtension().create([
                SignatureScheme.ecdsa_secp256r1_sha256,
                SignatureScheme.ecdsa_secp384r1_sha384,
                SignatureScheme.ecdsa_secp521r1_sha512,
                (HashAlgorithm.sha1, SignatureAlgorithm.ecdsa),
                SignatureScheme.rsa_pss_rsae_sha256,
                SignatureScheme.rsa_pss_rsae_sha384,
                SignatureScheme.rsa_pss_rsae_sha512,
                SignatureScheme.rsa_pkcs1_sha256,
                SignatureScheme.rsa_pkcs1_sha384,
                SignatureScheme.rsa_pkcs1_sha512,
                SignatureScheme.rsa_pkcs1_sha1,
                (HashAlgorithm.sha256, SignatureAlgorithm.dsa),
                (HashAlgorithm.sha384, SignatureAlgorithm.dsa),
                (HashAlgorithm.sha512, SignatureAlgorithm.dsa),
                (HashAlgorithm.sha1, SignatureAlgorithm.dsa)
            ]),
            TLSExtension(extType=45).create(bytearray(b'\x01\x01')),
            RecordSizeLimitExtension().create(16385)
        ]
        client_hello = ClientHello()
        client_hello.create((3, 3),
                            bytearray(b'\xcb4\xec\xb1\xe7\x81c'
                                      b'\xba\x1c8\xc6\xda\xcb'
                                      b'\x19jm\xff\xa2\x1a\x8d'
                                      b'\x99\x12\xec\x18\xa2'
                                      b'\xefb\x83\x02M\xec\xe7'),
                            bytearray(b''), [
                                CipherSuite.TLS_AES_128_GCM_SHA256,
                                CipherSuite.TLS_CHACHA20_POLY1305_SHA256,
                                CipherSuite.TLS_AES_256_GCM_SHA384
                            ],
                            extensions=ext)

        self.assertEqual(client_hello.write(), client_hello_ciphertext[5:])

        for result in record_layer.recvRecord():
            # check if non-blocking
            self.assertNotIn(result, (0, 1))
            break

        header, parser = result
        hs_type = parser.get(1)
        self.assertEqual(hs_type, HandshakeType.server_hello)
        server_hello = ServerHello().parse(parser)

        self.assertEqual(server_hello.server_version, (3, 3))
        self.assertEqual(server_hello.cipher_suite,
                         CipherSuite.TLS_AES_128_GCM_SHA256)

        server_key_share = server_hello.getExtension(ExtensionType.key_share)
        server_key_share = server_key_share.server_share

        self.assertEqual(server_key_share.group, GroupName.x25519)

        # for TLS_AES_128_GCM_SHA256:
        prf_name = 'sha256'
        prf_size = 256 // 8
        secret = bytearray(prf_size)
        psk = bytearray(prf_size)

        # early secret
        secret = secureHMAC(secret, psk, prf_name)

        self.assertEqual(
            secret,
            clean("""
                         33 ad 0a 1c 60 7e c0 3b 09 e6 cd 98 93 68 0c
                         e2 10 ad f3 00 aa 1f 26 60 e1 b2 2e 10 f1 70 f9 2a
                         """))

        # derive secret for handshake
        secret = derive_secret(secret, b"derived", None, prf_name)

        self.assertEqual(
            secret,
            clean("""
                         6f 26 15 a1 08 c7 02 c5 67 8f 54 fc 9d ba
                         b6 97 16 c0 76 18 9c 48 25 0c eb ea c3 57 6c 36 11 ba
                         """))

        # extract secret "handshake"
        Z = x25519(client_key_private, server_key_share.key_exchange)

        self.assertEqual(
            Z,
            clean("""
                         8b d4 05 4f b5 5b 9d 63 fd fb ac f9 f0 4b 9f 0d
                         35 e6 d6 3f 53 75 63 ef d4 62 72 90 0f 89 49 2d
                         """))

        secret = secureHMAC(secret, Z, prf_name)

        self.assertEqual(
            secret,
            clean("""
                         1d c8 26 e9 36 06 aa 6f dc 0a ad c1 2f 74 1b
                         01 04 6a a6 b9 9f 69 1e d2 21 a9 f0 ca 04 3f be ac
                         """))

        handshake_hashes = HandshakeHashes()
        handshake_hashes.update(client_hello_plaintext)
        handshake_hashes.update(server_hello_payload)

        # derive "tls13 c hs traffic"
        c_hs_traffic = derive_secret(secret, bytearray(b'c hs traffic'),
                                     handshake_hashes, prf_name)
        self.assertEqual(
            c_hs_traffic,
            clean("""
                         b3 ed db 12 6e 06 7f 35 a7 80 b3 ab f4 5e
                         2d 8f 3b 1a 95 07 38 f5 2e 96 00 74 6a 0e 27 a5 5a 21
                         """))
        s_hs_traffic = derive_secret(secret, bytearray(b's hs traffic'),
                                     handshake_hashes, prf_name)
        self.assertEqual(
            s_hs_traffic,
            clean("""
                         b6 7b 7d 69 0c c1 6c 4e 75 e5 42 13 cb 2d
                         37 b4 e9 c9 12 bc de d9 10 5d 42 be fd 59 d3 91 ad 38
                         """))

        # derive master secret
        secret = derive_secret(secret, b"derived", None, prf_name)

        self.assertEqual(
            secret,
            clean("""
                         43 de 77 e0 c7 77 13 85 9a 94 4d b9 db 25
                         90 b5 31 90 a6 5b 3e e2 e4 f1 2d d7 a0 bb 7c e2 54 b4
                         """))

        # extract secret "master"
        secret = secureHMAC(secret, bytearray(prf_size), prf_name)

        self.assertEqual(
            secret,
            clean("""
                         18 df 06 84 3d 13 a0 8b f2 a4 49 84 4c 5f 8a
                         47 80 01 bc 4d 4c 62 79 84 d5 a4 1d a8 d0 40 29 19
                         """))

        # derive write keys for handshake data
        server_hs_write_trafic_key = HKDF_expand_label(s_hs_traffic, b"key",
                                                       b"", 16, prf_name)

        self.assertEqual(
            server_hs_write_trafic_key,
            clean("""
                         3f ce 51 60 09 c2 17 27 d0 f2 e4 e8 6e
                         e4 03 bc
                         """))

        server_hs_write_trafic_iv = HKDF_expand_label(s_hs_traffic, b"iv", b"",
                                                      12, prf_name)

        self.assertEqual(
            server_hs_write_trafic_iv,
            clean("""
                         5d 31 3e b2 67 12 76 ee 13 00 0b 30
                         """))

        # derive key for Finished message
        server_finished_key = HKDF_expand_label(s_hs_traffic, b"finished", b"",
                                                prf_size, prf_name)
        self.assertEqual(
            server_finished_key,
            clean("""
                         00 8d 3b 66 f8 16 ea 55 9f 96 b5 37 e8 85
                         c3 1f c0 68 bf 49 2c 65 2f 01 f2 88 a1 d8 cd c1 9f c8
                         """))

        # Update the handshake transcript
        handshake_hashes.update(server_encrypted_extensions)
        handshake_hashes.update(server_certificate_message)
        handshake_hashes.update(server_certificateverify_message)
        hs_transcript = handshake_hashes.digest(prf_name)

        server_finished = secureHMAC(server_finished_key, hs_transcript,
                                     prf_name)

        self.assertEqual(
            server_finished,
            clean("""
                         9b 9b 14 1d 90 63 37 fb d2 cb dc e7 1d f4
                         de da 4a b4 2c 30 95 72 cb 7f ff ee 54 54 b7 8f 07 18
                         """))

        server_finished_message = Finished((3, 4)).create(server_finished)
        server_finished_payload = server_finished_message.write()

        # update handshake transcript to include Finished payload
        handshake_hashes.update(server_finished_payload)

        # derive keys for client application traffic
        c_ap_traffic = derive_secret(secret, b"c ap traffic", handshake_hashes,
                                     prf_name)

        self.assertEqual(
            c_ap_traffic,
            clean("""
                         9e 40 64 6c e7 9a 7f 9d c0 5a f8 88 9b ce
                         65 52 87 5a fa 0b 06 df 00 87 f7 92 eb b7 c1 75 04 a5
                         """))

        # derive keys for server application traffic
        s_ap_traffic = derive_secret(secret, b"s ap traffic", handshake_hashes,
                                     prf_name)

        self.assertEqual(
            s_ap_traffic,
            clean("""
                         a1 1a f9 f0 55 31 f8 56 ad 47 11 6b 45 a9
                         50 32 82 04 b4 f4 4b fb 6b 3a 4b 4f 1f 3f cb 63 16 43
                         """))

        # derive exporter master secret
        exp_master = derive_secret(secret, b"exp master", handshake_hashes,
                                   prf_name)

        self.assertEqual(
            exp_master,
            clean("""
                         fe 22 f8 81 17 6e da 18 eb 8f 44 52 9e 67
                         92 c5 0c 9a 3f 89 45 2f 68 d8 ae 31 1b 43 09 d3 cf 50
                         """))

        # derive write traffic keys for app data
        server_write_traffic_key = HKDF_expand_label(s_ap_traffic, b"key", b"",
                                                     16, prf_name)

        self.assertEqual(
            server_write_traffic_key,
            clean("""
                         9f 02 28 3b 6c 9c 07 ef c2 6b b9 f2 ac
                         92 e3 56
                         """))

        server_write_traffic_iv = HKDF_expand_label(s_ap_traffic, b"iv", b"",
                                                    12, prf_name)

        self.assertEqual(
            server_write_traffic_iv,
            clean("""
                         cf 78 2b 88 dd 83 54 9a ad f1 e9 84
                         """))

        # derive read traffic keys for app data
        server_read_hs_key = HKDF_expand_label(c_hs_traffic, b"key", b"", 16,
                                               prf_name)

        self.assertEqual(
            server_read_hs_key,
            clean("""
                         db fa a6 93 d1 76 2c 5b 66 6a f5 d9 50
                         25 8d 01
                         """))

        server_read_hs_iv = HKDF_expand_label(c_hs_traffic, b"iv", b"", 12,
                                              prf_name)

        self.assertEqual(
            server_read_hs_iv,
            clean("""
                         5b d3 c7 1b 83 6e 0b 76 bb 73 26 5f
                         """))
    def test_full_connection_with_RSA_kex(self):

        clnt_sock, srv_sock = socket.socketpair()

        #
        # client part
        #
        record_layer = TLSRecordLayer(clnt_sock)

        record_layer._handshakeStart(client=True)
        record_layer.version = (3, 3)

        client_hello = ClientHello()
        client_hello = client_hello.create(
            (3, 3), bytearray(32), bytearray(0),
            [CipherSuite.TLS_RSA_WITH_AES_128_CBC_SHA], None, None, False,
            False, None)

        for result in record_layer._sendMsg(client_hello):
            if result in (0, 1):
                raise Exception("blocking socket")

        #
        # server part
        #

        srv_record_layer = TLSRecordLayer(srv_sock)

        srv_raw_certificate = str(
            "-----BEGIN CERTIFICATE-----\n"\
            "MIIB9jCCAV+gAwIBAgIJAMyn9DpsTG55MA0GCSqGSIb3DQEBCwUAMBQxEjAQBgNV\n"\
            "BAMMCWxvY2FsaG9zdDAeFw0xNTAxMjExNDQzMDFaFw0xNTAyMjAxNDQzMDFaMBQx\n"\
            "EjAQBgNVBAMMCWxvY2FsaG9zdDCBnzANBgkqhkiG9w0BAQEFAAOBjQAwgYkCgYEA\n"\
            "0QkEeakSyV/LMtTeARdRtX5pdbzVuUuqOIdz3lg7YOyRJ/oyLTPzWXpKxr//t4FP\n"\
            "QvYsSJiVOlPk895FNu6sNF/uJQyQGfFWYKkE6fzFifQ6s9kssskFlL1DVI/dD/Zn\n"\
            "7sgzua2P1SyLJHQTTs1MtMb170/fX2EBPkDz+2kYKN0CAwEAAaNQME4wHQYDVR0O\n"\
            "BBYEFJtvXbRmxRFXYVMOPH/29pXCpGmLMB8GA1UdIwQYMBaAFJtvXbRmxRFXYVMO\n"\
            "PH/29pXCpGmLMAwGA1UdEwQFMAMBAf8wDQYJKoZIhvcNAQELBQADgYEAkOgC7LP/\n"\
            "Rd6uJXY28HlD2K+/hMh1C3SRT855ggiCMiwstTHACGgNM+AZNqt6k8nSfXc6k1gw\n"\
            "5a7SGjzkWzMaZC3ChBeCzt/vIAGlMyXeqTRhjTCdc/ygRv3NPrhUKKsxUYyXRk5v\n"\
            "g/g6MwxzXfQP3IyFu3a9Jia/P89Z1rQCNRY=\n"\
            "-----END CERTIFICATE-----\n"\
            )

        srv_raw_key = str(
            "-----BEGIN RSA PRIVATE KEY-----\n"\
            "MIICXQIBAAKBgQDRCQR5qRLJX8sy1N4BF1G1fml1vNW5S6o4h3PeWDtg7JEn+jIt\n"\
            "M/NZekrGv/+3gU9C9ixImJU6U+Tz3kU27qw0X+4lDJAZ8VZgqQTp/MWJ9Dqz2Syy\n"\
            "yQWUvUNUj90P9mfuyDO5rY/VLIskdBNOzUy0xvXvT99fYQE+QPP7aRgo3QIDAQAB\n"\
            "AoGAVSLbE8HsyN+fHwDbuo4I1Wa7BRz33xQWLBfe9TvyUzOGm0WnkgmKn3LTacdh\n"\
            "GxgrdBZXSun6PVtV8I0im5DxyVaNdi33sp+PIkZU386f1VUqcnYnmgsnsUQEBJQu\n"\
            "fUZmgNM+bfR+Rfli4Mew8lQ0sorZ+d2/5fsM0g80Qhi5M3ECQQDvXeCyrcy0u/HZ\n"\
            "FNjIloyXaAIvavZ6Lc6gfznCSfHc5YwplOY7dIWp8FRRJcyXkA370l5dJ0EXj5Gx\n"\
            "udV9QQ43AkEA34+RxjRk4DT7Zo+tbM/Fkoi7jh1/0hFkU5NDHweJeH/mJseiHtsH\n"\
            "KOcPGtEGBBqT2KNPWVz4Fj19LiUmmjWXiwJBAIBs49O5/+ywMdAAqVblv0S0nweF\n"\
            "4fwne4cM+5ZMSiH0XsEojGY13EkTEon/N8fRmE8VzV85YmkbtFWgmPR85P0CQQCs\n"\
            "elWbN10EZZv3+q1wH7RsYzVgZX3yEhz3JcxJKkVzRCnKjYaUi6MweWN76vvbOq4K\n"\
            "G6Tiawm0Duh/K4ZmvyYVAkBppE5RRQqXiv1KF9bArcAJHvLm0vnHPpf1yIQr5bW6\n"\
            "njBuL4qcxlaKJVGRXT7yFtj2fj0gv3914jY2suWqp8XJ\n"\
            "-----END RSA PRIVATE KEY-----\n"\
            )

        srv_private_key = parsePEMKey(srv_raw_key, private=True)
        srv_cert_chain = X509CertChain([X509().parse(srv_raw_certificate)])

        srv_record_layer._handshakeStart(client=False)

        srv_record_layer.version = (3, 3)

        for result in srv_record_layer._getMsg(ContentType.handshake,
                                               HandshakeType.client_hello):
            if result in (0, 1):
                raise Exception("blocking socket")
            else:
                break

        srv_client_hello = result
        self.assertEqual(ClientHello, type(srv_client_hello))

        srv_cipher_suite = CipherSuite.TLS_RSA_WITH_AES_128_CBC_SHA
        srv_session_id = bytearray(0)

        srv_server_hello = ServerHello().create(
            (3, 3), bytearray(32), srv_session_id, srv_cipher_suite,
            CertificateType.x509, None, None)

        srv_msgs = []
        srv_msgs.append(srv_server_hello)
        srv_msgs.append(
            Certificate(CertificateType.x509).create(srv_cert_chain))
        srv_msgs.append(ServerHelloDone())
        for result in srv_record_layer._sendMsgs(srv_msgs):
            if result in (0, 1):
                raise Exception("blocking socket")
            else:
                break
        srv_record_layer._versionCheck = True

        #
        # client part
        #

        for result in record_layer._getMsg(ContentType.handshake,
                                           HandshakeType.server_hello):
            if result in (0, 1):
                raise Exception("blocking socket")
            else:
                break

        server_hello = result
        self.assertEqual(ServerHello, type(server_hello))

        for result in record_layer._getMsg(ContentType.handshake,
                                           HandshakeType.certificate,
                                           CertificateType.x509):
            if result in (0, 1):
                raise Exception("blocking socket")
            else:
                break

        server_certificate = result
        self.assertEqual(Certificate, type(server_certificate))

        for result in record_layer._getMsg(ContentType.handshake,
                                           HandshakeType.server_hello_done):
            if result in (0, 1):
                raise Exception("blocking socket")
            else:
                break

        server_hello_done = result
        self.assertEqual(ServerHelloDone, type(server_hello_done))

        public_key = server_certificate.cert_chain.getEndEntityPublicKey()

        premasterSecret = bytearray(48)
        premasterSecret[0] = 3  # 'cause we negotiatied TLSv1.2
        premasterSecret[1] = 3

        encryptedPreMasterSecret = public_key.encrypt(premasterSecret)

        client_key_exchange = ClientKeyExchange(
            CipherSuite.TLS_RSA_WITH_AES_128_CBC_SHA, (3, 3))
        client_key_exchange.createRSA(encryptedPreMasterSecret)

        for result in record_layer._sendMsg(client_key_exchange):
            if result in (0, 1):
                raise Exception("blocking socket")
            else:
                break

        master_secret = calc_key((3, 3),
                                 premasterSecret,
                                 CipherSuite.TLS_RSA_WITH_AES_128_CBC_SHA,
                                 b"master secret",
                                 client_random=client_hello.random,
                                 server_random=server_hello.random,
                                 output_length=48)

        record_layer._calcPendingStates(
            CipherSuite.TLS_RSA_WITH_AES_128_CBC_SHA, master_secret,
            client_hello.random, server_hello.random, None)

        for result in record_layer._sendMsg(ChangeCipherSpec()):
            if result in (0, 1):
                raise Exception("blocking socket")
            else:
                break

        record_layer._changeWriteState()

        handshake_hashes = record_layer._handshake_hash.digest('sha256')
        verify_data = PRF_1_2(master_secret, b'client finished',
                              handshake_hashes, 12)

        finished = Finished((3, 3)).create(verify_data)
        for result in record_layer._sendMsg(finished):
            if result in (0, 1):
                raise Exception("blocking socket")
            else:
                break

        #
        # server part
        #

        for result in srv_record_layer._getMsg(
                ContentType.handshake, HandshakeType.client_key_exchange,
                srv_cipher_suite):
            if result in (0, 1):
                raise Exception("blocking socket")
            else:
                break

        srv_client_key_exchange = result

        srv_premaster_secret = srv_private_key.decrypt(
            srv_client_key_exchange.encryptedPreMasterSecret)

        self.assertEqual(bytearray(b'\x03\x03' + b'\x00' * 46),
                         srv_premaster_secret)

        srv_master_secret = calc_key(srv_record_layer.version,
                                     srv_premaster_secret,
                                     CipherSuite.TLS_RSA_WITH_AES_128_CBC_SHA,
                                     b"master secret",
                                     client_random=srv_client_hello.random,
                                     server_random=srv_server_hello.random,
                                     output_length=48)

        srv_record_layer._calcPendingStates(srv_cipher_suite,
                                            srv_master_secret,
                                            srv_client_hello.random,
                                            srv_server_hello.random, None)

        for result in srv_record_layer._getMsg(ContentType.change_cipher_spec):
            if result in (0, 1):
                raise Exception("blocking socket")
            else:
                break

        srv_change_cipher_spec = result
        self.assertEqual(ChangeCipherSpec, type(srv_change_cipher_spec))

        srv_record_layer._changeReadState()

        srv_handshakeHashes = srv_record_layer._handshake_hash.digest('sha256')
        srv_verify_data = PRF_1_2(srv_master_secret, b"client finished",
                                  srv_handshakeHashes, 12)

        for result in srv_record_layer._getMsg(ContentType.handshake,
                                               HandshakeType.finished):
            if result in (0, 1):
                raise Exception("blocking socket")
            else:
                break
        srv_finished = result
        self.assertEqual(Finished, type(srv_finished))
        self.assertEqual(srv_verify_data, srv_finished.verify_data)

        for result in srv_record_layer._sendMsg(ChangeCipherSpec()):
            if result in (0, 1):
                raise Exception("blocking socket")
            else:
                break

        srv_record_layer._changeWriteState()

        srv_handshakeHashes = srv_record_layer._handshake_hash.digest('sha256')
        srv_verify_data = PRF_1_2(srv_master_secret, b"server finished",
                                  srv_handshakeHashes, 12)

        for result in srv_record_layer._sendMsg(
                Finished((3, 3)).create(srv_verify_data)):
            if result in (0, 1):
                raise Exception("blocking socket")
            else:
                break

        srv_record_layer._handshakeDone(resumed=False)

        #
        # client part
        #

        for result in record_layer._getMsg(ContentType.change_cipher_spec):
            if result in (0, 1):
                raise Exception("blocking socket")
            else:
                break

        change_cipher_spec = result
        self.assertEqual(ChangeCipherSpec, type(change_cipher_spec))

        record_layer._changeReadState()

        handshake_hashes = record_layer._handshake_hash.digest('sha256')
        server_verify_data = PRF_1_2(master_secret, b'server finished',
                                     handshake_hashes, 12)

        for result in record_layer._getMsg(ContentType.handshake,
                                           HandshakeType.finished):
            if result in (0, 1):
                raise Exception("blocking socket")
            else:
                break

        server_finished = result
        self.assertEqual(Finished, type(server_finished))
        self.assertEqual(server_verify_data, server_finished.verify_data)

        record_layer._handshakeDone(resumed=False)

        # try sending data
        record_layer.write(bytearray(b'text\n'))

        # try recieving data
        data = srv_record_layer.read(10)
        self.assertEqual(data, bytearray(b'text\n'))

        record_layer.close()
        srv_record_layer.close()
예제 #30
0
    def process(self, state, msg):
        """
        Process the message and update state accordingly

        @type state: ConnectionState
        @param state: overall state of TLS connection

        @type msg: Message
        @param msg: TLS Message read from socket
        """
        assert msg.contentType == ContentType.handshake

        parser = Parser(msg.write())
        hs_type = parser.get(1)
        assert hs_type == HandshakeType.server_hello

        srv_hello = ServerHello()
        srv_hello.parse(parser)

        # extract important info
        state.server_random = srv_hello.random

        # check for session_id based session resumption
        if self.resume:
            assert state.session_id == srv_hello.session_id
        if (state.session_id == srv_hello.session_id and
                srv_hello.session_id != bytearray(0) and
                self._extract_version(srv_hello) < (3, 4)):
            # TLS 1.2 resumption, TLS 1.3 is based on PSKs
            state.resuming = True
            assert state.cipher == srv_hello.cipher_suite
            assert state.version == self._extract_version(srv_hello)
        state.session_id = srv_hello.session_id

        if self.version is not None:
            assert self.version == srv_hello.server_version

        if self.cipher is not None:
            assert self.cipher == srv_hello.cipher_suite

        # check if server sent cipher matches what we advertised in CH
        cln_hello = state.get_last_message_of_type(ClientHello)
        if srv_hello.cipher_suite not in cln_hello.cipher_suites:
            cipher = srv_hello.cipher_suite
            if cipher in CipherSuite.ietfNames:
                name = "{0} ({1:#06x})".format(CipherSuite.ietfNames[cipher],
                                               cipher)
            else:
                name = "{0:#06x}".format(cipher)
            raise AssertionError("Server responded with cipher we did"
                                 " not advertise: {0}".format(name))

        state.cipher = srv_hello.cipher_suite
        state.version = self._extract_version(srv_hello)

        # update the state of connection
        state.msg_sock.version = state.version
        state.msg_sock.tls13record = state.version > (3, 3)

        self._check_against_hrr(state, srv_hello)

        state.handshake_messages.append(srv_hello)
        state.handshake_hashes.update(msg.write())

        # Reset value of the session-wide settings
        state.extended_master_secret = False
        state.encrypt_then_mac = False

        if srv_hello.extensions:
            self._process_extensions(state, cln_hello, srv_hello)

        self._compare_extensions(srv_hello)

        if state.version > (3, 3):
            self._setup_tls13_handshake_keys(state)
        return srv_hello
예제 #31
0
    def process(self, state, msg):
        """
        Process the message and update state accordingly

        @type state: ConnectionState
        @param state: overall state of TLS connection

        @type msg: Message
        @param msg: TLS Message read from socket
        """
        assert msg.contentType == ContentType.handshake

        parser = Parser(msg.write())
        hs_type = parser.get(1)
        assert hs_type == HandshakeType.server_hello

        srv_hello = ServerHello()
        srv_hello.parse(parser)

        # extract important info
        state.server_random = srv_hello.random

        # check for session_id based session resumption
        if self.resume:
            assert state.session_id == srv_hello.session_id
        if (state.session_id == srv_hello.session_id and
                srv_hello.session_id != bytearray(0)):
            state.resuming = True
            assert state.cipher == srv_hello.cipher_suite
            assert state.version == srv_hello.server_version
        state.session_id = srv_hello.session_id

        if self.version is not None:
            assert self.version == srv_hello.server_version

        state.cipher = srv_hello.cipher_suite
        state.version = srv_hello.server_version

        # update the state of connection
        state.msg_sock.version = srv_hello.server_version

        state.handshake_messages.append(srv_hello)
        state.handshake_hashes.update(msg.write())

        # Reset value of the session-wide settings
        state.extended_master_secret = False

        # check if the message has expected values
        if self.extensions is not None:
            for ext_id in self.extensions:
                ext = srv_hello.getExtension(ext_id)
                assert ext is not None
                # run extension-specific checker if present
                if self.extensions[ext_id] is not None:
                    self.extensions[ext_id](state, ext)
                if ext_id == ExtensionType.extended_master_secret:
                    state.extended_master_secret = True
            # not supporting any extensions is valid
            if srv_hello.extensions is not None:
                for ext_id in (ext.extType for ext in srv_hello.extensions):
                    assert ext_id in self.extensions
예제 #32
0
 def test_makeClientKeyExchange(self):
     srv_h = ServerHello().create((3, 3), bytearray(32), bytearray(0), 0)
     keyExchange = KeyExchange(0, None, srv_h, None)
     self.assertIsInstance(keyExchange.makeClientKeyExchange(),
                           ClientKeyExchange)
예제 #33
0
    def test(self):

        sock = MockSocket(server_hello_ciphertext)

        record_layer = RecordLayer(sock)

        ext = [SNIExtension().create(bytearray(b'server')),
               TLSExtension(extType=ExtensionType.renegotiation_info)
               .create(bytearray(b'\x00')),
               SupportedGroupsExtension().create([GroupName.x25519,
                                                  GroupName.secp256r1,
                                                  GroupName.secp384r1,
                                                  GroupName.secp521r1,
                                                  GroupName.ffdhe2048,
                                                  GroupName.ffdhe3072,
                                                  GroupName.ffdhe4096,
                                                  GroupName.ffdhe6144,
                                                  GroupName.ffdhe8192]),
               ECPointFormatsExtension().create([ECPointFormat.uncompressed]),
               TLSExtension(extType=35),
               ClientKeyShareExtension().create([KeyShareEntry().create(GroupName.x25519,
                                                client_key_public,
                                                client_key_private)]),
               SupportedVersionsExtension().create([TLS_1_3_DRAFT,
                                                    (3, 3), (3, 2)]),
               SignatureAlgorithmsExtension().create([(HashAlgorithm.sha256,
                                                       SignatureAlgorithm.ecdsa),
                                                      (HashAlgorithm.sha384,
                                                       SignatureAlgorithm.ecdsa),
                                                      (HashAlgorithm.sha512,
                                                       SignatureAlgorithm.ecdsa),
                                                      (HashAlgorithm.sha1,
                                                       SignatureAlgorithm.ecdsa),
                                                      SignatureScheme.rsa_pss_sha256,
                                                      SignatureScheme.rsa_pss_sha384,
                                                      SignatureScheme.rsa_pss_sha512,
                                                      SignatureScheme.rsa_pkcs1_sha256,
                                                      SignatureScheme.rsa_pkcs1_sha384,
                                                      SignatureScheme.rsa_pkcs1_sha512,
                                                      SignatureScheme.rsa_pkcs1_sha1,
                                                      (HashAlgorithm.sha256,
                                                       SignatureAlgorithm.dsa),
                                                      (HashAlgorithm.sha384,
                                                       SignatureAlgorithm.dsa),
                                                      (HashAlgorithm.sha512,
                                                       SignatureAlgorithm.dsa),
                                                      (HashAlgorithm.sha1,
                                                       SignatureAlgorithm.dsa)]),
                TLSExtension(extType=45).create(bytearray(b'\x01\x01')),
                TLSExtension(extType=ExtensionType.client_hello_padding)
                .create(bytearray(252))
               ]
        client_hello = ClientHello()
        client_hello.create((3, 3),
                            bytearray(b'\xaf!\x15k\x04\xdbc\x9ef\x15J\x1f\xe5'
                                      b'\xad\xfa\xea\xdf\x9eA4\x16\x00\rW\xb8'
                                      b'\xe1\x12mM\x11\x9a\x8b'),
                            bytearray(b''),
                            [CipherSuite.TLS_AES_128_GCM_SHA256,
                             CipherSuite.TLS_CHACHA20_POLY1305_SHA256,
                             CipherSuite.TLS_AES_256_GCM_SHA384,
                             CipherSuite.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256,
                             CipherSuite.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256,
                             0xCCA9,
                             CipherSuite.TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256,
                             CipherSuite.TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA,
                             CipherSuite.TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA,
                             CipherSuite.TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA,
                             CipherSuite.TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA256,
                             CipherSuite.TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA256,
                             CipherSuite.TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA,
                             CipherSuite.TLS_DHE_RSA_WITH_AES_128_GCM_SHA256,
                             CipherSuite.TLS_DHE_RSA_WITH_CHACHA20_POLY1305_SHA256,
                             CipherSuite.TLS_DHE_RSA_WITH_AES_128_CBC_SHA,
                             0x0032,
                             CipherSuite.TLS_DHE_RSA_WITH_AES_128_CBC_SHA256,
                             CipherSuite.TLS_DHE_RSA_WITH_AES_256_CBC_SHA,
                             0x0038,
                             CipherSuite.TLS_DHE_RSA_WITH_AES_256_CBC_SHA256,
                             CipherSuite.TLS_DHE_RSA_WITH_3DES_EDE_CBC_SHA,
                             0x0013,
                             CipherSuite.TLS_RSA_WITH_AES_128_GCM_SHA256,
                             CipherSuite.TLS_RSA_WITH_AES_128_CBC_SHA,
                             CipherSuite.TLS_RSA_WITH_AES_128_CBC_SHA256,
                             CipherSuite.TLS_RSA_WITH_AES_256_CBC_SHA,
                             CipherSuite.TLS_RSA_WITH_AES_256_CBC_SHA256,
                             CipherSuite.TLS_RSA_WITH_3DES_EDE_CBC_SHA,
                             CipherSuite.TLS_RSA_WITH_RC4_128_SHA,
                             CipherSuite.TLS_RSA_WITH_RC4_128_MD5],
                            extensions=ext)

        self.assertEqual(client_hello.write(), client_hello_ciphertext[5:])

        for result in record_layer.recvRecord():
            # check if non-blocking
            self.assertNotIn(result, (0, 1))
        header, parser = result
        hs_type = parser.get(1)
        self.assertEqual(hs_type, HandshakeType.server_hello)
        server_hello = ServerHello().parse(parser)

        self.assertEqual(server_hello.server_version, TLS_1_3_DRAFT)
        self.assertEqual(server_hello.cipher_suite, CipherSuite.TLS_AES_128_GCM_SHA256)

        server_key_share = server_hello.getExtension(ExtensionType.key_share)
        server_key_share = server_key_share.server_share

        self.assertEqual(server_key_share.group, GroupName.x25519)

        # for TLS_AES_128_GCM_SHA256:
        prf_name = 'sha256'
        prf_size = 256 // 8
        secret = bytearray(prf_size)
        psk = bytearray(prf_size)

        # early secret
        secret = secureHMAC(secret, psk, prf_name)

        self.assertEqual(secret,
                         str_to_bytearray(
                             "33ad0a1c607ec03b 09e6cd9893680ce2"
                             "10adf300aa1f2660 e1b22e10f170f92a"))

        # derive secret for handshake
        secret = derive_secret(secret, b"derived", None, prf_name)

        self.assertEqual(secret,
                         str_to_bytearray(
                             "6f2615a108c702c5 678f54fc9dbab697"
                             "16c076189c48250c ebeac3576c3611ba"))

        # extract secret "handshake"
        Z = x25519(client_key_private, server_key_share.key_exchange)

        self.assertEqual(Z,
                         str_to_bytearray(
                             "f677c3cdac26a755 455b130efa9b1a3f"
                             "3cafb153544ca46a ddf670df199d996e"))

        secret = secureHMAC(secret, Z, prf_name)

        self.assertEqual(secret,
                         str_to_bytearray(
                             "0cefce00d5d29fd0 9f5de36c86fc8e72"
                             "99b4ad11ba4211c6 7063c2cc539fc4f9"))

        handshake_hashes = HandshakeHashes()
        handshake_hashes.update(client_hello_plaintext)
        handshake_hashes.update(server_hello_payload)

        # derive "tls13 c hs traffic"
        c_hs_traffic = derive_secret(secret,
                                     bytearray(b'c hs traffic'),
                                     handshake_hashes,
                                     prf_name)
        self.assertEqual(c_hs_traffic,
                         str_to_bytearray(
                             "5a63db760b817b1b da96e72832333aec"
                             "6a177deeadb5b407 501ac10c17dac0a4"))
        s_hs_traffic = derive_secret(secret,
                                     bytearray(b's hs traffic'),
                                     handshake_hashes,
                                     prf_name)
        self.assertEqual(s_hs_traffic,
                         str_to_bytearray(
                             "3aa72a3c77b791e8 f4de243f9ccce172"
                             "941f8392aeb05429 320f4b572ccfe744"))

        # derive master secret
        secret = derive_secret(secret, b"derived", None, prf_name)

        self.assertEqual(secret,
                         str_to_bytearray(
                             "32cadf38f3089048 5c54bf4f1184eaa5"
                             "569eeef15a43f3c7 6ab33965a47c9ff6"))

        # extract secret "master
        secret = secureHMAC(secret, bytearray(prf_size), prf_name)

        self.assertEqual(secret,
                         str_to_bytearray(
                             "6c6d4b3e7c925460 82d7b7a32f6ce219"
                             "3804f1bb930fed74 5c6b93c71397f424"))
예제 #34
0
    def test(self):

        sock = MockSocket(server_hello_ciphertext)

        record_layer = RecordLayer(sock)

        ext = [SNIExtension().create(bytearray(b'server')),
               TLSExtension(extType=ExtensionType.renegotiation_info)
               .create(bytearray(b'\x00')),
               SupportedGroupsExtension().create([GroupName.x25519,
                                                  GroupName.secp256r1,
                                                  GroupName.secp384r1,
                                                  GroupName.secp521r1,
                                                  GroupName.ffdhe2048,
                                                  GroupName.ffdhe3072,
                                                  GroupName.ffdhe4096,
                                                  GroupName.ffdhe6144,
                                                  GroupName.ffdhe8192]),
               ECPointFormatsExtension().create([ECPointFormat.uncompressed]),
               TLSExtension(extType=35),
               ClientKeyShareExtension().create([KeyShareEntry().create(GroupName.x25519,
                                                client_key_public,
                                                client_key_private)]),
               SupportedVersionsExtension().create([TLS_1_3_DRAFT,
                                                    (3, 3), (3, 2)]),
               SignatureAlgorithmsExtension().create([(HashAlgorithm.sha256,
                                                       SignatureAlgorithm.ecdsa),
                                                      (HashAlgorithm.sha384,
                                                       SignatureAlgorithm.ecdsa),
                                                      (HashAlgorithm.sha512,
                                                       SignatureAlgorithm.ecdsa),
                                                      (HashAlgorithm.sha1,
                                                       SignatureAlgorithm.ecdsa),
                                                      SignatureScheme.rsa_pss_sha256,
                                                      SignatureScheme.rsa_pss_sha384,
                                                      SignatureScheme.rsa_pss_sha512,
                                                      SignatureScheme.rsa_pkcs1_sha256,
                                                      SignatureScheme.rsa_pkcs1_sha384,
                                                      SignatureScheme.rsa_pkcs1_sha512,
                                                      SignatureScheme.rsa_pkcs1_sha1,
                                                      (HashAlgorithm.sha256,
                                                       SignatureAlgorithm.dsa),
                                                      (HashAlgorithm.sha384,
                                                       SignatureAlgorithm.dsa),
                                                      (HashAlgorithm.sha512,
                                                       SignatureAlgorithm.dsa),
                                                      (HashAlgorithm.sha1,
                                                       SignatureAlgorithm.dsa)]),
                TLSExtension(extType=45).create(bytearray(b'\x01\x01')),
                TLSExtension(extType=ExtensionType.client_hello_padding)
                .create(bytearray(252))
               ]
        client_hello = ClientHello()
        client_hello.create((3, 3),
                            bytearray(b'\xaf!\x15k\x04\xdbc\x9ef\x15J\x1f\xe5'
                                      b'\xad\xfa\xea\xdf\x9eA4\x16\x00\rW\xb8'
                                      b'\xe1\x12mM\x11\x9a\x8b'),
                            bytearray(b''),
                            [CipherSuite.TLS_AES_128_GCM_SHA256,
                             CipherSuite.TLS_CHACHA20_POLY1305_SHA256,
                             CipherSuite.TLS_AES_256_GCM_SHA384,
                             CipherSuite.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256,
                             CipherSuite.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256,
                             0xCCA9,
                             CipherSuite.TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256,
                             CipherSuite.TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA,
                             CipherSuite.TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA,
                             CipherSuite.TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA,
                             CipherSuite.TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA256,
                             CipherSuite.TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA256,
                             CipherSuite.TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA,
                             CipherSuite.TLS_DHE_RSA_WITH_AES_128_GCM_SHA256,
                             CipherSuite.TLS_DHE_RSA_WITH_CHACHA20_POLY1305_SHA256,
                             CipherSuite.TLS_DHE_RSA_WITH_AES_128_CBC_SHA,
                             0x0032,
                             CipherSuite.TLS_DHE_RSA_WITH_AES_128_CBC_SHA256,
                             CipherSuite.TLS_DHE_RSA_WITH_AES_256_CBC_SHA,
                             0x0038,
                             CipherSuite.TLS_DHE_RSA_WITH_AES_256_CBC_SHA256,
                             CipherSuite.TLS_DHE_RSA_WITH_3DES_EDE_CBC_SHA,
                             0x0013,
                             CipherSuite.TLS_RSA_WITH_AES_128_GCM_SHA256,
                             CipherSuite.TLS_RSA_WITH_AES_128_CBC_SHA,
                             CipherSuite.TLS_RSA_WITH_AES_128_CBC_SHA256,
                             CipherSuite.TLS_RSA_WITH_AES_256_CBC_SHA,
                             CipherSuite.TLS_RSA_WITH_AES_256_CBC_SHA256,
                             CipherSuite.TLS_RSA_WITH_3DES_EDE_CBC_SHA,
                             CipherSuite.TLS_RSA_WITH_RC4_128_SHA,
                             CipherSuite.TLS_RSA_WITH_RC4_128_MD5],
                            extensions=ext)

        self.assertEqual(client_hello.write(), client_hello_ciphertext[5:])

        for result in record_layer.recvRecord():
            # check if non-blocking
            self.assertNotIn(result, (0, 1))
        header, parser = result
        hs_type = parser.get(1)
        self.assertEqual(hs_type, HandshakeType.server_hello)
        server_hello = ServerHello().parse(parser)

        self.assertEqual(server_hello.server_version, TLS_1_3_DRAFT)
        self.assertEqual(server_hello.cipher_suite, CipherSuite.TLS_AES_128_GCM_SHA256)

        server_key_share = server_hello.getExtension(ExtensionType.key_share)
        server_key_share = server_key_share.server_share

        self.assertEqual(server_key_share.group, GroupName.x25519)

        # for TLS_AES_128_GCM_SHA256:
        prf_name = 'sha256'
        prf_size = 256 // 8
        secret = bytearray(prf_size)
        psk = bytearray(prf_size)

        # early secret
        secret = secureHMAC(secret, psk, prf_name)

        self.assertEqual(secret,
                         str_to_bytearray(
                             "33ad0a1c607ec03b 09e6cd9893680ce2"
                             "10adf300aa1f2660 e1b22e10f170f92a"))

        # derive secret for handshake
        secret = derive_secret(secret, b"derived", None, prf_name)

        self.assertEqual(secret,
                         str_to_bytearray(
                             "6f2615a108c702c5 678f54fc9dbab697"
                             "16c076189c48250c ebeac3576c3611ba"))

        # extract secret "handshake"
        Z = x25519(client_key_private, server_key_share.key_exchange)

        self.assertEqual(Z,
                         str_to_bytearray(
                             "f677c3cdac26a755 455b130efa9b1a3f"
                             "3cafb153544ca46a ddf670df199d996e"))

        secret = secureHMAC(secret, Z, prf_name)

        self.assertEqual(secret,
                         str_to_bytearray(
                             "0cefce00d5d29fd0 9f5de36c86fc8e72"
                             "99b4ad11ba4211c6 7063c2cc539fc4f9"))

        handshake_hashes = HandshakeHashes()
        handshake_hashes.update(client_hello_plaintext)
        handshake_hashes.update(server_hello_payload)

        # derive "tls13 c hs traffic"
        c_hs_traffic = derive_secret(secret,
                                     bytearray(b'c hs traffic'),
                                     handshake_hashes,
                                     prf_name)
        self.assertEqual(c_hs_traffic,
                         str_to_bytearray(
                             "5a63db760b817b1b da96e72832333aec"
                             "6a177deeadb5b407 501ac10c17dac0a4"))
        s_hs_traffic = derive_secret(secret,
                                     bytearray(b's hs traffic'),
                                     handshake_hashes,
                                     prf_name)
        self.assertEqual(s_hs_traffic,
                         str_to_bytearray(
                             "3aa72a3c77b791e8 f4de243f9ccce172"
                             "941f8392aeb05429 320f4b572ccfe744"))

        # derive master secret
        secret = derive_secret(secret, b"derived", None, prf_name)

        self.assertEqual(secret,
                         str_to_bytearray(
                             "32cadf38f3089048 5c54bf4f1184eaa5"
                             "569eeef15a43f3c7 6ab33965a47c9ff6"))

        # extract secret "master
        secret = secureHMAC(secret, bytearray(prf_size), prf_name)

        self.assertEqual(secret,
                         str_to_bytearray(
                             "6c6d4b3e7c925460 82d7b7a32f6ce219"
                             "3804f1bb930fed74 5c6b93c71397f424"))