def dispatch( self, event: event_manager.EventT_inv) -> asyncio.Future[typing.Any]: if not isinstance(event, base_events.Event): raise TypeError( f"Events must be subclasses of {base_events.Event.__name__}, not {type(event).__name__}" ) # We only need to iterate through the MRO until we hit Event, as # anything after that is random garbage we don't care about, as they do # not describe event types. This improves efficiency as well. mro = type(event).mro() tasks: typing.List[typing.Coroutine[None, typing.Any, None]] = [] for cls in mro[:mro.index(base_events.Event) + 1]: if cls in self._listeners: for callback in self._listeners[cls]: tasks.append(self._invoke_callback(callback, event)) if cls in self._waiters: waiter_set = self._waiters[cls] for predicate, future in tuple(waiter_set): try: result = predicate(event) if not result: continue except Exception as ex: future.set_exception(ex) else: future.set_result(event) waiter_set.remove((predicate, future)) return asyncio.gather(*tasks) if tasks else aio.completed_future()
def acquire( self, max_rate_limit: float = float("inf")) -> asyncio.Future[None]: """Acquire time on this rate limiter. !!! note You should afterwards invoke `RESTBucket.update_rate_limit` to update any rate limit information you are made aware of. Parameters ---------- max_rate_limit : builtins.float The max number of seconds to backoff for when rate limited. Anything greater than this will instead raise an error. The default is an infinite value, which will thus never time out. Returns ------- asyncio.Future[builtins.None] A future that should be awaited immediately. Once the future completes, you are allowed to proceed with your operation. If the reset-after time for the bucket is greater than `max_rate_limit`, then this will contain `RateLimitTooLongError` as an exception. """ return aio.completed_future( None) if self.is_unknown else super().acquire()
def acquire(self) -> asyncio.Future[None]: """Acquire time on this rate limiter. !!! note You should afterwards invoke `RESTBucket.update_rate_limit` to update any rate limit information you are made aware of. Returns ------- asyncio.Future[builtins.None] A future that should be awaited immediately. Once the future completes, you are allowed to proceed with your operation. """ return aio.completed_future( None) if self.is_unknown else super().acquire()
async def test_non_default_result(self): assert aio.completed_future(...).result() is ...
async def test_default_result_is_none(self): assert aio.completed_future().result() is None
async def test_is_completed(self, args): future = aio.completed_future(*args) assert future.done()
class TestGatewayShardImpl: @pytest.fixture() def client_session(self): stub = client_session_stub.ClientSessionStub() with mock.patch.object(aiohttp, "ClientSession", new=stub): yield stub @pytest.fixture(scope="module") def unslotted_client_type(self): return hikari_test_helpers.mock_class_namespace(shard.GatewayShardImpl, slots_=False) @pytest.fixture() def client(self, http_settings, proxy_settings, unslotted_client_type): return unslotted_client_type( url="wss://gateway.discord.gg", intents=intents.Intents.ALL, token="lol", event_consumer=mock.Mock(), http_settings=http_settings, proxy_settings=proxy_settings, ) @pytest.mark.parametrize( ("compression", "expect"), [ (None, f"v={shard._VERSION}&encoding=json"), ("payload_zlib_stream", f"v={shard._VERSION}&encoding=json&compress=zlib-stream"), ], ) def test__init__sets_url_is_correct_json(self, compression, expect, http_settings, proxy_settings): g = shard.GatewayShardImpl( event_consumer=mock.Mock(), http_settings=http_settings, proxy_settings=proxy_settings, intents=intents.Intents.ALL, url="wss://gaytewhuy.discord.meh", data_format="json", compression=compression, token="12345", ) assert g._url == f"wss://gaytewhuy.discord.meh?{expect}" def test_using_etf_is_unsupported(self, http_settings, proxy_settings): with pytest.raises(NotImplementedError, match="Unsupported gateway data format: etf"): shard.GatewayShardImpl( event_consumer=mock.Mock(), http_settings=http_settings, proxy_settings=proxy_settings, token=mock.Mock(), url="wss://erlpack-is-broken-lol.discord.meh", intents=intents.Intents.ALL, data_format="etf", compression=True, ) def test_heartbeat_latency_property(self, client): client._heartbeat_latency = 420 assert client.heartbeat_latency == 420 def test_id_property(self, client): client._shard_id = 101 assert client.id == 101 def test_intents_property(self, client): intents = object() client._intents = intents assert client.intents is intents @pytest.mark.parametrize( ("run_task", "expected"), [ (None, False), (asyncio.get_event_loop().create_future(), True), (aio.completed_future(), False), ], ) def test_is_alive_property(self, run_task, expected, client): client._run_task = run_task assert client.is_alive is expected def test_shard_count_property(self, client): client._shard_count = 69 assert client.shard_count == 69 async def test_close_when_closing_set(self, client): client._closing = mock.Mock(is_set=mock.Mock(return_value=True)) client._ws = mock.Mock() client._chunking_rate_limit = mock.Mock() client._total_rate_limit = mock.Mock() await client.close() client._closing.set.assert_not_called() client._ws.close.assert_not_called() client._chunking_rate_limit.close.assert_not_called() client._total_rate_limit.close.assert_not_called() async def test_close_when_closing_not_set(self, client): client._closing = mock.Mock(is_set=mock.Mock(return_value=False)) client._ws = mock.Mock(close=mock.AsyncMock()) client._chunking_rate_limit = mock.Mock() client._total_rate_limit = mock.Mock() await client.close() client._closing.set.assert_called_once_with() client._ws.close.assert_awaited_once_with( code=errors.ShardCloseCode.GOING_AWAY, message=b"shard disconnecting") client._chunking_rate_limit.close.assert_called_once_with() client._total_rate_limit.close.assert_called_once_with() async def test_close_when_closing_not_set_and_ws_is_None(self, client): client._closing = mock.Mock(is_set=mock.Mock(return_value=False)) client._ws = None client._chunking_rate_limit = mock.Mock() client._total_rate_limit = mock.Mock() await client.close() client._closing.set.assert_called_once_with() client._chunking_rate_limit.close.assert_called_once_with() client._total_rate_limit.close.assert_called_once_with() async def test_when__user_id_is_None(self, client): client._handshake_completed = mock.Mock(wait=mock.AsyncMock()) client._user_id = None with pytest.raises(RuntimeError): assert await client.get_user_id() async def test_when__user_id_is_not_None(self, client): client._handshake_completed = mock.Mock(wait=mock.AsyncMock()) client._user_id = 123 assert await client.get_user_id() == 123 async def test_join(self, client): client._closed = mock.Mock(wait=mock.AsyncMock()) await client.join() client._closed.wait.assert_awaited_once_with() async def test_request_guild_members_when_no_query_and_no_limit_and_GUILD_MEMBERS_not_enabled( self, client): client._intents = intents.Intents.GUILD_INTEGRATIONS with pytest.raises(errors.MissingIntentError): await client.request_guild_members(123, query="", limit=0) async def test_request_guild_members_when_presences_and_GUILD_PRESENCES_not_enabled( self, client): client._intents = intents.Intents.GUILD_INTEGRATIONS with pytest.raises(errors.MissingIntentError): await client.request_guild_members(123, query="test", limit=1, include_presences=True) async def test_request_guild_members_when_presences_false_and_GUILD_PRESENCES_not_enabled( self, client): client._intents = intents.Intents.GUILD_INTEGRATIONS client._ws = mock.Mock(send_json=mock.AsyncMock()) await client.request_guild_members(123, query="test", limit=1, include_presences=False) client._ws.send_json.assert_awaited_once_with({ "op": 8, "d": { "guild_id": "123", "query": "test", "presences": False, "limit": 1 }, }) @pytest.mark.parametrize("kwargs", [{"query": "some query"}, {"limit": 1}]) async def test_request_guild_members_when_specifiying_users_with_limit_or_query( self, client, kwargs): client._intents = intents.Intents.GUILD_INTEGRATIONS with pytest.raises(ValueError, match="Cannot specify limit/query with users"): await client.request_guild_members(123, users=[], **kwargs) @pytest.mark.parametrize("limit", [-1, 101]) async def test_request_guild_members_when_limit_under_0_or_over_100( self, client, limit): client._intents = intents.Intents.ALL with pytest.raises( ValueError, match="'limit' must be between 0 and 100, both inclusive"): await client.request_guild_members(123, limit=limit) async def test_request_guild_members_when_users_over_100(self, client): client._intents = intents.Intents.ALL with pytest.raises(ValueError, match="'users' is limited to 100 users"): await client.request_guild_members(123, users=range(101)) async def test_request_guild_members_when_nonce_over_32_chars( self, client): client._intents = intents.Intents.ALL with pytest.raises( ValueError, match="'nonce' can be no longer than 32 byte characters long." ): await client.request_guild_members(123, nonce="x" * 33) @pytest.mark.parametrize("include_presences", [True, False]) async def test_request_guild_members(self, client, include_presences): client._intents = intents.Intents.ALL client._ws = mock.Mock(send_json=mock.AsyncMock()) await client.request_guild_members(123, include_presences=include_presences) client._ws.send_json.assert_awaited_once_with({ "op": 8, "d": { "guild_id": "123", "query": "", "presences": include_presences, "limit": 0 }, }) async def test_start_when_already_running(self, client): client._run_task = object() with pytest.raises( RuntimeError, match= "Cannot run more than one instance of one shard concurrently"): await client.start() async def test_start_when_shard_closed_before_starting(self, client): client._run_task = None client._shard_id = 20 client._run = mock.Mock() client._handshake_completed = mock.Mock(wait=mock.Mock()) run_task = mock.Mock() waiter = mock.Mock() stack = contextlib.ExitStack() create_task = stack.enter_context( mock.patch.object(asyncio, "create_task", side_effect=[run_task, waiter])) wait = stack.enter_context( mock.patch.object(asyncio, "wait", return_value=([run_task], [waiter]))) stack.enter_context( pytest.raises( asyncio.CancelledError, match="shard 20 was closed before it could start successfully") ) with stack: await client.start() assert client._run_task is None assert create_task.call_count == 2 create_task.has_call(mock.call(client._run(), name="run shard 20")) create_task.has_call( mock.call(client._handshake_completed.wait(), name="wait for shard 20 to start")) run_task.result.assert_called_once_with() waiter.cancel.assert_called_once_with() wait.assert_awaited_once_with((waiter, run_task), return_when=asyncio.FIRST_COMPLETED) async def test_start(self, client): client._run_task = None client._shard_id = 20 client._run = mock.Mock() client._handshake_completed = mock.Mock(wait=mock.Mock()) run_task = mock.Mock() waiter = mock.Mock() with mock.patch.object(asyncio, "create_task", side_effect=[run_task, waiter]) as create_task: with mock.patch.object(asyncio, "wait", return_value=([waiter], [run_task ])) as wait: await client.start() assert client._run_task == run_task assert create_task.call_count == 2 create_task.has_call(mock.call(client._run(), name="run shard 20")) create_task.has_call( mock.call(client._handshake_completed.wait(), name="wait for shard 20 to start")) run_task.result.assert_not_called() waiter.cancel.assert_called_once_with() wait.assert_awaited_once_with((waiter, run_task), return_when=asyncio.FIRST_COMPLETED) async def test_update_presence(self, client): presence_payload = object() client._ws = mock.Mock(send_json=mock.AsyncMock()) client._serialize_and_store_presence_payload = mock.Mock( return_value=presence_payload) client._send_json = mock.AsyncMock() await client.update_presence( idle_since=datetime.datetime.now(), afk=True, status=presences.Status.IDLE, activity=None, ) client._ws.send_json.assert_awaited_once_with({ "op": 3, "d": presence_payload }) @pytest.mark.parametrize("channel", [12345, None]) @pytest.mark.parametrize("self_deaf", [True, False]) @pytest.mark.parametrize("self_mute", [True, False]) async def test_update_voice_state(self, client, channel, self_deaf, self_mute): client._ws = mock.Mock(send_json=mock.AsyncMock()) payload = { "channel_id": str(channel) if channel is not None else None, "guild_id": "6969420", "deaf": self_deaf, "mute": self_mute, } await client.update_voice_state("6969420", channel, self_mute=self_mute, self_deaf=self_deaf) client._ws.send_json.assert_awaited_once_with({"op": 4, "d": payload}) def test_dispatch_when_READY(self, client): client._seq = 0 client._session_id = 0 client._user_id = 0 client._logger = mock.Mock() client._handshake_completed = mock.Mock() client._event_consumer = mock.Mock() pl = { "session_id": 101, "user": { "id": 123, "username": "******", "discriminator": "5863" }, "guilds": [ { "id": "123" }, { "id": "456" }, { "id": "789" }, ], "v": 8, } client._dispatch( "READY", 10, pl, ) assert client._seq == 10 assert client._session_id == 101 assert client._user_id == 123 client._logger.info.assert_called_once_with( "shard is ready: %s guilds, %s (%s), session %r on v%s gateway", 3, "hikari#5863", 123, 101, 8, ) client._handshake_completed.set.assert_called_once_with() client._event_consumer.assert_called_once_with( client, "READY", pl, ) def test__dipatch_when_RESUME(self, client): client._seq = 0 client._session_id = 123 client._logger = mock.Mock() client._handshake_completed = mock.Mock() client._event_consumer = mock.Mock() client._dispatch("RESUME", 10, {}) assert client._seq == 10 client._logger.info.assert_called_once_with( "shard has resumed [session:%s, seq:%s]", 123, 10) client._handshake_completed.set.assert_called_once_with() client._event_consumer.assert_called_once_with(client, "RESUME", {}) def test__dipatch(self, client): client._logger = mock.Mock() client._handshake_completed = mock.Mock() client._event_consumer = mock.Mock() client._dispatch("EVENT NAME", 10, {"payload": None}) client._logger.info.assert_not_called() client._handshake_completed.set.assert_not_called() client._event_consumer.assert_called_once_with(client, "EVENT NAME", {"payload": None}) async def test__identify(self, client): client._token = "token" client._intents = intents.Intents.ALL client._large_threshold = 123 client._shard_id = 0 client._shard_count = 1 client._serialize_and_store_presence_payload = mock.Mock( return_value={"presence": "payload"}) client._ws = mock.Mock(send_json=mock.AsyncMock()) stack = contextlib.ExitStack() stack.enter_context( mock.patch.object(platform, "system", return_value="Potato PC")) stack.enter_context( mock.patch.object(platform, "architecture", return_value=["ARM64"])) stack.enter_context( mock.patch.object(aiohttp, "__version__", new="v0.0.1")) stack.enter_context( mock.patch.object(_about, "__version__", new="v1.0.0")) with stack: await client._identify() expected_json = { "op": 2, "d": { "token": "token", "compress": False, "large_threshold": 123, "properties": { "$os": "Potato PC ARM64", "$browser": "aiohttp v0.0.1", "$device": "hikari v1.0.0", }, "shard": [0, 1], "intents": 32767, "presence": { "presence": "payload" }, }, } client._ws.send_json.assert_awaited_once_with(expected_json) @hikari_test_helpers.timeout() async def test__heartbeat(self, client): client._last_heartbeat_sent = 5 client._logger = mock.Mock() client._closing = mock.Mock(is_set=mock.Mock(return_value=False)) client._closed = mock.Mock(is_set=mock.Mock(return_value=False)) client._send_heartbeat = mock.AsyncMock() with mock.patch.object(time, "monotonic", return_value=10): with mock.patch.object(asyncio, "wait_for", side_effect=[asyncio.TimeoutError, None]) as wait_for: assert await client._heartbeat(20) is False wait_for.assert_awaited_with(client._closing.wait(), timeout=20) @hikari_test_helpers.timeout() async def test__heartbeat_when_zombie(self, client): client._last_heartbeat_sent = 10 client._logger = mock.Mock() with mock.patch.object(time, "monotonic", return_value=5): with mock.patch.object(asyncio, "wait_for") as wait_for: assert await client._heartbeat(20) is True wait_for.assert_not_called() async def test__resume(self, client): client._token = "token" client._seq = 123 client._session_id = 456 client._ws = mock.Mock(send_json=mock.AsyncMock()) await client._resume() expected_json = { "op": 6, "d": { "token": "token", "seq": 123, "session_id": 456 }, } client._ws.send_json.assert_awaited_once_with(expected_json) @pytest.mark.skip("TODO") async def test__run(self, client): ... @pytest.mark.skip("TODO") async def test__run_once(self, client): ... async def test__send_heartbeat(self, client): client._ws = mock.Mock(send_json=mock.AsyncMock()) client._last_heartbeat_sent = 0 client._seq = 10 with mock.patch.object(time, "monotonic", return_value=200): await client._send_heartbeat() client._ws.send_json.assert_awaited_once_with({"op": 1, "d": 10}) assert client._last_heartbeat_sent == 200 async def test__send_heartbeat_ack(self, client): client._ws = mock.Mock(send_json=mock.AsyncMock()) await client._send_heartbeat_ack() client._ws.send_json.assert_awaited_once_with({"op": 11, "d": None}) def test__serialize_activity_when_activity_is_None(self, client): assert client._serialize_activity(None) is None def test__serialize_activity_when_activity_is_not_None(self, client): activity = mock.Mock(type="0", url="https://some.url") activity.name = "some name" # This has to be set seperate because if not, its set as the mock's name assert client._serialize_activity(activity) == { "name": "some name", "type": 0, "url": "https://some.url" } @pytest.mark.parametrize("idle_since", [datetime.datetime.now(), None]) @pytest.mark.parametrize("afk", [True, False]) @pytest.mark.parametrize( "status", [ presences.Status.DO_NOT_DISTURB, presences.Status.IDLE, presences.Status.ONLINE, presences.Status.OFFLINE ], ) @pytest.mark.parametrize("activity", [presences.Activity(name="foo"), None]) def test__serialize_and_store_presence_payload_when_all_args_undefined( self, client, idle_since, afk, status, activity): client._activity = activity client._idle_since = idle_since client._is_afk = afk client._status = status actual_result = client._serialize_and_store_presence_payload() if activity is not undefined.UNDEFINED and activity is not None: expected_activity = { "name": activity.name, "type": activity.type, "url": activity.url, } else: expected_activity = None if status == presences.Status.OFFLINE: expected_status = "invisible" else: expected_status = status.value expected_result = { "game": expected_activity, "since": int(idle_since.timestamp() * 1_000) if idle_since is not None else None, "afk": afk if afk is not undefined.UNDEFINED else False, "status": expected_status, }
if not future.done(): try: if predicate and not predicate(event): continue except Exception as ex: future.set_exception(ex) else: future.set_result(event) waiter_set.remove(waiter) if not waiter_set: del self._waiters[cls] self._increment_waiter_group_count(cls, -1) return asyncio.gather(*tasks) if tasks else aio.completed_future() def stream( self, event_type: typing.Type[base_events.EventT], /, timeout: typing.Union[float, int, None], limit: typing.Optional[int] = None, ) -> event_manager_.EventStream[base_events.EventT]: self._check_event(event_type, 1) return EventStream(self, event_type, timeout=timeout, limit=limit) async def wait_for( self, event_type: typing.Type[base_events.EventT], /,