def test_process_with_bad_extension(self): exps = {ExtensionType.renegotiation_info: None, ExtensionType.alpn: 'BAD_EXTENSION' } exp = ExpectServerHello(extensions=exps) state = ConnectionState() client_hello = ClientHello() client_hello.cipher_suites = [4] state.handshake_messages.append(client_hello) state.msg_sock = mock.MagicMock() exts = [] exts.append(RenegotiationInfoExtension().create(None)) exts.append(ALPNExtension().create([bytearray(b'http/1.1')])) msg = ServerHello().create(version=(3, 3), random=bytearray(32), session_id=bytearray(0), cipher_suite=4, extensions=exts) self.assertTrue(exp.is_match(msg)) with self.assertRaises(ValueError): exp.process(state, msg)
def test_parse_with_extensions_length_long_by_one(self): p = Parser( bytearray( # don't include type of message as it is handled by the hello # protocol layer # b'\x02' + # type of message - server_hello b'\x00\x00\x36' + # length - 54 bytes b'\x03\x03' + # version - TLS 1.2 b'\x01' * 31 + b'\x02' + # random b'\x00' + # session id length b'\x00\x9d' + # cipher suite b'\x01' + # compression method (zlib) b'\x00\x0f' + # extensions length - 15 bytes (!) b'\xff\x01' + # ext type - renegotiation_info b'\x00\x01' + # ext length - 1 byte b'\x00' + # value - supported (0) b'\x00\x23' + # ext type - session ticket (35) b'\x00\x00' + # ext length - 0 bytes b'\x00\x0f' + # ext type - heartbeat (15) b'\x00\x01' + # ext length - 1 byte b'\x01')) # peer allowed to send requests (1) server_hello = ServerHello() with self.assertRaises(SyntaxError) as context: server_hello.parse(p) # TODO the message could be more descriptive... self.assertIsNone(context.exception.msg)
def test_parse(self): p = Parser( bytearray( # don't include type of message as it is handled by the hello # protocol layer # b'\x02' + # type of message - server_hello b'\x00\x00\x36' + # length - 54 bytes b'\x03\x03' + # version - TLS 1.2 b'\x01' * 31 + b'\x02' + # random b'\x00' + # session id length b'\x00\x9d' + # cipher suite b'\x01' + # compression method (zlib) b'\x00\x0e' + # extensions length - 14 bytes b'\xff\x01' + # ext type - renegotiation_info b'\x00\x01' + # ext length - 1 byte b'\x00' + # value - supported (0) b'\x00\x23' + # ext type - session ticket (35) b'\x00\x00' + # ext length - 0 bytes b'\x00\x0f' + # ext type - heartbeat (15) b'\x00\x01' + # ext length - 1 byte b'\x01')) # peer allowed to send requests (1) server_hello = ServerHello() server_hello = server_hello.parse(p) self.assertEqual((3, 3), server_hello.server_version) self.assertEqual(bytearray(b'\x01' * 31 + b'\x02'), server_hello.random) self.assertEqual(bytearray(0), server_hello.session_id) self.assertEqual(157, server_hello.cipher_suite) # XXX not sent by server! self.assertEqual(CertificateType.x509, server_hello.certificate_type) self.assertEqual(1, server_hello.compression_method) self.assertEqual(None, server_hello.tackExt) self.assertEqual(None, server_hello.next_protos_advertised)
def test_process_with_rcf7919_groups_required_not_provided(self): exp = ExpectServerKeyExchange(valid_groups=[256]) state = ConnectionState() state.cipher = CipherSuite.TLS_DHE_RSA_WITH_AES_128_CBC_SHA cert = Certificate(CertificateType.x509).\ create(X509CertChain([X509().parse(srv_raw_certificate)])) private_key = parsePEMKey(srv_raw_key, private=True) client_hello = ClientHello() client_hello.client_version = (3, 3) client_hello.random = bytearray(32) client_hello.extensions = [SupportedGroupsExtension().create([256])] state.client_random = client_hello.random state.handshake_messages.append(client_hello) server_hello = ServerHello() server_hello.server_version = (3, 3) server_hello.random = bytearray(32) state.server_random = server_hello.random # server hello is not necessary for the test to work #state.handshake_messages.append(server_hello) state.handshake_messages.append(cert) srv_key_exchange = DHE_RSAKeyExchange(\ CipherSuite.TLS_DHE_RSA_WITH_AES_128_CBC_SHA, client_hello, server_hello, private_key, dhGroups=None) msg = srv_key_exchange.makeServerKeyExchange('sha1') with self.assertRaises(AssertionError): exp.process(state, msg)
def test_process_with_ECDHE_RSA(self): exp = ExpectServerKeyExchange() state = ConnectionState() state.cipher = CipherSuite.TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA cert = Certificate(CertificateType.x509).\ create(X509CertChain([X509().parse(srv_raw_certificate)])) private_key = parsePEMKey(srv_raw_key, private=True) client_hello = ClientHello() client_hello.client_version = (3, 3) client_hello.random = bytearray(32) client_hello.extensions = [ SignatureAlgorithmsExtension().create([(HashAlgorithm.sha256, SignatureAlgorithm.rsa)]), SupportedGroupsExtension().create([GroupName.secp256r1]) ] state.client_random = client_hello.random state.handshake_messages.append(client_hello) server_hello = ServerHello() server_hello.server_version = (3, 3) server_hello.random = bytearray(32) state.server_random = server_hello.random # server hello is not necessary for the test to work #state.handshake_messages.append(server_hello) state.handshake_messages.append(cert) srv_key_exchange = ECDHE_RSAKeyExchange( CipherSuite.TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA, client_hello, server_hello, private_key, [GroupName.secp256r1]) msg = srv_key_exchange.makeServerKeyExchange('sha256') exp.process(state, msg)
def test_process_with_not_matching_signature_algorithms(self): exp = ExpectServerKeyExchange( valid_sig_algs=[(HashAlgorithm.sha256, SignatureAlgorithm.rsa)]) state = ConnectionState() state.cipher = CipherSuite.TLS_DHE_RSA_WITH_AES_128_CBC_SHA cert = Certificate(CertificateType.x509).\ create(X509CertChain([X509().parse(srv_raw_certificate)])) private_key = parsePEMKey(srv_raw_key, private=True) client_hello = ClientHello() client_hello.client_version = (3, 3) client_hello.random = bytearray(32) state.client_random = client_hello.random state.handshake_messages.append(client_hello) server_hello = ServerHello() server_hello.server_version = (3, 3) server_hello.random = bytearray(32) state.server_random = server_hello.random # server hello is not necessary for the test to work #state.handshake_messages.append(server_hello) state.handshake_messages.append(cert) srv_key_exchange = DHE_RSAKeyExchange(\ CipherSuite.TLS_DHE_RSA_WITH_AES_128_CBC_SHA, client_hello, server_hello, private_key) msg = srv_key_exchange.makeServerKeyExchange('sha1') with self.assertRaises(TLSIllegalParameterException): exp.process(state, msg)
def test_process_with_unknown_key_exchange(self): exp = ExpectServerKeyExchange() state = ConnectionState() state.cipher = CipherSuite.TLS_SRP_SHA_RSA_WITH_AES_128_CBC_SHA cert = Certificate(CertificateType.x509).\ create(X509CertChain([X509().parse(srv_raw_certificate)])) private_key = parsePEMKey(srv_raw_key, private=True) client_hello = ClientHello() client_hello.client_version = (3, 3) client_hello.random = bytearray(32) client_hello.extensions = [ SignatureAlgorithmsExtension().create([(HashAlgorithm.sha256, SignatureAlgorithm.rsa)]) ] state.client_random = client_hello.random state.handshake_messages.append(client_hello) server_hello = ServerHello() server_hello.server_version = (3, 3) state.version = server_hello.server_version server_hello.random = bytearray(32) state.server_random = server_hello.random state.handshake_messages.append(cert) msg = ServerKeyExchange(state.cipher, state.version) msg.createSRP(1, 2, bytearray(3), 5) msg.signAlg = SignatureAlgorithm.rsa msg.hashAlg = HashAlgorithm.sha256 hash_bytes = msg.hash(client_hello.random, server_hello.random) hash_bytes = private_key.addPKCS1Prefix(hash_bytes, 'sha256') msg.signature = private_key.sign(hash_bytes) with self.assertRaises(AssertionError): exp.process(state, msg)
def setUp(self): self.srv_private_key = parsePEMKey(srv_raw_key, private=True) srv_chain = X509CertChain([X509().parse(srv_raw_certificate)]) self.srv_pub_key = srv_chain.getEndEntityPublicKey() self.cipher_suite = CipherSuite.TLS_SRP_SHA_RSA_WITH_AES_128_CBC_SHA self.client_hello = ClientHello().create((3, 3), bytearray(32), bytearray(0), [], srpUsername='******') self.server_hello = ServerHello().create((3, 3), bytearray(32), bytearray(0), self.cipher_suite) verifierDB = VerifierDB() verifierDB.create() entry = verifierDB.makeVerifier('user', 'password', 2048) verifierDB['user'] = entry self.keyExchange = SRPKeyExchange(self.cipher_suite, self.client_hello, self.server_hello, self.srv_private_key, verifierDB)
def test_write_with_next_protos(self): server_hello = ServerHello().create( (1, 1), # server version bytearray(b'\x00' * 31 + b'\x02'), # random bytearray(0), # session id 4, # cipher suite 0, # certificate type None, # TACK ext [b'spdy/3', b'http/1.1']) # next protos advertised self.assertEqual( list( bytearray(b'\x02' + # type of message - server_hello b'\x00\x00\x3c' + # length b'\x01\x01' + # proto version b'\x00' * 31 + b'\x02' + # random b'\x00' + # session id length b'\x00\x04' + # cipher suite b'\x00' + # compression method b'\x00\x14' + # extensions length b'\x33\x74' + # ext type - NPN (13172) b'\x00\x10' + # ext length - 16 bytes b'\x06' + # first entry length - 6 bytes # utf-8 encoding of 'spdy/3' b'\x73\x70\x64\x79\x2f\x33' b'\x08' + # second entry length - 8 bytes # utf-8 endoding of 'http/1.1' b'\x68\x74\x74\x70\x2f\x31\x2e\x31')), list(server_hello.write()))
def test_client_with_server_responing_with_SHA256_on_TLSv1_1(self): # socket to generate the faux response gen_sock = MockSocket(bytearray(0)) gen_record_layer = RecordLayer(gen_sock) gen_record_layer.version = (3, 2) server_hello = ServerHello().create( version=(3, 2), random=bytearray(32), session_id=bytearray(0), cipher_suite=CipherSuite.TLS_RSA_WITH_AES_128_CBC_SHA256, certificate_type=None, tackExt=None, next_protos_advertised=None) for res in gen_record_layer.sendRecord(server_hello): if res in (0, 1): self.assertTrue(False, "Blocking socket") else: break # test proper sock = MockSocket(gen_sock.sent[0]) conn = TLSConnection(sock) with self.assertRaises(TLSLocalAlert) as err: conn.handshakeClientCert() self.assertEqual(err.exception.description, AlertDescription.illegal_parameter)
def test___str__(self): server_hello = ServerHello() server_hello = server_hello.create((3, 0), bytearray(b'\x00' * 32), bytearray(b'\x01\x20'), 34500, 0, None, None) self.assertEqual("server_hello,length(40),version(3.0),random(...),"\ "session ID(bytearray(b'\\x01 ')),cipher(0x86c4),"\ "compression method(0)", str(server_hello))
def test___init__(self): server_hello = ServerHello() self.assertEqual((0, 0), server_hello.server_version) self.assertEqual(bytearray(32), server_hello.random) self.assertEqual(bytearray(0), server_hello.session_id) self.assertEqual(0, server_hello.cipher_suite) self.assertEqual(CertificateType.x509, server_hello.certificate_type) self.assertEqual(0, server_hello.compression_method) self.assertEqual(None, server_hello.tackExt) self.assertEqual(None, server_hello.next_protos_advertised) self.assertEqual(None, server_hello.next_protos)
def test_parse_with_cert_type_extension(self): p = Parser( bytearray(b'\x00\x00\x2d' + # length - 45 bytes b'\x03\x03' + # version - TLS 1.2 b'\x01' * 31 + b'\x02' + # random b'\x00' + # session id length b'\x00\x9d' + # cipher suite b'\x00' + # compression method (none) b'\x00\x05' + # extensions length - 5 bytes b'\x00\x09' + # ext type - cert_type (9) b'\x00\x01' + # ext length - 1 byte b'\x01' # value - OpenPGP (1) )) server_hello = ServerHello().parse(p) self.assertEqual(1, server_hello.certificate_type)
def test_parse_with_bad_cert_type_extension(self): p = Parser( bytearray(b'\x00\x00\x2e' + # length - 46 bytes b'\x03\x03' + # version - TLS 1.2 b'\x01' * 31 + b'\x02' + # random b'\x00' + # session id length b'\x00\x9d' + # cipher suite b'\x00' + # compression method (none) b'\x00\x06' + # extensions length - 5 bytes b'\x00\x09' + # ext type - cert_type (9) b'\x00\x02' + # ext length - 2 bytes b'\x00\x01' # value - X.509 (0), OpenPGP (1) )) server_hello = ServerHello() with self.assertRaises(SyntaxError) as context: server_hello.parse(p)
def test_process_with_incorrect_cipher(self): exp = ExpectServerHello(cipher=5) state = ConnectionState() state.msg_sock = mock.MagicMock() ext = RenegotiationInfoExtension().create() msg = ServerHello().create(version=(3, 3), random=bytearray(32), session_id=bytearray(0), cipher_suite=4) self.assertTrue(exp.is_match(msg)) with self.assertRaises(AssertionError): exp.process(state, msg)
def setUp(self): self.srv_private_key = parsePEMKey(srv_raw_key, private=True) srv_chain = X509CertChain([X509().parse(srv_raw_certificate)]) self.srv_pub_key = srv_chain.getEndEntityPublicKey() self.cipher_suite = CipherSuite.TLS_DHE_RSA_WITH_AES_128_CBC_SHA self.client_hello = ClientHello().create((3, 3), bytearray(32), bytearray(0), []) self.server_hello = ServerHello().create((3, 3), bytearray(32), bytearray(0), self.cipher_suite) self.keyExchange = DHE_RSAKeyExchange(self.cipher_suite, self.client_hello, self.server_hello, self.srv_private_key)
def test_signServerKeyExchange_in_TLS1_1(self): srv_private_key = parsePEMKey(srv_raw_key, private=True) client_hello = ClientHello() cipher_suite = CipherSuite.TLS_DHE_RSA_WITH_AES_128_CBC_SHA server_hello = ServerHello().create((3, 2), bytearray(32), bytearray(0), cipher_suite) keyExchange = KeyExchange(cipher_suite, client_hello, server_hello, srv_private_key) server_key_exchange = ServerKeyExchange(cipher_suite, (3, 2))\ .createDH(5, 2, 3) keyExchange.signServerKeyExchange(server_key_exchange) self.assertEqual(server_key_exchange.write(), self.expected_tls1_1_SKE)
def test___repr__(self): server_hello = ServerHello() server_hello = server_hello.create((3, 0), bytearray(b'\x00' * 32), bytearray(0), 34500, 0, None, None, extensions=[]) self.maxDiff = None self.assertEqual("ServerHello(server_version=(3.0), "\ "random=bytearray(b'\\x00\\x00"\ "\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00"\ "\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00"\ "\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00'), "\ "session_id=bytearray(b''), "\ "cipher_suite=34500, compression_method=0, _tack_ext=None, "\ "extensions=[])", repr(server_hello))
def test_parse_with_length_long_by_one(self): p = Parser( bytearray( # don't include type of message as it is handled by the hello # protocol layer # b'\x02' + # type of message - server_hello b'\x00\x00\x27' + # length - 39 bytes (one long) b'\x03\x03' + # version - TLS 1.2 b'\x01' * 31 + b'\x02' + # random b'\x00' + # session id length b'\x00\x9d' + # cipher suite b'\x01' # compression method (zlib) )) server_hello = ServerHello() with self.assertRaises(SyntaxError) as context: server_hello.parse(p) # TODO the message probably could be more descriptive... self.assertIsNone(context.exception.msg)
def process(self, state, msg): """ Process the message and update state accordingly @type state: ConnectionState @param state: overall state of TLS connection @type msg: Message @param msg: TLS Message read from socket """ assert msg.contentType == ContentType.handshake parser = Parser(msg.write()) hs_type = parser.get(1) assert hs_type == HandshakeType.server_hello srv_hello = ServerHello() srv_hello.parse(parser) # extract important info state.cipher = srv_hello.cipher_suite state.version = srv_hello.server_version state.server_random = srv_hello.random # update the state of connection state.msg_sock.version = srv_hello.server_version state.handshake_messages.append(srv_hello) state.handshake_hashes.update(msg.write()) # check if the message has expected values if self.extensions is not None: for ext_id in self.extensions: ext = srv_hello.getExtension(ext_id) assert ext is not None # run extension-specific checker if present if self.extensions[ext_id] is not None: self.extensions[ext_id](state, ext) # not supporting any extensions is valid if srv_hello.extensions is not None: for ext_id in (ext.extType for ext in srv_hello.extensions): assert ext_id in self.extensions
def test_parse_with_NPN_extension(self): p = Parser( bytearray(b'\x00\x00\x3c' + # length - 60 bytes b'\x03\x03' + # version - TLS 1.2 b'\x01' * 31 + b'\x02' + # random b'\x00' + # session id length b'\x00\x9d' + # cipher suite b'\x00' + # compression method (none) b'\x00\x14' + # extensions length - 20 bytes b'\x33\x74' + # ext type - npn b'\x00\x10' + # ext length - 16 bytes b'\x08' + # length of first name - 8 bytes b'http/1.1' + b'\x06' + # length of second name - 6 bytes b'spdy/3')) server_hello = ServerHello().parse(p) self.assertEqual([bytearray(b'http/1.1'), bytearray(b'spdy/3')], server_hello.next_protos)
def test_process_with_udefined_cipher(self): exp = ExpectServerHello() state = ConnectionState() client_hello = ClientHello() client_hello.cipher_suites = [4] state.handshake_messages.append(client_hello) state.msg_sock = mock.MagicMock() ext = RenegotiationInfoExtension().create(None) msg = ServerHello().create(version=(3, 3), random=bytearray(32), session_id=bytearray(0), cipher_suite=0xfff0) self.assertTrue(exp.is_match(msg)) with self.assertRaises(AssertionError): exp.process(state, msg)
def test_process_with_extended_master_secret(self): exp = ExpectServerHello( extensions={ExtensionType.extended_master_secret: None}) state = ConnectionState() state.msg_sock = mock.MagicMock() self.assertFalse(state.extended_master_secret) ext = TLSExtension(extType=ExtensionType.extended_master_secret) msg = ServerHello().create(version=(3, 3), random=bytearray(32), session_id=bytearray(0), cipher_suite=4, extensions=[ext]) self.assertTrue(exp.is_match(msg)) exp.process(state, msg) self.assertTrue(state.extended_master_secret)
def test_process_with_mandatory_resumption_but_wrong_id(self): exp = ExpectServerHello(resume=True) state = ConnectionState() state.msg_sock = mock.MagicMock() state.session_id = bytearray(b'\xaa\xaa\xaa') state.cipher = 4 self.assertFalse(state.resuming) msg = ServerHello() msg.create(version=(3, 3), random=bytearray(32), session_id=bytearray(b'\xbb\xbb\xbb'), cipher_suite=4) self.assertTrue(exp.is_match(msg)) with self.assertRaises(AssertionError): exp.process(state, msg)
def test_process_with_unexpected_extensions(self): exp = ExpectServerHello( extensions={ExtensionType.renegotiation_info: None}) state = ConnectionState() state.msg_sock = mock.MagicMock() exts = [] exts.append(RenegotiationInfoExtension().create()) exts.append(SNIExtension().create()) msg = ServerHello().create(version=(3, 3), random=bytearray(32), session_id=bytearray(0), cipher_suite=4, extensions=exts) self.assertTrue(exp.is_match(msg)) with self.assertRaises(AssertionError): exp.process(state, msg)
def test_create(self): server_hello = ServerHello().create( (1, 1), # server version bytearray(b'\x00' * 31 + b'\x01'), # random bytearray(0), # session id 4, # cipher suite 1, # certificate type None, # TACK ext None) # next protos advertised self.assertEqual((1, 1), server_hello.server_version) self.assertEqual(bytearray(b'\x00' * 31 + b'\x01'), server_hello.random) self.assertEqual(bytearray(0), server_hello.session_id) self.assertEqual(4, server_hello.cipher_suite) self.assertEqual(CertificateType.openpgp, server_hello.certificate_type) self.assertEqual(0, server_hello.compression_method) self.assertEqual(None, server_hello.tackExt) self.assertEqual(None, server_hello.next_protos_advertised)
def setUp(self): self.srv_private_key = parsePEMKey(srv_raw_key, private=True) srv_chain = X509CertChain([X509().parse(srv_raw_certificate)]) self.srv_pub_key = srv_chain.getEndEntityPublicKey() self.cipher_suite = CipherSuite.TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA ext = [SupportedGroupsExtension().create([GroupName.secp256r1])] self.client_hello = ClientHello().create((3, 3), bytearray(32), bytearray(0), [], extensions=ext) self.server_hello = ServerHello().create((3, 3), bytearray(32), bytearray(0), self.cipher_suite) self.keyExchange = ECDHE_RSAKeyExchange(self.cipher_suite, self.client_hello, self.server_hello, self.srv_private_key, [GroupName.secp256r1])
def test_process_with_resumption(self): exp = ExpectServerHello() state = ConnectionState() state.msg_sock = mock.MagicMock() state.session_id = bytearray(b'\xaa\xaa\xaa') state.cipher = 4 self.assertFalse(state.resuming) msg = ServerHello() msg.create(version=(3, 3), random=bytearray(32), session_id=bytearray(b'\xaa\xaa\xaa'), cipher_suite=4) self.assertTrue(exp.is_match(msg)) exp.process(state, msg) self.assertTrue(state.resuming)
def test_process_with_extensions(self): extension_process = mock.MagicMock() exp = ExpectServerHello( extensions={ExtensionType.renegotiation_info: extension_process}) state = ConnectionState() state.msg_sock = mock.MagicMock() ext = RenegotiationInfoExtension().create() msg = ServerHello().create(version=(3, 3), random=bytearray(32), session_id=bytearray(0), cipher_suite=4, extensions=[ext]) self.assertTrue(exp.is_match(msg)) exp.process(state, msg) extension_process.assert_called_once_with(state, ext)
def test_write(self): server_hello = ServerHello().create( (1, 1), # server version bytearray(b'\x00' * 31 + b'\x02'), # random bytearray(0), # session id 4, # cipher suite None, # certificate type None, # TACK ext None) # next protos advertised self.assertEqual( list( bytearray(b'\x02' + # type of message - server_hello b'\x00\x00\x26' + # length b'\x01\x01' + # proto version b'\x00' * 31 + b'\x02' + # random b'\x00' + # session id length b'\x00\x04' + # cipher suite b'\x00' # compression method )), list(server_hello.write()))