예제 #1
0
    def test_process_with_bad_extension(self):
        exps = {ExtensionType.renegotiation_info: None,
                ExtensionType.alpn: 'BAD_EXTENSION'
               }
        exp = ExpectServerHello(extensions=exps)

        state = ConnectionState()
        client_hello = ClientHello()
        client_hello.cipher_suites = [4]
        state.handshake_messages.append(client_hello)
        state.msg_sock = mock.MagicMock()

        exts = []
        exts.append(RenegotiationInfoExtension().create(None))
        exts.append(ALPNExtension().create([bytearray(b'http/1.1')]))
        msg = ServerHello().create(version=(3, 3),
                                   random=bytearray(32),
                                   session_id=bytearray(0),
                                   cipher_suite=4,
                                   extensions=exts)

        self.assertTrue(exp.is_match(msg))

        with self.assertRaises(ValueError):
            exp.process(state, msg)
예제 #2
0
    def test_parse_with_extensions_length_long_by_one(self):
        p = Parser(
            bytearray(
                # don't include type of message as it is handled by the hello
                # protocol layer
                # b'\x02' +                     # type of message - server_hello
                b'\x00\x00\x36' +  # length - 54 bytes
                b'\x03\x03' +  # version - TLS 1.2
                b'\x01' * 31 + b'\x02' +  # random
                b'\x00' +  # session id length
                b'\x00\x9d' +  # cipher suite
                b'\x01' +  # compression method (zlib)
                b'\x00\x0f' +  # extensions length - 15 bytes (!)
                b'\xff\x01' +  # ext type - renegotiation_info
                b'\x00\x01' +  # ext length - 1 byte
                b'\x00' +  # value - supported (0)
                b'\x00\x23' +  # ext type - session ticket (35)
                b'\x00\x00' +  # ext length - 0 bytes
                b'\x00\x0f' +  # ext type - heartbeat (15)
                b'\x00\x01' +  # ext length - 1 byte
                b'\x01'))  # peer allowed to send requests (1)
        server_hello = ServerHello()

        with self.assertRaises(SyntaxError) as context:
            server_hello.parse(p)

        # TODO the message could be more descriptive...
        self.assertIsNone(context.exception.msg)
예제 #3
0
    def test_parse(self):
        p = Parser(
            bytearray(
                # don't include type of message as it is handled by the hello
                # protocol layer
                # b'\x02' +                     # type of message - server_hello
                b'\x00\x00\x36' +  # length - 54 bytes
                b'\x03\x03' +  # version - TLS 1.2
                b'\x01' * 31 + b'\x02' +  # random
                b'\x00' +  # session id length
                b'\x00\x9d' +  # cipher suite
                b'\x01' +  # compression method (zlib)
                b'\x00\x0e' +  # extensions length - 14 bytes
                b'\xff\x01' +  # ext type - renegotiation_info
                b'\x00\x01' +  # ext length - 1 byte
                b'\x00' +  # value - supported (0)
                b'\x00\x23' +  # ext type - session ticket (35)
                b'\x00\x00' +  # ext length - 0 bytes
                b'\x00\x0f' +  # ext type - heartbeat (15)
                b'\x00\x01' +  # ext length - 1 byte
                b'\x01'))  # peer allowed to send requests (1)
        server_hello = ServerHello()
        server_hello = server_hello.parse(p)

        self.assertEqual((3, 3), server_hello.server_version)
        self.assertEqual(bytearray(b'\x01' * 31 + b'\x02'),
                         server_hello.random)
        self.assertEqual(bytearray(0), server_hello.session_id)
        self.assertEqual(157, server_hello.cipher_suite)
        # XXX not sent by server!
        self.assertEqual(CertificateType.x509, server_hello.certificate_type)
        self.assertEqual(1, server_hello.compression_method)
        self.assertEqual(None, server_hello.tackExt)
        self.assertEqual(None, server_hello.next_protos_advertised)
예제 #4
0
    def test_process_with_rcf7919_groups_required_not_provided(self):
        exp = ExpectServerKeyExchange(valid_groups=[256])

        state = ConnectionState()
        state.cipher = CipherSuite.TLS_DHE_RSA_WITH_AES_128_CBC_SHA

        cert = Certificate(CertificateType.x509).\
                create(X509CertChain([X509().parse(srv_raw_certificate)]))

        private_key = parsePEMKey(srv_raw_key, private=True)
        client_hello = ClientHello()
        client_hello.client_version = (3, 3)
        client_hello.random = bytearray(32)
        client_hello.extensions = [SupportedGroupsExtension().create([256])]
        state.client_random = client_hello.random
        state.handshake_messages.append(client_hello)
        server_hello = ServerHello()
        server_hello.server_version = (3, 3)
        server_hello.random = bytearray(32)
        state.server_random = server_hello.random
        # server hello is not necessary for the test to work
        #state.handshake_messages.append(server_hello)
        state.handshake_messages.append(cert)
        srv_key_exchange = DHE_RSAKeyExchange(\
                CipherSuite.TLS_DHE_RSA_WITH_AES_128_CBC_SHA,
                client_hello,
                server_hello,
                private_key,
                dhGroups=None)

        msg = srv_key_exchange.makeServerKeyExchange('sha1')

        with self.assertRaises(AssertionError):
            exp.process(state, msg)
예제 #5
0
    def test_process_with_ECDHE_RSA(self):
        exp = ExpectServerKeyExchange()

        state = ConnectionState()
        state.cipher = CipherSuite.TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA

        cert = Certificate(CertificateType.x509).\
                create(X509CertChain([X509().parse(srv_raw_certificate)]))

        private_key = parsePEMKey(srv_raw_key, private=True)
        client_hello = ClientHello()
        client_hello.client_version = (3, 3)
        client_hello.random = bytearray(32)
        client_hello.extensions = [
            SignatureAlgorithmsExtension().create([(HashAlgorithm.sha256,
                                                    SignatureAlgorithm.rsa)]),
            SupportedGroupsExtension().create([GroupName.secp256r1])
        ]
        state.client_random = client_hello.random
        state.handshake_messages.append(client_hello)
        server_hello = ServerHello()
        server_hello.server_version = (3, 3)
        server_hello.random = bytearray(32)
        state.server_random = server_hello.random
        # server hello is not necessary for the test to work
        #state.handshake_messages.append(server_hello)
        state.handshake_messages.append(cert)
        srv_key_exchange = ECDHE_RSAKeyExchange(
            CipherSuite.TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA, client_hello,
            server_hello, private_key, [GroupName.secp256r1])

        msg = srv_key_exchange.makeServerKeyExchange('sha256')

        exp.process(state, msg)
예제 #6
0
    def test_process_with_not_matching_signature_algorithms(self):
        exp = ExpectServerKeyExchange(
            valid_sig_algs=[(HashAlgorithm.sha256, SignatureAlgorithm.rsa)])

        state = ConnectionState()
        state.cipher = CipherSuite.TLS_DHE_RSA_WITH_AES_128_CBC_SHA

        cert = Certificate(CertificateType.x509).\
                create(X509CertChain([X509().parse(srv_raw_certificate)]))

        private_key = parsePEMKey(srv_raw_key, private=True)
        client_hello = ClientHello()
        client_hello.client_version = (3, 3)
        client_hello.random = bytearray(32)
        state.client_random = client_hello.random
        state.handshake_messages.append(client_hello)
        server_hello = ServerHello()
        server_hello.server_version = (3, 3)
        server_hello.random = bytearray(32)
        state.server_random = server_hello.random
        # server hello is not necessary for the test to work
        #state.handshake_messages.append(server_hello)
        state.handshake_messages.append(cert)
        srv_key_exchange = DHE_RSAKeyExchange(\
                CipherSuite.TLS_DHE_RSA_WITH_AES_128_CBC_SHA,
                client_hello,
                server_hello,
                private_key)

        msg = srv_key_exchange.makeServerKeyExchange('sha1')

        with self.assertRaises(TLSIllegalParameterException):
            exp.process(state, msg)
예제 #7
0
    def test_process_with_unknown_key_exchange(self):
        exp = ExpectServerKeyExchange()

        state = ConnectionState()
        state.cipher = CipherSuite.TLS_SRP_SHA_RSA_WITH_AES_128_CBC_SHA
        cert = Certificate(CertificateType.x509).\
                create(X509CertChain([X509().parse(srv_raw_certificate)]))
        private_key = parsePEMKey(srv_raw_key, private=True)

        client_hello = ClientHello()
        client_hello.client_version = (3, 3)
        client_hello.random = bytearray(32)
        client_hello.extensions = [
            SignatureAlgorithmsExtension().create([(HashAlgorithm.sha256,
                                                    SignatureAlgorithm.rsa)])
        ]
        state.client_random = client_hello.random
        state.handshake_messages.append(client_hello)
        server_hello = ServerHello()
        server_hello.server_version = (3, 3)
        state.version = server_hello.server_version
        server_hello.random = bytearray(32)
        state.server_random = server_hello.random
        state.handshake_messages.append(cert)

        msg = ServerKeyExchange(state.cipher, state.version)
        msg.createSRP(1, 2, bytearray(3), 5)
        msg.signAlg = SignatureAlgorithm.rsa
        msg.hashAlg = HashAlgorithm.sha256
        hash_bytes = msg.hash(client_hello.random, server_hello.random)
        hash_bytes = private_key.addPKCS1Prefix(hash_bytes, 'sha256')
        msg.signature = private_key.sign(hash_bytes)

        with self.assertRaises(AssertionError):
            exp.process(state, msg)
예제 #8
0
    def setUp(self):
        self.srv_private_key = parsePEMKey(srv_raw_key, private=True)
        srv_chain = X509CertChain([X509().parse(srv_raw_certificate)])
        self.srv_pub_key = srv_chain.getEndEntityPublicKey()
        self.cipher_suite = CipherSuite.TLS_SRP_SHA_RSA_WITH_AES_128_CBC_SHA
        self.client_hello = ClientHello().create((3, 3),
                                                 bytearray(32),
                                                 bytearray(0),
                                                 [],
                                                 srpUsername='******')
        self.server_hello = ServerHello().create((3, 3),
                                                 bytearray(32),
                                                 bytearray(0),
                                                 self.cipher_suite)

        verifierDB = VerifierDB()
        verifierDB.create()
        entry = verifierDB.makeVerifier('user', 'password', 2048)
        verifierDB['user'] = entry

        self.keyExchange = SRPKeyExchange(self.cipher_suite,
                                          self.client_hello,
                                          self.server_hello,
                                          self.srv_private_key,
                                          verifierDB)
예제 #9
0
    def test_write_with_next_protos(self):
        server_hello = ServerHello().create(
            (1, 1),  # server version
            bytearray(b'\x00' * 31 + b'\x02'),  # random
            bytearray(0),  # session id
            4,  # cipher suite
            0,  # certificate type
            None,  # TACK ext
            [b'spdy/3', b'http/1.1'])  # next protos advertised

        self.assertEqual(
            list(
                bytearray(b'\x02' +  # type of message - server_hello
                          b'\x00\x00\x3c' +  # length
                          b'\x01\x01' +  # proto version
                          b'\x00' * 31 + b'\x02' +  # random
                          b'\x00' +  # session id length
                          b'\x00\x04' +  # cipher suite
                          b'\x00' +  # compression method
                          b'\x00\x14' +  # extensions length
                          b'\x33\x74' +  # ext type - NPN (13172)
                          b'\x00\x10' +  # ext length - 16 bytes
                          b'\x06' +  # first entry length - 6 bytes
                          # utf-8 encoding of 'spdy/3'
                          b'\x73\x70\x64\x79\x2f\x33'
                          b'\x08' +  # second entry length - 8 bytes
                          # utf-8 endoding of 'http/1.1'
                          b'\x68\x74\x74\x70\x2f\x31\x2e\x31')),
            list(server_hello.write()))
예제 #10
0
    def test_client_with_server_responing_with_SHA256_on_TLSv1_1(self):
        # socket to generate the faux response
        gen_sock = MockSocket(bytearray(0))

        gen_record_layer = RecordLayer(gen_sock)
        gen_record_layer.version = (3, 2)

        server_hello = ServerHello().create(
            version=(3, 2),
            random=bytearray(32),
            session_id=bytearray(0),
            cipher_suite=CipherSuite.TLS_RSA_WITH_AES_128_CBC_SHA256,
            certificate_type=None,
            tackExt=None,
            next_protos_advertised=None)

        for res in gen_record_layer.sendRecord(server_hello):
            if res in (0, 1):
                self.assertTrue(False, "Blocking socket")
            else:
                break

        # test proper
        sock = MockSocket(gen_sock.sent[0])

        conn = TLSConnection(sock)

        with self.assertRaises(TLSLocalAlert) as err:
            conn.handshakeClientCert()

        self.assertEqual(err.exception.description,
                         AlertDescription.illegal_parameter)
예제 #11
0
    def test___str__(self):
        server_hello = ServerHello()
        server_hello = server_hello.create((3, 0), bytearray(b'\x00' * 32),
                                           bytearray(b'\x01\x20'), 34500, 0,
                                           None, None)

        self.assertEqual("server_hello,length(40),version(3.0),random(...),"\
                "session ID(bytearray(b'\\x01 ')),cipher(0x86c4),"\
                "compression method(0)",
                str(server_hello))
예제 #12
0
    def test___init__(self):
        server_hello = ServerHello()

        self.assertEqual((0, 0), server_hello.server_version)
        self.assertEqual(bytearray(32), server_hello.random)
        self.assertEqual(bytearray(0), server_hello.session_id)
        self.assertEqual(0, server_hello.cipher_suite)
        self.assertEqual(CertificateType.x509, server_hello.certificate_type)
        self.assertEqual(0, server_hello.compression_method)
        self.assertEqual(None, server_hello.tackExt)
        self.assertEqual(None, server_hello.next_protos_advertised)
        self.assertEqual(None, server_hello.next_protos)
예제 #13
0
    def test_parse_with_cert_type_extension(self):
        p = Parser(
            bytearray(b'\x00\x00\x2d' +  # length - 45 bytes
                      b'\x03\x03' +  # version - TLS 1.2
                      b'\x01' * 31 + b'\x02' +  # random
                      b'\x00' +  # session id length
                      b'\x00\x9d' +  # cipher suite
                      b'\x00' +  # compression method (none)
                      b'\x00\x05' +  # extensions length - 5 bytes
                      b'\x00\x09' +  # ext type - cert_type (9)
                      b'\x00\x01' +  # ext length - 1 byte
                      b'\x01'  # value - OpenPGP (1)
                      ))

        server_hello = ServerHello().parse(p)
        self.assertEqual(1, server_hello.certificate_type)
예제 #14
0
    def test_parse_with_bad_cert_type_extension(self):
        p = Parser(
            bytearray(b'\x00\x00\x2e' +  # length - 46 bytes
                      b'\x03\x03' +  # version - TLS 1.2
                      b'\x01' * 31 + b'\x02' +  # random
                      b'\x00' +  # session id length
                      b'\x00\x9d' +  # cipher suite
                      b'\x00' +  # compression method (none)
                      b'\x00\x06' +  # extensions length - 5 bytes
                      b'\x00\x09' +  # ext type - cert_type (9)
                      b'\x00\x02' +  # ext length - 2 bytes
                      b'\x00\x01'  # value - X.509 (0), OpenPGP (1)
                      ))

        server_hello = ServerHello()
        with self.assertRaises(SyntaxError) as context:
            server_hello.parse(p)
예제 #15
0
    def test_process_with_incorrect_cipher(self):
        exp = ExpectServerHello(cipher=5)

        state = ConnectionState()
        state.msg_sock = mock.MagicMock()

        ext = RenegotiationInfoExtension().create()

        msg = ServerHello().create(version=(3, 3),
                                   random=bytearray(32),
                                   session_id=bytearray(0),
                                   cipher_suite=4)

        self.assertTrue(exp.is_match(msg))

        with self.assertRaises(AssertionError):
            exp.process(state, msg)
예제 #16
0
    def setUp(self):
        self.srv_private_key = parsePEMKey(srv_raw_key, private=True)
        srv_chain = X509CertChain([X509().parse(srv_raw_certificate)])
        self.srv_pub_key = srv_chain.getEndEntityPublicKey()
        self.cipher_suite = CipherSuite.TLS_DHE_RSA_WITH_AES_128_CBC_SHA
        self.client_hello = ClientHello().create((3, 3),
                                                 bytearray(32),
                                                 bytearray(0),
                                                 [])
        self.server_hello = ServerHello().create((3, 3),
                                                 bytearray(32),
                                                 bytearray(0),
                                                 self.cipher_suite)

        self.keyExchange = DHE_RSAKeyExchange(self.cipher_suite,
                                              self.client_hello,
                                              self.server_hello,
                                              self.srv_private_key)
예제 #17
0
    def test_signServerKeyExchange_in_TLS1_1(self):
        srv_private_key = parsePEMKey(srv_raw_key, private=True)
        client_hello = ClientHello()
        cipher_suite = CipherSuite.TLS_DHE_RSA_WITH_AES_128_CBC_SHA
        server_hello = ServerHello().create((3, 2),
                                            bytearray(32),
                                            bytearray(0),
                                            cipher_suite)
        keyExchange = KeyExchange(cipher_suite,
                                  client_hello,
                                  server_hello,
                                  srv_private_key)
        server_key_exchange = ServerKeyExchange(cipher_suite, (3, 2))\
                              .createDH(5, 2, 3)

        keyExchange.signServerKeyExchange(server_key_exchange)

        self.assertEqual(server_key_exchange.write(), self.expected_tls1_1_SKE)
예제 #18
0
 def test___repr__(self):
     server_hello = ServerHello()
     server_hello = server_hello.create((3, 0),
                                        bytearray(b'\x00' * 32),
                                        bytearray(0),
                                        34500,
                                        0,
                                        None,
                                        None,
                                        extensions=[])
     self.maxDiff = None
     self.assertEqual("ServerHello(server_version=(3.0), "\
             "random=bytearray(b'\\x00\\x00"\
             "\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00"\
             "\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00"\
             "\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00'), "\
             "session_id=bytearray(b''), "\
             "cipher_suite=34500, compression_method=0, _tack_ext=None, "\
             "extensions=[])", repr(server_hello))
예제 #19
0
    def test_parse_with_length_long_by_one(self):
        p = Parser(
            bytearray(
                # don't include type of message as it is handled by the hello
                # protocol layer
                # b'\x02' +                     # type of message - server_hello
                b'\x00\x00\x27' +  # length - 39 bytes (one long)
                b'\x03\x03' +  # version - TLS 1.2
                b'\x01' * 31 + b'\x02' +  # random
                b'\x00' +  # session id length
                b'\x00\x9d' +  # cipher suite
                b'\x01'  # compression method (zlib)
            ))
        server_hello = ServerHello()
        with self.assertRaises(SyntaxError) as context:
            server_hello.parse(p)

        # TODO the message probably could be more descriptive...
        self.assertIsNone(context.exception.msg)
예제 #20
0
    def process(self, state, msg):
        """
        Process the message and update state accordingly

        @type state: ConnectionState
        @param state: overall state of TLS connection

        @type msg: Message
        @param msg: TLS Message read from socket
        """
        assert msg.contentType == ContentType.handshake

        parser = Parser(msg.write())
        hs_type = parser.get(1)
        assert hs_type == HandshakeType.server_hello

        srv_hello = ServerHello()
        srv_hello.parse(parser)

        # extract important info
        state.cipher = srv_hello.cipher_suite
        state.version = srv_hello.server_version
        state.server_random = srv_hello.random

        # update the state of connection
        state.msg_sock.version = srv_hello.server_version

        state.handshake_messages.append(srv_hello)
        state.handshake_hashes.update(msg.write())

        # check if the message has expected values
        if self.extensions is not None:
            for ext_id in self.extensions:
                ext = srv_hello.getExtension(ext_id)
                assert ext is not None
                # run extension-specific checker if present
                if self.extensions[ext_id] is not None:
                    self.extensions[ext_id](state, ext)
            # not supporting any extensions is valid
            if srv_hello.extensions is not None:
                for ext_id in (ext.extType for ext in srv_hello.extensions):
                    assert ext_id in self.extensions
예제 #21
0
    def test_parse_with_NPN_extension(self):
        p = Parser(
            bytearray(b'\x00\x00\x3c' +  # length - 60 bytes
                      b'\x03\x03' +  # version - TLS 1.2
                      b'\x01' * 31 + b'\x02' +  # random
                      b'\x00' +  # session id length
                      b'\x00\x9d' +  # cipher suite
                      b'\x00' +  # compression method (none)
                      b'\x00\x14' +  # extensions length - 20 bytes
                      b'\x33\x74' +  # ext type - npn
                      b'\x00\x10' +  # ext length - 16 bytes
                      b'\x08' +  # length of first name - 8 bytes
                      b'http/1.1' +
                      b'\x06' +  # length of second name - 6 bytes
                      b'spdy/3'))

        server_hello = ServerHello().parse(p)

        self.assertEqual([bytearray(b'http/1.1'),
                          bytearray(b'spdy/3')], server_hello.next_protos)
예제 #22
0
    def test_process_with_udefined_cipher(self):
        exp = ExpectServerHello()

        state = ConnectionState()
        client_hello = ClientHello()
        client_hello.cipher_suites = [4]
        state.handshake_messages.append(client_hello)
        state.msg_sock = mock.MagicMock()

        ext = RenegotiationInfoExtension().create(None)

        msg = ServerHello().create(version=(3, 3),
                                   random=bytearray(32),
                                   session_id=bytearray(0),
                                   cipher_suite=0xfff0)

        self.assertTrue(exp.is_match(msg))

        with self.assertRaises(AssertionError):
            exp.process(state, msg)
예제 #23
0
    def test_process_with_extended_master_secret(self):
        exp = ExpectServerHello(
            extensions={ExtensionType.extended_master_secret: None})

        state = ConnectionState()
        state.msg_sock = mock.MagicMock()
        self.assertFalse(state.extended_master_secret)

        ext = TLSExtension(extType=ExtensionType.extended_master_secret)
        msg = ServerHello().create(version=(3, 3),
                                   random=bytearray(32),
                                   session_id=bytearray(0),
                                   cipher_suite=4,
                                   extensions=[ext])

        self.assertTrue(exp.is_match(msg))

        exp.process(state, msg)

        self.assertTrue(state.extended_master_secret)
예제 #24
0
    def test_process_with_mandatory_resumption_but_wrong_id(self):
        exp = ExpectServerHello(resume=True)

        state = ConnectionState()
        state.msg_sock = mock.MagicMock()
        state.session_id = bytearray(b'\xaa\xaa\xaa')
        state.cipher = 4

        self.assertFalse(state.resuming)

        msg = ServerHello()
        msg.create(version=(3, 3),
                   random=bytearray(32),
                   session_id=bytearray(b'\xbb\xbb\xbb'),
                   cipher_suite=4)

        self.assertTrue(exp.is_match(msg))

        with self.assertRaises(AssertionError):
            exp.process(state, msg)
예제 #25
0
    def test_process_with_unexpected_extensions(self):
        exp = ExpectServerHello(
            extensions={ExtensionType.renegotiation_info: None})

        state = ConnectionState()
        state.msg_sock = mock.MagicMock()

        exts = []
        exts.append(RenegotiationInfoExtension().create())
        exts.append(SNIExtension().create())
        msg = ServerHello().create(version=(3, 3),
                                   random=bytearray(32),
                                   session_id=bytearray(0),
                                   cipher_suite=4,
                                   extensions=exts)

        self.assertTrue(exp.is_match(msg))

        with self.assertRaises(AssertionError):
            exp.process(state, msg)
예제 #26
0
    def test_create(self):
        server_hello = ServerHello().create(
            (1, 1),  # server version
            bytearray(b'\x00' * 31 + b'\x01'),  # random
            bytearray(0),  # session id
            4,  # cipher suite
            1,  # certificate type
            None,  # TACK ext
            None)  # next protos advertised

        self.assertEqual((1, 1), server_hello.server_version)
        self.assertEqual(bytearray(b'\x00' * 31 + b'\x01'),
                         server_hello.random)
        self.assertEqual(bytearray(0), server_hello.session_id)
        self.assertEqual(4, server_hello.cipher_suite)
        self.assertEqual(CertificateType.openpgp,
                         server_hello.certificate_type)
        self.assertEqual(0, server_hello.compression_method)
        self.assertEqual(None, server_hello.tackExt)
        self.assertEqual(None, server_hello.next_protos_advertised)
예제 #27
0
    def setUp(self):
        self.srv_private_key = parsePEMKey(srv_raw_key, private=True)
        srv_chain = X509CertChain([X509().parse(srv_raw_certificate)])
        self.srv_pub_key = srv_chain.getEndEntityPublicKey()
        self.cipher_suite = CipherSuite.TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA
        ext = [SupportedGroupsExtension().create([GroupName.secp256r1])]
        self.client_hello = ClientHello().create((3, 3),
                                                 bytearray(32),
                                                 bytearray(0),
                                                 [],
                                                 extensions=ext)
        self.server_hello = ServerHello().create((3, 3),
                                                 bytearray(32),
                                                 bytearray(0),
                                                 self.cipher_suite)

        self.keyExchange = ECDHE_RSAKeyExchange(self.cipher_suite,
                                                self.client_hello,
                                                self.server_hello,
                                                self.srv_private_key,
                                                [GroupName.secp256r1])
예제 #28
0
    def test_process_with_resumption(self):
        exp = ExpectServerHello()

        state = ConnectionState()
        state.msg_sock = mock.MagicMock()
        state.session_id = bytearray(b'\xaa\xaa\xaa')
        state.cipher = 4

        self.assertFalse(state.resuming)

        msg = ServerHello()
        msg.create(version=(3, 3),
                   random=bytearray(32),
                   session_id=bytearray(b'\xaa\xaa\xaa'),
                   cipher_suite=4)

        self.assertTrue(exp.is_match(msg))

        exp.process(state, msg)

        self.assertTrue(state.resuming)
예제 #29
0
    def test_process_with_extensions(self):
        extension_process = mock.MagicMock()
        exp = ExpectServerHello(
            extensions={ExtensionType.renegotiation_info: extension_process})

        state = ConnectionState()
        state.msg_sock = mock.MagicMock()

        ext = RenegotiationInfoExtension().create()

        msg = ServerHello().create(version=(3, 3),
                                   random=bytearray(32),
                                   session_id=bytearray(0),
                                   cipher_suite=4,
                                   extensions=[ext])

        self.assertTrue(exp.is_match(msg))

        exp.process(state, msg)

        extension_process.assert_called_once_with(state, ext)
예제 #30
0
    def test_write(self):
        server_hello = ServerHello().create(
            (1, 1),  # server version
            bytearray(b'\x00' * 31 + b'\x02'),  # random
            bytearray(0),  # session id
            4,  # cipher suite
            None,  # certificate type
            None,  # TACK ext
            None)  # next protos advertised

        self.assertEqual(
            list(
                bytearray(b'\x02' +  # type of message - server_hello
                          b'\x00\x00\x26' +  # length
                          b'\x01\x01' +  # proto version
                          b'\x00' * 31 + b'\x02' +  # random
                          b'\x00' +  # session id length
                          b'\x00\x04' +  # cipher suite
                          b'\x00'  # compression method
                          )),
            list(server_hello.write()))