async def test_wrap_stream_errors_aiter_called_multiple_times():
    mock_call = mock.Mock(aio.StreamStreamCall, autospec=True)
    multicallable = mock.Mock(return_value=mock_call)

    wrapped_callable = grpc_helpers_async._wrap_stream_errors(multicallable)
    wrapped_call = await wrapped_callable()

    assert wrapped_call.__aiter__() == wrapped_call.__aiter__()
async def test_wrap_stream_errors_type_error():
    mock_call = mock.Mock()
    multicallable = mock.Mock(return_value=mock_call)

    wrapped_callable = grpc_helpers_async._wrap_stream_errors(multicallable)

    with pytest.raises(TypeError):
        await wrapped_callable()
async def test_wrap_stream_errors_unary_stream():
    mock_call = mock.Mock(aio.UnaryStreamCall, autospec=True)
    multicallable = mock.Mock(return_value=mock_call)

    wrapped_callable = grpc_helpers_async._wrap_stream_errors(multicallable)

    await wrapped_callable(1, 2, three="four")
    multicallable.assert_called_once_with(1, 2, three="four")
    assert mock_call.wait_for_connection.call_count == 1
async def test_wrap_stream_errors_raised():
    grpc_error = RpcErrorImpl(grpc.StatusCode.INVALID_ARGUMENT)
    mock_call = mock.Mock(aio.StreamStreamCall, autospec=True)
    mock_call.wait_for_connection = mock.AsyncMock(side_effect=[grpc_error])
    multicallable = mock.Mock(return_value=mock_call)

    wrapped_callable = grpc_helpers_async._wrap_stream_errors(multicallable)

    with pytest.raises(exceptions.InvalidArgument):
        await wrapped_callable()
    assert mock_call.wait_for_connection.call_count == 1
async def test_wrap_stream_errors_read():
    grpc_error = RpcErrorImpl(grpc.StatusCode.INVALID_ARGUMENT)

    mock_call = mock.Mock(aio.StreamStreamCall, autospec=True)
    mock_call.read = mock.AsyncMock(side_effect=grpc_error)
    multicallable = mock.Mock(return_value=mock_call)

    wrapped_callable = grpc_helpers_async._wrap_stream_errors(multicallable)

    wrapped_call = await wrapped_callable(1, 2, three="four")
    multicallable.assert_called_once_with(1, 2, three="four")
    assert mock_call.wait_for_connection.call_count == 1

    with pytest.raises(exceptions.InvalidArgument) as exc_info:
        await wrapped_call.read()
    assert exc_info.value.response == grpc_error
async def test_wrap_stream_errors_aiter_non_rpc_error():
    non_grpc_error = TypeError('Not a gRPC error')

    mock_call = mock.Mock(aio.StreamStreamCall, autospec=True)
    mocked_aiter = mock.Mock(spec=['__anext__'])
    mocked_aiter.__anext__ = mock.AsyncMock(side_effect=[mock.sentinel.response, non_grpc_error])
    mock_call.__aiter__ = mock.Mock(return_value=mocked_aiter)
    multicallable = mock.Mock(return_value=mock_call)

    wrapped_callable = grpc_helpers_async._wrap_stream_errors(multicallable)
    wrapped_call = await wrapped_callable()

    with pytest.raises(TypeError) as exc_info:
        async for response in wrapped_call:
            assert response == mock.sentinel.response
    assert exc_info.value == non_grpc_error
async def test_wrap_stream_errors_aiter():
    grpc_error = RpcErrorImpl(grpc.StatusCode.INVALID_ARGUMENT)

    mock_call = mock.Mock(aio.StreamStreamCall, autospec=True)
    mocked_aiter = mock.Mock(spec=['__anext__'])
    mocked_aiter.__anext__ = mock.AsyncMock(side_effect=[mock.sentinel.response, grpc_error])
    mock_call.__aiter__ = mock.Mock(return_value=mocked_aiter)
    multicallable = mock.Mock(return_value=mock_call)

    wrapped_callable = grpc_helpers_async._wrap_stream_errors(multicallable)
    wrapped_call = await wrapped_callable()

    with pytest.raises(exceptions.InvalidArgument) as exc_info:
        async for response in wrapped_call:
            assert response == mock.sentinel.response
    assert exc_info.value.response == grpc_error
async def test_wrap_stream_errors_write():
    grpc_error = RpcErrorImpl(grpc.StatusCode.INVALID_ARGUMENT)

    mock_call = mock.Mock(aio.StreamStreamCall, autospec=True)
    mock_call.write = mock.AsyncMock(side_effect=[None, grpc_error])
    mock_call.done_writing = mock.AsyncMock(side_effect=[None, grpc_error])
    multicallable = mock.Mock(return_value=mock_call)

    wrapped_callable = grpc_helpers_async._wrap_stream_errors(multicallable)

    wrapped_call = await wrapped_callable()

    await wrapped_call.write(mock.sentinel.request)
    with pytest.raises(exceptions.InvalidArgument) as exc_info:
        await wrapped_call.write(mock.sentinel.request)
    assert mock_call.write.call_count == 2
    assert exc_info.value.response == grpc_error

    await wrapped_call.done_writing()
    with pytest.raises(exceptions.InvalidArgument) as exc_info:
        await wrapped_call.done_writing()
    assert mock_call.done_writing.call_count == 2
    assert exc_info.value.response == grpc_error