def test_handshake(self): """ The TLS handshake is performed when L{TLSMemoryBIOProtocol} is connected to a transport. """ clientFactory = ClientFactory() clientFactory.protocol = Protocol clientContextFactory, handshakeDeferred = ( HandshakeCallbackContextFactory.factoryAndDeferred()) wrapperFactory = TLSMemoryBIOFactory(clientContextFactory, True, clientFactory) sslClientProtocol = wrapperFactory.buildProtocol(None) serverFactory = ServerFactory() serverFactory.protocol = Protocol serverContextFactory = DefaultOpenSSLContextFactory(certPath, certPath) wrapperFactory = TLSMemoryBIOFactory(serverContextFactory, False, serverFactory) sslServerProtocol = wrapperFactory.buildProtocol(None) connectionDeferred = loopbackAsync(sslServerProtocol, sslClientProtocol) # Only wait for the handshake to complete. Anything after that isn't # important here. return handshakeDeferred
def test_loseConnectionAfterHandshake(self): """ L{TLSMemoryBIOProtocol.loseConnection} sends a TLS close alert and shuts down the underlying connection. """ clientConnectionLost = Deferred() clientFactory = ClientFactory() clientFactory.protocol = ( lambda: ConnectionLostNotifyingProtocol(clientConnectionLost)) clientContextFactory, handshakeDeferred = ( HandshakeCallbackContextFactory.factoryAndDeferred()) wrapperFactory = TLSMemoryBIOFactory(clientContextFactory, True, clientFactory) sslClientProtocol = wrapperFactory.buildProtocol(None) serverProtocol = Protocol() serverFactory = ServerFactory() serverFactory.protocol = lambda: serverProtocol serverContextFactory = DefaultOpenSSLContextFactory(certPath, certPath) wrapperFactory = TLSMemoryBIOFactory(serverContextFactory, False, serverFactory) sslServerProtocol = wrapperFactory.buildProtocol(None) connectionDeferred = loopbackAsync(sslServerProtocol, sslClientProtocol) # Wait for the handshake before dropping the connection. def cbHandshake(ignored): serverProtocol.transport.loseConnection() # Now wait for the client to notice. return clientConnectionLost handshakeDeferred.addCallback(cbHandshake) # Wait for the connection to end, then make sure the client was # notified of a handshake failure. def cbConnectionDone(clientProtocol): clientProtocol.lostConnectionReason.trap(ConnectionDone) # The server should have closed its underlying transport, in # addition to whatever it did to shut down the TLS layer. self.assertTrue(serverProtocol.transport.q.disconnect) # The client should also have closed its underlying transport once # it saw the server shut down the TLS layer, so as to avoid relying # on the server to close the underlying connection. self.assertTrue(clientProtocol.transport.q.disconnect) handshakeDeferred.addCallback(cbConnectionDone) return handshakeDeferred
def connect_to(self, description: str, peer: Optional[PeerId] = None, use_ssl: Optional[bool] = None) -> None: """ Attempt to connect to a peer, even if a connection already exists. Usually you should call `connect_to_if_not_connected`. If `use_ssl` is True, then the connection will be wraped by a TLS. """ if use_ssl is None: use_ssl = self.ssl connection_string, peer_id = description_to_connection_string( description) # When using twisted endpoints we can't have // in the connection string endpoint_url = connection_string.replace('//', '') endpoint = endpoints.clientFromString(self.reactor, endpoint_url) if use_ssl: certificate_options = self.my_peer.get_certificate_options() factory = TLSMemoryBIOFactory(certificate_options, True, self.client_factory) else: factory = self.client_factory deferred = endpoint.connect(factory) self.connecting_peers[endpoint] = deferred deferred.addCallback(self._connect_to_callback, peer, endpoint, connection_string, peer_id) deferred.addErrback(self.on_connection_failure, peer, endpoint) self.log.info('Connecting to: {description}...', description=description)
def test_disorderlyShutdown(self): """ If a L{TLSMemoryBIOProtocol} loses its connection unexpectedly, this is reported to the application. """ clientConnectionLost = Deferred() clientFactory = ClientFactory() clientFactory.protocol = ( lambda: ConnectionLostNotifyingProtocol(clientConnectionLost)) clientContextFactory = HandshakeCallbackContextFactory() wrapperFactory = TLSMemoryBIOFactory(clientContextFactory, True, clientFactory) sslClientProtocol = wrapperFactory.buildProtocol(None) # Client speaks first, so the server can be dumb. serverProtocol = Protocol() connectionDeferred = loopbackAsync(serverProtocol, sslClientProtocol) # Now destroy the connection. serverProtocol.transport.loseConnection() # And when the connection completely dies, check the reason. def cbDisconnected(clientProtocol): clientProtocol.lostConnectionReason.trap(Error) clientConnectionLost.addCallback(cbDisconnected) return clientConnectionLost
def listen(self, description: str, ssl: bool = True) -> IStreamServerEndpoint: """ Start to listen to new connection according to the description. If `ssl` is True, then the connection will be wraped by a TLS. :Example: `manager.listen(description='tcp:8000')` :param description: A description of the protocol and its parameters. :type description: str """ endpoint = endpoints.serverFromString(self.reactor, description) if ssl: certificate_options = self.my_peer.get_certificate_options() factory = TLSMemoryBIOFactory(certificate_options, False, self.server_factory) else: factory = self.server_factory endpoint.listen(factory) self.log.info('Listening to: {description}...', description=description) return endpoint
def __init__(self, reactor, cluster_state, configuration_service, endpoint, context_factory): """ :param reactor: See ``ControlServiceLocator.__init__``. :param ClusterStateService cluster_state: Object that records known cluster state. :param ConfigurationPersistenceService configuration_service: Persistence service for desired cluster configuration. :param endpoint: Endpoint to listen on. :param context_factory: TLS context factory. """ self._connections = set() self._reactor = reactor self._connections_pending_update = set() self._current_pending_update_delayed_call = None self._current_command = {} self._last_received_generation = defaultdict( lambda: _ConfigAndStateGeneration()) self._configuration_generation_tracker = GenerationTracker(100) self._state_generation_tracker = GenerationTracker(100) self.cluster_state = cluster_state self.configuration_service = configuration_service self.endpoint_service = StreamServerEndpointService( endpoint, TLSMemoryBIOFactory( context_factory, False, ServerFactory.forProtocol(lambda: ControlAMP(reactor, self)))) # When configuration changes, notify all connected clients: self.configuration_service.register(self._schedule_broadcast_update)
def refresh_certificate(hs: "HomeServer") -> None: """ Refresh the TLS certificates that Synapse is using by re-reading them from disk and updating the TLS context factories to use them. """ if not hs.config.server.has_tls_listener(): return hs.config.tls.read_certificate_from_disk() hs.tls_server_context_factory = context_factory.ServerContextFactory( hs.config) if hs._listening_services: logger.info("Updating context factories...") for i in hs._listening_services: # When you listenSSL, it doesn't make an SSL port but a TCP one with # a TLS wrapping factory around the factory you actually want to get # requests. This factory attribute is public but missing from # Twisted's documentation. if isinstance(i.factory, TLSMemoryBIOFactory): addr = i.getHost() logger.info("Replacing TLS context factory on [%s]:%i", addr.host, addr.port) # We want to replace TLS factories with a new one, with the new # TLS configuration. We do this by reaching in and pulling out # the wrappedFactory, and then re-wrapping it. i.factory = TLSMemoryBIOFactory(hs.tls_server_context_factory, False, i.factory.wrappedFactory) logger.info("Context factories updated.")
def __init__(self, reactor, cluster_state, configuration_service, endpoint, context_factory): """ :param reactor: See ``ControlServiceLocator.__init__``. :param ClusterStateService cluster_state: Object that records known cluster state. :param ConfigurationPersistenceService configuration_service: Persistence service for desired cluster configuration. :param endpoint: Endpoint to listen on. :param context_factory: TLS context factory. """ self.connections = set() self._current_command = {} self.cluster_state = cluster_state self.configuration_service = configuration_service self.endpoint_service = StreamServerEndpointService( endpoint, TLSMemoryBIOFactory( context_factory, False, ServerFactory.forProtocol(lambda: ControlAMP(reactor, self)) ) ) # When configuration changes, notify all connected clients: self.configuration_service.register( lambda: self._send_state_to_connections(self.connections))
def create_api_service(persistence_service, cluster_state_service, endpoint, context_factory, clock=reactor): """ Create a Twisted Service that serves the API on the given endpoint. :param ConfigurationPersistenceService persistence_service: Service for retrieving and setting desired configuration. :param ClusterStateService cluster_state_service: Service that knows about the current state of the cluster. :param endpoint: Twisted endpoint to listen on. :param context_factory: TLS context factory. :param IReactorTime clock: The clock to use for time. By default global reactor. :return: Service that will listen on the endpoint using HTTP API server. """ api_root = Resource() user = ConfigurationAPIUserV1(persistence_service, cluster_state_service, clock) api_root.putChild('v1', user.app.resource()) api_root._v1_user = user # For unit testing purposes, alas return StreamServerEndpointService( endpoint, TLSMemoryBIOFactory( context_factory, False, Site(api_root) ) )
def listen(self, protocolFactory): # noqa """ Start an issuing service, and wait until initial issuing is complete. """ def _got_port(port): self.service = AcmeIssuingService( cert_store=self.cert_store, client_creator=partial( self.client_creator, self.reactor, self.directory), clock=self.reactor, responders=[responder], check_interval=self.check_interval, reissue_interval=self.reissue_interval, panic_interval=self.panic_interval, panic=self._panic, generate_key=self._generate_key) self.service.startService() return ( self.service.when_certs_valid() .addCallback( lambda _: _WrapperPort(port=port, service=self.service))) responder = TLSSNI01Responder() sni_map = SNIMap(responder.wrap_host_map(self.cert_mapping)) return ( maybeDeferred( self.sub_endpoint.listen, TLSMemoryBIOFactory( contextFactory=sni_map, isClient=False, wrappedFactory=protocolFactory)) .addCallback(_got_port))
def startTLS(self, contextFactory, normal=True): """ @see: L{ITLSTransport.startTLS} """ # Figure out which direction the SSL goes in. If normal is True, # we'll go in the direction indicated by the subclass. Otherwise, # we'll go the other way (client = not normal ^ _tlsClientDefault, # in other words). if normal: client = self._tlsClientDefault else: client = not self._tlsClientDefault tlsFactory = TLSMemoryBIOFactory(contextFactory, client, None) tlsProtocol = TLSMemoryBIOProtocol(tlsFactory, self.protocol, False) self.protocol = tlsProtocol self.getHandle = tlsProtocol.getHandle self.getPeerCertificate = tlsProtocol.getPeerCertificate # Mark the transport as secure. directlyProvides(self, interfaces.ISSLTransport) # Remember we did this so that write and writeSequence can send the # data to the right place. self._tls = True # Hook it up self.protocol.makeConnection(_BypassTLS(self))
def test_writeAfterHandshake(self): """ Bytes written to L{TLSMemoryBIOProtocol} before the handshake is complete are received by the protocol on the other side of the connection once the handshake succeeds. """ bytes = "some bytes" clientProtocol = Protocol() clientFactory = ClientFactory() clientFactory.protocol = lambda: clientProtocol clientContextFactory, handshakeDeferred = ( HandshakeCallbackContextFactory.factoryAndDeferred()) wrapperFactory = TLSMemoryBIOFactory(clientContextFactory, True, clientFactory) sslClientProtocol = wrapperFactory.buildProtocol(None) serverProtocol = AccumulatingProtocol(len(bytes)) serverFactory = ServerFactory() serverFactory.protocol = lambda: serverProtocol serverContextFactory = DefaultOpenSSLContextFactory(certPath, certPath) wrapperFactory = TLSMemoryBIOFactory(serverContextFactory, False, serverFactory) sslServerProtocol = wrapperFactory.buildProtocol(None) connectionDeferred = loopbackAsync(sslServerProtocol, sslClientProtocol) # Wait for the handshake to finish before writing anything. def cbHandshook(ignored): clientProtocol.transport.write(bytes) # The server will drop the connection once it gets the bytes. return connectionDeferred handshakeDeferred.addCallback(cbHandshook) # Once the connection is lost, make sure the server received the # expected bytes. def cbDisconnected(ignored): self.assertEquals("".join(serverProtocol.received), bytes) handshakeDeferred.addCallback(cbDisconnected) return handshakeDeferred
def test_handshakeFailure(self): """ L{TLSMemoryBIOProtocol} reports errors in the handshake process to the application-level protocol object using its C{connectionLost} method and disconnects the underlying transport. """ clientConnectionLost = Deferred() clientFactory = ClientFactory() clientFactory.protocol = ( lambda: ConnectionLostNotifyingProtocol(clientConnectionLost)) clientContextFactory = HandshakeCallbackContextFactory() wrapperFactory = TLSMemoryBIOFactory(clientContextFactory, True, clientFactory) sslClientProtocol = wrapperFactory.buildProtocol(None) serverConnectionLost = Deferred() serverFactory = ServerFactory() serverFactory.protocol = ( lambda: ConnectionLostNotifyingProtocol(serverConnectionLost)) # This context factory rejects any clients which do not present a # certificate. certificateData = FilePath(certPath).getContent() certificate = PrivateCertificate.loadPEM(certificateData) serverContextFactory = certificate.options(certificate) wrapperFactory = TLSMemoryBIOFactory(serverContextFactory, False, serverFactory) sslServerProtocol = wrapperFactory.buildProtocol(None) connectionDeferred = loopbackAsync(sslServerProtocol, sslClientProtocol) def cbConnectionLost(protocol): # The connection should close on its own in response to the error # induced by the client not supplying the required certificate. # After that, check to make sure the protocol's connectionLost was # called with the right thing. protocol.lostConnectionReason.trap(Error) clientConnectionLost.addCallback(cbConnectionLost) serverConnectionLost.addCallback(cbConnectionLost) # Additionally, the underlying transport should have been told to # go away. return gatherResults( [clientConnectionLost, serverConnectionLost, connectionDeferred])
def addSubprocesses(self, fds, name, factory): super(HendrixDeployTLS, self).addSubprocesses(fds, name, factory) if name == 'main_web_ssl': privateCert = PrivateCertificate.loadPEM( open(self.options['cert']).read() + open(self.options['key']).read()) factory = TLSMemoryBIOFactory(privateCert.options(), False, factory)
def listenSSL(self, port, factory, contextFactory, backlog=50, interface=''): """ @see: twisted.internet.interfaces.IReactorSSL.listenSSL """ return self.listenTCP( port, TLSMemoryBIOFactory(contextFactory, False, factory), backlog, interface)
def connectSSL(self, host, port, factory, contextFactory, timeout=30, bindAddress=None): """ @see: twisted.internet.interfaces.IReactorSSL.connectSSL """ return self.connectTCP( host, port, TLSMemoryBIOFactory(contextFactory, True, factory), timeout, bindAddress)
def startTLS(transport, contextFactory, normal, bypass): """ Add a layer of SSL to a transport. @param transport: The transport which will be modified. This can either by a L{FileDescriptor<twisted.internet.abstract.FileDescriptor>} or a L{FileHandle<twisted.internet.iocpreactor.abstract.FileHandle>}. The actual requirements of this instance are that it have: - a C{_tlsClientDefault} attribute indicating whether the transport is a client (C{True}) or a server (C{False}) - a settable C{TLS} attribute which can be used to mark the fact that SSL has been started - settable C{getHandle} and C{getPeerCertificate} attributes so these L{ISSLTransport} methods can be added to it - a C{protocol} attribute referring to the L{IProtocol} currently connected to the transport, which can also be set to a new L{IProtocol} for the transport to deliver data to @param contextFactory: An SSL context factory defining SSL parameters for the new SSL layer. @type contextFactory: L{twisted.internet.ssl.ContextFactory} @param normal: A flag indicating whether SSL will go in the same direction as the underlying transport goes. That is, if the SSL client will be the underlying client and the SSL server will be the underlying server. C{True} means it is the same, C{False} means they are switched. @type param: L{bool} @param bypass: A transport base class to call methods on to bypass the new SSL layer (so that the SSL layer itself can send its bytes). @type bypass: L{type} """ # Figure out which direction the SSL goes in. If normal is True, # we'll go in the direction indicated by the subclass. Otherwise, # we'll go the other way (client = not normal ^ _tlsClientDefault, # in other words). if normal: client = transport._tlsClientDefault else: client = not transport._tlsClientDefault tlsFactory = TLSMemoryBIOFactory(contextFactory, client, None) tlsProtocol = TLSMemoryBIOProtocol(tlsFactory, transport.protocol, False) transport.protocol = tlsProtocol transport.getHandle = tlsProtocol.getHandle transport.getPeerCertificate = tlsProtocol.getPeerCertificate # Mark the transport as secure. directlyProvides(transport, ISSLTransport) # Remember we did this so that write and writeSequence can send the # data to the right place. transport.TLS = True # Hook it up transport.protocol.makeConnection(_BypassTLS(bypass, transport))
def listen(self, factory): """Implement IStreamServerEndpoint.listen to listen on TCP. Optionally configuring TLS behind the HAProxy protocol. """ if self._ssl_cf: factory = TLSMemoryBIOFactory(self._ssl_cf, False, factory) proxyf = self.wrapper_factory(factory) return defer.execute(self._listen, self._port, proxyf, **self._kwargs)
def startTLS(self, contextFactory, client, bytes=None): """ Add a layer of TLS, with SSL parameters defined by the given contextFactory. If *client* is True, this side of the connection will be an SSL client. Otherwise it will be an SSL server. If extra bytes which may be (or almost certainly are) part of the SSL handshake were received by the protocol running on top of OnionProtocol, they must be passed here as the **bytes** parameter. """ # The newest TLS session is spliced in between the previous # and the application protocol at the tail end of the list. tlsProtocol = TLSMemoryBIOProtocol(None, self._tailProtocol, False) tlsProtocol.factory = TLSMemoryBIOFactory(contextFactory, client, None) if self._currentProtocol is self._tailProtocol: # This is the first and thus outermost TLS session. The # transport is the immutable sentinel that no startTLS or # stopTLS call will move within the linked list stack. # The wrappedProtocol will remain this outermost session # until it's terminated. self.wrappedProtocol = tlsProtocol nextTransport = PopOnDisconnectTransport(original=self.transport, pop=self._pop) # Store the proxied transport as the list's head sentinel # to enable an easy identity check in _pop. self._headTransport = nextTransport else: # This a later TLS session within the stack. The previous # TLS session becomes its transport. nextTransport = PopOnDisconnectTransport( original=self._currentProtocol, pop=self._pop) # Splice the new TLS session into the linked list stack. # wrappedProtocol serves as the link, so the protocol at the # current position takes our new TLS session as its # wrappedProtocol. self._currentProtocol.wrappedProtocol = tlsProtocol # Move down one position in the linked list. self._currentProtocol = tlsProtocol # Expose the new, innermost TLS session as the transport to # the application protocol. self.transport = self._currentProtocol # Connect the new TLS session to the previous transport. The # transport attribute also serves as the previous link. tlsProtocol.makeConnection(nextTransport) # Left over bytes are part of the latest handshake. Pass them # on to the innermost TLS session. if bytes is not None: tlsProtocol.dataReceived(bytes)
def test_multipleWrites(self): """ If multiple separate TLS messages are received in a single chunk from the underlying transport, all of the application bytes from each message are delivered to the application-level protocol. """ bytes = [str(i) for i in range(10)] class SimpleSendingProtocol(Protocol): def connectionMade(self): for b in bytes: self.transport.write(b) clientFactory = ClientFactory() clientFactory.protocol = SimpleSendingProtocol clientContextFactory = HandshakeCallbackContextFactory() wrapperFactory = TLSMemoryBIOFactory(clientContextFactory, True, clientFactory) sslClientProtocol = wrapperFactory.buildProtocol(None) serverProtocol = AccumulatingProtocol(sum(map(len, bytes))) serverFactory = ServerFactory() serverFactory.protocol = lambda: serverProtocol serverContextFactory = DefaultOpenSSLContextFactory(certPath, certPath) wrapperFactory = TLSMemoryBIOFactory(serverContextFactory, False, serverFactory) sslServerProtocol = wrapperFactory.buildProtocol(None) connectionDeferred = loopbackAsync(sslServerProtocol, sslClientProtocol, collapsingPumpPolicy) # Wait for the connection to end, then make sure the server received # the bytes sent by the client. def cbConnectionDone(ignored): self.assertEquals("".join(serverProtocol.received), ''.join(bytes)) connectionDeferred.addCallback(cbConnectionDone) return connectionDeferred
def listenSSL(self, port, factory, contextFactory, backlog=50, interface=""): """ @see: twisted.internet.interfaces.IReactorSSL.listenSSL """ port = self.listenTCP( port, TLSMemoryBIOFactory(contextFactory, False, factory), backlog, interface, ) port._type = "TLS" return port
def test_writeBeforeHandshake(self): """ Bytes written to L{TLSMemoryBIOProtocol} before the handshake is complete are received by the protocol on the other side of the connection once the handshake succeeds. """ bytes = "some bytes" class SimpleSendingProtocol(Protocol): def connectionMade(self): self.transport.write(bytes) clientFactory = ClientFactory() clientFactory.protocol = SimpleSendingProtocol clientContextFactory, handshakeDeferred = ( HandshakeCallbackContextFactory.factoryAndDeferred()) wrapperFactory = TLSMemoryBIOFactory(clientContextFactory, True, clientFactory) sslClientProtocol = wrapperFactory.buildProtocol(None) serverProtocol = AccumulatingProtocol(len(bytes)) serverFactory = ServerFactory() serverFactory.protocol = lambda: serverProtocol serverContextFactory = DefaultOpenSSLContextFactory(certPath, certPath) wrapperFactory = TLSMemoryBIOFactory(serverContextFactory, False, serverFactory) sslServerProtocol = wrapperFactory.buildProtocol(None) connectionDeferred = loopbackAsync(sslServerProtocol, sslClientProtocol) # Wait for the connection to end, then make sure the server received # the bytes sent by the client. def cbConnectionDone(ignored): self.assertEquals("".join(serverProtocol.received), bytes) connectionDeferred.addCallback(cbConnectionDone) return connectionDeferred
def test_hugeWrite(self): """ If a very long string is passed to L{TLSMemoryBIOProtocol.write}, any trailing part of it which cannot be send immediately is buffered and sent later. """ bytes = "some bytes" factor = 8192 class SimpleSendingProtocol(Protocol): def connectionMade(self): self.transport.write(bytes * factor) clientFactory = ClientFactory() clientFactory.protocol = SimpleSendingProtocol clientContextFactory = HandshakeCallbackContextFactory() wrapperFactory = TLSMemoryBIOFactory(clientContextFactory, True, clientFactory) sslClientProtocol = wrapperFactory.buildProtocol(None) serverProtocol = AccumulatingProtocol(len(bytes) * factor) serverFactory = ServerFactory() serverFactory.protocol = lambda: serverProtocol serverContextFactory = DefaultOpenSSLContextFactory(certPath, certPath) wrapperFactory = TLSMemoryBIOFactory(serverContextFactory, False, serverFactory) sslServerProtocol = wrapperFactory.buildProtocol(None) connectionDeferred = loopbackAsync(sslServerProtocol, sslClientProtocol) # Wait for the connection to end, then make sure the server received # the bytes sent by the client. def cbConnectionDone(ignored): self.assertEquals("".join(serverProtocol.received), bytes * factor) connectionDeferred.addCallback(cbConnectionDone) return connectionDeferred
def __init__(self, context_factory): """ :param context_factory: TLS context factory for the AMP client. """ MultiService.__init__(self) convergence_loop = build_convergence_loop_fsm(self.reactor, self.deployer) self.logger = convergence_loop.logger self.cluster_status = build_cluster_status_fsm(convergence_loop) self.reconnecting_factory = ReconnectingClientFactory.forProtocol( lambda: AgentAMP(self.reactor, self)) self.factory = TLSMemoryBIOFactory(context_factory, True, self.reconnecting_factory)
def test_getPeerCertificate(self): """ L{TLSMemoryBIOFactory.getPeerCertificate} returns the L{OpenSSL.crypto.X509Type} instance representing the peer's certificate. """ # Set up a client and server so there's a certificate to grab. clientFactory = ClientFactory() clientFactory.protocol = Protocol clientContextFactory, handshakeDeferred = ( HandshakeCallbackContextFactory.factoryAndDeferred()) wrapperFactory = TLSMemoryBIOFactory(clientContextFactory, True, clientFactory) sslClientProtocol = wrapperFactory.buildProtocol(None) serverFactory = ServerFactory() serverFactory.protocol = Protocol serverContextFactory = DefaultOpenSSLContextFactory(certPath, certPath) wrapperFactory = TLSMemoryBIOFactory(serverContextFactory, False, serverFactory) sslServerProtocol = wrapperFactory.buildProtocol(None) connectionDeferred = loopbackAsync(sslServerProtocol, sslClientProtocol) # Wait for the handshake def cbHandshook(ignored): # Grab the server's certificate and check it out cert = sslClientProtocol.getPeerCertificate() self.assertIsInstance(cert, X509Type) self.assertEquals( cert.digest('md5'), '9B:A4:AB:43:10:BE:82:AE:94:3E:6B:91:F2:F3:40:E8') handshakeDeferred.addCallback(cbHandshook) return handshakeDeferred
def _build_test_server(): """Construct a test server This builds an HTTP channel, wrapped with a TLSMemoryBIOProtocol Returns: TLSMemoryBIOProtocol """ server_factory = Factory.forProtocol(HTTPChannel) # Request.finish expects the factory to have a 'log' method. server_factory.log = _log_request server_tls_factory = TLSMemoryBIOFactory( ServerTLSContext(), isClient=False, wrappedFactory=server_factory ) return server_tls_factory.buildProtocol(None)
def getEAPTLSTransport(self, state): # Create a server factory serverFactory = protocol.ServerFactory() serverFactory.protocol = lambda: eap.EAPTLSProtocol( state, self.peap_protocols) # Wrap it onto a context contextFactory = ssl.DefaultOpenSSLContextFactory( self.key, self.cert, sslmethod=SSL.TLSv1_METHOD) wrapperFactory = TLSMemoryBIOFactory(contextFactory, False, serverFactory) # Rig up a SSL wrapper to fake transport tlsProtocol = wrapperFactory.buildProtocol(None) transport = proto_utils.StringTransport() tlsProtocol.makeConnection(transport) return tlsProtocol
def test_makeConnection(self): """ When L{TLSMemoryBIOProtocol} is connected to a transport, it connects the protocol it wraps to a transport. """ clientProtocol = Protocol() clientFactory = ClientFactory() clientFactory.protocol = lambda: clientProtocol contextFactory = ClientContextFactory() wrapperFactory = TLSMemoryBIOFactory(contextFactory, True, clientFactory) sslProtocol = wrapperFactory.buildProtocol(None) transport = StringTransport() sslProtocol.makeConnection(transport) self.assertNotIdentical(clientProtocol.transport, None) self.assertNotIdentical(clientProtocol.transport, transport)
def _wrap_server_factory_for_tls(factory, sanlist=None): """Wrap an existing Protocol Factory with a test TLSMemoryBIOFactory The resultant factory will create a TLS server which presents a certificate signed by our test CA, valid for the domains in `sanlist` Args: factory (interfaces.IProtocolFactory): protocol factory to wrap sanlist (iterable[bytes]): list of domains the cert should be valid for Returns: interfaces.IProtocolFactory """ if sanlist is None: sanlist = [b"DNS:test.com"] connection_creator = TestServerTLSConnectionFactory(sanlist=sanlist) return TLSMemoryBIOFactory(connection_creator, isClient=False, wrappedFactory=factory)
def proxyClientFactoryClass(self, *args, **kwargs): """ Overwrites proxyClientFactoryClass to add a TLS wrapper to all connections generated by ReverseProxyResource protocol factory if enabled. """ client_factory = HTTPProxyClientFactory(*args, **kwargs) if self.__ssl_enabled: with open(server.config.ssl.certificate) as cert_file: cert = ssl.Certificate.loadPEM(cert_file.read()) # TLSMemoryBIOFactory is the wrapper that takes TLS options and # the wrapped factory to add TLS to connections return TLSMemoryBIOFactory(ssl.optionsForClientTLS( self.host.decode('ascii'), cert), isClient=True, wrappedFactory=client_factory) else: return client_factory