示例#1
0
文件: utils.py 项目: SSmJaE/PepperBot
def meet_command_exit(chain: MessageChain, command_config: CommandConfig):
    """退出判断"""

    logger.debug(pformat(command_config.exit_patterns))
    for pattern in command_config.exit_patterns:
        logger.debug(pformat(re.search(pattern, chain.pure_text)))
        if re.search(pattern, chain.pure_text):
            return True

    return False
示例#2
0
async def websocket_receiver(request, ws, protocol: T_BotProtocol):
    while True:
        try:
            data = await ws.recv()
            raw_event = json.loads(data)

            logger.debug(pformat(raw_event))
            await handle_event(protocol, raw_event)

        except EventHandleError as e:
            logger.error(e)

        except Exception:
            logger.exception("事件处理异常")
示例#3
0
文件: utils.py 项目: SSmJaE/PepperBot
def meet_text_prefix(
    chain: MessageChain,
    command_name: str,
    command_config: CommandConfig,
) -> Tuple[bool, str]:

    aliases: Set[str] = set(command_config.aliases)
    if command_config.include_class_name:
        aliases.add(command_name)
    # aliases.add("")  # 保证下方循环至少执行一次

    if command_config.need_prefix:
        prefixes: Iterable[str] = set(command_config.prefixes)
    else:  # 保证下方循环至少执行一次
        prefixes = [""]

    logger.debug(pformat(command_name))
    logger.debug(pformat(prefixes))
    logger.debug(pformat(aliases))

    for alias in aliases:
        for prefix in prefixes:
            final_prefix = prefix + alias
            logger.debug(pformat(final_prefix))
            if re.search(f"^{final_prefix}", chain.pure_text):
                return True, final_prefix

    return False, ""
示例#4
0
async def http_receiver(request, protocol: T_BotProtocol):
    try:
        raw_event = request.json

        logger.debug(pformat(raw_event))
        await handle_event(protocol, raw_event)

    except EventHandleError as e:
        logger.error(e)

    except Exception:
        logger.exception("事件处理异常")

    finally:
        return text("")
示例#5
0
def merge_text_of_segments(
        segments: List[T_SegmentInstance]) -> T_CompressedSegments:
    """合并相邻的Text片段,空格分隔,方便正则"""
    logger.debug(pformat(segments))

    if len(segments) <= 1:
        return segments

    compressed_segments: T_CompressedSegments = []

    text_buffer: List[Text] = []
    last_segment_type = Text
    segments_count = len(segments)

    for index, segment in enumerate(segments, start=1):

        # True, True
        if last_segment_type == Text and isinstance(segment, Text):
            text_buffer.append(segment)

            if index == segments_count:
                compressed_segments.append(
                    merge_multi_text_with_space(*text_buffer))

        # False, True
        elif last_segment_type != Text and isinstance(segment, Text):
            text_buffer.append(segment)

            if index == segments_count:
                compressed_segments.append(
                    merge_multi_text_with_space(*text_buffer))

        # True, False
        elif last_segment_type == Text and not isinstance(segment, Text):
            if text_buffer:
                compressed_segments.append(
                    merge_multi_text_with_space(*text_buffer))
                text_buffer = []

            compressed_segments.append(segment)

        # False, False
        else:
            compressed_segments.append(segment)

        last_segment_type = segment.__class__

    return compressed_segments
示例#6
0
async def run_class_handlers(
    protocol: T_BotProtocol,
    mode: T_RouteMode,
    source_id: str,
    raw_event: Dict,
    event_name: str,
    class_handler_names: Set[str],
):
    kwargs = await get_kwargs(protocol, mode, source_id, event_name, raw_event)
    logger.debug(pformat(kwargs))

    for class_handler_name in class_handler_names:
        class_handler_cache = class_handler_mapping[class_handler_name]
        logger.debug(pformat(class_handler_cache))

        event_handler = class_handler_cache.event_handlers.get(event_name)
        # logger.debug(pformat(event_name, event_handler))
        if event_handler:
            logger.info(f"开始执行 {class_handler_name} 的 {event_name} 事件响应")
            await await_or_sync(event_handler, **fit_kwargs(event_handler, kwargs))
示例#7
0
async def run_command_method(method_name, method, all_locals: Dict) -> Any:
    logger.debug(pformat(all_locals))
    injected_kwargs = dict(
        raw_event=all_locals["raw_event"],
        chain=all_locals["chain"],
        sender=all_locals["sender"],
        history=all_locals["status"].history,
        # context=all_locals["context"],
    )

    if method_name == "cache":
        injected_kwargs["exception"] = all_locals["exception"]

    else:
        if method_name not in COMMAND_LIFECYCLE_EXCEPTIONS.values():
            injected_kwargs = {**injected_kwargs, **all_locals["patterns"]}

    # 参数和group_message/friend_message事件参数一致
    logger.debug(f"将被注入 {method_name} 的参数\n{pformat(injected_kwargs)}")
    result = await await_or_sync(method, **fit_kwargs(method, injected_kwargs))
    logger.debug(pformat(result))
    return result
示例#8
0
async def parse_pattern(
    chain: MessageChain,
    sender: "CommandSender",
    method_name,
    cache: CommandMethodCache,
    prefix: str,
    context,
):
    """
    根据签名中的PatternArg,自动解析参数,并转换为对应类型,自动注入函数调用中

    满足pattern放行,如果不满足,会对方法调用进行拦截,
    """

    if not cache.compressed_patterns:
        return {}

    compressed_patterns = cache.compressed_patterns

    # todo pydantic有没有原生的功能
    # 尝试解析,解析失败,报错
    # todo List(展开), Any, Union, List[Union/Any]

    formot_hint = "请按照 "
    for arg_name, arg_type in cache.patterns:
        formot_hint += f"<{arg_name} : {arg_type.__name__}> "
    formot_hint += "的格式输入\n不需要输入<或者>,:右侧是该参数的类型"

    # formot_hint添加prefix
    # if method_name == "initial":

    results = {}

    try:

        compressed_segments = merge_text_of_segments(chain.segments)
        logger.debug(pformat(compressed_segments))

        if len(compressed_segments) != len(compressed_patterns):
            raise PatternFormotError(f"未提供足够参数,应为{len(cache.patterns)}个," +
                                     f"获得{len(chain.segments)}个")

        # 对initial应用pattern的情况,支持prefix
        # 目前仅支持文字prefix
        if method_name == "initial":
            if not isinstance(compressed_segments[0], Text):
                return PatternFormotError("目前仅支持文字前缀")

            with_prefix = compressed_segments[0].content
            without_prefix = re.sub(f"^{prefix}", "", with_prefix)
            compressed_segments[0].content = without_prefix

        results = get_pattern_results(compressed_patterns, compressed_segments)

    except PatternFormotError as e:
        # if on_format_error:
        #     return_text = await await_or_normal(
        #         on_format_error, *args, **kwargs
        #     )
        #     if return_text:
        #         await bot.group_msg(return_text)
        # else:
        await sender.send_message(
            # Text(f"{e}\n{formot_hint if with_formot_hint else ''}")
            Text(f"{e}\n{formot_hint}"))

        logger.exception("指令解析失败")
        raise e

    else:
        # todo patternResults的maxSize

        logger.debug(pformat(results))
        return results
示例#9
0
async def handle_event(protocol: T_BotProtocol, raw_event: Dict):
    logger.info("*" * 50)

    if not route_mapping.has_initial:
        await initial_bot_info()
        logger.info("成功获取bot元信息")
        route_mapping.has_initial = True

    adapter = get_adapter(protocol)
    raw_event_name = adapter.get_event_name(raw_event)
    protocol_event_name: str = f"{protocol}_" + raw_event_name

    if protocol == "onebot":
        if not onebot_event_meta.has_skip_buffered_event:
            flag = skip_current_onebot_event(raw_event, raw_event_name)
            if not flag:
                return

        logger.info("onebot 心跳")

    logger.info(f"{protocol}事件 {raw_event_name}")

    class_handler_names: Set[str] = set()
    # 对同一个消息来源,同一个class_handler也只应调用一次
    class_command_names: Set[str] = set()
    # 对同一个消息对象,在处理一次事件时
    # 在global_commands、mapping、validator中重复注册的指令应该只运行一次
    mode: T_RouteMode
    source_id: str
    command_trigger_events: List[str]

    if protocol_event_name in ALL_GROUP_EVENTS:
        mode = "group"
        source_id = get_source_id(protocol, mode, raw_event)
        command_trigger_events = GROUP_COMMAND_TRIGGER_EVENTS

    elif protocol_event_name in ALL_PRIVATE_EVENTS:
        mode = "private"
        source_id = get_source_id(protocol, mode, raw_event)
        command_trigger_events = PRIVATE_COMMAND_TRIGGER_EVENTS

    else:
        raise EventHandleError(f"无效/尚未实现的事件{protocol_event_name}")

    validator_handlers, validator_commands = await with_validators(mode, source_id)

    if protocol_event_name in command_trigger_events:
        # normal
        class_command_names |= route_mapping.global_commands[protocol][mode]
        class_command_names |= route_mapping.mapping[protocol][mode][source_id][
            "commands"
        ]
        class_command_names |= validator_commands

        if class_command_names:
            await run_class_commands(
                protocol, mode, source_id, raw_event, class_command_names
            )

        # lock_user
        # for rule_id, rule in lock_user_mapping.items():
        #     if source_id in rule[protocol]:
        #         await run_class_command(mode="lock_user",rule_id)

        # if mode =="group" or mode =="channel":
        #     lock_source_mapping

        # lock_source

    class_handler_names |= route_mapping.global_handlers[protocol][mode]
    class_handler_names |= route_mapping.mapping[protocol][mode][source_id][
        "class_handlers"
    ]
    class_handler_names |= validator_handlers

    logger.debug(pformat(class_handler_names))

    # 比如group_message这样的实现了统一事件的事件
    # 如果用户同时定义了group_message和onebot_group_message,应该执行两次
    event_mapping = UNIVERSAL_PROTOCOL_EVENT_MAPPING.get(raw_event_name)
    if event_mapping and protocol_event_name in event_mapping:
        await run_class_handlers(
            protocol, mode, source_id, raw_event, raw_event_name, class_handler_names
        )

    await run_class_handlers(
        protocol, mode, source_id, raw_event, protocol_event_name, class_handler_names
    )