def test_connectionMade_drops_connections_if_authentication_errors(self): logger = self.useFixture(TwistedLoggerFixture()) service = RegionService(sentinel.ipcWorker) service.running = True # Pretend it's running. service.factory.protocol = HandshakingRegionServer protocol = service.factory.buildProtocol(addr=None) # addr is unused. protocol.transport = MagicMock() exception_type = factory.make_exception_type() self.patch_authenticate_for_error(protocol, exception_type()) self.assertDictEqual({}, service.connections) connectionMade = wait_for_reactor(protocol.connectionMade) connectionMade() # The protocol is not added to the connection set. self.assertDictEqual({}, service.connections) # The transport is instructed to lose the connection. self.assertThat(protocol.transport.loseConnection, MockCalledOnceWith()) # The log was written to. self.assertDocTestMatches( """\ Rack controller '...' could not be authenticated; dropping connection. Check that /var/lib/maas/secret...""", logger.dump())
def test_start_up_binds_first_of_real_endpoint_options(self): service = RegionService(sentinel.ipcWorker) # endpoint_1.listen(...) will bind to a random high-numbered port. endpoint_1 = TCP4ServerEndpoint(reactor, 0) # endpoint_2.listen(...), if attempted, will crash because only root # (or a user with explicit capabilities) can do stuff like that. It's # a reasonable assumption that the user running these tests is not # root, but we'll check the port number later too to be sure. endpoint_2 = TCP4ServerEndpoint(reactor, 1) service.endpoints = [[endpoint_1, endpoint_2]] yield service.startService() self.addCleanup(wait_for_reactor(service.stopService)) # A single port has been bound. self.assertThat( service.ports, MatchesAll(HasLength(1), AllMatch(IsInstance(tcp.Port)))) # The port is not listening on port 1; i.e. a belt-n-braces check that # endpoint_2 was not used. [port] = service.ports self.assertThat(port.getHost().port, Not(Equals(1)))
def test_removeConnectionFor_is_okay_if_connection_is_not_there(self): service = RegionService(sentinel.ipcWorker) uuid = factory.make_UUID() service._removeConnectionFor(uuid, DummyConnection()) self.assertEqual({uuid: set()}, service.connections)
def test_stopping_logs_errors_when_closing_connections(self): service = RegionService(sentinel.ipcWorker) service.starting = Deferred() service.starting.addErrback( lambda failure: failure.trap(CancelledError)) service.factory.protocol = HandshakingRegionServer connections = { service.factory.buildProtocol(None), service.factory.buildProtocol(None), } for conn in connections: transport = self.patch(conn, "transport") transport.loseConnection.side_effect = OSError("broken") # Pretend it's already connected. service.connections[conn.ident].add(conn) logger = self.useFixture(TwistedLoggerFixture()) # stopService() completes without returning an error. yield service.stopService() # Connection-specific errors are logged. self.assertDocTestMatches( """\ Failure when closing RPC connection. Traceback (most recent call last): ... builtins.OSError: broken ... Failure when closing RPC connection. Traceback (most recent call last): ... builtins.OSError: broken """, logger.dump())
def test_getClientFor_errors_when_no_connections_for_cluster(self): service = RegionService(sentinel.ipcWorker) uuid = factory.make_UUID() service.connections[uuid].clear() return assert_fails_with( service.getClientFor(uuid, timeout=0), exceptions.NoConnectionsAvailable)
def test_connectionMade_does_not_update_services_connection_set(self): service = RegionService(sentinel.ipcWorker) service.running = True # Pretend it's running. service.factory.protocol = HandshakingRegionServer protocol = service.factory.buildProtocol(addr=None) # addr is unused. self.assertDictEqual({}, service.connections) protocol.connectionMade() self.assertDictEqual({}, service.connections)
def test_removeConnectionFor_fires_disconnected_event(self): service = RegionService(sentinel.ipcWorker) uuid = factory.make_UUID() c1 = DummyConnection() mock_fire = self.patch(service.events.disconnected, "fire") service._removeConnectionFor(uuid, c1) self.assertThat(mock_fire, MockCalledOnceWith(uuid))
def test_addConnectionFor_adds_connection(self): service = RegionService(sentinel.ipcWorker) uuid = factory.make_UUID() c1 = DummyConnection() c2 = DummyConnection() service._addConnectionFor(uuid, c1) service._addConnectionFor(uuid, c2) self.assertEqual({uuid: {c1, c2}}, service.connections)
def test_stopping_when_start_up_failed(self): service = RegionService(sentinel.ipcWorker) # Ensure that endpoint.listen fails with a obvious error. exception = ValueError("This is a very naughty boy.") endpoints = self.patch(service, "endpoints", [[Mock()]]) endpoints[0][0].listen.return_value = fail(exception) service.startService() # The test is that stopService() succeeds. return service.stopService()
def test_connectionMade_drops_connection_if_service_not_running(self): service = RegionService(sentinel.ipcWorker) service.running = False # Pretend it's not running. service.factory.protocol = HandshakingRegionServer protocol = service.factory.buildProtocol(addr=None) # addr is unused. transport = self.patch(protocol, "transport") self.assertDictEqual({}, service.connections) protocol.connectionMade() # The protocol is not added to the connection set. self.assertDictEqual({}, service.connections) # The transport is instructed to lose the connection. self.assertThat(transport.loseConnection, MockCalledOnceWith())
def test_start_up_binds_first_of_endpoint_options(self): service = RegionService(sentinel.ipcWorker) endpoint_1 = Mock() endpoint_1.listen.return_value = succeed(sentinel.port1) endpoint_2 = Mock() endpoint_2.listen.return_value = succeed(sentinel.port2) service.endpoints = [[endpoint_1, endpoint_2]] yield service.startService() self.assertThat(service.ports, Equals([sentinel.port1]))
def test_start_up_binds_first_successful_of_endpoint_options(self): service = RegionService(sentinel.ipcWorker) endpoint_broken = Mock() endpoint_broken.listen.return_value = fail(factory.make_exception()) endpoint_okay = Mock() endpoint_okay.listen.return_value = succeed(sentinel.port) service.endpoints = [[endpoint_broken, endpoint_okay]] yield service.startService() self.assertThat(service.ports, Equals([sentinel.port]))
def test_worker_registers_and_unregister_rpc_connection(self): yield deferToDatabase(load_builtin_scripts) pid = random.randint(1, 512) self.patch(os, "getpid").return_value = pid ( master, connected, disconnected, ) = self.make_IPCMasterService_with_wrap() rpc_started = self.wrap_async_method(master, "registerWorkerRPC") yield master.startService() worker = IPCWorkerService(reactor, socket_path=self.ipc_path) rpc = RegionService(worker) yield worker.startService() yield rpc.startService() yield connected.get(timeout=2) yield rpc_started.get(timeout=2) rackd = yield deferToDatabase(factory.make_RackController) connid = str(uuid.uuid4()) address = factory.make_ipv4_address() port = random.randint(1000, 5000) yield worker.rpcRegisterConnection(connid, rackd.system_id, address, port) def getConnection(): region = RegionController.objects.get_running_controller() process = RegionControllerProcess.objects.get(region=region, pid=pid) endpoint = RegionControllerProcessEndpoint.objects.get( process=process, address=address, port=port) return RegionRackRPCConnection.objects.filter( endpoint=endpoint, rack_controller=rackd).first() connection = yield deferToDatabase(getConnection) self.assertIsNotNone(connection) self.assertEqual( {connid: (rackd.system_id, address, port)}, master.connections[pid]["rpc"]["connections"], ) yield worker.rpcUnregisterConnection(connid) connection = yield deferToDatabase(getConnection) self.assertIsNone(connection) yield rpc.stopService() yield worker.stopService() yield disconnected.get(timeout=2) yield master.stopService()
def test_getConnectionFor_cancels_waiter_when_it_times_out(self): service = RegionService(sentinel.ipcWorker) uuid = factory.make_UUID() d = service._getConnectionFor(uuid, 1) # A waiter is added for the connection we're interested in. self.assertEqual({uuid: {d}}, service.waiters) d = assert_fails_with(d, CancelledError) def check(_): # The waiter has been unregistered. self.assertEqual({uuid: set()}, service.waiters) return d.addCallback(check)
def test_connectionLost_updates_services_connection_set(self): service = RegionService(sentinel.ipcWorker) service.running = True # Pretend it's running. service.factory.protocol = HandshakingRegionServer protocol = service.factory.buildProtocol(addr=None) # addr is unused. protocol.ident = factory.make_name("node") connectionLost_up_call = self.patch(amp.AMP, "connectionLost") service.connections[protocol.ident] = {protocol} protocol.connectionLost(reason=None) # The connection is removed from the set, but the key remains. self.assertDictEqual({protocol.ident: set()}, service.connections) # connectionLost() is called on the superclass. self.assertThat(connectionLost_up_call, MockCalledOnceWith(None))
def test_getConnectionFor_returns_existing_connection(self): service = RegionService(sentinel.ipcWorker) uuid = factory.make_UUID() conn = DummyConnection() service._addConnectionFor(uuid, conn) d = service._getConnectionFor(uuid, 1) # No waiter is added because a connection is available. self.assertEqual({uuid: set()}, service.waiters) def check(conn_returned): self.assertEquals(conn, conn_returned) return d.addCallback(check)
def test_startService_returns_Deferred(self): service = RegionService(sentinel.ipcWorker) # Don't configure any endpoints. self.patch(service, "endpoints", []) d = service.startService() self.assertThat(d, IsInstance(Deferred)) # It's actually the `starting` Deferred. self.assertIs(service.starting, d) def started(_): return service.stopService() return d.addCallback(started)
def make_Region(self, ipcWorker=None): if ipcWorker is None: ipcWorker = sentinel.ipcWorker patched_region = RegionServer() patched_region.factory = Factory.forProtocol(RegionServer) patched_region.factory.service = RegionService(ipcWorker) return patched_region
def test_start_up_errors_are_logged(self): ipcWorker = MagicMock() service = RegionService(ipcWorker) # Ensure that endpoint.listen fails with a obvious error. exception = ValueError("This is not the messiah.") endpoints = self.patch(service, "endpoints", [[Mock()]]) endpoints[0][0].listen.return_value = fail(exception) logged_failures_expected = [ AfterPreprocessing((lambda failure: failure.value), Is(exception)) ] with TwistedLoggerFixture() as logger: yield service.startService() self.assertThat(logger.failures, MatchesListwise(logged_failures_expected))
def test_start_up_can_be_cancelled(self): service = RegionService(sentinel.ipcWorker) # Return an inert Deferred from the listen() call. endpoints = self.patch(service, "endpoints", [[Mock()]]) endpoints[0][0].listen.return_value = Deferred() service.startService() self.assertThat(service.starting, IsInstance(Deferred)) service.starting.cancel() def check(port): self.assertThat(port, Is(None)) self.assertThat(service.ports, HasLength(0)) return service.stopService() return service.starting.addCallback(check)
def test_addConnectionFor_notifies_waiters(self): service = RegionService(sentinel.ipcWorker) uuid = factory.make_UUID() c1 = DummyConnection() c2 = DummyConnection() waiter1 = Mock() waiter2 = Mock() service.waiters[uuid].add(waiter1) service.waiters[uuid].add(waiter2) service._addConnectionFor(uuid, c1) service._addConnectionFor(uuid, c2) self.assertEqual({uuid: {c1, c2}}, service.connections) # Both mock waiters are called twice. A real waiter would only be # called once because it immediately unregisters itself once called. self.assertThat(waiter1.callback, MockCallsMatch(call(c1), call(c2))) self.assertThat(waiter2.callback, MockCallsMatch(call(c1), call(c2)))
def test_getConnectionFor_waits_for_connection(self): service = RegionService(sentinel.ipcWorker) uuid = factory.make_UUID() conn = DummyConnection() # Add the connection later (we're in the reactor thread right # now so this won't happen until after we return). reactor.callLater(0, service._addConnectionFor, uuid, conn) d = service._getConnectionFor(uuid, 1) # A waiter is added for the connection we're interested in. self.assertEqual({uuid: {d}}, service.waiters) def check(conn_returned): self.assertEqual(conn, conn_returned) # The waiter has been unregistered. self.assertEqual({uuid: set()}, service.waiters) return d.addCallback(check)
def test_stopping_closes_connections_cleanly(self): service = RegionService(sentinel.ipcWorker) service.starting = Deferred() service.starting.addErrback( lambda failure: failure.trap(CancelledError)) service.factory.protocol = HandshakingRegionServer connections = { service.factory.buildProtocol(None), service.factory.buildProtocol(None), } for conn in connections: # Pretend it's already connected. service.connections[conn.ident].add(conn) transports = {self.patch(conn, "transport") for conn in connections} yield service.stopService() self.assertThat( transports, AllMatch( AfterPreprocessing(attrgetter("loseConnection"), MockCalledOnceWith())))
def test_init_sets_appropriate_instance_attributes(self): service = RegionService(sentinel.ipcWorker) self.assertThat(service, IsInstance(Service)) self.assertThat(service.connections, IsInstance(defaultdict)) self.assertThat(service.connections.default_factory, Is(set)) self.assertThat(service.endpoints, AllMatch(AllMatch(Provides(IStreamServerEndpoint)))) self.assertThat(service.factory, IsInstance(Factory)) self.assertThat(service.factory.protocol, Equals(RegionServer)) self.assertThat(service.events.connected, IsInstance(events.Event)) self.assertThat(service.events.disconnected, IsInstance(events.Event))
def test_connectionLost_uses_ipcWorker_to_unregister(self): ipcWorker = MagicMock() service = RegionService(ipcWorker) service.running = True # Pretend it's running. service.factory.protocol = HandshakingRegionServer protocol = service.factory.buildProtocol(addr=None) # addr is unused. protocol.ident = factory.make_name("node") protocol.host = Mock() protocol.host.host = sentinel.host protocol.host.port = sentinel.port protocol.hostIsRemote = True connectionLost_up_call = self.patch(amp.AMP, "connectionLost") service.connections[protocol.ident] = {protocol} protocol.connectionLost(reason=None) self.assertThat(ipcWorker.rpcUnregisterConnection, MockCalledOnceWith(protocol.connid)) # The connection is removed from the set, but the key remains. self.assertDictEqual({protocol.ident: set()}, service.connections) # connectionLost() is called on the superclass. self.assertThat(connectionLost_up_call, MockCalledOnceWith(None))
def test_getClientFor_returns_random_connection(self): c1 = DummyConnection() c2 = DummyConnection() chosen = DummyConnection() service = RegionService(sentinel.ipcWorker) uuid = factory.make_UUID() conns_for_uuid = service.connections[uuid] conns_for_uuid.update({c1, c2}) def check_choice(choices): self.assertItemsEqual(choices, conns_for_uuid) return chosen self.patch(random, "choice", check_choice) def check(client): self.assertThat(client, Equals(RackClient(chosen, {}))) self.assertIs(client.cache, service.connectionsCache[client._conn]) return service.getClientFor(uuid).addCallback(check)
def test_getAllClients(self): service = RegionService(sentinel.ipcWorker) uuid1 = factory.make_UUID() c1 = DummyConnection() c2 = DummyConnection() service.connections[uuid1].update({c1, c2}) uuid2 = factory.make_UUID() c3 = DummyConnection() c4 = DummyConnection() service.connections[uuid2].update({c3, c4}) clients = service.getAllClients() self.assertThat(list(clients), MatchesAny( MatchesSetwise( Equals(RackClient(c1, {})), Equals(RackClient(c3, {}))), MatchesSetwise( Equals(RackClient(c1, {})), Equals(RackClient(c4, {}))), MatchesSetwise( Equals(RackClient(c2, {})), Equals(RackClient(c3, {}))), MatchesSetwise( Equals(RackClient(c2, {})), Equals(RackClient(c4, {}))), ))
def test_start_up_logs_failure_if_all_endpoint_options_fail(self): service = RegionService(sentinel.ipcWorker) error_1 = factory.make_exception_type() error_2 = factory.make_exception_type() endpoint_1 = Mock() endpoint_1.listen.return_value = fail(error_1()) endpoint_2 = Mock() endpoint_2.listen.return_value = fail(error_2()) service.endpoints = [[endpoint_1, endpoint_2]] with TwistedLoggerFixture() as logger: yield service.startService() self.assertDocTestMatches( """\ RegionServer endpoint failed to listen. Traceback (most recent call last): ... %s: """ % fullyQualifiedName(error_2), logger.output)
def test_getConnectionFor_with_concurrent_waiters(self): service = RegionService(sentinel.ipcWorker) uuid = factory.make_UUID() conn = DummyConnection() # Add the connection later (we're in the reactor thread right # now so this won't happen until after we return). reactor.callLater(0, service._addConnectionFor, uuid, conn) d1 = service._getConnectionFor(uuid, 1) d2 = service._getConnectionFor(uuid, 1) # A waiter is added for each call to _getConnectionFor(). self.assertEqual({uuid: {d1, d2}}, service.waiters) d = DeferredList((d1, d2)) def check(results): self.assertEqual([(True, conn), (True, conn)], results) # The waiters have both been unregistered. self.assertEqual({uuid: set()}, service.waiters) return d.addCallback(check)
def test_worker_registers_rpc_endpoints(self): yield deferToDatabase(load_builtin_scripts) pid = random.randint(1, 512) self.patch(os, "getpid").return_value = pid ( master, connected, disconnected, ) = self.make_IPCMasterService_with_wrap() rpc_started = self.wrap_async_method(master, "registerWorkerRPC") yield master.startService() worker = IPCWorkerService(reactor, socket_path=self.ipc_path) rpc = RegionService(worker) yield worker.startService() yield rpc.startService() yield connected.get(timeout=2) yield rpc_started.get(timeout=2) def getEndpoints(): region = RegionController.objects.get_running_controller() process = RegionControllerProcess.objects.get(region=region, pid=pid) return set([(endpoint.address, endpoint.port) for endpoint in ( RegionControllerProcessEndpoint.objects.filter( process=process))]) endpoints = yield deferToDatabase(getEndpoints) self.assertEqual( master._getListenAddresses(master.connections[pid]["rpc"]["port"]), endpoints, ) yield rpc.stopService() yield worker.stopService() yield disconnected.get(timeout=2) yield master.stopService()