예제 #1
0
    def dispatch(
            self,
            event: event_dispatcher.EventT_inv) -> asyncio.Future[typing.Any]:
        if not isinstance(event, base_events.Event):
            raise TypeError(
                f"Events must be subclasses of {base_events.Event.__name__}, not {type(event).__name__}"
            )

        # We only need to iterate through the MRO until we hit Event, as
        # anything after that is random garbage we don't care about, as they do
        # not describe event types. This improves efficiency as well.
        mro = type(event).mro()

        tasks: typing.List[typing.Coroutine[None, typing.Any, None]] = []

        for cls in mro[:mro.index(base_events.Event) + 1]:
            if cls in self._listeners:
                for callback in self._listeners[cls]:
                    tasks.append(self._invoke_callback(callback, event))

            if cls in self._waiters:
                for predicate, future in tuple(self._waiters[cls]):
                    try:
                        result = predicate(event)
                        if not result:
                            continue
                    except Exception as ex:
                        future.set_exception(ex)
                    else:
                        future.set_result(event)

                    waiter_set = self._waiters[cls]
                    waiter_set.remove((predicate, future))

        return asyncio.gather(*tasks) if tasks else aio.completed_future()
예제 #2
0
파일: buckets.py 프로젝트: tomxey/hikari
    def acquire(self) -> asyncio.Future[None]:
        """Acquire time on this rate limiter.

        !!! note
            You should afterwards invoke `RESTBucket.update_rate_limit` to
            update any rate limit information you are made aware of.

        Returns
        -------
        asyncio.Future[builtins.None]
            A future that should be awaited immediately. Once the future completes,
            you are allowed to proceed with your operation.
        """
        return aio.completed_future(
            None) if self.is_unknown else super().acquire()
예제 #3
0
파일: test_shard.py 프로젝트: tomxey/hikari
class TestGatewayShardImpl:
    @pytest.fixture()
    def client_session(self):
        stub = client_session_stub.ClientSessionStub()
        with mock.patch.object(aiohttp, "ClientSession", new=stub):
            yield stub

    @pytest.fixture(scope="module")
    def unslotted_client_type(self):
        return hikari_test_helpers.unslot_class(shard.GatewayShardImpl)

    @pytest.fixture()
    def client(self, http_settings, proxy_settings, unslotted_client_type):
        return unslotted_client_type(
            url="wss://gateway.discord.gg",
            token="lol",
            event_consumer=mock.Mock(),
            http_settings=http_settings,
            proxy_settings=proxy_settings,
        )

    @pytest.mark.parametrize(
        ("compression", "expect"),
        [
            (None, "v=6&encoding=json"),
            ("payload_zlib_stream", "v=6&encoding=json&compress=zlib-stream"),
        ],
    )
    def test__init__sets_url_is_correct_json(self, compression, expect, http_settings, proxy_settings):
        g = shard.GatewayShardImpl(
            event_consumer=mock.Mock(),
            http_settings=http_settings,
            proxy_settings=proxy_settings,
            url="wss://gaytewhuy.discord.meh",
            data_format="json",
            compression=compression,
            token="12345",
        )

        assert g._url == f"wss://gaytewhuy.discord.meh?{expect}"

    def test_using_etf_is_unsupported(self, http_settings, proxy_settings):
        with pytest.raises(NotImplementedError, match="Unsupported gateway data format: etf"):
            shard.GatewayShardImpl(
                event_consumer=mock.Mock(),
                http_settings=http_settings,
                proxy_settings=proxy_settings,
                token=mock.Mock(),
                url="wss://erlpack-is-broken-lol.discord.meh",
                data_format="etf",
                compression=True,
            )

    def test_heartbeat_latency_property(self, client):
        client._heartbeat_latency = 420
        assert client.heartbeat_latency == 420

    def test_id_property(self, client):
        client._shard_id = 101
        assert client.id == 101

    def test_intents_property(self, client):
        intents = object()
        client._intents = intents
        assert client.intents is intents

    @pytest.mark.parametrize(
        ("run_task", "expected"),
        [
            (None, False),
            (asyncio.get_event_loop().create_future(), True),
            (aio.completed_future(), False),
        ],
    )
    def test_is_alive_property(self, run_task, expected, client):
        client._run_task = run_task
        assert client.is_alive is expected

    def test_shard_count_property(self, client):
        client._shard_count = 69
        assert client.shard_count == 69

    async def test_close_when_closing_set(self, client):
        client._closing = mock.Mock(is_set=mock.Mock(return_value=True))
        client._ws = mock.Mock()
        client._chunking_rate_limit = mock.Mock()
        client._total_rate_limit = mock.Mock()

        await client.close()

        client._closing.set.assert_not_called()
        client._ws.close.assert_not_called()
        client._chunking_rate_limit.close.assert_not_called()
        client._total_rate_limit.close.assert_not_called()

    async def test_close_when_closing_not_set(self, client):
        client._closing = mock.Mock(is_set=mock.Mock(return_value=False))
        client._ws = mock.Mock(close=mock.AsyncMock())
        client._chunking_rate_limit = mock.Mock()
        client._total_rate_limit = mock.Mock()

        await client.close()

        client._closing.set.assert_called_once_with()
        client._ws.close.assert_awaited_once_with(code=errors.ShardCloseCode.GOING_AWAY, message=b"shard disconnecting")
        client._chunking_rate_limit.close.assert_called_once_with()
        client._total_rate_limit.close.assert_called_once_with()

    async def test_close_when_closing_not_set_and_ws_is_None(self, client):
        client._closing = mock.Mock(is_set=mock.Mock(return_value=False))
        client._ws = None
        client._chunking_rate_limit = mock.Mock()
        client._total_rate_limit = mock.Mock()

        await client.close()

        client._closing.set.assert_called_once_with()
        client._chunking_rate_limit.close.assert_called_once_with()
        client._total_rate_limit.close.assert_called_once_with()

    async def test_when__user_id_is_None(self, client):
        client._handshake_completed = mock.Mock(wait=mock.AsyncMock())
        client._user_id = None
        with pytest.raises(RuntimeError):
            assert await client.get_user_id()

    async def test_when__user_id_is_not_None(self, client):
        client._handshake_completed = mock.Mock(wait=mock.AsyncMock())
        client._user_id = 123
        assert await client.get_user_id() == 123

    async def test_join(self, client):
        client._closed = mock.Mock(wait=mock.AsyncMock())

        await client.join()

        client._closed.wait.assert_awaited_once_with()

    async def test_request_guild_members_when_no_query_and_no_limit_and_GUILD_MEMBERS_not_enabled(self, client):
        client._intents = intents.Intents.GUILD_INTEGRATIONS

        with pytest.raises(errors.MissingIntentError):
            await client.request_guild_members(123, query="", limit=0)

    async def test_request_guild_members_when_presences_and_GUILD_PRESENCES_not_enabled(self, client):
        client._intents = intents.Intents.GUILD_INTEGRATIONS

        with pytest.raises(errors.MissingIntentError):
            await client.request_guild_members(123, query="test", limit=1, include_presences=True)

    @pytest.mark.parametrize("kwargs", [{"query": "some query"}, {"limit": 1}])
    async def test_request_guild_members_when_specifiying_users_with_limit_or_query(self, client, kwargs):
        client._intents = intents.Intents.GUILD_INTEGRATIONS

        with pytest.raises(ValueError, match="Cannot specify limit/query with users"):
            await client.request_guild_members(123, users=[], **kwargs)

    @pytest.mark.parametrize("limit", [-1, 101])
    async def test_request_guild_members_when_limit_under_0_or_over_100(self, client, limit):
        client._intents = None

        with pytest.raises(ValueError, match="'limit' must be between 0 and 100, both inclusive"):
            await client.request_guild_members(123, limit=limit)

    async def test_request_guild_members_when_users_over_100(self, client):
        client._intents = None

        with pytest.raises(ValueError, match="'users' is limited to 100 users"):
            await client.request_guild_members(123, users=range(101))

    async def test_request_guild_members_when_nonce_over_32_chars(self, client):
        client._intents = None

        with pytest.raises(ValueError, match="'nonce' can be no longer than 32 byte characters long."):
            await client.request_guild_members(123, nonce="x" * 33)

    @pytest.mark.parametrize("include_presences", [True, False])
    async def test_request_guild_members(self, client, include_presences):
        client._intents = None
        client._ws = mock.Mock(send_json=mock.AsyncMock())

        await client.request_guild_members(123, include_presences=include_presences)

        client._ws.send_json.assert_awaited_once_with(
            {
                "op": 8,
                "d": {"guild_id": "123", "query": "", "presences": include_presences, "limit": 0},
            }
        )

    async def test_start_when_already_running(self, client):
        client._run_task = object()

        with pytest.raises(RuntimeError, match="Cannot run more than one instance of one shard concurrently"):
            await client.start()

    async def test_start_when_shard_closed_before_starting(self, client):
        client._run_task = None
        client._shard_id = 20
        client._run = mock.Mock()
        client._handshake_completed = mock.Mock(wait=mock.Mock())
        run_task = mock.Mock()
        waiter = mock.Mock()

        stack = contextlib.ExitStack()
        create_task = stack.enter_context(mock.patch.object(asyncio, "create_task", side_effect=[run_task, waiter]))
        wait = stack.enter_context(mock.patch.object(asyncio, "wait", return_value=([run_task], [waiter])))
        stack.enter_context(
            pytest.raises(asyncio.CancelledError, match="Shard 20 was closed before it could start successfully")
        )

        with stack:
            await client.start()

        assert client._run_task is None

        assert create_task.call_count == 2
        create_task.has_call(mock.call(client._run(), name="run shard 20"))
        create_task.has_call(mock.call(client._handshake_completed.wait(), name="wait for shard 20 to start"))

        run_task.result.assert_called_once_with()
        waiter.cancel.assert_called_once_with()
        wait.assert_awaited_once_with((waiter, run_task), return_when=asyncio.FIRST_COMPLETED)

    async def test_start(self, client):
        client._run_task = None
        client._shard_id = 20
        client._run = mock.Mock()
        client._handshake_completed = mock.Mock(wait=mock.Mock())
        run_task = mock.Mock()
        waiter = mock.Mock()

        with mock.patch.object(asyncio, "create_task", side_effect=[run_task, waiter]) as create_task:
            with mock.patch.object(asyncio, "wait", return_value=([waiter], [run_task])) as wait:
                await client.start()

        assert client._run_task == run_task

        assert create_task.call_count == 2
        create_task.has_call(mock.call(client._run(), name="run shard 20"))
        create_task.has_call(mock.call(client._handshake_completed.wait(), name="wait for shard 20 to start"))

        run_task.result.assert_not_called()
        waiter.cancel.assert_called_once_with()
        wait.assert_awaited_once_with((waiter, run_task), return_when=asyncio.FIRST_COMPLETED)

    async def test_update_presence(self, client):
        presence_payload = object()
        client._ws = mock.Mock(send_json=mock.AsyncMock())
        client._serialize_and_store_presence_payload = mock.Mock(return_value=presence_payload)
        client._send_json = mock.AsyncMock()

        await client.update_presence(
            idle_since=datetime.datetime.now(),
            afk=True,
            status=presences.Status.IDLE,
            activity=None,
        )

        client._ws.send_json.assert_awaited_once_with({"op": 3, "d": presence_payload})

    @pytest.mark.parametrize("channel", [12345, None])
    @pytest.mark.parametrize("self_deaf", [True, False])
    @pytest.mark.parametrize("self_mute", [True, False])
    async def test_update_voice_state(self, client, channel, self_deaf, self_mute):
        client._ws = mock.Mock(send_json=mock.AsyncMock())
        payload = {
            "channel_id": str(channel) if channel is not None else None,
            "guild_id": "6969420",
            "deaf": self_deaf,
            "mute": self_mute,
        }

        await client.update_voice_state("6969420", channel, self_mute=self_mute, self_deaf=self_deaf)

        client._ws.send_json.assert_awaited_once_with({"op": 4, "d": payload})

    def test_dispatch_when_READY(self, client):
        client._seq = 0
        client._session_id = 0
        client._user_id = 0
        client._logger = mock.Mock()
        client._handshake_completed = mock.Mock()
        client._event_consumer = mock.Mock()

        client._dispatch(
            "READY",
            10,
            {
                "session_id": 101,
                "user": {"id": 123, "username": "******", "discriminator": "5863"},
                "guilds": [
                    {"id": "123"},
                    {"id": "456"},
                    {"id": "789"},
                ],
            },
        )

        assert client._seq == 10
        assert client._session_id == 101
        assert client._user_id == 123
        client._logger.info.assert_called_once_with(
            "shard is ready [session:%s, user_id:%s, tag:%s, guilds:%s]", 101, 123, "hikari#5863", 3
        )
        client._handshake_completed.set.assert_called_once_with()
        client._event_consumer.assert_called_once_with(
            client,
            "READY",
            {
                "session_id": 101,
                "user": {"id": 123, "username": "******", "discriminator": "5863"},
                "guilds": [
                    {"id": "123"},
                    {"id": "456"},
                    {"id": "789"},
                ],
            },
        )

    def test__dipatch_when_RESUME(self, client):
        client._seq = 0
        client._session_id = 123
        client._logger = mock.Mock()
        client._handshake_completed = mock.Mock()
        client._event_consumer = mock.Mock()

        client._dispatch("RESUME", 10, {})

        assert client._seq == 10
        client._logger.info.assert_called_once_with("shard has resumed [session:%s, seq:%s]", 123, 10)
        client._handshake_completed.set.assert_called_once_with()
        client._event_consumer.assert_called_once_with(client, "RESUME", {})

    def test__dipatch(self, client):
        client._logger = mock.Mock()
        client._handshake_completed = mock.Mock()
        client._event_consumer = mock.Mock()

        client._dispatch("EVENT NAME", 10, {"payload": None})

        client._logger.info.assert_not_called()
        client._handshake_completed.set.assert_not_called()
        client._event_consumer.assert_called_once_with(client, "EVENT NAME", {"payload": None})

    async def test__identify_when_no_intents(self, client):
        client._token = "token"
        client._intents = None
        client._large_threshold = 123
        client._shard_id = 0
        client._shard_count = 1
        client._serialize_and_store_presence_payload = mock.Mock(return_value={"presence": "payload"})
        client._ws = mock.Mock(send_json=mock.AsyncMock())
        stack = contextlib.ExitStack()
        stack.enter_context(mock.patch.object(platform, "system", return_value="Potato PC"))
        stack.enter_context(mock.patch.object(platform, "architecture", return_value=["ARM64"]))
        stack.enter_context(mock.patch.object(aiohttp, "__version__", new="v0.0.1"))
        stack.enter_context(mock.patch.object(_about, "__version__", new="v1.0.0"))

        with stack:
            await client._identify()

        expected_json = {
            "op": 2,
            "d": {
                "token": "token",
                "compress": False,
                "large_threshold": 123,
                "properties": {
                    "$os": "Potato PC ARM64",
                    "$browser": "aiohttp v0.0.1",
                    "$device": "hikari v1.0.0",
                },
                "shard": [0, 1],
                "presence": {"presence": "payload"},
            },
        }
        client._ws.send_json.assert_awaited_once_with(expected_json)

    async def test__identify_when_intents(self, client):
        client._token = "token"
        client._intents = intents.Intents.ALL
        client._large_threshold = 123
        client._shard_id = 0
        client._shard_count = 1
        client._serialize_and_store_presence_payload = mock.Mock(return_value={"presence": "payload"})
        client._ws = mock.Mock(send_json=mock.AsyncMock())
        stack = contextlib.ExitStack()
        stack.enter_context(mock.patch.object(platform, "system", return_value="Potato PC"))
        stack.enter_context(mock.patch.object(platform, "architecture", return_value=["ARM64"]))
        stack.enter_context(mock.patch.object(aiohttp, "__version__", new="v0.0.1"))
        stack.enter_context(mock.patch.object(_about, "__version__", new="v1.0.0"))

        with stack:
            await client._identify()

        expected_json = {
            "op": 2,
            "d": {
                "token": "token",
                "compress": False,
                "large_threshold": 123,
                "properties": {
                    "$os": "Potato PC ARM64",
                    "$browser": "aiohttp v0.0.1",
                    "$device": "hikari v1.0.0",
                },
                "shard": [0, 1],
                "intents": 32767,
                "presence": {"presence": "payload"},
            },
        }
        client._ws.send_json.assert_awaited_once_with(expected_json)

    @hikari_test_helpers.timeout()
    async def test__heartbeat(self, client):
        client._last_heartbeat_sent = 5
        client._logger = mock.Mock()
        client._closing = mock.Mock(is_set=mock.Mock(return_value=False))
        client._closed = mock.Mock(is_set=mock.Mock(return_value=False))
        client._send_heartbeat = mock.AsyncMock()

        with mock.patch.object(date, "monotonic", return_value=10):
            with mock.patch.object(asyncio, "wait_for", side_effect=[asyncio.TimeoutError, None]) as wait_for:
                assert await client._heartbeat(20) is False

        wait_for.assert_awaited_with(client._closing.wait(), timeout=20)

    @hikari_test_helpers.timeout()
    async def test__heartbeat_when_zombie(self, client):
        client._last_heartbeat_sent = 10
        client._logger = mock.Mock()

        with mock.patch.object(date, "monotonic", return_value=5):
            with mock.patch.object(asyncio, "wait_for") as wait_for:
                assert await client._heartbeat(20) is True

        wait_for.assert_not_called()

    async def test__resume(self, client):
        client._token = "token"
        client._seq = 123
        client._session_id = 456
        client._ws = mock.Mock(send_json=mock.AsyncMock())

        await client._resume()

        expected_json = {
            "op": 6,
            "d": {"token": "token", "seq": 123, "session_id": 456},
        }
        client._ws.send_json.assert_awaited_once_with(expected_json)

    @pytest.mark.skip("TODO")
    async def test__run(self, client):
        ...

    @pytest.mark.skip("TODO")
    async def test__run_once(self, client):
        ...

    async def test__send_heartbeat(self, client):
        client._ws = mock.Mock(send_json=mock.AsyncMock())
        client._last_heartbeat_sent = 0
        client._seq = 10

        with mock.patch.object(date, "monotonic", return_value=200):
            await client._send_heartbeat()

        client._ws.send_json.assert_awaited_once_with({"op": 1, "d": 10})
        assert client._last_heartbeat_sent == 200

    async def test__send_heartbeat_ack(self, client):
        client._ws = mock.Mock(send_json=mock.AsyncMock())

        await client._send_heartbeat_ack()

        client._ws.send_json.assert_awaited_once_with({"op": 11, "d": None})

    def test__serialize_activity_when_activity_is_None(self, client):
        assert client._serialize_activity(None) is None

    def test__serialize_activity_when_activity_is_not_None(self, client):
        activity = mock.Mock(type="0", url="https://some.url")
        activity.name = "some name"  # This has to be set seperate because if not, its set as the mock's name
        assert client._serialize_activity(activity) == {"name": "some name", "type": 0, "url": "https://some.url"}

    @pytest.mark.parametrize("idle_since", [datetime.datetime.now(), None])
    @pytest.mark.parametrize("afk", [True, False])
    @pytest.mark.parametrize(
        "status",
        [presences.Status.DO_NOT_DISTURB, presences.Status.IDLE, presences.Status.ONLINE, presences.Status.OFFLINE],
    )
    @pytest.mark.parametrize("activity", [presences.Activity(name="foo"), None])
    def test__serialize_and_store_presence_payload_when_all_args_undefined(
        self, client, idle_since, afk, status, activity
    ):
        client._activity = activity
        client._idle_since = idle_since
        client._is_afk = afk
        client._status = status

        actual_result = client._serialize_and_store_presence_payload()

        if activity is not undefined.UNDEFINED and activity is not None:
            expected_activity = {
                "name": activity.name,
                "type": activity.type,
                "url": activity.url,
            }
        else:
            expected_activity = None

        if status == presences.Status.OFFLINE:
            expected_status = "invisible"
        else:
            expected_status = status.value

        expected_result = {
            "game": expected_activity,
            "since": int(idle_since.timestamp() * 1_000) if idle_since is not None else None,
            "afk": afk if afk is not undefined.UNDEFINED else False,
            "status": expected_status,
        }
예제 #4
0
파일: test_aio.py 프로젝트: tomxey/hikari
 async def test_non_default_result(self):
     assert aio.completed_future(...).result() is ...
예제 #5
0
파일: test_aio.py 프로젝트: tomxey/hikari
 async def test_default_result_is_none(self):
     assert aio.completed_future().result() is None
예제 #6
0
파일: test_aio.py 프로젝트: tomxey/hikari
 async def test_is_completed(self, args):
     future = aio.completed_future(*args)
     assert future.done()