def test_create_tuple_reraises_type_error(self, mock_stub): instance = mock_stub.return_value instance.CreateTuple = mock.Mock(side_effect=TypeError) loop = asyncio.get_event_loop() executor = create_remote_executor() type_signature = computation_types.TensorType(tf.int32) value_1 = remote_executor.RemoteValue(executor_pb2.ValueRef(), type_signature, executor) value_2 = remote_executor.RemoteValue(executor_pb2.ValueRef(), type_signature, executor) with self.assertRaises(TypeError): loop.run_until_complete(executor.create_tuple([value_1, value_2]))
def test_create_tuple_reraises_grpc_error(self, mock_stub): instance = mock_stub.return_value instance.CreateTuple = mock.Mock( side_effect=_raise_non_retryable_grpc_error) loop = asyncio.get_event_loop() executor = create_remote_executor() type_signature = computation_types.TensorType(tf.int32) value_1 = remote_executor.RemoteValue(executor_pb2.ValueRef(), type_signature, executor) value_2 = remote_executor.RemoteValue(executor_pb2.ValueRef(), type_signature, executor) with self.assertRaises(grpc.RpcError) as context: loop.run_until_complete(executor.create_tuple([value_1, value_2])) self.assertEqual(context.exception.code(), grpc.StatusCode.ABORTED)
def test_create_tuple_returns_remote_value(self, mock_stub): response = executor_pb2.CreateTupleResponse() instance = mock_stub.return_value instance.CreateTuple = mock.Mock(side_effect=[response]) loop = asyncio.get_event_loop() executor = create_remote_executor() type_signature = computation_types.TensorType(tf.int32) value_1 = remote_executor.RemoteValue(executor_pb2.ValueRef(), type_signature, executor) value_2 = remote_executor.RemoteValue(executor_pb2.ValueRef(), type_signature, executor) result = loop.run_until_complete(executor.create_tuple([value_1, value_2])) instance.CreateTuple.assert_called_once() self.assertIsInstance(result, remote_executor.RemoteValue)
def test_create_selection_reraises_type_error(self, mock_stub): instance = mock_stub.return_value instance.CreateSelection = mock.Mock(side_effect=TypeError) loop = asyncio.get_event_loop() executor = create_remote_executor() type_signature = computation_types.NamedTupleType([tf.int32, tf.int32]) source = remote_executor.RemoteValue(executor_pb2.ValueRef(), type_signature, executor) with self.assertRaises(TypeError): loop.run_until_complete(executor.create_selection(source, index=0))
def test_create_call_reraises_type_error(self, mock_stub): instance = mock_stub.return_value instance.CreateCall = mock.Mock(side_effect=TypeError) loop = asyncio.get_event_loop() executor = create_remote_executor() type_signature = computation_types.FunctionType(None, tf.int32) comp = remote_executor.RemoteValue(executor_pb2.ValueRef(), type_signature, executor) with self.assertRaises(TypeError): loop.run_until_complete(executor.create_call(comp))
def test_compute_raises_retryable_error_on_grpc_error_unavailable( self, mock_stub): instance = mock_stub.return_value instance.Compute = mock.Mock(side_effect=_raise_grpc_error_unavailable) loop = asyncio.get_event_loop() executor = create_remote_executor() type_signature = computation_types.FunctionType(None, tf.int32) comp = remote_executor.RemoteValue(executor_pb2.ValueRef(), type_signature, executor) with self.assertRaises(execution_context.RetryableError): loop.run_until_complete(comp.compute())
def test_create_selection_returns_remote_value(self, mock_stub): response = executor_pb2.CreateSelectionResponse() instance = mock_stub.return_value instance.CreateSelection = mock.Mock(side_effect=[response]) loop = asyncio.get_event_loop() executor = create_remote_executor() type_signature = computation_types.NamedTupleType([tf.int32, tf.int32]) source = remote_executor.RemoteValue(executor_pb2.ValueRef(), type_signature, executor) result = loop.run_until_complete(executor.create_selection(source, index=0)) instance.CreateSelection.assert_called_once() self.assertIsInstance(result, remote_executor.RemoteValue)
def test_create_call_returns_remote_value(self, mock_stub): response = executor_pb2.CreateCallResponse() instance = mock_stub.return_value instance.CreateCall = mock.Mock(side_effect=[response]) loop = asyncio.get_event_loop() executor = create_remote_executor() type_signature = computation_types.FunctionType(None, tf.int32) fn = remote_executor.RemoteValue(executor_pb2.ValueRef(), type_signature, executor) result = loop.run_until_complete(executor.create_call(fn, None)) instance.CreateCall.assert_called_once() self.assertIsInstance(result, remote_executor.RemoteValue)
def test_compute_reraises_grpc_error_deadline_exceeded(self, mock_stub): instance = mock_stub.return_value instance.Compute = mock.Mock( side_effect=_raise_grpc_error_deadline_exceeded) loop = asyncio.get_event_loop() executor = create_remote_executor() type_signature = computation_types.FunctionType(None, tf.int32) comp = remote_executor.RemoteValue(executor_pb2.ValueRef(), type_signature, executor) with self.assertRaises(grpc.RpcError) as context: loop.run_until_complete(comp.compute()) self.assertEqual(context.exception.code(), grpc.StatusCode.DEADLINE_EXCEEDED)
def test_compute_returns_result(self, mock_stub): tensor_proto = tf.make_tensor_proto(1) any_pb = any_pb2.Any() any_pb.Pack(tensor_proto) value = executor_pb2.Value(tensor=any_pb) response = executor_pb2.ComputeResponse(value=value) instance = mock_stub.return_value instance.Compute = mock.Mock(side_effect=[response]) loop = asyncio.get_event_loop() executor = create_remote_executor() type_signature = computation_types.FunctionType(None, tf.int32) comp = remote_executor.RemoteValue(executor_pb2.ValueRef(), type_signature, executor) result = loop.run_until_complete(comp.compute()) instance.Compute.assert_called_once() self.assertEqual(result, 1)