def testTryRpcWithMultipleAddressesAndRequests(self): flatten = lambda x: list(itertools.chain.from_iterable(x)) with self.cached_session() as sess: addresses = flatten([[ self._address, 'unix:/tmp/this_unix_socket_doesnt_exist_97820348!!@' ] for _ in range(10)]) requests = [ test_example_pb2.TestCase( values=[i, i + 1, i + 2]).SerializeToString() for i in range(20) ] response_tensors, status_code, _ = self.try_rpc( method=self.get_method_name('Increment'), address=addresses, request=requests) response_tensors_values, status_code_values = sess.run((response_tensors, status_code)) self.assertAllEqual( flatten([errors.OK, errors.UNAVAILABLE] for _ in range(10)), status_code_values) for i in range(20): if i % 2 == 1: self.assertFalse(response_tensors_values[i]) else: response_message = test_example_pb2.TestCase() self.assertTrue( response_message.ParseFromString(response_tensors_values[i])) self.assertAllEqual([i + 1, i + 2, i + 3], response_message.values)
def testScalarHostPortRpc(self): with self.cached_session() as sess: request_tensors = ( test_example_pb2.TestCase(values=[1, 2, 3]).SerializeToString()) response_tensors = self.rpc( method=self.get_method_name('Increment'), address=self._address, request=request_tensors) self.assertEqual(response_tensors.shape, ()) response_values = sess.run(response_tensors) response_message = test_example_pb2.TestCase() self.assertTrue(response_message.ParseFromString(response_values)) self.assertAllEqual([2, 3, 4], response_message.values)
def testVecHostPortRpc(self): with self.cached_session() as sess: request_tensors = [ test_example_pb2.TestCase( values=[i, i + 1, i + 2]).SerializeToString() for i in range(20) ] response_tensors = self.rpc( method=self.get_method_name('Increment'), address=self._address, request=request_tensors) self.assertEqual(response_tensors.shape, (20,)) response_values = sess.run(response_tensors) self.assertEqual(response_values.shape, (20,)) for i in range(20): response_message = test_example_pb2.TestCase() self.assertTrue(response_message.ParseFromString(response_values[i])) self.assertAllEqual([i + 1, i + 2, i + 3], response_message.values)
def testScalarHostPortTryRpc(self): with self.cached_session() as sess: request_tensors = ( test_example_pb2.TestCase(values=[1, 2, 3]).SerializeToString()) response_tensors, status_code, status_message = self.try_rpc( method=self.get_method_name('Increment'), address=self._address, request=request_tensors) self.assertEqual(status_code.shape, ()) self.assertEqual(status_message.shape, ()) self.assertEqual(response_tensors.shape, ()) response_values, status_code_values, status_message_values = ( sess.run((response_tensors, status_code, status_message))) response_message = test_example_pb2.TestCase() self.assertTrue(response_message.ParseFromString(response_values)) self.assertAllEqual([2, 3, 4], response_message.values) # For the base Rpc op, don't expect to get error status back. self.assertEqual(errors.OK, status_code_values) self.assertEqual(b'', status_message_values)
def testVecHostPortManyParallelRpcs(self): with self.cached_session() as sess: request_tensors = [ test_example_pb2.TestCase( values=[i, i + 1, i + 2]).SerializeToString() for i in range(20) ] many_response_tensors = [ self.rpc( method=self.get_method_name('Increment'), address=self._address, request=request_tensors) for _ in range(10) ] # Launch parallel 10 calls to the RpcOp, each containing 20 rpc requests. many_response_values = sess.run(many_response_tensors) self.assertEqual(10, len(many_response_values)) for response_values in many_response_values: self.assertEqual(response_values.shape, (20,)) for i in range(20): response_message = test_example_pb2.TestCase() self.assertTrue(response_message.ParseFromString(response_values[i])) self.assertAllEqual([i + 1, i + 2, i + 3], response_message.values)
def testTryRpcWithMultipleMethodsSingleRequest(self): flatten = lambda x: list(itertools.chain.from_iterable(x)) with self.cached_session() as sess: methods = flatten( [[self.get_method_name('Increment'), 'InvalidMethodName'] for _ in range(10)]) request = test_example_pb2.TestCase(values=[0, 1, 2]).SerializeToString() response_tensors, status_code, _ = self.try_rpc( method=methods, address=self._address, request=request) response_tensors_values, status_code_values = sess.run((response_tensors, status_code)) self.assertAllEqual( flatten([errors.OK, errors.UNIMPLEMENTED] for _ in range(10)), status_code_values) for i in range(10): self.assertTrue(response_tensors_values[2 * i]) self.assertFalse(response_tensors_values[2 * i + 1])
def testTryRpcWithMultipleAddressesSingleRequest(self): flatten = lambda x: list(itertools.chain.from_iterable(x)) with self.cached_session() as sess: addresses = flatten([[ self._address, 'unix:/tmp/this_unix_socket_doesnt_exist_97820348!!@' ] for _ in range(10)]) request = test_example_pb2.TestCase(values=[0, 1, 2]).SerializeToString() response_tensors, status_code, _ = self.try_rpc( method=self.get_method_name('Increment'), address=addresses, request=request) response_tensors_values, status_code_values = sess.run((response_tensors, status_code)) self.assertAllEqual( flatten([errors.OK, errors.UNAVAILABLE] for _ in range(10)), status_code_values) for i in range(10): self.assertTrue(response_tensors_values[2 * i]) self.assertFalse(response_tensors_values[2 * i + 1])