コード例 #1
0
    async def fetch_peers(self, peers: List[Union[raw.types.User, raw.types.Chat, raw.types.Channel]]) -> bool:
        is_min = False
        parsed_peers = []

        for peer in peers:
            if getattr(peer, "min", False):
                is_min = True
                continue

            username = None
            phone_number = None

            if isinstance(peer, raw.types.User):
                peer_id = peer.id
                access_hash = peer.access_hash
                username = (peer.username or "").lower() or None
                phone_number = peer.phone
                peer_type = "bot" if peer.bot else "user"
            elif isinstance(peer, (raw.types.Chat, raw.types.ChatForbidden)):
                peer_id = -peer.id
                access_hash = 0
                peer_type = "group"
            elif isinstance(peer, (raw.types.Channel, raw.types.ChannelForbidden)):
                peer_id = utils.get_channel_id(peer.id)
                access_hash = peer.access_hash
                username = (getattr(peer, "username", None) or "").lower() or None
                peer_type = "channel" if peer.broadcast else "supergroup"
            else:
                continue

            parsed_peers.append((peer_id, access_hash, peer_type, username, phone_number))

        await self.storage.update_peers(parsed_peers)

        return is_min
コード例 #2
0
ファイル: chat.py プロジェクト: uzkabuja/pyrogram
    def _parse_channel_chat(client, channel: raw.types.Channel) -> "Chat":
        peer_id = utils.get_channel_id(channel.id)
        restriction_reason = getattr(channel, "restriction_reason", [])

        return Chat(
            id=peer_id,
            type="supergroup" if channel.megagroup else "channel",
            is_verified=getattr(channel, "verified", None),
            is_restricted=getattr(channel, "restricted", None),
            is_creator=getattr(channel, "creator", None),
            is_scam=getattr(channel, "scam", None),
            is_fake=getattr(channel, "fake", None),
            title=channel.title,
            username=getattr(channel, "username", None),
            photo=types.ChatPhoto._parse(client,
                                         getattr(channel, "photo", None),
                                         peer_id, channel.access_hash),
            restrictions=types.List(
                [types.Restriction._parse(r) for r in restriction_reason])
            or None,
            permissions=types.ChatPermissions._parse(
                getattr(channel, "default_banned_rights", None)),
            members_count=getattr(channel, "participants_count", None),
            dc_id=getattr(getattr(channel, "photo", None), "dc_id", None),
            client=client)
コード例 #3
0
    def _parse_channel_chat(client, channel: raw.types.Channel) -> "Chat":
        peer_id = utils.get_channel_id(channel.id)
        restriction_reason = getattr(channel, "restriction_reason", [])

        return Chat(
            id=peer_id,
            type=enums.ChatType.SUPERGROUP
            if getattr(channel, "megagroup", None) else enums.ChatType.CHANNEL,
            is_verified=getattr(channel, "verified", None),
            is_restricted=getattr(channel, "restricted", None),
            is_creator=getattr(channel, "creator", None),
            is_scam=getattr(channel, "scam", None),
            is_fake=getattr(channel, "fake", None),
            title=channel.title,
            username=getattr(channel, "username", None),
            photo=types.ChatPhoto._parse(client,
                                         getattr(channel, "photo",
                                                 None), peer_id,
                                         getattr(channel, "access_hash", 0)),
            restrictions=types.List(
                [types.Restriction._parse(r) for r in restriction_reason])
            or None,
            permissions=types.ChatPermissions._parse(
                getattr(channel, "default_banned_rights", None)),
            members_count=getattr(channel, "participants_count", None),
            dc_id=getattr(getattr(channel, "photo", None), "dc_id", None),
            has_protected_content=getattr(channel, "noforwards", None),
            client=client)
コード例 #4
0
    async def get_chat(
            self,
            chat_id: Union[int,
                           str]) -> Union["types.Chat", "types.ChatPreview"]:
        """Get up to date information about a chat.

        Information include current name of the user for one-on-one conversations, current username of a user, group or
        channel, etc.

        Parameters:
            chat_id (``int`` | ``str``):
                Unique identifier (int) or username (str) of the target chat.
                Unique identifier for the target chat in form of a *t.me/joinchat/* link, identifier (int) or username
                of the target channel/supergroup (in the format @username).

        Returns:
            :obj:`~pyrogram.types.Chat` | :obj:`~pyrogram.types.ChatPreview`: On success, if you've already joined the chat, a chat object is returned,
            otherwise, a chat preview object is returned.

        Raises:
            ValueError: In case the chat invite link points to a chat you haven't joined yet.

        Example:
            .. code-block:: python

                chat = app.get_chat("pyrogram")
                print(chat)
        """
        match = self.INVITE_LINK_RE.match(str(chat_id))

        if match:
            r = await self.send(
                raw.functions.messages.CheckChatInvite(hash=match.group(1)))

            if isinstance(r, raw.types.ChatInvite):
                return types.ChatPreview._parse(self, r)

            await self.fetch_peers([r.chat])

            if isinstance(r.chat, raw.types.Chat):
                chat_id = -r.chat.id

            if isinstance(r.chat, raw.types.Channel):
                chat_id = utils.get_channel_id(r.chat.id)

        peer = await self.resolve_peer(chat_id)

        if isinstance(peer, raw.types.InputPeerChannel):
            r = await self.send(
                raw.functions.channels.GetFullChannel(channel=peer))
        elif isinstance(peer,
                        (raw.types.InputPeerUser, raw.types.InputPeerSelf)):
            r = await self.send(raw.functions.users.GetFullUser(id=peer))
        else:
            r = await self.send(
                raw.functions.messages.GetFullChat(chat_id=peer.chat_id))

        return await types.Chat._parse_full(self, r)
コード例 #5
0
    def _parse(client, action: raw.base.MessageAction, users: dict,
               chats: dict):
        if action is None:
            return None

        return MessageActionChatMigrateTo(
            client=client,
            channel_id=utils.get_channel_id(action.channel_id),
        )
コード例 #6
0
async def link_message(client, message):
    if message.chat.type == "private" and message.chat.type ==  "bot":
        await message.delete()
        return
    else:
        if message.reply_to_message:
            b = message.reply_to_message.message_id
        else:
            b = message.message_id
        a = utils.get_channel_id(message.chat.id)
        await edrep(message, text=f'https://t.me/c/{a}/{b}')
コード例 #7
0
    def _parse_input_channel(client: "pyrogram.Client",
                             input_channel: raw.types.InputChannel):
        if input_channel is None:
            return None

        peer_id = utils.get_channel_id(input_channel.channel_id)

        return Channel(
            client=client,
            id=peer_id,
            access_hash=getattr(input_channel, 'access_hash', None),
        )
コード例 #8
0
    def _parse_channel_chat(client,
                            channel: raw.types.Channel) -> Optional["Chat"]:
        if channel is None:
            return None

        peer_id = utils.get_channel_id(channel.id)
        return Chat(
            client=client,
            id=peer_id,
            type="supergroup" if channel.megagroup else "channel",
            channel=types.Channel._parse(client, channel),
            is_full_type=False,
        )
コード例 #9
0
    async def get_nearby_chats(
        self,
        latitude: float,
        longitude: float
    ) -> List["types.Chat"]:
        """Get nearby chats.

        Parameters:
            latitude (``float``):
                Latitude of the location.

            longitude (``float``):
                Longitude of the location.

        Returns:
            List of :obj:`~pyrogram.types.Chat`: On success, a list of nearby chats is returned.

        Example:
            .. code-block:: python

                chats = app.get_nearby_chats(51.500729, -0.124583)
                print(chats)
        """

        r = await self.send(
            raw.functions.contacts.GetLocated(
                geo_point=raw.types.InputGeoPoint(
                    lat=latitude,
                    long=longitude
                )
            )
        )

        if not r.updates:
            return []

        chats = types.List([types.Chat._parse_chat(self, chat) for chat in r.chats])
        peers = r.updates[0].peers

        for peer in peers:
            if isinstance(peer.peer, raw.types.PeerChannel):
                chat_id = utils.get_channel_id(peer.peer.channel_id)

                for chat in chats:
                    if chat.id == chat_id:
                        chat.distance = peer.distance
                        break

        return chats
コード例 #10
0
    async def _parse_channel_full_chat(client,
                                       chat_full: raw.types.messages.ChatFull,
                                       users: dict, chats: dict) -> "Chat":
        peer_id = utils.get_channel_id(chat_full.full_chat.id)

        channel = chats[chat_full.full_chat.id]
        return Chat(
            client=client,
            id=peer_id,
            type="supergroup" if channel.megagroup else "channel",
            full_channel=await
            types.ChannelFull._parse(client, chat_full.full_chat, users,
                                     chats),
            channel=types.Channel._parse(client, channel),
            is_full_type=True,
        )
コード例 #11
0
def get_input_peer(peer):
    """ This function is almost blindly copied from pyrogram sqlite storage"""
    peer_id, peer_type, access_hash = peer['id'], peer['type'], peer[
        'access_hash']

    if peer_type in {PeerType.USER.value, PeerType.BOT.value}:
        return raw.types.InputPeerUser(user_id=peer_id,
                                       access_hash=access_hash)

    if peer_type == PeerType.GROUP.value:
        return raw.types.InputPeerChat(chat_id=-peer_id)

    if peer_type in {PeerType.CHANNEL.value, PeerType.SUPERGROUP.value}:
        return raw.types.InputPeerChannel(
            channel_id=utils.get_channel_id(peer_id), access_hash=access_hash)

    raise ValueError(f"Invalid peer type: {peer['type']}")
コード例 #12
0
    async def _parse(client: "pyrogram.Client",
                     message_replies: raw.base.MessageReplies, users: dict,
                     chats: dict):
        if message_replies is None:
            return None

        def get_replier(peer: "raw.base.Peer"):
            if peer is None:
                return None, None
            if isinstance(peer, raw.types.PeerUser):
                return users.get(peer.user_id, None), 'user'
            elif isinstance(peer, raw.types.PeerChat):
                return chats.get(peer.chat_id, None), 'group'
            elif isinstance(peer, raw.types.PeerChannel):
                return chats.get(peer.channel_id, None), 'channel'

        recent_repliers = None
        if message_replies.recent_repliers:
            parsed_peers = []
            for peer in message_replies.recent_repliers:
                _peer, _type = get_replier(peer)
                if peer is None and _type is None:
                    continue

                if _type == 'user':
                    parsed_peer = types.User._parse(client, _peer)
                else:
                    parsed_peer = await types.Chat._parse_chat(client, _peer)

                if parsed_peer:
                    parsed_peers.append(parsed_peer)

            if len(parsed_peers):
                recent_repliers = types.List(parsed_peers)

        return MessageReplies(
            client=client,
            comments=getattr(message_replies, 'comments', None),
            replies=getattr(message_replies, 'replies', None),
            replies_pts=getattr(message_replies, 'replies_pts', None),
            recent_repliers=recent_repliers,
            channel_id=utils.get_channel_id(message_replies.channel_id)
            if getattr(message_replies, 'channel_id', None) else None,
            max_id=getattr(message_replies, 'max_id', None),
            read_max_id=getattr(message_replies, 'read_max_id', None),
        )
コード例 #13
0
    def _parse(client, channel):
        if channel is None:
            return None

        peer_id = utils.get_channel_id(channel.id)

        return Channel(
            client=client,
            id=peer_id,
            title=getattr(channel, 'title', None),
            is_forbidden=isinstance(channel, raw.types.ChannelForbidden),
            forbidden_until=getattr(channel, 'until_date', None),
            username=getattr(channel, 'username', None),
            photo=types.ChatPhoto._parse(client,
                                         getattr(channel, "photo",
                                                 None), peer_id,
                                         getattr(channel, 'access_hash', 0)),
            is_creator=getattr(channel, 'creator', None),
            left=getattr(channel, 'left', None),
            is_broadcast=getattr(channel, 'broadcast', None),
            is_verified=getattr(channel, 'verified', None),
            is_supergroup=getattr(channel, 'megagroup', None),
            is_restricted=getattr(channel, 'restricted', None),
            signatures_enabled=getattr(channel, 'signatures', None),
            min=getattr(channel, 'min', None),
            is_scam=getattr(channel, 'scam', None),
            is_fake=getattr(channel, "fake", None),
            has_private_join_link=getattr(channel, 'has_link', None),
            has_geo=getattr(channel, 'has_geo', None),
            slow_mode=getattr(channel, 'slowmode_enabled', None),
            access_hash=getattr(channel, 'access_hash', None),
            date=getattr(channel, 'date', None),
            version=getattr(channel, 'version', None),
            restrictions=types.List([
                types.Restriction._parse(r)
                for r in getattr(channel, 'restriction_reason', [])
            ]) or None,
            admin_rights=types.ChatAdminRights._parse(
                getattr(channel, 'admin_rights', None)),
            banned_rights=types.ChatPermissions._parse(
                getattr(channel, 'banned_rights', None)),
            default_banned_rights=types.ChatPermissions._parse(
                getattr(channel, 'default_banned_rights', None)),
            members_count=getattr(channel, 'participants_count', None),
        )
コード例 #14
0
    async def get_location(
        file_id: FileId
    ) -> Union[raw.types.InputPhotoFileLocation,
               raw.types.InputDocumentFileLocation,
               raw.types.InputPeerPhotoFileLocation, ]:
        """
        Returns the file location for the media file.
        """
        file_type = file_id.file_type

        if file_type == FileType.CHAT_PHOTO:
            if file_id.chat_id > 0:
                peer = raw.types.InputPeerUser(
                    user_id=file_id.chat_id,
                    access_hash=file_id.chat_access_hash)
            else:
                if file_id.chat_access_hash == 0:
                    peer = raw.types.InputPeerChat(chat_id=-file_id.chat_id)
                else:
                    peer = raw.types.InputPeerChannel(
                        channel_id=utils.get_channel_id(file_id.chat_id),
                        access_hash=file_id.chat_access_hash,
                    )

            location = raw.types.InputPeerPhotoFileLocation(
                peer=peer,
                volume_id=file_id.volume_id,
                local_id=file_id.local_id,
                big=file_id.thumbnail_source == ThumbnailSource.CHAT_PHOTO_BIG,
            )
        elif file_type == FileType.PHOTO:
            location = raw.types.InputPhotoFileLocation(
                id=file_id.media_id,
                access_hash=file_id.access_hash,
                file_reference=file_id.file_reference,
                thumb_size=file_id.thumbnail_size,
            )
        else:
            location = raw.types.InputDocumentFileLocation(
                id=file_id.media_id,
                access_hash=file_id.access_hash,
                file_reference=file_id.file_reference,
                thumb_size=file_id.thumbnail_size,
            )
        return location
コード例 #15
0
ファイル: metrics.py プロジェクト: legenhand/Nana-Remix
async def get_inactive(client, message):
    cmd = message.command
    start = time.time()
    limit = int(" ".join(cmd[1:])) if len(cmd) > 1 else 0
    messages = [
        m async for member in client.iter_chat_members(
            message.chat.id, limit=limit, filter="recent")
        if not member.user.is_deleted async for m in client.search_messages(
            message.chat.id, limit=1, from_user=member.user.id)
    ]

    delta = time.time() - start
    messages.sort(key=lambda k: k["date"])

    return "\n".join([
        "[{}](tg://user?id={}) last [message](https://t.me/c/{}/{}) was {}".
        format(
            m.from_user.first_name,
            m.from_user.id,
            get_channel_id(m.chat.id),
            m.message_id,
            timeago.format(m.date),
        ) for m in messages
    ] + [f"`{int(delta * 1000)}ms`"])
コード例 #16
0
    async def resolve_peer(
        self,
        peer_id: Union[int, str]
    ) -> Union[raw.base.InputPeer, raw.base.InputUser, raw.base.InputChannel]:
        """Get the InputPeer of a known peer id.
        Useful whenever an InputPeer type is required.

        .. note::

            This is a utility method intended to be used **only** when working with raw
            :obj:`functions <pyrogram.api.functions>` (i.e: a Telegram API method you wish to use which is not
            available yet in the Client class as an easy-to-use method).

        Parameters:
            peer_id (``int`` | ``str``):
                The peer id you want to extract the InputPeer from.
                Can be a direct id (int), a username (str) or a phone number (str).

        Returns:
            ``InputPeer``: On success, the resolved peer id is returned in form of an InputPeer object.

        Raises:
            KeyError: In case the peer doesn't exist in the internal database.
        """
        if not self.is_connected:
            raise ConnectionError("Client has not been started yet")

        try:
            return await self.storage.get_peer_by_id(peer_id)
        except KeyError:
            if isinstance(peer_id, str):
                if peer_id in ("self", "me"):
                    return raw.types.InputPeerSelf()

                peer_id = re.sub(r"[@+\s]", "", peer_id.lower())

                try:
                    int(peer_id)
                except ValueError:
                    try:
                        return await self.storage.get_peer_by_username(peer_id)
                    except KeyError:
                        await self.send(
                            raw.functions.contacts.ResolveUsername(
                                username=peer_id
                            )
                        )

                        return await self.storage.get_peer_by_username(peer_id)
                else:
                    try:
                        return await self.storage.get_peer_by_phone_number(peer_id)
                    except KeyError:
                        raise PeerIdInvalid

            peer_type = utils.get_peer_type(peer_id)

            if peer_type == "user":
                await self.fetch_peers(
                    await self.send(
                        raw.functions.users.GetUsers(
                            id=[
                                raw.types.InputUser(
                                    user_id=peer_id,
                                    access_hash=0
                                )
                            ]
                        )
                    )
                )
            elif peer_type == "chat":
                await self.send(
                    raw.functions.messages.GetChats(
                        id=[-peer_id]
                    )
                )
            else:
                await self.send(
                    raw.functions.channels.GetChannels(
                        id=[
                            raw.types.InputChannel(
                                channel_id=utils.get_channel_id(peer_id),
                                access_hash=0
                            )
                        ]
                    )
                )

            try:
                return await self.storage.get_peer_by_id(peer_id)
            except KeyError:
                raise PeerIdInvalid
コード例 #17
0
ファイル: client.py プロジェクト: TSpidermanBoss/pyrogramx
    async def handle_updates(self, updates):
        if isinstance(updates, (raw.types.Updates, raw.types.UpdatesCombined)):
            is_min = (await self.fetch_peers(
                updates.users)) or (await self.fetch_peers(updates.chats))

            users = {u.id: u for u in updates.users}
            chats = {c.id: c for c in updates.chats}

            for update in updates.updates:
                channel_id = getattr(
                    getattr(getattr(update, "message", None), "to_id", None),
                    "channel_id", None) or getattr(update, "channel_id", None)

                pts = getattr(update, "pts", None)
                pts_count = getattr(update, "pts_count", None)

                if isinstance(update, raw.types.UpdateChannelTooLong):
                    log.warning(update)

                if isinstance(update,
                              raw.types.UpdateNewChannelMessage) and is_min:
                    message = update.message

                    if not isinstance(message, raw.types.MessageEmpty):
                        try:
                            diff = await self.send(
                                raw.functions.updates.GetChannelDifference(
                                    channel=await self.resolve_peer(
                                        utils.get_channel_id(channel_id)),
                                    filter=raw.types.ChannelMessagesFilter(
                                        ranges=[
                                            raw.types.MessageRange(
                                                min_id=update.message.id,
                                                max_id=update.message.id)
                                        ]),
                                    pts=pts - pts_count,
                                    limit=pts))
                        except ChannelPrivate:
                            pass
                        else:
                            if not isinstance(
                                    diff,
                                    raw.types.updates.ChannelDifferenceEmpty):
                                users.update({u.id: u for u in diff.users})
                                chats.update({c.id: c for c in diff.chats})

                self.dispatcher.updates_queue.put_nowait(
                    (update, users, chats))
        elif isinstance(
                updates,
            (raw.types.UpdateShortMessage, raw.types.UpdateShortChatMessage)):
            diff = await self.send(
                raw.functions.updates.GetDifference(pts=updates.pts -
                                                    updates.pts_count,
                                                    date=updates.date,
                                                    qts=-1))

            if diff.new_messages:
                self.dispatcher.updates_queue.put_nowait(
                    (raw.types.UpdateNewMessage(message=diff.new_messages[0],
                                                pts=updates.pts,
                                                pts_count=updates.pts_count),
                     {u.id: u
                      for u in diff.users}, {c.id: c
                                             for c in diff.chats}))
            else:
                self.dispatcher.updates_queue.put_nowait(
                    (diff.other_updates[0], {}, {}))
        elif isinstance(updates, raw.types.UpdateShort):
            self.dispatcher.updates_queue.put_nowait((updates.update, {}, {}))
        elif isinstance(updates, raw.types.UpdatesTooLong):
            log.info(updates)
コード例 #18
0
    async def _parse(client: "pyrogram.Client",
                     channel_full: raw.types.ChannelFull, users, chats):
        if channel_full is None:
            return None

        peer_id = utils.get_channel_id(channel_full.id)
        raw_linked_chat = chats.get(
            getattr(channel_full, 'linked_chat_id', None), None)
        if raw_linked_chat:
            linked_chat = await types.Chat._parse_chat(client, raw_linked_chat)
        else:
            linked_chat = None

        migrated_from = None
        if getattr(channel_full, 'migrated_from_chat_id'):
            chat = chats.get(
                getattr(channel_full, 'migrated_from_chat_id', None), None)
            if chat:
                migrated_from = await types.Chat._parse_chat(client, chat)

        return ChannelFull(
            client=client,
            id=peer_id,
            can_view_participants=getattr(channel_full,
                                          'can_view_participants', None),
            can_set_username=getattr(channel_full, 'can_set_username', None),
            can_set_stickers=getattr(channel_full, 'can_set_stickers', None),
            is_prehistory_hidden=getattr(channel_full, 'hidden_prehistory',
                                         None),
            can_set_location=getattr(channel_full, 'can_set_location', None),
            has_scheduled=getattr(channel_full, 'has_scheduled', None),
            can_view_stats=getattr(channel_full, 'can_view_stats', None),
            is_blocked=getattr(channel_full, 'blocked', None),
            about=getattr(channel_full, 'about', None),
            members_count=getattr(channel_full, 'participants_count', None),
            admins_count=getattr(channel_full, 'admins_count', None),
            kicked_count=getattr(channel_full, 'kicked_count', None),
            banned_count=getattr(channel_full, 'banned_count', None),
            online_count=getattr(channel_full, 'online_count', None),
            read_inbox_max_id=getattr(channel_full, 'read_inbox_max_id', None),
            read_outbox_max_id=getattr(channel_full, 'read_outbox_max_id',
                                       None),
            unread_count=getattr(channel_full, 'unread_count', None),
            chat_photo=types.Photo._parse(
                client, getattr(channel_full, "chat_photo", None)),
            notify_settings=types.PeerNotifySettings._parse(
                client, getattr(channel_full, 'notify_settings', None)),
            pinned_message=await client.get_messages(
                peer_id, channel_full.pinned_msg_id) if getattr(
                    channel_full, 'pinned_msg_id') else None,
            invite_link=channel_full.exported_invite.link if isinstance(
                channel_full.exported_invite, raw.types.ChatInviteExported)
            else None,
            bot_infos=types.List([
                types.BotInfo._parse(client, r)
                for r in getattr(channel_full, 'bot_info', [])
            ]) or None,
            migrated_from=migrated_from,
            migrated_from_max_id=getattr(channel_full, 'migrated_from_max_id',
                                         None),
            stickerset=types.StickerSet._parse(
                client, getattr(channel_full, 'stickerset', None)),
            min_available_message_id=getattr(channel_full, 'available_min_id',
                                             None),
            folder_id=getattr(channel_full, 'folder_id', None),
            linked_chat=linked_chat,
            location=types.ChannelLocation._parse(
                client, getattr(channel_full, 'location', None)),
            slowmode_seconds=getattr(channel_full, 'slowmode_seconds', None),
            slowmode_next_send_date=getattr(channel_full,
                                            'slowmode_next_send_date', None),
            stats_dc=getattr(channel_full, 'stats_dc', None),
            pts=getattr(channel_full, 'pts', None),
        )
コード例 #19
0
    async def generate_file_properties(msg: Message):
        error_message = "This message doesn't contain any downloadable media"
        available_media = ("audio", "document", "photo", "sticker", "animation", "video", "voice", "video_note")

        media_file_name = None
        file_size = None
        mime_type = None
        date = None

        for kind in available_media:
            media = getattr(msg, kind, None)

            if media is not None:
                break
        else:
            raise ValueError(error_message)

        if isinstance(media, str):
            file_id_str = media
        else:
            file_id_str = media.file_id
            media_file_name = getattr(media, "file_name", "")
            file_size = getattr(media, "file_size", None)
            mime_type = getattr(media, "mime_type", None)
            date = getattr(media, "date", None)
            file_ref = getattr(media, "file_ref", None)

        data = FileData(
            file_name=media_file_name,
            file_size=file_size,
            mime_type=mime_type,
            date=date,
            file_ref=file_ref
        )

        def get_existing_attributes() -> dict:
            return dict(filter(lambda x: x[1] is not None, data.__dict__.items()))

        try:
            decoded = utils.decode_file_id(file_id_str)
            media_type = decoded[0]

            if media_type == 1:
                unpacked = struct.unpack("<iiqqqiiiqi", decoded)
                dc_id, _1, _2, volume_id, size_type, peer_id, x, peer_access_hash, local_id = unpacked[1:]

                if x == 0:
                    peer_type = "user"
                elif x == -1:
                    peer_id = -peer_id
                    peer_type = "chat"
                else:
                    peer_id = utils.get_channel_id(peer_id - 1000727379968)
                    peer_type = "channel"

                data = FileData(
                    **get_existing_attributes(),
                    media_type=media_type,
                    dc_id=dc_id,
                    peer_id=peer_id,
                    peer_type=peer_type,
                    peer_access_hash=peer_access_hash,
                    volume_id=volume_id,
                    local_id=local_id,
                    is_big=size_type == 3
                )
            elif media_type in (0, 2, 14):
                unpacked = struct.unpack("<iiqqqiiii", decoded)
                dc_id, document_id, access_hash, volume_id, _, _, thumb_size, local_id = unpacked[1:]

                data = FileData(
                    **get_existing_attributes(),
                    media_type=media_type,
                    dc_id=dc_id,
                    document_id=document_id,
                    access_hash=access_hash,
                    thumb_size=chr(thumb_size)
                )
            elif media_type in (3, 4, 5, 8, 9, 10, 13):
                unpacked = struct.unpack("<iiqq", decoded)
                dc_id, document_id, access_hash = unpacked[1:]

                data = FileData(
                    **get_existing_attributes(),
                    media_type=media_type,
                    dc_id=dc_id,
                    document_id=document_id,
                    access_hash=access_hash
                )
            else:
                raise ValueError(f"Unknown media type: {file_id_str}")
            return data
        except (AssertionError, binascii.Error, struct.error):
            raise FileIdInvalid from None
コード例 #20
0
    async def get_file(self,
                       file_id: FileId,
                       file_size: int,
                       progress: callable,
                       progress_args: tuple = ()) -> str:
        dc_id = file_id.dc_id

        async with self.media_sessions_lock:
            session = self.media_sessions.get(dc_id, None)

            if session is None:
                if dc_id != await self.storage.dc_id():
                    session = Session(self,
                                      dc_id,
                                      await
                                      Auth(self, dc_id, await
                                           self.storage.test_mode()).create(),
                                      await self.storage.test_mode(),
                                      is_media=True)
                    await session.start()

                    for _ in range(3):
                        exported_auth = await self.send(
                            raw.functions.auth.ExportAuthorization(dc_id=dc_id)
                        )

                        try:
                            await session.send(
                                raw.functions.auth.ImportAuthorization(
                                    id=exported_auth.id,
                                    bytes=exported_auth.bytes))
                        except AuthBytesInvalid:
                            continue
                        else:
                            break
                    else:
                        await session.stop()
                        raise AuthBytesInvalid
                else:
                    session = Session(self,
                                      dc_id,
                                      await self.storage.auth_key(),
                                      await self.storage.test_mode(),
                                      is_media=True)
                    await session.start()

                self.media_sessions[dc_id] = session

        file_type = file_id.file_type

        if file_type == FileType.CHAT_PHOTO:
            if file_id.chat_id > 0:
                peer = raw.types.InputPeerUser(
                    user_id=file_id.chat_id,
                    access_hash=file_id.chat_access_hash)
            else:
                if file_id.chat_access_hash == 0:
                    peer = raw.types.InputPeerChat(chat_id=-file_id.chat_id)
                else:
                    peer = raw.types.InputPeerChannel(
                        channel_id=utils.get_channel_id(file_id.chat_id),
                        access_hash=file_id.chat_access_hash)

            location = raw.types.InputPeerPhotoFileLocation(
                peer=peer,
                volume_id=file_id.volume_id,
                local_id=file_id.local_id,
                big=file_id.thumbnail_source == ThumbnailSource.CHAT_PHOTO_BIG)
        elif file_type == FileType.PHOTO:
            location = raw.types.InputPhotoFileLocation(
                id=file_id.media_id,
                access_hash=file_id.access_hash,
                file_reference=file_id.file_reference,
                thumb_size=file_id.thumbnail_size)
        else:
            location = raw.types.InputDocumentFileLocation(
                id=file_id.media_id,
                access_hash=file_id.access_hash,
                file_reference=file_id.file_reference,
                thumb_size=file_id.thumbnail_size)

        limit = 1024 * 1024
        offset = 0
        file_name = ""

        try:
            r = await session.send(raw.functions.upload.GetFile(
                location=location, offset=offset, limit=limit),
                                   sleep_threshold=30)

            if isinstance(r, raw.types.upload.File):
                with tempfile.NamedTemporaryFile("wb", delete=False) as f:
                    file_name = f.name

                    while True:
                        chunk = r.bytes

                        if not chunk:
                            break

                        f.write(chunk)

                        offset += limit

                        if progress:
                            func = functools.partial(
                                progress,
                                min(offset, file_size) if file_size != 0 else
                                offset, file_size, *progress_args)

                            if inspect.iscoroutinefunction(progress):
                                await func()
                            else:
                                await self.loop.run_in_executor(
                                    self.executor, func)

                        r = await session.send(raw.functions.upload.GetFile(
                            location=location, offset=offset, limit=limit),
                                               sleep_threshold=30)

            elif isinstance(r, raw.types.upload.FileCdnRedirect):
                async with self.media_sessions_lock:
                    cdn_session = self.media_sessions.get(r.dc_id, None)

                    if cdn_session is None:
                        cdn_session = Session(
                            self,
                            r.dc_id,
                            await Auth(self, r.dc_id, await
                                       self.storage.test_mode()).create(),
                            await self.storage.test_mode(),
                            is_media=True,
                            is_cdn=True)

                        await cdn_session.start()

                        self.media_sessions[r.dc_id] = cdn_session

                try:
                    with tempfile.NamedTemporaryFile("wb", delete=False) as f:
                        file_name = f.name

                        while True:
                            r2 = await cdn_session.send(
                                raw.functions.upload.GetCdnFile(
                                    file_token=r.file_token,
                                    offset=offset,
                                    limit=limit))

                            if isinstance(
                                    r2,
                                    raw.types.upload.CdnFileReuploadNeeded):
                                try:
                                    await session.send(
                                        raw.functions.upload.ReuploadCdnFile(
                                            file_token=r.file_token,
                                            request_token=r2.request_token))
                                except VolumeLocNotFound:
                                    break
                                else:
                                    continue

                            chunk = r2.bytes

                            # https://core.telegram.org/cdn#decrypting-files
                            decrypted_chunk = aes.ctr256_decrypt(
                                chunk, r.encryption_key,
                                bytearray(r.encryption_iv[:-4] +
                                          (offset // 16).to_bytes(4, "big")))

                            hashes = await session.send(
                                raw.functions.upload.GetCdnFileHashes(
                                    file_token=r.file_token, offset=offset))

                            # https://core.telegram.org/cdn#verifying-files
                            for i, h in enumerate(hashes):
                                cdn_chunk = decrypted_chunk[h.limit *
                                                            i:h.limit *
                                                            (i + 1)]
                                assert h.hash == sha256(cdn_chunk).digest(
                                ), f"Invalid CDN hash part {i}"

                            f.write(decrypted_chunk)

                            offset += limit

                            if progress:
                                func = functools.partial(
                                    progress,
                                    min(offset, file_size) if file_size != 0
                                    else offset, file_size, *progress_args)

                                if inspect.iscoroutinefunction(progress):
                                    await func()
                                else:
                                    await self.loop.run_in_executor(
                                        self.executor, func)

                            if len(chunk) < limit:
                                break
                except Exception as e:
                    raise e
        except Exception as e:
            if not isinstance(e, pyrogram.StopTransmission):
                log.error(e, exc_info=True)

            try:
                os.remove(file_name)
            except OSError:
                pass

            return ""
        else:
            return file_name
コード例 #21
0
    async def get_file(
        self,
        file_id: FileId,
        file_size: int = 0,
        limit: int = 0,
        offset: int = 0,
        progress: Callable = None,
        progress_args: tuple = ()
    ) -> Optional[AsyncGenerator[bytes, None]]:
        dc_id = file_id.dc_id

        async with self.media_sessions_lock:
            session = self.media_sessions.get(dc_id, None)

            if session is None:
                if dc_id != await self.storage.dc_id():
                    session = Session(self,
                                      dc_id,
                                      await
                                      Auth(self, dc_id, await
                                           self.storage.test_mode()).create(),
                                      await self.storage.test_mode(),
                                      is_media=True)
                    await session.start()

                    for _ in range(3):
                        exported_auth = await self.invoke(
                            raw.functions.auth.ExportAuthorization(dc_id=dc_id)
                        )

                        try:
                            await session.invoke(
                                raw.functions.auth.ImportAuthorization(
                                    id=exported_auth.id,
                                    bytes=exported_auth.bytes))
                        except AuthBytesInvalid:
                            continue
                        else:
                            break
                    else:
                        await session.stop()
                        raise AuthBytesInvalid
                else:
                    session = Session(self,
                                      dc_id,
                                      await self.storage.auth_key(),
                                      await self.storage.test_mode(),
                                      is_media=True)
                    await session.start()

                self.media_sessions[dc_id] = session

        file_type = file_id.file_type

        if file_type == FileType.CHAT_PHOTO:
            if file_id.chat_id > 0:
                peer = raw.types.InputPeerUser(
                    user_id=file_id.chat_id,
                    access_hash=file_id.chat_access_hash)
            else:
                if file_id.chat_access_hash == 0:
                    peer = raw.types.InputPeerChat(chat_id=-file_id.chat_id)
                else:
                    peer = raw.types.InputPeerChannel(
                        channel_id=utils.get_channel_id(file_id.chat_id),
                        access_hash=file_id.chat_access_hash)

            location = raw.types.InputPeerPhotoFileLocation(
                peer=peer,
                photo_id=file_id.media_id,
                big=file_id.thumbnail_source == ThumbnailSource.CHAT_PHOTO_BIG)
        elif file_type == FileType.PHOTO:
            location = raw.types.InputPhotoFileLocation(
                id=file_id.media_id,
                access_hash=file_id.access_hash,
                file_reference=file_id.file_reference,
                thumb_size=file_id.thumbnail_size)
        else:
            location = raw.types.InputDocumentFileLocation(
                id=file_id.media_id,
                access_hash=file_id.access_hash,
                file_reference=file_id.file_reference,
                thumb_size=file_id.thumbnail_size)

        current = 0
        total = abs(limit) or (1 << 31) - 1
        chunk_size = 1024 * 1024
        offset_bytes = abs(offset) * chunk_size

        try:
            r = await session.invoke(raw.functions.upload.GetFile(
                location=location, offset=offset_bytes, limit=chunk_size),
                                     sleep_threshold=30)

            if isinstance(r, raw.types.upload.File):
                while True:
                    chunk = r.bytes

                    yield chunk

                    current += 1
                    offset_bytes += chunk_size

                    if progress:
                        func = functools.partial(
                            progress,
                            min(offset_bytes, file_size) if file_size != 0 else
                            offset_bytes, file_size, *progress_args)

                        if inspect.iscoroutinefunction(progress):
                            await func()
                        else:
                            await self.loop.run_in_executor(
                                self.executor, func)

                    if len(chunk) < chunk_size or current >= total:
                        break

                    r = await session.invoke(raw.functions.upload.GetFile(
                        location=location,
                        offset=offset_bytes,
                        limit=chunk_size),
                                             sleep_threshold=30)

            elif isinstance(r, raw.types.upload.FileCdnRedirect):
                async with self.media_sessions_lock:
                    cdn_session = self.media_sessions.get(r.dc_id, None)

                    if cdn_session is None:
                        cdn_session = Session(
                            self,
                            r.dc_id,
                            await Auth(self, r.dc_id, await
                                       self.storage.test_mode()).create(),
                            await self.storage.test_mode(),
                            is_media=True,
                            is_cdn=True)

                        await cdn_session.start()

                        self.media_sessions[r.dc_id] = cdn_session

                try:
                    while True:
                        r2 = await cdn_session.invoke(
                            raw.functions.upload.GetCdnFile(
                                file_token=r.file_token,
                                offset=offset_bytes,
                                limit=chunk_size))

                        if isinstance(r2,
                                      raw.types.upload.CdnFileReuploadNeeded):
                            try:
                                await session.invoke(
                                    raw.functions.upload.ReuploadCdnFile(
                                        file_token=r.file_token,
                                        request_token=r2.request_token))
                            except VolumeLocNotFound:
                                break
                            else:
                                continue

                        chunk = r2.bytes

                        # https://core.telegram.org/cdn#decrypting-files
                        decrypted_chunk = aes.ctr256_decrypt(
                            chunk, r.encryption_key,
                            bytearray(r.encryption_iv[:-4] +
                                      (offset_bytes // 16).to_bytes(4, "big")))

                        hashes = await session.invoke(
                            raw.functions.upload.GetCdnFileHashes(
                                file_token=r.file_token, offset=offset_bytes))

                        # https://core.telegram.org/cdn#verifying-files
                        for i, h in enumerate(hashes):
                            cdn_chunk = decrypted_chunk[h.limit * i:h.limit *
                                                        (i + 1)]
                            CDNFileHashMismatch.check(
                                h.hash == sha256(cdn_chunk).digest())

                        yield decrypted_chunk

                        current += 1
                        offset_bytes += chunk_size

                        if progress:
                            func = functools.partial(
                                progress,
                                min(offset_bytes, file_size) if file_size != 0
                                else offset_bytes, file_size, *progress_args)

                            if inspect.iscoroutinefunction(progress):
                                await func()
                            else:
                                await self.loop.run_in_executor(
                                    self.executor, func)

                        if len(chunk) < chunk_size or current >= total:
                            break
                except Exception as e:
                    raise e
        except pyrogram.StopTransmission:
            raise
        except Exception as e:
            log.error(e, exc_info=True)
コード例 #22
0
    async def download_media(
        self,
        message: Union["types.Message", str],
        file_ref: str = None,
        file_name: str = DEFAULT_DOWNLOAD_DIR,
        block: bool = True,
        progress: callable = None,
        progress_args: tuple = ()) -> Union[str, None]:
        """Download the media from a message.

        Parameters:
            message (:obj:`~pyrogram.types.Message` | ``str``):
                Pass a Message containing the media, the media itself (message.audio, message.video, ...) or
                the file id as string.

            file_ref (``str``, *optional*):
                A valid file reference obtained by a recently fetched media message.
                To be used in combination with a file id in case a file reference is needed.

            file_name (``str``, *optional*):
                A custom *file_name* to be used instead of the one provided by Telegram.
                By default, all files are downloaded in the *downloads* folder in your working directory.
                You can also specify a path for downloading files in a custom location: paths that end with "/"
                are considered directories. All non-existent folders will be created automatically.

            block (``bool``, *optional*):
                Blocks the code execution until the file has been downloaded.
                Defaults to True.

            progress (``callable``, *optional*):
                Pass a callback function to view the file transmission progress.
                The function must take *(current, total)* as positional arguments (look at Other Parameters below for a
                detailed description) and will be called back each time a new file chunk has been successfully
                transmitted.

            progress_args (``tuple``, *optional*):
                Extra custom arguments for the progress callback function.
                You can pass anything you need to be available in the progress callback scope; for example, a Message
                object or a Client instance in order to edit the message with the updated progress status.

        Other Parameters:
            current (``int``):
                The amount of bytes transmitted so far.

            total (``int``):
                The total size of the file.

            *args (``tuple``, *optional*):
                Extra custom arguments as defined in the *progress_args* parameter.
                You can either keep *\*args* or add every single extra argument in your function signature.

        Returns:
            ``str`` | ``None``: On success, the absolute path of the downloaded file is returned, otherwise, in case
            the download failed or was deliberately stopped with :meth:`~pyrogram.Client.stop_transmission`, None is
            returned.

        Raises:
            ValueError: if the message doesn't contain any downloadable media

        Example:
            .. code-block:: python

                # Download from Message
                app.download_media(message)

                # Download from file id
                app.download_media("CAADBAADzg4AAvLQYAEz_x2EOgdRwBYE")

                # Keep track of the progress while downloading
                def progress(current, total):
                    print(f"{current * 100 / total:.1f}%")

                app.download_media(message, progress=progress)
        """
        error_message = "This message doesn't contain any downloadable media"
        available_media = ("audio", "document", "photo", "sticker",
                           "animation", "video", "voice", "video_note")

        media_file_name = None
        file_size = None
        mime_type = None
        date = None

        if isinstance(message, types.Message):
            for kind in available_media:
                media = getattr(message, kind, None)

                if media is not None:
                    break
            else:
                raise ValueError(error_message)
        else:
            media = message

        if isinstance(media, str):
            file_id_str = media
        else:
            file_id_str = media.file_id
            media_file_name = getattr(media, "file_name", "")
            file_size = getattr(media, "file_size", None)
            mime_type = getattr(media, "mime_type", None)
            date = getattr(media, "date", None)
            file_ref = getattr(media, "file_ref", None)

        data = FileData(file_name=media_file_name,
                        file_size=file_size,
                        mime_type=mime_type,
                        date=date,
                        file_ref=file_ref)

        def get_existing_attributes() -> dict:
            return dict(
                filter(lambda x: x[1] is not None, data.__dict__.items()))

        try:
            decoded = utils.decode_file_id(file_id_str)
            media_type = decoded[0]

            if media_type == 1:
                unpacked = struct.unpack("<iiqqqiiiqi", decoded)
                dc_id, photo_id, _, volume_id, size_type, peer_id, x, peer_access_hash, local_id = unpacked[
                    1:]

                if x == 0:
                    peer_type = "user"
                elif x == -1:
                    peer_id = -peer_id
                    peer_type = "chat"
                else:
                    peer_id = utils.get_channel_id(peer_id - 1000727379968)
                    peer_type = "channel"

                data = FileData(**get_existing_attributes(),
                                media_type=media_type,
                                dc_id=dc_id,
                                peer_id=peer_id,
                                peer_type=peer_type,
                                peer_access_hash=peer_access_hash,
                                volume_id=volume_id,
                                local_id=local_id,
                                is_big=size_type == 3)
            elif media_type in (0, 2, 14):
                unpacked = struct.unpack("<iiqqqiiii", decoded)
                dc_id, document_id, access_hash, volume_id, _, _, thumb_size, local_id = unpacked[
                    1:]

                data = FileData(**get_existing_attributes(),
                                media_type=media_type,
                                dc_id=dc_id,
                                document_id=document_id,
                                access_hash=access_hash,
                                thumb_size=chr(thumb_size))
            elif media_type in (3, 4, 5, 8, 9, 10, 13):
                unpacked = struct.unpack("<iiqq", decoded)
                dc_id, document_id, access_hash = unpacked[1:]

                data = FileData(**get_existing_attributes(),
                                media_type=media_type,
                                dc_id=dc_id,
                                document_id=document_id,
                                access_hash=access_hash)
            else:
                raise ValueError(f"Unknown media type: {file_id_str}")
        except (AssertionError, binascii.Error, struct.error):
            raise FileIdInvalid from None

        directory, file_name = os.path.split(file_name)
        file_name = file_name or data.file_name or ""

        if not os.path.isabs(file_name):
            directory = self.PARENT_DIR / (directory or DEFAULT_DOWNLOAD_DIR)

        media_type_str = self.MEDIA_TYPE_ID[data.media_type]

        if not file_name:
            guessed_extension = self.guess_extension(data.mime_type)

            if data.media_type in (0, 1, 2, 14):
                extension = ".jpg"
            elif data.media_type == 3:
                extension = guessed_extension or ".ogg"
            elif data.media_type in (4, 10, 13):
                extension = guessed_extension or ".mp4"
            elif data.media_type == 5:
                extension = guessed_extension or ".zip"
            elif data.media_type == 8:
                extension = guessed_extension or ".webp"
            elif data.media_type == 9:
                extension = guessed_extension or ".mp3"
            else:
                extension = ".unknown"

            file_name = "{}_{}_{}{}".format(
                media_type_str,
                datetime.fromtimestamp(
                    data.date or time.time()).strftime("%Y-%m-%d_%H-%M-%S"),
                self.rnd_id(), extension)

        downloader = self.handle_download(
            (data, directory, file_name, progress, progress_args))

        if block:
            return await downloader
        else:
            asyncio.get_event_loop().create_task(downloader)