예제 #1
0
  def testPreProc(self):
    def _VerifyRequest(req):
      req.success = True
      req.resp_status_code = http.HTTP_OK
      req.resp_body = serializer.DumpJson((True, req.post_data))

    resolver = rpc._StaticResolver([
      "192.0.2.30",
      "192.0.2.35",
      ])

    nodes = [
      "node30.example.com",
      "node35.example.com",
      ]

    def _PreProc(node, data):
      self.assertEqual(len(data), 1)
      return data[0] + node

    cdef = ("test_call", NotImplemented, None, constants.RPC_TMO_NORMAL, [
      ("arg0", None, NotImplemented),
      ], _PreProc, None, NotImplemented)

    http_proc = _FakeRequestProcessor(_VerifyRequest)
    client = rpc._RpcClientBase(resolver, NotImplemented,
                                _req_process_fn=http_proc)

    for prefix in ["foo", "bar", "baz"]:
      result = client._Call(cdef, nodes, [prefix])
      self.assertEqual(len(result), len(nodes))
      for (idx, (node, res)) in enumerate(result.items()):
        self.assertFalse(res.fail_msg)
        self.assertEqual(serializer.LoadJson(res.payload), prefix + node)
예제 #2
0
  def testNoHosts(self):
    cdef = ("test_call", NotImplemented, None, constants.RPC_TMO_SLOW, [],
            None, None, NotImplemented)
    http_proc = _FakeRequestProcessor(NotImplemented)
    client = rpc._RpcClientBase(rpc._StaticResolver([]), NotImplemented,
                                _req_process_fn=http_proc)
    self.assertEqual(client._Call(cdef, [], []), {})

    # Test wrong number of arguments
    self.assertRaises(errors.ProgrammerError, client._Call,
                      cdef, [], [0, 1, 2])
예제 #3
0
 def testVersionSuccess(self):
   resolver = rpc._StaticResolver(["127.0.0.1"])
   http_proc = _FakeRequestProcessor(self._GetVersionResponse)
   proc = rpc._RpcProcessor(resolver, 24094)
   result = proc(["localhost"], "version", {"localhost": ""}, 60,
                 NotImplemented, _req_process_fn=http_proc)
   self.assertEqual(result.keys(), ["localhost"])
   lhresp = result["localhost"]
   self.assertFalse(lhresp.offline)
   self.assertEqual(lhresp.node, "localhost")
   self.assertFalse(lhresp.fail_msg)
   self.assertEqual(lhresp.payload, 123)
   self.assertEqual(lhresp.call, "version")
   lhresp.Raise("should not raise")
   self.assertEqual(http_proc.reqcount, 1)
예제 #4
0
 def testReadTimeout(self):
   resolver = rpc._StaticResolver(["192.0.2.13"])
   http_proc = _FakeRequestProcessor(self._ReadTimeoutResponse)
   proc = rpc._RpcProcessor(resolver, 19176)
   host = "node31856"
   body = {host: ""}
   result = proc([host], "version", body, 12356, NotImplemented,
                 _req_process_fn=http_proc)
   self.assertEqual(result.keys(), [host])
   lhresp = result[host]
   self.assertFalse(lhresp.offline)
   self.assertEqual(lhresp.node, host)
   self.assertFalse(lhresp.fail_msg)
   self.assertEqual(lhresp.payload, -1)
   self.assertEqual(lhresp.call, "version")
   lhresp.Raise("should not raise")
   self.assertEqual(http_proc.reqcount, 1)
예제 #5
0
  def testInvalidResponse(self):
    resolver = rpc._StaticResolver(["oqo7lanhly.example.com"])
    proc = rpc._RpcProcessor(resolver, 19978)

    for fn in [self._GetInvalidResponseA, self._GetInvalidResponseB]:
      http_proc = _FakeRequestProcessor(fn)
      host = "oqo7lanhly.example.com"
      body = {host: ""}
      result = proc([host], "version", body, 60, NotImplemented,
                    _req_process_fn=http_proc)
      self.assertEqual(result.keys(), [host])
      lhresp = result[host]
      self.assertFalse(lhresp.offline)
      self.assertEqual(lhresp.node, host)
      self.assert_(lhresp.fail_msg)
      self.assertFalse(lhresp.payload)
      self.assertEqual(lhresp.call, "version")
      self.assertRaises(errors.OpExecError, lhresp.Raise, "failed")
      self.assertEqual(http_proc.reqcount, 1)
예제 #6
0
 def testVersionFailure(self):
   resolver = rpc._StaticResolver(["aef9ur4i.example.com"])
   proc = rpc._RpcProcessor(resolver, 5903)
   for errinfo in [None, "Unknown error"]:
     http_proc = \
       _FakeRequestProcessor(compat.partial(self._GetVersionResponseFail,
                                            errinfo))
     host = "aef9ur4i.example.com"
     body = {host: ""}
     result = proc(body.keys(), "version", body, 60, NotImplemented,
                   _req_process_fn=http_proc)
     self.assertEqual(result.keys(), [host])
     lhresp = result[host]
     self.assertFalse(lhresp.offline)
     self.assertEqual(lhresp.node, host)
     self.assert_(lhresp.fail_msg)
     self.assertFalse(lhresp.payload)
     self.assertEqual(lhresp.call, "version")
     self.assertRaises(errors.OpExecError, lhresp.Raise, "failed")
     self.assertEqual(http_proc.reqcount, 1)
예제 #7
0
  def testMultiVersionSuccess(self):
    nodes = ["node%s" % i for i in range(50)]
    body = dict((n, "") for n in nodes)
    resolver = rpc._StaticResolver(nodes)
    http_proc = _FakeRequestProcessor(self._GetMultiVersionResponse)
    proc = rpc._RpcProcessor(resolver, 23245)
    result = proc(nodes, "version", body, 60, NotImplemented,
                  _req_process_fn=http_proc)
    self.assertEqual(sorted(result.keys()), sorted(nodes))

    for name in nodes:
      lhresp = result[name]
      self.assertFalse(lhresp.offline)
      self.assertEqual(lhresp.node, name)
      self.assertFalse(lhresp.fail_msg)
      self.assertEqual(lhresp.payload, 987)
      self.assertEqual(lhresp.call, "version")
      lhresp.Raise("should not raise")

    self.assertEqual(http_proc.reqcount, len(nodes))
예제 #8
0
  def testHttpError(self):
    nodes = ["uaf6pbbv%s" % i for i in range(50)]
    body = dict((n, "") for n in nodes)
    resolver = rpc._StaticResolver(nodes)

    httperrnodes = set(nodes[1::7])
    self.assertEqual(len(httperrnodes), 7)

    failnodes = set(nodes[2::3]) - httperrnodes
    self.assertEqual(len(failnodes), 14)

    self.assertEqual(len(set(nodes) - failnodes - httperrnodes), 29)

    proc = rpc._RpcProcessor(resolver, 15165)
    http_proc = \
      _FakeRequestProcessor(compat.partial(self._GetHttpErrorResponse,
                                           httperrnodes, failnodes))
    result = proc(nodes, "vg_list", body,
                  constants.RPC_TMO_URGENT, NotImplemented,
                  _req_process_fn=http_proc)
    self.assertEqual(sorted(result.keys()), sorted(nodes))

    for name in nodes:
      lhresp = result[name]
      self.assertFalse(lhresp.offline)
      self.assertEqual(lhresp.node, name)
      self.assertEqual(lhresp.call, "vg_list")

      if name in httperrnodes:
        self.assert_(lhresp.fail_msg)
        self.assertRaises(errors.OpExecError, lhresp.Raise, "failed")
      elif name in failnodes:
        self.assert_(lhresp.fail_msg)
        self.assertRaises(errors.OpPrereqError, lhresp.Raise, "failed",
                          prereq=True, ecode=errors.ECODE_INVAL)
      else:
        self.assertFalse(lhresp.fail_msg)
        self.assertEqual(lhresp.payload, hash(name))
        lhresp.Raise("should not raise")

    self.assertEqual(http_proc.reqcount, len(nodes))
예제 #9
0
  def testArgumentEncoder(self):
    (AT1, AT2) = range(1, 3)

    resolver = rpc._StaticResolver([
      "192.0.2.5",
      "192.0.2.6",
      ])

    nodes = [
      "node5.example.com",
      "node6.example.com",
      ]

    encoders = {
      AT1: hex,
      AT2: hash,
      }

    cdef = ("test_call", NotImplemented, None, constants.RPC_TMO_NORMAL, [
      ("arg0", None, NotImplemented),
      ("arg1", AT1, NotImplemented),
      ("arg1", AT2, NotImplemented),
      ], None, None, NotImplemented)

    def _VerifyRequest(req):
      req.success = True
      req.resp_status_code = http.HTTP_OK
      req.resp_body = serializer.DumpJson((True, req.post_data))

    http_proc = _FakeRequestProcessor(_VerifyRequest)

    for num in [0, 3796, 9032119]:
      client = rpc._RpcClientBase(resolver, encoders.get,
                                  _req_process_fn=http_proc)
      result = client._Call(cdef, nodes, ["foo", num, "Hello%s" % num])
      self.assertEqual(len(result), len(nodes))
      for res in result.values():
        self.assertFalse(res.fail_msg)
        self.assertEqual(serializer.LoadJson(res.payload),
                         ["foo", hex(num), hash("Hello%s" % num)])
예제 #10
0
  def testTimeout(self):
    def _CalcTimeout((arg1, arg2)):
      return arg1 + arg2

    def _VerifyRequest(exp_timeout, req):
      self.assertEqual(req.read_timeout, exp_timeout)

      req.success = True
      req.resp_status_code = http.HTTP_OK
      req.resp_body = serializer.DumpJson((True, hex(req.read_timeout)))

    resolver = rpc._StaticResolver([
      "192.0.2.1",
      "192.0.2.2",
      ])

    nodes = [
      "node1.example.com",
      "node2.example.com",
      ]

    tests = [(100, None, 100), (30, None, 30)]
    tests.extend((_CalcTimeout, i, i + 300)
                 for i in [0, 5, 16485, 30516])

    for timeout, arg1, exp_timeout in tests:
      cdef = ("test_call", NotImplemented, None, timeout, [
        ("arg1", None, NotImplemented),
        ("arg2", None, NotImplemented),
        ], None, None, NotImplemented)

      http_proc = _FakeRequestProcessor(compat.partial(_VerifyRequest,
                                                       exp_timeout))
      client = rpc._RpcClientBase(resolver, NotImplemented,
                                  _req_process_fn=http_proc)
      result = client._Call(cdef, nodes, [arg1, 300])
      self.assertEqual(len(result), len(nodes))
      self.assertTrue(compat.all(not res.fail_msg and
                                 res.payload == hex(exp_timeout)
                                 for res in result.values()))
예제 #11
0
 def testResponseBody(self):
   test_data = {
     "Hello": "World",
     "xyz": range(10),
     }
   resolver = rpc._StaticResolver(["192.0.2.84"])
   http_proc = _FakeRequestProcessor(compat.partial(self._GetBodyTestResponse,
                                                    test_data))
   proc = rpc._RpcProcessor(resolver, 18700)
   host = "node19759"
   body = {host: serializer.DumpJson(test_data)}
   result = proc([host], "upload_file", body, 30, NotImplemented,
                 _req_process_fn=http_proc)
   self.assertEqual(result.keys(), [host])
   lhresp = result[host]
   self.assertFalse(lhresp.offline)
   self.assertEqual(lhresp.node, host)
   self.assertFalse(lhresp.fail_msg)
   self.assertEqual(lhresp.payload, None)
   self.assertEqual(lhresp.call, "upload_file")
   lhresp.Raise("should not raise")
   self.assertEqual(http_proc.reqcount, 1)
예제 #12
0
  def testOfflineNode(self):
    resolver = rpc._StaticResolver([rpc._OFFLINE])
    http_proc = _FakeRequestProcessor(NotImplemented)
    proc = rpc._RpcProcessor(resolver, 30668)
    host = "n17296"
    body = {host: ""}
    result = proc([host], "version", body, 60, NotImplemented,
                  _req_process_fn=http_proc)
    self.assertEqual(result.keys(), [host])
    lhresp = result[host]
    self.assertTrue(lhresp.offline)
    self.assertEqual(lhresp.node, host)
    self.assertTrue(lhresp.fail_msg)
    self.assertFalse(lhresp.payload)
    self.assertEqual(lhresp.call, "version")

    # With a message
    self.assertRaises(errors.OpExecError, lhresp.Raise, "should raise")

    # No message
    self.assertRaises(errors.OpExecError, lhresp.Raise, None)

    self.assertEqual(http_proc.reqcount, 0)
예제 #13
0
  def testPostProc(self):
    def _VerifyRequest(nums, req):
      req.success = True
      req.resp_status_code = http.HTTP_OK
      req.resp_body = serializer.DumpJson((True, nums))

    resolver = rpc._StaticResolver([
      "192.0.2.90",
      "192.0.2.95",
      ])

    nodes = [
      "node90.example.com",
      "node95.example.com",
      ]

    def _PostProc(res):
      self.assertFalse(res.fail_msg)
      res.payload = sum(res.payload)
      return res

    cdef = ("test_call", NotImplemented, None, constants.RPC_TMO_NORMAL, [],
            None, _PostProc, NotImplemented)

    # Seeded random generator
    rnd = random.Random(20299)

    for i in [0, 4, 74, 1391]:
      nums = [rnd.randint(0, 1000) for _ in range(i)]
      http_proc = _FakeRequestProcessor(compat.partial(_VerifyRequest, nums))
      client = rpc._RpcClientBase(resolver, NotImplemented,
                                  _req_process_fn=http_proc)
      result = client._Call(cdef, nodes, [])
      self.assertEqual(len(result), len(nodes))
      for res in result.values():
        self.assertFalse(res.fail_msg)
        self.assertEqual(res.payload, sum(nums))
예제 #14
0
 def testWrongLength(self):
   res = rpc._StaticResolver([])
   self.assertRaises(AssertionError, res, ["abc"], NotImplemented)
예제 #15
0
 def test(self):
   addresses = ["192.0.2.%d" % n for n in range(0, 123, 7)]
   nodes = ["node%s.example.com" % n for n in range(0, 123, 7)]
   res = rpc._StaticResolver(addresses)
   self.assertEqual(res(nodes, NotImplemented), zip(nodes, addresses))