def testPreProc(self): def _VerifyRequest(req): req.success = True req.resp_status_code = http.HTTP_OK req.resp_body = serializer.DumpJson((True, req.post_data)) resolver = rpc._StaticResolver([ "192.0.2.30", "192.0.2.35", ]) nodes = [ "node30.example.com", "node35.example.com", ] def _PreProc(node, data): self.assertEqual(len(data), 1) return data[0] + node cdef = ("test_call", NotImplemented, None, constants.RPC_TMO_NORMAL, [ ("arg0", None, NotImplemented), ], _PreProc, None, NotImplemented) http_proc = _FakeRequestProcessor(_VerifyRequest) client = rpc._RpcClientBase(resolver, NotImplemented, _req_process_fn=http_proc) for prefix in ["foo", "bar", "baz"]: result = client._Call(cdef, nodes, [prefix]) self.assertEqual(len(result), len(nodes)) for (idx, (node, res)) in enumerate(result.items()): self.assertFalse(res.fail_msg) self.assertEqual(serializer.LoadJson(res.payload), prefix + node)
def testNoHosts(self): cdef = ("test_call", NotImplemented, None, constants.RPC_TMO_SLOW, [], None, None, NotImplemented) http_proc = _FakeRequestProcessor(NotImplemented) client = rpc._RpcClientBase(rpc._StaticResolver([]), NotImplemented, _req_process_fn=http_proc) self.assertEqual(client._Call(cdef, [], []), {}) # Test wrong number of arguments self.assertRaises(errors.ProgrammerError, client._Call, cdef, [], [0, 1, 2])
def testVersionSuccess(self): resolver = rpc._StaticResolver(["127.0.0.1"]) http_proc = _FakeRequestProcessor(self._GetVersionResponse) proc = rpc._RpcProcessor(resolver, 24094) result = proc(["localhost"], "version", {"localhost": ""}, 60, NotImplemented, _req_process_fn=http_proc) self.assertEqual(result.keys(), ["localhost"]) lhresp = result["localhost"] self.assertFalse(lhresp.offline) self.assertEqual(lhresp.node, "localhost") self.assertFalse(lhresp.fail_msg) self.assertEqual(lhresp.payload, 123) self.assertEqual(lhresp.call, "version") lhresp.Raise("should not raise") self.assertEqual(http_proc.reqcount, 1)
def testReadTimeout(self): resolver = rpc._StaticResolver(["192.0.2.13"]) http_proc = _FakeRequestProcessor(self._ReadTimeoutResponse) proc = rpc._RpcProcessor(resolver, 19176) host = "node31856" body = {host: ""} result = proc([host], "version", body, 12356, NotImplemented, _req_process_fn=http_proc) self.assertEqual(result.keys(), [host]) lhresp = result[host] self.assertFalse(lhresp.offline) self.assertEqual(lhresp.node, host) self.assertFalse(lhresp.fail_msg) self.assertEqual(lhresp.payload, -1) self.assertEqual(lhresp.call, "version") lhresp.Raise("should not raise") self.assertEqual(http_proc.reqcount, 1)
def testInvalidResponse(self): resolver = rpc._StaticResolver(["oqo7lanhly.example.com"]) proc = rpc._RpcProcessor(resolver, 19978) for fn in [self._GetInvalidResponseA, self._GetInvalidResponseB]: http_proc = _FakeRequestProcessor(fn) host = "oqo7lanhly.example.com" body = {host: ""} result = proc([host], "version", body, 60, NotImplemented, _req_process_fn=http_proc) self.assertEqual(result.keys(), [host]) lhresp = result[host] self.assertFalse(lhresp.offline) self.assertEqual(lhresp.node, host) self.assert_(lhresp.fail_msg) self.assertFalse(lhresp.payload) self.assertEqual(lhresp.call, "version") self.assertRaises(errors.OpExecError, lhresp.Raise, "failed") self.assertEqual(http_proc.reqcount, 1)
def testVersionFailure(self): resolver = rpc._StaticResolver(["aef9ur4i.example.com"]) proc = rpc._RpcProcessor(resolver, 5903) for errinfo in [None, "Unknown error"]: http_proc = \ _FakeRequestProcessor(compat.partial(self._GetVersionResponseFail, errinfo)) host = "aef9ur4i.example.com" body = {host: ""} result = proc(body.keys(), "version", body, 60, NotImplemented, _req_process_fn=http_proc) self.assertEqual(result.keys(), [host]) lhresp = result[host] self.assertFalse(lhresp.offline) self.assertEqual(lhresp.node, host) self.assert_(lhresp.fail_msg) self.assertFalse(lhresp.payload) self.assertEqual(lhresp.call, "version") self.assertRaises(errors.OpExecError, lhresp.Raise, "failed") self.assertEqual(http_proc.reqcount, 1)
def testMultiVersionSuccess(self): nodes = ["node%s" % i for i in range(50)] body = dict((n, "") for n in nodes) resolver = rpc._StaticResolver(nodes) http_proc = _FakeRequestProcessor(self._GetMultiVersionResponse) proc = rpc._RpcProcessor(resolver, 23245) result = proc(nodes, "version", body, 60, NotImplemented, _req_process_fn=http_proc) self.assertEqual(sorted(result.keys()), sorted(nodes)) for name in nodes: lhresp = result[name] self.assertFalse(lhresp.offline) self.assertEqual(lhresp.node, name) self.assertFalse(lhresp.fail_msg) self.assertEqual(lhresp.payload, 987) self.assertEqual(lhresp.call, "version") lhresp.Raise("should not raise") self.assertEqual(http_proc.reqcount, len(nodes))
def testHttpError(self): nodes = ["uaf6pbbv%s" % i for i in range(50)] body = dict((n, "") for n in nodes) resolver = rpc._StaticResolver(nodes) httperrnodes = set(nodes[1::7]) self.assertEqual(len(httperrnodes), 7) failnodes = set(nodes[2::3]) - httperrnodes self.assertEqual(len(failnodes), 14) self.assertEqual(len(set(nodes) - failnodes - httperrnodes), 29) proc = rpc._RpcProcessor(resolver, 15165) http_proc = \ _FakeRequestProcessor(compat.partial(self._GetHttpErrorResponse, httperrnodes, failnodes)) result = proc(nodes, "vg_list", body, constants.RPC_TMO_URGENT, NotImplemented, _req_process_fn=http_proc) self.assertEqual(sorted(result.keys()), sorted(nodes)) for name in nodes: lhresp = result[name] self.assertFalse(lhresp.offline) self.assertEqual(lhresp.node, name) self.assertEqual(lhresp.call, "vg_list") if name in httperrnodes: self.assert_(lhresp.fail_msg) self.assertRaises(errors.OpExecError, lhresp.Raise, "failed") elif name in failnodes: self.assert_(lhresp.fail_msg) self.assertRaises(errors.OpPrereqError, lhresp.Raise, "failed", prereq=True, ecode=errors.ECODE_INVAL) else: self.assertFalse(lhresp.fail_msg) self.assertEqual(lhresp.payload, hash(name)) lhresp.Raise("should not raise") self.assertEqual(http_proc.reqcount, len(nodes))
def testArgumentEncoder(self): (AT1, AT2) = range(1, 3) resolver = rpc._StaticResolver([ "192.0.2.5", "192.0.2.6", ]) nodes = [ "node5.example.com", "node6.example.com", ] encoders = { AT1: hex, AT2: hash, } cdef = ("test_call", NotImplemented, None, constants.RPC_TMO_NORMAL, [ ("arg0", None, NotImplemented), ("arg1", AT1, NotImplemented), ("arg1", AT2, NotImplemented), ], None, None, NotImplemented) def _VerifyRequest(req): req.success = True req.resp_status_code = http.HTTP_OK req.resp_body = serializer.DumpJson((True, req.post_data)) http_proc = _FakeRequestProcessor(_VerifyRequest) for num in [0, 3796, 9032119]: client = rpc._RpcClientBase(resolver, encoders.get, _req_process_fn=http_proc) result = client._Call(cdef, nodes, ["foo", num, "Hello%s" % num]) self.assertEqual(len(result), len(nodes)) for res in result.values(): self.assertFalse(res.fail_msg) self.assertEqual(serializer.LoadJson(res.payload), ["foo", hex(num), hash("Hello%s" % num)])
def testTimeout(self): def _CalcTimeout((arg1, arg2)): return arg1 + arg2 def _VerifyRequest(exp_timeout, req): self.assertEqual(req.read_timeout, exp_timeout) req.success = True req.resp_status_code = http.HTTP_OK req.resp_body = serializer.DumpJson((True, hex(req.read_timeout))) resolver = rpc._StaticResolver([ "192.0.2.1", "192.0.2.2", ]) nodes = [ "node1.example.com", "node2.example.com", ] tests = [(100, None, 100), (30, None, 30)] tests.extend((_CalcTimeout, i, i + 300) for i in [0, 5, 16485, 30516]) for timeout, arg1, exp_timeout in tests: cdef = ("test_call", NotImplemented, None, timeout, [ ("arg1", None, NotImplemented), ("arg2", None, NotImplemented), ], None, None, NotImplemented) http_proc = _FakeRequestProcessor(compat.partial(_VerifyRequest, exp_timeout)) client = rpc._RpcClientBase(resolver, NotImplemented, _req_process_fn=http_proc) result = client._Call(cdef, nodes, [arg1, 300]) self.assertEqual(len(result), len(nodes)) self.assertTrue(compat.all(not res.fail_msg and res.payload == hex(exp_timeout) for res in result.values()))
def testResponseBody(self): test_data = { "Hello": "World", "xyz": range(10), } resolver = rpc._StaticResolver(["192.0.2.84"]) http_proc = _FakeRequestProcessor(compat.partial(self._GetBodyTestResponse, test_data)) proc = rpc._RpcProcessor(resolver, 18700) host = "node19759" body = {host: serializer.DumpJson(test_data)} result = proc([host], "upload_file", body, 30, NotImplemented, _req_process_fn=http_proc) self.assertEqual(result.keys(), [host]) lhresp = result[host] self.assertFalse(lhresp.offline) self.assertEqual(lhresp.node, host) self.assertFalse(lhresp.fail_msg) self.assertEqual(lhresp.payload, None) self.assertEqual(lhresp.call, "upload_file") lhresp.Raise("should not raise") self.assertEqual(http_proc.reqcount, 1)
def testOfflineNode(self): resolver = rpc._StaticResolver([rpc._OFFLINE]) http_proc = _FakeRequestProcessor(NotImplemented) proc = rpc._RpcProcessor(resolver, 30668) host = "n17296" body = {host: ""} result = proc([host], "version", body, 60, NotImplemented, _req_process_fn=http_proc) self.assertEqual(result.keys(), [host]) lhresp = result[host] self.assertTrue(lhresp.offline) self.assertEqual(lhresp.node, host) self.assertTrue(lhresp.fail_msg) self.assertFalse(lhresp.payload) self.assertEqual(lhresp.call, "version") # With a message self.assertRaises(errors.OpExecError, lhresp.Raise, "should raise") # No message self.assertRaises(errors.OpExecError, lhresp.Raise, None) self.assertEqual(http_proc.reqcount, 0)
def testPostProc(self): def _VerifyRequest(nums, req): req.success = True req.resp_status_code = http.HTTP_OK req.resp_body = serializer.DumpJson((True, nums)) resolver = rpc._StaticResolver([ "192.0.2.90", "192.0.2.95", ]) nodes = [ "node90.example.com", "node95.example.com", ] def _PostProc(res): self.assertFalse(res.fail_msg) res.payload = sum(res.payload) return res cdef = ("test_call", NotImplemented, None, constants.RPC_TMO_NORMAL, [], None, _PostProc, NotImplemented) # Seeded random generator rnd = random.Random(20299) for i in [0, 4, 74, 1391]: nums = [rnd.randint(0, 1000) for _ in range(i)] http_proc = _FakeRequestProcessor(compat.partial(_VerifyRequest, nums)) client = rpc._RpcClientBase(resolver, NotImplemented, _req_process_fn=http_proc) result = client._Call(cdef, nodes, []) self.assertEqual(len(result), len(nodes)) for res in result.values(): self.assertFalse(res.fail_msg) self.assertEqual(res.payload, sum(nums))
def testWrongLength(self): res = rpc._StaticResolver([]) self.assertRaises(AssertionError, res, ["abc"], NotImplemented)
def test(self): addresses = ["192.0.2.%d" % n for n in range(0, 123, 7)] nodes = ["node%s.example.com" % n for n in range(0, 123, 7)] res = rpc._StaticResolver(addresses) self.assertEqual(res(nodes, NotImplemented), zip(nodes, addresses))