Ejemplo n.º 1
0
def get_counter(variable: ContextVar, fn: Callable):
    logger = logging.getLogger(fn.__module__).getChild(fn.__name__)
    logger.setLevel(logging.DEBUG)
    counter = Counter(logger=logger)
    token = variable.set(counter)
    yield counter
    variable.reset(token)
Ejemplo n.º 2
0
def _tristate_armed(context_var: ContextVar, enabled=True):
    """Assumes "enabled" if `enabled` flag is None."""
    resetter = context_var.set(enabled if enabled is None else bool(enabled))
    try:
        yield
    finally:
        context_var.reset(resetter)
Ejemplo n.º 3
0
    def test_make_target(self):
        """Context Setup"""
        from contextvars import ContextVar

        foo = ContextVar("foo")
        token = foo.set("foo")
        """Test for ``_config.PASS_CONTEXTVARS = False`` (the default)."""
        assert _config.PASS_CONTEXTVARS is False

        def target_without_context():
            with pytest.raises(LookupError):
                foo.get()

        thread_without_context = Thread(
            target=make_target(target_without_context))
        thread_without_context.start()
        thread_without_context.join()
        """Test for ``_config.PASS_CONTEXTVARS = True``."""
        _config.PASS_CONTEXTVARS = True

        def target_with_context():
            assert foo.get() == "foo"

        thread_with_context = Thread(target=make_target(target_with_context))
        thread_with_context.start()
        thread_with_context.join()

        _config.PASS_CONTEXTVARS = False
        """Context Teardown"""
        foo.reset(token)
class ContextVarsRuntimeContext(_RuntimeContext):
    """An implementation of the RuntimeContext interface which wraps ContextVar under
    the hood. This is the prefered implementation for usage with Python 3.5+
    """

    _CONTEXT_KEY = "current_context"

    def __init__(self) -> None:
        self._current_context = ContextVar(self._CONTEXT_KEY,
                                           default=Context())

    def attach(self, context: Context) -> object:
        """Sets the current `Context` object. Returns a
        token that can be used to reset to the previous `Context`.

        Args:
            context: The Context to set.
        """
        return self._current_context.set(context)

    def get_current(self) -> Context:
        """Returns the current `Context` object."""
        return self._current_context.get()

    def detach(self, token: object) -> None:
        """Resets Context to a previous value

        Args:
            token: A reference to a previous Context.
        """
        self._current_context.reset(token)  # type: ignore
Ejemplo n.º 5
0
class Urls:
    def __init__(self, rules):
        self.map = Map(rules)
        self.urls = ContextVar("urls", default=self.map.bind(''))

    def iter_rules(self):
        return self.map.iter_rules()

    @contextmanager
    def _bind(self, urls):
        token = self.urls.set(urls)
        try:
            yield
        finally:
            self.urls.reset(token)

    def bind(self, *args, **kwargs):
        return self._bind(self.map.bind(*args, **kwargs))

    def bind_to_environ(self, *args, **kwargs):
        return self._bind(self.map.bind_to_environ(*args, **kwargs))

    def build(self,
              endpoint,
              values=None,
              method=None,
              force_external=False,
              append_unknown=True):
        urls = self.urls.get()
        path = urls.build(endpoint, values, method, force_external,
                          append_unknown)
        if urls.url_scheme == 'file' and urls.server_name == '.':
            assert not force_external
            if path.endswith('/'):
                path = path + 'index.html'
            return os.path.relpath(path, os.path.dirname(urls.path_info))
        return path

    def match(self, path=None, return_rule=False):
        urls = self.urls.get()
        if path is None:
            return urls.match(return_rule=return_rule)
        if urls.url_scheme == 'file' and urls.server_name == '.':
            path = os.path.normpath(
                os.path.join(os.path.dirname(urls.path_info), path))
            if path.endswith("/index.html"):
                path = path[:-11]
        else:
            script_name = urls.script_name.rstrip("/")
            assert path.startswith(script_name)
            path = path[len(script_name):]

        return urls.match(path, return_rule=return_rule)

    def dispatch(self, *args, **kwargs):
        return self.urls.get().dispatch(*args, **kwargs)

    def current_url(self):
        return self.dispatch(lambda e, v: self.build(e, v))
Ejemplo n.º 6
0
class ContextService:
    def __init__(self):
        self._context = ContextVar(f"{self}", default={})
        self._undo_tokens = deque()

    def set_context_variable(self, var_name, var_value):
        context = self._get_context().copy()
        context[var_name] = var_value
        self._set_context(context)

    def get_context_variable(self, var_name):
        context = self._get_context()
        return context[var_name] if var_name in context else None

    def undo(self):
        undo_token = self._pop_undo_token()
        if undo_token is None:
            return

        self._context.reset(undo_token)

    async def run_in_context(self, func, *args, **kwargs):
        if not callable(func):
            raise BbpypValueError("func", func, f"must be callable")

        copied_context = copy_context()
        await copied_context.run(func, *args, **kwargs)

    def get_context_key_value_pairs(self):
        copied_context = copy_context()
        values = list(copied_context.values())
        kwargs = values[0] if isinstance(values[0], dict) else values[1]
        return kwargs

    def set_context_key_value_pairs(self, **kwargs):
        for key, value in kwargs.items():
            self.set_context_variable(key, value)

    def _get_context(self):
        return self._context.get({})

    def _set_context(self, context_value):
        undo_token = self._context.set(context_value)
        self._push_undo_token(undo_token)

    def _pop_undo_token(self):
        return self._undo_tokens.pop() if len(self._undo_tokens) > 0 else None

    def _push_undo_token(self, undo_token):
        self._undo_tokens.append(undo_token)
class ContextVarsRuntimeContext(RuntimeContext):
    """An implementation of the RuntimeContext interface which wraps ContextVar under
    the hood. This is the prefered implementation for usage with Python 3.5+
    """

    _CONTEXT_KEY = "current_context"

    def __init__(self) -> None:
        self._current_context = ContextVar(self._CONTEXT_KEY,
                                           default=Context())

    def attach(self, context: Context) -> object:
        """See `opentelemetry.context.RuntimeContext.attach`."""
        return self._current_context.set(context)

    def get_current(self) -> Context:
        """See `opentelemetry.context.RuntimeContext.get_current`."""
        return self._current_context.get()

    def detach(self, token: object) -> None:
        """See `opentelemetry.context.RuntimeContext.detach`."""
        self._current_context.reset(token)  # type: ignore
Ejemplo n.º 8
0
class ContextManager:
    def __init__(self, name, namespace='', default=None):
        self.context = ContextVar('.'.join(([namespace] if namespace else []) +
                                           [name]),
                                  default=[None, default])
        self.default = default

    @property
    def name(self):
        return self.context.name

    def enter(self, value):
        values = [None, value]
        token = self.context.set(values)
        values[0] = token
        return self

    def leave(self):
        token, value = self.context.get()
        self.context.reset(token)

    def get(self):
        return self.context.get()[-1]

    __call__ = get

    def __enter__(self):
        return self

    def __exit__(self, exc_type, exc_val, exc_tb):
        self.leave()

    def __getitem__(self, item):
        token, value = self.context.get()
        return value[item]

    __getattr__ = __getitem__
Ejemplo n.º 9
0
class Ctx(Generic[T]):
    current_ctx: ContextVar[T]

    def __init__(self, name: str) -> None:
        self.current_ctx = ContextVar(name)

    def get(self, default: Union[D, T] = None) -> Union[T, D]:
        return self.current_ctx.get(default)

    def set(self, value: T):
        return self.current_ctx.set(value)

    def reset(self, token: Token):
        return self.current_ctx.reset(token)

    @contextmanager
    def use(self, value: T):
        token = self.set(value)
        yield
        self.reset(token)
Ejemplo n.º 10
0
class Kanata(BaseDispatcher):
    "彼方."

    always = True  # 兼容重构版的 bcc.

    signature_list: List[Union[NormalMatch, PatternReceiver]]
    stop_exec_if_fail: bool = True

    parsed_items: ContextVar[Dict[str, MessageChain]]

    def __init__(self,
                 signature_list: List[Union[NormalMatch, PatternReceiver]],
                 stop_exec_if_fail: bool = True) -> None:
        """该魔法方法用于实例化该参数解析器.

        Args:
            signature_list (List[Union[NormalMatch, PatternReceiver]]): 匹配标识链
            stop_exec_if_fail (bool, optional): 是否在无可用匹配时停止监听器执行. Defaults to True.
        """
        self.signature_list = signature_list
        self.stop_exec_if_fail = stop_exec_if_fail
        self.parsed_items = ContextVar("kanata_parsed_items")

    @staticmethod
    def detect_index(
        signature_chain: Tuple[Union[NormalMatch, PatternReceiver]],
        message_chain: MessageChain
    ) -> Optional[Dict[str, Tuple[MessageIndex, MessageIndex]]]:
        merged_chain = merge_signature_chain(signature_chain)
        message_chain = message_chain.asMerged()

        reached_message_index: MessageIndex = (0, None)
        # [0] => real_index
        # [1] => text_index(optional)

        start_index: MessageIndex = (0, None)

        match_result: Dict[Arguments, Tuple[MessageIndex,  # start(include)
                                            MessageIndex  # stop(exclude)
                                            ]] = {}

        signature_iterable = InsertGenerator(enumerate(merged_chain))
        latest_index = None
        matching_recevier: Optional[Arguments] = None

        for signature_index, signature in signature_iterable:
            if isinstance(signature, (Arguments, PatternReceiver)):
                if matching_recevier:  # 已经选中了一个...
                    if isinstance(signature, Arguments):
                        if latest_index == signature_index:
                            matching_recevier.content.extend(signature.content)
                            continue
                        else:
                            raise TypeError(
                                "a unexpected case: match conflict")
                    if isinstance(signature, PatternReceiver):
                        matching_recevier.content.append(signature)
                        continue
                else:
                    if isinstance(signature, PatternReceiver):
                        signature = Arguments([signature])
                matching_recevier = signature
                start_index = reached_message_index
            elif isinstance(signature, NormalMatch):
                if not matching_recevier:
                    # 如果不要求匹配参数, 从当前位置(reached_message_index)开始匹配FullMatch.
                    current_chain = message_chain.subchain(
                        slice(reached_message_index, None, None))
                    if not current_chain.__root__:  # index 越界
                        return
                    if not isinstance(current_chain.__root__[0], Plain):
                        # 切片后第一个 **不是** Plain.
                        return
                    re_match_result = re.match(signature.operator(),
                                               current_chain.__root__[0].text)
                    if not re_match_result:
                        # 不匹配的
                        return
                    # 推进当前进度.
                    plain_text_length = len(current_chain.__root__[0].text)
                    pattern_length = re_match_result.end(
                    ) - re_match_result.start()
                    if (pattern_length + 1) > plain_text_length:
                        # 不推进 text_index 进度, 转而推进 element_index 进度
                        reached_message_index = (reached_message_index[0] + 1,
                                                 None)
                    else:
                        # 推进 element_index 进度至已匹配到的地方后.
                        reached_message_index = (
                            reached_message_index[0],
                            origin_or_zero(reached_message_index[1]) +
                            re_match_result.start() + pattern_length)
                else:
                    # 需要匹配参数(是否贪婪模式查找, 即是否从后向前)
                    greed = matching_recevier.isGreed
                    for element_index, element in \
                            enumerate(message_chain.subchain(slice(reached_message_index, None, None)).__root__):
                        if isinstance(element, Plain):
                            current_text: str = element.text
                            # 完成贪婪判断
                            text_find_result_list = list(
                                re.finditer(signature.operator(),
                                            current_text))
                            if not text_find_result_list:
                                continue
                            text_find_result = text_find_result_list[-int(greed
                                                                          )]
                            if not text_find_result:
                                continue
                            text_find_index = text_find_result.start()

                            # 找到了! 这里不仅要推进进度, 还要把当前匹配的参数记录结束位置并清理.
                            stop_index = (
                                reached_message_index[0] + element_index +
                                int(element_index == 0),
                                origin_or_zero(reached_message_index[1]) +
                                text_find_index)
                            match_result[matching_recevier] = (
                                copy.copy(start_index), stop_index)

                            start_index = (0, None)
                            matching_recevier = None

                            pattern_length = text_find_result.end(
                            ) - text_find_result.start()
                            if current_text == text_find_result.string[slice(
                                    *text_find_result.span())]:
                                # 此处是如果推进 text_index 就会被爆破....
                                # 推进 element_index 而不是 text_index
                                reached_message_index = (
                                    reached_message_index[0] + element_index +
                                    int(element_index != 0), None)
                            else:
                                reached_message_index = (
                                    reached_message_index[0] + element_index,
                                    origin_or_zero(reached_message_index[1]) +
                                    text_find_index + pattern_length)
                            break
                    else:
                        # 找遍了都没匹配到.
                        return
            latest_index = signature_index
        else:
            if matching_recevier:  # 到达了终点, 却仍然还要做点事的.
                # 计算终点坐标.
                text_index = None

                latest_element = message_chain.__root__[-1]
                if isinstance(latest_element, Plain):
                    text_index = len(latest_element.text)

                stop_index = (len(message_chain.__root__), text_index)
                match_result[matching_recevier] = (start_index, stop_index)
        return match_result

    @staticmethod
    def detect_and_mapping(
        signature_chain: Tuple[Union[NormalMatch, PatternReceiver]],
        message_chain: MessageChain
    ) -> Optional[Dict[Arguments, MessageChain]]:
        match_result = Kanata.detect_index(signature_chain, message_chain)
        if match_result is not None:
            return {
                k: message_chain[v[0]:(v[1][0], (
                    v[1][1] - (origin_or_zero(v[0][1]) if any([
                        v[0][0] + 1 == v[1][0], v[0][0] == v[1][0], v[0][0] -
                        1 == v[1][0]
                    ]) else 0)) if v[1][1] is not None else None)]
                for k, v in match_result.items()
            }

    @staticmethod
    def allocation(
        mapping: Dict[Arguments, MessageChain]
    ) -> Optional[Dict[str, MessageChain]]:
        if mapping is None:
            return None
        result = {}
        for arguemnt_set, message_chain in mapping.items():
            length = len(arguemnt_set.content)
            for index, receiver in enumerate(arguemnt_set.content):
                if receiver.name in result:
                    raise ConflictItem(
                        '{0} is defined repeatedly'.format(receiver))
                if isinstance(receiver, RequireParam):
                    if not message_chain.__root__:
                        return
                    result[receiver.name] = message_chain
                elif isinstance(receiver, OptionalParam):
                    if not message_chain.__root__:
                        result[receiver.name] = None
                    else:
                        result[receiver.name] = message_chain
                break  # 还没来得及做长度匹配...
        return result

    @lru_cache(None)
    def catch_argument_names(self) -> List[str]:
        return [
            i.name for i in self.signature_list
            if isinstance(i, PatternReceiver)
        ]

    async def catch(self, interface: DispatcherInterface):
        # 因为 Dispatcher 的特性, 要用 yield (自动清理 self.parsed_items)
        token = None
        if self.parsed_items.get(None) is None:
            message_chain = (await interface.execute_with(
                "__kanata_messagechain_origin__", MessageChain,
                None)).exclude(Source, Quote, Xml, Json, App, Poke)
            mapping_result = self.detect_and_mapping(self.signature_list,
                                                     message_chain)
            if mapping_result is not None:
                token = self.parsed_items.set(self.allocation(mapping_result))
            else:
                if self.stop_exec_if_fail:
                    raise ExecutionStop()

        _i = random.random()
        result = self.parsed_items.get({}).get(interface.name, _i)
        if result is _i:
            yield  # 跳过.(另: Executor 应加入对 default 的不可预测行为反制措施.)
        else:
            yield Force(result)
        if token is not None:
            self.parsed_items.reset(token)
Ejemplo n.º 11
0
async def limited_gather(iteration_coroutine: Callable[[Any], Any],
                         iterable: Iterable[Any],
                         result_callback: Callable[[Task, Any], Any],
                         num_concurrent: int = 5) -> None:
    """
    Run coroutines concurrently with a maximum limit.

    :param iteration_coroutine: A coroutine to call with the values in the iterable.
    :param iterable: An iterable with values to pass to the coroutine
    :param result_callback: A callback which to call with the result of a call to the coroutine.
    :param num_concurrent: The maximum number of concurrent invocations of the coroutine.
    :return: None
    """

    if not iterable:
        return

    limiting_semaphore = Semaphore(num_concurrent)
    all_finished_event = Event()

    num_started = 0
    num_finished = 0

    passed_iteration_value_context_var = ContextVar('passed_iteration_value')

    def signal_callback_finished(*_, **__):
        nonlocal num_finished
        num_finished += 1
        all_finished_event.set()

    def task_done_callback(finished_task: Task) -> None:
        limiting_semaphore.release()

        response: Optional[Any] = None

        try:
            response = result_callback(
                finished_task, passed_iteration_value_context_var.get())
        except:
            LOG.exception(f'Unexpected exception in result callback.')
        finally:
            if iscoroutine(response):
                Task(coro=response).add_done_callback(signal_callback_finished)
                return

            signal_callback_finished()

    for iteration_value in iterable:
        await limiting_semaphore.acquire()
        num_started += 1

        passed_iteration_value_token: Token = passed_iteration_value_context_var.set(
            iteration_value)

        Task(coro=iteration_coroutine(iteration_value)).add_done_callback(
            task_done_callback, context=copy_context())

        passed_iteration_value_context_var.reset(passed_iteration_value_token)

    while num_finished < num_started:
        await all_finished_event.wait()
        all_finished_event.clear()
Ejemplo n.º 12
0
def context(contextvar: contextvars.ContextVar, value):
    token = contextvar.set(value)
    yield
    contextvar.reset(token)
Ejemplo n.º 13
0
class TimeoutIdleChecker(BaseIdleChecker):
    """
    Checks the idleness of a session by the elapsed time since last used.
    The usage means processing of any computation requests, such as
    query/batch-mode code execution and having active service-port connections.
    """

    name: ClassVar[str] = "timeout"

    _config_iv = t.Dict({
        t.Key('threshold', default="10m"): tx.TimeDuration(),
    }).allow_extra('*')

    idle_timeout: timedelta
    _policy_cache: ContextVar[Dict[AccessKey, Optional[Mapping[str, Any]]]]

    async def __ainit__(self) -> None:
        await super().__ainit__()
        self._policy_cache = ContextVar('_policy_cache')
        self._evh_session_started = \
            self._event_dispatcher.consume("session_started", None, self._session_started_cb)
        self._evh_execution_started = \
            self._event_dispatcher.consume("execution_started", None, self._execution_started_cb)
        self._evh_execution_finished = \
            self._event_dispatcher.consume("execution_finished", None, self._execution_exited_cb)
        self._evh_execution_timeout = \
            self._event_dispatcher.consume("execution_timeout", None, self._execution_exited_cb)
        self._evh_execution_cancelled = \
            self._event_dispatcher.consume("execution_cancelled", None, self._execution_exited_cb)

    async def aclose(self) -> None:
        self._event_dispatcher.unconsume("session_started",
                                         self._evh_session_started)
        self._event_dispatcher.unconsume("execution_started",
                                         self._evh_execution_started)
        self._event_dispatcher.unconsume("execution_finished",
                                         self._evh_execution_finished)
        self._event_dispatcher.unconsume("execution_timeout",
                                         self._evh_execution_timeout)
        self._event_dispatcher.unconsume("execution_cancelled",
                                         self._evh_execution_cancelled)
        await super().aclose()

    async def populate_config(self, raw_config: Mapping[str, Any]) -> None:
        config = self._config_iv.check(raw_config)
        self.idle_timeout = config['threshold']
        log.info(
            'TimeoutIdleChecker: default idle_timeout = {0:,} seconds',
            self.idle_timeout.total_seconds(),
        )

    async def update_app_streaming_status(
        self,
        session_id: SessionId,
        status: AppStreamingStatus,
    ) -> None:
        if status == AppStreamingStatus.HAS_ACTIVE_CONNECTIONS:
            await self._disable_timeout(session_id)
        elif status == AppStreamingStatus.NO_ACTIVE_CONNECTIONS:
            await self._update_timeout(session_id)

    async def _disable_timeout(self, session_id: SessionId) -> None:
        log.debug(f"TimeoutIdleChecker._disable_timeout({session_id})")
        await self._redis.set(f"session.{session_id}.last_access",
                              "0",
                              exist=self._redis.SET_IF_EXIST)

    async def _update_timeout(self, session_id: SessionId) -> None:
        log.debug(f"TimeoutIdleChecker._update_timeout({session_id})")
        t = await self._redis.time()
        await self._redis.set(
            f"session.{session_id}.last_access",
            f"{t:.06f}",
            expire=max(86400,
                       self.idle_timeout.total_seconds() * 2),
        )

    async def _session_started_cb(
        self,
        context: Any,
        agent_id: AgentId,
        event_name: str,
        session_id: SessionId,
        creation_id: str,
    ) -> None:
        await self._update_timeout(session_id)

    async def _execution_started_cb(
        self,
        context: Any,
        agent_id: AgentId,
        event_name: str,
        session_id: SessionId,
    ) -> None:
        await self._disable_timeout(session_id)

    async def _execution_exited_cb(
        self,
        context: Any,
        agent_id: AgentId,
        event_name: str,
        session_id: SessionId,
    ) -> None:
        await self._update_timeout(session_id)

    async def _do_idle_check(self, context: Any, agent_id: AgentId,
                             event_name: str) -> None:
        cache_token = self._policy_cache.set(dict())
        try:
            return await super()._do_idle_check(context, agent_id, event_name)
        finally:
            self._policy_cache.reset(cache_token)

    async def check_session(self, session: RowProxy,
                            dbconn: SAConnection) -> bool:
        session_id = session['id']
        active_streams = await self._redis.zcount(
            f"session.{session_id}.active_app_connections")
        if active_streams is not None and active_streams > 0:
            return True
        t = await self._redis.time()
        raw_last_access = await self._redis.get(
            f"session.{session_id}.last_access")
        if raw_last_access is None or raw_last_access == "0":
            return True
        last_access = float(raw_last_access)
        # serves as the default fallback if keypair resource policy's idle_timeout is "undefined"
        idle_timeout = self.idle_timeout.total_seconds()
        policy_cache = self._policy_cache.get()
        policy = policy_cache.get(session['access_key'], None)
        if policy is None:
            query = (sa.select([keypair_resource_policies]).select_from(
                sa.join(
                    keypairs,
                    keypair_resource_policies,
                    (keypair_resource_policies.c.name
                     == keypairs.c.resource_policy),
                )).where(keypairs.c.access_key == session['access_key']))
            result = await dbconn.execute(query)
            policy = await result.first()
            assert policy is not None
            policy_cache[session['access_key']] = policy
        # setting idle_timeout:
        # - zero/inf means "infinite"
        # - negative means "undefined"
        if policy['idle_timeout'] >= 0:
            idle_timeout = float(policy['idle_timeout'])
        if ((idle_timeout <= 0)
                or (math.isinf(idle_timeout) and idle_timeout > 0)
                or (t - last_access <= idle_timeout)):
            return True
        return False
Ejemplo n.º 14
0
class Backendpy:
    """The Backendpy ASGI handler"""

    def __new__(cls, *args, **kwargs):
        """Process Backendpy class instance."""
        config = get_config(project_path=cls._get_project_path())
        cls._add_project_sys_path(config['environment']['project_path'])
        return MiddlewareProcessor(paths=parse_list(config['middlewares']['active'])) \
            .run_process_application(application=super().__new__(cls))

    def __init__(self):
        """Initialize Backendpy class instance."""
        self.config = get_config(project_path=self._get_project_path(), error_logs=True)
        self.context = dict()
        self._request_context_var = ContextVar('request')
        self._hook_runner = HookRunner()
        self._router = Router()
        self._middleware_processor = MiddlewareProcessor(
            paths=parse_list(self.config['middlewares']['active']))
        self.errors = base_errors
        self._project_apps = self._get_project_apps()
        for app_data in self._project_apps:
            if app_data['app'].routes:
                for i in app_data['app'].routes:
                    self._router.routes.merge(i)
            if app_data['app'].hooks:
                for i in app_data['app'].hooks:
                    self._hook_runner.hooks.merge(i)
            if app_data['app'].errors:
                for i in app_data['app'].errors:
                    self.errors.merge(i)
            if app_data['app'].template_dirs:
                Template.template_dirs[app_data['path']] = \
                    [Path(app_data['path']).joinpath(p) for p in app_data['app'].template_dirs]
        self._lifespan_startup = False

    async def __call__(self, scope, receive, send):
        """Receive the requests and return the responses."""
        if scope['type'] == 'http':
            if not self._lifespan_startup:
                try:
                    await self.execute_event('startup')
                except Exception as e:
                    LOGGER.exception(e)
                else:
                    self._lifespan_startup = True

            try:
                body: bytes = await self._get_request_body(receive)
            except Exception as e:
                LOGGER.exception(f'Request data receive error: {e}')
                response = Error(1000)
                await self._send_response(send, *await response(None))
                return

            try:
                request = Request(app=self, scope=scope, body=body)
            except Exception as e:
                LOGGER.exception(f'Request instance creation error: {e}')
                response = Error(1000)
                await self._send_response(send, *await response(None))
                return

            token = self._request_context_var.set(request)
            await self.execute_event('request_start')
            await self._send_response(send, *await self._get_response(request))
            await self.execute_event('request_end')
            self._request_context_var.reset(token)

        elif scope['type'] == 'websocket':
            # TODO
            raise NotImplementedError

        elif scope['type'] == 'lifespan':
            await self._handle_lifespan(receive, send)

    def get_current_request(self):
        """Return the current request object."""
        return self._request_context_var.get()

    def event(self, name: str) -> callable:
        """Register an event hook with python decorator.

        .. seealso:: :func:`~backendpy.hook.Hooks.event`
        """
        return self._hook_runner.hooks.event(name)

    async def execute_event(self, name: str, args: Optional[Mapping[str, Any]] = None) -> None:
        """Trigger all hooks related to the event.

        :param name: The name of an event
        :param args: A dictionary-like object containing arguments passed to the hook function.
        """
        return await self._hook_runner.trigger(name, args)

    async def _handle_lifespan(self, receive, send):
        while True:
            message = await receive()
            if message['type'] == 'lifespan.startup':
                try:
                    await self.execute_event('startup')
                except Exception as e:
                    LOGGER.exception(e)
                    await send({'type': 'lifespan.startup.failed',
                                'message': str(e)})
                else:
                    self._lifespan_startup = True
                    await send({'type': 'lifespan.startup.complete'})

            elif message['type'] == 'lifespan.shutdown':
                try:
                    await self.execute_event('shutdown')
                except Exception as e:
                    LOGGER.exception(e)
                    await send({'type': 'lifespan.shutdown.failed',
                                'message': str(e)})
                else:
                    await send({'type': 'lifespan.shutdown.complete'})
                    return

    async def _get_response(self, request):
        try:
            request = await self._middleware_processor.run_process_request(request=request)
        except ExceptionResponse as e:
            return await e(request)
        except Exception as e:
            LOGGER.exception(f'Middleware error: {e}')
            response = Error(1000)
            return await response(request)

        try:
            handler, data_handler_cls, request.url_vars = \
                self._router.match(request.path, request.method, request.scheme)
            if not handler:
                response = Error(1001)
                return await response(request)
        except Exception as e:
            LOGGER.exception(e)
            response = Error(1000)
            return await response(request)

        try:
            handler = await self._middleware_processor.run_process_handler(
                request=request,
                handler=handler)
        except ExceptionResponse as e:
            return await e(request)
        except Exception as e:
            LOGGER.exception(f'Middleware error: {e}')
            response = Error(1000)
            return await response(request)

        try:
            data_errors = None
            if data_handler_cls:
                request.cleaned_data, data_errors = \
                    await data_handler_cls(request=request).get_cleaned_data()
            if data_errors:
                response = Error(1002, data=data_errors)
                return await response(request)
        except Exception as e:
            LOGGER.exception(f'Data handler error: {e}')
            response = Error(1000)
            return await response(request)

        try:
            response = await handler(request=request)
        except ExceptionResponse as e:
            return await e(request)
        except Exception as e:
            LOGGER.exception(f'Handler error: {e}')
            response = Error(1000)
            return await response(request)

        try:
            response = await self._middleware_processor.run_process_response(
                request=request,
                response=response)
        except ExceptionResponse as e:
            return await e(request)
        except Exception as e:
            LOGGER.exception(f'Middleware error: {e}')
            response = Error(1000)
            return await response(request)

        return await response(request)

    @staticmethod
    async def _send_response(send, body, status, headers, stream=False):
        await send({
            'type': 'http.response.start',
            'status': status,
            'headers': headers})

        if stream:
            if hasattr(body, '__aiter__'):
                async for chunk in body:
                    await send({
                        'type': 'http.response.body',
                        'body': to_bytes(chunk),
                        'more_body': True})
            else:
                for chunk in body:
                    await send({
                        'type': 'http.response.body',
                        'body': to_bytes(chunk),
                        'more_body': True})
            await send({
                'type': 'http.response.body',
                'body': b''})
        else:
            await send({
                'type': 'http.response.body',
                'body': to_bytes(body)})

    @staticmethod
    async def _get_request_body(receive) -> bytes:
        # Todo: Problem for a huge body ?
        body = b''
        more_body = True
        while more_body:
            message = await receive()
            body += message.get('body', b'')
            more_body = message.get('more_body', False)
        return body

    @staticmethod
    async def _get_request_body_generator(receive):
        more_body = True
        while more_body:
            message = await receive()
            yield message.get('body', b'')
            more_body = message.get('more_body', False)

    def _get_project_apps(self):
        apps: list[dict] = list()
        for package_name in parse_list(self.config['apps']['active']):
            try:
                module = importlib.import_module(f'{package_name}.main')
                app = getattr(module, 'app')
                if isinstance(app, App):
                    apps.append(dict(
                        package_name=package_name,
                        path=os.path.dirname(os.path.abspath(module.__file__)),
                        app=app))
                else:
                    LOGGER.error(f'"{package_name}" app instance error')
            except (ImportError, AttributeError):
                LOGGER.error(f'"{package_name}" app instance import error')
        return apps

    @staticmethod
    def _get_project_path():
        return os.path.dirname(os.path.realpath(inspect.stack()[2].filename))

    @staticmethod
    def _add_project_sys_path(project_path):
        sys.path.insert(0, os.path.dirname(project_path))
Ejemplo n.º 15
0
class PurityClient:

    endpoint: URL
    api_token: str
    api_version: str
    auth_token: ContextVar[str]

    _session: aiohttp.ClientSession
    _auth_token_cvtoken: Token

    def __init__(
        self,
        endpoint: str,
        api_token: str,
        *,
        api_version: str = '1.8',
    ) -> None:
        self.endpoint = URL(endpoint)
        self.api_token = api_token
        self.api_version = api_version
        self.auth_token = ContextVar('auth_token')
        self._session = aiohttp.ClientSession()

    async def aclose(self) -> None:
        await self._session.close()

    async def __aenter__(self) -> PurityClient:
        async with self._session.post(
                self.endpoint / 'api' / 'login',
                headers={'api-token': self.api_token},
                ssl=False,
                raise_for_status=True,
        ) as resp:
            auth_token = resp.headers['x-auth-token']
            self._auth_token_cvtoken = self.auth_token.set(auth_token)
            _ = await resp.json()
        return self

    async def __aexit__(self, *exc_info) -> None:
        self.auth_token.reset(self._auth_token_cvtoken)

    async def get_nfs_metric(
            self, fs_name: str) -> AsyncGenerator[Mapping[str, Any], None]:
        if self.auth_token is None:
            raise RuntimeError(
                'The auth token for Purity API is not initialized.')
        pagination_token = ''
        while True:
            async with self._session.get(
                (self.endpoint / 'api' / self.api_version / 'file-systems' /
                 'performance'),
                    headers={'x-auth-token': self.auth_token.get()},
                    params={
                        'names': fs_name,
                        'protocol': 'NFS',
                        'items_returned': 10,
                        'token': pagination_token,
                    },
                    ssl=False,
                    raise_for_status=True,
            ) as resp:
                data = await resp.json()
                for item in data['items']:
                    yield item
                pagination_token = data['pagination_info'][
                    'continuation_token']
                if pagination_token is None:
                    break
Ejemplo n.º 16
0
def ctxtvar_redefined(var: contextvars.ContextVar, value):
    token = var.set(value)
    try:
        yield value
    finally:
        var.reset(token)
Ejemplo n.º 17
0
def context(var: ContextVar, value: Any):
    token = var.set(value)
    yield
    var.reset(token)
Ejemplo n.º 18
0
class Database(LimitInstances):
    """Represents a database."""

    __instances__: dict[str, Database]

    def __init__(self, name: str):
        self.name = name
        self.user: t.Optional[str] = None
        self.url: t.Optional[str] = None

        self.pool: t.Optional[asyncpg.Pool] = None

        self.type = types
        self.schemas: t.Set[Schema] = set()

        self._mock = False
        self._prepared = False
        self._tracking = ContextVar(f"stmt_tracking:{name}")

    @classmethod
    def connect(cls, name: str, user: str, password: str, *, host: str = "localhost", port: int = 5432) -> Database:
        """Establish the connection URL and name for the database, returning the instance representing it."""
        if len(cls.__instances__) == 1:
            db = cls.__instances__["__default__"]
            cls.__instances__[name] = db
            db.name = name
        else:
            db = Database(name)
        db.user = user
        db.url = f"postgres://{user}:{password}@{host}:{port}/{name}"
        return db

    @property
    def public_schema(self):
        return self.Schema("public")

    def __call__(self, name: str) -> Database:
        """Return the instance representing the given database name."""
        return Database(name)

    def __str__(self):
        """Return the URL representation of the given database instance, if set."""
        return self.url or self.name

    def __repr__(self):
        status = " disabled" if self._mock else ""
        if self.user:
            return f"<Database '{self.name}' user='******'{status}>"
        else:
            return f"<Database '{self.name}'{status}>"

    def __hash__(self):
        return hash(str(self))

    def __eq__(self, other: t.Any):
        if isinstance(other, Database):
            return str(self) == str(other)
        return False

    def __getitem__(self, name: str) -> Database:
        """Retrieve an existing database with the given name."""
        return self.__instances__[name]

    def __delitem__(self, name: str):
        """Delete a database instance by it's name."""
        del self.__instances__[name]

    @classmethod
    def get_default(cls):
        return cls.__instances__["__default__"]

    async def create_pool(self):
        """Create the asyncpg connection pool for this database connection to use."""
        if self.pool:
            self.pool.close()
        if not self.url:
            raise DBError("Please define a connection with Database.connect.")
        self.pool = await asyncpg.create_pool(self.url, init=self._enable_json)  # pragma: no cover

    @staticmethod
    async def _enable_json(conn: asyncpg.Connection):  # pragma: no cover
        await conn.set_type_codec("jsonb", encoder=json.dumps, decoder=json.loads, schema="pg_catalog")
        await conn.set_type_codec("json", encoder=json.dumps, decoder=json.loads, schema="pg_catalog")

    async def prepare(self):
        """Prepare all child objects for this database."""
        for schema in self.schemas:
            await schema.prepare()
        self._prepared = True

    def disable_execution(self):
        """Return generated SQL without executing when Database.execute is used."""
        self._mock = True

    def enable_execution(self):
        """Sets Database.execute to it's normal execution behaviour."""
        self._mock = False

    @contextlib.contextmanager
    def stmt_tracking(self):
        """Collects raw executed statements until exit when execution is disabled."""
        ctx_token = self._tracking.set([])
        try:
            yield self
        finally:
            self._tracking.reset(ctx_token)

    async def close(self):  # pragma: no cover
        """Close the asyncpg connection pool for this database."""
        if self.pool:
            await self.pool.close()

    async def execute(self, sql: str, *args, timeout: t.Optional[float] = None) -> t.Union[str, tuple[str, t.Any]]:
        """Execute an SQL statement."""
        if self._mock:
            try:
                stmt_list = self._tracking.get()
                stmt_list.append((sql, args))
            except LookupError:
                pass
            if not args:
                return sql
            else:
                return sql, *args

        if not self.pool:  # pragma: no cover
            await self.create_pool()
        return await self.pool.execute(sql, *args, timeout=timeout)  # pragma: no cover

    def Schema(self, name: str) -> Schema:
        """Return a bound Schema for this database."""
        s = Schema(name, self)
        self.schemas.add(s)
        return s

    def Table(self, name: str) -> Table:
        """Return a bound Table for the public schema on this database."""
        return Table(name, self)
Ejemplo n.º 19
0
def increase_counter(contextvar: ContextVar) -> Generator:
    token = contextvar.set(contextvar.get() + 1)
    try:
        yield
    finally:
        contextvar.reset(token)
Ejemplo n.º 20
0
Archivo: proxy.py Proyecto: ddqof/proxy
class ProxyServer:
    def __init__(self, port: int = 8080, block_images: bool = False, cfg=None):
        self.connection = ContextVar("connection")
        self.block_images = block_images
        self.port = port
        self._spent_data = {}
        if cfg is not None:
            if isinstance(cfg, dict):
                self._cfg = cfg
            else:
                raise ValueError(f"Config should be {dict.__name__} object")
            for rsc in chain(cfg["limited"], cfg["black-list"]):
                self._spent_data[rsc] = 0
        self.context_token = None

    async def run(self):
        """
        Launch async proxy-server at specified host and port.
        """
        srv = await asyncio.start_server(self._handle_connection, LOCALHOST,
                                         self.port)

        addr = srv.sockets[0].getsockname()
        LOGGER.info(START_SERVER_MSG.format(app_address=addr))

        async with srv:
            await srv.serve_forever()

    async def _handle_connection(self, client_reader: StreamReader,
                                 client_writer: StreamWriter) -> None:
        """
        Handle every client response.
        Called whenever a new connection is established.
        """
        try:
            raw_request = await client_reader.read(CHUNK_SIZE)
            print(raw_request)
            await client_writer.drain()
            if not raw_request:
                return
            pr = ProxyRequest(raw_request, self._cfg)
            LOGGER.info(f"{pr.method:<{len('CONNECT')}} " f"{pr.abs_url}")
            try:
                server_reader, server_writer = await asyncio.open_connection(
                    pr.hostname, pr.port)
            except OSError:
                LOGGER.info(
                    CONNECTION_REFUSED_MSG.format(method=pr.method,
                                                  url=pr.abs_url))
                return
            client_endpoint = Endpoint(client_reader, client_writer)
            server_endpoint = Endpoint(server_reader, server_writer)
            conn = Connection(client_endpoint, server_endpoint, pr,
                              self.block_images)
            self.context_token = self.connection.set(conn)
            if self.block_images and pr.is_image_request:
                await self.connection.get().reset()
                return
            if pr.scheme is HTTPScheme.HTTPS:
                await self._handle_https()
            else:
                await self._handle_http()
        except Exception as e:
            if isinstance(e, ConnectionResetError):
                LOGGER.info(CONNECTION_CLOSED_MSG.format(url=pr.abs_url))
            else:
                LOGGER.exception(e)
                asyncio.get_event_loop().stop()
            if self.context_token is not None:
                self.connection.reset(self.context_token)

    async def _handle_http(self) -> None:
        """
        Send HTTP request and then forwards the following HTTP requests.
        """
        conn = self.connection.get()
        LOGGER.debug(
            HANDLING_HTTP_REQUEST_MSG.format(method=conn.pr.method,
                                             url=conn.pr.abs_url))
        await conn.server.write_and_drain(conn.pr.raw)
        await asyncio.gather(conn.forward_to_client(self._spent_data),
                             conn.forward_to_server())

    async def _handle_https(self) -> None:
        """
        Handles https connection by making HTTP tunnel.
        """
        conn = self.connection.get()
        hostname = conn.pr.hostname
        LOGGER.debug(HANDLING_HTTPS_CONNECTION_MSG.format(url=hostname))
        rsc = conn.pr.restriction
        if rsc:
            if self._spent_data[rsc.initiator] >= rsc.data_limit:
                await conn.reset()
                return
        await conn.client.write_and_drain(CONNECTION_ESTABLISHED_HTTP_MSG)
        LOGGER.debug(CONNECTION_ESTABLISHED_MSG.format(url=conn.pr.abs_url))
        await asyncio.gather(conn.forward_to_server(),
                             conn.forward_to_client(self._spent_data))
Ejemplo n.º 21
0
class Kanata(BaseDispatcher):
    "彼方."

    always = True  # 兼容重构版的 bcc.

    signature_list: List[Union[NormalMatch, PatternReceiver]]
    stop_exec_if_fail: bool = True

    parsed_items: ContextVar[Dict[str, MessageChain]]

    allow_quote: bool
    skip_one_at_in_quote: bool

    content_token: Optional[Token] = None

    def __init__(
        self,
        signature_list: List[Union[NormalMatch, PatternReceiver]],
        stop_exec_if_fail: bool = True,
        allow_quote: bool = True,
        skip_one_at_in_quote: bool = False,
    ) -> None:
        """该魔法方法用于实例化该参数解析器.

        Args:
            signature_list (List[Union[NormalMatch, PatternReceiver]]): 匹配标识链
            stop_exec_if_fail (bool, optional): 是否在无可用匹配时停止监听器执行. Defaults to True.
            allow_quote (bool, optional): 是否允许 Kanata 处理回复消息中的用户输入部分. Defaults to True.
            skip_one_at_in_quote (bool, optional): 是否允许 Kanata 在处理回复消息中的用户输入部分时自动删除可能\
                由 QQ 客户端添加的 At 和一个包含在单独 Plain 元素中的空格. Defaults to False.
        """
        self.signature_list = signature_list
        self.stop_exec_if_fail = stop_exec_if_fail
        self.parsed_items = ContextVar("kanata_parsed_items")
        self.allow_quote = allow_quote
        self.skip_one_at_in_quote = skip_one_at_in_quote

    @staticmethod
    def detect_index(
        signature_chain: Tuple[Union[NormalMatch, PatternReceiver]],
        message_chain: MessageChain,
    ) -> Optional[Dict[str, Tuple[MessageIndex, MessageIndex]]]:
        merged_chain = merge_signature_chain(signature_chain)
        message_chain = message_chain.asMerged()
        element_num = len(message_chain.__root__)
        end_index: MessageIndex = (
            element_num - 1,
            len(message_chain.__root__[-1].text) if element_num != 0
            and message_chain.__root__[-1].__class__ is Plain else None,
        )

        reached_message_index: MessageIndex = (0, None)
        # [0] => real_index
        # [1] => text_index(optional)

        start_index: MessageIndex = (0, None)

        match_result: Dict[Arguments, Tuple[
            MessageIndex, MessageIndex],  # start(include)  # stop(exclude)
                           ] = {}

        signature_iterable = InsertGenerator(enumerate(merged_chain))
        latest_index = None
        matching_recevier: Optional[Arguments] = None

        for signature_index, signature in signature_iterable:
            if isinstance(signature, (Arguments, PatternReceiver)):
                if matching_recevier:  # 已经选中了一个...
                    if isinstance(signature, Arguments):
                        if latest_index == signature_index:
                            matching_recevier.content.extend(signature.content)
                            continue
                        else:
                            raise TypeError(
                                "a unexpected case: match conflict")
                    if isinstance(signature, PatternReceiver):
                        matching_recevier.content.append(signature)
                        continue
                else:
                    if isinstance(signature, PatternReceiver):
                        signature = Arguments([signature])
                matching_recevier = signature
                start_index = reached_message_index
            elif isinstance(signature, NormalMatch):
                if not matching_recevier:
                    # 如果不要求匹配参数, 从当前位置(reached_message_index)开始匹配FullMatch.
                    current_chain = message_chain.subchain(
                        slice(reached_message_index, None, None))
                    if not current_chain.__root__:  # index 越界
                        return
                    if not isinstance(current_chain.__root__[0], Plain):
                        # 切片后第一个 **不是** Plain.
                        return
                    re_match_result = re.match(signature.operator(),
                                               current_chain.__root__[0].text)
                    if not re_match_result:
                        # 不匹配的
                        return
                    # 推进当前进度.
                    plain_text_length = len(current_chain.__root__[0].text)
                    pattern_length = re_match_result.end(
                    ) - re_match_result.start()
                    if (pattern_length + 1) > plain_text_length:  # 推进后可能造成错误
                        # 不推进 text_index 进度, 转而推进 element_index 进度
                        reached_message_index = (reached_message_index[0] + 1,
                                                 None)
                    else:
                        # 推进 element_index 进度至已匹配到的地方后.
                        reached_message_index = (
                            reached_message_index[0],
                            origin_or_zero(reached_message_index[1]) +
                            re_match_result.start() + pattern_length,
                        )
                else:
                    # 需要匹配参数(是否贪婪模式查找, 即是否从后向前)
                    greed = matching_recevier.isGreed
                    for element_index, element in enumerate(
                            message_chain.subchain(
                                slice(reached_message_index, None,
                                      None)).__root__):
                        if isinstance(element, Plain):
                            current_text: str = element.text
                            # 完成贪婪判断
                            text_find_result_list = list(
                                re.finditer(signature.operator(),
                                            current_text))
                            if not text_find_result_list:
                                continue
                            text_find_result = text_find_result_list[-int(greed
                                                                          )]
                            if not text_find_result:
                                continue
                            text_find_index = text_find_result.start()

                            # 找到了! 这里不仅要推进进度, 还要把当前匹配的参数记录结束位置并清理.
                            stop_index = (
                                reached_message_index[0] + element_index +
                                int(element_index == 0),
                                origin_or_zero(reached_message_index[1]) +
                                text_find_index,
                            )
                            match_result[matching_recevier] = (
                                copy.copy(start_index),
                                stop_index,
                            )

                            start_index = (0, None)
                            matching_recevier = None

                            pattern_length = (text_find_result.end() -
                                              text_find_result.start())
                            if (current_text == text_find_result.string[slice(
                                    *text_find_result.span())]):
                                # 此处是如果推进 text_index 就会被爆破....
                                # 推进 element_index 而不是 text_index
                                reached_message_index = (
                                    reached_message_index[0] + element_index +
                                    int(element_index != 0),
                                    None,
                                )
                            else:
                                reached_message_index = (
                                    reached_message_index[0] + element_index,
                                    origin_or_zero(reached_message_index[1]) +
                                    text_find_index + pattern_length,
                                )
                            break
                    else:
                        # 找遍了都没匹配到.
                        return
            latest_index = signature_index
        else:
            if matching_recevier:  # 到达了终点, 却仍然还要做点事的.
                # 计算终点坐标.
                text_index = None

                latest_element = message_chain.__root__[-1]
                if isinstance(latest_element, Plain):
                    text_index = len(latest_element.text)

                stop_index = (len(message_chain.__root__), text_index)
                match_result[matching_recevier] = (start_index, stop_index)
            else:  # 如果不需要继续捕获消息作为参数, 但 Signature 已经无法指示 Message 的样式时, 判定本次匹配非法.
                if reached_message_index < end_index:
                    return

        return match_result

    @staticmethod
    def detect_and_mapping(
        signature_chain: Tuple[Union[NormalMatch, PatternReceiver]],
        message_chain: MessageChain,
    ) -> Optional[Dict[Arguments, MessageChain]]:
        match_result = Kanata.detect_index(signature_chain, message_chain)
        if match_result is not None:
            return {
                k: message_chain[v[0]:(
                    v[1][0],
                    (v[1][1] - (origin_or_zero(v[0][1]) if
                                (v[1][0] <= v[0][0] <= v[1][0]) else 0)
                     ) if v[1][1] is not None else None,
                )]
                for k, v in match_result.items()
            }

    @staticmethod
    def allocation(
        mapping: Dict[Arguments, MessageChain]
    ) -> Optional[Dict[str, MessageChain]]:
        if mapping is None:
            return None
        result = {}
        for arguemnt_set, message_chain in mapping.items():
            length = len(arguemnt_set.content)
            for index, receiver in enumerate(arguemnt_set.content):
                if receiver.name in result:
                    raise ConflictItem(
                        "{0} is defined repeatedly".format(receiver))
                if isinstance(receiver, RequireParam):
                    if not message_chain.__root__:
                        return
                    result[receiver.name] = message_chain
                elif isinstance(receiver, OptionalParam):
                    if not message_chain.__root__:
                        result[receiver.name] = None
                    else:
                        result[receiver.name] = message_chain
                break  # 还没来得及做长度匹配...
        return result

    @lru_cache(None)
    async def catch_argument_names(self) -> List[str]:
        return [
            i.name for i in self.signature_list
            if isinstance(i, PatternReceiver)
        ]

    async def beforeDispatch(self, interface: DispatcherInterface):
        message_chain: MessageChain = (await interface.lookup_param(
            "__kanata_messagechain__", MessageChain, None)).exclude(Source)
        if set([i.__class__ for i in message_chain.__root__
                ]).intersection(BLOCKING_ELEMENTS):
            raise ExecutionStop()
        if self.allow_quote and message_chain.has(Quote):
            # 自动忽略自 Quote 后第一个 At
            # 0: Quote, 1: At, 2: Plain(一个空格, 可能会在以后的 mirai 版本后被其处理, 这里先自动处理这个了.)
            message_chain = message_chain[(3, None):]
            if self.skip_one_at_in_quote and message_chain.__root__:
                if message_chain.__root__[0].__class__ is At:
                    message_chain = message_chain[(
                        1, 1):]  # 利用 MessageIndex 可以非常快捷的实现特性.
        mapping_result = self.detect_and_mapping(self.signature_list,
                                                 message_chain)
        if mapping_result is not None:
            self.content_token = self.parsed_items.set(
                self.allocation(mapping_result))
        else:
            if self.stop_exec_if_fail:
                raise ExecutionStop()

    async def catch(self, interface: DispatcherInterface):
        if not self.content_token:
            return
        random_id = random.random()
        current_item = self.parsed_items.get()
        if current_item is not None:
            result = current_item.get(interface.name, random_id)
            return Force(result) if result is not random_id else None
        else:
            if self.stop_exec_if_fail:
                raise ExecutionStop()

    async def afterDispatch(
        self,
        interface: "DispatcherInterface",
        exception: Optional[Exception] = None,
        tb: Optional[TracebackType] = None,
    ):
        if self.content_token:
            self.parsed_items.reset(self.content_token)
            self.catch = Kanata.catch
            self.content_token = None
Ejemplo n.º 22
0
class PurityClient:

    endpoint: URL
    api_token: str
    api_version: str
    auth_token: ContextVar[str]

    _session: aiohttp.ClientSession
    _auth_token_cvtoken: Token

    def __init__(
        self,
        endpoint: str,
        api_token: str,
        *,
        api_version: str = "1.8",
    ) -> None:
        self.endpoint = URL(endpoint)
        self.api_token = api_token
        self.api_version = api_version
        self.auth_token = ContextVar("auth_token")
        self._session = aiohttp.ClientSession()

    async def aclose(self) -> None:
        await self._session.close()

    async def __aenter__(self) -> PurityClient:
        async with self._session.post(
            self.endpoint / "api" / "login",
            headers={"api-token": self.api_token},
            ssl=False,
            raise_for_status=True,
        ) as resp:
            auth_token = resp.headers["x-auth-token"]
            self._auth_token_cvtoken = self.auth_token.set(auth_token)
            _ = await resp.json()
        return self

    async def __aexit__(self, *exc_info) -> None:
        self.auth_token.reset(self._auth_token_cvtoken)

    # For the concrete API reference, check out:
    # https://purity-fb.readthedocs.io/en/latest/

    async def get_metadata(self) -> Mapping[str, Any]:
        if self.auth_token is None:
            raise RuntimeError("The auth token for Purity API is not initialized.")
        items = []
        pagination_token = ""
        while True:
            async with self._session.get(
                (self.endpoint / "api" / self.api_version / "arrays"),
                headers={"x-auth-token": self.auth_token.get()},
                params={
                    "items_returned": 10,
                    "token": pagination_token,
                },
                ssl=False,
                raise_for_status=True,
            ) as resp:
                data = await resp.json()
                for item in data["items"]:
                    items.append(item)
                pagination_token = data["pagination_info"]["continuation_token"]
                if pagination_token is None:
                    break
        if not items:
            return {}
        first = items[0]
        return {
            "id": first["id"],
            "name": first["name"],
            "os": first["os"],
            "revision": first["revision"],
            "version": first["version"],
            "blade_count": str(len(items)),
            "console_url": str(self.endpoint),
        }

    async def get_nfs_metric(
        self,
        fs_name: str,
    ) -> AsyncGenerator[Mapping[str, Any], None]:
        if self.auth_token is None:
            raise RuntimeError("The auth token for Purity API is not initialized.")
        pagination_token = ""
        while True:
            async with self._session.get(
                (
                    self.endpoint
                    / "api"
                    / self.api_version
                    / "file-systems"
                    / "performance"
                ),
                headers={"x-auth-token": self.auth_token.get()},
                params={
                    "names": fs_name,
                    "protocol": "NFS",
                    "items_returned": 10,
                    "token": pagination_token,
                },
                ssl=False,
                raise_for_status=True,
            ) as resp:
                data = await resp.json()
                for item in data["items"]:
                    yield item
                pagination_token = data["pagination_info"]["continuation_token"]
                if pagination_token is None:
                    break

    async def get_usage(self, fs_name: str) -> Mapping[str, Any]:
        if self.auth_token is None:
            raise RuntimeError("The auth token for Purity API is not initialized.")
        items = []
        pagination_token = ""
        while True:
            async with self._session.get(
                (self.endpoint / "api" / self.api_version / "file-systems"),
                headers={"x-auth-token": self.auth_token.get()},
                params={
                    "names": fs_name,
                    "items_returned": 10,
                    "token": pagination_token,
                },
                ssl=False,
                raise_for_status=True,
            ) as resp:
                data = await resp.json()
                for item in data["items"]:
                    items.append(item)
                pagination_token = data["pagination_info"]["continuation_token"]
                if pagination_token is None:
                    break
        if not items:
            return {}
        first = items[0]
        return {
            "capacity_bytes": data["total"]["provisioned"],
            "used_bytes": first["space"]["total_physical"],
        }