예제 #1
0
    async def test_game(
        self,
        cog: EventsCog,
        message: discord.Message,
        add_user: Callable[..., User],
    ) -> None:
        player1 = add_user()
        player2 = add_user()
        users = [mock_discord_object(player1), mock_discord_object(player2)]
        with mock_operations(lfg_action, users=users):
            lfg_action.safe_followup_channel.return_value = message

            await self.run(
                cog.game,
                players=f"<@{player1.xid}><@{player2.xid}>",
                format=cast(int, GameFormat.LEGACY.value),
            )

        game = DatabaseSession.query(Game).one()
        assert game.status == GameStatus.STARTED.value
        admin = DatabaseSession.query(User).get(self.interaction.user.id)
        assert admin is not None and admin.game_id is None
        players = DatabaseSession.query(User).filter(User.xid != self.interaction.user.id).all()
        assert len(players) == 2
        for player in players:
            assert player.game_id == game.id
예제 #2
0
 async def test_lfg(self, cog: LookingForGameCog, channel: Channel):
     await self.run(cog.lfg)
     game = DatabaseSession.query(Game).one()
     user = DatabaseSession.query(User).one()
     assert game.channel_xid == channel.xid
     assert game.guild_xid == self.guild.xid
     assert user.game_id == game.id
예제 #3
0
    async def test_watch_and_unwatch(self, cog: WatchCog,
                                     add_user: Callable[..., User]) -> None:
        target_user = add_user()
        target_member = cast(discord.Member, mock_discord_object(target_user))

        await self.run(cog.watch, target=target_member, note="note")
        self.interaction.response.send_message.assert_called_once_with(
            f"Watching <@{target_member.id}>.",
            ephemeral=True,
        )

        watch = DatabaseSession.query(Watch).one()
        assert watch.to_dict() == {
            "guild_xid": self.guild.xid,
            "user_xid": target_member.id,
            "note": "note",
        }

        self.interaction.response.send_message.reset_mock()
        await self.run(cog.unwatch, target=target_member)
        self.interaction.response.send_message.assert_called_once_with(
            f"No longer watching <@{target_member.id}>.",
            ephemeral=True,
        )

        watch = DatabaseSession.query(Watch).one_or_none()
        assert not watch
예제 #4
0
    async def test_lfg_with_friend_when_game_wrong_format(
            self, guild: Guild, channel: Channel):
        user1 = UserFactory.create(xid=101)
        user2 = UserFactory.create(xid=102)
        user3 = UserFactory.create(xid=103)
        bad_game = GameFactory.create(
            seats=4,
            channel=channel,
            guild=guild,
            format=GameFormat.TWO_HEADED_GIANT.value,
        )

        games = GamesService()
        new = await games.upsert(
            guild_xid=guild.xid,
            channel_xid=channel.xid,
            author_xid=user1.xid,
            friends=[user2.xid, user3.xid],
            seats=4,
            format=GameFormat.COMMANDER.value,
        )
        assert new

        DatabaseSession.expire_all()
        game = DatabaseSession.query(Game).filter(Game.id != bad_game.id).one()
        assert game.guild_xid == guild.xid
        assert game.channel_xid == channel.xid
        rows = DatabaseSession.query(
            User.xid).filter(User.game_id == game.id).all()
        assert set(row[0] for row in rows) == {101, 102, 103}
예제 #5
0
    async def test_block_and_unblock(self, cog: BlockCog):
        target = MagicMock()
        target.id = 2
        target.display_name = "target-author-display-name"

        await self.run(cog.block, target=target)

        self.interaction.response.send_message.assert_called_once_with(
            f"<@{target.id}> has been blocked.",
            ephemeral=True,
        )

        users = sorted(list(DatabaseSession.query(User).all()), key=lambda u: u.name)
        assert len(users) == 2
        assert users[0].name == target.display_name
        assert users[0].xid == target.id
        assert users[1].name == self.interaction.user.display_name
        assert users[1].xid == self.interaction.user.id

        blocks = list(DatabaseSession.query(Block).all())
        assert len(blocks) == 1
        assert blocks[0].user_xid == self.interaction.user.id
        assert blocks[0].blocked_user_xid == target.id

        DatabaseSession.expire_all()
        await self.run(cog.unblock, target=target)
        blocks = list(DatabaseSession.query(Block).all())
        assert len(blocks) == 0
예제 #6
0
    async def test_users_select(self):
        users = UsersService()
        assert not await users.select(201)

        UserFactory.create(xid=201)

        DatabaseSession.expire_all()
        assert await users.select(201)
예제 #7
0
    async def test_games_make_ready(self, game: Game):
        games = GamesService()
        await games.select(game.id)
        await games.make_ready("http://link")

        DatabaseSession.expire_all()
        found = DatabaseSession.query(Game).get(game.id)
        assert found and found.spelltable_link == "http://link"
        assert found.status == GameStatus.STARTED.value
예제 #8
0
    async def test_games_set_message_xid(self, game: Game):
        games = GamesService()
        await games.select(game.id)
        await games.set_message_xid(12345)

        DatabaseSession.expire_all()
        found = DatabaseSession.query(Game).filter_by(
            message_xid=12345).one_or_none()
        assert found and found.id == game.id
예제 #9
0
    async def test_games_set_voice(self, game: Game):
        games = GamesService()
        await games.select(game.id)
        await games.set_voice(12345, "http://link")

        DatabaseSession.expire_all()
        found = DatabaseSession.query(Game).get(game.id)
        assert found and found.voice_xid == 12345
        assert found.voice_invite_link == "http://link"
예제 #10
0
    async def test_games_add_player(self, game: Game):
        user = UserFactory.create()

        games = GamesService()
        await games.select(game.id)
        await games.add_player(user.xid)

        DatabaseSession.expire_all()
        found = DatabaseSession.query(User).get(user.xid)
        assert found and found.game.id == game.id
예제 #11
0
    async def test_games_watch_notes(self, game: Game):
        user1 = UserFactory.create(game=game)
        user2 = UserFactory.create(game=game)
        user3 = UserFactory.create()
        watch = WatchFactory.create(guild_xid=game.guild.xid,
                                    user_xid=user1.xid)

        DatabaseSession.expire_all()
        games = GamesService()
        await games.select(game.id)
        assert await games.watch_notes([user1.xid, user2.xid, user3.xid]) == {
            user1.xid: watch.note,
        }
예제 #12
0
    async def test_guilds_award_delete(self, guild: Guild):
        award1 = GuildAwardFactory.create(guild=guild)
        award2 = GuildAwardFactory.create(guild=guild)

        guilds = GuildsService()
        await guilds.select(guild.xid)
        award1_id = award1.id
        await guilds.award_delete(award1.id)
        await guilds.award_delete(404)

        DatabaseSession.expire_all()
        assert not DatabaseSession.query(GuildAward).get(award1_id)
        assert DatabaseSession.query(GuildAward).get(award2.id)
예제 #13
0
    async def test_users_watch_without_note(self, guild: Guild):
        user = UserFactory.create()

        users = UsersService()
        await users.watch(guild_xid=guild.xid, user_xid=user.xid)

        DatabaseSession.expire_all()
        watches = [w.to_dict() for w in DatabaseSession.query(Watch).all()]
        assert watches == [{
            "guild_xid": guild.xid,
            "user_xid": user.xid,
            "note": None
        }]
예제 #14
0
    async def test_set_motd(self, cog: AdminCog) -> None:
        await self.run(cog.motd, message="this is a test")
        self.interaction.response.send_message.assert_called_once_with(
            "Message of the day updated.",
            ephemeral=True,
        )
        guild = DatabaseSession.query(Guild).one()
        assert guild.motd == "this is a test"

        await self.run(cog.motd)
        DatabaseSession.expire_all()
        guild = DatabaseSession.query(Guild).one()
        assert guild.motd == ""
예제 #15
0
    async def test_users_block(self):
        user1 = UserFactory.create()
        user2 = UserFactory.create()

        users = UsersService()
        await users.block(user1.xid, user2.xid)

        DatabaseSession.expire_all()
        blocks = [b.to_dict() for b in DatabaseSession.query(Block).all()]
        assert blocks == [{
            "user_xid": user1.xid,
            "blocked_user_xid": user2.xid
        }]
예제 #16
0
    async def test_guilds_set_motd(self):
        guilds = GuildsService()
        assert not await guilds.select(101)

        guild = GuildFactory.create()

        guilds = GuildsService()
        await guilds.select(guild.xid)
        message_of_the_day = "message of the day"
        await guilds.set_motd(message_of_the_day)

        DatabaseSession.expire_all()
        guild = DatabaseSession.query(Guild).get(guild.xid)
        assert guild and guild.motd == message_of_the_day
예제 #17
0
    async def test_channel_motd(self, cog: AdminCog):
        motd = "this is a channel message of the day"
        await self.run(cog.channel_motd, message=motd)
        self.interaction.response.send_message.assert_called_once_with(
            f"Message of the day for this channel has been set to: {motd}",
            ephemeral=True,
        )
        channel = DatabaseSession.query(Channel).one()
        assert channel.motd == motd

        await self.run(cog.channel_motd)
        DatabaseSession.expire_all()
        channel = DatabaseSession.query(Channel).one()
        assert channel.motd == ""
예제 #18
0
    async def test_lfg_alone_when_existing_game(self, game: Game, user: User):
        games = GamesService()
        new = await games.upsert(
            guild_xid=game.guild.xid,
            channel_xid=game.channel.xid,
            author_xid=user.xid,
            friends=[],
            seats=4,
            format=GameFormat.COMMANDER.value,
        )
        assert not new

        DatabaseSession.expire_all()
        found = DatabaseSession.query(User).one()
        assert found.game_id == game.id
예제 #19
0
 async def test_award_add_zero_count(self, cog: AdminCog):
     await self.run(cog.award_add, count=0, role="role", message="message")
     self.interaction.response.send_message.assert_called_once_with(
         "You can't create an award for zero games played.",
         ephemeral=True,
     )
     assert DatabaseSession.query(GuildAward).count() == 0
예제 #20
0
    async def test_concurrent_lfg_requests_same_channel(
        self,
        bot: SpellBot,
        monkeypatch: pytest.MonkeyPatch,
    ):
        next_message_xid = 1

        def get_next_message(*args: Any, **kwargs: Any):
            nonlocal next_message_xid
            message = MagicMock(spec=discord.Message)
            message.id = next_message_xid
            next_message_xid += 1
            return message

        monkeypatch.setattr(lfg_action, "safe_fetch_user", AsyncMock())
        monkeypatch.setattr(
            lfg_action,
            "safe_followup_channel",
            AsyncMock(side_effect=get_next_message),
        )
        monkeypatch.setattr(
            lfg_action,
            "safe_get_partial_message",
            MagicMock(side_effect=get_next_message),
        )
        monkeypatch.setattr(lfg_action, "safe_update_embed_origin",
                            AsyncMock(return_value=True))
        monkeypatch.setattr(lfg_action, "safe_update_embed",
                            AsyncMock(return_value=True))

        cog = LookingForGameCog(bot)
        guild = build_guild()
        channel = build_channel(guild)
        default_seats = 4
        n = default_seats * 25
        interactions = [
            build_interaction(guild, channel, build_author(i))
            for i in range(n)
        ]
        tasks = [run_lfg(cog, interactions[i]) for i in range(n)]

        done, pending = await asyncio.wait(tasks)
        assert not pending
        for future in done:
            future.result()

        games = DatabaseSession.query(Game).order_by(Game.created_at).all()
        assert len(games) == n / default_seats

        # Since all these lfg requests should be handled concurrently, we should
        # see message_xids OUT of order in the created games (as ordered by created at).
        messages_out_of_order = False
        message_xid: Optional[int] = None
        for game in games:
            if message_xid is not None and game.message_xid != message_xid + 1:
                # At leat one game is out of order, this is good!
                messages_out_of_order = True
                break
            message_xid = game.message_xid
        assert messages_out_of_order
예제 #21
0
    async def test_concurrent_lfg_requests_different_channels(
            self, bot: SpellBot):
        cog = LookingForGameCog(bot)
        guild = build_guild()
        n = 100
        interactions = [
            build_interaction(guild, build_channel(guild, i), build_author(i))
            for i in range(n)
        ]
        tasks = [run_lfg(cog, interactions[i]) for i in range(n)]

        done, pending = await asyncio.wait(tasks)
        assert not pending
        for future in done:
            future.result()

        games = DatabaseSession.query(Game).order_by(Game.created_at).all()
        assert len(games) == n

        # Since all these lfg requests should be handled concurrently, we should
        # see message_xids OUT of order in the created games (as ordered by created at).
        messages_out_of_order = False
        message_xid: Optional[int] = None
        for game in games:
            if message_xid is not None and game.message_xid != message_xid + 1:
                # At leat one game is out of order, this is good!
                messages_out_of_order = True
                break
            message_xid = game.message_xid
        assert messages_out_of_order
예제 #22
0
 async def test_award_add_message_too_long(self, cog: AdminCog):
     message = "hippo " * 300
     await self.run(cog.award_add, count=1, role="role", message=message)
     self.interaction.response.send_message.assert_called_once_with(
         "Your message can't be longer than 500 characters.",
         ephemeral=True,
     )
     assert DatabaseSession.query(GuildAward).count() == 0
예제 #23
0
    async def test_games_add_points(self, game: Game):
        user1 = UserFactory.create(game=game)
        user2 = UserFactory.create(game=game)
        PlayFactory.create(user_xid=user1.xid, game_id=game.id, points=5)
        PlayFactory.create(user_xid=user2.xid, game_id=game.id, points=None)

        games = GamesService()
        await games.select(game.id)
        await games.add_points(user1.xid, 5)

        DatabaseSession.expire_all()
        found = DatabaseSession.query(Play).filter(
            Play.user_xid == user1.xid).one()
        assert found.points == 5
        found = DatabaseSession.query(Play).filter(
            Play.user_xid == user2.xid).one()
        assert found.points is None
예제 #24
0
    async def test_power_level(self, cog: ConfigCog):
        await self.run(cog.power, level=10)

        config = DatabaseSession.query(Config).one()
        assert self.interaction.guild is not None
        assert config.guild_xid == self.interaction.guild.id
        assert config.user_xid == self.interaction.user.id
        assert config.power_level == 10
예제 #25
0
 async def test_award_add_dupe(self, cog: AdminCog):
     await self.run(cog.award_add, count=10, role="role", message="message")
     self.interaction.response.send_message.reset_mock()
     await self.run(cog.award_add, count=10, role="role", message="message")
     self.interaction.response.send_message.assert_called_once_with(
         "There's already an award for players who reach that many games.",
         ephemeral=True,
     )
     assert DatabaseSession.query(GuildAward).count() == 1
예제 #26
0
 async def test_default_seats(self, cog: AdminCog) -> None:
     seats = Channel.default_seats.default.arg - 1  # type: ignore
     await self.run(cog.default_seats, seats=seats)
     self.interaction.response.send_message.assert_called_once_with(
         f"Default seats set to {seats} for this channel.",
         ephemeral=True,
     )
     channel = DatabaseSession.query(Channel).one()
     assert channel.default_seats == seats
예제 #27
0
 async def test_unverified_only(self, cog: AdminCog):
     default_value = Channel.unverified_only.default.arg  # type: ignore
     await self.run(cog.unverified_only, setting=not default_value)
     self.interaction.response.send_message.assert_called_once_with(
         f"Unverified only set to {not default_value} for this channel.",
         ephemeral=True,
     )
     channel = DatabaseSession.query(Channel).one()
     assert channel.unverified_only != default_value
예제 #28
0
    async def test_lfg_alone_when_no_game(self, guild: Guild, channel: Channel,
                                          user: User):
        games = GamesService()
        new = await games.upsert(
            guild_xid=guild.xid,
            channel_xid=channel.xid,
            author_xid=user.xid,
            friends=[],
            seats=4,
            format=GameFormat.COMMANDER.value,
        )
        assert new

        DatabaseSession.expire_all()
        found_user = DatabaseSession.query(User).one()
        found_game = DatabaseSession.query(Game).one()
        assert found_game.guild_xid == guild.xid
        assert found_game.channel_xid == channel.xid
        assert found_user.game_id == found_game.id
예제 #29
0
    async def test_lfg_with_friend_when_existing_game(self, game: Game):
        user1 = UserFactory.create(xid=101)
        user2 = UserFactory.create(xid=102)

        games = GamesService()
        new = await games.upsert(
            guild_xid=game.guild.xid,
            channel_xid=game.channel.xid,
            author_xid=user1.xid,
            friends=[user2.xid],
            seats=4,
            format=GameFormat.COMMANDER.value,
        )
        assert not new

        DatabaseSession.expire_all()
        rows = DatabaseSession.query(
            User.xid).filter(User.game_id == game.id).all()
        assert set(row[0] for row in rows) == {101, 102}
예제 #30
0
 async def test_voice_category(self, cog: AdminCog):
     default_value = Channel.voice_category.default.arg  # type: ignore
     new_value = "wotnot" + default_value
     await self.run(cog.voice_category, prefix=new_value)
     self.interaction.response.send_message.assert_called_once_with(
         f"Voice category prefix for this channel has been set to: {new_value}",
         ephemeral=True,
     )
     channel = DatabaseSession.query(Channel).one()
     assert channel.voice_category != default_value