示例#1
0
 def testTryRpcWithMultipleAddressesAndRequests(self):
   flatten = lambda x: list(itertools.chain.from_iterable(x))
   with self.cached_session() as sess:
     addresses = flatten([[
         self._address, 'unix:/tmp/this_unix_socket_doesnt_exist_97820348!!@'
     ] for _ in range(10)])
     requests = [
         test_example_pb2.TestCase(
             values=[i, i + 1, i + 2]).SerializeToString() for i in range(20)
     ]
     response_tensors, status_code, _ = self.try_rpc(
         method=self.get_method_name('Increment'),
         address=addresses,
         request=requests)
     response_tensors_values, status_code_values = sess.run((response_tensors,
                                                             status_code))
     self.assertAllEqual(
         flatten([errors.OK, errors.UNAVAILABLE] for _ in range(10)),
         status_code_values)
     for i in range(20):
       if i % 2 == 1:
         self.assertFalse(response_tensors_values[i])
       else:
         response_message = test_example_pb2.TestCase()
         self.assertTrue(
             response_message.ParseFromString(response_tensors_values[i]))
         self.assertAllEqual([i + 1, i + 2, i + 3], response_message.values)
示例#2
0
 def testScalarHostPortRpc(self):
   with self.cached_session() as sess:
     request_tensors = (
         test_example_pb2.TestCase(values=[1, 2, 3]).SerializeToString())
     response_tensors = self.rpc(
         method=self.get_method_name('Increment'),
         address=self._address,
         request=request_tensors)
     self.assertEqual(response_tensors.shape, ())
     response_values = sess.run(response_tensors)
   response_message = test_example_pb2.TestCase()
   self.assertTrue(response_message.ParseFromString(response_values))
   self.assertAllEqual([2, 3, 4], response_message.values)
示例#3
0
 def testVecHostPortRpc(self):
   with self.cached_session() as sess:
     request_tensors = [
         test_example_pb2.TestCase(
             values=[i, i + 1, i + 2]).SerializeToString() for i in range(20)
     ]
     response_tensors = self.rpc(
         method=self.get_method_name('Increment'),
         address=self._address,
         request=request_tensors)
     self.assertEqual(response_tensors.shape, (20,))
     response_values = sess.run(response_tensors)
   self.assertEqual(response_values.shape, (20,))
   for i in range(20):
     response_message = test_example_pb2.TestCase()
     self.assertTrue(response_message.ParseFromString(response_values[i]))
     self.assertAllEqual([i + 1, i + 2, i + 3], response_message.values)
示例#4
0
 def testScalarHostPortTryRpc(self):
   with self.cached_session() as sess:
     request_tensors = (
         test_example_pb2.TestCase(values=[1, 2, 3]).SerializeToString())
     response_tensors, status_code, status_message = self.try_rpc(
         method=self.get_method_name('Increment'),
         address=self._address,
         request=request_tensors)
     self.assertEqual(status_code.shape, ())
     self.assertEqual(status_message.shape, ())
     self.assertEqual(response_tensors.shape, ())
     response_values, status_code_values, status_message_values = (
         sess.run((response_tensors, status_code, status_message)))
   response_message = test_example_pb2.TestCase()
   self.assertTrue(response_message.ParseFromString(response_values))
   self.assertAllEqual([2, 3, 4], response_message.values)
   # For the base Rpc op, don't expect to get error status back.
   self.assertEqual(errors.OK, status_code_values)
   self.assertEqual(b'', status_message_values)
示例#5
0
 def testVecHostPortManyParallelRpcs(self):
   with self.cached_session() as sess:
     request_tensors = [
         test_example_pb2.TestCase(
             values=[i, i + 1, i + 2]).SerializeToString() for i in range(20)
     ]
     many_response_tensors = [
         self.rpc(
             method=self.get_method_name('Increment'),
             address=self._address,
             request=request_tensors) for _ in range(10)
     ]
     # Launch parallel 10 calls to the RpcOp, each containing 20 rpc requests.
     many_response_values = sess.run(many_response_tensors)
   self.assertEqual(10, len(many_response_values))
   for response_values in many_response_values:
     self.assertEqual(response_values.shape, (20,))
     for i in range(20):
       response_message = test_example_pb2.TestCase()
       self.assertTrue(response_message.ParseFromString(response_values[i]))
       self.assertAllEqual([i + 1, i + 2, i + 3], response_message.values)
示例#6
0
 def testTryRpcWithMultipleMethodsSingleRequest(self):
   flatten = lambda x: list(itertools.chain.from_iterable(x))
   with self.cached_session() as sess:
     methods = flatten(
         [[self.get_method_name('Increment'), 'InvalidMethodName']
          for _ in range(10)])
     request = test_example_pb2.TestCase(values=[0, 1, 2]).SerializeToString()
     response_tensors, status_code, _ = self.try_rpc(
         method=methods, address=self._address, request=request)
     response_tensors_values, status_code_values = sess.run((response_tensors,
                                                             status_code))
     self.assertAllEqual(
         flatten([errors.OK, errors.UNIMPLEMENTED] for _ in range(10)),
         status_code_values)
     for i in range(10):
       self.assertTrue(response_tensors_values[2 * i])
       self.assertFalse(response_tensors_values[2 * i + 1])
示例#7
0
 def testTryRpcWithMultipleAddressesSingleRequest(self):
   flatten = lambda x: list(itertools.chain.from_iterable(x))
   with self.cached_session() as sess:
     addresses = flatten([[
         self._address, 'unix:/tmp/this_unix_socket_doesnt_exist_97820348!!@'
     ] for _ in range(10)])
     request = test_example_pb2.TestCase(values=[0, 1, 2]).SerializeToString()
     response_tensors, status_code, _ = self.try_rpc(
         method=self.get_method_name('Increment'),
         address=addresses,
         request=request)
     response_tensors_values, status_code_values = sess.run((response_tensors,
                                                             status_code))
     self.assertAllEqual(
         flatten([errors.OK, errors.UNAVAILABLE] for _ in range(10)),
         status_code_values)
     for i in range(10):
       self.assertTrue(response_tensors_values[2 * i])
       self.assertFalse(response_tensors_values[2 * i + 1])