def trace_stack_push(trace_stack_var: ContextVar, node: Any) -> None: """Push an element to the top of a trace stack.""" trace_stack = trace_stack_var.get() if trace_stack is None: trace_stack = [] trace_stack_var.set(trace_stack) trace_stack.append(node)
def test_retry(kernel): counter = ContextVar('counter_test_retry', default=5) async def tick(): value = counter.get() counter.set(value - 1) if value <= 0: return SUCCESS if value == 3: raise RuntimeError('3') return FAILURE result = kernel.run(retry(tick)) assert not result assert isinstance(result, ExceptionDecorator) assert kernel.run(retry(tick)) counter.set(10) assert kernel.run(retry(tick, max_retry=11)) counter.set(100) assert kernel.run(retry(tick, max_retry=-1)) # negative with pytest.raises(AssertionError): retry(tick, max_retry=-2)
async def test_retry(): counter = ContextVar('counter_test_retry', default=5) async def tick(): value = counter.get() counter.set(value - 1) print(f"value: {value}") if value <= 0: return SUCCESS if value == 3: raise RuntimeError('3') return FAILURE result = await retry(ignore_exception(tick))() # counter: 5, 4, 3 assert not result assert isinstance(result, ControlFlowException) assert await retry(tick)() # counter: 2, 1, 0 counter.set(10) assert await retry(ignore_exception(tick), max_retry=11)() counter.set(100) assert await retry(ignore_exception(tick), max_retry=-1)() # negative with pytest.raises(AssertionError): retry(tick, max_retry=-2)
def _test_context(self, propagate): id_var = ContextVar("id", default=None) id_var.set(0) callback = getcurrent().switch counts = dict((i, 0) for i in range(5)) lets = [ greenlet( partial( partial(copy_context().run, self._increment) if propagate else self._increment, greenlet_id=i, ctx_var=id_var, callback=callback, counts=counts, expect=0 if propagate else None, )) for i in range(1, 5) ] for i in range(2): counts[id_var.get()] += 1 for let in lets: let.switch() self.assertEqual(set(counts.values()), set([2]))
def _test_context(self, propagate_by): id_var = ContextVar("id", default=None) id_var.set(0) callback = getcurrent().switch counts = dict((i, 0) for i in range(5)) lets = [ greenlet(partial( partial( copy_context().run, self._increment ) if propagate_by == "run" else self._increment, greenlet_id=i, ctx_var=id_var, callback=callback, counts=counts, expect=( i - 1 if propagate_by == "share" else 0 if propagate_by in ("set", "run") else None ) )) for i in range(1, 5) ] for let in lets: if propagate_by == "set": let.gr_context = copy_context() elif propagate_by == "share": let.gr_context = getcurrent().gr_context for i in range(2): counts[id_var.get()] += 1 for let in lets: let.switch() if propagate_by == "run": # Must leave each context.run() in reverse order of entry for let in reversed(lets): let.switch() else: # No context.run(), so fine to exit in any order. for let in lets: let.switch() for let in lets: self.assertTrue(let.dead) # When using run(), we leave the run() as the greenlet dies, # and there's no context "underneath". When not using run(), # gr_context still reflects the context the greenlet was # running in. self.assertEqual(let.gr_context is None, propagate_by == "run") if propagate_by == "share": self.assertEqual(counts, {0: 1, 1: 1, 2: 1, 3: 1, 4: 6}) else: self.assertEqual(set(counts.values()), set([2]))
async def test_to_thread_contextvars(self): test_ctx = ContextVar('test_ctx') def get_ctx(): return test_ctx.get() test_ctx.set('parrot') result = await asyncio.to_thread(get_ctx) self.assertEqual(result, 'parrot')
class ControlMixin: def __init__(self, *args, **kwargs): self.__disable = ContextVar(str(id(self)), default=()) if "disable" in kwargs: self._set_disable(kwargs.pop("disable")) else: self._set_disable(not kwargs.pop("enable", True)) super().__init__(*args, **kwargs) @property def _disable(self): return list(self.__disable.get(())) def _set_disable(self, value): if value is True: value = [ _ALL, ] elif value is False: value = [] self.__disable.set(tuple(value)) def is_disable(self, *cmds: str) -> bool: _disable = self._disable if not cmds and _disable: return True for cmd in cmds: if cmd.lower() in [c.lower() for c in _disable]: return True return False def is_enable(self, *cmds): return not self.is_disable(*cmds) def disable(self, *cmds: str): _disable = self._disable if not cmds: _disable = [ _ALL, ] if self._disable is False: _disable = [] _disable.extend(cmds) self._set_disable(_disable) def enable(self, *cmds: str): _disable = self._disable if not cmds: _disable = [] for cmd in cmds: if cmd in _disable: _disable.remove(cmd) self._set_disable(_disable)
def test_contextvar_propagation_sync( self, anyio_backend_name: str, anyio_backend_options: Dict[str, Any]) -> None: if anyio_backend_name == "asyncio" and sys.version_info < (3, 7): pytest.skip("Asyncio does not propagate context before Python 3.7") var = ContextVar("var", default=1) var.set(6) with start_blocking_portal(anyio_backend_name, anyio_backend_options) as portal: propagated_value = portal.call(var.get) assert propagated_value == 6
async def test_retry_until_failed(): counter = ContextVar('counter_test_retry_until_failed', default=5) async def tick(): value = counter.get() counter.set(value - 1) if value <= 0: return SUCCESS if value == 3: raise RuntimeError('3') return FAILURE counter.set(100) assert await retry_until_failed(tick)()
async def test_retry_until_success(): counter = ContextVar('counter_test_retry_until_success', default=5) async def tick(): value = counter.get() counter.set(value - 1) if value <= 0: return SUCCESS if value == 3: raise RuntimeError('3') return FAILURE counter.set(100) assert await retry_until_success(ignore_exception(tick))()
def test_retry_until_success(kernel): counter = ContextVar('counter_test_retry_until_success', default=5) async def tick(): value = counter.get() counter.set(value - 1) if value <= 0: return SUCCESS if value == 3: raise RuntimeError('3') return FAILURE counter.set(100) assert kernel.run(retry_until_success(tick))
def test_contextvars(self): from contextvars import ContextVar var = ContextVar('var') var.set(0) async def set_val(): var.set(42) async def coro(): await set_val() await asyncio.sleep(0.01) return var.get() result = self.loop.run_until_complete(coro()) self.assertEqual(result, 42)
def get_counter(variable: ContextVar, fn: Callable): logger = logging.getLogger(fn.__module__).getChild(fn.__name__) logger.setLevel(logging.DEBUG) counter = Counter(logger=logger) token = variable.set(counter) yield counter variable.reset(token)
class ContextVarsRuntimeContext(_RuntimeContext): """An implementation of the RuntimeContext interface which wraps ContextVar under the hood. This is the prefered implementation for usage with Python 3.5+ """ _CONTEXT_KEY = "current_context" def __init__(self) -> None: self._current_context = ContextVar(self._CONTEXT_KEY, default=Context()) def attach(self, context: Context) -> object: """Sets the current `Context` object. Returns a token that can be used to reset to the previous `Context`. Args: context: The Context to set. """ return self._current_context.set(context) def get_current(self) -> Context: """Returns the current `Context` object.""" return self._current_context.get() def detach(self, token: object) -> None: """Resets Context to a previous value Args: token: A reference to a previous Context. """ self._current_context.reset(token) # type: ignore
def test_make_target(self): """Context Setup""" from contextvars import ContextVar foo = ContextVar("foo") token = foo.set("foo") """Test for ``_config.PASS_CONTEXTVARS = False`` (the default).""" assert _config.PASS_CONTEXTVARS is False def target_without_context(): with pytest.raises(LookupError): foo.get() thread_without_context = Thread( target=make_target(target_without_context)) thread_without_context.start() thread_without_context.join() """Test for ``_config.PASS_CONTEXTVARS = True``.""" _config.PASS_CONTEXTVARS = True def target_with_context(): assert foo.get() == "foo" thread_with_context = Thread(target=make_target(target_with_context)) thread_with_context.start() thread_with_context.join() _config.PASS_CONTEXTVARS = False """Context Teardown""" foo.reset(token)
class WeakContextVar: """弱引用版的上下文资源共享器 """ _instances = {} def __new__(cls, name): if name in cls._instances: inst = cls._instances[name] else: inst = cls._instances[name] = super().__new__(cls) return inst def __init__(self, name): self._context_var = ContextVar(name, default=None) def get(self): ref = self._context_var.get() return None if ref is None else ref() def set(self, value): return self._context_var.set(weakref.ref(value))
def _tristate_armed(context_var: ContextVar, enabled=True): """Assumes "enabled" if `enabled` flag is None.""" resetter = context_var.set(enabled if enabled is None else bool(enabled)) try: yield finally: context_var.reset(resetter)
class Urls: def __init__(self, rules): self.map = Map(rules) self.urls = ContextVar("urls", default=self.map.bind('')) def iter_rules(self): return self.map.iter_rules() @contextmanager def _bind(self, urls): token = self.urls.set(urls) try: yield finally: self.urls.reset(token) def bind(self, *args, **kwargs): return self._bind(self.map.bind(*args, **kwargs)) def bind_to_environ(self, *args, **kwargs): return self._bind(self.map.bind_to_environ(*args, **kwargs)) def build(self, endpoint, values=None, method=None, force_external=False, append_unknown=True): urls = self.urls.get() path = urls.build(endpoint, values, method, force_external, append_unknown) if urls.url_scheme == 'file' and urls.server_name == '.': assert not force_external if path.endswith('/'): path = path + 'index.html' return os.path.relpath(path, os.path.dirname(urls.path_info)) return path def match(self, path=None, return_rule=False): urls = self.urls.get() if path is None: return urls.match(return_rule=return_rule) if urls.url_scheme == 'file' and urls.server_name == '.': path = os.path.normpath( os.path.join(os.path.dirname(urls.path_info), path)) if path.endswith("/index.html"): path = path[:-11] else: script_name = urls.script_name.rstrip("/") assert path.startswith(script_name) path = path[len(script_name):] return urls.match(path, return_rule=return_rule) def dispatch(self, *args, **kwargs): return self.urls.get().dispatch(*args, **kwargs) def current_url(self): return self.dispatch(lambda e, v: self.build(e, v))
class Environment(object): __slots__ = ["level"] def __init__(self): self.level = ContextVar("%s_level" % self.__class__.__name__, default=0) def __repr__(self): return "<%s, level %s>" % (self.__class__.__name__, self.level.get()) def __bool__(self): return self.level.get() != 0 def __enter__(self): self.level.set(self.level.get() + 1) def __exit__(self, exc_type=None, exc_val=None, exc_tb=None): self.level.set(self.level.get() - 1)
class RenderVar: def __new__(cls, name, *args, **kwds): if not args: return lambda factory: cls(name, factory, **kwds) return super().__new__(cls) def __init__(self, name, factory, **kwds): if callable(factory): self._factory = partial(factory, **kwds) else: self._factory = lambda: factory self._var = ContextVar(name, default=_MISSING) _REGISTRY[name] = self def __get__(self, instance=None, owner=None): if (res := self._var.get()) is _MISSING: res = self._factory() self._var.set(res) return res
class Slot(base_context.BaseRuntimeContext.Slot): def __init__(self, name: str, default: "object"): # pylint: disable=super-init-not-called self.name = name self.contextvar = ContextVar(name) # type: ContextVar[object] self.default = base_context.wrap_callable( default) # type: typing.Callable[..., object] def clear(self) -> None: self.contextvar.set(self.default()) def get(self) -> "object": try: return self.contextvar.get() except LookupError: value = self.default() self.set(value) return value def set(self, value: "object") -> None: self.contextvar.set(value)
class ContextService: def __init__(self): self._context = ContextVar(f"{self}", default={}) self._undo_tokens = deque() def set_context_variable(self, var_name, var_value): context = self._get_context().copy() context[var_name] = var_value self._set_context(context) def get_context_variable(self, var_name): context = self._get_context() return context[var_name] if var_name in context else None def undo(self): undo_token = self._pop_undo_token() if undo_token is None: return self._context.reset(undo_token) async def run_in_context(self, func, *args, **kwargs): if not callable(func): raise BbpypValueError("func", func, f"must be callable") copied_context = copy_context() await copied_context.run(func, *args, **kwargs) def get_context_key_value_pairs(self): copied_context = copy_context() values = list(copied_context.values()) kwargs = values[0] if isinstance(values[0], dict) else values[1] return kwargs def set_context_key_value_pairs(self, **kwargs): for key, value in kwargs.items(): self.set_context_variable(key, value) def _get_context(self): return self._context.get({}) def _set_context(self, context_value): undo_token = self._context.set(context_value) self._push_undo_token(undo_token) def _pop_undo_token(self): return self._undo_tokens.pop() if len(self._undo_tokens) > 0 else None def _push_undo_token(self, undo_token): self._undo_tokens.append(undo_token)
class Ctx(Generic[T]): current_ctx: ContextVar[T] def __init__(self, name: str) -> None: self.current_ctx = ContextVar(name) def get(self, default: Union[D, T] = None) -> Union[T, D]: return self.current_ctx.get(default) def set(self, value: T): return self.current_ctx.set(value) def reset(self, token: Token): return self.current_ctx.reset(token) @contextmanager def use(self, value: T): token = self.set(value) yield self.reset(token)
class ContextVarsRuntimeContext(RuntimeContext): """An implementation of the RuntimeContext interface which wraps ContextVar under the hood. This is the prefered implementation for usage with Python 3.5+ """ _CONTEXT_KEY = "current_context" def __init__(self) -> None: self._current_context = ContextVar(self._CONTEXT_KEY, default=Context()) def attach(self, context: Context) -> object: """See `opentelemetry.context.RuntimeContext.attach`.""" return self._current_context.set(context) def get_current(self) -> Context: """See `opentelemetry.context.RuntimeContext.get_current`.""" return self._current_context.get() def detach(self, token: object) -> None: """See `opentelemetry.context.RuntimeContext.detach`.""" self._current_context.reset(token) # type: ignore
class ContextManager: def __init__(self, name, namespace='', default=None): self.context = ContextVar('.'.join(([namespace] if namespace else []) + [name]), default=[None, default]) self.default = default @property def name(self): return self.context.name def enter(self, value): values = [None, value] token = self.context.set(values) values[0] = token return self def leave(self): token, value = self.context.get() self.context.reset(token) def get(self): return self.context.get()[-1] __call__ = get def __enter__(self): return self def __exit__(self, exc_type, exc_val, exc_tb): self.leave() def __getitem__(self, item): token, value = self.context.get() return value[item] __getattr__ = __getitem__
class StackVar: """ContextVar that represents a stack.""" def __init__(self, name): """Initialize a StackVar.""" self.var = ContextVar(name, default=(None, None)) self.var.set((None, None)) def push(self, x): """Push a new value on the stack.""" self.var.set((x, self.var.get())) def pop(self): """Remove the top element of the stack and return it.""" curr, prev = self.var.get() assert prev is not None self.var.set(prev) return curr def top(self): """Return the top element of the stack.""" return self.var.get()[0]
from contextvars import copy_context t = [] for i in range(len(valset.columns)) : if valset.iloc[:, i].isna().any() == True : t = valset.columns[i] var : ContextVar[str] = ContextVar(valset.columns[i], default=42) ctx: t = copy_context() print(type(t)) #%% t = None if t == None : var = ContextVar('None') token = var.set('new value') # code that uses 'var'; var.get() returns 'new value'. token.var.name #%% tcname = 1 #%% var: ContextVar[int] = ContextVar('var', default=42) token = var.set('new value') #%% token.old_value #%% from contextvars import copy_context ctx: tcname = copy_context()
class Database: SUPPORTED_BACKENDS = { "postgresql": "databases.backends.postgres:PostgresBackend", "postgresql+aiopg": "databases.backends.aiopg:AiopgBackend", "postgres": "databases.backends.postgres:PostgresBackend", "mysql": "databases.backends.mysql:MySQLBackend", "sqlite": "databases.backends.sqlite:SQLiteBackend", } def __init__( self, url: typing.Union[str, "DatabaseURL"], *, force_rollback: bool = False, **options: typing.Any, ): self.url = DatabaseURL(url) self.options = options self.is_connected = False self._force_rollback = force_rollback backend_str = self.SUPPORTED_BACKENDS[self.url.scheme] backend_cls = import_from_string(backend_str) assert issubclass(backend_cls, DatabaseBackend) self._backend = backend_cls(self.url, **self.options) # Connections are stored as task-local state. self._connection_context = ContextVar( "connection_context") # type: ContextVar # When `force_rollback=True` is used, we use a single global # connection, within a transaction that always rolls back. self._global_connection = None # type: typing.Optional[Connection] self._global_transaction = None # type: typing.Optional[Transaction] async def connect(self) -> None: """ Establish the connection pool. """ assert not self.is_connected, "Already connected." await self._backend.connect() logger.info("Connected to database %s", self.url.obscure_password, extra=CONNECT_EXTRA) self.is_connected = True if self._force_rollback: assert self._global_connection is None assert self._global_transaction is None self._global_connection = Connection(self._backend) self._global_transaction = self._global_connection.transaction( force_rollback=True) await self._global_transaction.__aenter__() async def disconnect(self) -> None: """ Close all connections in the connection pool. """ assert self.is_connected, "Already disconnected." if self._force_rollback: assert self._global_connection is not None assert self._global_transaction is not None await self._global_transaction.__aexit__() self._global_transaction = None self._global_connection = None await self._backend.disconnect() logger.info( "Disconnected from database %s", self.url.obscure_password, extra=DISCONNECT_EXTRA, ) self.is_connected = False async def __aenter__(self) -> "Database": await self.connect() return self async def __aexit__( self, exc_type: typing.Type[BaseException] = None, exc_value: BaseException = None, traceback: TracebackType = None, ) -> None: await self.disconnect() async def fetch_all(self, query: typing.Union[ClauseElement, str], values: dict = None) -> typing.List[typing.Mapping]: async with self.connection() as connection: return await connection.fetch_all(query, values) async def fetch_one( self, query: typing.Union[ClauseElement, str], values: dict = None) -> typing.Optional[typing.Mapping]: async with self.connection() as connection: return await connection.fetch_one(query, values) async def fetch_val( self, query: typing.Union[ClauseElement, str], values: dict = None, column: typing.Any = 0, ) -> typing.Any: async with self.connection() as connection: return await connection.fetch_val(query, values, column=column) async def execute(self, query: typing.Union[ClauseElement, str], values: dict = None) -> typing.Any: async with self.connection() as connection: return await connection.execute(query, values) async def execute_many(self, query: typing.Union[ClauseElement, str], values: list) -> None: async with self.connection() as connection: return await connection.execute_many(query, values) async def iterate(self, query: typing.Union[ClauseElement, str], values: dict = None ) -> typing.AsyncGenerator[typing.Mapping, None]: async with self.connection() as connection: async for record in connection.iterate(query, values): yield record def connection(self) -> "Connection": if self._global_connection is not None: return self._global_connection try: return self._connection_context.get() except LookupError: connection = Connection(self._backend) self._connection_context.set(connection) return connection def transaction(self, *, force_rollback: bool = False, **kwargs: typing.Any) -> "Transaction": return Transaction(self.connection, force_rollback=force_rollback, **kwargs) @contextlib.contextmanager def force_rollback(self) -> typing.Iterator[None]: initial = self._force_rollback self._force_rollback = True try: yield finally: self._force_rollback = initial
class Database: SUPPORTED_BACKENDS = { "postgresql": "databases.backends.postgres:PostgresBackend", "mysql": "databases.backends.mysql:MySQLBackend", "sqlite": "databases.backends.sqlite:SQLiteBackend", } def __init__(self, url: typing.Union[str, "DatabaseURL"], *, force_rollback: bool = False): self._url = DatabaseURL(url) self._force_rollback = force_rollback self.is_connected = False backend_str = self.SUPPORTED_BACKENDS[self._url.dialect] backend_cls = import_from_string(backend_str) assert issubclass(backend_cls, DatabaseBackend) self._backend = backend_cls(self._url) # Connections are stored as task-local state. self._connection_context = ContextVar( "connection_context") # type: ContextVar # When `force_rollback=True` is used, we use a single global # connection, within a transaction that always rolls back. self._global_connection = None # type: typing.Optional[Connection] self._global_transaction = None # type: typing.Optional[Transaction] if self._force_rollback: self._global_connection = Connection(self._backend) self._global_transaction = self._global_connection.transaction( force_rollback=True) async def connect(self) -> None: """ Establish the connection pool. """ assert not self.is_connected, "Already connected." await self._backend.connect() self.is_connected = True if self._force_rollback: assert self._global_transaction is not None await self._global_transaction.__aenter__() async def disconnect(self) -> None: """ Close all connections in the connection pool. """ assert self.is_connected, "Already disconnected." if self._force_rollback: assert self._global_transaction is not None await self._global_transaction.__aexit__() await self._backend.disconnect() self.is_connected = False async def __aenter__(self) -> "Database": await self.connect() return self async def __aexit__( self, exc_type: typing.Type[BaseException] = None, exc_value: BaseException = None, traceback: TracebackType = None, ) -> None: await self.disconnect() async def fetch_all(self, query: ClauseElement) -> typing.List[RowProxy]: async with self.connection() as connection: return await connection.fetch_all(query=query) async def fetch_one(self, query: ClauseElement) -> RowProxy: async with self.connection() as connection: return await connection.fetch_one(query=query) async def execute(self, query: ClauseElement, values: dict = None) -> typing.Any: async with self.connection() as connection: return await connection.execute(query=query, values=values) async def execute_many(self, query: ClauseElement, values: list) -> None: async with self.connection() as connection: return await connection.execute_many(query=query, values=values) async def iterate( self, query: ClauseElement) -> typing.AsyncGenerator[RowProxy, None]: async with self.connection() as connection: async for record in connection.iterate(query): yield record def connection(self) -> "Connection": if self._global_connection is not None: return self._global_connection try: return self._connection_context.get() except LookupError: connection = Connection(self._backend) self._connection_context.set(connection) return connection def transaction(self, *, force_rollback: bool = False) -> "Transaction": return self.connection().transaction(force_rollback=force_rollback)
class Kanata(BaseDispatcher): "彼方." always = True # 兼容重构版的 bcc. signature_list: List[Union[NormalMatch, PatternReceiver]] stop_exec_if_fail: bool = True parsed_items: ContextVar[Dict[str, MessageChain]] allow_quote: bool skip_one_at_in_quote: bool content_token: Optional[Token] = None def __init__( self, signature_list: List[Union[NormalMatch, PatternReceiver]], stop_exec_if_fail: bool = True, allow_quote: bool = True, skip_one_at_in_quote: bool = False, ) -> None: """该魔法方法用于实例化该参数解析器. Args: signature_list (List[Union[NormalMatch, PatternReceiver]]): 匹配标识链 stop_exec_if_fail (bool, optional): 是否在无可用匹配时停止监听器执行. Defaults to True. allow_quote (bool, optional): 是否允许 Kanata 处理回复消息中的用户输入部分. Defaults to True. skip_one_at_in_quote (bool, optional): 是否允许 Kanata 在处理回复消息中的用户输入部分时自动删除可能\ 由 QQ 客户端添加的 At 和一个包含在单独 Plain 元素中的空格. Defaults to False. """ self.signature_list = signature_list self.stop_exec_if_fail = stop_exec_if_fail self.parsed_items = ContextVar("kanata_parsed_items") self.allow_quote = allow_quote self.skip_one_at_in_quote = skip_one_at_in_quote @staticmethod def detect_index( signature_chain: Tuple[Union[NormalMatch, PatternReceiver]], message_chain: MessageChain, ) -> Optional[Dict[str, Tuple[MessageIndex, MessageIndex]]]: merged_chain = merge_signature_chain(signature_chain) message_chain = message_chain.asMerged() element_num = len(message_chain.__root__) end_index: MessageIndex = ( element_num - 1, len(message_chain.__root__[-1].text) if element_num != 0 and message_chain.__root__[-1].__class__ is Plain else None, ) reached_message_index: MessageIndex = (0, None) # [0] => real_index # [1] => text_index(optional) start_index: MessageIndex = (0, None) match_result: Dict[Arguments, Tuple[ MessageIndex, MessageIndex], # start(include) # stop(exclude) ] = {} signature_iterable = InsertGenerator(enumerate(merged_chain)) latest_index = None matching_recevier: Optional[Arguments] = None for signature_index, signature in signature_iterable: if isinstance(signature, (Arguments, PatternReceiver)): if matching_recevier: # 已经选中了一个... if isinstance(signature, Arguments): if latest_index == signature_index: matching_recevier.content.extend(signature.content) continue else: raise TypeError( "a unexpected case: match conflict") if isinstance(signature, PatternReceiver): matching_recevier.content.append(signature) continue else: if isinstance(signature, PatternReceiver): signature = Arguments([signature]) matching_recevier = signature start_index = reached_message_index elif isinstance(signature, NormalMatch): if not matching_recevier: # 如果不要求匹配参数, 从当前位置(reached_message_index)开始匹配FullMatch. current_chain = message_chain.subchain( slice(reached_message_index, None, None)) if not current_chain.__root__: # index 越界 return if not isinstance(current_chain.__root__[0], Plain): # 切片后第一个 **不是** Plain. return re_match_result = re.match(signature.operator(), current_chain.__root__[0].text) if not re_match_result: # 不匹配的 return # 推进当前进度. plain_text_length = len(current_chain.__root__[0].text) pattern_length = re_match_result.end( ) - re_match_result.start() if (pattern_length + 1) > plain_text_length: # 推进后可能造成错误 # 不推进 text_index 进度, 转而推进 element_index 进度 reached_message_index = (reached_message_index[0] + 1, None) else: # 推进 element_index 进度至已匹配到的地方后. reached_message_index = ( reached_message_index[0], origin_or_zero(reached_message_index[1]) + re_match_result.start() + pattern_length, ) else: # 需要匹配参数(是否贪婪模式查找, 即是否从后向前) greed = matching_recevier.isGreed for element_index, element in enumerate( message_chain.subchain( slice(reached_message_index, None, None)).__root__): if isinstance(element, Plain): current_text: str = element.text # 完成贪婪判断 text_find_result_list = list( re.finditer(signature.operator(), current_text)) if not text_find_result_list: continue text_find_result = text_find_result_list[-int(greed )] if not text_find_result: continue text_find_index = text_find_result.start() # 找到了! 这里不仅要推进进度, 还要把当前匹配的参数记录结束位置并清理. stop_index = ( reached_message_index[0] + element_index + int(element_index == 0), origin_or_zero(reached_message_index[1]) + text_find_index, ) match_result[matching_recevier] = ( copy.copy(start_index), stop_index, ) start_index = (0, None) matching_recevier = None pattern_length = (text_find_result.end() - text_find_result.start()) if (current_text == text_find_result.string[slice( *text_find_result.span())]): # 此处是如果推进 text_index 就会被爆破.... # 推进 element_index 而不是 text_index reached_message_index = ( reached_message_index[0] + element_index + int(element_index != 0), None, ) else: reached_message_index = ( reached_message_index[0] + element_index, origin_or_zero(reached_message_index[1]) + text_find_index + pattern_length, ) break else: # 找遍了都没匹配到. return latest_index = signature_index else: if matching_recevier: # 到达了终点, 却仍然还要做点事的. # 计算终点坐标. text_index = None latest_element = message_chain.__root__[-1] if isinstance(latest_element, Plain): text_index = len(latest_element.text) stop_index = (len(message_chain.__root__), text_index) match_result[matching_recevier] = (start_index, stop_index) else: # 如果不需要继续捕获消息作为参数, 但 Signature 已经无法指示 Message 的样式时, 判定本次匹配非法. if reached_message_index < end_index: return return match_result @staticmethod def detect_and_mapping( signature_chain: Tuple[Union[NormalMatch, PatternReceiver]], message_chain: MessageChain, ) -> Optional[Dict[Arguments, MessageChain]]: match_result = Kanata.detect_index(signature_chain, message_chain) if match_result is not None: return { k: message_chain[v[0]:( v[1][0], (v[1][1] - (origin_or_zero(v[0][1]) if (v[1][0] <= v[0][0] <= v[1][0]) else 0) ) if v[1][1] is not None else None, )] for k, v in match_result.items() } @staticmethod def allocation( mapping: Dict[Arguments, MessageChain] ) -> Optional[Dict[str, MessageChain]]: if mapping is None: return None result = {} for arguemnt_set, message_chain in mapping.items(): length = len(arguemnt_set.content) for index, receiver in enumerate(arguemnt_set.content): if receiver.name in result: raise ConflictItem( "{0} is defined repeatedly".format(receiver)) if isinstance(receiver, RequireParam): if not message_chain.__root__: return result[receiver.name] = message_chain elif isinstance(receiver, OptionalParam): if not message_chain.__root__: result[receiver.name] = None else: result[receiver.name] = message_chain break # 还没来得及做长度匹配... return result @lru_cache(None) async def catch_argument_names(self) -> List[str]: return [ i.name for i in self.signature_list if isinstance(i, PatternReceiver) ] async def beforeDispatch(self, interface: DispatcherInterface): message_chain: MessageChain = (await interface.lookup_param( "__kanata_messagechain__", MessageChain, None)).exclude(Source) if set([i.__class__ for i in message_chain.__root__ ]).intersection(BLOCKING_ELEMENTS): raise ExecutionStop() if self.allow_quote and message_chain.has(Quote): # 自动忽略自 Quote 后第一个 At # 0: Quote, 1: At, 2: Plain(一个空格, 可能会在以后的 mirai 版本后被其处理, 这里先自动处理这个了.) message_chain = message_chain[(3, None):] if self.skip_one_at_in_quote and message_chain.__root__: if message_chain.__root__[0].__class__ is At: message_chain = message_chain[( 1, 1):] # 利用 MessageIndex 可以非常快捷的实现特性. mapping_result = self.detect_and_mapping(self.signature_list, message_chain) if mapping_result is not None: self.content_token = self.parsed_items.set( self.allocation(mapping_result)) else: if self.stop_exec_if_fail: raise ExecutionStop() async def catch(self, interface: DispatcherInterface): if not self.content_token: return random_id = random.random() current_item = self.parsed_items.get() if current_item is not None: result = current_item.get(interface.name, random_id) return Force(result) if result is not random_id else None else: if self.stop_exec_if_fail: raise ExecutionStop() async def afterDispatch( self, interface: "DispatcherInterface", exception: Optional[Exception] = None, tb: Optional[TracebackType] = None, ): if self.content_token: self.parsed_items.reset(self.content_token) self.catch = Kanata.catch self.content_token = None