def testVecHostPortManyParallelRpcs(self): with self.cached_session() as sess: request_tensors = [ test_example_pb2.TestCase( values=[i, i + 1, i + 2]).SerializeToString() for i in range(20) ] many_response_tensors = [ self.rpc( method=self.get_method_name('Increment'), address=self._address, request=request_tensors) for _ in range(10) ] # Launch parallel 10 calls to the RpcOp, each containing 20 rpc requests. many_response_values = sess.run(many_response_tensors) self.assertEqual(10, len(many_response_values)) for response_values in many_response_values: self.assertEqual(response_values.shape, (20,)) for i in range(20): response_message = test_example_pb2.TestCase() self.assertTrue(response_message.ParseFromString(response_values[i])) self.assertAllEqual([i + 1, i + 2, i + 3], response_message.values)
def testTryRpcWithMultipleMethodsSingleRequest(self): flatten = lambda x: list(itertools.chain.from_iterable(x)) with self.test_session() as sess: methods = flatten( [[self.get_method_name('IncrementTestShapes'), 'InvalidMethodName'] for _ in range(10)]) request = test_example_pb2.TestCase(shape=[0, 1, 2]).SerializeToString() response_tensors, status_code, _ = self.try_rpc( method=methods, address=self._address, request=request) response_tensors_values, status_code_values = sess.run((response_tensors, status_code)) self.assertAllEqual( flatten([errors.OK, errors.UNIMPLEMENTED] for _ in range(10)), status_code_values) for i in range(10): self.assertTrue(response_tensors_values[2 * i]) self.assertFalse(response_tensors_values[2 * i + 1])
def testTryRpcWithMultipleAddressesSingleRequest(self): flatten = lambda x: list(itertools.chain.from_iterable(x)) with self.test_session() as sess: addresses = flatten([[ self._address, 'unix:/tmp/this_unix_socket_doesnt_exist_97820348!!@' ] for _ in range(10)]) request = test_example_pb2.TestCase(shape=[0, 1, 2]).SerializeToString() response_tensors, status_code, _ = self.try_rpc( method=self.get_method_name('IncrementTestShapes'), address=addresses, request=request) response_tensors_values, status_code_values = sess.run((response_tensors, status_code)) self.assertAllEqual( flatten([errors.OK, errors.UNAVAILABLE] for _ in range(10)), status_code_values) for i in range(10): self.assertTrue(response_tensors_values[2 * i]) self.assertFalse(response_tensors_values[2 * i + 1])