Пример #1
0
    async def _receive_and_check(self, timeout: typing.Optional[float],
                                 /) -> str:
        buff = bytearray()

        while True:
            message = await self.receive(timeout)

            if message.type == aiohttp.WSMsgType.CLOSE:
                close_code = int(message.data)
                reason = message.extra
                self.logger.error("connection closed with code %s (%s)",
                                  close_code, reason)

                can_reconnect = close_code < 4000 or close_code in (
                    errors.ShardCloseCode.UNKNOWN_ERROR,
                    errors.ShardCloseCode.DECODE_ERROR,
                    errors.ShardCloseCode.INVALID_SEQ,
                    errors.ShardCloseCode.SESSION_TIMEOUT,
                    errors.ShardCloseCode.RATE_LIMITED,
                )

                # Assume we can always resume first.
                raise errors.GatewayServerClosedConnectionError(
                    reason, close_code, can_reconnect)

            elif message.type == aiohttp.WSMsgType.CLOSING or message.type == aiohttp.WSMsgType.CLOSED:
                # May be caused by the server shutting us down.
                # May be caused by Windows injecting an EOF if something disconnects, as some
                # network drivers appear to do this.
                raise errors.GatewayError("Socket has closed")

            elif len(buff) != 0 and message.type != aiohttp.WSMsgType.BINARY:
                raise errors.GatewayError(
                    f"Unexpected message type received {message.type.name}, expected BINARY"
                )

            elif message.type == aiohttp.WSMsgType.BINARY:
                buff.extend(message.data)

                if buff.endswith(b"\x00\x00\xff\xff"):
                    return self.zlib.decompress(buff).decode("utf-8")

            elif message.type == aiohttp.WSMsgType.TEXT:
                return message.data  # type: ignore

            else:
                # Assume exception for now.
                ex = self.exception()
                self.logger.warning(
                    "encountered unexpected error: %s",
                    ex,
                    exc_info=ex
                    if self.logger.isEnabledFor(logging.DEBUG) else None,
                )
                raise errors.GatewayError(
                    "Unexpected websocket exception from gateway") from ex
Пример #2
0
    async def _wait_for_hello(self) -> asyncio.Task[bool]:
        # Expect HELLO.
        payload = await self._ws.receive_json()  # type: ignore[union-attr]
        if payload[_OP] != _HELLO:
            self._logger.debug(
                "expected HELLO opcode, received %s which makes no sense, closing with PROTOCOL ERROR ",
                "(_run_once => raise and do not reconnect)",
                payload[_OP],
            )
            await self._ws.send_close(  # type: ignore[union-attr]
                code=errors.ShardCloseCode.PROTOCOL_ERROR,
                message=b"Expected HELLO op",
            )
            raise errors.GatewayError(
                f"Expected opcode {_HELLO}, but received {payload[_OP]}")

        if self._closing.is_set():
            self._logger.debug(
                "closing flag was set before we could handshake, disconnecting with GOING AWAY "
                "(_run_once => do not reconnect)")
            await self._ws.send_close(  # type: ignore[union-attr]
                code=errors.ShardCloseCode.GOING_AWAY,
                message=b"shard disconnecting",
            )
            raise asyncio.CancelledError(
                "closing flag was set before we could handshake")

        heartbeat_interval = float(payload[_D]["heartbeat_interval"]) / 1_000.0
        heartbeat_task = asyncio.create_task(
            self._heartbeat(heartbeat_interval))
        return heartbeat_task
Пример #3
0
    def _handle_other_message(self, message: aiohttp.WSMessage,
                              /) -> typing.NoReturn:
        if message.type == aiohttp.WSMsgType.CLOSE:
            close_code = int(message.data)
            reason = message.extra
            self.logger.error("connection closed with code %s (%s)",
                              close_code, reason)
            can_reconnect = close_code < 4000 or close_code in _RECONNECTABLE_CLOSE_CODES
            raise errors.GatewayServerClosedConnectionError(
                reason, close_code, can_reconnect)

        if message.type == aiohttp.WSMsgType.CLOSING or message.type == aiohttp.WSMsgType.CLOSED:
            # May be caused by the server shutting us down.
            # May be caused by Windows injecting an EOF if something disconnects, as some
            # network drivers appear to do this.
            raise errors.GatewayConnectionError("Socket has closed")

        # Assume exception for now.
        ex = self.exception()
        self.logger.warning(
            "encountered unexpected error: %s",
            ex,
            exc_info=ex if self.logger.isEnabledFor(logging.DEBUG) else None,
        )
        raise errors.GatewayError(
            "Unexpected websocket exception from gateway") from ex
Пример #4
0
    async def test_connect_when_gateway_error_after_connecting(
            self, http_settings, proxy_settings):
        class MockWS(hikari_test_helpers.AsyncContextManagerMock,
                     shard._GatewayTransport):
            closed = False
            sent_close = False
            send_close = mock.AsyncMock()

            def __init__(self):
                pass

        mock_websocket = MockWS()
        mock_client_session = hikari_test_helpers.AsyncContextManagerMock()
        mock_client_session.ws_connect = mock.MagicMock(
            return_value=mock_websocket)

        stack = contextlib.ExitStack()
        sleep = stack.enter_context(mock.patch.object(asyncio, "sleep"))
        stack.enter_context(
            mock.patch.object(aiohttp,
                              "ClientSession",
                              return_value=mock_client_session))
        stack.enter_context(mock.patch.object(aiohttp, "TCPConnector"))
        stack.enter_context(mock.patch.object(aiohttp, "ClientTimeout"))
        stack.enter_context(
            pytest.raises(errors.GatewayError, match="some reason"))
        logger = mock.Mock()
        log_filterer = mock.Mock()

        with stack:
            async with shard._GatewayTransport.connect(
                    http_settings=http_settings,
                    proxy_settings=proxy_settings,
                    logger=logger,
                    url="https://some.url",
                    log_filterer=log_filterer,
            ):
                hikari_test_helpers.raiser(errors.GatewayError("some reason"))

        mock_websocket.send_close.assert_awaited_once_with(
            code=errors.ShardCloseCode.UNEXPECTED_CONDITION,
            message=b"unexpected fatal client error :-(")

        sleep.assert_awaited_once_with(0.25)
        mock_client_session.assert_used_once()
        mock_websocket.assert_used_once()
Пример #5
0
    async def _receive_and_check_complete_zlib_package(
            self, initial_data: bytes, timeout: typing.Optional[float],
            /) -> str:
        buff = bytearray(initial_data)

        while not buff.endswith(_ZLIB_SUFFIX):
            message = await self.receive(timeout)

            if message.type == aiohttp.WSMsgType.BINARY:
                buff.extend(message.data)
                continue

            if message.type == aiohttp.WSMsgType.TEXT:
                raise errors.GatewayError(
                    "Unexpected message type received TEXT, expected BINARY")

            self._handle_other_message(message)

        return self.zlib.decompress(buff).decode("utf-8")
Пример #6
0
    async def _start_one_shard(
        self,
        activity: typing.Optional[presences.Activity],
        afk: bool,
        idle_since: typing.Optional[datetime.datetime],
        status: presences.Status,
        large_threshold: int,
        shard_id: int,
        shard_count: int,
        url: str,
    ) -> shard_impl.GatewayShardImpl:
        new_shard = shard_impl.GatewayShardImpl(
            event_consumer=self._raw_event_consumer,
            http_settings=self._http_settings,
            initial_activity=activity,
            initial_is_afk=afk,
            initial_idle_since=idle_since,
            initial_status=status,
            large_threshold=large_threshold,
            intents=self._intents,
            proxy_settings=self._proxy_settings,
            shard_id=shard_id,
            shard_count=shard_count,
            token=self._token,
            url=url,
        )

        start = time.monotonic()
        await aio.first_completed(new_shard.start(),
                                  self._closing_event.wait())
        end = time.monotonic()

        if new_shard.is_alive:
            _LOGGER.debug("Shard %s started successfully in %.1fms", shard_id,
                          (end - start) * 1_000)
            return new_shard

        raise errors.GatewayError(
            f"shard {shard_id} shut down immediately when starting")
Пример #7
0
 def error(self):
     return errors.GatewayError("some reason")
Пример #8
0
    async def start(
        self,
        *,
        activity: typing.Optional[presences.Activity] = None,
        afk: bool = False,
        check_for_updates: bool = True,
        idle_since: typing.Optional[datetime.datetime] = None,
        ignore_session_start_limit: bool = False,
        large_threshold: int = 250,
        shard_ids: typing.Optional[typing.Set[int]] = None,
        shard_count: typing.Optional[int] = None,
        status: presences.Status = presences.Status.ONLINE,
    ) -> None:
        """Start the bot, wait for all shards to become ready, and then return.

        Other Parameters
        ----------------
        activity : typing.Optional[hikari.presences.Activity]
            The initial activity to display in the bot user presence, or
            `builtins.None` (default) to not show any.
        afk : builtins.bool
            The initial AFK state to display in the bot user presence, or
            `builtins.False` (default) to not show any.
        check_for_updates : builtins.bool
            Defaults to `builtins.True`. If `builtins.True`, will check for
            newer versions of `hikari` on PyPI and notify if available.
        idle_since : typing.Optional[datetime.datetime]
            The `datetime.datetime` the user should be marked as being idle
            since, or `builtins.None` (default) to not show this.
        ignore_session_start_limit : builtins.bool
            Defaults to `builtins.False`. If `builtins.False`, then attempting
            to start more sessions than you are allowed in a 24 hour window
            will throw a `hikari.errors.GatewayError` rather than going ahead
            and hitting the IDENTIFY limit, which may result in your token
            being reset. Setting to `builtins.True` disables this behavior.
        large_threshold : builtins.int
            Threshold for members in a guild before it is treated as being
            "large" and no longer sending member details in the `GUILD CREATE`
            event. Defaults to `250`.
        shard_ids : typing.Optional[typing.Set[builtins.int]]
            The shard IDs to create shards for. If not `builtins.None`, then
            a non-`None` `shard_count` must ALSO be provided. Defaults to
            `builtins.None`, which means the Discord-recommended count is used
            for your application instead.
        shard_count : typing.Optional[builtins.int]
            The number of shards to use in the entire distributed application.
            Defaults to `builtins.None` which results in the count being
            determined dynamically on startup.
        status : hikari.presences.Status
            The initial status to show for the user presence on startup.
            Defaults to `hikari.presences.Status.ONLINE`.
        """
        if shard_ids is not None and shard_count is None:
            raise TypeError(
                "Must pass shard_count if specifying shard_ids manually")

        self._validate_activity(activity)

        # Dispatch the update checker, the sharding requirements checker, and dispatch
        # the starting event together to save a little time on startup.
        start_time = time.monotonic()

        if check_for_updates:
            asyncio.create_task(
                ux.check_for_updates(self._http_settings,
                                     self._proxy_settings),
                name="check for package updates",
            )
        requirements_task = asyncio.create_task(
            self._rest.fetch_gateway_bot(),
            name="fetch gateway sharding settings")
        await self.dispatch(lifetime_events.StartingEvent(app=self))
        requirements = await requirements_task

        if shard_count is None:
            shard_count = requirements.shard_count
        if shard_ids is None:
            shard_ids = set(range(shard_count))

        if requirements.session_start_limit.remaining < len(
                shard_ids) and not ignore_session_start_limit:
            _LOGGER.critical(
                "would have started %s session%s, but you only have %s session%s remaining until %s. Starting more "
                "sessions than you are allowed to start may result in your token being reset. To skip this message, "
                "use bot.run(..., ignore_session_start_limit=True) or bot.start(..., ignore_session_start_limit=True)",
                len(shard_ids),
                "s" if len(shard_ids) != 1 else "",
                requirements.session_start_limit.remaining,
                "s" if requirements.session_start_limit.remaining != 1 else "",
                requirements.session_start_limit.reset_at,
            )
            raise errors.GatewayError(
                "Attempted to start more sessions than were allowed in the given time-window"
            )

        _LOGGER.info(
            "planning to start %s session%s... you can start %s session%s before the next window starts at %s",
            len(shard_ids),
            "s" if len(shard_ids) != 1 else "",
            requirements.session_start_limit.remaining,
            "s" if requirements.session_start_limit.remaining != 1 else "",
            requirements.session_start_limit.reset_at,
        )

        for window_start in range(
                0, shard_count,
                requirements.session_start_limit.max_concurrency):
            window = [
                candidate_shard_id for candidate_shard_id in range(
                    window_start, window_start +
                    requirements.session_start_limit.max_concurrency)
                if candidate_shard_id in shard_ids
            ]

            if not window:
                continue
            if self._shards:
                close_waiter = asyncio.create_task(self._closing_event.wait())
                shard_joiners = [
                    asyncio.ensure_future(s.join())
                    for s in self._shards.values()
                ]

                try:
                    # Attempt to wait for all started shards, for 5 seconds, along with the close
                    # waiter.
                    # If the close flag is set (i.e. user invoked bot.close), or one or more shards
                    # die in this time, we shut down immediately.
                    # If we time out, the joining tasks get discarded and we spin up the next
                    # block of shards, if applicable.
                    _LOGGER.info(
                        "the next startup window is in 5 seconds, please wait..."
                    )
                    await aio.first_completed(
                        aio.all_of(*shard_joiners, timeout=5), close_waiter)

                    if not close_waiter.cancelled():
                        _LOGGER.info(
                            "requested to shut down during startup of shards")
                    else:
                        _LOGGER.critical(
                            "one or more shards shut down unexpectedly during bot startup"
                        )
                    return

                except asyncio.TimeoutError:
                    # If any shards stopped silently, we should close.
                    if any(not s.is_alive for s in self._shards.values()):
                        _LOGGER.info(
                            "one of the shards has been manually shut down (no error), will now shut down"
                        )
                        return
                    # new window starts.

                except Exception as ex:
                    _LOGGER.critical(
                        "an exception occurred in one of the started shards during bot startup: %r",
                        ex)
                    raise

            started_shards = await aio.all_of(
                *(self._start_one_shard(
                    activity=activity,
                    afk=afk,
                    idle_since=idle_since,
                    status=status,
                    large_threshold=large_threshold,
                    shard_id=candidate_shard_id,
                    shard_count=shard_count,
                    url=requirements.url,
                ) for candidate_shard_id in window
                  if candidate_shard_id in shard_ids))

            for started_shard in started_shards:
                self._shards[started_shard.id] = started_shard

        await self.dispatch(lifetime_events.StartedEvent(app=self))

        _LOGGER.info("application started successfully in approx %.2f seconds",
                     time.monotonic() - start_time)
Пример #9
0
    async def connect(
        cls,
        *,
        http_settings: config.HTTPSettings,
        logger: logging.Logger,
        proxy_settings: config.ProxySettings,
        log_filterer: typing.Callable[[str], str],
        url: str,
    ) -> typing.AsyncGenerator[_GatewayTransport, None]:
        """Generate a single-use websocket connection.

        This uses a single connection in a TCP connector pool, with a one-use
        aiohttp client session.

        This also handles waiting for transports to be closed properly first,
        and keeps all of the nested boilerplate out of the way of the
        rest of the code, for the most part anyway.
        """
        exit_stack = contextlib.AsyncExitStack()

        try:
            connector = net.create_tcp_connector(http_settings,
                                                 dns_cache=False,
                                                 limit=1)
            client_session = await exit_stack.enter_async_context(
                net.create_client_session(connector, True, http_settings, True,
                                          proxy_settings.trust_env, cls))

            web_socket = await exit_stack.enter_async_context(
                client_session.ws_connect(
                    max_msg_size=0,
                    proxy=proxy_settings.url,
                    proxy_headers=proxy_settings.headers,
                    url=url,
                ))

            assert isinstance(web_socket, cls)

            raised = False
            try:
                web_socket.logger = logger
                # We store this so we can remove it from debug logs
                # which enables people to send me logs in issues safely.
                # Also MyPy raises a false positive about this...
                web_socket.log_filterer = log_filterer  # type: ignore

                yield web_socket
            except errors.GatewayError:
                raised = True
                raise
            except Exception as ex:
                raised = True
                raise errors.GatewayError(
                    f"Unexpected {type(ex).__name__}: {ex}") from ex
            finally:
                if web_socket.closed:
                    logger.log(ux.TRACE, "ws was already closed")

                elif raised:
                    await web_socket.send_close(
                        code=errors.ShardCloseCode.UNEXPECTED_CONDITION,
                        message=b"unexpected fatal client error :-(",
                    )

                elif not web_socket._closing:
                    # We use a special close code here that prevents Discord
                    # randomly invalidating our session. Undocumented behaviour is
                    # nice like that...
                    await web_socket.send_close(
                        code=_RESUME_CLOSE_CODE,
                        message=b"client is shutting down",
                    )

        except (aiohttp.ClientOSError, aiohttp.ClientConnectionError,
                aiohttp.WSServerHandshakeError) as ex:
            # Windows will sometimes raise an aiohttp.ClientOSError
            # If we cannot do DNS lookup, this will fail with a ClientConnectionError
            # usually.
            raise errors.GatewayConnectionError(
                f"Failed to connect to Discord: {ex!r}") from ex

        finally:
            await exit_stack.aclose()

            # We have to sleep to allow aiohttp time to close SSL transports...
            # https://github.com/aio-libs/aiohttp/issues/1925
            # https://docs.aiohttp.org/en/stable/client_advanced.html#graceful-shutdown
            await asyncio.sleep(0.25)