예제 #1
0
    def dispatch(
            self,
            event: event_manager.EventT_inv) -> asyncio.Future[typing.Any]:
        if not isinstance(event, base_events.Event):
            raise TypeError(
                f"Events must be subclasses of {base_events.Event.__name__}, not {type(event).__name__}"
            )

        # We only need to iterate through the MRO until we hit Event, as
        # anything after that is random garbage we don't care about, as they do
        # not describe event types. This improves efficiency as well.
        mro = type(event).mro()

        tasks: typing.List[typing.Coroutine[None, typing.Any, None]] = []

        for cls in mro[:mro.index(base_events.Event) + 1]:
            if cls in self._listeners:
                for callback in self._listeners[cls]:
                    tasks.append(self._invoke_callback(callback, event))

            if cls in self._waiters:
                waiter_set = self._waiters[cls]
                for predicate, future in tuple(waiter_set):
                    try:
                        result = predicate(event)
                        if not result:
                            continue
                    except Exception as ex:
                        future.set_exception(ex)
                    else:
                        future.set_result(event)

                    waiter_set.remove((predicate, future))

        return asyncio.gather(*tasks) if tasks else aio.completed_future()
예제 #2
0
    def acquire(
        self, max_rate_limit: float = float("inf")) -> asyncio.Future[None]:
        """Acquire time on this rate limiter.

        !!! note
            You should afterwards invoke `RESTBucket.update_rate_limit` to
            update any rate limit information you are made aware of.

        Parameters
        ----------
        max_rate_limit : builtins.float
            The max number of seconds to backoff for when rate limited. Anything
            greater than this will instead raise an error.

            The default is an infinite value, which will thus never time out.

        Returns
        -------
        asyncio.Future[builtins.None]
            A future that should be awaited immediately. Once the future completes,
            you are allowed to proceed with your operation.


            If the reset-after time for the bucket is greater than
            `max_rate_limit`, then this will contain `RateLimitTooLongError`
            as an exception.
        """
        return aio.completed_future(
            None) if self.is_unknown else super().acquire()
예제 #3
0
    def acquire(self) -> asyncio.Future[None]:
        """Acquire time on this rate limiter.

        !!! note
            You should afterwards invoke `RESTBucket.update_rate_limit` to
            update any rate limit information you are made aware of.

        Returns
        -------
        asyncio.Future[builtins.None]
            A future that should be awaited immediately. Once the future completes,
            you are allowed to proceed with your operation.
        """
        return aio.completed_future(
            None) if self.is_unknown else super().acquire()
예제 #4
0
 async def test_non_default_result(self):
     assert aio.completed_future(...).result() is ...
예제 #5
0
 async def test_default_result_is_none(self):
     assert aio.completed_future().result() is None
예제 #6
0
 async def test_is_completed(self, args):
     future = aio.completed_future(*args)
     assert future.done()
예제 #7
0
class TestGatewayShardImpl:
    @pytest.fixture()
    def client_session(self):
        stub = client_session_stub.ClientSessionStub()
        with mock.patch.object(aiohttp, "ClientSession", new=stub):
            yield stub

    @pytest.fixture(scope="module")
    def unslotted_client_type(self):
        return hikari_test_helpers.mock_class_namespace(shard.GatewayShardImpl,
                                                        slots_=False)

    @pytest.fixture()
    def client(self, http_settings, proxy_settings, unslotted_client_type):
        return unslotted_client_type(
            url="wss://gateway.discord.gg",
            intents=intents.Intents.ALL,
            token="lol",
            event_consumer=mock.Mock(),
            http_settings=http_settings,
            proxy_settings=proxy_settings,
        )

    @pytest.mark.parametrize(
        ("compression", "expect"),
        [
            (None, f"v={shard._VERSION}&encoding=json"),
            ("payload_zlib_stream",
             f"v={shard._VERSION}&encoding=json&compress=zlib-stream"),
        ],
    )
    def test__init__sets_url_is_correct_json(self, compression, expect,
                                             http_settings, proxy_settings):
        g = shard.GatewayShardImpl(
            event_consumer=mock.Mock(),
            http_settings=http_settings,
            proxy_settings=proxy_settings,
            intents=intents.Intents.ALL,
            url="wss://gaytewhuy.discord.meh",
            data_format="json",
            compression=compression,
            token="12345",
        )

        assert g._url == f"wss://gaytewhuy.discord.meh?{expect}"

    def test_using_etf_is_unsupported(self, http_settings, proxy_settings):
        with pytest.raises(NotImplementedError,
                           match="Unsupported gateway data format: etf"):
            shard.GatewayShardImpl(
                event_consumer=mock.Mock(),
                http_settings=http_settings,
                proxy_settings=proxy_settings,
                token=mock.Mock(),
                url="wss://erlpack-is-broken-lol.discord.meh",
                intents=intents.Intents.ALL,
                data_format="etf",
                compression=True,
            )

    def test_heartbeat_latency_property(self, client):
        client._heartbeat_latency = 420
        assert client.heartbeat_latency == 420

    def test_id_property(self, client):
        client._shard_id = 101
        assert client.id == 101

    def test_intents_property(self, client):
        intents = object()
        client._intents = intents
        assert client.intents is intents

    @pytest.mark.parametrize(
        ("run_task", "expected"),
        [
            (None, False),
            (asyncio.get_event_loop().create_future(), True),
            (aio.completed_future(), False),
        ],
    )
    def test_is_alive_property(self, run_task, expected, client):
        client._run_task = run_task
        assert client.is_alive is expected

    def test_shard_count_property(self, client):
        client._shard_count = 69
        assert client.shard_count == 69

    async def test_close_when_closing_set(self, client):
        client._closing = mock.Mock(is_set=mock.Mock(return_value=True))
        client._ws = mock.Mock()
        client._chunking_rate_limit = mock.Mock()
        client._total_rate_limit = mock.Mock()

        await client.close()

        client._closing.set.assert_not_called()
        client._ws.close.assert_not_called()
        client._chunking_rate_limit.close.assert_not_called()
        client._total_rate_limit.close.assert_not_called()

    async def test_close_when_closing_not_set(self, client):
        client._closing = mock.Mock(is_set=mock.Mock(return_value=False))
        client._ws = mock.Mock(close=mock.AsyncMock())
        client._chunking_rate_limit = mock.Mock()
        client._total_rate_limit = mock.Mock()

        await client.close()

        client._closing.set.assert_called_once_with()
        client._ws.close.assert_awaited_once_with(
            code=errors.ShardCloseCode.GOING_AWAY,
            message=b"shard disconnecting")
        client._chunking_rate_limit.close.assert_called_once_with()
        client._total_rate_limit.close.assert_called_once_with()

    async def test_close_when_closing_not_set_and_ws_is_None(self, client):
        client._closing = mock.Mock(is_set=mock.Mock(return_value=False))
        client._ws = None
        client._chunking_rate_limit = mock.Mock()
        client._total_rate_limit = mock.Mock()

        await client.close()

        client._closing.set.assert_called_once_with()
        client._chunking_rate_limit.close.assert_called_once_with()
        client._total_rate_limit.close.assert_called_once_with()

    async def test_when__user_id_is_None(self, client):
        client._handshake_completed = mock.Mock(wait=mock.AsyncMock())
        client._user_id = None
        with pytest.raises(RuntimeError):
            assert await client.get_user_id()

    async def test_when__user_id_is_not_None(self, client):
        client._handshake_completed = mock.Mock(wait=mock.AsyncMock())
        client._user_id = 123
        assert await client.get_user_id() == 123

    async def test_join(self, client):
        client._closed = mock.Mock(wait=mock.AsyncMock())

        await client.join()

        client._closed.wait.assert_awaited_once_with()

    async def test_request_guild_members_when_no_query_and_no_limit_and_GUILD_MEMBERS_not_enabled(
            self, client):
        client._intents = intents.Intents.GUILD_INTEGRATIONS

        with pytest.raises(errors.MissingIntentError):
            await client.request_guild_members(123, query="", limit=0)

    async def test_request_guild_members_when_presences_and_GUILD_PRESENCES_not_enabled(
            self, client):
        client._intents = intents.Intents.GUILD_INTEGRATIONS

        with pytest.raises(errors.MissingIntentError):
            await client.request_guild_members(123,
                                               query="test",
                                               limit=1,
                                               include_presences=True)

    async def test_request_guild_members_when_presences_false_and_GUILD_PRESENCES_not_enabled(
            self, client):
        client._intents = intents.Intents.GUILD_INTEGRATIONS
        client._ws = mock.Mock(send_json=mock.AsyncMock())

        await client.request_guild_members(123,
                                           query="test",
                                           limit=1,
                                           include_presences=False)

        client._ws.send_json.assert_awaited_once_with({
            "op": 8,
            "d": {
                "guild_id": "123",
                "query": "test",
                "presences": False,
                "limit": 1
            },
        })

    @pytest.mark.parametrize("kwargs", [{"query": "some query"}, {"limit": 1}])
    async def test_request_guild_members_when_specifiying_users_with_limit_or_query(
            self, client, kwargs):
        client._intents = intents.Intents.GUILD_INTEGRATIONS

        with pytest.raises(ValueError,
                           match="Cannot specify limit/query with users"):
            await client.request_guild_members(123, users=[], **kwargs)

    @pytest.mark.parametrize("limit", [-1, 101])
    async def test_request_guild_members_when_limit_under_0_or_over_100(
            self, client, limit):
        client._intents = intents.Intents.ALL

        with pytest.raises(
                ValueError,
                match="'limit' must be between 0 and 100, both inclusive"):
            await client.request_guild_members(123, limit=limit)

    async def test_request_guild_members_when_users_over_100(self, client):
        client._intents = intents.Intents.ALL

        with pytest.raises(ValueError,
                           match="'users' is limited to 100 users"):
            await client.request_guild_members(123, users=range(101))

    async def test_request_guild_members_when_nonce_over_32_chars(
            self, client):
        client._intents = intents.Intents.ALL

        with pytest.raises(
                ValueError,
                match="'nonce' can be no longer than 32 byte characters long."
        ):
            await client.request_guild_members(123, nonce="x" * 33)

    @pytest.mark.parametrize("include_presences", [True, False])
    async def test_request_guild_members(self, client, include_presences):
        client._intents = intents.Intents.ALL
        client._ws = mock.Mock(send_json=mock.AsyncMock())

        await client.request_guild_members(123,
                                           include_presences=include_presences)

        client._ws.send_json.assert_awaited_once_with({
            "op": 8,
            "d": {
                "guild_id": "123",
                "query": "",
                "presences": include_presences,
                "limit": 0
            },
        })

    async def test_start_when_already_running(self, client):
        client._run_task = object()

        with pytest.raises(
                RuntimeError,
                match=
                "Cannot run more than one instance of one shard concurrently"):
            await client.start()

    async def test_start_when_shard_closed_before_starting(self, client):
        client._run_task = None
        client._shard_id = 20
        client._run = mock.Mock()
        client._handshake_completed = mock.Mock(wait=mock.Mock())
        run_task = mock.Mock()
        waiter = mock.Mock()

        stack = contextlib.ExitStack()
        create_task = stack.enter_context(
            mock.patch.object(asyncio,
                              "create_task",
                              side_effect=[run_task, waiter]))
        wait = stack.enter_context(
            mock.patch.object(asyncio,
                              "wait",
                              return_value=([run_task], [waiter])))
        stack.enter_context(
            pytest.raises(
                asyncio.CancelledError,
                match="shard 20 was closed before it could start successfully")
        )

        with stack:
            await client.start()

        assert client._run_task is None

        assert create_task.call_count == 2
        create_task.has_call(mock.call(client._run(), name="run shard 20"))
        create_task.has_call(
            mock.call(client._handshake_completed.wait(),
                      name="wait for shard 20 to start"))

        run_task.result.assert_called_once_with()
        waiter.cancel.assert_called_once_with()
        wait.assert_awaited_once_with((waiter, run_task),
                                      return_when=asyncio.FIRST_COMPLETED)

    async def test_start(self, client):
        client._run_task = None
        client._shard_id = 20
        client._run = mock.Mock()
        client._handshake_completed = mock.Mock(wait=mock.Mock())
        run_task = mock.Mock()
        waiter = mock.Mock()

        with mock.patch.object(asyncio,
                               "create_task",
                               side_effect=[run_task, waiter]) as create_task:
            with mock.patch.object(asyncio,
                                   "wait",
                                   return_value=([waiter], [run_task
                                                            ])) as wait:
                await client.start()

        assert client._run_task == run_task

        assert create_task.call_count == 2
        create_task.has_call(mock.call(client._run(), name="run shard 20"))
        create_task.has_call(
            mock.call(client._handshake_completed.wait(),
                      name="wait for shard 20 to start"))

        run_task.result.assert_not_called()
        waiter.cancel.assert_called_once_with()
        wait.assert_awaited_once_with((waiter, run_task),
                                      return_when=asyncio.FIRST_COMPLETED)

    async def test_update_presence(self, client):
        presence_payload = object()
        client._ws = mock.Mock(send_json=mock.AsyncMock())
        client._serialize_and_store_presence_payload = mock.Mock(
            return_value=presence_payload)
        client._send_json = mock.AsyncMock()

        await client.update_presence(
            idle_since=datetime.datetime.now(),
            afk=True,
            status=presences.Status.IDLE,
            activity=None,
        )

        client._ws.send_json.assert_awaited_once_with({
            "op": 3,
            "d": presence_payload
        })

    @pytest.mark.parametrize("channel", [12345, None])
    @pytest.mark.parametrize("self_deaf", [True, False])
    @pytest.mark.parametrize("self_mute", [True, False])
    async def test_update_voice_state(self, client, channel, self_deaf,
                                      self_mute):
        client._ws = mock.Mock(send_json=mock.AsyncMock())
        payload = {
            "channel_id": str(channel) if channel is not None else None,
            "guild_id": "6969420",
            "deaf": self_deaf,
            "mute": self_mute,
        }

        await client.update_voice_state("6969420",
                                        channel,
                                        self_mute=self_mute,
                                        self_deaf=self_deaf)

        client._ws.send_json.assert_awaited_once_with({"op": 4, "d": payload})

    def test_dispatch_when_READY(self, client):
        client._seq = 0
        client._session_id = 0
        client._user_id = 0
        client._logger = mock.Mock()
        client._handshake_completed = mock.Mock()
        client._event_consumer = mock.Mock()

        pl = {
            "session_id": 101,
            "user": {
                "id": 123,
                "username": "******",
                "discriminator": "5863"
            },
            "guilds": [
                {
                    "id": "123"
                },
                {
                    "id": "456"
                },
                {
                    "id": "789"
                },
            ],
            "v": 8,
        }

        client._dispatch(
            "READY",
            10,
            pl,
        )

        assert client._seq == 10
        assert client._session_id == 101
        assert client._user_id == 123
        client._logger.info.assert_called_once_with(
            "shard is ready: %s guilds, %s (%s), session %r on v%s gateway",
            3,
            "hikari#5863",
            123,
            101,
            8,
        )
        client._handshake_completed.set.assert_called_once_with()
        client._event_consumer.assert_called_once_with(
            client,
            "READY",
            pl,
        )

    def test__dipatch_when_RESUME(self, client):
        client._seq = 0
        client._session_id = 123
        client._logger = mock.Mock()
        client._handshake_completed = mock.Mock()
        client._event_consumer = mock.Mock()

        client._dispatch("RESUME", 10, {})

        assert client._seq == 10
        client._logger.info.assert_called_once_with(
            "shard has resumed [session:%s, seq:%s]", 123, 10)
        client._handshake_completed.set.assert_called_once_with()
        client._event_consumer.assert_called_once_with(client, "RESUME", {})

    def test__dipatch(self, client):
        client._logger = mock.Mock()
        client._handshake_completed = mock.Mock()
        client._event_consumer = mock.Mock()

        client._dispatch("EVENT NAME", 10, {"payload": None})

        client._logger.info.assert_not_called()
        client._handshake_completed.set.assert_not_called()
        client._event_consumer.assert_called_once_with(client, "EVENT NAME",
                                                       {"payload": None})

    async def test__identify(self, client):
        client._token = "token"
        client._intents = intents.Intents.ALL
        client._large_threshold = 123
        client._shard_id = 0
        client._shard_count = 1
        client._serialize_and_store_presence_payload = mock.Mock(
            return_value={"presence": "payload"})
        client._ws = mock.Mock(send_json=mock.AsyncMock())
        stack = contextlib.ExitStack()
        stack.enter_context(
            mock.patch.object(platform, "system", return_value="Potato PC"))
        stack.enter_context(
            mock.patch.object(platform, "architecture",
                              return_value=["ARM64"]))
        stack.enter_context(
            mock.patch.object(aiohttp, "__version__", new="v0.0.1"))
        stack.enter_context(
            mock.patch.object(_about, "__version__", new="v1.0.0"))

        with stack:
            await client._identify()

        expected_json = {
            "op": 2,
            "d": {
                "token": "token",
                "compress": False,
                "large_threshold": 123,
                "properties": {
                    "$os": "Potato PC ARM64",
                    "$browser": "aiohttp v0.0.1",
                    "$device": "hikari v1.0.0",
                },
                "shard": [0, 1],
                "intents": 32767,
                "presence": {
                    "presence": "payload"
                },
            },
        }
        client._ws.send_json.assert_awaited_once_with(expected_json)

    @hikari_test_helpers.timeout()
    async def test__heartbeat(self, client):
        client._last_heartbeat_sent = 5
        client._logger = mock.Mock()
        client._closing = mock.Mock(is_set=mock.Mock(return_value=False))
        client._closed = mock.Mock(is_set=mock.Mock(return_value=False))
        client._send_heartbeat = mock.AsyncMock()

        with mock.patch.object(time, "monotonic", return_value=10):
            with mock.patch.object(asyncio,
                                   "wait_for",
                                   side_effect=[asyncio.TimeoutError,
                                                None]) as wait_for:
                assert await client._heartbeat(20) is False

        wait_for.assert_awaited_with(client._closing.wait(), timeout=20)

    @hikari_test_helpers.timeout()
    async def test__heartbeat_when_zombie(self, client):
        client._last_heartbeat_sent = 10
        client._logger = mock.Mock()

        with mock.patch.object(time, "monotonic", return_value=5):
            with mock.patch.object(asyncio, "wait_for") as wait_for:
                assert await client._heartbeat(20) is True

        wait_for.assert_not_called()

    async def test__resume(self, client):
        client._token = "token"
        client._seq = 123
        client._session_id = 456
        client._ws = mock.Mock(send_json=mock.AsyncMock())

        await client._resume()

        expected_json = {
            "op": 6,
            "d": {
                "token": "token",
                "seq": 123,
                "session_id": 456
            },
        }
        client._ws.send_json.assert_awaited_once_with(expected_json)

    @pytest.mark.skip("TODO")
    async def test__run(self, client):
        ...

    @pytest.mark.skip("TODO")
    async def test__run_once(self, client):
        ...

    async def test__send_heartbeat(self, client):
        client._ws = mock.Mock(send_json=mock.AsyncMock())
        client._last_heartbeat_sent = 0
        client._seq = 10

        with mock.patch.object(time, "monotonic", return_value=200):
            await client._send_heartbeat()

        client._ws.send_json.assert_awaited_once_with({"op": 1, "d": 10})
        assert client._last_heartbeat_sent == 200

    async def test__send_heartbeat_ack(self, client):
        client._ws = mock.Mock(send_json=mock.AsyncMock())

        await client._send_heartbeat_ack()

        client._ws.send_json.assert_awaited_once_with({"op": 11, "d": None})

    def test__serialize_activity_when_activity_is_None(self, client):
        assert client._serialize_activity(None) is None

    def test__serialize_activity_when_activity_is_not_None(self, client):
        activity = mock.Mock(type="0", url="https://some.url")
        activity.name = "some name"  # This has to be set seperate because if not, its set as the mock's name
        assert client._serialize_activity(activity) == {
            "name": "some name",
            "type": 0,
            "url": "https://some.url"
        }

    @pytest.mark.parametrize("idle_since", [datetime.datetime.now(), None])
    @pytest.mark.parametrize("afk", [True, False])
    @pytest.mark.parametrize(
        "status",
        [
            presences.Status.DO_NOT_DISTURB, presences.Status.IDLE,
            presences.Status.ONLINE, presences.Status.OFFLINE
        ],
    )
    @pytest.mark.parametrize("activity",
                             [presences.Activity(name="foo"), None])
    def test__serialize_and_store_presence_payload_when_all_args_undefined(
            self, client, idle_since, afk, status, activity):
        client._activity = activity
        client._idle_since = idle_since
        client._is_afk = afk
        client._status = status

        actual_result = client._serialize_and_store_presence_payload()

        if activity is not undefined.UNDEFINED and activity is not None:
            expected_activity = {
                "name": activity.name,
                "type": activity.type,
                "url": activity.url,
            }
        else:
            expected_activity = None

        if status == presences.Status.OFFLINE:
            expected_status = "invisible"
        else:
            expected_status = status.value

        expected_result = {
            "game":
            expected_activity,
            "since":
            int(idle_since.timestamp() *
                1_000) if idle_since is not None else None,
            "afk":
            afk if afk is not undefined.UNDEFINED else False,
            "status":
            expected_status,
        }
예제 #8
0
                if not future.done():
                    try:
                        if predicate and not predicate(event):
                            continue
                    except Exception as ex:
                        future.set_exception(ex)
                    else:
                        future.set_result(event)

                waiter_set.remove(waiter)

            if not waiter_set:
                del self._waiters[cls]
                self._increment_waiter_group_count(cls, -1)

        return asyncio.gather(*tasks) if tasks else aio.completed_future()

    def stream(
        self,
        event_type: typing.Type[base_events.EventT],
        /,
        timeout: typing.Union[float, int, None],
        limit: typing.Optional[int] = None,
    ) -> event_manager_.EventStream[base_events.EventT]:
        self._check_event(event_type, 1)
        return EventStream(self, event_type, timeout=timeout, limit=limit)

    async def wait_for(
        self,
        event_type: typing.Type[base_events.EventT],
        /,