Example #1
0
def trace_stack_push(trace_stack_var: ContextVar, node: Any) -> None:
    """Push an element to the top of a trace stack."""
    trace_stack = trace_stack_var.get()
    if trace_stack is None:
        trace_stack = []
        trace_stack_var.set(trace_stack)
    trace_stack.append(node)
def test_retry(kernel):

    counter = ContextVar('counter_test_retry', default=5)

    async def tick():
        value = counter.get()
        counter.set(value - 1)
        if value <= 0:
            return SUCCESS
        if value == 3:
            raise RuntimeError('3')
        return FAILURE

    result = kernel.run(retry(tick))
    assert not result
    assert isinstance(result, ExceptionDecorator)
    assert kernel.run(retry(tick))

    counter.set(10)
    assert kernel.run(retry(tick, max_retry=11))

    counter.set(100)
    assert kernel.run(retry(tick, max_retry=-1))

    # negative
    with pytest.raises(AssertionError):
        retry(tick, max_retry=-2)
Example #3
0
async def test_retry():

    counter = ContextVar('counter_test_retry', default=5)

    async def tick():
        value = counter.get()
        counter.set(value - 1)
        print(f"value: {value}")
        if value <= 0:
            return SUCCESS
        if value == 3:
            raise RuntimeError('3')
        return FAILURE

    result = await retry(ignore_exception(tick))()  # counter: 5, 4, 3
    assert not result
    assert isinstance(result, ControlFlowException)

    assert await retry(tick)()  # counter: 2, 1, 0

    counter.set(10)
    assert await retry(ignore_exception(tick), max_retry=11)()

    counter.set(100)
    assert await retry(ignore_exception(tick), max_retry=-1)()

    # negative
    with pytest.raises(AssertionError):
        retry(tick, max_retry=-2)
Example #4
0
        def _test_context(self, propagate):
            id_var = ContextVar("id", default=None)
            id_var.set(0)

            callback = getcurrent().switch
            counts = dict((i, 0) for i in range(5))

            lets = [
                greenlet(
                    partial(
                        partial(copy_context().run, self._increment)
                        if propagate else self._increment,
                        greenlet_id=i,
                        ctx_var=id_var,
                        callback=callback,
                        counts=counts,
                        expect=0 if propagate else None,
                    )) for i in range(1, 5)
            ]

            for i in range(2):
                counts[id_var.get()] += 1
                for let in lets:
                    let.switch()

            self.assertEqual(set(counts.values()), set([2]))
Example #5
0
    def _test_context(self, propagate_by):
        id_var = ContextVar("id", default=None)
        id_var.set(0)

        callback = getcurrent().switch
        counts = dict((i, 0) for i in range(5))

        lets = [
            greenlet(partial(
                partial(
                    copy_context().run,
                    self._increment
                ) if propagate_by == "run" else self._increment,
                greenlet_id=i,
                ctx_var=id_var,
                callback=callback,
                counts=counts,
                expect=(
                    i - 1 if propagate_by == "share" else
                    0 if propagate_by in ("set", "run") else None
                )
            ))
            for i in range(1, 5)
        ]

        for let in lets:
            if propagate_by == "set":
                let.gr_context = copy_context()
            elif propagate_by == "share":
                let.gr_context = getcurrent().gr_context

        for i in range(2):
            counts[id_var.get()] += 1
            for let in lets:
                let.switch()

        if propagate_by == "run":
            # Must leave each context.run() in reverse order of entry
            for let in reversed(lets):
                let.switch()
        else:
            # No context.run(), so fine to exit in any order.
            for let in lets:
                let.switch()

        for let in lets:
            self.assertTrue(let.dead)
            # When using run(), we leave the run() as the greenlet dies,
            # and there's no context "underneath". When not using run(),
            # gr_context still reflects the context the greenlet was
            # running in.
            self.assertEqual(let.gr_context is None, propagate_by == "run")

        if propagate_by == "share":
            self.assertEqual(counts, {0: 1, 1: 1, 2: 1, 3: 1, 4: 6})
        else:
            self.assertEqual(set(counts.values()), set([2]))
Example #6
0
    async def test_to_thread_contextvars(self):
        test_ctx = ContextVar('test_ctx')

        def get_ctx():
            return test_ctx.get()

        test_ctx.set('parrot')
        result = await asyncio.to_thread(get_ctx)

        self.assertEqual(result, 'parrot')
Example #7
0
class ControlMixin:
    def __init__(self, *args, **kwargs):
        self.__disable = ContextVar(str(id(self)), default=())
        if "disable" in kwargs:
            self._set_disable(kwargs.pop("disable"))
        else:
            self._set_disable(not kwargs.pop("enable", True))
        super().__init__(*args, **kwargs)

    @property
    def _disable(self):
        return list(self.__disable.get(()))

    def _set_disable(self, value):
        if value is True:
            value = [
                _ALL,
            ]
        elif value is False:
            value = []
        self.__disable.set(tuple(value))

    def is_disable(self, *cmds: str) -> bool:
        _disable = self._disable
        if not cmds and _disable:
            return True
        for cmd in cmds:
            if cmd.lower() in [c.lower() for c in _disable]:
                return True
        return False

    def is_enable(self, *cmds):
        return not self.is_disable(*cmds)

    def disable(self, *cmds: str):
        _disable = self._disable
        if not cmds:
            _disable = [
                _ALL,
            ]
        if self._disable is False:
            _disable = []
        _disable.extend(cmds)
        self._set_disable(_disable)

    def enable(self, *cmds: str):
        _disable = self._disable
        if not cmds:
            _disable = []
        for cmd in cmds:
            if cmd in _disable:
                _disable.remove(cmd)
        self._set_disable(_disable)
Example #8
0
    def test_contextvar_propagation_sync(
            self, anyio_backend_name: str,
            anyio_backend_options: Dict[str, Any]) -> None:
        if anyio_backend_name == "asyncio" and sys.version_info < (3, 7):
            pytest.skip("Asyncio does not propagate context before Python 3.7")

        var = ContextVar("var", default=1)
        var.set(6)
        with start_blocking_portal(anyio_backend_name,
                                   anyio_backend_options) as portal:
            propagated_value = portal.call(var.get)

        assert propagated_value == 6
Example #9
0
async def test_retry_until_failed():
    counter = ContextVar('counter_test_retry_until_failed', default=5)

    async def tick():
        value = counter.get()
        counter.set(value - 1)
        if value <= 0:
            return SUCCESS
        if value == 3:
            raise RuntimeError('3')
        return FAILURE

    counter.set(100)
    assert await retry_until_failed(tick)()
Example #10
0
async def test_retry_until_success():
    counter = ContextVar('counter_test_retry_until_success', default=5)

    async def tick():
        value = counter.get()
        counter.set(value - 1)
        if value <= 0:
            return SUCCESS
        if value == 3:
            raise RuntimeError('3')
        return FAILURE

    counter.set(100)
    assert await retry_until_success(ignore_exception(tick))()
def test_retry_until_success(kernel):
    counter = ContextVar('counter_test_retry_until_success', default=5)

    async def tick():
        value = counter.get()
        counter.set(value - 1)
        if value <= 0:
            return SUCCESS
        if value == 3:
            raise RuntimeError('3')
        return FAILURE

    counter.set(100)
    assert kernel.run(retry_until_success(tick))
Example #12
0
    def test_contextvars(self):
        from contextvars import ContextVar
        var = ContextVar('var')
        var.set(0)

        async def set_val():
            var.set(42)

        async def coro():
            await set_val()
            await asyncio.sleep(0.01)
            return var.get()

        result = self.loop.run_until_complete(coro())
        self.assertEqual(result, 42)
Example #13
0
def get_counter(variable: ContextVar, fn: Callable):
    logger = logging.getLogger(fn.__module__).getChild(fn.__name__)
    logger.setLevel(logging.DEBUG)
    counter = Counter(logger=logger)
    token = variable.set(counter)
    yield counter
    variable.reset(token)
class ContextVarsRuntimeContext(_RuntimeContext):
    """An implementation of the RuntimeContext interface which wraps ContextVar under
    the hood. This is the prefered implementation for usage with Python 3.5+
    """

    _CONTEXT_KEY = "current_context"

    def __init__(self) -> None:
        self._current_context = ContextVar(self._CONTEXT_KEY,
                                           default=Context())

    def attach(self, context: Context) -> object:
        """Sets the current `Context` object. Returns a
        token that can be used to reset to the previous `Context`.

        Args:
            context: The Context to set.
        """
        return self._current_context.set(context)

    def get_current(self) -> Context:
        """Returns the current `Context` object."""
        return self._current_context.get()

    def detach(self, token: object) -> None:
        """Resets Context to a previous value

        Args:
            token: A reference to a previous Context.
        """
        self._current_context.reset(token)  # type: ignore
Example #15
0
    def test_make_target(self):
        """Context Setup"""
        from contextvars import ContextVar

        foo = ContextVar("foo")
        token = foo.set("foo")
        """Test for ``_config.PASS_CONTEXTVARS = False`` (the default)."""
        assert _config.PASS_CONTEXTVARS is False

        def target_without_context():
            with pytest.raises(LookupError):
                foo.get()

        thread_without_context = Thread(
            target=make_target(target_without_context))
        thread_without_context.start()
        thread_without_context.join()
        """Test for ``_config.PASS_CONTEXTVARS = True``."""
        _config.PASS_CONTEXTVARS = True

        def target_with_context():
            assert foo.get() == "foo"

        thread_with_context = Thread(target=make_target(target_with_context))
        thread_with_context.start()
        thread_with_context.join()

        _config.PASS_CONTEXTVARS = False
        """Context Teardown"""
        foo.reset(token)
Example #16
0
class WeakContextVar:
    """弱引用版的上下文资源共享器
    """

    _instances = {}

    def __new__(cls, name):

        if name in cls._instances:
            inst = cls._instances[name]
        else:
            inst = cls._instances[name] = super().__new__(cls)

        return inst

    def __init__(self, name):

        self._context_var = ContextVar(name, default=None)

    def get(self):

        ref = self._context_var.get()

        return None if ref is None else ref()

    def set(self, value):

        return self._context_var.set(weakref.ref(value))
Example #17
0
def _tristate_armed(context_var: ContextVar, enabled=True):
    """Assumes "enabled" if `enabled` flag is None."""
    resetter = context_var.set(enabled if enabled is None else bool(enabled))
    try:
        yield
    finally:
        context_var.reset(resetter)
Example #18
0
class Urls:
    def __init__(self, rules):
        self.map = Map(rules)
        self.urls = ContextVar("urls", default=self.map.bind(''))

    def iter_rules(self):
        return self.map.iter_rules()

    @contextmanager
    def _bind(self, urls):
        token = self.urls.set(urls)
        try:
            yield
        finally:
            self.urls.reset(token)

    def bind(self, *args, **kwargs):
        return self._bind(self.map.bind(*args, **kwargs))

    def bind_to_environ(self, *args, **kwargs):
        return self._bind(self.map.bind_to_environ(*args, **kwargs))

    def build(self,
              endpoint,
              values=None,
              method=None,
              force_external=False,
              append_unknown=True):
        urls = self.urls.get()
        path = urls.build(endpoint, values, method, force_external,
                          append_unknown)
        if urls.url_scheme == 'file' and urls.server_name == '.':
            assert not force_external
            if path.endswith('/'):
                path = path + 'index.html'
            return os.path.relpath(path, os.path.dirname(urls.path_info))
        return path

    def match(self, path=None, return_rule=False):
        urls = self.urls.get()
        if path is None:
            return urls.match(return_rule=return_rule)
        if urls.url_scheme == 'file' and urls.server_name == '.':
            path = os.path.normpath(
                os.path.join(os.path.dirname(urls.path_info), path))
            if path.endswith("/index.html"):
                path = path[:-11]
        else:
            script_name = urls.script_name.rstrip("/")
            assert path.startswith(script_name)
            path = path[len(script_name):]

        return urls.match(path, return_rule=return_rule)

    def dispatch(self, *args, **kwargs):
        return self.urls.get().dispatch(*args, **kwargs)

    def current_url(self):
        return self.dispatch(lambda e, v: self.build(e, v))
Example #19
0
class Environment(object):
    __slots__ = ["level"]

    def __init__(self):
        self.level = ContextVar("%s_level" % self.__class__.__name__,
                                default=0)

    def __repr__(self):
        return "<%s, level %s>" % (self.__class__.__name__, self.level.get())

    def __bool__(self):
        return self.level.get() != 0

    def __enter__(self):
        self.level.set(self.level.get() + 1)

    def __exit__(self, exc_type=None, exc_val=None, exc_tb=None):
        self.level.set(self.level.get() - 1)
Example #20
0
class RenderVar:
    def __new__(cls, name, *args, **kwds):
        if not args:
            return lambda factory: cls(name, factory, **kwds)
        return super().__new__(cls)

    def __init__(self, name, factory, **kwds):
        if callable(factory):
            self._factory = partial(factory, **kwds)
        else:
            self._factory = lambda: factory
        self._var = ContextVar(name, default=_MISSING)
        _REGISTRY[name] = self

    def __get__(self, instance=None, owner=None):
        if (res := self._var.get()) is _MISSING:
            res = self._factory()
            self._var.set(res)
        return res
    class Slot(base_context.BaseRuntimeContext.Slot):
        def __init__(self, name: str, default: "object"):
            # pylint: disable=super-init-not-called
            self.name = name
            self.contextvar = ContextVar(name)  # type: ContextVar[object]
            self.default = base_context.wrap_callable(
                default)  # type: typing.Callable[..., object]

        def clear(self) -> None:
            self.contextvar.set(self.default())

        def get(self) -> "object":
            try:
                return self.contextvar.get()
            except LookupError:
                value = self.default()
                self.set(value)
                return value

        def set(self, value: "object") -> None:
            self.contextvar.set(value)
Example #22
0
class ContextService:
    def __init__(self):
        self._context = ContextVar(f"{self}", default={})
        self._undo_tokens = deque()

    def set_context_variable(self, var_name, var_value):
        context = self._get_context().copy()
        context[var_name] = var_value
        self._set_context(context)

    def get_context_variable(self, var_name):
        context = self._get_context()
        return context[var_name] if var_name in context else None

    def undo(self):
        undo_token = self._pop_undo_token()
        if undo_token is None:
            return

        self._context.reset(undo_token)

    async def run_in_context(self, func, *args, **kwargs):
        if not callable(func):
            raise BbpypValueError("func", func, f"must be callable")

        copied_context = copy_context()
        await copied_context.run(func, *args, **kwargs)

    def get_context_key_value_pairs(self):
        copied_context = copy_context()
        values = list(copied_context.values())
        kwargs = values[0] if isinstance(values[0], dict) else values[1]
        return kwargs

    def set_context_key_value_pairs(self, **kwargs):
        for key, value in kwargs.items():
            self.set_context_variable(key, value)

    def _get_context(self):
        return self._context.get({})

    def _set_context(self, context_value):
        undo_token = self._context.set(context_value)
        self._push_undo_token(undo_token)

    def _pop_undo_token(self):
        return self._undo_tokens.pop() if len(self._undo_tokens) > 0 else None

    def _push_undo_token(self, undo_token):
        self._undo_tokens.append(undo_token)
Example #23
0
class Ctx(Generic[T]):
    current_ctx: ContextVar[T]

    def __init__(self, name: str) -> None:
        self.current_ctx = ContextVar(name)

    def get(self, default: Union[D, T] = None) -> Union[T, D]:
        return self.current_ctx.get(default)

    def set(self, value: T):
        return self.current_ctx.set(value)

    def reset(self, token: Token):
        return self.current_ctx.reset(token)

    @contextmanager
    def use(self, value: T):
        token = self.set(value)
        yield
        self.reset(token)
class ContextVarsRuntimeContext(RuntimeContext):
    """An implementation of the RuntimeContext interface which wraps ContextVar under
    the hood. This is the prefered implementation for usage with Python 3.5+
    """

    _CONTEXT_KEY = "current_context"

    def __init__(self) -> None:
        self._current_context = ContextVar(self._CONTEXT_KEY,
                                           default=Context())

    def attach(self, context: Context) -> object:
        """See `opentelemetry.context.RuntimeContext.attach`."""
        return self._current_context.set(context)

    def get_current(self) -> Context:
        """See `opentelemetry.context.RuntimeContext.get_current`."""
        return self._current_context.get()

    def detach(self, token: object) -> None:
        """See `opentelemetry.context.RuntimeContext.detach`."""
        self._current_context.reset(token)  # type: ignore
class ContextManager:
    def __init__(self, name, namespace='', default=None):
        self.context = ContextVar('.'.join(([namespace] if namespace else []) +
                                           [name]),
                                  default=[None, default])
        self.default = default

    @property
    def name(self):
        return self.context.name

    def enter(self, value):
        values = [None, value]
        token = self.context.set(values)
        values[0] = token
        return self

    def leave(self):
        token, value = self.context.get()
        self.context.reset(token)

    def get(self):
        return self.context.get()[-1]

    __call__ = get

    def __enter__(self):
        return self

    def __exit__(self, exc_type, exc_val, exc_tb):
        self.leave()

    def __getitem__(self, item):
        token, value = self.context.get()
        return value[item]

    __getattr__ = __getitem__
Example #26
0
class StackVar:
    """ContextVar that represents a stack."""
    def __init__(self, name):
        """Initialize a StackVar."""
        self.var = ContextVar(name, default=(None, None))
        self.var.set((None, None))

    def push(self, x):
        """Push a new value on the stack."""
        self.var.set((x, self.var.get()))

    def pop(self):
        """Remove the top element of the stack and return it."""
        curr, prev = self.var.get()
        assert prev is not None
        self.var.set(prev)
        return curr

    def top(self):
        """Return the top element of the stack."""
        return self.var.get()[0]
Example #27
0
from contextvars import copy_context

t = []
for i in range(len(valset.columns)) : 
    if valset.iloc[:, i].isna().any() == True :
        t = valset.columns[i]

    var : ContextVar[str] = ContextVar(valset.columns[i], default=42)
    ctx: t = copy_context()
    print(type(t))

#%%
t = None
if t == None :
    var = ContextVar('None')
token = var.set('new value')
# code that uses 'var'; var.get() returns 'new value'.

token.var.name

#%%
tcname = 1
#%%
var: ContextVar[int] = ContextVar('var', default=42)
token = var.set('new value')
#%%
token.old_value
#%%
from contextvars import copy_context

ctx: tcname = copy_context()
Example #28
0
class Database:
    SUPPORTED_BACKENDS = {
        "postgresql": "databases.backends.postgres:PostgresBackend",
        "postgresql+aiopg": "databases.backends.aiopg:AiopgBackend",
        "postgres": "databases.backends.postgres:PostgresBackend",
        "mysql": "databases.backends.mysql:MySQLBackend",
        "sqlite": "databases.backends.sqlite:SQLiteBackend",
    }

    def __init__(
        self,
        url: typing.Union[str, "DatabaseURL"],
        *,
        force_rollback: bool = False,
        **options: typing.Any,
    ):
        self.url = DatabaseURL(url)
        self.options = options
        self.is_connected = False

        self._force_rollback = force_rollback

        backend_str = self.SUPPORTED_BACKENDS[self.url.scheme]
        backend_cls = import_from_string(backend_str)
        assert issubclass(backend_cls, DatabaseBackend)
        self._backend = backend_cls(self.url, **self.options)

        # Connections are stored as task-local state.
        self._connection_context = ContextVar(
            "connection_context")  # type: ContextVar

        # When `force_rollback=True` is used, we use a single global
        # connection, within a transaction that always rolls back.
        self._global_connection = None  # type: typing.Optional[Connection]
        self._global_transaction = None  # type: typing.Optional[Transaction]

    async def connect(self) -> None:
        """
        Establish the connection pool.
        """
        assert not self.is_connected, "Already connected."

        await self._backend.connect()
        logger.info("Connected to database %s",
                    self.url.obscure_password,
                    extra=CONNECT_EXTRA)
        self.is_connected = True

        if self._force_rollback:
            assert self._global_connection is None
            assert self._global_transaction is None

            self._global_connection = Connection(self._backend)
            self._global_transaction = self._global_connection.transaction(
                force_rollback=True)

            await self._global_transaction.__aenter__()

    async def disconnect(self) -> None:
        """
        Close all connections in the connection pool.
        """
        assert self.is_connected, "Already disconnected."

        if self._force_rollback:
            assert self._global_connection is not None
            assert self._global_transaction is not None

            await self._global_transaction.__aexit__()

            self._global_transaction = None
            self._global_connection = None

        await self._backend.disconnect()
        logger.info(
            "Disconnected from database %s",
            self.url.obscure_password,
            extra=DISCONNECT_EXTRA,
        )
        self.is_connected = False

    async def __aenter__(self) -> "Database":
        await self.connect()
        return self

    async def __aexit__(
        self,
        exc_type: typing.Type[BaseException] = None,
        exc_value: BaseException = None,
        traceback: TracebackType = None,
    ) -> None:
        await self.disconnect()

    async def fetch_all(self,
                        query: typing.Union[ClauseElement, str],
                        values: dict = None) -> typing.List[typing.Mapping]:
        async with self.connection() as connection:
            return await connection.fetch_all(query, values)

    async def fetch_one(
            self,
            query: typing.Union[ClauseElement, str],
            values: dict = None) -> typing.Optional[typing.Mapping]:
        async with self.connection() as connection:
            return await connection.fetch_one(query, values)

    async def fetch_val(
        self,
        query: typing.Union[ClauseElement, str],
        values: dict = None,
        column: typing.Any = 0,
    ) -> typing.Any:
        async with self.connection() as connection:
            return await connection.fetch_val(query, values, column=column)

    async def execute(self,
                      query: typing.Union[ClauseElement, str],
                      values: dict = None) -> typing.Any:
        async with self.connection() as connection:
            return await connection.execute(query, values)

    async def execute_many(self, query: typing.Union[ClauseElement, str],
                           values: list) -> None:
        async with self.connection() as connection:
            return await connection.execute_many(query, values)

    async def iterate(self,
                      query: typing.Union[ClauseElement, str],
                      values: dict = None
                      ) -> typing.AsyncGenerator[typing.Mapping, None]:
        async with self.connection() as connection:
            async for record in connection.iterate(query, values):
                yield record

    def connection(self) -> "Connection":
        if self._global_connection is not None:
            return self._global_connection

        try:
            return self._connection_context.get()
        except LookupError:
            connection = Connection(self._backend)
            self._connection_context.set(connection)
            return connection

    def transaction(self,
                    *,
                    force_rollback: bool = False,
                    **kwargs: typing.Any) -> "Transaction":
        return Transaction(self.connection,
                           force_rollback=force_rollback,
                           **kwargs)

    @contextlib.contextmanager
    def force_rollback(self) -> typing.Iterator[None]:
        initial = self._force_rollback
        self._force_rollback = True
        try:
            yield
        finally:
            self._force_rollback = initial
Example #29
0
class Database:
    SUPPORTED_BACKENDS = {
        "postgresql": "databases.backends.postgres:PostgresBackend",
        "mysql": "databases.backends.mysql:MySQLBackend",
        "sqlite": "databases.backends.sqlite:SQLiteBackend",
    }

    def __init__(self,
                 url: typing.Union[str, "DatabaseURL"],
                 *,
                 force_rollback: bool = False):
        self._url = DatabaseURL(url)
        self._force_rollback = force_rollback
        self.is_connected = False

        backend_str = self.SUPPORTED_BACKENDS[self._url.dialect]
        backend_cls = import_from_string(backend_str)
        assert issubclass(backend_cls, DatabaseBackend)
        self._backend = backend_cls(self._url)

        # Connections are stored as task-local state.
        self._connection_context = ContextVar(
            "connection_context")  # type: ContextVar

        # When `force_rollback=True` is used, we use a single global
        # connection, within a transaction that always rolls back.
        self._global_connection = None  # type: typing.Optional[Connection]
        self._global_transaction = None  # type: typing.Optional[Transaction]

        if self._force_rollback:
            self._global_connection = Connection(self._backend)
            self._global_transaction = self._global_connection.transaction(
                force_rollback=True)

    async def connect(self) -> None:
        """
        Establish the connection pool.
        """
        assert not self.is_connected, "Already connected."

        await self._backend.connect()
        self.is_connected = True

        if self._force_rollback:
            assert self._global_transaction is not None
            await self._global_transaction.__aenter__()

    async def disconnect(self) -> None:
        """
        Close all connections in the connection pool.
        """
        assert self.is_connected, "Already disconnected."

        if self._force_rollback:
            assert self._global_transaction is not None
            await self._global_transaction.__aexit__()

        await self._backend.disconnect()
        self.is_connected = False

    async def __aenter__(self) -> "Database":
        await self.connect()
        return self

    async def __aexit__(
        self,
        exc_type: typing.Type[BaseException] = None,
        exc_value: BaseException = None,
        traceback: TracebackType = None,
    ) -> None:
        await self.disconnect()

    async def fetch_all(self, query: ClauseElement) -> typing.List[RowProxy]:
        async with self.connection() as connection:
            return await connection.fetch_all(query=query)

    async def fetch_one(self, query: ClauseElement) -> RowProxy:
        async with self.connection() as connection:
            return await connection.fetch_one(query=query)

    async def execute(self,
                      query: ClauseElement,
                      values: dict = None) -> typing.Any:
        async with self.connection() as connection:
            return await connection.execute(query=query, values=values)

    async def execute_many(self, query: ClauseElement, values: list) -> None:
        async with self.connection() as connection:
            return await connection.execute_many(query=query, values=values)

    async def iterate(
            self,
            query: ClauseElement) -> typing.AsyncGenerator[RowProxy, None]:
        async with self.connection() as connection:
            async for record in connection.iterate(query):
                yield record

    def connection(self) -> "Connection":
        if self._global_connection is not None:
            return self._global_connection

        try:
            return self._connection_context.get()
        except LookupError:
            connection = Connection(self._backend)
            self._connection_context.set(connection)
            return connection

    def transaction(self, *, force_rollback: bool = False) -> "Transaction":
        return self.connection().transaction(force_rollback=force_rollback)
Example #30
0
class Kanata(BaseDispatcher):
    "彼方."

    always = True  # 兼容重构版的 bcc.

    signature_list: List[Union[NormalMatch, PatternReceiver]]
    stop_exec_if_fail: bool = True

    parsed_items: ContextVar[Dict[str, MessageChain]]

    allow_quote: bool
    skip_one_at_in_quote: bool

    content_token: Optional[Token] = None

    def __init__(
        self,
        signature_list: List[Union[NormalMatch, PatternReceiver]],
        stop_exec_if_fail: bool = True,
        allow_quote: bool = True,
        skip_one_at_in_quote: bool = False,
    ) -> None:
        """该魔法方法用于实例化该参数解析器.

        Args:
            signature_list (List[Union[NormalMatch, PatternReceiver]]): 匹配标识链
            stop_exec_if_fail (bool, optional): 是否在无可用匹配时停止监听器执行. Defaults to True.
            allow_quote (bool, optional): 是否允许 Kanata 处理回复消息中的用户输入部分. Defaults to True.
            skip_one_at_in_quote (bool, optional): 是否允许 Kanata 在处理回复消息中的用户输入部分时自动删除可能\
                由 QQ 客户端添加的 At 和一个包含在单独 Plain 元素中的空格. Defaults to False.
        """
        self.signature_list = signature_list
        self.stop_exec_if_fail = stop_exec_if_fail
        self.parsed_items = ContextVar("kanata_parsed_items")
        self.allow_quote = allow_quote
        self.skip_one_at_in_quote = skip_one_at_in_quote

    @staticmethod
    def detect_index(
        signature_chain: Tuple[Union[NormalMatch, PatternReceiver]],
        message_chain: MessageChain,
    ) -> Optional[Dict[str, Tuple[MessageIndex, MessageIndex]]]:
        merged_chain = merge_signature_chain(signature_chain)
        message_chain = message_chain.asMerged()
        element_num = len(message_chain.__root__)
        end_index: MessageIndex = (
            element_num - 1,
            len(message_chain.__root__[-1].text) if element_num != 0
            and message_chain.__root__[-1].__class__ is Plain else None,
        )

        reached_message_index: MessageIndex = (0, None)
        # [0] => real_index
        # [1] => text_index(optional)

        start_index: MessageIndex = (0, None)

        match_result: Dict[Arguments, Tuple[
            MessageIndex, MessageIndex],  # start(include)  # stop(exclude)
                           ] = {}

        signature_iterable = InsertGenerator(enumerate(merged_chain))
        latest_index = None
        matching_recevier: Optional[Arguments] = None

        for signature_index, signature in signature_iterable:
            if isinstance(signature, (Arguments, PatternReceiver)):
                if matching_recevier:  # 已经选中了一个...
                    if isinstance(signature, Arguments):
                        if latest_index == signature_index:
                            matching_recevier.content.extend(signature.content)
                            continue
                        else:
                            raise TypeError(
                                "a unexpected case: match conflict")
                    if isinstance(signature, PatternReceiver):
                        matching_recevier.content.append(signature)
                        continue
                else:
                    if isinstance(signature, PatternReceiver):
                        signature = Arguments([signature])
                matching_recevier = signature
                start_index = reached_message_index
            elif isinstance(signature, NormalMatch):
                if not matching_recevier:
                    # 如果不要求匹配参数, 从当前位置(reached_message_index)开始匹配FullMatch.
                    current_chain = message_chain.subchain(
                        slice(reached_message_index, None, None))
                    if not current_chain.__root__:  # index 越界
                        return
                    if not isinstance(current_chain.__root__[0], Plain):
                        # 切片后第一个 **不是** Plain.
                        return
                    re_match_result = re.match(signature.operator(),
                                               current_chain.__root__[0].text)
                    if not re_match_result:
                        # 不匹配的
                        return
                    # 推进当前进度.
                    plain_text_length = len(current_chain.__root__[0].text)
                    pattern_length = re_match_result.end(
                    ) - re_match_result.start()
                    if (pattern_length + 1) > plain_text_length:  # 推进后可能造成错误
                        # 不推进 text_index 进度, 转而推进 element_index 进度
                        reached_message_index = (reached_message_index[0] + 1,
                                                 None)
                    else:
                        # 推进 element_index 进度至已匹配到的地方后.
                        reached_message_index = (
                            reached_message_index[0],
                            origin_or_zero(reached_message_index[1]) +
                            re_match_result.start() + pattern_length,
                        )
                else:
                    # 需要匹配参数(是否贪婪模式查找, 即是否从后向前)
                    greed = matching_recevier.isGreed
                    for element_index, element in enumerate(
                            message_chain.subchain(
                                slice(reached_message_index, None,
                                      None)).__root__):
                        if isinstance(element, Plain):
                            current_text: str = element.text
                            # 完成贪婪判断
                            text_find_result_list = list(
                                re.finditer(signature.operator(),
                                            current_text))
                            if not text_find_result_list:
                                continue
                            text_find_result = text_find_result_list[-int(greed
                                                                          )]
                            if not text_find_result:
                                continue
                            text_find_index = text_find_result.start()

                            # 找到了! 这里不仅要推进进度, 还要把当前匹配的参数记录结束位置并清理.
                            stop_index = (
                                reached_message_index[0] + element_index +
                                int(element_index == 0),
                                origin_or_zero(reached_message_index[1]) +
                                text_find_index,
                            )
                            match_result[matching_recevier] = (
                                copy.copy(start_index),
                                stop_index,
                            )

                            start_index = (0, None)
                            matching_recevier = None

                            pattern_length = (text_find_result.end() -
                                              text_find_result.start())
                            if (current_text == text_find_result.string[slice(
                                    *text_find_result.span())]):
                                # 此处是如果推进 text_index 就会被爆破....
                                # 推进 element_index 而不是 text_index
                                reached_message_index = (
                                    reached_message_index[0] + element_index +
                                    int(element_index != 0),
                                    None,
                                )
                            else:
                                reached_message_index = (
                                    reached_message_index[0] + element_index,
                                    origin_or_zero(reached_message_index[1]) +
                                    text_find_index + pattern_length,
                                )
                            break
                    else:
                        # 找遍了都没匹配到.
                        return
            latest_index = signature_index
        else:
            if matching_recevier:  # 到达了终点, 却仍然还要做点事的.
                # 计算终点坐标.
                text_index = None

                latest_element = message_chain.__root__[-1]
                if isinstance(latest_element, Plain):
                    text_index = len(latest_element.text)

                stop_index = (len(message_chain.__root__), text_index)
                match_result[matching_recevier] = (start_index, stop_index)
            else:  # 如果不需要继续捕获消息作为参数, 但 Signature 已经无法指示 Message 的样式时, 判定本次匹配非法.
                if reached_message_index < end_index:
                    return

        return match_result

    @staticmethod
    def detect_and_mapping(
        signature_chain: Tuple[Union[NormalMatch, PatternReceiver]],
        message_chain: MessageChain,
    ) -> Optional[Dict[Arguments, MessageChain]]:
        match_result = Kanata.detect_index(signature_chain, message_chain)
        if match_result is not None:
            return {
                k: message_chain[v[0]:(
                    v[1][0],
                    (v[1][1] - (origin_or_zero(v[0][1]) if
                                (v[1][0] <= v[0][0] <= v[1][0]) else 0)
                     ) if v[1][1] is not None else None,
                )]
                for k, v in match_result.items()
            }

    @staticmethod
    def allocation(
        mapping: Dict[Arguments, MessageChain]
    ) -> Optional[Dict[str, MessageChain]]:
        if mapping is None:
            return None
        result = {}
        for arguemnt_set, message_chain in mapping.items():
            length = len(arguemnt_set.content)
            for index, receiver in enumerate(arguemnt_set.content):
                if receiver.name in result:
                    raise ConflictItem(
                        "{0} is defined repeatedly".format(receiver))
                if isinstance(receiver, RequireParam):
                    if not message_chain.__root__:
                        return
                    result[receiver.name] = message_chain
                elif isinstance(receiver, OptionalParam):
                    if not message_chain.__root__:
                        result[receiver.name] = None
                    else:
                        result[receiver.name] = message_chain
                break  # 还没来得及做长度匹配...
        return result

    @lru_cache(None)
    async def catch_argument_names(self) -> List[str]:
        return [
            i.name for i in self.signature_list
            if isinstance(i, PatternReceiver)
        ]

    async def beforeDispatch(self, interface: DispatcherInterface):
        message_chain: MessageChain = (await interface.lookup_param(
            "__kanata_messagechain__", MessageChain, None)).exclude(Source)
        if set([i.__class__ for i in message_chain.__root__
                ]).intersection(BLOCKING_ELEMENTS):
            raise ExecutionStop()
        if self.allow_quote and message_chain.has(Quote):
            # 自动忽略自 Quote 后第一个 At
            # 0: Quote, 1: At, 2: Plain(一个空格, 可能会在以后的 mirai 版本后被其处理, 这里先自动处理这个了.)
            message_chain = message_chain[(3, None):]
            if self.skip_one_at_in_quote and message_chain.__root__:
                if message_chain.__root__[0].__class__ is At:
                    message_chain = message_chain[(
                        1, 1):]  # 利用 MessageIndex 可以非常快捷的实现特性.
        mapping_result = self.detect_and_mapping(self.signature_list,
                                                 message_chain)
        if mapping_result is not None:
            self.content_token = self.parsed_items.set(
                self.allocation(mapping_result))
        else:
            if self.stop_exec_if_fail:
                raise ExecutionStop()

    async def catch(self, interface: DispatcherInterface):
        if not self.content_token:
            return
        random_id = random.random()
        current_item = self.parsed_items.get()
        if current_item is not None:
            result = current_item.get(interface.name, random_id)
            return Force(result) if result is not random_id else None
        else:
            if self.stop_exec_if_fail:
                raise ExecutionStop()

    async def afterDispatch(
        self,
        interface: "DispatcherInterface",
        exception: Optional[Exception] = None,
        tb: Optional[TracebackType] = None,
    ):
        if self.content_token:
            self.parsed_items.reset(self.content_token)
            self.catch = Kanata.catch
            self.content_token = None