def test_with_entry(self):
        entry = KeyShareEntry().create(1313, bytearray(b'something'))
        gen = key_share_ext_gen([entry])

        ext = gen(None)

        self.assertIsInstance(ext, ClientKeyShareExtension)
        self.assertEqual(len(ext.client_shares), 1)
        self.assertEqual(ext.client_shares[0].group, 1313)
        self.assertEqual(ext.client_shares[0].key_exchange, b'something')
Beispiel #2
0
def key_share_gen(group, version=(3, 4)):
    """
    Create a random key share for a group of a given id.

    :param int group: TLS numerical ID from GroupName identifying the group
    :param tuple version: TLS protocol version as a tuple, as encoded on the
        wire
    :return: KeyShareEntry
    """
    kex = kex_for_group(group, version)
    private = kex.get_random_private_key()
    share = kex.calc_public_value(private)
    return KeyShareEntry().create(group, share, private)
Beispiel #3
0
def key_share_gen(group, version=(3, 4)):
    """
    Create a random key share for a group of a given id.

    @type group: int
    @param group: TLS numerical ID from GroupName identifying the group
    @type version: tuple
    @param version: TLS protocol version as a tuple, as encoded on the
        wire
    @return: L{KeyShareEntry<tlslite.extensions.KeyShareEntry>}
    """
    kex = kex_for_group(group, version)
    private = kex.get_random_private_key()
    share = kex.calc_public_value(private)
    return KeyShareEntry().create(group, share, private)
Beispiel #4
0
    def test_client_with_server_responing_with_wrong_session_id_in_TLS1_3(
            self):
        # socket to generate the faux response
        gen_sock = MockSocket(bytearray(0))

        gen_record_layer = RecordLayer(gen_sock)
        gen_record_layer.version = (3, 3)

        srv_ext = []
        srv_ext.append(SrvSupportedVersionsExtension().create((3, 4)))
        srv_ext.append(ServerKeyShareExtension().create(KeyShareEntry().create(
            GroupName.secp256r1, bytearray(b'\x03' + b'\x01' * 32))))

        server_hello = ServerHello().create(
            version=(3, 3),
            random=bytearray(32),
            session_id=bytearray(b"test"),
            cipher_suite=CipherSuite.TLS_AES_128_GCM_SHA256,
            certificate_type=None,
            tackExt=None,
            next_protos_advertised=None,
            extensions=srv_ext)

        for res in gen_record_layer.sendRecord(server_hello):
            if res in (0, 1):
                self.assertTrue(False, "Blocking socket")
            else:
                break

        # test proper
        sock = MockSocket(gen_sock.sent[0])

        conn = TLSConnection(sock)

        with self.assertRaises(TLSLocalAlert) as err:
            conn.handshakeClientCert()

        self.assertEqual(err.exception.description,
                         AlertDescription.illegal_parameter)
Beispiel #5
0
def main():
    host = "localhost"
    port = 4433
    num_limit = None
    run_exclude = set()
    expected_failures = {}
    last_exp_tmp = None

    argv = sys.argv[1:]
    opts, args = getopt.getopt(argv, "h:p:e:x:X:n:", ["help"])
    for opt, arg in opts:
        if opt == '-h':
            host = arg
        elif opt == '-p':
            port = int(arg)
        elif opt == '-e':
            run_exclude.add(arg)
        elif opt == '-x':
            expected_failures[arg] = None
            last_exp_tmp = str(arg)
        elif opt == '-X':
            if not last_exp_tmp:
                raise ValueError("-x has to be specified before -X")
            expected_failures[last_exp_tmp] = str(arg)
        elif opt == '-n':
            num_limit = int(arg)
        elif opt == '--help':
            help_msg()
            sys.exit(0)
        else:
            raise ValueError("Unknown option: {0}".format(opt))

    if args:
        run_only = set(args)
    else:
        run_only = None

    conversations = {}

    conversation = Connect(host, port)
    node = conversation
    ciphers = [
        CipherSuite.TLS_AES_128_GCM_SHA256,
        CipherSuite.TLS_EMPTY_RENEGOTIATION_INFO_SCSV
    ]
    ext = {}
    groups = [
        GroupName.secp256r1, GroupName.secp384r1, GroupName.secp521r1,
        GroupName.x25519, GroupName.x448
    ]
    ext[ExtensionType.key_share] = key_share_ext_gen(groups)
    ext[ExtensionType.supported_versions] = SupportedVersionsExtension()\
        .create([TLS_1_3_DRAFT, (3, 3)])
    ext[ExtensionType.supported_groups] = SupportedGroupsExtension()\
        .create(groups)
    sig_algs = [
        SignatureScheme.rsa_pss_rsae_sha256, SignatureScheme.rsa_pss_pss_sha256
    ]
    sig_algs += ECDSA_SIG_TLS1_3_ALL
    ext[ExtensionType.signature_algorithms] = SignatureAlgorithmsExtension()\
        .create(sig_algs)
    ext[ExtensionType.signature_algorithms_cert] = SignatureAlgorithmsCertExtension()\
        .create(SIG_ALL)
    node = node.add_child(ClientHelloGenerator(ciphers, extensions=ext))
    node = node.add_child(ExpectServerHello())
    node = node.add_child(ExpectChangeCipherSpec())
    node = node.add_child(ExpectEncryptedExtensions())
    node = node.add_child(ExpectCertificate())
    node = node.add_child(ExpectCertificateVerify())
    node = node.add_child(ExpectFinished())
    node = node.add_child(FinishedGenerator())
    node = node.add_child(
        ApplicationDataGenerator(bytearray(b"GET / HTTP/1.0\r\n\r\n")))

    # This message is optional and may show up 0 to many times
    cycle = ExpectNewSessionTicket()
    node = node.add_child(cycle)
    node.add_child(cycle)

    node.next_sibling = ExpectApplicationData()
    node = node.next_sibling.add_child(
        AlertGenerator(AlertLevel.warning, AlertDescription.close_notify))

    node = node.add_child(ExpectAlert())
    node.next_sibling = ExpectClose()
    conversations["sanity"] = conversation

    for group in groups:
        conversation = Connect(host, port)
        node = conversation
        ciphers = [
            CipherSuite.TLS_AES_128_GCM_SHA256,
            CipherSuite.TLS_EMPTY_RENEGOTIATION_INFO_SCSV
        ]
        ext = {}
        ext[ExtensionType.key_share] = key_share_ext_gen([group])
        ext[ExtensionType.supported_versions] = SupportedVersionsExtension()\
            .create([TLS_1_3_DRAFT, (3, 3)])
        ext[ExtensionType.supported_groups] = SupportedGroupsExtension()\
            .create([group])
        ext[ExtensionType.signature_algorithms] = SignatureAlgorithmsExtension()\
            .create(sig_algs)
        ext[ExtensionType.signature_algorithms_cert] = SignatureAlgorithmsCertExtension()\
            .create(SIG_ALL)
        node = node.add_child(ClientHelloGenerator(ciphers, extensions=ext))
        node = node.add_child(ExpectServerHello())
        node = node.add_child(ExpectChangeCipherSpec())
        node = node.add_child(ExpectEncryptedExtensions())
        node = node.add_child(ExpectCertificate())
        node = node.add_child(ExpectCertificateVerify())
        node = node.add_child(ExpectFinished())
        node = node.add_child(FinishedGenerator())
        node = node.add_child(
            ApplicationDataGenerator(bytearray(b"GET / HTTP/1.0\r\n\r\n")))

        # This message is optional and may show up 0 to many times
        cycle = ExpectNewSessionTicket()
        node = node.add_child(cycle)
        node.add_child(cycle)

        node.next_sibling = ExpectApplicationData()
        node = node.next_sibling.add_child(
            AlertGenerator(AlertLevel.warning, AlertDescription.close_notify))

        node = node.add_child(ExpectAlert())
        node.next_sibling = ExpectClose()
        conversations["sanity - {0}".format(
            GroupName.toRepr(group))] = conversation

        # padded representation
        conversation = Connect(host, port)
        node = conversation
        ciphers = [
            CipherSuite.TLS_AES_128_GCM_SHA256,
            CipherSuite.TLS_EMPTY_RENEGOTIATION_INFO_SCSV
        ]
        ext = {}
        key_share = key_share_gen(group)
        key_share.key_exchange += bytearray(b'\x00')
        ext[ExtensionType.key_share] = ClientKeyShareExtension().create(
            [key_share])
        ext[ExtensionType.supported_versions] = SupportedVersionsExtension()\
            .create([TLS_1_3_DRAFT, (3, 3)])
        ext[ExtensionType.supported_groups] = SupportedGroupsExtension()\
            .create([group])
        ext[ExtensionType.signature_algorithms] = SignatureAlgorithmsExtension()\
            .create(sig_algs)
        ext[ExtensionType.signature_algorithms_cert] = SignatureAlgorithmsCertExtension()\
            .create(SIG_ALL)
        node = node.add_child(ClientHelloGenerator(ciphers, extensions=ext))
        node = node.add_child(
            ExpectAlert(AlertLevel.fatal, AlertDescription.illegal_parameter))
        node.add_child(ExpectClose())
        conversations["{0} - right 0-padded key_share".format(
            GroupName.toRepr(group))] = conversation

        # truncated representation
        conversation = Connect(host, port)
        node = conversation
        ciphers = [
            CipherSuite.TLS_AES_128_GCM_SHA256,
            CipherSuite.TLS_EMPTY_RENEGOTIATION_INFO_SCSV
        ]
        ext = {}
        key_share = key_share_gen(group)
        key_share.key_exchange.pop()
        ext[ExtensionType.key_share] = ClientKeyShareExtension().create(
            [key_share])
        ext[ExtensionType.supported_versions] = SupportedVersionsExtension()\
            .create([TLS_1_3_DRAFT, (3, 3)])
        ext[ExtensionType.supported_groups] = SupportedGroupsExtension()\
            .create([group])
        ext[ExtensionType.signature_algorithms] = SignatureAlgorithmsExtension()\
            .create(sig_algs)
        ext[ExtensionType.signature_algorithms_cert] = SignatureAlgorithmsCertExtension()\
            .create(SIG_ALL)
        node = node.add_child(ClientHelloGenerator(ciphers, extensions=ext))
        node = node.add_child(
            ExpectAlert(AlertLevel.fatal, AlertDescription.illegal_parameter))
        node.add_child(ExpectClose())
        conversations["{0} - right-truncated key_share".format(
            GroupName.toRepr(group))] = conversation

        # key share from wrong curve
        conversation = Connect(host, port)
        node = conversation
        ciphers = [
            CipherSuite.TLS_AES_128_GCM_SHA256,
            CipherSuite.TLS_EMPTY_RENEGOTIATION_INFO_SCSV
        ]
        ext = {}
        key_share = key_share_gen(group)
        if group == GroupName.secp256r1:
            key_share2 = key_share_gen(GroupName.secp384r1)
        else:
            key_share2 = key_share_gen(GroupName.secp256r1)
        key_share.key_exchange = key_share2.key_exchange
        ext[ExtensionType.key_share] = ClientKeyShareExtension().create(
            [key_share])
        ext[ExtensionType.supported_versions] = SupportedVersionsExtension()\
            .create([TLS_1_3_DRAFT, (3, 3)])
        ext[ExtensionType.supported_groups] = SupportedGroupsExtension()\
            .create([group])
        ext[ExtensionType.signature_algorithms] = SignatureAlgorithmsExtension()\
            .create(sig_algs)
        ext[ExtensionType.signature_algorithms_cert] = SignatureAlgorithmsCertExtension()\
            .create(SIG_ALL)
        node = node.add_child(ClientHelloGenerator(ciphers, extensions=ext))
        node = node.add_child(
            ExpectAlert(AlertLevel.fatal, AlertDescription.illegal_parameter))
        node.add_child(ExpectClose())
        conversations["{0} - key share from other curve".format(
            GroupName.toRepr(group))] = conversation

        # 0-point
        conversation = Connect(host, port)
        node = conversation
        ciphers = [
            CipherSuite.TLS_AES_128_GCM_SHA256,
            CipherSuite.TLS_EMPTY_RENEGOTIATION_INFO_SCSV
        ]
        ext = {}
        key_share = KeyShareEntry().create(group, bytearray(b'\x00'))
        ext[ExtensionType.key_share] = ClientKeyShareExtension().create(
            [key_share])
        ext[ExtensionType.supported_versions] = SupportedVersionsExtension()\
            .create([TLS_1_3_DRAFT, (3, 3)])
        ext[ExtensionType.supported_groups] = SupportedGroupsExtension()\
            .create([group])
        ext[ExtensionType.signature_algorithms] = SignatureAlgorithmsExtension()\
            .create(sig_algs)
        ext[ExtensionType.signature_algorithms_cert] = SignatureAlgorithmsCertExtension()\
            .create(SIG_ALL)
        node = node.add_child(ClientHelloGenerator(ciphers, extensions=ext))
        node = node.add_child(
            ExpectAlert(AlertLevel.fatal, AlertDescription.illegal_parameter))
        node.add_child(ExpectClose())
        conversations["{0} - point at infinity".format(
            GroupName.toRepr(group))] = conversation

        if group not in (GroupName.x25519, GroupName.x448):
            # points not on curve
            conversation = Connect(host, port)
            node = conversation
            ciphers = [
                CipherSuite.TLS_AES_128_GCM_SHA256,
                CipherSuite.TLS_EMPTY_RENEGOTIATION_INFO_SCSV
            ]
            ext = {}
            key_share = key_share_gen(group)
            key_share.key_exchange[-1] ^= 0xff
            ext[ExtensionType.key_share] = ClientKeyShareExtension()\
                .create([key_share])
            ext[ExtensionType.supported_versions] = \
                SupportedVersionsExtension()\
                .create([TLS_1_3_DRAFT, (3, 3)])
            ext[ExtensionType.supported_groups] = SupportedGroupsExtension()\
                .create([group])
            ext[ExtensionType.signature_algorithms] = \
                SignatureAlgorithmsExtension()\
                .create(sig_algs)
            ext[ExtensionType.signature_algorithms_cert] = \
                SignatureAlgorithmsCertExtension()\
                .create(SIG_ALL)
            node = node.add_child(ClientHelloGenerator(ciphers,
                                                       extensions=ext))
            node = node.add_child(
                ExpectAlert(AlertLevel.fatal,
                            AlertDescription.illegal_parameter))
            node.add_child(ExpectClose())
            conversations["{0} - point outside curve".format(
                GroupName.toRepr(group))] = conversation

            # all-zero point
            conversation = Connect(host, port)
            node = conversation
            ciphers = [
                CipherSuite.TLS_AES_128_GCM_SHA256,
                CipherSuite.TLS_EMPTY_RENEGOTIATION_INFO_SCSV
            ]
            ext = {}
            key_share = key_share_gen(group)
            key_share.key_exchange = bytearray(len(key_share.key_exchange))
            key_share.key_exchange[
                0] = 0x04  # SEC1 uncompressed point encoding
            ext[ExtensionType.key_share] = ClientKeyShareExtension()\
                .create([key_share])
            ext[ExtensionType.supported_versions] = \
                SupportedVersionsExtension()\
                .create([TLS_1_3_DRAFT, (3, 3)])
            ext[ExtensionType.supported_groups] = SupportedGroupsExtension()\
                .create([group])
            ext[ExtensionType.signature_algorithms] = \
                SignatureAlgorithmsExtension()\
                .create(sig_algs)
            ext[ExtensionType.signature_algorithms_cert] = \
                SignatureAlgorithmsCertExtension()\
                .create(SIG_ALL)
            node = node.add_child(ClientHelloGenerator(ciphers,
                                                       extensions=ext))
            node = node.add_child(
                ExpectAlert(AlertLevel.fatal,
                            AlertDescription.illegal_parameter))
            node.add_child(ExpectClose())
            conversations["{0} - x=0, y=0".format(
                GroupName.toRepr(group))] = conversation

    # run the conversation
    good = 0
    bad = 0
    xfail = 0
    xpass = 0
    failed = []
    xpassed = []
    if not num_limit:
        num_limit = len(conversations)

    # make sure that sanity test is run first and last
    # to verify that server was running and kept running throughout
    sanity_tests = [('sanity', conversations['sanity'])]
    regular_tests = [(k, v) for k, v in conversations.items() if k != 'sanity']
    sampled_tests = sample(regular_tests, min(num_limit, len(regular_tests)))
    ordered_tests = chain(sanity_tests, sampled_tests, sanity_tests)

    for c_name, c_test in ordered_tests:
        if run_only and c_name not in run_only or c_name in run_exclude:
            continue
        print("{0} ...".format(c_name))

        runner = Runner(c_test)

        res = True
        exception = None
        try:
            runner.run()
        except Exception as exp:
            exception = exp
            print("Error while processing")
            print(traceback.format_exc())
            res = False

        if c_name in expected_failures:
            if res:
                xpass += 1
                xpassed.append(c_name)
                print("XPASS: expected failure but test passed\n")
            else:
                if expected_failures[c_name] is not None and  \
                    expected_failures[c_name] not in str(exception):
                    bad += 1
                    failed.append(c_name)
                    print("Expected error message: {0}\n".format(
                        expected_failures[c_name]))
                else:
                    xfail += 1
                    print("OK-expected failure\n")
        else:
            if res:
                good += 1
                print("OK\n")
            else:
                bad += 1
                failed.append(c_name)

    print("Basic ECDHE curve tests in TLS 1.3")
    print("Check if invalid, malformed and incompatible curve key_shares are")
    print("rejected by server")
    print("See also: test-tls13-crfg-curves.py")
    print("version: {0}\n".format(version))

    print("Test end")
    print(20 * '=')
    print("TOTAL: {0}".format(len(sampled_tests) + 2 * len(sanity_tests)))
    print("SKIP: {0}".format(
        len(run_exclude.intersection(conversations.keys()))))
    print("PASS: {0}".format(good))
    print("XFAIL: {0}".format(xfail))
    print("FAIL: {0}".format(bad))
    print("XPASS: {0}".format(xpass))
    print(20 * '=')
    sort = sorted(xpassed, key=natural_sort_keys)
    if len(sort):
        print("XPASSED:\n\t{0}".format('\n\t'.join(repr(i) for i in sort)))
    sort = sorted(failed, key=natural_sort_keys)
    if len(sort):
        print("FAILED:\n\t{0}".format('\n\t'.join(repr(i) for i in sort)))

    if bad > 0:
        sys.exit(1)
Beispiel #6
0
def main():
    host = "localhost"
    port = 4433
    num_limit = None
    run_exclude = set()

    argv = sys.argv[1:]
    opts, args = getopt.getopt(argv, "h:p:e:n:", ["help"])
    for opt, arg in opts:
        if opt == '-h':
            host = arg
        elif opt == '-p':
            port = int(arg)
        elif opt == '-e':
            run_exclude.add(arg)
        elif opt == '-n':
            num_limit = int(arg)
        elif opt == '--help':
            help_msg()
            sys.exit(0)
        else:
            raise ValueError("Unknown option: {0}".format(opt))

    if args:
        run_only = set(args)
    else:
        run_only = None

    conversations = {}

    conversation = Connect(host, port)
    node = conversation
    ciphers = [
        CipherSuite.TLS_AES_128_GCM_SHA256,
        CipherSuite.TLS_EMPTY_RENEGOTIATION_INFO_SCSV
    ]
    ext = {}
    groups = [GroupName.secp256r1]
    key_shares = []
    for group in groups:
        key_shares.append(key_share_gen(group))
    ext[ExtensionType.key_share] = ClientKeyShareExtension().create(key_shares)
    ext[ExtensionType.supported_versions] = SupportedVersionsExtension()\
        .create([TLS_1_3_DRAFT, (3, 3)])
    ext[ExtensionType.supported_groups] = SupportedGroupsExtension()\
        .create(groups)
    sig_algs = [
        SignatureScheme.rsa_pss_rsae_sha256, SignatureScheme.rsa_pss_pss_sha256
    ]
    ext[ExtensionType.signature_algorithms] = SignatureAlgorithmsExtension()\
        .create(sig_algs)
    ext[ExtensionType.signature_algorithms_cert] = SignatureAlgorithmsCertExtension()\
        .create(RSA_SIG_ALL)
    node = node.add_child(ClientHelloGenerator(ciphers, extensions=ext))
    node = node.add_child(ExpectServerHello())
    node = node.add_child(ExpectChangeCipherSpec())
    node = node.add_child(ExpectEncryptedExtensions())
    node = node.add_child(ExpectCertificate())
    node = node.add_child(ExpectCertificateVerify())
    node = node.add_child(ExpectFinished())
    node = node.add_child(FinishedGenerator())
    node = node.add_child(
        ApplicationDataGenerator(bytearray(b"GET / HTTP/1.0\r\n\r\n")))

    # This message is optional and may show up 0 to many times
    cycle = ExpectNewSessionTicket()
    node = node.add_child(cycle)
    node.add_child(cycle)

    node.next_sibling = ExpectApplicationData()
    node = node.next_sibling.add_child(
        AlertGenerator(AlertLevel.warning, AlertDescription.close_notify))

    node = node.add_child(ExpectAlert())
    node.next_sibling = ExpectClose()
    conversations["sanity"] = conversation

    # https://tools.ietf.org/html/rfc8446#appendix-B.3.1.4
    obsolete_groups = chain(range(0x0001,
                                  0x0016 + 1), range(0x001A, 0x001C + 1),
                            range(0xFF01, 0XFF02 + 1))

    for obsolete_group in obsolete_groups:
        obsolete_group_name = (GroupName.toRepr(obsolete_group)
                               or "unknown ({0})".format(obsolete_group))
        conversation = Connect(host, port)
        node = conversation
        ciphers = [
            CipherSuite.TLS_AES_128_GCM_SHA256,
            CipherSuite.TLS_EMPTY_RENEGOTIATION_INFO_SCSV
        ]
        ext = {}
        groups = [obsolete_group]
        try:
            key_shares = []
            for group in groups:
                key_shares.append(key_share_gen(group))
        except ValueError:
            # bogus value to move on, if it makes problems, these won't result in handshake_failure
            key_shares = [
                KeyShareEntry().create(obsolete_group, bytearray(b'\xab' * 32))
            ]
        ext[ExtensionType.key_share] = ClientKeyShareExtension().create(
            key_shares)
        ext[ExtensionType.supported_versions] = SupportedVersionsExtension()\
            .create([TLS_1_3_DRAFT, (3, 3)])
        ext[ExtensionType.supported_groups] = SupportedGroupsExtension()\
            .create(groups)
        sig_algs = [
            SignatureScheme.rsa_pss_rsae_sha256,
            SignatureScheme.rsa_pss_pss_sha256
        ]
        ext[ExtensionType.signature_algorithms] = SignatureAlgorithmsExtension()\
            .create(sig_algs)
        ext[ExtensionType.signature_algorithms_cert] = SignatureAlgorithmsCertExtension()\
            .create(RSA_SIG_ALL)
        node = node.add_child(ClientHelloGenerator(ciphers, extensions=ext))
        node = node.add_child(
            ExpectAlert(AlertLevel.fatal, AlertDescription.handshake_failure))

        node = node.add_child(ExpectClose())
        conversation_name = "{0} should be handshake_failed in TLS 1.3".format(
            obsolete_group_name)
        conversations[conversation_name] = conversation

    # run the conversation
    good = 0
    bad = 0
    failed = []
    if not num_limit:
        num_limit = len(conversations)

    # make sure that sanity test is run first and last
    # to verify that server was running and kept running throughout
    sanity_tests = [('sanity', conversations['sanity'])]
    regular_tests = [(k, v) for k, v in conversations.items() if k != 'sanity']
    sampled_tests = sample(regular_tests, min(num_limit, len(regular_tests)))
    ordered_tests = chain(sanity_tests, sampled_tests, sanity_tests)

    for c_name, c_test in ordered_tests:
        if run_only and c_name not in run_only or c_name in run_exclude:
            continue
        print("{0} ...".format(c_name))

        runner = Runner(c_test)

        res = True
        try:
            runner.run()
        except Exception:
            print("Error while processing")
            print(traceback.format_exc())
            res = False

        if res:
            good += 1
            print("OK\n")
        else:
            bad += 1
            failed.append(c_name)

    print("Negotiating obsolete curves with TLS 1.3 server")
    print("Check that TLS 1.3 server will not use obsolete curves and")
    print("will reject the connection with handshake_failure alert.")
    print("Reproduces https://github.com/openssl/openssl/issues/8369\n")
    print("version: {0}\n".format(version))

    print("Test end")
    print("successful: {0}".format(good))
    print("failed: {0}".format(bad))
    failed_sorted = sorted(failed, key=natural_sort_keys)
    print("  {0}".format('\n  '.join(repr(i) for i in failed_sorted)))

    if bad > 0:
        sys.exit(1)
Beispiel #7
0
    def test(self):

        sock = MockSocket(server_hello_ciphertext)

        record_layer = RecordLayer(sock)

        ext = [SNIExtension().create(bytearray(b'server')),
               TLSExtension(extType=ExtensionType.renegotiation_info)
               .create(bytearray(b'\x00')),
               SupportedGroupsExtension().create([GroupName.x25519,
                                                  GroupName.secp256r1,
                                                  GroupName.secp384r1,
                                                  GroupName.secp521r1,
                                                  GroupName.ffdhe2048,
                                                  GroupName.ffdhe3072,
                                                  GroupName.ffdhe4096,
                                                  GroupName.ffdhe6144,
                                                  GroupName.ffdhe8192]),
               ECPointFormatsExtension().create([ECPointFormat.uncompressed]),
               TLSExtension(extType=35),
               ClientKeyShareExtension().create([KeyShareEntry().create(GroupName.x25519,
                                                client_key_public,
                                                client_key_private)]),
               SupportedVersionsExtension().create([TLS_1_3_DRAFT,
                                                    (3, 3), (3, 2)]),
               SignatureAlgorithmsExtension().create([(HashAlgorithm.sha256,
                                                       SignatureAlgorithm.ecdsa),
                                                      (HashAlgorithm.sha384,
                                                       SignatureAlgorithm.ecdsa),
                                                      (HashAlgorithm.sha512,
                                                       SignatureAlgorithm.ecdsa),
                                                      (HashAlgorithm.sha1,
                                                       SignatureAlgorithm.ecdsa),
                                                      SignatureScheme.rsa_pss_sha256,
                                                      SignatureScheme.rsa_pss_sha384,
                                                      SignatureScheme.rsa_pss_sha512,
                                                      SignatureScheme.rsa_pkcs1_sha256,
                                                      SignatureScheme.rsa_pkcs1_sha384,
                                                      SignatureScheme.rsa_pkcs1_sha512,
                                                      SignatureScheme.rsa_pkcs1_sha1,
                                                      (HashAlgorithm.sha256,
                                                       SignatureAlgorithm.dsa),
                                                      (HashAlgorithm.sha384,
                                                       SignatureAlgorithm.dsa),
                                                      (HashAlgorithm.sha512,
                                                       SignatureAlgorithm.dsa),
                                                      (HashAlgorithm.sha1,
                                                       SignatureAlgorithm.dsa)]),
                TLSExtension(extType=45).create(bytearray(b'\x01\x01')),
                TLSExtension(extType=ExtensionType.client_hello_padding)
                .create(bytearray(252))
               ]
        client_hello = ClientHello()
        client_hello.create((3, 3),
                            bytearray(b'\xaf!\x15k\x04\xdbc\x9ef\x15J\x1f\xe5'
                                      b'\xad\xfa\xea\xdf\x9eA4\x16\x00\rW\xb8'
                                      b'\xe1\x12mM\x11\x9a\x8b'),
                            bytearray(b''),
                            [CipherSuite.TLS_AES_128_GCM_SHA256,
                             CipherSuite.TLS_CHACHA20_POLY1305_SHA256,
                             CipherSuite.TLS_AES_256_GCM_SHA384,
                             CipherSuite.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256,
                             CipherSuite.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256,
                             0xCCA9,
                             CipherSuite.TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256,
                             CipherSuite.TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA,
                             CipherSuite.TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA,
                             CipherSuite.TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA,
                             CipherSuite.TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA256,
                             CipherSuite.TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA256,
                             CipherSuite.TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA,
                             CipherSuite.TLS_DHE_RSA_WITH_AES_128_GCM_SHA256,
                             CipherSuite.TLS_DHE_RSA_WITH_CHACHA20_POLY1305_SHA256,
                             CipherSuite.TLS_DHE_RSA_WITH_AES_128_CBC_SHA,
                             0x0032,
                             CipherSuite.TLS_DHE_RSA_WITH_AES_128_CBC_SHA256,
                             CipherSuite.TLS_DHE_RSA_WITH_AES_256_CBC_SHA,
                             0x0038,
                             CipherSuite.TLS_DHE_RSA_WITH_AES_256_CBC_SHA256,
                             CipherSuite.TLS_DHE_RSA_WITH_3DES_EDE_CBC_SHA,
                             0x0013,
                             CipherSuite.TLS_RSA_WITH_AES_128_GCM_SHA256,
                             CipherSuite.TLS_RSA_WITH_AES_128_CBC_SHA,
                             CipherSuite.TLS_RSA_WITH_AES_128_CBC_SHA256,
                             CipherSuite.TLS_RSA_WITH_AES_256_CBC_SHA,
                             CipherSuite.TLS_RSA_WITH_AES_256_CBC_SHA256,
                             CipherSuite.TLS_RSA_WITH_3DES_EDE_CBC_SHA,
                             CipherSuite.TLS_RSA_WITH_RC4_128_SHA,
                             CipherSuite.TLS_RSA_WITH_RC4_128_MD5],
                            extensions=ext)

        self.assertEqual(client_hello.write(), client_hello_ciphertext[5:])

        for result in record_layer.recvRecord():
            # check if non-blocking
            self.assertNotIn(result, (0, 1))
        header, parser = result
        hs_type = parser.get(1)
        self.assertEqual(hs_type, HandshakeType.server_hello)
        server_hello = ServerHello().parse(parser)

        self.assertEqual(server_hello.server_version, TLS_1_3_DRAFT)
        self.assertEqual(server_hello.cipher_suite, CipherSuite.TLS_AES_128_GCM_SHA256)

        server_key_share = server_hello.getExtension(ExtensionType.key_share)
        server_key_share = server_key_share.server_share

        self.assertEqual(server_key_share.group, GroupName.x25519)

        # for TLS_AES_128_GCM_SHA256:
        prf_name = 'sha256'
        prf_size = 256 // 8
        secret = bytearray(prf_size)
        psk = bytearray(prf_size)

        # early secret
        secret = secureHMAC(secret, psk, prf_name)

        self.assertEqual(secret,
                         str_to_bytearray(
                             "33ad0a1c607ec03b 09e6cd9893680ce2"
                             "10adf300aa1f2660 e1b22e10f170f92a"))

        # derive secret for handshake
        secret = derive_secret(secret, b"derived", None, prf_name)

        self.assertEqual(secret,
                         str_to_bytearray(
                             "6f2615a108c702c5 678f54fc9dbab697"
                             "16c076189c48250c ebeac3576c3611ba"))

        # extract secret "handshake"
        Z = x25519(client_key_private, server_key_share.key_exchange)

        self.assertEqual(Z,
                         str_to_bytearray(
                             "f677c3cdac26a755 455b130efa9b1a3f"
                             "3cafb153544ca46a ddf670df199d996e"))

        secret = secureHMAC(secret, Z, prf_name)

        self.assertEqual(secret,
                         str_to_bytearray(
                             "0cefce00d5d29fd0 9f5de36c86fc8e72"
                             "99b4ad11ba4211c6 7063c2cc539fc4f9"))

        handshake_hashes = HandshakeHashes()
        handshake_hashes.update(client_hello_plaintext)
        handshake_hashes.update(server_hello_payload)

        # derive "tls13 c hs traffic"
        c_hs_traffic = derive_secret(secret,
                                     bytearray(b'c hs traffic'),
                                     handshake_hashes,
                                     prf_name)
        self.assertEqual(c_hs_traffic,
                         str_to_bytearray(
                             "5a63db760b817b1b da96e72832333aec"
                             "6a177deeadb5b407 501ac10c17dac0a4"))
        s_hs_traffic = derive_secret(secret,
                                     bytearray(b's hs traffic'),
                                     handshake_hashes,
                                     prf_name)
        self.assertEqual(s_hs_traffic,
                         str_to_bytearray(
                             "3aa72a3c77b791e8 f4de243f9ccce172"
                             "941f8392aeb05429 320f4b572ccfe744"))

        # derive master secret
        secret = derive_secret(secret, b"derived", None, prf_name)

        self.assertEqual(secret,
                         str_to_bytearray(
                             "32cadf38f3089048 5c54bf4f1184eaa5"
                             "569eeef15a43f3c7 6ab33965a47c9ff6"))

        # extract secret "master
        secret = secureHMAC(secret, bytearray(prf_size), prf_name)

        self.assertEqual(secret,
                         str_to_bytearray(
                             "6c6d4b3e7c925460 82d7b7a32f6ce219"
                             "3804f1bb930fed74 5c6b93c71397f424"))
def main():
    host = "localhost"
    port = 4433
    num_limit = None
    run_exclude = set()
    num_bytes = 2**14
    cookie = False

    argv = sys.argv[1:]
    opts, args = getopt.getopt(argv, "h:p:e:n:",
                               ["help", "num-bytes=", "cookie"])
    for opt, arg in opts:
        if opt == '-h':
            host = arg
        elif opt == '-p':
            port = int(arg)
        elif opt == '-e':
            run_exclude.add(arg)
        elif opt == '-n':
            num_limit = int(arg)
        elif opt == '--help':
            help_msg()
            sys.exit(0)
        elif opt == '--num-bytes':
            num_bytes = int(arg)
        elif opt == '--cookie':
            cookie = True
        else:
            raise ValueError("Unknown option: {0}".format(opt))

    if args:
        run_only = set(args)
    else:
        run_only = None

    conversations = {}

    # sanity check
    conversation = Connect(host, port)
    node = conversation
    ciphers = [
        CipherSuite.TLS_AES_128_GCM_SHA256,
        CipherSuite.TLS_EMPTY_RENEGOTIATION_INFO_SCSV
    ]
    ext = {}
    groups = [GroupName.secp256r1]
    key_shares = []
    for group in groups:
        key_shares.append(key_share_gen(group))
    ext[ExtensionType.key_share] = ClientKeyShareExtension().create(key_shares)
    ext[ExtensionType.supported_versions] = SupportedVersionsExtension()\
        .create([TLS_1_3_DRAFT, (3, 3)])
    ext[ExtensionType.supported_groups] = SupportedGroupsExtension()\
        .create(groups)
    sig_algs = [
        SignatureScheme.rsa_pss_rsae_sha256, SignatureScheme.rsa_pss_pss_sha256
    ]
    ext[ExtensionType.signature_algorithms] = SignatureAlgorithmsExtension()\
        .create(sig_algs)
    ext[ExtensionType.signature_algorithms_cert] = SignatureAlgorithmsCertExtension()\
        .create(RSA_SIG_ALL)
    node = node.add_child(ClientHelloGenerator(ciphers, extensions=ext))
    node = node.add_child(ExpectServerHello())
    node = node.add_child(ExpectChangeCipherSpec())
    node = node.add_child(ExpectEncryptedExtensions())
    node = node.add_child(ExpectCertificate())
    node = node.add_child(ExpectCertificateVerify())
    node = node.add_child(ExpectFinished())
    node = node.add_child(FinishedGenerator())
    node = node.add_child(
        ApplicationDataGenerator(bytearray(b"GET / HTTP/1.0\r\n\r\n")))

    # This message is optional and may show up 0 to many times
    cycle = ExpectNewSessionTicket()
    node = node.add_child(cycle)
    node.add_child(cycle)

    node.next_sibling = ExpectApplicationData()
    node = node.next_sibling.add_child(
        AlertGenerator(AlertLevel.warning, AlertDescription.close_notify))

    node = node.add_child(ExpectAlert())
    node.next_sibling = ExpectClose()
    conversations["sanity"] = conversation

    # sanity check with PSK binders
    conversation = Connect(host, port)
    node = conversation
    ciphers = [
        CipherSuite.TLS_AES_128_GCM_SHA256,
        CipherSuite.TLS_EMPTY_RENEGOTIATION_INFO_SCSV
    ]
    ext = OrderedDict()
    groups = [GroupName.secp256r1]
    key_shares = []
    for group in groups:
        key_shares.append(key_share_gen(group))
    ext[ExtensionType.key_share] = ClientKeyShareExtension().create(key_shares)
    ext[ExtensionType.supported_versions] = SupportedVersionsExtension()\
        .create([TLS_1_3_DRAFT, (3, 3)])
    ext[ExtensionType.supported_groups] = SupportedGroupsExtension()\
        .create(groups)
    sig_algs = [
        SignatureScheme.rsa_pss_rsae_sha256, SignatureScheme.rsa_pss_pss_sha256
    ]
    ext[ExtensionType.signature_algorithms] = SignatureAlgorithmsExtension()\
        .create(sig_algs)
    ext[ExtensionType.signature_algorithms_cert] = SignatureAlgorithmsCertExtension()\
        .create(RSA_SIG_ALL)
    ext[ExtensionType.psk_key_exchange_modes] = PskKeyExchangeModesExtension()\
        .create([PskKeyExchangeMode.psk_dhe_ke, PskKeyExchangeMode.psk_ke])
    iden = PskIdentity().create(getRandomBytes(320), 0)
    bind = getRandomBytes(32)
    ext[ExtensionType.pre_shared_key] = PreSharedKeyExtension().create([iden],
                                                                       [bind])
    node = node.add_child(ClientHelloGenerator(ciphers, extensions=ext))
    node = node.add_child(ExpectServerHello())
    node = node.add_child(ExpectChangeCipherSpec())
    node = node.add_child(ExpectEncryptedExtensions())
    node = node.add_child(ExpectCertificate())
    node = node.add_child(ExpectCertificateVerify())
    node = node.add_child(ExpectFinished())
    node = node.add_child(FinishedGenerator())
    node = node.add_child(
        ApplicationDataGenerator(bytearray(b"GET / HTTP/1.0\r\n\r\n")))

    # This message is optional and may show up 0 to many times
    cycle = ExpectNewSessionTicket()
    node = node.add_child(cycle)
    node.add_child(cycle)

    node.next_sibling = ExpectApplicationData()
    node = node.next_sibling.add_child(
        AlertGenerator(AlertLevel.warning, AlertDescription.close_notify))

    node = node.add_child(ExpectAlert())
    node.next_sibling = ExpectClose()
    conversations["handshake with invalid PSK"] = conversation

    # fake 0-RTT resumption with HRR and early data after second client hello
    conversation = Connect(host, port)
    node = conversation
    ciphers = [
        CipherSuite.TLS_AES_128_GCM_SHA256,
        CipherSuite.TLS_EMPTY_RENEGOTIATION_INFO_SCSV
    ]
    ext = OrderedDict()
    groups = [0x1300, GroupName.secp256r1]
    key_shares = [KeyShareEntry().create(0x1300, bytearray(b'\xab' * 32))]
    ext[ExtensionType.key_share] = ClientKeyShareExtension().create(key_shares)
    ext[ExtensionType.supported_versions] = SupportedVersionsExtension()\
        .create([TLS_1_3_DRAFT, (3, 3)])
    ext[ExtensionType.supported_groups] = SupportedGroupsExtension()\
        .create(groups)
    sig_algs = [
        SignatureScheme.rsa_pss_rsae_sha256, SignatureScheme.rsa_pss_pss_sha256
    ]
    ext[ExtensionType.signature_algorithms] = SignatureAlgorithmsExtension()\
        .create(sig_algs)
    ext[ExtensionType.signature_algorithms_cert] = SignatureAlgorithmsCertExtension()\
        .create(RSA_SIG_ALL)
    ext[ExtensionType.early_data] = \
        TLSExtension(extType=ExtensionType.early_data)
    ext[ExtensionType.psk_key_exchange_modes] = PskKeyExchangeModesExtension()\
        .create([PskKeyExchangeMode.psk_dhe_ke, PskKeyExchangeMode.psk_ke])
    iden = PskIdentity().create(getRandomBytes(320),
                                getRandomNumber(2**30, 2**32))
    bind = getRandomBytes(32)
    ext[ExtensionType.pre_shared_key] = PreSharedKeyExtension().create([iden],
                                                                       [bind])
    node = node.add_child(TCPBufferingEnable())
    node = node.add_child(ClientHelloGenerator(ciphers, extensions=ext))
    node = node.add_child(SetRecordVersion((3, 3)))
    node = node.add_child(ApplicationDataGenerator(getRandomBytes(num_bytes)))
    node = node.add_child(TCPBufferingDisable())
    node = node.add_child(TCPBufferingFlush())

    ext = {}
    if cookie:
        ext[ExtensionType.cookie] = None
    ext[ExtensionType.key_share] = None
    ext[ExtensionType.supported_versions] = None
    node = node.add_child(ExpectHelloRetryRequest(extensions=ext))
    node = node.add_child(ExpectChangeCipherSpec())

    ext = OrderedDict()
    key_shares = []
    for group in [GroupName.secp256r1]:
        key_shares.append(key_share_gen(group))
    ext[ExtensionType.key_share] = ClientKeyShareExtension().create(key_shares)
    ext[ExtensionType.supported_versions] = SupportedVersionsExtension()\
        .create([TLS_1_3_DRAFT, (3, 3)])
    ext[ExtensionType.supported_groups] = SupportedGroupsExtension()\
        .create(groups)
    ext[ExtensionType.signature_algorithms] = SignatureAlgorithmsExtension()\
        .create(sig_algs)
    ext[ExtensionType.signature_algorithms_cert] = SignatureAlgorithmsCertExtension()\
        .create(RSA_SIG_ALL)
    if cookie:
        ext[ExtensionType.cookie] = ch_cookie_handler
    ext[ExtensionType.psk_key_exchange_modes] = PskKeyExchangeModesExtension()\
        .create([PskKeyExchangeMode.psk_dhe_ke, PskKeyExchangeMode.psk_ke])
    ext[ExtensionType.pre_shared_key] = PreSharedKeyExtension().create(
        [iden], [getRandomBytes(32)])

    node = node.add_child(ClientHelloGenerator(ciphers, extensions=ext))

    node = node.add_child(ExpectServerHello())
    node = node.add_child(ExpectEncryptedExtensions())
    node = node.add_child(ExpectCertificate())
    node = node.add_child(ExpectCertificateVerify())
    node = node.add_child(ExpectFinished())
    node = node.add_child(
        PlaintextMessageGenerator(ContentType.application_data,
                                  getRandomBytes(64)))

    # This message is optional and may show up 0 to many times
    cycle = ExpectNewSessionTicket()
    node = node.add_child(cycle)
    node.add_child(cycle)

    node.next_sibling = ExpectAlert(AlertLevel.fatal,
                                    AlertDescription.bad_record_mac)
    node.next_sibling.add_child(ExpectClose())
    conversations["handshake with 0-RTT, HRR and early data after 2nd Client Hello"]\
        = conversation

    # fake 0-RTT resumption with HRR
    conversation = Connect(host, port)
    node = conversation
    ciphers = [
        CipherSuite.TLS_AES_128_GCM_SHA256,
        CipherSuite.TLS_EMPTY_RENEGOTIATION_INFO_SCSV
    ]
    ext = OrderedDict()
    groups = [0x1300, GroupName.secp256r1]
    key_shares = [KeyShareEntry().create(0x1300, bytearray(b'\xab' * 32))]
    ext[ExtensionType.key_share] = ClientKeyShareExtension().create(key_shares)
    ext[ExtensionType.supported_versions] = SupportedVersionsExtension()\
        .create([TLS_1_3_DRAFT, (3, 3)])
    ext[ExtensionType.supported_groups] = SupportedGroupsExtension()\
        .create(groups)
    sig_algs = [
        SignatureScheme.rsa_pss_rsae_sha256, SignatureScheme.rsa_pss_pss_sha256
    ]
    ext[ExtensionType.signature_algorithms] = SignatureAlgorithmsExtension()\
        .create(sig_algs)
    ext[ExtensionType.signature_algorithms_cert] = SignatureAlgorithmsCertExtension()\
        .create(RSA_SIG_ALL)
    ext[ExtensionType.early_data] = \
        TLSExtension(extType=ExtensionType.early_data)
    ext[ExtensionType.psk_key_exchange_modes] = PskKeyExchangeModesExtension()\
        .create([PskKeyExchangeMode.psk_dhe_ke, PskKeyExchangeMode.psk_ke])
    iden = PskIdentity().create(getRandomBytes(320),
                                getRandomNumber(2**30, 2**32))
    bind = getRandomBytes(32)
    ext[ExtensionType.pre_shared_key] = PreSharedKeyExtension().create([iden],
                                                                       [bind])
    node = node.add_child(TCPBufferingEnable())
    node = node.add_child(ClientHelloGenerator(ciphers, extensions=ext))
    node = node.add_child(SetRecordVersion((3, 3)))
    node = node.add_child(ApplicationDataGenerator(getRandomBytes(num_bytes)))
    node = node.add_child(TCPBufferingDisable())
    node = node.add_child(TCPBufferingFlush())

    ext = {}
    if cookie:
        ext[ExtensionType.cookie] = None
    ext[ExtensionType.key_share] = None
    ext[ExtensionType.supported_versions] = None
    node = node.add_child(ExpectHelloRetryRequest(extensions=ext))
    node = node.add_child(ExpectChangeCipherSpec())

    ext = OrderedDict()
    key_shares = []
    for group in [GroupName.secp256r1]:
        key_shares.append(key_share_gen(group))
    ext[ExtensionType.key_share] = ClientKeyShareExtension().create(key_shares)
    ext[ExtensionType.supported_versions] = SupportedVersionsExtension()\
        .create([TLS_1_3_DRAFT, (3, 3)])
    ext[ExtensionType.supported_groups] = SupportedGroupsExtension()\
        .create(groups)
    ext[ExtensionType.signature_algorithms] = SignatureAlgorithmsExtension()\
        .create(sig_algs)
    ext[ExtensionType.signature_algorithms_cert] = SignatureAlgorithmsCertExtension()\
        .create(RSA_SIG_ALL)
    if cookie:
        ext[ExtensionType.cookie] = ch_cookie_handler
    ext[ExtensionType.psk_key_exchange_modes] = PskKeyExchangeModesExtension()\
        .create([PskKeyExchangeMode.psk_dhe_ke, PskKeyExchangeMode.psk_ke])
    ext[ExtensionType.pre_shared_key] = PreSharedKeyExtension().create(
        [iden], [getRandomBytes(32)])

    node = node.add_child(ClientHelloGenerator(ciphers, extensions=ext))

    node = node.add_child(ExpectServerHello())
    node = node.add_child(ExpectEncryptedExtensions())
    node = node.add_child(ExpectCertificate())
    node = node.add_child(ExpectCertificateVerify())
    node = node.add_child(ExpectFinished())
    node = node.add_child(FinishedGenerator())
    node = node.add_child(
        ApplicationDataGenerator(bytearray(b"GET / HTTP/1.0\r\n\r\n")))

    # This message is optional and may show up 0 to many times
    cycle = ExpectNewSessionTicket()
    node = node.add_child(cycle)
    node.add_child(cycle)

    node.next_sibling = ExpectApplicationData()
    node = node.next_sibling.add_child(
        AlertGenerator(AlertLevel.warning, AlertDescription.close_notify))

    node = node.add_child(ExpectAlert())
    node.next_sibling = ExpectClose()
    conversations["handshake with invalid 0-RTT and HRR"] = conversation

    # fake 0-RTT resumption with fragmented early data
    conversation = Connect(host, port)
    node = conversation
    ciphers = [
        CipherSuite.TLS_AES_128_GCM_SHA256,
        CipherSuite.TLS_EMPTY_RENEGOTIATION_INFO_SCSV
    ]
    ext = OrderedDict()
    groups = [GroupName.secp256r1]
    key_shares = []
    for group in groups:
        key_shares.append(key_share_gen(group))
    ext[ExtensionType.key_share] = ClientKeyShareExtension().create(key_shares)
    ext[ExtensionType.supported_versions] = SupportedVersionsExtension()\
        .create([TLS_1_3_DRAFT, (3, 3)])
    ext[ExtensionType.supported_groups] = SupportedGroupsExtension()\
        .create(groups)
    sig_algs = [
        SignatureScheme.rsa_pss_rsae_sha256, SignatureScheme.rsa_pss_pss_sha256
    ]
    ext[ExtensionType.signature_algorithms] = SignatureAlgorithmsExtension()\
        .create(sig_algs)
    ext[ExtensionType.signature_algorithms_cert] = SignatureAlgorithmsCertExtension()\
        .create(RSA_SIG_ALL)
    ext[ExtensionType.early_data] = \
        TLSExtension(extType=ExtensionType.early_data)
    ext[ExtensionType.psk_key_exchange_modes] = PskKeyExchangeModesExtension()\
        .create([PskKeyExchangeMode.psk_dhe_ke, PskKeyExchangeMode.psk_ke])
    iden = PskIdentity().create(getRandomBytes(320),
                                getRandomNumber(2**30, 2**32))
    bind = getRandomBytes(32)
    ext[ExtensionType.pre_shared_key] = PreSharedKeyExtension().create([iden],
                                                                       [bind])
    node = node.add_child(TCPBufferingEnable())
    node = node.add_child(ClientHelloGenerator(ciphers, extensions=ext))
    node = node.add_child(SetRecordVersion((3, 3)))
    node = node.add_child(
        ApplicationDataGenerator(getRandomBytes(num_bytes // 2)))
    node = node.add_child(
        ApplicationDataGenerator(getRandomBytes(num_bytes // 2)))
    node = node.add_child(TCPBufferingDisable())
    node = node.add_child(TCPBufferingFlush())
    node = node.add_child(ExpectServerHello())
    node = node.add_child(ExpectChangeCipherSpec())
    node = node.add_child(ExpectEncryptedExtensions())
    node = node.add_child(ExpectCertificate())
    node = node.add_child(ExpectCertificateVerify())
    node = node.add_child(ExpectFinished())
    node = node.add_child(FinishedGenerator())
    node = node.add_child(
        ApplicationDataGenerator(bytearray(b"GET / HTTP/1.0\r\n\r\n")))

    # This message is optional and may show up 0 to many times
    cycle = ExpectNewSessionTicket()
    node = node.add_child(cycle)
    node.add_child(cycle)

    node.next_sibling = ExpectApplicationData()
    node = node.next_sibling.add_child(
        AlertGenerator(AlertLevel.warning, AlertDescription.close_notify))

    node = node.add_child(ExpectAlert())
    node.next_sibling = ExpectClose()
    conversations["handshake with invalid 0-RTT with fragmented early data"]\
        = conversation

    # fake 0-RTT and early data spliced into the Finished message
    conversation = Connect(host, port)
    node = conversation
    ciphers = [
        CipherSuite.TLS_AES_128_GCM_SHA256,
        CipherSuite.TLS_EMPTY_RENEGOTIATION_INFO_SCSV
    ]
    ext = OrderedDict()
    groups = [GroupName.secp256r1]
    key_shares = []
    for group in groups:
        key_shares.append(key_share_gen(group))
    ext[ExtensionType.key_share] = ClientKeyShareExtension().create(key_shares)
    ext[ExtensionType.supported_versions] = SupportedVersionsExtension()\
        .create([TLS_1_3_DRAFT, (3, 3)])
    ext[ExtensionType.supported_groups] = SupportedGroupsExtension()\
        .create(groups)
    sig_algs = [
        SignatureScheme.rsa_pss_rsae_sha256, SignatureScheme.rsa_pss_pss_sha256
    ]
    ext[ExtensionType.signature_algorithms] = SignatureAlgorithmsExtension()\
        .create(sig_algs)
    ext[ExtensionType.signature_algorithms_cert] = SignatureAlgorithmsCertExtension()\
        .create(RSA_SIG_ALL)
    ext[ExtensionType.early_data] = \
        TLSExtension(extType=ExtensionType.early_data)
    ext[ExtensionType.psk_key_exchange_modes] = PskKeyExchangeModesExtension()\
        .create([PskKeyExchangeMode.psk_dhe_ke, PskKeyExchangeMode.psk_ke])
    iden = PskIdentity().create(getRandomBytes(320),
                                getRandomNumber(2**30, 2**32))
    bind = getRandomBytes(32)
    ext[ExtensionType.pre_shared_key] = PreSharedKeyExtension().create([iden],
                                                                       [bind])
    node = node.add_child(TCPBufferingEnable())
    node = node.add_child(ClientHelloGenerator(ciphers, extensions=ext))
    node = node.add_child(SetRecordVersion((3, 3)))
    node = node.add_child(ApplicationDataGenerator(getRandomBytes(num_bytes)))
    node = node.add_child(TCPBufferingDisable())
    node = node.add_child(TCPBufferingFlush())
    node = node.add_child(ExpectServerHello())
    node = node.add_child(ExpectChangeCipherSpec())
    node = node.add_child(ExpectEncryptedExtensions())
    node = node.add_child(ExpectCertificate())
    node = node.add_child(ExpectCertificateVerify())
    node = node.add_child(ExpectFinished())
    finished_fragments = []
    node = node.add_child(
        split_message(FinishedGenerator(), finished_fragments, 16))
    # early data spliced into the Finished message
    node = node.add_child(
        PlaintextMessageGenerator(ContentType.application_data,
                                  getRandomBytes(64)))

    # This message is optional and may show up 0 to many times
    cycle = ExpectNewSessionTicket()
    node = node.add_child(cycle)
    node.add_child(cycle)

    node.next_sibling = ExpectAlert(AlertLevel.fatal,
                                    AlertDescription.bad_record_mac)

    node.next_sibling.add_child(ExpectClose())
    conversations["undecryptable record later in handshake together with early_data"]\
        = conversation

    # fake 0-RTT resumption and CCS between fake early data
    conversation = Connect(host, port)
    node = conversation
    ciphers = [
        CipherSuite.TLS_AES_128_GCM_SHA256,
        CipherSuite.TLS_EMPTY_RENEGOTIATION_INFO_SCSV
    ]
    ext = OrderedDict()
    groups = [GroupName.secp256r1]
    key_shares = []
    for group in groups:
        key_shares.append(key_share_gen(group))
    ext[ExtensionType.key_share] = ClientKeyShareExtension().create(key_shares)
    ext[ExtensionType.supported_versions] = SupportedVersionsExtension()\
        .create([TLS_1_3_DRAFT, (3, 3)])
    ext[ExtensionType.supported_groups] = SupportedGroupsExtension()\
        .create(groups)
    sig_algs = [
        SignatureScheme.rsa_pss_rsae_sha256, SignatureScheme.rsa_pss_pss_sha256
    ]
    ext[ExtensionType.signature_algorithms] = SignatureAlgorithmsExtension()\
        .create(sig_algs)
    ext[ExtensionType.signature_algorithms_cert] = SignatureAlgorithmsCertExtension()\
        .create(RSA_SIG_ALL)
    ext[ExtensionType.early_data] = \
        TLSExtension(extType=ExtensionType.early_data)
    ext[ExtensionType.psk_key_exchange_modes] = PskKeyExchangeModesExtension()\
        .create([PskKeyExchangeMode.psk_dhe_ke, PskKeyExchangeMode.psk_ke])
    iden = PskIdentity().create(getRandomBytes(320),
                                getRandomNumber(2**30, 2**32))
    bind = getRandomBytes(32)
    ext[ExtensionType.pre_shared_key] = PreSharedKeyExtension().create([iden],
                                                                       [bind])
    node = node.add_child(TCPBufferingEnable())
    node = node.add_child(ClientHelloGenerator(ciphers, extensions=ext))
    node = node.add_child(SetRecordVersion((3, 3)))
    node = node.add_child(
        ApplicationDataGenerator(getRandomBytes(num_bytes // 2)))
    node = node.add_child(ChangeCipherSpecGenerator(fake=True))
    node = node.add_child(
        ApplicationDataGenerator(getRandomBytes(num_bytes // 2)))
    node = node.add_child(TCPBufferingDisable())
    node = node.add_child(TCPBufferingFlush())
    node = node.add_child(ExpectServerHello())
    node = node.add_child(ExpectChangeCipherSpec())
    node = node.add_child(ExpectEncryptedExtensions())
    node = node.add_child(ExpectCertificate())
    node = node.add_child(ExpectCertificateVerify())
    node = node.add_child(ExpectFinished())
    node = node.add_child(FinishedGenerator())
    node = node.add_child(
        ApplicationDataGenerator(bytearray(b"GET / HTTP/1.0\r\n\r\n")))

    # This message is optional and may show up 0 to many times
    cycle = ExpectNewSessionTicket()
    node = node.add_child(cycle)
    node.add_child(cycle)

    node.next_sibling = ExpectApplicationData()
    node = node.next_sibling.add_child(
        AlertGenerator(AlertLevel.warning, AlertDescription.close_notify))

    node = node.add_child(ExpectAlert())
    node.next_sibling = ExpectClose()
    conversations["handshake with invalid 0-RTT and CCS between early data records"]\
        = conversation

    # fake 0-RTT resumption and CCS
    conversation = Connect(host, port)
    node = conversation
    ciphers = [
        CipherSuite.TLS_AES_128_GCM_SHA256,
        CipherSuite.TLS_EMPTY_RENEGOTIATION_INFO_SCSV
    ]
    ext = OrderedDict()
    groups = [GroupName.secp256r1]
    key_shares = []
    for group in groups:
        key_shares.append(key_share_gen(group))
    ext[ExtensionType.key_share] = ClientKeyShareExtension().create(key_shares)
    ext[ExtensionType.supported_versions] = SupportedVersionsExtension()\
        .create([TLS_1_3_DRAFT, (3, 3)])
    ext[ExtensionType.supported_groups] = SupportedGroupsExtension()\
        .create(groups)
    sig_algs = [
        SignatureScheme.rsa_pss_rsae_sha256, SignatureScheme.rsa_pss_pss_sha256
    ]
    ext[ExtensionType.signature_algorithms] = SignatureAlgorithmsExtension()\
        .create(sig_algs)
    ext[ExtensionType.signature_algorithms_cert] = SignatureAlgorithmsCertExtension()\
        .create(RSA_SIG_ALL)
    ext[ExtensionType.early_data] = \
        TLSExtension(extType=ExtensionType.early_data)
    ext[ExtensionType.psk_key_exchange_modes] = PskKeyExchangeModesExtension()\
        .create([PskKeyExchangeMode.psk_dhe_ke, PskKeyExchangeMode.psk_ke])
    iden = PskIdentity().create(getRandomBytes(320),
                                getRandomNumber(2**30, 2**32))
    bind = getRandomBytes(32)
    ext[ExtensionType.pre_shared_key] = PreSharedKeyExtension().create([iden],
                                                                       [bind])
    node = node.add_child(TCPBufferingEnable())
    node = node.add_child(ClientHelloGenerator(ciphers, extensions=ext))
    node = node.add_child(SetRecordVersion((3, 3)))
    node = node.add_child(ChangeCipherSpecGenerator(fake=True))
    node = node.add_child(ApplicationDataGenerator(getRandomBytes(num_bytes)))
    node = node.add_child(TCPBufferingDisable())
    node = node.add_child(TCPBufferingFlush())
    node = node.add_child(ExpectServerHello())
    node = node.add_child(ExpectChangeCipherSpec())
    node = node.add_child(ExpectEncryptedExtensions())
    node = node.add_child(ExpectCertificate())
    node = node.add_child(ExpectCertificateVerify())
    node = node.add_child(ExpectFinished())
    node = node.add_child(FinishedGenerator())
    node = node.add_child(
        ApplicationDataGenerator(bytearray(b"GET / HTTP/1.0\r\n\r\n")))

    # This message is optional and may show up 0 to many times
    cycle = ExpectNewSessionTicket()
    node = node.add_child(cycle)
    node.add_child(cycle)

    node.next_sibling = ExpectApplicationData()
    node = node.next_sibling.add_child(
        AlertGenerator(AlertLevel.warning, AlertDescription.close_notify))

    node = node.add_child(ExpectAlert())
    node.next_sibling = ExpectClose()
    conversations["handshake with invalid 0-RTT and CCS"] = conversation

    # fake 0-RTT resumption with unknown version
    conversation = Connect(host, port)
    node = conversation
    ciphers = [
        CipherSuite.TLS_AES_128_GCM_SHA256,
        CipherSuite.TLS_RSA_WITH_AES_128_CBC_SHA,
        CipherSuite.TLS_EMPTY_RENEGOTIATION_INFO_SCSV
    ]
    ext = OrderedDict()
    groups = [GroupName.secp256r1]
    key_shares = []
    for group in groups:
        key_shares.append(key_share_gen(group))
    ext[ExtensionType.key_share] = ClientKeyShareExtension().create(key_shares)
    ext[ExtensionType.supported_versions] = SupportedVersionsExtension()\
        .create([(3, 5), (3, 3)])
    ext[ExtensionType.supported_groups] = SupportedGroupsExtension()\
        .create(groups)
    sig_algs = [
        SignatureScheme.rsa_pss_rsae_sha256, SignatureScheme.rsa_pss_pss_sha256
    ]
    ext[ExtensionType.signature_algorithms] = SignatureAlgorithmsExtension()\
        .create(sig_algs)
    ext[ExtensionType.signature_algorithms_cert] = SignatureAlgorithmsCertExtension()\
        .create(RSA_SIG_ALL)
    ext[ExtensionType.early_data] = \
        TLSExtension(extType=ExtensionType.early_data)
    ext[ExtensionType.psk_key_exchange_modes] = PskKeyExchangeModesExtension()\
        .create([PskKeyExchangeMode.psk_dhe_ke, PskKeyExchangeMode.psk_ke])
    iden = PskIdentity().create(getRandomBytes(320),
                                getRandomNumber(2**30, 2**32))
    bind = getRandomBytes(32)
    ext[ExtensionType.pre_shared_key] = PreSharedKeyExtension().create([iden],
                                                                       [bind])
    node = node.add_child(TCPBufferingEnable())
    node = node.add_child(ClientHelloGenerator(ciphers, extensions=ext))
    node = node.add_child(SetRecordVersion((3, 3)))
    node = node.add_child(ApplicationDataGenerator(getRandomBytes(num_bytes)))
    node = node.add_child(TCPBufferingDisable())
    node = node.add_child(TCPBufferingFlush())
    node = node.add_child(ExpectServerHello(version=(3, 3)))
    node = node.add_child(ExpectCertificate())
    node = node.add_child(ExpectServerHelloDone())
    # section D.3 of draft 28 states that client that receives TLS 1.2
    # ServerHello as a reply to 0-RTT Client Hello MUST fail a connection
    # consequently, the server does not need to be able to ignore early data
    # in TLS 1.2 mode
    node = node.add_child(
        ExpectAlert(AlertLevel.fatal, AlertDescription.unexpected_message))
    node.add_child(ExpectClose())
    conversations[
        "handshake with invalid 0-RTT and unknown version (downgrade to TLS 1.2)"] = conversation

    # fake 0-RTT resumption
    conversation = Connect(host, port)
    node = conversation
    ciphers = [
        CipherSuite.TLS_AES_128_GCM_SHA256,
        CipherSuite.TLS_EMPTY_RENEGOTIATION_INFO_SCSV
    ]
    ext = OrderedDict()
    groups = [GroupName.secp256r1]
    key_shares = []
    for group in groups:
        key_shares.append(key_share_gen(group))
    ext[ExtensionType.key_share] = ClientKeyShareExtension().create(key_shares)
    ext[ExtensionType.supported_versions] = SupportedVersionsExtension()\
        .create([TLS_1_3_DRAFT, (3, 3)])
    ext[ExtensionType.supported_groups] = SupportedGroupsExtension()\
        .create(groups)
    sig_algs = [
        SignatureScheme.rsa_pss_rsae_sha256, SignatureScheme.rsa_pss_pss_sha256
    ]
    ext[ExtensionType.signature_algorithms] = SignatureAlgorithmsExtension()\
        .create(sig_algs)
    ext[ExtensionType.signature_algorithms_cert] = SignatureAlgorithmsCertExtension()\
        .create(RSA_SIG_ALL)
    ext[ExtensionType.early_data] = \
        TLSExtension(extType=ExtensionType.early_data)
    ext[ExtensionType.psk_key_exchange_modes] = PskKeyExchangeModesExtension()\
        .create([PskKeyExchangeMode.psk_dhe_ke, PskKeyExchangeMode.psk_ke])
    iden = PskIdentity().create(getRandomBytes(320),
                                getRandomNumber(2**30, 2**32))
    bind = getRandomBytes(32)
    ext[ExtensionType.pre_shared_key] = PreSharedKeyExtension().create([iden],
                                                                       [bind])
    node = node.add_child(TCPBufferingEnable())
    node = node.add_child(ClientHelloGenerator(ciphers, extensions=ext))
    node = node.add_child(SetRecordVersion((3, 3)))
    node = node.add_child(ApplicationDataGenerator(getRandomBytes(num_bytes)))
    node = node.add_child(TCPBufferingDisable())
    node = node.add_child(TCPBufferingFlush())
    node = node.add_child(ExpectServerHello())
    node = node.add_child(ExpectChangeCipherSpec())
    node = node.add_child(ExpectEncryptedExtensions())
    node = node.add_child(ExpectCertificate())
    node = node.add_child(ExpectCertificateVerify())
    node = node.add_child(ExpectFinished())
    node = node.add_child(FinishedGenerator())
    node = node.add_child(
        ApplicationDataGenerator(bytearray(b"GET / HTTP/1.0\r\n\r\n")))

    # This message is optional and may show up 0 to many times
    cycle = ExpectNewSessionTicket()
    node = node.add_child(cycle)
    node.add_child(cycle)

    node.next_sibling = ExpectApplicationData()
    node = node.next_sibling.add_child(
        AlertGenerator(AlertLevel.warning, AlertDescription.close_notify))

    node = node.add_child(ExpectAlert())
    node.next_sibling = ExpectClose()
    conversations["handshake with invalid 0-RTT"] = conversation

    # run the conversation
    good = 0
    bad = 0
    failed = []
    if not num_limit:
        num_limit = len(conversations)

    # make sure that sanity test is run first and last
    # to verify that server was running and kept running throught
    sanity_test = ('sanity', conversations['sanity'])
    ordered_tests = chain([sanity_test],
                          islice(
                              filter(lambda x: x[0] != 'sanity',
                                     conversations.items()), num_limit),
                          [sanity_test])

    for c_name, c_test in ordered_tests:
        if run_only and c_name not in run_only or c_name in run_exclude:
            continue
        print("{0} ...".format(c_name))

        runner = Runner(c_test)

        res = True
        try:
            runner.run()
        except Exception:
            print("Error while processing")
            print(traceback.format_exc())
            res = False

        if res:
            good += 1
            print("OK\n")
        else:
            bad += 1
            failed.append(c_name)

    print("Basic check if TLS 1.3 server can handle 0-RTT handshake")
    print("Verify that the server can handle a 0-RTT handshake from client")
    print("even if (or rather, especially if) it doesn't support 0-RTT.\n")
    print("version: {0}\n".format(version))

    print("Test end")
    print("successful: {0}".format(good))
    print("failed: {0}".format(bad))
    failed_sorted = sorted(failed, key=natural_sort_keys)
    print("  {0}".format('\n  '.join(repr(i) for i in failed_sorted)))

    if bad > 0:
        sys.exit(1)
def main():
    host = "localhost"
    port = 4433
    num_limit = None
    run_exclude = set()
    expected_failures = {}
    last_exp_tmp = None

    argv = sys.argv[1:]
    opts, args = getopt.getopt(argv, "h:p:e:x:X:n:", ["help"])
    for opt, arg in opts:
        if opt == '-h':
            host = arg
        elif opt == '-p':
            port = int(arg)
        elif opt == '-e':
            run_exclude.add(arg)
        elif opt == '-x':
            expected_failures[arg] = None
            last_exp_tmp = str(arg)
        elif opt == '-X':
            if not last_exp_tmp:
                raise ValueError("-x has to be specified before -X")
            expected_failures[last_exp_tmp] = str(arg)
        elif opt == '-n':
            num_limit = int(arg)
        elif opt == '--help':
            help_msg()
            sys.exit(0)
        else:
            raise ValueError("Unknown option: {0}".format(opt))

    if args:
        run_only = set(args)
    else:
        run_only = None

    conversations = {}

    conversation = Connect(host, port)
    node = conversation
    ciphers = [CipherSuite.TLS_AES_128_GCM_SHA256,
               CipherSuite.TLS_EMPTY_RENEGOTIATION_INFO_SCSV]
    ext = {}
    groups = GroupName.allFF
    ext[ExtensionType.key_share] = key_share_ext_gen(groups)
    ext[ExtensionType.supported_versions] = SupportedVersionsExtension()\
        .create([TLS_1_3_DRAFT, (3, 3)])
    ext[ExtensionType.supported_groups] = SupportedGroupsExtension()\
        .create(groups)
    sig_algs = [SignatureScheme.rsa_pss_rsae_sha256,
                SignatureScheme.rsa_pss_pss_sha256]
    sig_algs += ECDSA_SIG_TLS1_3_ALL
    ext[ExtensionType.signature_algorithms] = SignatureAlgorithmsExtension()\
        .create(sig_algs)
    ext[ExtensionType.signature_algorithms_cert] = SignatureAlgorithmsCertExtension()\
        .create(SIG_ALL)
    node = node.add_child(ClientHelloGenerator(ciphers, extensions=ext))
    node = node.add_child(ExpectServerHello())
    node = node.add_child(ExpectChangeCipherSpec())
    node = node.add_child(ExpectEncryptedExtensions())
    node = node.add_child(ExpectCertificate())
    node = node.add_child(ExpectCertificateVerify())
    node = node.add_child(ExpectFinished())
    node = node.add_child(FinishedGenerator())
    node = node.add_child(ApplicationDataGenerator(
        bytearray(b"GET / HTTP/1.0\r\n\r\n")))

    # This message is optional and may show up 0 to many times
    cycle = ExpectNewSessionTicket()
    node = node.add_child(cycle)
    node.add_child(cycle)

    node.next_sibling = ExpectApplicationData()
    node = node.next_sibling.add_child(AlertGenerator(AlertLevel.warning,
                                       AlertDescription.close_notify))

    node = node.add_child(ExpectAlert())
    node.next_sibling = ExpectClose()
    conversations["sanity"] = conversation

    for group in groups:
        conversation = Connect(host, port)
        node = conversation
        ciphers = [CipherSuite.TLS_AES_128_GCM_SHA256,
                   CipherSuite.TLS_EMPTY_RENEGOTIATION_INFO_SCSV]
        ext = {}
        ext[ExtensionType.key_share] = key_share_ext_gen([group])
        ext[ExtensionType.supported_versions] = SupportedVersionsExtension()\
            .create([TLS_1_3_DRAFT, (3, 3)])
        ext[ExtensionType.supported_groups] = SupportedGroupsExtension()\
            .create([group])
        ext[ExtensionType.signature_algorithms] = SignatureAlgorithmsExtension()\
            .create(sig_algs)
        ext[ExtensionType.signature_algorithms_cert] = SignatureAlgorithmsCertExtension()\
            .create(SIG_ALL)
        node = node.add_child(ClientHelloGenerator(ciphers, extensions=ext))
        node = node.add_child(ExpectServerHello())
        node = node.add_child(ExpectChangeCipherSpec())
        node = node.add_child(ExpectEncryptedExtensions())
        node = node.add_child(ExpectCertificate())
        node = node.add_child(ExpectCertificateVerify())
        node = node.add_child(ExpectFinished())
        node = node.add_child(FinishedGenerator())
        node = node.add_child(ApplicationDataGenerator(
            bytearray(b"GET / HTTP/1.0\r\n\r\n")))

        # This message is optional and may show up 0 to many times
        cycle = ExpectNewSessionTicket()
        node = node.add_child(cycle)
        node.add_child(cycle)

        node.next_sibling = ExpectApplicationData()
        node = node.next_sibling.add_child(AlertGenerator(AlertLevel.warning,
                                           AlertDescription.close_notify))

        node = node.add_child(ExpectAlert())
        node.next_sibling = ExpectClose()
        conversations["sanity - {0}".format(GroupName.toRepr(group))] = conversation

        # duplicated key share entry
        conversation = Connect(host, port)
        node = conversation
        ciphers = [CipherSuite.TLS_AES_128_GCM_SHA256,
                   CipherSuite.TLS_EMPTY_RENEGOTIATION_INFO_SCSV]
        ext = {}
        key_shares = []
        key_shares.append(key_share_gen(group))
        key_shares.append(key_share_gen(group))
        ext[ExtensionType.key_share] = ClientKeyShareExtension()\
            .create(key_shares)
        ext[ExtensionType.supported_versions] = SupportedVersionsExtension()\
            .create([TLS_1_3_DRAFT, (3, 3)])
        ext[ExtensionType.supported_groups] = SupportedGroupsExtension()\
            .create([group])
        ext[ExtensionType.signature_algorithms] = SignatureAlgorithmsExtension()\
            .create(sig_algs)
        ext[ExtensionType.signature_algorithms_cert] = SignatureAlgorithmsCertExtension()\
            .create(SIG_ALL)
        node = node.add_child(ClientHelloGenerator(ciphers, extensions=ext))
        node = node.add_child(ExpectAlert(AlertLevel.fatal,
                                          AlertDescription.illegal_parameter))
        node.add_child(ExpectClose())
        conversations["{0} - duplicated key share entry"
                      .format(GroupName.toRepr(group))] = conversation

        # padded representation
        conversation = Connect(host, port)
        node = conversation
        ciphers = [CipherSuite.TLS_AES_128_GCM_SHA256,
                   CipherSuite.TLS_EMPTY_RENEGOTIATION_INFO_SCSV]
        ext = {}
        key_share = key_share_gen(group)
        key_share.key_exchange += bytearray(b'\x00')
        ext[ExtensionType.key_share] = ClientKeyShareExtension().create([key_share])
        ext[ExtensionType.supported_versions] = SupportedVersionsExtension()\
            .create([TLS_1_3_DRAFT, (3, 3)])
        ext[ExtensionType.supported_groups] = SupportedGroupsExtension()\
            .create([group])
        ext[ExtensionType.signature_algorithms] = SignatureAlgorithmsExtension()\
            .create(sig_algs)
        ext[ExtensionType.signature_algorithms_cert] = SignatureAlgorithmsCertExtension()\
            .create(SIG_ALL)
        node = node.add_child(ClientHelloGenerator(ciphers, extensions=ext))
        node = node.add_child(ExpectAlert(AlertLevel.fatal,
                                          AlertDescription.illegal_parameter))
        node.add_child(ExpectClose())
        conversations["{0} - right 0-padded key_share"
                      .format(GroupName.toRepr(group))] = conversation

        # truncated representation (given that all groups use safe primes,
        # any integer between 1 and p-1 is a valid key share, it's just that
        # after truncation we don't know the private key that generates it)
        conversation = Connect(host, port)
        node = conversation
        ciphers = [CipherSuite.TLS_AES_128_GCM_SHA256,
                   CipherSuite.TLS_EMPTY_RENEGOTIATION_INFO_SCSV]
        ext = {}
        key_share = key_share_gen(group)
        key_share.key_exchange.pop()
        ext[ExtensionType.key_share] = ClientKeyShareExtension().create([key_share])
        ext[ExtensionType.supported_versions] = SupportedVersionsExtension()\
            .create([TLS_1_3_DRAFT, (3, 3)])
        ext[ExtensionType.supported_groups] = SupportedGroupsExtension()\
            .create([group])
        ext[ExtensionType.signature_algorithms] = SignatureAlgorithmsExtension()\
            .create(sig_algs)
        ext[ExtensionType.signature_algorithms_cert] = SignatureAlgorithmsCertExtension()\
            .create(SIG_ALL)
        node = node.add_child(ClientHelloGenerator(ciphers, extensions=ext))
        node = node.add_child(ExpectAlert(AlertLevel.fatal,
                                          AlertDescription.illegal_parameter))
        node.add_child(ExpectClose())
        conversations["{0} - right-truncated key_share"
                      .format(GroupName.toRepr(group))] = conversation

        # key share from wrong group
        conversation = Connect(host, port)
        node = conversation
        ciphers = [CipherSuite.TLS_AES_128_GCM_SHA256,
                   CipherSuite.TLS_EMPTY_RENEGOTIATION_INFO_SCSV]
        ext = {}
        key_share = key_share_gen(group)
        if group == GroupName.ffdhe2048:
            key_share2 = key_share_gen(GroupName.ffdhe3072)
        else:
            key_share2 = key_share_gen(GroupName.ffdhe2048)
        key_share.key_exchange = key_share2.key_exchange
        ext[ExtensionType.key_share] = ClientKeyShareExtension().create([key_share])
        ext[ExtensionType.supported_versions] = SupportedVersionsExtension()\
            .create([TLS_1_3_DRAFT, (3, 3)])
        ext[ExtensionType.supported_groups] = SupportedGroupsExtension()\
            .create([group])
        ext[ExtensionType.signature_algorithms] = SignatureAlgorithmsExtension()\
            .create(sig_algs)
        ext[ExtensionType.signature_algorithms_cert] = SignatureAlgorithmsCertExtension()\
            .create(SIG_ALL)
        node = node.add_child(ClientHelloGenerator(ciphers, extensions=ext))
        node = node.add_child(ExpectAlert(AlertLevel.fatal,
                                          AlertDescription.illegal_parameter))
        node.add_child(ExpectClose())
        conversations["{0} - key share from other group"
                      .format(GroupName.toRepr(group))] = conversation

        # just 0
        conversation = Connect(host, port)
        node = conversation
        ciphers = [CipherSuite.TLS_AES_128_GCM_SHA256,
                   CipherSuite.TLS_EMPTY_RENEGOTIATION_INFO_SCSV]
        ext = {}
        key_share = KeyShareEntry().create(group, bytearray(b'\x00'))
        ext[ExtensionType.key_share] = ClientKeyShareExtension().create([key_share])
        ext[ExtensionType.supported_versions] = SupportedVersionsExtension()\
            .create([TLS_1_3_DRAFT, (3, 3)])
        ext[ExtensionType.supported_groups] = SupportedGroupsExtension()\
            .create([group])
        ext[ExtensionType.signature_algorithms] = SignatureAlgorithmsExtension()\
            .create(sig_algs)
        ext[ExtensionType.signature_algorithms_cert] = SignatureAlgorithmsCertExtension()\
            .create(SIG_ALL)
        node = node.add_child(ClientHelloGenerator(ciphers, extensions=ext))
        node = node.add_child(ExpectAlert(AlertLevel.fatal,
                                          AlertDescription.illegal_parameter))
        node.add_child(ExpectClose())
        conversations["{0} - 0 as one byte"
                      .format(GroupName.toRepr(group))] = conversation

        # 0 key share
        conversation = Connect(host, port)
        node = conversation
        ciphers = [CipherSuite.TLS_AES_128_GCM_SHA256,
                   CipherSuite.TLS_EMPTY_RENEGOTIATION_INFO_SCSV]
        ext = {}
        key_share = key_share_gen(group)
        key_share.key_exchange = bytearray(len(key_share.key_exchange))
        ext[ExtensionType.key_share] = ClientKeyShareExtension().create([key_share])
        ext[ExtensionType.supported_versions] = SupportedVersionsExtension()\
            .create([TLS_1_3_DRAFT, (3, 3)])
        ext[ExtensionType.supported_groups] = SupportedGroupsExtension()\
            .create([group])
        ext[ExtensionType.signature_algorithms] = SignatureAlgorithmsExtension()\
            .create(sig_algs)
        ext[ExtensionType.signature_algorithms_cert] = SignatureAlgorithmsCertExtension()\
            .create(SIG_ALL)
        node = node.add_child(ClientHelloGenerator(ciphers, extensions=ext))
        node = node.add_child(ExpectAlert(AlertLevel.fatal,
                                          AlertDescription.illegal_parameter))
        node.add_child(ExpectClose())
        conversations["{0} - 0 as key share"
                      .format(GroupName.toRepr(group))] = conversation

        # 1 key share
        conversation = Connect(host, port)
        node = conversation
        ciphers = [CipherSuite.TLS_AES_128_GCM_SHA256,
                   CipherSuite.TLS_EMPTY_RENEGOTIATION_INFO_SCSV]
        ext = {}
        key_share = key_share_gen(group)
        key_share.key_exchange = bytearray(len(key_share.key_exchange))
        key_share.key_exchange[-1] = 0x01
        ext[ExtensionType.key_share] = ClientKeyShareExtension().create([key_share])
        ext[ExtensionType.supported_versions] = SupportedVersionsExtension()\
            .create([TLS_1_3_DRAFT, (3, 3)])
        ext[ExtensionType.supported_groups] = SupportedGroupsExtension()\
            .create([group])
        ext[ExtensionType.signature_algorithms] = SignatureAlgorithmsExtension()\
            .create(sig_algs)
        ext[ExtensionType.signature_algorithms_cert] = SignatureAlgorithmsCertExtension()\
            .create(SIG_ALL)
        node = node.add_child(ClientHelloGenerator(ciphers, extensions=ext))
        node = node.add_child(ExpectAlert(AlertLevel.fatal,
                                          AlertDescription.illegal_parameter))
        node.add_child(ExpectClose())
        conversations["{0} - 1 as key share"
                      .format(GroupName.toRepr(group))] = conversation

        # all bits set key share
        conversation = Connect(host, port)
        node = conversation
        ciphers = [CipherSuite.TLS_AES_128_GCM_SHA256,
                   CipherSuite.TLS_EMPTY_RENEGOTIATION_INFO_SCSV]
        ext = {}
        key_share = key_share_gen(group)
        key_share.key_exchange = bytearray([0xff] * len(key_share.key_exchange))
        ext[ExtensionType.key_share] = ClientKeyShareExtension().create([key_share])
        ext[ExtensionType.supported_versions] = SupportedVersionsExtension()\
            .create([TLS_1_3_DRAFT, (3, 3)])
        ext[ExtensionType.supported_groups] = SupportedGroupsExtension()\
            .create([group])
        ext[ExtensionType.signature_algorithms] = SignatureAlgorithmsExtension()\
            .create(sig_algs)
        ext[ExtensionType.signature_algorithms_cert] = SignatureAlgorithmsCertExtension()\
            .create(SIG_ALL)
        node = node.add_child(ClientHelloGenerator(ciphers, extensions=ext))
        node = node.add_child(ExpectAlert(AlertLevel.fatal,
                                          AlertDescription.illegal_parameter))
        node.add_child(ExpectClose())
        conversations["{0} - all bits set key share"
                      .format(GroupName.toRepr(group))] = conversation

        if group == GroupName.ffdhe2048:
            params = FFDHE2048
        elif group == GroupName.ffdhe3072:
            params = FFDHE3072
        elif group == GroupName.ffdhe4096:
            params = FFDHE4096
        elif group == GroupName.ffdhe6144:
            params = FFDHE6144
        else:
            assert group == GroupName.ffdhe8192
            params = FFDHE8192

        # p-1 key share
        conversation = Connect(host, port)
        node = conversation
        ciphers = [CipherSuite.TLS_AES_128_GCM_SHA256,
                   CipherSuite.TLS_EMPTY_RENEGOTIATION_INFO_SCSV]
        ext = {}
        key_share = KeyShareEntry().create(group,
                                           numberToByteArray(params[1]-1))
        ext[ExtensionType.key_share] = ClientKeyShareExtension().create([key_share])
        ext[ExtensionType.supported_versions] = SupportedVersionsExtension()\
            .create([TLS_1_3_DRAFT, (3, 3)])
        ext[ExtensionType.supported_groups] = SupportedGroupsExtension()\
            .create([group])
        ext[ExtensionType.signature_algorithms] = SignatureAlgorithmsExtension()\
            .create(sig_algs)
        ext[ExtensionType.signature_algorithms_cert] = SignatureAlgorithmsCertExtension()\
            .create(SIG_ALL)
        node = node.add_child(ClientHelloGenerator(ciphers, extensions=ext))
        node = node.add_child(ExpectAlert(AlertLevel.fatal,
                                          AlertDescription.illegal_parameter))
        node.add_child(ExpectClose())
        conversations["{0} - p-1 as key share"
                      .format(GroupName.toRepr(group))] = conversation

        # p key share
        conversation = Connect(host, port)
        node = conversation
        ciphers = [CipherSuite.TLS_AES_128_GCM_SHA256,
                   CipherSuite.TLS_EMPTY_RENEGOTIATION_INFO_SCSV]
        ext = {}
        key_share = KeyShareEntry().create(group,
                                           numberToByteArray(params[1]))
        ext[ExtensionType.key_share] = ClientKeyShareExtension().create([key_share])
        ext[ExtensionType.supported_versions] = SupportedVersionsExtension()\
            .create([TLS_1_3_DRAFT, (3, 3)])
        ext[ExtensionType.supported_groups] = SupportedGroupsExtension()\
            .create([group])
        ext[ExtensionType.signature_algorithms] = SignatureAlgorithmsExtension()\
            .create(sig_algs)
        ext[ExtensionType.signature_algorithms_cert] = SignatureAlgorithmsCertExtension()\
            .create(SIG_ALL)
        node = node.add_child(ClientHelloGenerator(ciphers, extensions=ext))
        node = node.add_child(ExpectAlert(AlertLevel.fatal,
                                          AlertDescription.illegal_parameter))
        node.add_child(ExpectClose())
        conversations["{0} - p as key share"
                      .format(GroupName.toRepr(group))] = conversation

        # empty key share entry
        conversation = Connect(host, port)
        node = conversation
        ciphers = [CipherSuite.TLS_AES_128_GCM_SHA256,
                   CipherSuite.TLS_EMPTY_RENEGOTIATION_INFO_SCSV]
        ext = {}
        key_share = KeyShareEntry().create(group,
                                           bytearray())
        ext[ExtensionType.key_share] = ClientKeyShareExtension().create([key_share])
        ext[ExtensionType.supported_versions] = SupportedVersionsExtension()\
            .create([TLS_1_3_DRAFT, (3, 3)])
        ext[ExtensionType.supported_groups] = SupportedGroupsExtension()\
            .create([group])
        ext[ExtensionType.signature_algorithms] = SignatureAlgorithmsExtension()\
            .create(sig_algs)
        ext[ExtensionType.signature_algorithms_cert] = SignatureAlgorithmsCertExtension()\
            .create(SIG_ALL)
        node = node.add_child(ClientHelloGenerator(ciphers, extensions=ext))
        node = node.add_child(ExpectAlert(AlertLevel.fatal,
                                          AlertDescription.illegal_parameter))
        node.add_child(ExpectClose())
        conversations["{0} - empty key share entry in key_share ext"
                      .format(GroupName.toRepr(group))] = conversation

    # run the conversation
    good = 0
    bad = 0
    xfail = 0
    xpass = 0
    failed = []
    xpassed = []
    if not num_limit:
        num_limit = len(conversations)

    # make sure that sanity test is run first and last
    # to verify that server was running and kept running throughout
    sanity_tests = [('sanity', conversations['sanity'])]
    if run_only:
        if num_limit > len(run_only):
            num_limit = len(run_only)
        regular_tests = [(k, v) for k, v in conversations.items() if
                          k in run_only]
    else:
        regular_tests = [(k, v) for k, v in conversations.items() if
                         (k != 'sanity') and k not in run_exclude]
    sampled_tests = sample(regular_tests, min(num_limit, len(regular_tests)))
    ordered_tests = chain(sanity_tests, sampled_tests, sanity_tests)

    for c_name, c_test in ordered_tests:
        if run_only and c_name not in run_only or c_name in run_exclude:
            continue
        print("{0} ...".format(c_name))

        runner = Runner(c_test)

        res = True
        exception = None
        try:
            runner.run()
        except Exception as exp:
            exception = exp
            print("Error while processing")
            print(traceback.format_exc())
            res = False

        if c_name in expected_failures:
            if res:
                xpass += 1
                xpassed.append(c_name)
                print("XPASS-expected failure but test passed\n")
            else:
                if expected_failures[c_name] is not None and  \
                    expected_failures[c_name] not in str(exception):
                        bad += 1
                        failed.append(c_name)
                        print("Expected error message: {0}\n"
                            .format(expected_failures[c_name]))
                else:
                    xfail += 1
                    print("OK-expected failure\n")
        else:
            if res:
                good += 1
                print("OK\n")
            else:
                bad += 1
                failed.append(c_name)

    print("Basic FFDHE group tests in TLS 1.3")
    print("Check if invalid, malformed and incompatible group key_shares are")
    print("rejected by server")

    print("Test end")
    print(20 * '=')
    print("version: {0}".format(version))
    print(20 * '=')
    print("TOTAL: {0}".format(len(sampled_tests) + 2*len(sanity_tests)))
    print("SKIP: {0}".format(len(run_exclude.intersection(conversations.keys()))))
    print("PASS: {0}".format(good))
    print("XFAIL: {0}".format(xfail))
    print("FAIL: {0}".format(bad))
    print("XPASS: {0}".format(xpass))
    print(20 * '=')
    sort = sorted(xpassed ,key=natural_sort_keys)
    if len(sort):
        print("XPASSED:\n\t{0}".format('\n\t'.join(repr(i) for i in sort)))
    sort = sorted(failed, key=natural_sort_keys)
    if len(sort):
        print("FAILED:\n\t{0}".format('\n\t'.join(repr(i) for i in sort)))

    if bad > 0:
        sys.exit(1)
Beispiel #10
0
    def test(self):

        sock = MockSocket(server_hello_ciphertext)

        record_layer = RecordLayer(sock)

        ext = [
            SNIExtension().create(bytearray(b'server')),
            TLSExtension(extType=ExtensionType.renegotiation_info).create(
                bytearray(b'\x00')),
            SupportedGroupsExtension().create([
                GroupName.x25519, GroupName.secp256r1, GroupName.secp384r1,
                GroupName.secp521r1, GroupName.ffdhe2048, GroupName.ffdhe3072,
                GroupName.ffdhe4096, GroupName.ffdhe6144, GroupName.ffdhe8192
            ]),
            TLSExtension(extType=35),
            ClientKeyShareExtension().create([
                KeyShareEntry().create(GroupName.x25519, client_key_public,
                                       client_key_private)
            ]),
            SupportedVersionsExtension().create([(3, 4)]),
            SignatureAlgorithmsExtension().create([
                SignatureScheme.ecdsa_secp256r1_sha256,
                SignatureScheme.ecdsa_secp384r1_sha384,
                SignatureScheme.ecdsa_secp521r1_sha512,
                (HashAlgorithm.sha1, SignatureAlgorithm.ecdsa),
                SignatureScheme.rsa_pss_rsae_sha256,
                SignatureScheme.rsa_pss_rsae_sha384,
                SignatureScheme.rsa_pss_rsae_sha512,
                SignatureScheme.rsa_pkcs1_sha256,
                SignatureScheme.rsa_pkcs1_sha384,
                SignatureScheme.rsa_pkcs1_sha512,
                SignatureScheme.rsa_pkcs1_sha1,
                (HashAlgorithm.sha256, SignatureAlgorithm.dsa),
                (HashAlgorithm.sha384, SignatureAlgorithm.dsa),
                (HashAlgorithm.sha512, SignatureAlgorithm.dsa),
                (HashAlgorithm.sha1, SignatureAlgorithm.dsa)
            ]),
            TLSExtension(extType=45).create(bytearray(b'\x01\x01')),
            RecordSizeLimitExtension().create(16385)
        ]
        client_hello = ClientHello()
        client_hello.create((3, 3),
                            bytearray(b'\xcb4\xec\xb1\xe7\x81c'
                                      b'\xba\x1c8\xc6\xda\xcb'
                                      b'\x19jm\xff\xa2\x1a\x8d'
                                      b'\x99\x12\xec\x18\xa2'
                                      b'\xefb\x83\x02M\xec\xe7'),
                            bytearray(b''), [
                                CipherSuite.TLS_AES_128_GCM_SHA256,
                                CipherSuite.TLS_CHACHA20_POLY1305_SHA256,
                                CipherSuite.TLS_AES_256_GCM_SHA384
                            ],
                            extensions=ext)

        self.assertEqual(client_hello.write(), client_hello_ciphertext[5:])

        for result in record_layer.recvRecord():
            # check if non-blocking
            self.assertNotIn(result, (0, 1))
            break

        header, parser = result
        hs_type = parser.get(1)
        self.assertEqual(hs_type, HandshakeType.server_hello)
        server_hello = ServerHello().parse(parser)

        self.assertEqual(server_hello.server_version, (3, 3))
        self.assertEqual(server_hello.cipher_suite,
                         CipherSuite.TLS_AES_128_GCM_SHA256)

        server_key_share = server_hello.getExtension(ExtensionType.key_share)
        server_key_share = server_key_share.server_share

        self.assertEqual(server_key_share.group, GroupName.x25519)

        # for TLS_AES_128_GCM_SHA256:
        prf_name = 'sha256'
        prf_size = 256 // 8
        secret = bytearray(prf_size)
        psk = bytearray(prf_size)

        # early secret
        secret = secureHMAC(secret, psk, prf_name)

        self.assertEqual(
            secret,
            clean("""
                         33 ad 0a 1c 60 7e c0 3b 09 e6 cd 98 93 68 0c
                         e2 10 ad f3 00 aa 1f 26 60 e1 b2 2e 10 f1 70 f9 2a
                         """))

        # derive secret for handshake
        secret = derive_secret(secret, b"derived", None, prf_name)

        self.assertEqual(
            secret,
            clean("""
                         6f 26 15 a1 08 c7 02 c5 67 8f 54 fc 9d ba
                         b6 97 16 c0 76 18 9c 48 25 0c eb ea c3 57 6c 36 11 ba
                         """))

        # extract secret "handshake"
        Z = x25519(client_key_private, server_key_share.key_exchange)

        self.assertEqual(
            Z,
            clean("""
                         8b d4 05 4f b5 5b 9d 63 fd fb ac f9 f0 4b 9f 0d
                         35 e6 d6 3f 53 75 63 ef d4 62 72 90 0f 89 49 2d
                         """))

        secret = secureHMAC(secret, Z, prf_name)

        self.assertEqual(
            secret,
            clean("""
                         1d c8 26 e9 36 06 aa 6f dc 0a ad c1 2f 74 1b
                         01 04 6a a6 b9 9f 69 1e d2 21 a9 f0 ca 04 3f be ac
                         """))

        handshake_hashes = HandshakeHashes()
        handshake_hashes.update(client_hello_plaintext)
        handshake_hashes.update(server_hello_payload)

        # derive "tls13 c hs traffic"
        c_hs_traffic = derive_secret(secret, bytearray(b'c hs traffic'),
                                     handshake_hashes, prf_name)
        self.assertEqual(
            c_hs_traffic,
            clean("""
                         b3 ed db 12 6e 06 7f 35 a7 80 b3 ab f4 5e
                         2d 8f 3b 1a 95 07 38 f5 2e 96 00 74 6a 0e 27 a5 5a 21
                         """))
        s_hs_traffic = derive_secret(secret, bytearray(b's hs traffic'),
                                     handshake_hashes, prf_name)
        self.assertEqual(
            s_hs_traffic,
            clean("""
                         b6 7b 7d 69 0c c1 6c 4e 75 e5 42 13 cb 2d
                         37 b4 e9 c9 12 bc de d9 10 5d 42 be fd 59 d3 91 ad 38
                         """))

        # derive master secret
        secret = derive_secret(secret, b"derived", None, prf_name)

        self.assertEqual(
            secret,
            clean("""
                         43 de 77 e0 c7 77 13 85 9a 94 4d b9 db 25
                         90 b5 31 90 a6 5b 3e e2 e4 f1 2d d7 a0 bb 7c e2 54 b4
                         """))

        # extract secret "master"
        secret = secureHMAC(secret, bytearray(prf_size), prf_name)

        self.assertEqual(
            secret,
            clean("""
                         18 df 06 84 3d 13 a0 8b f2 a4 49 84 4c 5f 8a
                         47 80 01 bc 4d 4c 62 79 84 d5 a4 1d a8 d0 40 29 19
                         """))

        # derive write keys for handshake data
        server_hs_write_trafic_key = HKDF_expand_label(s_hs_traffic, b"key",
                                                       b"", 16, prf_name)

        self.assertEqual(
            server_hs_write_trafic_key,
            clean("""
                         3f ce 51 60 09 c2 17 27 d0 f2 e4 e8 6e
                         e4 03 bc
                         """))

        server_hs_write_trafic_iv = HKDF_expand_label(s_hs_traffic, b"iv", b"",
                                                      12, prf_name)

        self.assertEqual(
            server_hs_write_trafic_iv,
            clean("""
                         5d 31 3e b2 67 12 76 ee 13 00 0b 30
                         """))

        # derive key for Finished message
        server_finished_key = HKDF_expand_label(s_hs_traffic, b"finished", b"",
                                                prf_size, prf_name)
        self.assertEqual(
            server_finished_key,
            clean("""
                         00 8d 3b 66 f8 16 ea 55 9f 96 b5 37 e8 85
                         c3 1f c0 68 bf 49 2c 65 2f 01 f2 88 a1 d8 cd c1 9f c8
                         """))

        # Update the handshake transcript
        handshake_hashes.update(server_encrypted_extensions)
        handshake_hashes.update(server_certificate_message)
        handshake_hashes.update(server_certificateverify_message)
        hs_transcript = handshake_hashes.digest(prf_name)

        server_finished = secureHMAC(server_finished_key, hs_transcript,
                                     prf_name)

        self.assertEqual(
            server_finished,
            clean("""
                         9b 9b 14 1d 90 63 37 fb d2 cb dc e7 1d f4
                         de da 4a b4 2c 30 95 72 cb 7f ff ee 54 54 b7 8f 07 18
                         """))

        server_finished_message = Finished((3, 4)).create(server_finished)
        server_finished_payload = server_finished_message.write()

        # update handshake transcript to include Finished payload
        handshake_hashes.update(server_finished_payload)

        # derive keys for client application traffic
        c_ap_traffic = derive_secret(secret, b"c ap traffic", handshake_hashes,
                                     prf_name)

        self.assertEqual(
            c_ap_traffic,
            clean("""
                         9e 40 64 6c e7 9a 7f 9d c0 5a f8 88 9b ce
                         65 52 87 5a fa 0b 06 df 00 87 f7 92 eb b7 c1 75 04 a5
                         """))

        # derive keys for server application traffic
        s_ap_traffic = derive_secret(secret, b"s ap traffic", handshake_hashes,
                                     prf_name)

        self.assertEqual(
            s_ap_traffic,
            clean("""
                         a1 1a f9 f0 55 31 f8 56 ad 47 11 6b 45 a9
                         50 32 82 04 b4 f4 4b fb 6b 3a 4b 4f 1f 3f cb 63 16 43
                         """))

        # derive exporter master secret
        exp_master = derive_secret(secret, b"exp master", handshake_hashes,
                                   prf_name)

        self.assertEqual(
            exp_master,
            clean("""
                         fe 22 f8 81 17 6e da 18 eb 8f 44 52 9e 67
                         92 c5 0c 9a 3f 89 45 2f 68 d8 ae 31 1b 43 09 d3 cf 50
                         """))

        # derive write traffic keys for app data
        server_write_traffic_key = HKDF_expand_label(s_ap_traffic, b"key", b"",
                                                     16, prf_name)

        self.assertEqual(
            server_write_traffic_key,
            clean("""
                         9f 02 28 3b 6c 9c 07 ef c2 6b b9 f2 ac
                         92 e3 56
                         """))

        server_write_traffic_iv = HKDF_expand_label(s_ap_traffic, b"iv", b"",
                                                    12, prf_name)

        self.assertEqual(
            server_write_traffic_iv,
            clean("""
                         cf 78 2b 88 dd 83 54 9a ad f1 e9 84
                         """))

        # derive read traffic keys for app data
        server_read_hs_key = HKDF_expand_label(c_hs_traffic, b"key", b"", 16,
                                               prf_name)

        self.assertEqual(
            server_read_hs_key,
            clean("""
                         db fa a6 93 d1 76 2c 5b 66 6a f5 d9 50
                         25 8d 01
                         """))

        server_read_hs_iv = HKDF_expand_label(c_hs_traffic, b"iv", b"", 12,
                                              prf_name)

        self.assertEqual(
            server_read_hs_iv,
            clean("""
                         5b d3 c7 1b 83 6e 0b 76 bb 73 26 5f
                         """))
def main():
    host = "localhost"
    port = 4433
    num_limit = None
    run_exclude = set()
    expected_failures = {}
    last_exp_tmp = None

    argv = sys.argv[1:]
    opts, args = getopt.getopt(argv, "h:p:e:x:X:n:", ["help"])
    for opt, arg in opts:
        if opt == '-h':
            host = arg
        elif opt == '-p':
            port = int(arg)
        elif opt == '-e':
            run_exclude.add(arg)
        elif opt == '-x':
            expected_failures[arg] = None
            last_exp_tmp = str(arg)
        elif opt == '-X':
            if not last_exp_tmp:
                raise ValueError("-x has to be specified before -X")
            expected_failures[last_exp_tmp] = str(arg)
        elif opt == '-n':
            num_limit = int(arg)
        elif opt == '--help':
            help_msg()
            sys.exit(0)
        else:
            raise ValueError("Unknown option: {0}".format(opt))

    if args:
        run_only = set(args)
    else:
        run_only = None

    conversations = {}

    conversation = Connect(host, port)
    node = conversation
    ciphers = [
        CipherSuite.TLS_AES_128_GCM_SHA256,
        CipherSuite.TLS_EMPTY_RENEGOTIATION_INFO_SCSV
    ]
    ext = {}
    groups = [GroupName.secp256r1]
    key_shares = []
    for group in groups:
        key_shares.append(key_share_gen(group))
    ext[ExtensionType.key_share] = ClientKeyShareExtension().create(key_shares)
    ext[ExtensionType.supported_versions] = SupportedVersionsExtension()\
        .create([TLS_1_3_DRAFT, (3, 3)])
    ext[ExtensionType.supported_groups] = SupportedGroupsExtension()\
        .create(groups)
    sig_algs = [
        SignatureScheme.rsa_pss_rsae_sha256, SignatureScheme.rsa_pss_pss_sha256
    ]
    ext[ExtensionType.signature_algorithms] = SignatureAlgorithmsExtension()\
        .create(sig_algs)
    ext[ExtensionType.signature_algorithms_cert] = SignatureAlgorithmsCertExtension()\
        .create(RSA_SIG_ALL)
    node = node.add_child(ClientHelloGenerator(ciphers, extensions=ext))
    node = node.add_child(ExpectServerHello())
    node = node.add_child(ExpectChangeCipherSpec())
    node = node.add_child(ExpectEncryptedExtensions())
    node = node.add_child(ExpectCertificate())
    node = node.add_child(ExpectCertificateVerify())
    node = node.add_child(ExpectFinished())
    node = node.add_child(FinishedGenerator())
    node = node.add_child(
        ApplicationDataGenerator(bytearray(b"GET / HTTP/1.0\r\n\r\n")))

    # This message is optional and may show up 0 to many times
    cycle = ExpectNewSessionTicket()
    node = node.add_child(cycle)
    node.add_child(cycle)

    node.next_sibling = ExpectApplicationData()
    node = node.next_sibling.add_child(
        AlertGenerator(AlertLevel.warning, AlertDescription.close_notify))

    node = node.add_child(ExpectAlert())
    node.next_sibling = ExpectClose()
    conversations["sanity"] = conversation

    test_groups = {
        GroupName.x25519: X25519_ORDER_SIZE,
        GroupName.x448: X448_ORDER_SIZE,
    }

    # check if server will negotiate x25519/x448 - sanity check
    for test_group, group_size in test_groups.items():
        for compression_format in [
                ECPointFormat.ansiX962_compressed_prime,
                ECPointFormat.ansiX962_compressed_char2,
                ECPointFormat.uncompressed
        ]:
            conversation = Connect(host, port)
            node = conversation
            ciphers = [
                CipherSuite.TLS_AES_128_GCM_SHA256,
                CipherSuite.TLS_EMPTY_RENEGOTIATION_INFO_SCSV
            ]
            ext = {}
            ext[ExtensionType.ec_point_formats] = ECPointFormatsExtension(
            ).create([compression_format])
            groups = [test_group]
            key_shares = []
            for group in groups:
                key_shares.append(key_share_gen(group))
            ext[ExtensionType.key_share] = ClientKeyShareExtension().create(
                key_shares)
            ext[ExtensionType.supported_versions] = SupportedVersionsExtension()\
                .create([TLS_1_3_DRAFT, (3, 3)])
            ext[ExtensionType.supported_groups] = SupportedGroupsExtension()\
                .create(groups)
            sig_algs = [
                SignatureScheme.rsa_pss_rsae_sha256,
                SignatureScheme.rsa_pss_pss_sha256
            ]
            ext[ExtensionType.signature_algorithms] = SignatureAlgorithmsExtension()\
                .create(sig_algs)
            ext[ExtensionType.signature_algorithms_cert] = SignatureAlgorithmsCertExtension()\
                .create(RSA_SIG_ALL)
            node = node.add_child(ClientHelloGenerator(ciphers,
                                                       extensions=ext))
            node = node.add_child(ExpectServerHello())
            node = node.add_child(ExpectChangeCipherSpec())
            node = node.add_child(ExpectEncryptedExtensions())
            node = node.add_child(ExpectCertificate())
            node = node.add_child(ExpectCertificateVerify())
            node = node.add_child(ExpectFinished())
            node = node.add_child(FinishedGenerator())
            node = node.add_child(
                ApplicationDataGenerator(bytearray(b"GET / HTTP/1.0\r\n\r\n")))

            # This message is optional and may show up 0 to many times
            cycle = ExpectNewSessionTicket()
            node = node.add_child(cycle)
            node.add_child(cycle)

            node.next_sibling = ExpectApplicationData()
            node = node.next_sibling.add_child(
                AlertGenerator(AlertLevel.warning,
                               AlertDescription.close_notify))

            node = node.add_child(ExpectAlert())
            node.next_sibling = ExpectClose()
            conversations["sanity {0} with compression {1}".format(
                GroupName.toRepr(test_group),
                ECPointFormat.toRepr(compression_format))] = conversation

        # check if server will reject an all-zero key share for x25519/x448
        # (it should result in all-zero shared secret)
        conversation = Connect(host, port)
        node = conversation
        ciphers = [
            CipherSuite.TLS_AES_128_GCM_SHA256,
            CipherSuite.TLS_EMPTY_RENEGOTIATION_INFO_SCSV
        ]
        ext = {}
        groups = [test_group]

        key_shares = [
            KeyShareEntry().create(test_group, bytearray(group_size))
        ]
        ext[ExtensionType.key_share] = ClientKeyShareExtension().create(
            key_shares)

        ext[ExtensionType.supported_versions] = SupportedVersionsExtension()\
            .create([TLS_1_3_DRAFT, (3, 3)])
        ext[ExtensionType.supported_groups] = SupportedGroupsExtension()\
            .create(groups)
        sig_algs = [
            SignatureScheme.rsa_pss_rsae_sha256,
            SignatureScheme.rsa_pss_pss_sha256
        ]
        ext[ExtensionType.signature_algorithms] = SignatureAlgorithmsExtension()\
            .create(sig_algs)
        ext[ExtensionType.signature_algorithms_cert] = SignatureAlgorithmsCertExtension()\
            .create(RSA_SIG_ALL)
        node = node.add_child(ClientHelloGenerator(ciphers, extensions=ext))
        node = node.add_child(
            ExpectAlert(AlertLevel.fatal, AlertDescription.illegal_parameter))
        node.add_child(ExpectClose())
        conversations["all zero {0} key share".format(
            GroupName.toRepr(test_group))] = conversation

        # check if server will reject a key share or 1 for x25519/x448
        # (it should result in all-zero shared secret)
        conversation = Connect(host, port)
        node = conversation
        ciphers = [
            CipherSuite.TLS_AES_128_GCM_SHA256,
            CipherSuite.TLS_EMPTY_RENEGOTIATION_INFO_SCSV
        ]
        ext = {}
        groups = [test_group]

        key_shares = [
            KeyShareEntry().create(test_group,
                                   numberToByteArray(1, group_size, "little"))
        ]
        ext[ExtensionType.key_share] = ClientKeyShareExtension().create(
            key_shares)

        ext[ExtensionType.supported_versions] = SupportedVersionsExtension()\
            .create([TLS_1_3_DRAFT, (3, 3)])
        ext[ExtensionType.supported_groups] = SupportedGroupsExtension()\
            .create(groups)
        sig_algs = [
            SignatureScheme.rsa_pss_rsae_sha256,
            SignatureScheme.rsa_pss_pss_sha256
        ]
        ext[ExtensionType.signature_algorithms] = SignatureAlgorithmsExtension()\
            .create(sig_algs)
        ext[ExtensionType.signature_algorithms_cert] = SignatureAlgorithmsCertExtension()\
            .create(RSA_SIG_ALL)
        node = node.add_child(ClientHelloGenerator(ciphers, extensions=ext))
        node = node.add_child(
            ExpectAlert(AlertLevel.fatal, AlertDescription.illegal_parameter))
        node.add_child(ExpectClose())
        conversations["{0} key share of \"1\"".format(
            GroupName.toRepr(test_group))] = conversation

        # check if server will reject too small x25519/x448 share
        # (one with too few bytes in the key share)
        conversation = Connect(host, port)
        node = conversation
        ciphers = [
            CipherSuite.TLS_AES_128_GCM_SHA256,
            CipherSuite.TLS_EMPTY_RENEGOTIATION_INFO_SCSV
        ]
        ext = {}
        groups = [test_group]

        key_shares = [
            KeyShareEntry().create(test_group,
                                   bytearray([55] * (group_size - 1)))
        ]
        ext[ExtensionType.key_share] = ClientKeyShareExtension().create(
            key_shares)

        ext[ExtensionType.supported_versions] = SupportedVersionsExtension()\
            .create([TLS_1_3_DRAFT, (3, 3)])
        ext[ExtensionType.supported_groups] = SupportedGroupsExtension()\
            .create(groups)
        sig_algs = [
            SignatureScheme.rsa_pss_rsae_sha256,
            SignatureScheme.rsa_pss_pss_sha256
        ]
        ext[ExtensionType.signature_algorithms] = SignatureAlgorithmsExtension()\
            .create(sig_algs)
        ext[ExtensionType.signature_algorithms_cert] = SignatureAlgorithmsCertExtension()\
            .create(RSA_SIG_ALL)
        node = node.add_child(ClientHelloGenerator(ciphers, extensions=ext))
        node = node.add_child(
            ExpectAlert(AlertLevel.fatal, AlertDescription.illegal_parameter))
        node.add_child(ExpectClose())
        conversations["too small {0} key share".format(
            GroupName.toRepr(test_group))] = conversation

        # check if server will reject empty x25519/x448 share
        # no compression
        conversation = Connect(host, port)
        node = conversation
        ciphers = [
            CipherSuite.TLS_AES_128_GCM_SHA256,
            CipherSuite.TLS_EMPTY_RENEGOTIATION_INFO_SCSV
        ]
        ext = {}
        groups = [test_group]

        key_shares = [KeyShareEntry().create(test_group, bytearray())]
        ext[ExtensionType.key_share] = ClientKeyShareExtension().create(
            key_shares)

        ext[ExtensionType.supported_versions] = SupportedVersionsExtension()\
            .create([TLS_1_3_DRAFT, (3, 3)])
        ext[ExtensionType.supported_groups] = SupportedGroupsExtension()\
            .create(groups)
        sig_algs = [
            SignatureScheme.rsa_pss_rsae_sha256,
            SignatureScheme.rsa_pss_pss_sha256
        ]
        ext[ExtensionType.signature_algorithms] = SignatureAlgorithmsExtension()\
            .create(sig_algs)
        ext[ExtensionType.signature_algorithms_cert] = SignatureAlgorithmsCertExtension()\
            .create(RSA_SIG_ALL)
        node = node.add_child(ClientHelloGenerator(ciphers, extensions=ext))
        node = node.add_child(
            ExpectAlert(AlertLevel.fatal, AlertDescription.illegal_parameter))
        node.add_child(ExpectClose())
        conversations["empty {0} key share".format(
            GroupName.toRepr(test_group))] = conversation

        # check if server will reject too big x25519/x448 share
        # (one with too many bytes in the key share)
        conversation = Connect(host, port)
        node = conversation
        ciphers = [
            CipherSuite.TLS_AES_128_GCM_SHA256,
            CipherSuite.TLS_EMPTY_RENEGOTIATION_INFO_SCSV
        ]
        ext = {}
        groups = [test_group]

        key_shares = [
            KeyShareEntry().create(test_group,
                                   bytearray([55] * (group_size + 1)))
        ]
        ext[ExtensionType.key_share] = ClientKeyShareExtension().create(
            key_shares)

        ext[ExtensionType.supported_versions] = SupportedVersionsExtension()\
            .create([TLS_1_3_DRAFT, (3, 3)])
        ext[ExtensionType.supported_groups] = SupportedGroupsExtension()\
            .create(groups)
        sig_algs = [
            SignatureScheme.rsa_pss_rsae_sha256,
            SignatureScheme.rsa_pss_pss_sha256
        ]
        ext[ExtensionType.signature_algorithms] = SignatureAlgorithmsExtension()\
            .create(sig_algs)
        ext[ExtensionType.signature_algorithms_cert] = SignatureAlgorithmsCertExtension()\
            .create(RSA_SIG_ALL)
        node = node.add_child(ClientHelloGenerator(ciphers, extensions=ext))
        node = node.add_child(
            ExpectAlert(AlertLevel.fatal, AlertDescription.illegal_parameter))
        node.add_child(ExpectClose())
        conversations["too big {0} key share".format(
            GroupName.toRepr(test_group))] = conversation

    # run the conversation
    good = 0
    bad = 0
    xfail = 0
    xpass = 0
    failed = []
    xpassed = []
    if not num_limit:
        num_limit = len(conversations)

    # make sure that sanity test is run first and last
    # to verify that server was running and kept running throughout
    sanity_tests = [('sanity', conversations['sanity'])]
    regular_tests = [(k, v) for k, v in conversations.items() if k != 'sanity']
    sampled_tests = sample(regular_tests, min(num_limit, len(regular_tests)))
    ordered_tests = chain(sanity_tests, sampled_tests, sanity_tests)

    for c_name, c_test in ordered_tests:
        if run_only and c_name not in run_only or c_name in run_exclude:
            continue
        print("{0} ...".format(c_name))

        runner = Runner(c_test)

        res = True
        exception = None
        try:
            runner.run()
        except Exception as exp:
            exception = exp
            print("Error while processing")
            print(traceback.format_exc())
            res = False

        if c_name in expected_failures:
            if res:
                xpass += 1
                xpassed.append(c_name)
                print("XPASS: expected failure but test passed\n")
            else:
                if expected_failures[c_name] is not None and  \
                    expected_failures[c_name] not in str(exception):
                    bad += 1
                    failed.append(c_name)
                    print("Expected error message: {0}\n".format(
                        expected_failures[c_name]))
                else:
                    xfail += 1
                    print("OK-expected failure\n")
        else:
            if res:
                good += 1
                print("OK\n")
            else:
                bad += 1
                failed.append(c_name)

    print("Basic test to verify that server selects same ECDHE parameters")
    print("and ciphersuites when x25519 or x448 curve is an option\n")
    print("version: {0}\n".format(version))

    print("Test end")
    print(20 * '=')
    print("TOTAL: {0}".format(len(sampled_tests) + 2 * len(sanity_tests)))
    print("SKIP: {0}".format(
        len(run_exclude.intersection(conversations.keys()))))
    print("PASS: {0}".format(good))
    print("XFAIL: {0}".format(xfail))
    print("FAIL: {0}".format(bad))
    print("XPASS: {0}".format(xpass))
    print(20 * '=')
    sort = sorted(xpassed, key=natural_sort_keys)
    if len(sort):
        print("XPASSED:\n\t{0}".format('\n\t'.join(repr(i) for i in sort)))
    sort = sorted(failed, key=natural_sort_keys)
    if len(sort):
        print("FAILED:\n\t{0}".format('\n\t'.join(repr(i) for i in sort)))

    if bad > 0:
        sys.exit(1)
def main():
    host = "localhost"
    port = 4433
    num_limit = None
    run_exclude = set()
    expected_failures = {}
    last_exp_tmp = None
    alert_desc = AlertDescription.illegal_parameter
    cookie = False

    argv = sys.argv[1:]
    opts, args = getopt.getopt(argv, "h:p:e:x:X:n:a:", ["help", "cookie"])
    for opt, arg in opts:
        if opt == '-h':
            host = arg
        elif opt == '-p':
            port = int(arg)
        elif opt == '-e':
            run_exclude.add(arg)
        elif opt == '-x':
            expected_failures[arg] = None
            last_exp_tmp = str(arg)
        elif opt == '-X':
            if not last_exp_tmp:
                raise ValueError("-x has to be specified before -X")
            expected_failures[last_exp_tmp] = str(arg)
        elif opt == '-n':
            num_limit = int(arg)
        elif opt == "-a":
            try:
                alert_desc = int(arg)
            except ValueError:
                alert_desc = getattr(AlertDescription, arg)
        elif opt == "--cookie":
            cookie = True
        elif opt == '--help':
            help_msg()
            sys.exit(0)
        else:
            raise ValueError("Unknown option: {0}".format(opt))

    if args:
        run_only = set(args)
    else:
        run_only = None

    conversations = {}

    conversation = Connect(host, port)
    node = conversation
    ciphers = [
        CipherSuite.TLS_AES_128_GCM_SHA256,
        CipherSuite.TLS_EMPTY_RENEGOTIATION_INFO_SCSV
    ]
    ext = {}
    groups = [GroupName.secp256r1]
    key_shares = []
    for group in groups:
        key_shares.append(key_share_gen(group))
    ext[ExtensionType.key_share] = ClientKeyShareExtension().create(key_shares)
    ext[ExtensionType.supported_versions] = SupportedVersionsExtension()\
        .create([TLS_1_3_DRAFT, (3, 3)])
    ext[ExtensionType.supported_groups] = SupportedGroupsExtension()\
        .create(groups)
    sig_algs = [
        SignatureScheme.rsa_pss_rsae_sha256,
        SignatureScheme.rsa_pss_pss_sha256,
        SignatureScheme.ecdsa_secp256r1_sha256
    ]
    ext[ExtensionType.signature_algorithms] = SignatureAlgorithmsExtension()\
        .create(sig_algs)
    ext[ExtensionType.signature_algorithms_cert] = SignatureAlgorithmsCertExtension()\
        .create(SIG_ALL)
    node = node.add_child(ClientHelloGenerator(ciphers, extensions=ext))
    node = node.add_child(ExpectServerHello())
    node = node.add_child(ExpectChangeCipherSpec())
    node = node.add_child(ExpectEncryptedExtensions())
    node = node.add_child(ExpectCertificate())
    node = node.add_child(ExpectCertificateVerify())
    node = node.add_child(ExpectFinished())
    node = node.add_child(FinishedGenerator())
    node = node.add_child(
        ApplicationDataGenerator(bytearray(b"GET / HTTP/1.0\r\n\r\n")))

    # This message is optional and may show up 0 to many times
    cycle = ExpectNewSessionTicket()
    node = node.add_child(cycle)
    node.add_child(cycle)

    node.next_sibling = ExpectApplicationData()
    node = node.next_sibling.add_child(
        AlertGenerator(AlertLevel.warning, AlertDescription.close_notify))

    node = node.add_child(ExpectAlert())
    node.next_sibling = ExpectClose()
    conversations["sanity"] = conversation

    # verify that server supports HRR
    conversation = Connect(host, port)
    node = conversation
    ciphers = [
        CipherSuite.TLS_AES_128_GCM_SHA256,
        CipherSuite.TLS_EMPTY_RENEGOTIATION_INFO_SCSV
    ]
    ext = OrderedDict()
    groups = [GroupName.secp256r1]
    key_shares = []
    ext[ExtensionType.key_share] = ClientKeyShareExtension().create(key_shares)
    ext[ExtensionType.supported_versions] = SupportedVersionsExtension()\
        .create([TLS_1_3_DRAFT, (3, 3)])
    ext[ExtensionType.supported_groups] = SupportedGroupsExtension()\
        .create(groups)
    sig_algs = [
        SignatureScheme.rsa_pss_rsae_sha256,
        SignatureScheme.rsa_pss_pss_sha256,
        SignatureScheme.ecdsa_secp256r1_sha256
    ]
    ext[ExtensionType.signature_algorithms] = SignatureAlgorithmsExtension()\
        .create(sig_algs)
    ext[ExtensionType.signature_algorithms_cert] = SignatureAlgorithmsCertExtension()\
        .create(SIG_ALL)
    node = node.add_child(ClientHelloGenerator(ciphers, extensions=ext))

    ext = OrderedDict()
    if cookie:
        ext[ExtensionType.cookie] = None
    ext[ExtensionType.key_share] = None
    ext[ExtensionType.supported_versions] = None
    node = node.add_child(ExpectHelloRetryRequest(extensions=ext))
    ext = OrderedDict()
    if cookie:
        ext[ExtensionType.cookie] = ch_cookie_handler
    groups = [GroupName.secp256r1]
    key_shares = []
    for group in groups:
        key_shares.append(key_share_gen(group))
    ext[ExtensionType.key_share] = ClientKeyShareExtension().create(key_shares)
    ext[ExtensionType.supported_versions] = SupportedVersionsExtension()\
        .create([TLS_1_3_DRAFT, (3, 3)])
    ext[ExtensionType.supported_groups] = SupportedGroupsExtension()\
        .create(groups)
    sig_algs = [
        SignatureScheme.rsa_pss_rsae_sha256,
        SignatureScheme.rsa_pss_pss_sha256,
        SignatureScheme.ecdsa_secp256r1_sha256
    ]
    ext[ExtensionType.signature_algorithms] = SignatureAlgorithmsExtension()\
        .create(sig_algs)
    ext[ExtensionType.signature_algorithms_cert] = SignatureAlgorithmsCertExtension()\
        .create(SIG_ALL)
    ext[ExtensionType.key_share] = ClientKeyShareExtension().create(key_shares)
    node = node.add_child(ClientHelloGenerator(ciphers, extensions=ext))
    node = node.add_child(ExpectChangeCipherSpec())
    node = node.add_child(ExpectServerHello())
    node = node.add_child(ExpectEncryptedExtensions())
    node = node.add_child(ExpectCertificate())
    node = node.add_child(ExpectCertificateVerify())
    node = node.add_child(ExpectFinished())
    node = node.add_child(FinishedGenerator())
    node = node.add_child(
        ApplicationDataGenerator(bytearray(b"GET / HTTP/1.0\r\n\r\n")))

    # This message is optional and may show up 0 to many times
    cycle = ExpectNewSessionTicket()
    node = node.add_child(cycle)
    node.add_child(cycle)

    node.next_sibling = ExpectApplicationData()
    node = node.next_sibling.add_child(
        AlertGenerator(AlertLevel.warning, AlertDescription.close_notify))

    node = node.add_child(ExpectAlert())
    node.next_sibling = ExpectClose()
    conversations["sanity - HRR support"] = conversation

    # https://tools.ietf.org/html/rfc8446#appendix-B.3.1.4
    obsolete_groups = chain(range(0x0001,
                                  0x0016 + 1), range(0x001A, 0x001C + 1),
                            range(0xFF01, 0XFF02 + 1))

    for obsolete_group in obsolete_groups:
        obsolete_group_name = (GroupName.toRepr(obsolete_group)
                               or "unknown ({0})".format(obsolete_group))

        ext = {}
        groups = [obsolete_group]
        try:
            key_shares = []
            for group in groups:
                key_shares.append(key_share_gen(group))
        except ValueError:
            # bogus value to move on, if it makes problems, these won't result in handshake_failure
            key_shares = [
                KeyShareEntry().create(obsolete_group, bytearray(b'\xab' * 32))
            ]
        ext[ExtensionType.key_share] = ClientKeyShareExtension().create(
            key_shares)
        ext[ExtensionType.supported_groups] = SupportedGroupsExtension()\
            .create(groups)
        conversation = negative_test(host, port, ext, alert_desc)
        conversation_name = "{0} in supported_groups and key_share"\
            .format(obsolete_group_name)
        conversations[conversation_name] = conversation

        # check if it's rejected when it's just the one advertised, but not shared
        ext = {}
        groups = [obsolete_group]
        key_shares = []
        ext[ExtensionType.key_share] = ClientKeyShareExtension().create(
            key_shares)
        ext[ExtensionType.supported_groups] = SupportedGroupsExtension()\
            .create(groups)
        conversation = negative_test(host, port, ext, alert_desc)
        conversation_name = "{0} in supported_groups and empty key_share"\
            .format(obsolete_group_name)
        conversations[conversation_name] = conversation

        # check invalid group advertised together with valid in key share
        ext = {}
        groups = [obsolete_group, GroupName.secp256r1]
        ext[ExtensionType.key_share] = key_share_ext_gen([GroupName.secp256r1])
        ext[ExtensionType.supported_groups] = SupportedGroupsExtension()\
            .create(groups)
        conversation = negative_test(host, port, ext, alert_desc)
        conversation_name = "{0} and secp256r1 in supported_groups and "\
                            "secp256r1 in key_share".format(obsolete_group_name)
        conversations[conversation_name] = conversation

        # check with both valid and invalid in key share and supported groups
        ext = {}
        groups = [obsolete_group, GroupName.secp256r1]
        key_shares = []
        for group in groups:
            try:
                key_shares.append(key_share_gen(group))
            except ValueError:
                # bogus value to move on, if it makes problems, these won't result in handshake_failure
                key_shares.append(KeyShareEntry().create(
                    obsolete_group, bytearray(b'\xab' * 32)))
        ext[ExtensionType.key_share] = ClientKeyShareExtension().create(
            key_shares)
        ext[ExtensionType.supported_groups] = SupportedGroupsExtension()\
            .create(groups)
        conversation = negative_test(host, port, ext, alert_desc)
        conversation_name = "{0} and secp256r1 in supported_groups and key_share".format(
            obsolete_group_name)
        conversations[conversation_name] = conversation

        # also check with inverted order
        ext = {}
        groups = [GroupName.secp256r1, obsolete_group]
        key_shares = []
        for group in groups:
            try:
                key_shares.append(key_share_gen(group))
            except ValueError:
                # bogus value to move on, if it makes problems, these won't result in handshake_failure
                key_shares.append(KeyShareEntry().create(
                    obsolete_group, bytearray(b'\xab' * 32)))
        ext[ExtensionType.key_share] = ClientKeyShareExtension().create(
            key_shares)
        ext[ExtensionType.supported_groups] = SupportedGroupsExtension()\
            .create(groups)
        conversation = negative_test(host, port, ext, alert_desc)
        conversation_name = "secp256r1 and {0} in supported_groups and key_share".format(
            obsolete_group_name)
        conversations[conversation_name] = conversation

        # check inconsistent key_share with supported_groups
        ext = {}
        groups = [GroupName.secp256r1]
        key_shares = []
        try:
            key_shares.append(key_share_gen(obsolete_group))
        except ValueError:
            # bogus value to move on, if it makes problems, these won't result in handshake_failure
            key_shares.append(KeyShareEntry().create(obsolete_group,
                                                     bytearray(b'\xab' * 32)))
        ext[ExtensionType.key_share] = ClientKeyShareExtension().create(
            key_shares)
        ext[ExtensionType.supported_groups] = SupportedGroupsExtension()\
            .create(groups)
        conversation = negative_test(host, port, ext,
                                     AlertDescription.illegal_parameter)
        conversation_name = \
            "{0} in key_share and secp256r1 in "\
            "supported_groups (inconsistent extensions)"\
            .format(obsolete_group_name)
        conversations[conversation_name] = conversation

    # run the conversation
    good = 0
    bad = 0
    xfail = 0
    xpass = 0
    failed = []
    xpassed = []
    if not num_limit:
        num_limit = len(conversations)

    # make sure that sanity test is run first and last
    # to verify that server was running and kept running throughout
    sanity_tests = [('sanity', conversations['sanity'])]
    if run_only:
        if num_limit > len(run_only):
            num_limit = len(run_only)
        regular_tests = [(k, v) for k, v in conversations.items()
                         if k in run_only]
    else:
        regular_tests = [(k, v) for k, v in conversations.items()
                         if (k != 'sanity') and k not in run_exclude]
    sampled_tests = sample(regular_tests, min(num_limit, len(regular_tests)))
    ordered_tests = chain(sanity_tests, sampled_tests, sanity_tests)

    for c_name, c_test in ordered_tests:
        if run_only and c_name not in run_only or c_name in run_exclude:
            continue
        print("{0} ...".format(c_name))

        runner = Runner(c_test)

        res = True
        exception = None
        try:
            runner.run()
        except Exception as exp:
            exception = exp
            print("Error while processing")
            print(traceback.format_exc())
            res = False

        if c_name in expected_failures:
            if res:
                xpass += 1
                xpassed.append(c_name)
                print("XPASS-expected failure but test passed\n")
            else:
                if expected_failures[c_name] is not None and  \
                    expected_failures[c_name] not in str(exception):
                    bad += 1
                    failed.append(c_name)
                    print("Expected error message: {0}\n".format(
                        expected_failures[c_name]))
                else:
                    xfail += 1
                    print("OK-expected failure\n")
        else:
            if res:
                good += 1
                print("OK\n")
            else:
                bad += 1
                failed.append(c_name)

    print("Negotiating obsolete curves with TLS 1.3 server")
    print("Check that TLS 1.3 server will not use or accept obsolete curves")
    print("in TLS 1.3.")
    print("Reproduces https://github.com/openssl/openssl/issues/8369\n")

    print("Test end")
    print(20 * '=')
    print("version: {0}".format(version))
    print(20 * '=')
    print("TOTAL: {0}".format(len(sampled_tests) + 2 * len(sanity_tests)))
    print("SKIP: {0}".format(
        len(run_exclude.intersection(conversations.keys()))))
    print("PASS: {0}".format(good))
    print("XFAIL: {0}".format(xfail))
    print("FAIL: {0}".format(bad))
    print("XPASS: {0}".format(xpass))
    print(20 * '=')
    sort = sorted(xpassed, key=natural_sort_keys)
    if len(sort):
        print("XPASSED:\n\t{0}".format('\n\t'.join(repr(i) for i in sort)))
    sort = sorted(failed, key=natural_sort_keys)
    if len(sort):
        print("FAILED:\n\t{0}".format('\n\t'.join(repr(i) for i in sort)))

    if bad > 0:
        sys.exit(1)
def main():
    host = "localhost"
    port = 4433
    num_limit = None
    run_exclude = set()
    cookie = False

    argv = sys.argv[1:]
    opts, args = getopt.getopt(argv, "h:p:e:n:", ["help", "cookie"])
    for opt, arg in opts:
        if opt == '-h':
            host = arg
        elif opt == '-p':
            port = int(arg)
        elif opt == '-e':
            run_exclude.add(arg)
        elif opt == '-n':
            num_limit = int(arg)
        elif opt == '--cookie':
            cookie = True
        elif opt == '--help':
            help_msg()
            sys.exit(0)
        else:
            raise ValueError("Unknown option: {0}".format(opt))

    if args:
        run_only = set(args)
    else:
        run_only = None

    conversations = {}

    conversation = Connect(host, port)
    node = conversation
    ciphers = [CipherSuite.TLS_AES_128_GCM_SHA256,
               CipherSuite.TLS_EMPTY_RENEGOTIATION_INFO_SCSV]
    ext = {}
    groups = [GroupName.secp256r1]
    key_shares = []
    for group in groups:
        key_shares.append(key_share_gen(group))
    ext[ExtensionType.key_share] = ClientKeyShareExtension().create(key_shares)
    ext[ExtensionType.supported_versions] = SupportedVersionsExtension()\
        .create([TLS_1_3_DRAFT, (3, 3)])
    ext[ExtensionType.supported_groups] = SupportedGroupsExtension()\
        .create(groups)
    sig_algs = [SignatureScheme.rsa_pss_rsae_sha256,
                SignatureScheme.rsa_pss_pss_sha256]
    ext[ExtensionType.signature_algorithms] = SignatureAlgorithmsExtension()\
        .create(sig_algs)
    ext[ExtensionType.signature_algorithms_cert] = SignatureAlgorithmsCertExtension()\
        .create(RSA_SIG_ALL)
    node = node.add_child(ClientHelloGenerator(ciphers, extensions=ext))
    node = node.add_child(ExpectServerHello())
    node = node.add_child(ExpectChangeCipherSpec())
    node = node.add_child(ExpectEncryptedExtensions())
    node = node.add_child(ExpectCertificate())
    node = node.add_child(ExpectCertificateVerify())
    node = node.add_child(ExpectFinished())
    node = node.add_child(FinishedGenerator())
    node = node.add_child(ApplicationDataGenerator(
        bytearray(b"GET / HTTP/1.0\r\n\r\n")))

    # This message is optional and may show up 0 to many times
    cycle = ExpectNewSessionTicket()
    node = node.add_child(cycle)
    node.add_child(cycle)

    node.next_sibling = ExpectApplicationData()
    node = node.next_sibling.add_child(AlertGenerator(AlertLevel.warning,
                                       AlertDescription.close_notify))

    node = node.add_child(ExpectAlert())
    node.next_sibling = ExpectClose()
    conversations["sanity"] = conversation

    unknown_groups = {
        'EC': list(range(34, 255)),  # Unassigned groups from EC range
        'FFDHE': list(range(261, 507)),  # Unassigned groups from FFDHE range
    }
    known_groups = [GroupName.secp256r1, GroupName.ffdhe2048]

    # Unknown key_shares, one known group and range of unknown groups in supported_groups
    for group_name, unknown_group in unknown_groups.items():
        for size in [64, 128, 256]:
            for known_group in known_groups:
                conversation = Connect(host, port)
                node = conversation
                ciphers = [CipherSuite.TLS_AES_128_GCM_SHA256,
                        CipherSuite.TLS_EMPTY_RENEGOTIATION_INFO_SCSV]
                ext = OrderedDict()

                groups = [known_group] + unknown_group
                key_shares = [KeyShareEntry().create(un_group, bytearray(b'\xab'*size)) for un_group in unknown_group]

                ext[ExtensionType.key_share] = ClientKeyShareExtension().create(key_shares)
                ext[ExtensionType.supported_versions] = SupportedVersionsExtension()\
                    .create([TLS_1_3_DRAFT, (3, 3)])
                ext[ExtensionType.supported_groups] = SupportedGroupsExtension()\
                    .create(groups)
                sig_algs = [SignatureScheme.rsa_pss_rsae_sha256,
                            SignatureScheme.rsa_pss_pss_sha256]
                ext[ExtensionType.signature_algorithms] = SignatureAlgorithmsExtension()\
                    .create(sig_algs)
                ext[ExtensionType.signature_algorithms_cert] = SignatureAlgorithmsCertExtension()\
                    .create(RSA_SIG_ALL)
                node = node.add_child(ClientHelloGenerator(ciphers, extensions=ext))

                ext = OrderedDict()
                if cookie:
                    ext[ExtensionType.cookie] = None
                ext[ExtensionType.key_share] = None
                ext[ExtensionType.supported_versions] = None
                node = node.add_child(ExpectHelloRetryRequest(extensions=ext))
                node = node.add_child(ExpectChangeCipherSpec())

                ext = OrderedDict()
                groups = [known_group] + unknown_group
                key_shares = [key_share_gen(groups[0])]
                ext[ExtensionType.key_share] = ClientKeyShareExtension().create(key_shares)
                ext[ExtensionType.supported_versions] = SupportedVersionsExtension()\
                    .create([TLS_1_3_DRAFT, (3, 3)])
                ext[ExtensionType.supported_groups] = SupportedGroupsExtension()\
                    .create(groups)
                if cookie:
                    ext[ExtensionType.cookie] = ch_cookie_handler
                sig_algs = [SignatureScheme.rsa_pss_rsae_sha256,
                            SignatureScheme.rsa_pss_pss_sha256]
                ext[ExtensionType.signature_algorithms] = SignatureAlgorithmsExtension()\
                    .create(sig_algs)
                ext[ExtensionType.signature_algorithms_cert] = SignatureAlgorithmsCertExtension()\
                    .create(RSA_SIG_ALL)
                node = node.add_child(ClientHelloGenerator(ciphers, extensions=ext))
                node = node.add_child(ExpectServerHello())
                node = node.add_child(ExpectEncryptedExtensions())
                node = node.add_child(ExpectCertificate())
                node = node.add_child(ExpectCertificateVerify())
                node = node.add_child(ExpectFinished())
                node = node.add_child(FinishedGenerator())
                node = node.add_child(ApplicationDataGenerator(
                    bytearray(b"GET / HTTP/1.0\r\n\r\n")))

                # This message is optional and may show up 0 to many times
                cycle = ExpectNewSessionTicket()
                node = node.add_child(cycle)
                node.add_child(cycle)

                node.next_sibling = ExpectApplicationData()
                node = node.next_sibling.add_child(AlertGenerator(AlertLevel.warning,
                                                AlertDescription.close_notify))

                node = node.add_child(ExpectAlert())
                node.next_sibling = ExpectClose()

                conversations["only unknown key_share from {0} range, key_share of size {1} + {2} in supported_groups".format(
                    group_name, size, GroupName.toRepr(known_group))] = conversation

                # One known group and list of unknown groups, unknown ones are listed first
                conversation = Connect(host, port)
                node = conversation
                ciphers = [CipherSuite.TLS_AES_128_GCM_SHA256,
                        CipherSuite.TLS_EMPTY_RENEGOTIATION_INFO_SCSV]
                ext = OrderedDict()

                groups = unknown_group + [known_group]
                key_shares = [KeyShareEntry().create(un_group, bytearray(b'\xab'*size)) for un_group in unknown_group]
                key_shares.append(key_share_gen(groups[-1]))

                ext[ExtensionType.key_share] = ClientKeyShareExtension().create(key_shares)
                ext[ExtensionType.supported_versions] = SupportedVersionsExtension()\
                    .create([TLS_1_3_DRAFT, (3, 3)])
                ext[ExtensionType.supported_groups] = SupportedGroupsExtension()\
                    .create(groups)
                sig_algs = [SignatureScheme.rsa_pss_rsae_sha256,
                            SignatureScheme.rsa_pss_pss_sha256]
                ext[ExtensionType.signature_algorithms] = SignatureAlgorithmsExtension()\
                    .create(sig_algs)
                ext[ExtensionType.signature_algorithms_cert] = SignatureAlgorithmsCertExtension()\
                    .create(RSA_SIG_ALL)
                node = node.add_child(ClientHelloGenerator(ciphers, extensions=ext))
                node = node.add_child(ExpectServerHello())
                node = node.add_child(ExpectChangeCipherSpec())
                node = node.add_child(ExpectEncryptedExtensions())
                node = node.add_child(ExpectCertificate())
                node = node.add_child(ExpectCertificateVerify())
                node = node.add_child(ExpectFinished())
                node = node.add_child(FinishedGenerator())
                node = node.add_child(ApplicationDataGenerator(
                    bytearray(b"GET / HTTP/1.0\r\n\r\n")))

                # This message is optional and may show up 0 to many times
                cycle = ExpectNewSessionTicket()
                node = node.add_child(cycle)
                node.add_child(cycle)

                node.next_sibling = ExpectApplicationData()
                node = node.next_sibling.add_child(AlertGenerator(AlertLevel.warning,
                                                AlertDescription.close_notify))

                node = node.add_child(ExpectAlert())
                node.next_sibling = ExpectClose()

                conversations["known group {0} and unknown groups from {1} range, key_share of size {2}".format(
                    GroupName.toRepr(known_group), group_name, size)] = conversation

            # Unknown supported_groups and key_shares
            conversation = Connect(host, port)
            node = conversation
            ciphers = [CipherSuite.TLS_AES_128_GCM_SHA256,
                    CipherSuite.TLS_EMPTY_RENEGOTIATION_INFO_SCSV]
            ext = OrderedDict()

            groups = unknown_group
            key_shares = [KeyShareEntry().create(un_group, bytearray(b'\xab'*size)) for un_group in unknown_group]

            ext[ExtensionType.key_share] = ClientKeyShareExtension().create(key_shares)
            ext[ExtensionType.supported_versions] = SupportedVersionsExtension()\
                .create([TLS_1_3_DRAFT, (3, 3)])
            ext[ExtensionType.supported_groups] = SupportedGroupsExtension()\
                .create(groups)
            sig_algs = [SignatureScheme.rsa_pss_rsae_sha256,
                        SignatureScheme.rsa_pss_pss_sha256]
            ext[ExtensionType.signature_algorithms] = SignatureAlgorithmsExtension()\
                .create(sig_algs)
            ext[ExtensionType.signature_algorithms_cert] = SignatureAlgorithmsCertExtension()\
                .create(RSA_SIG_ALL)
            node = node.add_child(ClientHelloGenerator(ciphers, extensions=ext))
            node = node.add_child(ExpectAlert(AlertLevel.fatal,
                                            AlertDescription.handshake_failure))
            node.add_child(ExpectClose())

            conversations["only unknown supported_groups from {0} range, key_share of size {1}".format(
                group_name, size)] = conversation

    # run the conversation
    good = 0
    bad = 0
    failed = []
    if not num_limit:
        num_limit = len(conversations)

    # make sure that sanity test is run first and last
    # to verify that server was running and kept running throught
    sanity_test = ('sanity', conversations['sanity'])
    ordered_tests = chain([sanity_test],
                          islice(filter(lambda x: x[0] != 'sanity',
                                        conversations.items()), num_limit),
                          [sanity_test])

    for c_name, c_test in ordered_tests:
        if run_only and c_name not in run_only or c_name in run_exclude:
            continue
        print("{0} ...".format(c_name))

        runner = Runner(c_test)

        res = True
        try:
            runner.run()
        except Exception:
            print("Error while processing")
            print(traceback.format_exc())
            res = False

        if res:
            good += 1
            print("OK\n")
        else:
            bad += 1
            failed.append(c_name)

    print("Unrecognised groups in TLS 1.3")
    print("Check that server replies with HRR, aborts the connection")
    print("with handshake_failure or chooses a known group from client list.")
    print("Groups with IDs from FFDHE and ECDH range.\n")
    print("version: {0}\n".format(version))

    print("Test end")
    print("successful: {0}".format(good))
    print("failed: {0}".format(bad))
    failed_sorted = sorted(failed, key=natural_sort_keys)
    print("  {0}".format('\n  '.join(repr(i) for i in failed_sorted)))

    if bad > 0:
        sys.exit(1)