Ejemplo n.º 1
0
    def send_with_dtlslib(self, address):
        from os import path
        import ssl
        from logging import basicConfig, DEBUG
        basicConfig(level=DEBUG)  # set now for dtls import code
        from dtls import do_patch
        from dtls.wrapper import DtlsSocket
        import socket
        import os

        do_patch()
        ISSUER_CERTFILE_EC = os.path.join(
            os.path.dirname(__file__) or os.curdir, "certs", "ca-cert_ec.pem")
        cert_path = path.join(path.abspath(path.dirname(__file__)), "certs")
        s = DtlsSocket(
            socket.socket(socket.AF_INET, socket.SOCK_DGRAM),
            keyfile=None,
            certfile=None,
            #            cert_reqs=ssl.CERT_REQUIRED,
            ssl_version=ssl.PROTOCOL_DTLSv1_2,
            ca_certs=ISSUER_CERTFILE_EC,
            ciphers='ECDHE:EECDH',
            curves='prime256v1',
            sigalgs=None,
            user_mtu=None)
        s.connect(address)
        s.send('Hi there'.encode())
        print(s.recv().decode())
        s = s.unwrap()
        s.close()
Ejemplo n.º 2
0
    def __init__(self,
                 certificate,
                 ssl_version=None,
                 certreqs=None,
                 cacerts=None,
                 ciphers=None,
                 curves=None,
                 sigalgs=None,
                 mtu=None,
                 server_key_exchange_curve=None,
                 server_cert_options=None,
                 chatty=True):

        if ssl_version is None:
            ssl_version = ssl.PROTOCOL_DTLSv1
        if certreqs is None:
            certreqs = ssl.CERT_NONE

        self.certificate = certificate
        self.protocol = ssl_version
        self.certreqs = certreqs
        self.cacerts = cacerts
        self.ciphers = ciphers
        self.curves = curves
        self.sigalgs = sigalgs
        self.mtu = mtu
        self.server_key_exchange_curve = server_key_exchange_curve
        self.server_cert_options = server_cert_options
        self.chatty = chatty

        self.flag = None

        self.sock = DtlsSocket(
            socket.socket(socket.AF_INET, socket.SOCK_DGRAM),
            keyfile=self.certificate,
            certfile=self.certificate,
            server_side=True,
            cert_reqs=self.certreqs,
            ssl_version=self.protocol,
            ca_certs=self.cacerts,
            ciphers=self.ciphers,
            curves=self.curves,
            sigalgs=self.sigalgs,
            user_mtu=self.mtu,
            server_key_exchange_curve=self.server_key_exchange_curve,
            server_cert_options=self.server_cert_options)

        if self.chatty:
            sys.stdout.write(' server:  wrapped server socket as %s\n' %
                             str(self.sock))
        self.sock.bind((HOST, 0))
        self.port = self.sock.getsockname()[1]
        self.active = False
        threading.Thread.__init__(self)
        self.daemon = True
Ejemplo n.º 3
0
    def __init__(self, certificate, ssl_version=None, certreqs=None, cacerts=None,
                 ciphers=None, curves=None, sigalgs=None,
                 mtu=None, server_key_exchange_curve=None, server_cert_options=None,
                 chatty=True):

        if ssl_version is None:
            ssl_version = ssl.PROTOCOL_DTLSv1
        if certreqs is None:
            certreqs = ssl.CERT_NONE

        self.certificate = certificate
        self.protocol = ssl_version
        self.certreqs = certreqs
        self.cacerts = cacerts
        self.ciphers = ciphers
        self.curves = curves
        self.sigalgs = sigalgs
        self.mtu = mtu
        self.server_key_exchange_curve = server_key_exchange_curve
        self.server_cert_options = server_cert_options
        self.chatty = chatty

        self.flag = None

        self.sock = DtlsSocket(socket.socket(socket.AF_INET, socket.SOCK_DGRAM),
                               keyfile=self.certificate,
                               certfile=self.certificate,
                               server_side=True,
                               cert_reqs=self.certreqs,
                               ssl_version=self.protocol,
                               ca_certs=self.cacerts,
                               ciphers=self.ciphers,
                               curves=self.curves,
                               sigalgs=self.sigalgs,
                               user_mtu=self.mtu,
                               server_key_exchange_curve=self.server_key_exchange_curve,
                               server_cert_options=self.server_cert_options)

        if self.chatty:
            sys.stdout.write(' server:  wrapped server socket as %s\n' % str(self.sock))
        self.sock.bind((HOST, 0))
        self.port = self.sock.getsockname()[1]
        self.active = False
        threading.Thread.__init__(self)
        self.daemon = True
Ejemplo n.º 4
0
    def test_set_ecdh_curve(self):
        steps = {
            # server, client, result
            'all auto': (None, None, True),  # Auto
            'client restricted': (None, "secp256k1:prime256v1",
                                  True),  # client can handle key curve
            'client too restricted':
            (None, "secp256k1", False),  # client _cannot_ handle key curve
            'client minimum':
            (None, "prime256v1", True),  # client can only handle key curve
            'server restricted':
            ("secp384r1", None, True),  # client can handle key curve
            'server one, client two': ("secp384r1", "prime256v1:secp384r1",
                                       True),  # client can handle key curve
            'server one, client one':
            ("secp384r1", "secp384r1",
             False),  # client _cannot_ handle key curve
        }

        chatty, connectionchatty = CHATTY, CHATTY_CLIENT
        indata = 'FOO'
        certs = dict()

        if chatty or connectionchatty:
            sys.stdout.write("\nTestcase: test_ecdh_curve\n")
        for step, tmp in steps.iteritems():
            if chatty or connectionchatty:
                sys.stdout.write("\n Subcase: %s\n" % step)
            server_curve, client_curve, result = tmp
            server = ThreadedEchoServer(certificate=CERTFILE_EC,
                                        ssl_version=ssl.PROTOCOL_DTLSv1_2,
                                        certreqs=ssl.CERT_NONE,
                                        cacerts=ISSUER_CERTFILE_EC,
                                        ciphers=None,
                                        curves=None,
                                        sigalgs=None,
                                        mtu=None,
                                        server_key_exchange_curve=server_curve,
                                        server_cert_options=None,
                                        chatty=chatty)
            flag = threading.Event()
            server.start(flag)
            # wait for it to start
            flag.wait()
            try:
                s = DtlsSocket(socket.socket(socket.AF_INET,
                                             socket.SOCK_DGRAM),
                               keyfile=None,
                               certfile=None,
                               cert_reqs=ssl.CERT_REQUIRED,
                               ssl_version=ssl.PROTOCOL_DTLSv1_2,
                               ca_certs=ISSUER_CERTFILE_EC,
                               ciphers=None,
                               curves=client_curve,
                               sigalgs=None,
                               user_mtu=None)
                s.connect((HOST, server.port))
                if connectionchatty:
                    sys.stdout.write(" client:  sending %s...\n" %
                                     (repr(indata)))
                s.write(indata)
                outdata = s.read()
                if connectionchatty:
                    sys.stdout.write(" client:  read %s\n" % repr(outdata))
                if outdata != indata.lower():
                    raise AssertionError(
                        "bad data <<%s>> (%d) received; expected <<%s>> (%d)\n"
                        % (outdata[:min(len(outdata), 20)], len(outdata),
                           indata[:min(len(indata), 20)].lower(), len(indata)))
                if connectionchatty:
                    sys.stdout.write(" client:  closing connection.\n")
                try:
                    s.close()
                except Exception as e:
                    if connectionchatty:
                        sys.stdout.write(
                            " client:  error closing connection %s...\n" %
                            (repr(e)))
                    pass
            except Exception as e:
                if connectionchatty:
                    sys.stdout.write(
                        " client:  aborting with exception %s...\n" %
                        (repr(e)))
                if result:
                    raise
            finally:
                server.stop()

        pass
Ejemplo n.º 5
0
    def test_build_cert_chain(self):
        steps = [
            ssl.SSL_BUILD_CHAIN_FLAG_NONE, ssl.SSL_BUILD_CHAIN_FLAG_NO_ROOT
        ]
        chatty, connectionchatty = CHATTY, CHATTY_CLIENT
        indata = 'FOO'
        certs = dict()

        if chatty or connectionchatty:
            sys.stdout.write("\nTestcase: test_build_cert_chain\n")
        for step in steps:
            server = ThreadedEchoServer(certificate=CERTFILE,
                                        ssl_version=ssl.PROTOCOL_DTLSv1_2,
                                        certreqs=ssl.CERT_NONE,
                                        cacerts=ISSUER_CERTFILE,
                                        ciphers=None,
                                        curves=None,
                                        sigalgs=None,
                                        mtu=None,
                                        server_key_exchange_curve=None,
                                        server_cert_options=step,
                                        chatty=chatty)
            flag = threading.Event()
            server.start(flag)
            # wait for it to start
            flag.wait()
            try:
                s = DtlsSocket(socket.socket(socket.AF_INET,
                                             socket.SOCK_DGRAM),
                               keyfile=None,
                               certfile=None,
                               cert_reqs=ssl.CERT_REQUIRED,
                               ssl_version=ssl.PROTOCOL_DTLSv1_2,
                               ca_certs=ISSUER_CERTFILE,
                               ciphers=None,
                               curves=None,
                               sigalgs=None,
                               user_mtu=None)
                s.connect((HOST, server.port))
                if connectionchatty:
                    sys.stdout.write(" client:  sending %s...\n" %
                                     (repr(indata)))
                s.write(indata)
                outdata = s.read()
                if connectionchatty:
                    sys.stdout.write(" client:  read %s\n" % repr(outdata))
                if outdata != indata.lower():
                    raise AssertionError(
                        "bad data <<%s>> (%d) received; expected <<%s>> (%d)\n"
                        % (outdata[:min(len(outdata), 20)], len(outdata),
                           indata[:min(len(indata), 20)].lower(), len(indata)))
                # cert = s.getpeercert()
                # cipher = s.cipher()
                # if connectionchatty:
                #     sys.stdout.write("cert:\n" + pprint.pformat(cert) + "\n")
                #     sys.stdout.write("cipher:\n" + pprint.pformat(cipher) + "\n")
                certs[step] = s.getpeercertchain()
                if connectionchatty:
                    sys.stdout.write(" client:  closing connection.\n")
                try:
                    s.close()
                except Exception as e:
                    if connectionchatty:
                        sys.stdout.write(
                            " client:  error closing connection %s...\n" %
                            (repr(e)))
                    pass
            except Exception as e:
                if connectionchatty:
                    sys.stdout.write(
                        " client:  aborting with exception %s...\n" %
                        (repr(e)))
                raise
            finally:
                server.stop()

        if chatty:
            sys.stdout.write("certs:\n")
            for step in steps:
                sys.stdout.write("SSL_CTX_build_cert_chain: %s\n%s\n" %
                                 (step, pprint.pformat(certs[step])))
        self.assertNotEqual(certs[steps[0]], certs[steps[1]])
        self.assertEqual(len(certs[steps[0]]) - len(certs[steps[1]]), 1)
Ejemplo n.º 6
0
def params_test(start_server,
                certfile,
                protocol,
                certreqs,
                cacertsfile,
                client_certfile=None,
                client_protocol=None,
                client_certreqs=None,
                client_cacertsfile=None,
                ciphers=None,
                curves=None,
                sigalgs=None,
                client_ciphers=None,
                client_curves=None,
                client_sigalgs=None,
                mtu=None,
                server_key_exchange_curve=None,
                server_cert_options=None,
                indata="FOO\n",
                chatty=False,
                connectionchatty=False):
    """
    Launch a server, connect a client to it and try various reads
    and writes.
    """
    server = ThreadedEchoServer(
        certfile,
        ssl_version=protocol,
        certreqs=certreqs,
        cacerts=cacertsfile,
        ciphers=ciphers,
        curves=curves,
        sigalgs=sigalgs,
        mtu=mtu,
        server_key_exchange_curve=server_key_exchange_curve,
        server_cert_options=server_cert_options,
        chatty=chatty)
    # should we really run the server?
    if start_server:
        flag = threading.Event()
        server.start(flag)
        # wait for it to start
        flag.wait()
    else:
        server.sock.close()
    # try to connect
    if client_protocol is None:
        client_protocol = protocol
    if client_ciphers is None:
        client_ciphers = ciphers
    if client_curves is None:
        client_curves = curves
    if client_sigalgs is None:
        client_sigalgs = sigalgs
    try:
        s = DtlsSocket(socket.socket(socket.AF_INET, socket.SOCK_DGRAM),
                       keyfile=client_certfile,
                       certfile=client_certfile,
                       cert_reqs=client_certreqs,
                       ssl_version=client_protocol,
                       ca_certs=client_cacertsfile,
                       ciphers=client_ciphers,
                       curves=client_curves,
                       sigalgs=client_sigalgs,
                       user_mtu=mtu)
        s.connect((HOST, server.port))
        if connectionchatty:
            sys.stdout.write(" client:  sending %s...\n" % (repr(indata)))
        s.write(indata)
        outdata = s.read()
        if connectionchatty:
            sys.stdout.write(" client:  read %s\n" % repr(outdata))
        if outdata != indata.lower():
            raise AssertionError(
                "bad data <<%s>> (%d) received; expected <<%s>> (%d)\n" %
                (outdata[:min(len(outdata), 20)], len(outdata),
                 indata[:min(len(indata), 20)].lower(), len(indata)))
        cert = s.getpeercert()
        cipher = s.cipher()
        if connectionchatty:
            sys.stdout.write("cert:\n" + pprint.pformat(cert) + "\n")
            sys.stdout.write("cipher:\n" + pprint.pformat(cipher) + "\n")
        if connectionchatty:
            sys.stdout.write(" client:  closing connection.\n")
        try:
            s.close()
        except Exception as e:
            if connectionchatty:
                sys.stdout.write(" client:  error closing connection %s...\n" %
                                 (repr(e)))
            pass
    except Exception as e:
        if connectionchatty:
            sys.stdout.write(" client:  aborting with exception %s...\n" %
                             (repr(e)))
        return False, e
    finally:
        if start_server:
            server.stop()
    return True, None
Ejemplo n.º 7
0
class ThreadedEchoServer(threading.Thread):
    def __init__(self,
                 certificate,
                 ssl_version=None,
                 certreqs=None,
                 cacerts=None,
                 ciphers=None,
                 curves=None,
                 sigalgs=None,
                 mtu=None,
                 server_key_exchange_curve=None,
                 server_cert_options=None,
                 chatty=True):

        if ssl_version is None:
            ssl_version = ssl.PROTOCOL_DTLSv1
        if certreqs is None:
            certreqs = ssl.CERT_NONE

        self.certificate = certificate
        self.protocol = ssl_version
        self.certreqs = certreqs
        self.cacerts = cacerts
        self.ciphers = ciphers
        self.curves = curves
        self.sigalgs = sigalgs
        self.mtu = mtu
        self.server_key_exchange_curve = server_key_exchange_curve
        self.server_cert_options = server_cert_options
        self.chatty = chatty

        self.flag = None

        self.sock = DtlsSocket(
            socket.socket(socket.AF_INET, socket.SOCK_DGRAM),
            keyfile=self.certificate,
            certfile=self.certificate,
            server_side=True,
            cert_reqs=self.certreqs,
            ssl_version=self.protocol,
            ca_certs=self.cacerts,
            ciphers=self.ciphers,
            curves=self.curves,
            sigalgs=self.sigalgs,
            user_mtu=self.mtu,
            server_key_exchange_curve=self.server_key_exchange_curve,
            server_cert_options=self.server_cert_options)

        if self.chatty:
            sys.stdout.write(' server:  wrapped server socket as %s\n' %
                             str(self.sock))
        self.sock.bind((HOST, 0))
        self.port = self.sock.getsockname()[1]
        self.active = False
        threading.Thread.__init__(self)
        self.daemon = True

    def start(self, flag=None):
        self.flag = flag
        self.starter = threading.current_thread().ident
        threading.Thread.start(self)

    def run(self):
        self.sock.settimeout(0.05)
        self.sock.listen(0)
        self.active = True
        if self.flag:
            # signal an event
            self.flag.set()
        while self.active:
            try:
                acc_ret = self.sock.recvfrom(4096)
                if acc_ret:
                    newdata, connaddr = acc_ret
                    if self.chatty:
                        sys.stdout.write(' server:  new data from ' +
                                         str(connaddr) + '\n')
                    self.sock.sendto(newdata.lower(), connaddr)
            except socket.timeout:
                pass
            except KeyboardInterrupt:
                self.stop()
            except Exception as e:
                if self.chatty:
                    sys.stdout.write(' server:  error ' + str(e) + '\n')
                pass
        if self.chatty:
            sys.stdout.write(' server:  closing socket as %s\n' %
                             str(self.sock))
        self.sock.close()

    def stop(self):
        self.active = False
        if self.starter != threading.current_thread().ident:
            return
        self.join()  # don't allow spawning new handlers after we've checked
Ejemplo n.º 8
0
    def test_set_ecdh_curve(self):
        steps = {
            # server, client, result
            'all auto':                 (None, None,                            True),      # Auto
            'client restricted':        (None, "secp256k1:prime256v1",          True),      # client can handle key curve
            'client too restricted':    (None, "secp256k1",                     False),     # client _cannot_ handle key curve
            'client minimum':           (None, "prime256v1",                    True),      # client can only handle key curve
            'server restricted':        ("secp384r1", None,                     True),      # client can handle key curve
            'server one, client two':   ("secp384r1", "prime256v1:secp384r1",   True),      # client can handle key curve
            'server one, client one':   ("secp384r1", "secp384r1",              False),     # client _cannot_ handle key curve
        }

        chatty, connectionchatty = CHATTY, CHATTY_CLIENT
        indata = 'FOO'
        certs = dict()

        if chatty or connectionchatty:
            sys.stdout.write("\nTestcase: test_ecdh_curve\n")
        for step, tmp in steps.iteritems():
            if chatty or connectionchatty:
                sys.stdout.write("\n Subcase: %s\n" % step)
            server_curve, client_curve, result = tmp
            server = ThreadedEchoServer(certificate=CERTFILE_EC,
                                        ssl_version=ssl.PROTOCOL_DTLSv1_2,
                                        certreqs=ssl.CERT_NONE,
                                        cacerts=ISSUER_CERTFILE_EC,
                                        ciphers=None,
                                        curves=None,
                                        sigalgs=None,
                                        mtu=None,
                                        server_key_exchange_curve=server_curve,
                                        server_cert_options=None,
                                        chatty=chatty)
            flag = threading.Event()
            server.start(flag)
            # wait for it to start
            flag.wait()
            try:
                s = DtlsSocket(socket.socket(socket.AF_INET, socket.SOCK_DGRAM),
                               keyfile=None,
                               certfile=None,
                               cert_reqs=ssl.CERT_REQUIRED,
                               ssl_version=ssl.PROTOCOL_DTLSv1_2,
                               ca_certs=ISSUER_CERTFILE_EC,
                               ciphers=None,
                               curves=client_curve,
                               sigalgs=None,
                               user_mtu=None)
                s.connect((HOST, server.port))
                if connectionchatty:
                    sys.stdout.write(" client:  sending %s...\n" % (repr(indata)))
                s.write(indata)
                outdata = s.read()
                if connectionchatty:
                    sys.stdout.write(" client:  read %s\n" % repr(outdata))
                if outdata != indata.lower():
                    raise AssertionError("bad data <<%s>> (%d) received; expected <<%s>> (%d)\n"
                                         % (outdata[:min(len(outdata), 20)], len(outdata),
                                            indata[:min(len(indata), 20)].lower(), len(indata)))
                if connectionchatty:
                    sys.stdout.write(" client:  closing connection.\n")
                try:
                    s.close()
                except Exception as e:
                    if connectionchatty:
                        sys.stdout.write(" client:  error closing connection %s...\n" % (repr(e)))
                    pass
            except Exception as e:
                if connectionchatty:
                    sys.stdout.write(" client:  aborting with exception %s...\n" % (repr(e)))
                if result:
                    raise
            finally:
                server.stop()

        pass
Ejemplo n.º 9
0
    def test_build_cert_chain(self):
        steps = [ssl.SSL_BUILD_CHAIN_FLAG_NONE, ssl.SSL_BUILD_CHAIN_FLAG_NO_ROOT]
        chatty, connectionchatty = CHATTY, CHATTY_CLIENT
        indata = 'FOO'
        certs = dict()

        if chatty or connectionchatty:
            sys.stdout.write("\nTestcase: test_build_cert_chain\n")
        for step in steps:
            server = ThreadedEchoServer(certificate=CERTFILE,
                                        ssl_version=ssl.PROTOCOL_DTLSv1_2,
                                        certreqs=ssl.CERT_NONE,
                                        cacerts=ISSUER_CERTFILE,
                                        ciphers=None,
                                        curves=None,
                                        sigalgs=None,
                                        mtu=None,
                                        server_key_exchange_curve=None,
                                        server_cert_options=step,
                                        chatty=chatty)
            flag = threading.Event()
            server.start(flag)
            # wait for it to start
            flag.wait()
            try:
                s = DtlsSocket(socket.socket(socket.AF_INET, socket.SOCK_DGRAM),
                               keyfile=None,
                               certfile=None,
                               cert_reqs=ssl.CERT_REQUIRED,
                               ssl_version=ssl.PROTOCOL_DTLSv1_2,
                               ca_certs=ISSUER_CERTFILE,
                               ciphers=None,
                               curves=None,
                               sigalgs=None,
                               user_mtu=None)
                s.connect((HOST, server.port))
                if connectionchatty:
                    sys.stdout.write(" client:  sending %s...\n" % (repr(indata)))
                s.write(indata)
                outdata = s.read()
                if connectionchatty:
                    sys.stdout.write(" client:  read %s\n" % repr(outdata))
                if outdata != indata.lower():
                    raise AssertionError("bad data <<%s>> (%d) received; expected <<%s>> (%d)\n"
                                         % (outdata[:min(len(outdata), 20)], len(outdata),
                                            indata[:min(len(indata), 20)].lower(), len(indata)))
                # cert = s.getpeercert()
                # cipher = s.cipher()
                # if connectionchatty:
                #     sys.stdout.write("cert:\n" + pprint.pformat(cert) + "\n")
                #     sys.stdout.write("cipher:\n" + pprint.pformat(cipher) + "\n")
                certs[step] = s.getpeercertchain()
                if connectionchatty:
                    sys.stdout.write(" client:  closing connection.\n")
                try:
                    s.close()
                except Exception as e:
                    if connectionchatty:
                        sys.stdout.write(" client:  error closing connection %s...\n" % (repr(e)))
                    pass
            except Exception as e:
                if connectionchatty:
                    sys.stdout.write(" client:  aborting with exception %s...\n" % (repr(e)))
                raise
            finally:
                server.stop()

        if chatty:
            sys.stdout.write("certs:\n")
            for step in steps:
                sys.stdout.write("SSL_CTX_build_cert_chain: %s\n%s\n" % (step, pprint.pformat(certs[step])))
        self.assertNotEqual(certs[steps[0]], certs[steps[1]])
        self.assertEqual(len(certs[steps[0]]) - len(certs[steps[1]]), 1)
Ejemplo n.º 10
0
def params_test(start_server, certfile, protocol, certreqs, cacertsfile,
                client_certfile=None, client_protocol=None, client_certreqs=None, client_cacertsfile=None,
                ciphers=None, curves=None, sigalgs=None,
                client_ciphers=None, client_curves=None, client_sigalgs=None,
                mtu=None, server_key_exchange_curve=None, server_cert_options=None,
                indata="FOO\n", chatty=False, connectionchatty=False):
    """
    Launch a server, connect a client to it and try various reads
    and writes.
    """
    server = ThreadedEchoServer(certfile,
                                ssl_version=protocol,
                                certreqs=certreqs,
                                cacerts=cacertsfile,
                                ciphers=ciphers,
                                curves=curves,
                                sigalgs=sigalgs,
                                mtu=mtu,
                                server_key_exchange_curve=server_key_exchange_curve,
                                server_cert_options=server_cert_options,
                                chatty=chatty)
    # should we really run the server?
    if start_server:
        flag = threading.Event()
        server.start(flag)
        # wait for it to start
        flag.wait()
    else:
        server.sock.close()
    # try to connect
    if client_protocol is None:
        client_protocol = protocol
    if client_ciphers is None:
        client_ciphers = ciphers
    if client_curves is None:
        client_curves = curves
    if client_sigalgs is None:
        client_sigalgs = sigalgs
    try:
        s = DtlsSocket(socket.socket(socket.AF_INET, socket.SOCK_DGRAM),
                       keyfile=client_certfile,
                       certfile=client_certfile,
                       cert_reqs=client_certreqs,
                       ssl_version=client_protocol,
                       ca_certs=client_cacertsfile,
                       ciphers=client_ciphers,
                       curves=client_curves,
                       sigalgs=client_sigalgs,
                       user_mtu=mtu)
        s.connect((HOST, server.port))
        if connectionchatty:
            sys.stdout.write(" client:  sending %s...\n" % (repr(indata)))
        s.write(indata)
        outdata = s.read()
        if connectionchatty:
            sys.stdout.write(" client:  read %s\n" % repr(outdata))
        if outdata != indata.lower():
            raise AssertionError("bad data <<%s>> (%d) received; expected <<%s>> (%d)\n"
                                 % (outdata[:min(len(outdata), 20)], len(outdata),
                                    indata[:min(len(indata), 20)].lower(), len(indata)))
        cert = s.getpeercert()
        cipher = s.cipher()
        if connectionchatty:
            sys.stdout.write("cert:\n" + pprint.pformat(cert) + "\n")
            sys.stdout.write("cipher:\n" + pprint.pformat(cipher) + "\n")
        if connectionchatty:
            sys.stdout.write(" client:  closing connection.\n")
        try:
            s.close()
        except Exception as e:
            if connectionchatty:
                sys.stdout.write(" client:  error closing connection %s...\n" % (repr(e)))
            pass
    except Exception as e:
        if connectionchatty:
            sys.stdout.write(" client:  aborting with exception %s...\n" % (repr(e)))
        return False, e
    finally:
        if start_server:
            server.stop()
    return True, None
Ejemplo n.º 11
0
class ThreadedEchoServer(threading.Thread):

    def __init__(self, certificate, ssl_version=None, certreqs=None, cacerts=None,
                 ciphers=None, curves=None, sigalgs=None,
                 mtu=None, server_key_exchange_curve=None, server_cert_options=None,
                 chatty=True):

        if ssl_version is None:
            ssl_version = ssl.PROTOCOL_DTLSv1
        if certreqs is None:
            certreqs = ssl.CERT_NONE

        self.certificate = certificate
        self.protocol = ssl_version
        self.certreqs = certreqs
        self.cacerts = cacerts
        self.ciphers = ciphers
        self.curves = curves
        self.sigalgs = sigalgs
        self.mtu = mtu
        self.server_key_exchange_curve = server_key_exchange_curve
        self.server_cert_options = server_cert_options
        self.chatty = chatty

        self.flag = None

        self.sock = DtlsSocket(socket.socket(socket.AF_INET, socket.SOCK_DGRAM),
                               keyfile=self.certificate,
                               certfile=self.certificate,
                               server_side=True,
                               cert_reqs=self.certreqs,
                               ssl_version=self.protocol,
                               ca_certs=self.cacerts,
                               ciphers=self.ciphers,
                               curves=self.curves,
                               sigalgs=self.sigalgs,
                               user_mtu=self.mtu,
                               server_key_exchange_curve=self.server_key_exchange_curve,
                               server_cert_options=self.server_cert_options)

        if self.chatty:
            sys.stdout.write(' server:  wrapped server socket as %s\n' % str(self.sock))
        self.sock.bind((HOST, 0))
        self.port = self.sock.getsockname()[1]
        self.active = False
        threading.Thread.__init__(self)
        self.daemon = True

    def start(self, flag=None):
        self.flag = flag
        self.starter = threading.current_thread().ident
        threading.Thread.start(self)

    def run(self):
        self.sock.settimeout(0.05)
        self.sock.listen(0)
        self.active = True
        if self.flag:
            # signal an event
            self.flag.set()
        while self.active:
            try:
                acc_ret = self.sock.recvfrom(4096)
                if acc_ret:
                    newdata, connaddr = acc_ret
                    if self.chatty:
                        sys.stdout.write(' server:  new data from ' + str(connaddr) + '\n')
                    self.sock.sendto(newdata.lower(), connaddr)
            except socket.timeout:
                pass
            except KeyboardInterrupt:
                self.stop()
            except Exception as e:
                if self.chatty:
                    sys.stdout.write(' server:  error ' + str(e) + '\n')
                pass
        if self.chatty:
            sys.stdout.write(' server:  closing socket as %s\n' % str(self.sock))
        self.sock.close()

    def stop(self):
        self.active = False
        if self.starter != threading.current_thread().ident:
            return
        self.join()  # don't allow spawning new handlers after we've checked