def testRegister(self): """Create master and register other servers.""" m = master.DataMaster(7000, self.mock_service) self.assertNotEqual(m, None) self.assertFalse(m.AllRegistered()) server1 = m.RegisterServer("127.0.0.1", 7001) self.assertNotEqual(server1, None) self.assertEqual(server1.Address(), "127.0.0.1") self.assertEqual(server1.Port(), 7001) self.assertEqual(server1.Index(), 1) self.assertFalse(m.AllRegistered()) server2 = m.RegisterServer("127.0.0.1", 7002) self.assertNotEqual(server2, None) self.assertEqual(server2.Address(), "127.0.0.1") self.assertEqual(server2.Port(), 7002) self.assertEqual(server2.Index(), 2) self.assertFalse(m.AllRegistered()) server3 = m.RegisterServer("127.0.0.1", 7003) self.assertNotEqual(server3, None) self.assertEqual(server3.Address(), "127.0.0.1") self.assertEqual(server3.Port(), 7003) self.assertEqual(server3.Index(), 3) self.assertTrue(m.AllRegistered()) # Try to register something that does not exist. self.assertFalse(m.RegisterServer("127.0.0.1", 7004)) # Deregister a server. m.DeregisterServer(server1) self.assertFalse(m.AllRegistered()) # Register again. m.RegisterServer(server1.Address(), server1.Port()) self.assertTrue(m.AllRegistered())
def InitMasterServer(cls, port): cls.MASTER = master.DataMaster(port, cls.SERVICE) cls.MAPPING = cls.MASTER.LoadMapping() # Master is the only data server that knows about the client credentials. # The credentials will be sent to other data servers once they login. creds = auth.ClientCredentials() creds.InitializeFromConfig() cls.NONCE_STORE.SetClientCredentials(creds) logging.info("Starting Data Master/Server on port %d ...", port)
def testMapping(self): """Check that the mapping is valid.""" m = master.DataMaster(7000, self.mock_service) self.assertNotEqual(m, None) server1 = m.RegisterServer("127.0.0.1", 7001) server2 = m.RegisterServer("127.0.0.1", 7002) server3 = m.RegisterServer("127.0.0.1", 7003) self.assertTrue(m.AllRegistered()) mapping = m.LoadMapping() self.assertNotEqual(mapping, None) self.assertEqual(mapping.num_servers, 4) self.assertEqual(len(mapping.servers), 4) # Check server information. self.assertEqual(mapping.servers[0].address, "127.0.0.1") self.assertEqual(mapping.servers[0].port, 7000) self.assertEqual(mapping.servers[0].index, 0) for idx, server in [(1, server1), (2, server2), (3, server3)]: self.assertEqual(mapping.servers[idx].address, server.Address()) self.assertEqual(mapping.servers[idx].port, server.Port()) self.assertEqual(mapping.servers[idx].index, server.Index()) self.assertTrue(server.IsRegistered()) # Check intervals. interval1 = server1.Interval() interval2 = server2.Interval() interval3 = server3.Interval() self.assertEqual(interval1.end, interval2.start) self.assertEqual(interval2.end, interval3.start) self.assertEqual(interval1.end - interval1.end, interval2.end - interval2.end) self.assertEqual(interval1.end - interval1.end, interval3.end - interval3.end) self.assertEqual(interval3.end, constants.MAX_RANGE) # Check that mapping to a server works. self.assertEqual(utils._FindServerInMapping(mapping, 0x0), 0) self.assertEqual( utils._FindServerInMapping(mapping, constants.MAX_RANGE / 4), 1) self.assertEqual( utils._FindServerInMapping(mapping, constants.MAX_RANGE / 4 + 1), 1) self.assertEqual( utils._FindServerInMapping(mapping, constants.MAX_RANGE / 2), 2) half_fifth = constants.MAX_RANGE / 2 + constants.MAX_RANGE / 5 self.assertEqual(utils._FindServerInMapping(mapping, half_fifth), 2) self.assertEqual( utils._FindServerInMapping(mapping, constants.MAX_RANGE / 4 * 3), 3) self.assertEqual( utils._FindServerInMapping(mapping, constants.MAX_RANGE), 3)
def testRegister(self): """Create master and register other servers.""" m = master.DataMaster(self.ports[0], self.mock_service) self.assertNotEqual(m, None) self.assertFalse(m.AllRegistered()) servers = [None] for (i, port) in enumerate(self.ports): if i == 0: # Skip master server. continue self.assertFalse(m.AllRegistered()) server = m.RegisterServer(self.host, port) servers.append(server) self.assertNotEqual(server, None) self.assertEqual(server.Address(), self.host) self.assertEqual(server.Port(), port) self.assertEqual(server.Index(), i) self.assertTrue(m.AllRegistered()) # Try to register something that does not exist. self.assertFalse(m.RegisterServer(self.host, 7004)) # Deregister a server. m.DeregisterServer(servers[1]) self.assertFalse(m.AllRegistered()) # Register again. m.RegisterServer(servers[1].Address(), servers[1].Port()) self.assertTrue(m.AllRegistered()) for port in self.ports: for response_sequence in [ [constants.RESPONSE_OK, constants.RESPONSE_SERVER_NOT_AUTHORIZED], [constants.RESPONSE_OK, constants.RESPONSE_SERVER_NOT_ALLOWED], [constants.RESPONSE_OK, constants.RESPONSE_NOT_MASTER_SERVER] ]: response_mocks = [] for response_status in response_sequence: response_mocks.append(MockResponse(response_status)) pool_class = GetMockHTTPConnectionPoolClass(response_mocks) with libutils.Stubber(urllib3.connectionpool, "HTTPConnectionPool", pool_class): m = data_server.StandardDataServer(port, data_server.DataServerHandler) m.handler_cls.NONCE_STORE = auth.NonceStore() self.assertRaises(errors.DataServerError, m._DoRegister) # Ensure two requests have been made. self.assertEqual(len(pool_class.requests), 2) # Ensure the register body is non-empty. self.assertTrue(pool_class.requests[1]["body"]) # Ensure that the register body is a valid rdfvalue. rdf_data_server.DataStoreRegistrationRequest.FromSerializedString( pool_class.requests[1]["body"]) # Ensure the requests are POST requests. self.assertEqual(pool_class.requests[0]["method"], "POST") self.assertEqual(pool_class.requests[1]["method"], "POST") # Ensure the correct URLs are hit according to the API. self.assertEqual(pool_class.requests[0]["url"], "/server/handshake") self.assertEqual(pool_class.requests[1]["url"], "/server/register")