Esempio n. 1
0
 def testVecHostPortManyParallelRpcs(self):
   with self.cached_session() as sess:
     request_tensors = [
         test_example_pb2.TestCase(
             values=[i, i + 1, i + 2]).SerializeToString() for i in range(20)
     ]
     many_response_tensors = [
         self.rpc(
             method=self.get_method_name('Increment'),
             address=self._address,
             request=request_tensors) for _ in range(10)
     ]
     # Launch parallel 10 calls to the RpcOp, each containing 20 rpc requests.
     many_response_values = sess.run(many_response_tensors)
   self.assertEqual(10, len(many_response_values))
   for response_values in many_response_values:
     self.assertEqual(response_values.shape, (20,))
     for i in range(20):
       response_message = test_example_pb2.TestCase()
       self.assertTrue(response_message.ParseFromString(response_values[i]))
       self.assertAllEqual([i + 1, i + 2, i + 3], response_message.values)
Esempio n. 2
0
 def testTryRpcWithMultipleMethodsSingleRequest(self):
   flatten = lambda x: list(itertools.chain.from_iterable(x))
   with self.test_session() as sess:
     methods = flatten(
         [[self.get_method_name('IncrementTestShapes'), 'InvalidMethodName']
          for _ in range(10)])
     request = test_example_pb2.TestCase(shape=[0, 1, 2]).SerializeToString()
     response_tensors, status_code, _ = self.try_rpc(
         method=methods, address=self._address, request=request)
     response_tensors_values, status_code_values = sess.run((response_tensors,
                                                             status_code))
     self.assertAllEqual(
         flatten([errors.OK, errors.UNIMPLEMENTED] for _ in range(10)),
         status_code_values)
     for i in range(10):
       self.assertTrue(response_tensors_values[2 * i])
       self.assertFalse(response_tensors_values[2 * i + 1])
Esempio n. 3
0
 def testTryRpcWithMultipleAddressesSingleRequest(self):
   flatten = lambda x: list(itertools.chain.from_iterable(x))
   with self.test_session() as sess:
     addresses = flatten([[
         self._address, 'unix:/tmp/this_unix_socket_doesnt_exist_97820348!!@'
     ] for _ in range(10)])
     request = test_example_pb2.TestCase(shape=[0, 1, 2]).SerializeToString()
     response_tensors, status_code, _ = self.try_rpc(
         method=self.get_method_name('IncrementTestShapes'),
         address=addresses,
         request=request)
     response_tensors_values, status_code_values = sess.run((response_tensors,
                                                             status_code))
     self.assertAllEqual(
         flatten([errors.OK, errors.UNAVAILABLE] for _ in range(10)),
         status_code_values)
     for i in range(10):
       self.assertTrue(response_tensors_values[2 * i])
       self.assertFalse(response_tensors_values[2 * i + 1])