예제 #1
0
    def wrapper(*args, **kwargs):
        run_kwargs = {'backend': kwargs['anyio_backend_name'],
                      'backend_options': kwargs['anyio_backend_options']}
        if 'anyio_backend_name' in strip_argnames:
            del kwargs['anyio_backend_name']
        if 'anyio_backend_options' in strip_argnames:
            del kwargs['anyio_backend_options']

        if isasyncgenfunction(func):
            gen = func(*args, **kwargs)
            try:
                value = run(gen.__anext__, **run_kwargs)
            except StopAsyncIteration:
                raise RuntimeError('Async generator did not yield')

            yield value

            try:
                run(gen.__anext__, **run_kwargs)
            except StopAsyncIteration:
                pass
            else:
                run(gen.aclose, **run_kwargs)
                raise RuntimeError('Async generator fixture did not stop')
        else:
            yield run(partial(func, *args, **kwargs), **run_kwargs)
예제 #2
0
def sniff_options(obj):
    options = set()
    async_gen = False
    # We walk the __wrapped__ chain to collect properties.
    #
    # If something sniffs as *both* an async generator *and* a coroutine, then
    # it's probably an async_generator-style async_generator (since they wrap
    # a coroutine, but are not a coroutine).
    while True:
        if getattr(obj, "__isabstractmethod__", False):
            options.add("abstractmethod")
        if isinstance(obj, classmethod):
            options.add("classmethod")
        if isinstance(obj, staticmethod):
            options.add("staticmethod")
        if isinstance(obj, property):
            options.add("property")
        if inspect.iscoroutinefunction(obj):
            options.add("async")
        if async_generator.isasyncgenfunction(obj):
            async_gen = True
        if hasattr(obj, "__wrapped__"):
            obj = obj.__wrapped__
        else:
            break
    if async_gen:
        options.discard("async")
    return options
예제 #3
0
def pytest_fixture_setup(fixturedef, request):
    def wrapper(*args, **kwargs):
        backend = kwargs['anyio_backend']
        if strip_backend:
            del kwargs['anyio_backend']

        if isasyncgenfunction(func):
            gen = func(*args, **kwargs)
            try:
                value = anyio.run(gen.__anext__, backend=backend)
            except StopAsyncIteration:
                raise RuntimeError('Async generator did not yield')

            yield value

            try:
                anyio.run(gen.__anext__, backend=backend)
            except StopAsyncIteration:
                pass
            else:
                anyio.run(gen.aclose)
                raise RuntimeError('Async generator fixture did not stop')
        else:
            yield anyio.run(partial(func, *args, **kwargs), backend=backend)

    func = fixturedef.func
    if isasyncgenfunction(func) or iscoroutinefunction(func):
        strip_backend = False
        if 'anyio_backend' not in fixturedef.argnames:
            fixturedef.argnames += ('anyio_backend', )
            strip_backend = True

        fixturedef.func = wrapper

    yield
예제 #4
0
def _syncify(*types, loop, thread_ident):
    for t in types:
        # __enter__ and __exit__ need special care (VERY dirty hack).
        #
        # Normally we want them to raise if the loop is running because
        # the user can't await there, and they need the async with variant.
        #
        # However they check if the loop is running to raise, which it is
        # with full_sync enabled, so we patch them with the async variant.
        if hasattr(t, '__aenter__'):
            _syncify_wrap(t,
                          '__aenter__',
                          loop,
                          thread_ident,
                          rename='__enter__')

            _syncify_wrap(t,
                          '__aexit__',
                          loop,
                          thread_ident,
                          rename='__exit__')

        for name in dir(t):
            if not name.startswith('_') or name == '__call__':
                meth = getattr(t, name)
                meth = getattr(meth, '__tl.sync', meth)
                if inspect.iscoroutinefunction(meth):
                    _syncify_wrap(t, name, loop, thread_ident)
                elif isasyncgenfunction(meth):
                    _syncify_wrap(t, name, loop, thread_ident, _SyncGen)
예제 #5
0
파일: context.py 프로젝트: vdt/asphalt
def context_teardown(func: Callable):
    """
    Wrap an async generator function to execute the rest of the function at context teardown.

    This function returns an async function, which, when called, starts the wrapped async
    generator. The wrapped async function is run until the first ``yield`` statement
    (``await async_generator.yield_()`` on Python 3.5). When the context is being torn down, the
    exception that ended the context, if any, is sent to the generator.

    For example::

        class SomeComponent(Component):
            @context_teardown
            async def start(self, ctx: Context):
                service = SomeService()
                ctx.add_resource(service)
                exception = yield
                service.stop()

    :param func: an async generator function
    :return: an async function

    """
    @wraps(func)
    async def wrapper(*args, **kwargs) -> None:
        async def teardown_callback(exception: Optional[Exception]):
            try:
                await generator.asend(exception)
            except StopAsyncIteration:
                pass
            finally:
                await generator.aclose()

        try:
            ctx = next(arg for arg in args[:2] if isinstance(arg, Context))
        except StopIteration:
            raise RuntimeError(
                'the first positional argument to {}() has to be a Context '
                'instance'.format(callable_name(func))) from None

        generator = func(*args, **kwargs)
        try:
            await generator.asend(None)
        except StopAsyncIteration:
            raise RuntimeError('{} did not do "await yield_()"'.format(
                callable_name(func))) from None
        except BaseException:
            await generator.aclose()
            raise
        else:
            ctx.add_teardown_callback(teardown_callback, True)

    if iscoroutinefunction(func):
        func = async_generator(func)
    elif not isasyncgenfunction(func):
        raise TypeError('{} must be an async generator function'.format(
            callable_name(func)))

    return wrapper
예제 #6
0
async def test_agen_protection():
    @_core.enable_ki_protection
    @async_generator
    async def agen_protected1():
        assert _core.currently_ki_protected()
        try:
            await yield_()
        finally:
            assert _core.currently_ki_protected()

    @_core.disable_ki_protection
    @async_generator
    async def agen_unprotected1():
        assert not _core.currently_ki_protected()
        try:
            await yield_()
        finally:
            assert not _core.currently_ki_protected()

    # Swap the order of the decorators:
    @async_generator
    @_core.enable_ki_protection
    async def agen_protected2():
        assert _core.currently_ki_protected()
        try:
            await yield_()
        finally:
            assert _core.currently_ki_protected()

    @async_generator
    @_core.disable_ki_protection
    async def agen_unprotected2():
        assert not _core.currently_ki_protected()
        try:
            await yield_()
        finally:
            assert not _core.currently_ki_protected()

    for agen_fn in [
        agen_protected1,
        agen_protected2,
        agen_unprotected1,
        agen_unprotected2,
    ]:
        async for _ in agen_fn():
            assert not _core.currently_ki_protected()

        # asynccontextmanager insists that the function passed must itself be an
        # async gen function, not a wrapper around one
        if isasyncgenfunction(agen_fn):
            async with asynccontextmanager(agen_fn)():
                assert not _core.currently_ki_protected()

            # Another case that's tricky due to:
            #   https://bugs.python.org/issue29590
            with pytest.raises(KeyError):
                async with asynccontextmanager(agen_fn)():
                    raise KeyError
예제 #7
0
async def test_agen_protection():
    @_core.enable_ki_protection
    @async_generator
    async def agen_protected1():
        assert _core.currently_ki_protected()
        try:
            await yield_()
        finally:
            assert _core.currently_ki_protected()

    @_core.disable_ki_protection
    @async_generator
    async def agen_unprotected1():
        assert not _core.currently_ki_protected()
        try:
            await yield_()
        finally:
            assert not _core.currently_ki_protected()

    # Swap the order of the decorators:
    @async_generator
    @_core.enable_ki_protection
    async def agen_protected2():
        assert _core.currently_ki_protected()
        try:
            await yield_()
        finally:
            assert _core.currently_ki_protected()

    @async_generator
    @_core.disable_ki_protection
    async def agen_unprotected2():
        assert not _core.currently_ki_protected()
        try:
            await yield_()
        finally:
            assert not _core.currently_ki_protected()

    for agen_fn in [
            agen_protected1,
            agen_protected2,
            agen_unprotected1,
            agen_unprotected2,
    ]:
        async for _ in agen_fn():  # noqa
            assert not _core.currently_ki_protected()

        # asynccontextmanager insists that the function passed must itself be an
        # async gen function, not a wrapper around one
        if isasyncgenfunction(agen_fn):
            async with asynccontextmanager(agen_fn)():
                assert not _core.currently_ki_protected()

            # Another case that's tricky due to:
            #   https://bugs.python.org/issue29590
            with pytest.raises(KeyError):
                async with asynccontextmanager(agen_fn)():
                    raise KeyError
예제 #8
0
def _syncify(*types, loop, thread_ident):
    for t in types:
        for method_name in dir(t):
            if not method_name.startswith('_') or method_name == '__call__':
                if inspect.iscoroutinefunction(getattr(t, method_name)):
                    _syncify_wrap(t, method_name, loop, thread_ident,
                                  _sync_result)
                elif isasyncgenfunction(getattr(t, method_name)):
                    _syncify_wrap(t, method_name, loop, thread_ident, _SyncGen)
예제 #9
0
def _is_trio_fixture(func, coerce_async, kwargs):
    if getattr(func, "_force_trio_fixture", False):
        return True
    if (coerce_async
            and (iscoroutinefunction(func) or isasyncgenfunction(func))):
        return True
    if any(isinstance(value, TrioFixture) for value in kwargs.values()):
        return True
    return False
예제 #10
0
def context_teardown(func: Callable):
    """
    Wrap an async generator function to execute the rest of the function at context teardown.

    This function returns an async function, which, when called, starts the wrapped async
    generator. The wrapped async function is run until the first ``yield`` statement
    (``await async_generator.yield_()`` on Python 3.5). When the context is being torn down, the
    exception that ended the context, if any, is sent to the generator.

    For example::

        class SomeComponent(Component):
            @context_teardown
            async def start(self, ctx: Context):
                service = SomeService()
                ctx.add_resource(service)
                exception = yield
                service.stop()

    :param func: an async generator function
    :return: an async function

    """
    @wraps(func)
    async def wrapper(*args, **kwargs) -> None:
        async def teardown_callback(exception: Optional[Exception]):
            try:
                await generator.asend(exception)
            except StopAsyncIteration:
                pass
            finally:
                await generator.aclose()

        try:
            ctx = next(arg for arg in args[:2] if isinstance(arg, Context))
        except StopIteration:
            raise RuntimeError('the first positional argument to {}() has to be a Context '
                               'instance'.format(callable_name(func))) from None

        generator = func(*args, **kwargs)
        try:
            await generator.asend(None)
        except StopAsyncIteration:
            pass
        except BaseException:
            await generator.aclose()
            raise
        else:
            ctx.add_teardown_callback(teardown_callback, True)

    if iscoroutinefunction(func):
        func = async_generator(func)
    elif not isasyncgenfunction(func):
        raise TypeError('{} must be an async generator function'.format(callable_name(func)))

    return wrapper
예제 #11
0
def _syncify(*types, loop, thread_ident):
    for t in types:
        for name in dir(t):
            if not name.startswith('_') or name == '__call__':
                meth = getattr(t, name)
                meth = getattr(meth, '__tl.sync', meth)
                if inspect.iscoroutinefunction(meth):
                    _syncify_wrap(t, name, loop, thread_ident)
                elif isasyncgenfunction(meth):
                    _syncify_wrap(t, name, loop, thread_ident, _SyncGen)
예제 #12
0
파일: _util.py 프로젝트: zed/trio
def acontextmanager(func):
    """Like @contextmanager, but async."""
    if not async_generator.isasyncgenfunction(func):
        raise TypeError(
            "must be an async generator (native or from async_generator; "
            "if using @async_generator then @acontextmanager must be on top.")
    @wraps(func)
    def helper(*args, **kwds):
        return _AsyncGeneratorContextManager(func, args, kwds)
    return helper
예제 #13
0
    def decorator(fn):
        # In some version of Python, isgeneratorfunction returns true for
        # coroutine functions, so we have to check for coroutine functions
        # first.
        if inspect.iscoroutinefunction(fn):

            @wraps(fn)
            def wrapper(*args, **kwargs):
                # See the comment for regular generators below
                coro = fn(*args, **kwargs)
                coro.cr_frame.f_locals[LOCALS_KEY_KI_PROTECTION_ENABLED
                                       ] = enabled
                return coro

            return wrapper
        elif inspect.isgeneratorfunction(fn):

            @wraps(fn)
            def wrapper(*args, **kwargs):
                # It's important that we inject this directly into the
                # generator's locals, as opposed to setting it here and then
                # doing 'yield from'. The reason is, if a generator is
                # throw()n into, then it may magically pop to the top of the
                # stack. And @contextmanager generators in particular are a
                # case where we often want KI protection, and which are often
                # thrown into! See:
                #     https://bugs.python.org/issue29590
                gen = fn(*args, **kwargs)
                gen.gi_frame.f_locals[LOCALS_KEY_KI_PROTECTION_ENABLED
                                      ] = enabled
                return gen

            return wrapper
        elif async_generator.isasyncgenfunction(fn):

            @wraps(fn)
            def wrapper(*args, **kwargs):
                # See the comment for regular generators above
                agen = fn(*args, **kwargs)
                agen.ag_frame.f_locals[LOCALS_KEY_KI_PROTECTION_ENABLED
                                       ] = enabled
                return agen

            return wrapper
        else:

            @wraps(fn)
            def wrapper(*args, **kwargs):
                locals()[LOCALS_KEY_KI_PROTECTION_ENABLED] = enabled
                return fn(*args, **kwargs)

            return wrapper
예제 #14
0
    def decorator(fn):
        # In some version of Python, isgeneratorfunction returns true for
        # coroutine functions, so we have to check for coroutine functions
        # first.
        if inspect.iscoroutinefunction(fn):

            @wraps(fn)
            def wrapper(*args, **kwargs):
                # See the comment for regular generators below
                coro = fn(*args, **kwargs)
                coro.cr_frame.f_locals[
                    LOCALS_KEY_KI_PROTECTION_ENABLED] = enabled
                return coro

            return wrapper
        elif inspect.isgeneratorfunction(fn):

            @wraps(fn)
            def wrapper(*args, **kwargs):
                # It's important that we inject this directly into the
                # generator's locals, as opposed to setting it here and then
                # doing 'yield from'. The reason is, if a generator is
                # throw()n into, then it may magically pop to the top of the
                # stack. And @contextmanager generators in particular are a
                # case where we often want KI protection, and which are often
                # thrown into! See:
                #     https://bugs.python.org/issue29590
                gen = fn(*args, **kwargs)
                gen.gi_frame.f_locals[
                    LOCALS_KEY_KI_PROTECTION_ENABLED] = enabled
                return gen

            return wrapper
        elif async_generator.isasyncgenfunction(fn):

            @wraps(fn)
            def wrapper(*args, **kwargs):
                # See the comment for regular generators above
                agen = fn(*args, **kwargs)
                agen.ag_frame.f_locals[
                    LOCALS_KEY_KI_PROTECTION_ENABLED] = enabled
                return agen

            return wrapper
        else:

            @wraps(fn)
            def wrapper(*args, **kwargs):
                locals()[LOCALS_KEY_KI_PROTECTION_ENABLED] = enabled
                return fn(*args, **kwargs)

            return wrapper
예제 #15
0
파일: sync.py 프로젝트: mrt-kousha/Telethon
def syncify(*types):
    """
    Converts all the methods in the given types (class definitions)
    into synchronous, which return either the coroutine or the result
    based on whether ``asyncio's`` event loop is running.
    """
    for t in types:
        for method_name in dir(t):
            if not method_name.startswith('_') or method_name == '__call__':
                if inspect.iscoroutinefunction(getattr(t, method_name)):
                    _syncify_coro(t, method_name)
                elif isasyncgenfunction(getattr(t, method_name)):
                    _syncify_gen(t, method_name)
예제 #16
0
def syncify(*types):
    """
    Converts all the methods in the given types (class definitions)
    into synchronous, which return either the coroutine or the result
    based on whether ``asyncio's`` event loop is running.
    """
    loop = asyncio.get_event_loop()
    for t in types:
        for name in dir(t):
            if not name.startswith('_') or name == '__call__':
                if inspect.iscoroutinefunction(getattr(t, name)):
                    _syncify_wrap(t, name, loop.run_until_complete)
                elif isasyncgenfunction(getattr(t, name)):
                    _syncify_wrap(t, name, functools.partial(_SyncGen, loop))
def async_contextmanager(func: Callable[..., generator_types]) -> Callable:
    """
    Transform a coroutine function into something that works with ``async with``.

    This is an asynchronous counterpart to :func:`~contextlib.contextmanager`.
    The wrapped function can either be a native async generator function (``async def`` with
    ``yield``) or, if your code needs to be compatible with Python 3.5, you can use
    :func:`~async_generator.yield_` instead of the native ``yield`` statement.

    The generator must yield *exactly once*, just like with :func:`~contextlib.contextmanager`.

    Usage in Python 3.5 and earlier::

        @async_contextmanager
        async def mycontextmanager(arg):
            context = await setup_remote_context(arg)
            await yield_(context)
            await context.teardown()

        async def frobnicate(arg):
            async with mycontextmanager(arg) as context:
                do_something_with(context)

    The same context manager function in Python 3.6+::

        @async_contextmanager
        async def mycontextmanager(arg):
            context = await setup_remote_context(arg)
            yield context
            await context.teardown()

    :param func: an async generator function or a coroutine function using
        :func:`~async_generator.yield_`
    :return: a callable that can be used with ``async with``

    """
    if not isasyncgenfunction(func):
        if iscoroutinefunction(func):
            func = async_generator(func)
        else:
            '"func" must be an async generator function or a coroutine function'

    @wraps(func)
    def wrapper(*args, **kwargs):
        generator = func(*args, **kwargs)
        return _AsyncContextManager(generator)

    return wrapper
예제 #18
0
def async_contextmanager(func: Callable[..., generator_types]) -> Callable:
    """
    Transform a coroutine function into something that works with ``async with``.

    This is an asynchronous counterpart to :func:`~contextlib.contextmanager`.
    The wrapped function can either be a native async generator function (``async def`` with
    ``yield``) or, if your code needs to be compatible with Python 3.5, you can use
    :func:`~async_generator.yield_` instead of the native ``yield`` statement.

    The generator must yield *exactly once*, just like with :func:`~contextlib.contextmanager`.

    Usage in Python 3.5 and earlier::

        @async_contextmanager
        async def mycontextmanager(arg):
            context = await setup_remote_context(arg)
            await yield_(context)
            await context.teardown()

        async def frobnicate(arg):
            async with mycontextmanager(arg) as context:
                do_something_with(context)

    The same context manager function in Python 3.6+::

        @async_contextmanager
        async def mycontextmanager(arg):
            context = await setup_remote_context(arg)
            yield context
            await context.teardown()

    :param func: an async generator function or a coroutine function using
        :func:`~async_generator.yield_`
    :return: a callable that can be used with ``async with``

    """
    if not isasyncgenfunction(func):
        if iscoroutinefunction(func):
            func = async_generator(func)
        else:
            '"func" must be an async generator function or a coroutine function'

    @wraps(func)
    def wrapper(*args, **kwargs):
        generator = func(*args, **kwargs)
        return _AsyncContextManager(generator)

    return wrapper
예제 #19
0
def _install_async_fixture_if_needed(fixturedef, request):
    asyncfix = None
    deps = {dep: request.getfixturevalue(dep) for dep in fixturedef.argnames}
    if iscoroutinefunction(fixturedef.func):
        asyncfix = AsyncFixture(fixturedef, deps)
    elif isasyncgenfunction(fixturedef.func):
        asyncfix = AsyncYieldFixture(fixturedef, deps)
    elif any(dep for dep in deps.values()
             if isinstance(dep, BaseAsyncFixture)):
        if isgeneratorfunction(fixturedef.func):
            asyncfix = SyncYieldFixtureWithAsyncDeps(fixturedef, deps)
        else:
            asyncfix = SyncFixtureWithAsyncDeps(fixturedef, deps)
    if asyncfix:
        fixturedef.cached_result = (asyncfix, request.param_index, None)
        return asyncfix
예제 #20
0
def trio2aio(proc):
    if isasyncgenfunction(proc):

        @wraps(proc)
        def call(*args, **kwargs):
            proc_ = proc
            if kwargs:
                proc_ = partial(proc_, **kwargs)
            return trio_asyncio.wrap_generator(proc_, *args)

    else:

        @wraps(proc)
        async def call(*args, **kwargs):
            proc_ = proc
            if kwargs:
                proc_ = partial(proc_, **kwargs)
            return await trio_asyncio.run_asyncio(proc_, *args)

    return call
예제 #21
0
def sniff_options(obj):
    options = set()
    # We walk the __wrapped__ chain to collect properties.
    while True:
        if getattr(obj, "__isabstractmethod__", False):
            options.add("abstractmethod")
        if isinstance(obj, classmethod):
            options.add("classmethod")
        if isinstance(obj, staticmethod):
            options.add("staticmethod")
        # if isinstance(obj, property):
        #     options.add("property")
        # Only check for these if we haven't seen any of them yet:
        if not (options & EXCLUSIVE_OPTIONS):
            if inspect.iscoroutinefunction(obj):
                options.add("async")
            # in some versions of Python, isgeneratorfunction returns true for
            # coroutines, so we use elif
            elif inspect.isgeneratorfunction(obj):
                options.add("for")
            if isasyncgenfunction(obj):
                options.add("async-for")
            # Some heuristics to detect when something is a context manager
            if getattr(obj, "__code__", None) in CM_CODES:
                options.add("with")
            if getattr(obj, "__returns_contextmanager__", False):
                options.add("with")
            if getattr(obj, "__code__", None) in ACM_CODES:
                options.add("async-with")
            if getattr(obj, "__returns_acontextmanager__", False):
                options.add("async-with")
        if hasattr(obj, "__wrapped__"):
            obj = obj.__wrapped__
        elif hasattr(obj, "__func__"):  # for staticmethod & classmethod
            obj = obj.__func__
        else:
            break

    return options
예제 #22
0
    def wrapper(*args, **kwargs):
        backend = kwargs['anyio_backend']
        if strip_backend:
            del kwargs['anyio_backend']

        if isasyncgenfunction(func):
            gen = func(*args, **kwargs)
            try:
                value = run(gen.__anext__, backend=backend)
            except StopAsyncIteration:
                raise RuntimeError('Async generator did not yield')

            yield value

            try:
                run(gen.__anext__, backend=backend)
            except StopAsyncIteration:
                pass
            else:
                run(gen.aclose, backend=backend)
                raise RuntimeError('Async generator fixture did not stop')
        else:
            yield run(partial(func, *args, **kwargs), backend=backend)
예제 #23
0
def pytest_fixture_setup(fixturedef, request):
    def wrapper(*args, **kwargs):
        run_kwargs = {'backend': kwargs['anyio_backend_name'],
                      'backend_options': kwargs['anyio_backend_options']}
        if 'anyio_backend_name' in strip_argnames:
            del kwargs['anyio_backend_name']
        if 'anyio_backend_options' in strip_argnames:
            del kwargs['anyio_backend_options']

        if isasyncgenfunction(func):
            gen = func(*args, **kwargs)
            try:
                value = run(gen.__anext__, **run_kwargs)
            except StopAsyncIteration:
                raise RuntimeError('Async generator did not yield')

            yield value

            try:
                run(gen.__anext__, **run_kwargs)
            except StopAsyncIteration:
                pass
            else:
                run(gen.aclose, **run_kwargs)
                raise RuntimeError('Async generator fixture did not stop')
        else:
            yield run(partial(func, *args, **kwargs), **run_kwargs)

    func = fixturedef.func
    if (isasyncgenfunction(func) or iscoroutinefunction(func)) and 'anyio' in request.keywords:
        strip_argnames = []
        for argname in ('anyio_backend_name', 'anyio_backend_options'):
            if argname not in fixturedef.argnames:
                fixturedef.argnames += (argname,)
                strip_argnames.append(argname)

        fixturedef.func = wrapper
예제 #24
0
파일: plugin.py 프로젝트: eocanha/webkit
def pytest_fixture_setup(fixturedef, request):
    """Adjust the event loop policy when an event loop is produced."""
    if fixturedef.argname == "event_loop":
        outcome = yield
        loop = outcome.get_result()
        policy = asyncio.get_event_loop_policy()
        policy.set_event_loop(loop)
        return

    if isasyncgenfunction(fixturedef.func):
        # This is an async generator function. Wrap it accordingly.
        generator = fixturedef.func

        fixture_stripper = FixtureStripper(fixturedef)
        fixture_stripper.add(FixtureStripper.EVENT_LOOP)
        fixture_stripper.add(FixtureStripper.REQUEST)


        def wrapper(*args, **kwargs):
            loop = fixture_stripper.get_and_strip_from(FixtureStripper.EVENT_LOOP, kwargs)
            request = fixture_stripper.get_and_strip_from(FixtureStripper.REQUEST, kwargs)

            gen_obj = generator(*args, **kwargs)

            async def setup():
                res = await gen_obj.__anext__()
                return res

            def finalizer():
                """Yield again, to finalize."""
                async def async_finalizer():
                    try:
                        await gen_obj.__anext__()
                    except StopAsyncIteration:
                        pass
                    else:
                        msg = "Async generator fixture didn't stop."
                        msg += "Yield only once."
                        raise ValueError(msg)
                loop.run_until_complete(async_finalizer())

            request.addfinalizer(finalizer)
            return loop.run_until_complete(setup())

        fixturedef.func = wrapper
    elif inspect.iscoroutinefunction(fixturedef.func):
        coro = fixturedef.func

        fixture_stripper = FixtureStripper(fixturedef)
        fixture_stripper.add(FixtureStripper.EVENT_LOOP)

        def wrapper(*args, **kwargs):
            loop = fixture_stripper.get_and_strip_from(FixtureStripper.EVENT_LOOP, kwargs)

            async def setup():
                res = await coro(*args, **kwargs)
                return res

            return loop.run_until_complete(setup())

        fixturedef.func = wrapper
    yield
예제 #25
0
def pytest_fixture_setup(fixturedef, request):
    """Adjust the event loop policy when an event loop is produced."""
    if isasyncgenfunction(fixturedef.func):
        # This is an async generator function. Wrap it accordingly.
        f = fixturedef.func

        strip_event_loop = False
        if 'event_loop' not in fixturedef.argnames:
            fixturedef.argnames += ('event_loop', )
            strip_event_loop = True
        strip_request = False
        if 'request' not in fixturedef.argnames:
            fixturedef.argnames += ('request', )
            strip_request = True

        def wrapper(*args, **kwargs):
            loop = kwargs['event_loop']
            request = kwargs['request']
            if strip_event_loop:
                del kwargs['event_loop']
            if strip_request:
                del kwargs['request']

            gen_obj = f(*args, **kwargs)

            async def setup():
                res = await gen_obj.__anext__()
                return res

            def finalizer():
                """Yield again, to finalize."""
                async def async_finalizer():
                    try:
                        await gen_obj.__anext__()
                    except StopAsyncIteration:
                        pass
                    else:
                        msg = "Async generator fixture didn't stop."
                        msg += "Yield only once."
                        raise ValueError(msg)

                loop.run_until_complete(async_finalizer())

            request.addfinalizer(finalizer)

            return loop.run_until_complete(setup())

        fixturedef.func = wrapper

    elif inspect.iscoroutinefunction(fixturedef.func):
        # Just a coroutine, not an async generator.
        f = fixturedef.func

        strip_event_loop = False
        if 'event_loop' not in fixturedef.argnames:
            fixturedef.argnames += ('event_loop', )
            strip_event_loop = True

        def wrapper(*args, **kwargs):
            loop = kwargs['event_loop']
            if strip_event_loop:
                del kwargs['event_loop']

            async def setup():
                res = await f(*args, **kwargs)
                return res

            return loop.run_until_complete(setup())

        fixturedef.func = wrapper

    outcome = yield

    if fixturedef.argname == "event_loop" and 'asyncio' in request.keywords:
        loop = outcome.get_result()
        for kw in _markers_2_fixtures.keys():
            if kw not in request.keywords:
                continue
            policy = asyncio.get_event_loop_policy()
            try:
                old_loop = policy.get_event_loop()
            except RuntimeError as exc:
                if 'no current event loop' not in str(exc):
                    raise
                old_loop = None
            policy.set_event_loop(loop)
            fixturedef.addfinalizer(lambda: policy.set_event_loop(old_loop))
예제 #26
0
def isasyncgenerator(func):
    if isasyncgenfunction(func):
        return True
    elif asyncio.iscoroutinefunction(func):
        return False
예제 #27
0
def pytest_fixture_setup(fixturedef):
    """
    Allow fixtures to be coroutines. Run coroutine fixtures in an event loop.
    """
    if isasyncgenfunction(fixturedef.func):
        func = fixturedef.func

        strip_request = False
        if 'request' not in fixturedef.argnames:
            fixturedef.argnames += ('request', )
            strip_request = True

        def wrapper(*args, **kwargs):
            request = kwargs['request']

            if strip_request:
                del kwargs['request']

            if 'loop' not in request.fixturenames:
                raise Exception(
                    "Asynchronous fixtures must depend on the 'loop' fixture or "
                    "be used in tests depending from it.")

            loop = request.getfixturevalue('loop')
            # for async generators, we need to advance the generator once,
            # then advance it again in a finalizer
            gen = func(*args, **kwargs)

            def finalizer():
                try:
                    return loop.run_until_complete(gen.__anext__())
                except StopAsyncIteration:  # NOQA
                    pass

            request.addfinalizer(finalizer)
            return loop.run_until_complete(gen.__anext__())

        fixturedef.func = wrapper

    elif asyncio.iscoroutinefunction(fixturedef.func):
        func = fixturedef.func

        strip_request = False
        if 'request' not in fixturedef.argnames:
            fixturedef.argnames += ('request', )
            strip_request = True

        def wrapper(*args, **kwargs):
            request = kwargs['request']
            if 'loop' not in request.fixturenames:
                raise Exception(
                    "Asynchronous fixtures must depend on the 'loop' fixture or "
                    "be used in tests depending from it.")

            loop = request.getfixturevalue('loop')

            if strip_request:
                del kwargs['request']

            return loop.run_until_complete(func(*args, **kwargs))

        fixturedef.func = wrapper

    else:
        return