async def _heartbeat(self, heartbeat_interval: float) -> bool: # Return True if zombied or should reconnect, false if time to die forever. # Prevent immediately zombie-ing. self._last_heartbeat_ack_received = date.monotonic() self._logger.debug("starting heartbeat with interval %ss", heartbeat_interval) while not self._closing.is_set() and not self._closed.is_set(): if self._last_heartbeat_ack_received <= self._last_heartbeat_sent: # Gateway is zombie, close and request reconnect. self._logger.warning( "connection has not received a HEARTBEAT_ACK for approx %.1fs and is being disconnected, " "expect a reconnect shortly", date.monotonic() - self._last_heartbeat_ack_received, ) return True self._logger.log( ux.TRACE, "preparing to send HEARTBEAT [s:%s, interval:%ss]", self._seq, heartbeat_interval) await self._send_heartbeat() try: await asyncio.wait_for(self._closing.wait(), timeout=heartbeat_interval) # We are closing break except asyncio.TimeoutError: # We should continue continue self._logger.debug("heartbeat task is finishing now") return False
async def throttle(self) -> None: """Perform the throttling rate limiter logic. Iterates repeatedly while the queue is not empty, adhering to any rate limits that occur in the mean time. !!! note You should usually not need to invoke this directly, but if you do, ensure to call it using `asyncio.create_task`, and store the task immediately in `throttle_task`. When this coroutine function completes, it will set the `throttle_task` to `builtins.None`. This means you can check if throttling is occurring by checking if `throttle_task` is not `builtins.None`. """ _LOGGER.debug( "you are being rate limited on bucket %s, backing off for %ss", self.name, self.get_time_until_reset(date.monotonic()), ) while self.queue: sleep_for = self.get_time_until_reset(date.monotonic()) await asyncio.sleep(sleep_for) while self.remaining > 0 and self.queue: self.drip() self.queue.pop(0).set_result(None) self.throttle_task = None
def acquire(self) -> asyncio.Future[typing.Any]: """Acquire time on this rate limiter. Returns ------- asyncio.Future[typing.Any] A future that should be immediately awaited. Once the await completes, you are able to proceed with the operation that is under this rate limit. """ loop = asyncio.get_running_loop() future = loop.create_future() # If we are rate limited, delegate invoking this to the throttler and spin it up # if it hasn't started. Likewise, if the throttle task is still running, we should # delegate releasing the future to the throttler task so that we still process # first-come-first-serve if self.throttle_task is not None or self.is_rate_limited( date.monotonic()): self.queue.append(future) if self.throttle_task is None: self.throttle_task = loop.create_task(self.throttle()) else: self.drip() future.set_result(None) return future
def update_rate_limits( self, compiled_route: routes.CompiledRoute, bucket_header: str, remaining_header: int, limit_header: int, date_header: datetime.datetime, reset_at_header: datetime.datetime, ) -> None: """Update the rate limits for a bucket using info from a response. Parameters ---------- compiled_route : hikari.utilities.routes.CompiledRoute The compiled _route to get the bucket for. bucket_header : typing.Optional[builtins.str] The `X-RateLimit-Bucket` header that was provided in the response. remaining_header : builtins.int The `X-RateLimit-Remaining` header cast to an `builtins.int`. limit_header : builtins.int The `X-RateLimit-Limit`header cast to an `builtins.int`. date_header : datetime.datetime The `Date` header value as a `datetime.datetime`. reset_at_header : datetime.datetime The `X-RateLimit-Reset` header value as a `datetime.datetime`. """ self.routes_to_hashes[compiled_route.route] = bucket_header real_bucket_hash = compiled_route.create_real_bucket_hash( bucket_header) reset_after = (reset_at_header - date_header).total_seconds() reset_at_monotonic = date.monotonic() + reset_after if real_bucket_hash in self.real_hashes_to_buckets: bucket = self.real_hashes_to_buckets[real_bucket_hash] _LOGGER.debug( "updating %s with bucket %s [reset-after:%ss, limit:%s, remaining:%s]", compiled_route, real_bucket_hash, reset_after, limit_header, remaining_header, ) else: bucket = RESTBucket(real_bucket_hash, compiled_route) self.real_hashes_to_buckets[real_bucket_hash] = bucket _LOGGER.debug( "remapping %s with bucket %s [reset-after:%ss, limit:%s, remaining:%s]", compiled_route, real_bucket_hash, reset_after, limit_header, remaining_header, ) bucket.update_rate_limit(remaining_header, limit_header, reset_at_monotonic)
def do_gc_pass(self, expire_after: float) -> None: """Perform a single garbage collection pass. This will assess any routes stored in the internal mappings of this object and remove any that are deemed to be inactive or dead in order to save memory. If the removed routes are used again in the future, they will be re-cached automatically. Parameters ---------- expire_after : builtins.float Time after which the last `reset_at` was hit for a bucket to\ remove it. Defaults to `reset_at` + 20 seconds. Higher values will retain unneeded ratelimit info for longer, but may produce more effective ratelimiting logic as a result. !!! warning You generally have no need to invoke this directly. Use `RESTBucketManager.start` and `RESTBucketManager.close` to control this instead. """ buckets_to_purge = [] now = date.monotonic() # We have three main states that a bucket can be in: # 1. active - the bucket is active and is not at risk of deallocation # 2. survival - the bucket is inactive but is still fresh enough to be kept alive. # 3. death - the bucket has been inactive for too long. active = 0 # Discover and purge bucket_pairs = self.real_hashes_to_buckets.items() for full_hash, bucket in bucket_pairs: if bucket.is_empty and bucket.reset_at + expire_after < now: # If it is still running a throttle and is in memory, it will remain in memory # but we will not know about it. buckets_to_purge.append(full_hash) if bucket.reset_at >= now: active += 1 dead = len(buckets_to_purge) total = len(bucket_pairs) survival = total - active - dead for full_hash in buckets_to_purge: self.real_hashes_to_buckets[full_hash].close() del self.real_hashes_to_buckets[full_hash] _LOGGER.log( ux.TRACE, "purged %s stale buckets, %s remain in survival, %s active", dead, survival, active)
async def _start_one_shard( self, activity: typing.Optional[presences.Activity], afk: bool, idle_since: typing.Optional[datetime.datetime], status: presences.Status, large_threshold: int, shard_id: int, shard_count: int, url: str, ) -> shard_impl.GatewayShardImpl: new_shard = shard_impl.GatewayShardImpl( event_consumer=self._raw_event_consumer, http_settings=self._http_settings, initial_activity=activity, initial_is_afk=afk, initial_idle_since=idle_since, initial_status=status, large_threshold=large_threshold, intents=self._intents, proxy_settings=self._proxy_settings, shard_id=shard_id, shard_count=shard_count, token=self._token, url=url, ) start = date.monotonic() await aio.first_completed(new_shard.start(), self._closing_event.wait()) end = date.monotonic() if new_shard.is_alive: _LOGGER.debug("Shard %s started successfully in %.1fms", shard_id, (end - start) * 1_000) return new_shard raise errors.GatewayError( f"Shard {shard_id} shut down immediately when starting")
async def _poll_events(self) -> typing.Optional[bool]: payload = await self._ws.receive_json(timeout=5 ) # type: ignore[union-attr] op = payload[_OP] # opcode int d = payload[ _D] # data/payload. Usually a dict or a bool for INVALID_SESSION if op == _DISPATCH: t = payload[_T] # event name str s = payload[_S] # seq int self._logger.log(ux.TRACE, "dispatching %s with seq %s", t, s) self._dispatch(t, s, d) elif op == _HEARTBEAT: await self._send_heartbeat_ack() self._logger.log(ux.TRACE, "sent HEARTBEAT") elif op == _HEARTBEAT_ACK: now = date.monotonic() self._last_heartbeat_ack_received = now self._heartbeat_latency = now - self._last_heartbeat_sent self._logger.log(ux.TRACE, "received HEARTBEAT ACK in %.1fms", self._heartbeat_latency * 1_000) elif op == _RECONNECT: # We should be able to resume... self._logger.info( "received instruction to reconnect, will resume existing session" ) return True elif op == _INVALID_SESSION: # We can resume if the payload was `true`. if not d: self._logger.info( "received invalid session, will need to start a new session" ) self._seq = None self._session_id = None else: self._logger.info( "received invalid session, will resume existing session") return True else: self._logger.log( ux.TRACE, "unknown opcode %s received, it will be ignored...", op) return None
def update_rate_limit(self, remaining: int, limit: int, reset_at: float) -> None: """Amend the rate limit. Parameters ---------- remaining : builtins.int The calls remaining in this time window. limit : builtins.int The total calls allowed in this time window. reset_at : builtins.float The epoch at which to reset the limit. !!! note The `reset_at` epoch is expected to be a `date.monotonic_timestamp` monotonic epoch, rather than a `time.time` date-based epoch. """ self.remaining = remaining self.limit = limit self.reset_at = reset_at self.period = max(0.0, self.reset_at - date.monotonic())
async def start( self, *, activity: typing.Optional[presences.Activity] = None, afk: bool = False, check_for_updates: bool = True, idle_since: typing.Optional[datetime.datetime] = None, ignore_session_start_limit: bool = False, large_threshold: int = 250, shard_ids: typing.Optional[typing.Set[int]] = None, shard_count: typing.Optional[int] = None, status: presences.Status = presences.Status.ONLINE, ) -> None: """Start the bot, wait for all shards to become ready, and then return. Other Parameters ---------------- activity : typing.Optional[hikari.presences.Activity] The initial activity to display in the bot user presence, or `builtins.None` (default) to not show any. afk : builtins.bool The initial AFK state to display in the bot user presence, or `builtins.False` (default) to not show any. check_for_updates : builtins.bool Defaults to `builtins.True`. If `builtins.True`, will check for newer versions of `hikari` on PyPI and notify if available. idle_since : typing.Optional[datetime.datetime] The `datetime.datetime` the user should be marked as being idle since, or `builtins.None` (default) to not show this. ignore_session_start_limit : builtins.bool Defaults to `builtins.False`. If `builtins.False`, then attempting to start more sessions than you are allowed in a 24 hour window will throw a `hikari.errors.GatewayError` rather than going ahead and hitting the IDENTIFY limit, which may result in your token being reset. Setting to `builtins.True` disables this behavior. large_threshold : builtins.int Threshold for members in a guild before it is treated as being "large" and no longer sending member details in the `GUILD CREATE` event. Defaults to `250`. shard_ids : typing.Optional[typing.Set[builtins.int]] The shard IDs to create shards for. If not `builtins.None`, then a non-`None` `shard_count` must ALSO be provided. Defaults to `builtins.None`, which means the Discord-recommended count is used for your application instead. shard_count : typing.Optional[builtins.int] The number of shards to use in the entire distributed application. Defaults to `builtins.None` which results in the count being determined dynamically on startup. status : hikari.presences.Status The initial status to show for the user presence on startup. Defaults to `hikari.presences.Status.ONLINE`. """ if shard_ids is not None and shard_count is None: raise TypeError( "Must pass shard_count if specifying shard_ids manually") # Dispatch the update checker, the sharding requirements checker, and dispatch # the starting event together to save a little time on startup. start_time = date.monotonic() if check_for_updates: asyncio.create_task( ux.check_for_updates(self._http_settings, self._proxy_settings), name="check for package updates", ) requirements_task = asyncio.create_task( self._rest.fetch_gateway_bot(), name="fetch gateway sharding settings") await self.dispatch(lifetime_events.StartingEvent(app=self)) requirements = await requirements_task if shard_count is None: shard_count = requirements.shard_count if shard_ids is None: shard_ids = set(range(shard_count)) if requirements.session_start_limit.remaining < len( shard_ids) and not ignore_session_start_limit: _LOGGER.critical( "would have started %s session%s, but you only have %s session%s remaining until %s. Starting more " "sessions than you are allowed to start may result in your token being reset. To skip this message, " "use bot.run(..., ignore_session_start_limit=True) or bot.start(..., ignore_session_start_limit=True)", len(shard_ids), "s" if len(shard_ids) != 1 else "", requirements.session_start_limit.remaining, "s" if requirements.session_start_limit.remaining != 1 else "", requirements.session_start_limit.reset_at, ) raise errors.GatewayError( "Attempted to start more sessions than were allowed in the given time-window" ) _LOGGER.info( "planning to start %s session%s... you can start %s session%s before the next window starts at %s", len(shard_ids), "s" if len(shard_ids) != 1 else "", requirements.session_start_limit.remaining, "s" if requirements.session_start_limit.remaining != 1 else "", requirements.session_start_limit.reset_at, ) for window_start in range( 0, shard_count, requirements.session_start_limit.max_concurrency): window = [ candidate_shard_id for candidate_shard_id in range( window_start, window_start + requirements.session_start_limit.max_concurrency) if candidate_shard_id in shard_ids ] if not window: continue if self._shards: close_waiter = asyncio.create_task(self._closing_event.wait()) shard_joiners = [ asyncio.ensure_future(s.join()) for s in self._shards.values() ] try: # Attempt to wait for all started shards, for 5 seconds, along with the close # waiter. # If the close flag is set (i.e. user invoked bot.close), or one or more shards # die in this time, we shut down immediately. # If we time out, the joining tasks get discarded and we spin up the next # block of shards, if applicable. _LOGGER.info( "the next startup window is in 5 seconds, please wait..." ) await aio.first_completed( aio.all_of(*shard_joiners, timeout=5), close_waiter) if not close_waiter.cancelled(): _LOGGER.info( "requested to shut down during startup of shards") else: _LOGGER.critical( "one or more shards shut down unexpectedly during bot startup" ) return except asyncio.TimeoutError: # If any shards stopped silently, we should close. if any(not s.is_alive for s in self._shards.values()): _LOGGER.info( "one of the shards has been manually shut down (no error), will now shut down" ) return # new window starts. except Exception as ex: _LOGGER.critical( "an exception occurred in one of the started shards during bot startup: %r", ex) raise started_shards = await aio.all_of( *(self._start_one_shard( activity=activity, afk=afk, idle_since=idle_since, status=status, large_threshold=large_threshold, shard_id=candidate_shard_id, shard_count=shard_count, url=requirements.url, ) for candidate_shard_id in window if candidate_shard_id in shard_ids)) for started_shard in started_shards: self._shards[started_shard.id] = started_shard await self.dispatch(lifetime_events.StartedEvent(app=self)) _LOGGER.info("application started successfully in approx %.0f seconds", date.monotonic() - start_time)
async def _send_heartbeat(self) -> None: await self._ws.send_json({ _OP: _HEARTBEAT, _D: self._seq }) # type: ignore[union-attr] self._last_heartbeat_sent = date.monotonic()
async def _run_once(self) -> bool: self._handshake_completed.clear() dispatch_disconnect = False exit_stack = contextlib.AsyncExitStack() self._ws = await exit_stack.enter_async_context( _V6GatewayTransport.connect( http_settings=self._http_settings, log_filterer=_log_filterer(self._token), logger=self._logger, proxy_settings=self._proxy_settings, url=self._url, )) try: # Dispatch CONNECTED synthetic event. self._event_consumer(self, "CONNECTED", {}) dispatch_disconnect = True heartbeat_task = await self._wait_for_hello() try: if self._seq is not None: self._logger.debug("resuming session %s", self._session_id) await self._resume() else: self._logger.debug("identifying with new session") await self._identify() if self._closing.is_set(): self._logger.debug( "closing flag was set during handshake, disconnecting with GOING AWAY " "(_run_once => do not reconnect)") await self._ws.send_close( # type: ignore[union-attr] code=errors.ShardCloseCode.GOING_AWAY, message=b"shard disconnecting") return False # Event polling. while not self._closing.is_set() and not heartbeat_task.done( ) and not heartbeat_task.cancelled(): try: result = await self._poll_events() if result is not None: return result except asyncio.TimeoutError: # We should check if the shard is still alive and then poll again after. pass # If the heartbeat died due to an error, it should be raised here. # This will currently allow us to try to resume if that happens # We return True if zombied. if await heartbeat_task: now = date.monotonic() self._logger.error( "connection is a zombie, last heartbeat sent %.2fs ago", now - self._last_heartbeat_sent, ) self._logger.debug( "will attempt to reconnect (_run_once => reconnect)") return True self._logger.debug( "shard has requested graceful termination, so will not attempt to reconnect " "(_run_once => do not reconnect)") await self._ws.send_close( # type: ignore[union-attr] code=errors.ShardCloseCode.GOING_AWAY, message=b"shard disconnecting", ) return False finally: heartbeat_task.cancel() finally: ws = self._ws self._ws = None await exit_stack.aclose() if dispatch_disconnect: # If we managed to connect, we must always send the DISCONNECT event # afterwards. self._event_consumer(self, "DISCONNECTED", {}) # Check if we made the socket close or handled it. If we didn't, we should always try to # reconnect, as aiohttp is probably closing it internally without telling us properly. if not ws.sent_close: # type: ignore[union-attr] return True
async def _run(self) -> None: self._closed.clear() self._closing.clear() last_started_at = -float("inf") backoff = rate_limits.ExponentialBackOff( base=_BACKOFF_BASE, maximum=_BACKOFF_CAP, initial_increment=_BACKOFF_INCREMENT_START, ) try: while not self._closing.is_set() and not self._closed.is_set(): if date.monotonic() - last_started_at < _BACKOFF_WINDOW: time = next(backoff) self._logger.info("backing off reconnecting for %.2fs", time) try: await asyncio.wait_for(self._closing.wait(), timeout=time) # We were told to close. return except asyncio.TimeoutError: # We are going to run once. pass try: last_started_at = date.monotonic() should_restart = await self._run_once() if not should_restart: self._logger.info( "shard has disconnected and shut down normally") return except errors.GatewayConnectionError as ex: self._logger.error( "failed to communicate with server, reason was: %s. Will retry shortly", ex.__cause__, ) except errors.GatewayServerClosedConnectionError as ex: if not ex.can_reconnect: raise self._logger.info( "server has closed connection, will reconnect if possible [code:%s, reason:%s]", ex.code, ex.reason, ) # We don't want to back off from this. If Discord keep closing the connection, it is their issue. # If we back off here, we'll find a mass outage will prevent shards from becoming healthy on # reconnect in large sharded bots for a very long period of time. backoff.reset() except errors.GatewayError as ex: self._logger.error("encountered generic gateway error", exc_info=ex) raise except Exception as ex: self._logger.error("encountered some unhandled error", exc_info=ex) raise finally: self._closing.set() self._closed.set()