예제 #1
0
파일: auth.py 프로젝트: bsharp1001/telefy
class Auth:
    MAX_RETRIES = 5

    def __init__(self, client: "pyrogram.Client", dc_id: int):
        self.dc_id = dc_id
        self.test_mode = client.storage.test_mode()
        self.ipv6 = client.ipv6
        self.proxy = client.proxy

        self.connection = None

    @staticmethod
    def pack(data: TLObject) -> bytes:
        return (bytes(8) + Long(MsgId()) + Int(len(data.write())) +
                data.write())

    @staticmethod
    def unpack(b: BytesIO):
        b.seek(
            20)  # Skip auth_key_id (8), message_id (8) and message_length (4)
        return TLObject.read(b)

    def send(self, data: TLObject):
        data = self.pack(data)
        self.connection.send(data)
        response = BytesIO(self.connection.recv())

        return self.unpack(response)

    def create(self):
        """
        https://core.telegram.org/mtproto/auth_key
        https://core.telegram.org/mtproto/samples-auth_key
        """
        retries_left = self.MAX_RETRIES

        # The server may close the connection at any time, causing the auth key creation to fail.
        # If that happens, just try again up to MAX_RETRIES times.
        while True:
            self.connection = Connection(self.dc_id, self.test_mode, self.ipv6,
                                         self.proxy)

            try:
                log.info("Start creating a new auth key on DC{}".format(
                    self.dc_id))

                self.connection.connect()

                # Step 1; Step 2
                nonce = int.from_bytes(urandom(16), "little", signed=True)
                log.debug("Send req_pq: {}".format(nonce))
                res_pq = self.send(functions.ReqPqMulti(nonce=nonce))
                log.debug("Got ResPq: {}".format(res_pq.server_nonce))
                log.debug("Server public key fingerprints: {}".format(
                    res_pq.server_public_key_fingerprints))

                for i in res_pq.server_public_key_fingerprints:
                    if i in RSA.server_public_keys:
                        log.debug("Using fingerprint: {}".format(i))
                        public_key_fingerprint = i
                        break
                    else:
                        log.debug("Fingerprint unknown: {}".format(i))
                else:
                    raise Exception("Public key not found")

                # Step 3
                pq = int.from_bytes(res_pq.pq, "big")
                log.debug("Start PQ factorization: {}".format(pq))
                start = time.time()
                g = Prime.decompose(pq)
                p, q = sorted((g, pq // g))  # p < q
                log.debug("Done PQ factorization ({}s): {} {}".format(
                    round(time.time() - start, 3), p, q))

                # Step 4
                server_nonce = res_pq.server_nonce
                new_nonce = int.from_bytes(urandom(32), "little", signed=True)

                data = types.PQInnerData(
                    pq=res_pq.pq,
                    p=p.to_bytes(4, "big"),
                    q=q.to_bytes(4, "big"),
                    nonce=nonce,
                    server_nonce=server_nonce,
                    new_nonce=new_nonce,
                ).write()

                sha = sha1(data).digest()
                padding = urandom(-(len(data) + len(sha)) % 255)
                data_with_hash = sha + data + padding
                encrypted_data = RSA.encrypt(data_with_hash,
                                             public_key_fingerprint)

                log.debug("Done encrypt data with RSA")

                # Step 5. TODO: Handle "server_DH_params_fail". Code assumes response is ok
                log.debug("Send req_DH_params")
                server_dh_params = self.send(
                    functions.ReqDHParams(
                        nonce=nonce,
                        server_nonce=server_nonce,
                        p=p.to_bytes(4, "big"),
                        q=q.to_bytes(4, "big"),
                        public_key_fingerprint=public_key_fingerprint,
                        encrypted_data=encrypted_data))

                encrypted_answer = server_dh_params.encrypted_answer

                server_nonce = server_nonce.to_bytes(16, "little", signed=True)
                new_nonce = new_nonce.to_bytes(32, "little", signed=True)

                tmp_aes_key = (sha1(new_nonce + server_nonce).digest() +
                               sha1(server_nonce + new_nonce).digest()[:12])

                tmp_aes_iv = (sha1(server_nonce + new_nonce).digest()[12:] +
                              sha1(new_nonce + new_nonce).digest() +
                              new_nonce[:4])

                server_nonce = int.from_bytes(server_nonce,
                                              "little",
                                              signed=True)

                answer_with_hash = AES.ige256_decrypt(encrypted_answer,
                                                      tmp_aes_key, tmp_aes_iv)
                answer = answer_with_hash[20:]

                server_dh_inner_data = TLObject.read(BytesIO(answer))

                log.debug("Done decrypting answer")

                dh_prime = int.from_bytes(server_dh_inner_data.dh_prime, "big")
                delta_time = server_dh_inner_data.server_time - time.time()

                log.debug("Delta time: {}".format(round(delta_time, 3)))

                # Step 6
                g = server_dh_inner_data.g
                b = int.from_bytes(urandom(256), "big")
                g_b = pow(g, b, dh_prime).to_bytes(256, "big")

                retry_id = 0

                data = types.ClientDHInnerData(nonce=nonce,
                                               server_nonce=server_nonce,
                                               retry_id=retry_id,
                                               g_b=g_b).write()

                sha = sha1(data).digest()
                padding = urandom(-(len(data) + len(sha)) % 16)
                data_with_hash = sha + data + padding
                encrypted_data = AES.ige256_encrypt(data_with_hash,
                                                    tmp_aes_key, tmp_aes_iv)

                log.debug("Send set_client_DH_params")
                set_client_dh_params_answer = self.send(
                    functions.SetClientDHParams(nonce=nonce,
                                                server_nonce=server_nonce,
                                                encrypted_data=encrypted_data))

                # TODO: Handle "auth_key_aux_hash" if the previous step fails

                # Step 7; Step 8
                g_a = int.from_bytes(server_dh_inner_data.g_a, "big")
                auth_key = pow(g_a, b, dh_prime).to_bytes(256, "big")
                server_nonce = server_nonce.to_bytes(16, "little", signed=True)

                # TODO: Handle errors

                #######################
                # Security checks
                #######################

                assert dh_prime == Prime.CURRENT_DH_PRIME
                log.debug("DH parameters check: OK")

                # https://core.telegram.org/mtproto/security_guidelines#g-a-and-g-b-validation
                g_b = int.from_bytes(g_b, "big")
                assert 1 < g < dh_prime - 1
                assert 1 < g_a < dh_prime - 1
                assert 1 < g_b < dh_prime - 1
                assert 2**(2048 - 64) < g_a < dh_prime - 2**(2048 - 64)
                assert 2**(2048 - 64) < g_b < dh_prime - 2**(2048 - 64)
                log.debug("g_a and g_b validation: OK")

                # https://core.telegram.org/mtproto/security_guidelines#checking-sha1-hash-values
                answer = server_dh_inner_data.write(
                )  # Call .write() to remove padding
                assert answer_with_hash[:20] == sha1(answer).digest()
                log.debug("SHA1 hash values check: OK")

                # https://core.telegram.org/mtproto/security_guidelines#checking-nonce-server-nonce-and-new-nonce-fields
                # 1st message
                assert nonce == res_pq.nonce
                # 2nd message
                server_nonce = int.from_bytes(server_nonce,
                                              "little",
                                              signed=True)
                assert nonce == server_dh_params.nonce
                assert server_nonce == server_dh_params.server_nonce
                # 3rd message
                assert nonce == set_client_dh_params_answer.nonce
                assert server_nonce == set_client_dh_params_answer.server_nonce
                server_nonce = server_nonce.to_bytes(16, "little", signed=True)
                log.debug("Nonce fields check: OK")

                # Step 9
                server_salt = AES.xor(new_nonce[:8], server_nonce[:8])

                log.debug("Server salt: {}".format(
                    int.from_bytes(server_salt, "little")))

                log.info("Done auth key exchange: {}".format(
                    set_client_dh_params_answer.__class__.__name__))
            except Exception as e:
                if retries_left:
                    retries_left -= 1
                else:
                    raise e

                time.sleep(1)
                continue
            else:
                return auth_key
            finally:
                self.connection.close()
예제 #2
0
파일: auth.py 프로젝트: zolemar/pyrogram
class Auth:
    MAX_RETRIES = 5

    CURRENT_DH_PRIME = int(
        "C71CAEB9C6B1C9048E6C522F70F13F73980D40238E3E21C14934D037563D930F"
        "48198A0AA7C14058229493D22530F4DBFA336F6E0AC925139543AED44CCE7C37"
        "20FD51F69458705AC68CD4FE6B6B13ABDC9746512969328454F18FAF8C595F64"
        "2477FE96BB2A941D5BCD1D4AC8CC49880708FA9B378E3C4F3A9060BEE67CF9A4"
        "A4A695811051907E162753B56B0F6B410DBA74D8A84B2A14B3144E0EF1284754"
        "FD17ED950D5965B4B9DD46582DB1178D169C6BC465B0D6FF9CA3928FEF5B9AE4"
        "E418FC15E83EBEA0F87FA9FF5EED70050DED2849F47BF959D956850CE929851F"
        "0D8115F635B105EE2E4E15D04B2454BF6F4FADF034B10403119CD8E3B92FCC5B",
        16
    )

    def __init__(self, dc_id: int, test_mode: bool, proxy: type):
        self.dc_id = dc_id
        self.test_mode = test_mode

        self.connection = Connection(DataCenter(dc_id, test_mode), proxy)

    @staticmethod
    def pack(data: Object) -> bytes:
        return (
            bytes(8)
            + Long(MsgId())
            + Int(len(data.write()))
            + data.write()
        )

    @staticmethod
    def unpack(b: BytesIO):
        b.seek(20)  # Skip auth_key_id (8), message_id (8) and message_length (4)
        return Object.read(b)

    def send(self, data: Object):
        data = self.pack(data)
        self.connection.send(data)
        response = BytesIO(self.connection.recv())

        return self.unpack(response)

    def create(self):
        """
        https://core.telegram.org/mtproto/auth_key
        https://core.telegram.org/mtproto/samples-auth_key
        """
        retries_left = self.MAX_RETRIES

        # The server may close the connection at any time, causing the auth key creation to fail.
        # If that happens, just try again up to MAX_RETRIES times.
        while True:
            try:
                log.info("Start creating a new auth key on DC{}".format(self.dc_id))

                self.connection.connect()

                # Step 1; Step 2
                nonce = int.from_bytes(urandom(16), "little", signed=True)
                log.debug("Send req_pq: {}".format(nonce))
                res_pq = self.send(functions.ReqPqMulti(nonce))
                log.debug("Got ResPq: {}".format(res_pq.server_nonce))
                log.debug("Server public key fingerprints: {}".format(res_pq.server_public_key_fingerprints))

                for i in res_pq.server_public_key_fingerprints:
                    if i in RSA.server_public_keys:
                        log.debug("Using fingerprint: {}".format(i))
                        public_key_fingerprint = i
                        break
                    else:
                        log.debug("Fingerprint unknown: {}".format(i))
                else:
                    raise Exception("Public key not found")

                # Step 3
                pq = int.from_bytes(res_pq.pq, "big")
                log.debug("Start PQ factorization: {}".format(pq))
                start = time.time()
                g = Prime.decompose(pq)
                p, q = sorted((g, pq // g))  # p < q
                log.debug("Done PQ factorization ({}s): {} {}".format(round(time.time() - start, 3), p, q))

                # Step 4
                server_nonce = res_pq.server_nonce
                new_nonce = int.from_bytes(urandom(32), "little", signed=True)

                data = types.PQInnerData(
                    res_pq.pq,
                    int.to_bytes(p, 4, "big"),
                    int.to_bytes(q, 4, "big"),
                    nonce,
                    server_nonce,
                    new_nonce,
                ).write()

                sha = sha1(data).digest()
                padding = urandom(- (len(data) + len(sha)) % 255)
                data_with_hash = sha + data + padding
                encrypted_data = RSA.encrypt(data_with_hash, public_key_fingerprint)

                log.debug("Done encrypt data with RSA")

                # Step 5. TODO: Handle "server_DH_params_fail". Code assumes response is ok
                log.debug("Send req_DH_params")
                server_dh_params = self.send(
                    functions.ReqDHParams(
                        nonce,
                        server_nonce,
                        int.to_bytes(p, 4, "big"),
                        int.to_bytes(q, 4, "big"),
                        public_key_fingerprint,
                        encrypted_data
                    )
                )

                encrypted_answer = server_dh_params.encrypted_answer

                server_nonce = int.to_bytes(server_nonce, 16, "little", signed=True)
                new_nonce = int.to_bytes(new_nonce, 32, "little", signed=True)

                tmp_aes_key = (
                    sha1(new_nonce + server_nonce).digest()
                    + sha1(server_nonce + new_nonce).digest()[:12]
                )

                tmp_aes_iv = (
                    sha1(server_nonce + new_nonce).digest()[12:]
                    + sha1(new_nonce + new_nonce).digest() + new_nonce[:4]
                )

                server_nonce = int.from_bytes(server_nonce, "little", signed=True)

                answer_with_hash = AES.ige_decrypt(encrypted_answer, tmp_aes_key, tmp_aes_iv)
                answer = answer_with_hash[20:]

                server_dh_inner_data = Object.read(BytesIO(answer))

                log.debug("Done decrypting answer")

                dh_prime = int.from_bytes(server_dh_inner_data.dh_prime, "big")
                delta_time = server_dh_inner_data.server_time - time.time()

                log.debug("Delta time: {}".format(round(delta_time, 3)))

                # Step 6
                g = server_dh_inner_data.g
                b = int.from_bytes(urandom(256), "big")
                g_b = int.to_bytes(pow(g, b, dh_prime), 256, "big")

                retry_id = 0

                data = types.ClientDHInnerData(
                    nonce,
                    server_nonce,
                    retry_id,
                    g_b
                ).write()

                sha = sha1(data).digest()
                padding = urandom(- (len(data) + len(sha)) % 16)
                data_with_hash = sha + data + padding
                encrypted_data = AES.ige_encrypt(data_with_hash, tmp_aes_key, tmp_aes_iv)

                log.debug("Send set_client_DH_params")
                set_client_dh_params_answer = self.send(
                    functions.SetClientDHParams(
                        nonce,
                        server_nonce,
                        encrypted_data
                    )
                )

                # TODO: Handle "auth_key_aux_hash" if the previous step fails

                # Step 7; Step 8
                g_a = int.from_bytes(server_dh_inner_data.g_a, "big")
                auth_key = int.to_bytes(pow(g_a, b, dh_prime), 256, "big")
                server_nonce = int.to_bytes(server_nonce, 16, "little", signed=True)

                # TODO: Handle errors

                #######################
                # Security checks
                #######################

                assert dh_prime == self.CURRENT_DH_PRIME
                log.debug("DH parameters check: OK")

                # https://core.telegram.org/mtproto/security_guidelines#g-a-and-g-b-validation
                g_b = int.from_bytes(g_b, "big")
                assert 1 < g < dh_prime - 1
                assert 1 < g_a < dh_prime - 1
                assert 1 < g_b < dh_prime - 1
                assert 2 ** (2048 - 64) < g_a < dh_prime - 2 ** (2048 - 64)
                assert 2 ** (2048 - 64) < g_b < dh_prime - 2 ** (2048 - 64)
                log.debug("g_a and g_b validation: OK")

                # https://core.telegram.org/mtproto/security_guidelines#checking-sha1-hash-values
                answer = server_dh_inner_data.write()  # Call .write() to remove padding
                assert answer_with_hash[:20] == sha1(answer).digest()
                log.debug("SHA1 hash values check: OK")

                # https://core.telegram.org/mtproto/security_guidelines#checking-nonce-server-nonce-and-new-nonce-fields
                # 1st message
                assert nonce == res_pq.nonce
                # 2nd message
                server_nonce = int.from_bytes(server_nonce, "little", signed=True)
                assert nonce == server_dh_params.nonce
                assert server_nonce == server_dh_params.server_nonce
                # 3rd message
                assert nonce == set_client_dh_params_answer.nonce
                assert server_nonce == set_client_dh_params_answer.server_nonce
                server_nonce = int.to_bytes(server_nonce, 16, "little", signed=True)
                log.debug("Nonce fields check: OK")

                # Step 9
                server_salt = AES.xor(new_nonce[:8], server_nonce[:8])

                log.debug("Server salt: {}".format(int.from_bytes(server_salt, "little")))

                log.info(
                    "Done auth key exchange: {}".format(
                        set_client_dh_params_answer.__class__.__name__
                    )
                )
            except Exception as e:
                if retries_left:
                    retries_left -= 1
                else:
                    raise e

                log.warning("Auth key creation failed. Let's try again: {}".format(repr(e)))
                time.sleep(1)
                continue
            else:
                return auth_key
            finally:
                self.connection.close()
예제 #3
0
class Session:
    VERSION = __version__
    APP_VERSION = "Pyrogram \U0001f525 {}".format(VERSION)

    DEVICE_MODEL = "{} {}".format(platform.python_implementation(),
                                  platform.python_version())

    SYSTEM_VERSION = "{} {}".format(platform.system(), platform.release())

    INITIAL_SALT = 0x616e67656c696361

    WORKERS = 4
    WAIT_TIMEOUT = 10
    MAX_RETRIES = 5
    ACKS_THRESHOLD = 8
    PING_INTERVAL = 5

    notice_displayed = False

    def __init__(self,
                 dc_id: int,
                 test_mode: bool,
                 proxy: type,
                 auth_key: bytes,
                 api_id: str,
                 is_cdn: bool = False):
        if not Session.notice_displayed:
            print("Pyrogram v{}, {}".format(__version__, __copyright__))
            print("Licensed under the terms of the " + __license__, end="\n\n")
            Session.notice_displayed = True

        self.is_cdn = is_cdn

        self.connection = Connection(DataCenter(dc_id, test_mode), proxy)

        self.api_id = api_id

        self.auth_key = auth_key
        self.auth_key_id = sha1(auth_key).digest()[-8:]

        self.msg_id = MsgId()
        self.session_id = Long(self.msg_id())
        self.msg_factory = MsgFactory(self.msg_id)

        self.current_salt = None

        self.pending_acks = set()

        self.recv_queue = Queue()
        self.results = {}

        self.ping_thread = None
        self.ping_thread_event = Event()

        self.next_salt_thread = None
        self.next_salt_thread_event = Event()

        self.is_connected = Event()

        self.update_handler = None

        self.total_connections = 0
        self.total_messages = 0
        self.total_bytes = 0

    def start(self):
        terms = None

        while True:
            try:
                self.connection.connect()

                for i in range(self.WORKERS):
                    Thread(target=self.worker,
                           name="Worker#{}".format(i + 1)).start()

                Thread(target=self.recv, name="RecvThread").start()

                self.current_salt = FutureSalt(0, 0, self.INITIAL_SALT)
                self.current_salt = FutureSalt(
                    0, 0,
                    self._send(functions.Ping(0)).new_server_salt)
                self.current_salt = self._send(
                    functions.GetFutureSalts(1)).salts[0]

                self.next_salt_thread = Thread(target=self.next_salt,
                                               name="NextSaltThread")
                self.next_salt_thread.start()

                if not self.is_cdn:
                    terms = self._send(
                        functions.InvokeWithLayer(
                            layer,
                            functions.InitConnection(
                                self.api_id,
                                self.DEVICE_MODEL,
                                self.SYSTEM_VERSION,
                                self.APP_VERSION,
                                "en",
                                "",
                                "en",
                                functions.help.GetTermsOfService(),
                            ))).text

                self.ping_thread = Thread(target=self.ping, name="PingThread")
                self.ping_thread.start()

                log.info("Connection inited: Layer {}".format(layer))
            except (OSError, TimeoutError):
                self.stop()
            else:
                break

        self.is_connected.set()
        self.total_connections += 1

        log.debug("Session started")

        return terms

    def stop(self):
        self.is_connected.clear()

        self.ping_thread_event.set()
        self.next_salt_thread_event.set()

        if self.ping_thread is not None:
            self.ping_thread.join()

        if self.next_salt_thread is not None:
            self.next_salt_thread.join()

        self.ping_thread_event.clear()
        self.next_salt_thread_event.clear()

        self.connection.close()

        for i in range(self.WORKERS):
            self.recv_queue.put(None)

        log.debug("Session stopped")

    def restart(self):
        self.stop()
        self.start()

    def pack(self, message: Message):
        data = Long(self.current_salt.salt) + self.session_id + message.write()
        # MTProto 2.0 requires a minimum of 12 padding bytes.
        # I don't get why it says up to 1024 when what it actually needs after the
        # required 12 bytes is just extra 0..15 padding bytes for aes
        # TODO: It works, but recheck this. What's the meaning of 12..1024 padding bytes?
        padding = urandom(-(len(data) + 12) % 16 + 12)

        # 88 = 88 + 0 (outgoing message)
        msg_key_large = sha256(self.auth_key[88:88 + 32] + data +
                               padding).digest()
        msg_key = msg_key_large[8:24]
        aes_key, aes_iv = KDF(self.auth_key, msg_key, True)

        return self.auth_key_id + msg_key + IGE.encrypt(
            data + padding, aes_key, aes_iv)

    def unpack(self, b: BytesIO) -> Message:
        assert b.read(8) == self.auth_key_id, b.getvalue()

        msg_key = b.read(16)
        aes_key, aes_iv = KDF(self.auth_key, msg_key, False)
        data = BytesIO(IGE.decrypt(b.read(), aes_key, aes_iv))
        data.read(8)

        # https://core.telegram.org/mtproto/security_guidelines#checking-session-id
        assert data.read(8) == self.session_id

        message = Message.read(data)

        # https://core.telegram.org/mtproto/security_guidelines#checking-sha256-hash-value-of-msg-key
        # https://core.telegram.org/mtproto/security_guidelines#checking-message-length
        # 96 = 88 + 8 (incoming message)
        assert msg_key == sha256(self.auth_key[96:96 + 32] +
                                 data.getvalue()).digest()[8:24]

        # https://core.telegram.org/mtproto/security_guidelines#checking-msg-id
        # TODO: check for lower msg_ids
        assert message.msg_id % 2 != 0

        return message

    def worker(self):
        name = threading.current_thread().name
        log.debug("{} started".format(name))

        while True:
            packet = self.recv_queue.get()

            if packet is None:
                break

            try:
                self.unpack_dispatch_and_ack(packet)
            except Exception as e:
                log.error(e, exc_info=True)

        log.debug("{} stopped".format(name))

    def unpack_dispatch_and_ack(self, packet: bytes):
        # TODO: A better dispatcher
        data = self.unpack(BytesIO(packet))

        messages = (data.body.messages
                    if isinstance(data.body, MsgContainer) else [data])

        log.debug(data)

        self.total_bytes += len(packet)
        self.total_messages += len(messages)

        for i in messages:
            if i.seq_no % 2 != 0:
                if i.msg_id in self.pending_acks:
                    continue
                else:
                    self.pending_acks.add(i.msg_id)

            # log.debug("{}".format(type(i.body)))

            if isinstance(i.body,
                          (types.MsgDetailedInfo, types.MsgNewDetailedInfo)):
                self.pending_acks.add(i.body.answer_msg_id)
                continue

            if isinstance(i.body, types.NewSessionCreated):
                continue

            msg_id = None

            if isinstance(i.body,
                          (types.BadMsgNotification, types.BadServerSalt)):
                msg_id = i.body.bad_msg_id
            elif isinstance(i.body, (core.FutureSalts, types.RpcResult)):
                msg_id = i.body.req_msg_id
            elif isinstance(i.body, types.Pong):
                msg_id = i.body.msg_id
            else:
                if self.update_handler:
                    self.update_handler(i.body)

            if msg_id in self.results:
                self.results[msg_id].value = getattr(i.body, "result", i.body)
                self.results[msg_id].event.set()

        # print(
        #     "This packet bytes: ({}) | Total bytes: ({})\n"
        #     "This packet messages: ({}) | Total messages: ({})\n"
        #     "Total connections: ({})".format(
        #         len(packet), self.total_bytes, len(messages), self.total_messages, self.total_connections
        #     )
        # )

        if len(self.pending_acks) >= self.ACKS_THRESHOLD:
            log.info("Send {} acks".format(len(self.pending_acks)))

            try:
                self._send(types.MsgsAck(list(self.pending_acks)), False)
            except (OSError, TimeoutError):
                pass
            else:
                self.pending_acks.clear()

    def ping(self):
        log.debug("PingThread started")

        while True:
            self.ping_thread_event.wait(self.PING_INTERVAL)

            if self.ping_thread_event.is_set():
                break

            try:
                self._send(functions.Ping(0), False)
            except (OSError, TimeoutError):
                pass

        log.debug("PingThread stopped")

    def next_salt(self):
        log.debug("NextSaltThread started")

        while True:
            now = datetime.now()

            # Seconds to wait until middle-overlap, which is
            # 15 minutes before/after the current/next salt end/start time
            dt = (self.current_salt.valid_until - now).total_seconds() - 900

            log.debug(
                "Current salt: {} | Next salt in {:.0f}m {:.0f}s ({})".format(
                    self.current_salt.salt, dt // 60, dt % 60,
                    now + timedelta(seconds=dt)))

            self.next_salt_thread_event.wait(dt)

            if self.next_salt_thread_event.is_set():
                break

            try:
                self.current_salt = self._send(
                    functions.GetFutureSalts(1)).salts[0]
            except (OSError, TimeoutError):
                self.connection.close()
                break

        log.debug("NextSaltThread stopped")

    def recv(self):
        log.debug("RecvThread started")

        while True:
            packet = self.connection.recv()

            if packet is None or (len(packet) == 4
                                  and Int.read(BytesIO(packet)) == -404):
                if self.is_connected.is_set():
                    Thread(target=self.restart, name="RestartThread").start()
                break

            self.recv_queue.put(packet)

        log.debug("RecvThread stopped")

    def _send(self, data: Object, wait_response: bool = True):
        message = self.msg_factory(data)
        msg_id = message.msg_id

        if wait_response:
            self.results[msg_id] = Result()

        payload = self.pack(message)

        try:
            self.connection.send(payload)
        except OSError as e:
            self.results.pop(msg_id, None)
            raise e

        if wait_response:
            self.results[msg_id].event.wait(self.WAIT_TIMEOUT)
            result = self.results.pop(msg_id).value

            if result is None:
                raise TimeoutError
            elif isinstance(result, types.RpcError):
                Error.raise_it(result, type(data))
            else:
                return result

    def send(self, data: Object):
        for i in range(self.MAX_RETRIES):
            self.is_connected.wait()

            try:
                return self._send(data)
            except (OSError, TimeoutError):
                log.warning("Retrying {}".format(type(data)))
                continue
        else:
            return None
예제 #4
0
파일: session.py 프로젝트: zolemar/pyrogram
class Session:
    VERSION = __version__
    APP_VERSION = "Pyrogram \U0001f525 {}".format(VERSION)

    DEVICE_MODEL = "{} {}".format(
        platform.python_implementation(),
        platform.python_version()
    )

    SYSTEM_VERSION = "{} {}".format(
        platform.system(),
        platform.release()
    )

    INITIAL_SALT = 0x616e67656c696361
    NET_WORKERS = 1
    WAIT_TIMEOUT = 30
    MAX_RETRIES = 5
    ACKS_THRESHOLD = 8
    PING_INTERVAL = 5

    notice_displayed = False

    BAD_MSG_DESCRIPTION = {
        16: "[16] msg_id too low, the client time has to be synchronized",
        17: "[17] msg_id too high, the client time has to be synchronized",
        18: "[18] incorrect two lower order msg_id bits, the server expects client message msg_id to be divisible by 4",
        19: "[19] container msg_id is the same as msg_id of a previously received message",
        20: "[20] message too old, it cannot be verified by the server",
        32: "[32] msg_seqno too low",
        33: "[33] msg_seqno too high",
        34: "[34] an even msg_seqno expected, but odd received",
        35: "[35] odd msg_seqno expected, but even received",
        48: "[48] incorrect server salt",
        64: "[64] invalid container"
    }

    def __init__(self,
                 dc_id: int,
                 test_mode: bool,
                 proxy: type,
                 auth_key: bytes,
                 api_id: str,
                 is_cdn: bool = False,
                 client: pyrogram = None):
        if not Session.notice_displayed:
            print("Pyrogram v{}, {}".format(__version__, __copyright__))
            print("Licensed under the terms of the " + __license__, end="\n\n")
            Session.notice_displayed = True

        self.connection = Connection(DataCenter(dc_id, test_mode), proxy)
        self.api_id = api_id
        self.is_cdn = is_cdn
        self.client = client

        self.auth_key = auth_key
        self.auth_key_id = sha1(auth_key).digest()[-8:]

        self.session_id = Long(MsgId())
        self.msg_factory = MsgFactory()

        self.current_salt = None

        self.pending_acks = set()

        self.recv_queue = Queue()
        self.results = {}

        self.ping_thread = None
        self.ping_thread_event = Event()

        self.next_salt_thread = None
        self.next_salt_thread_event = Event()

        self.is_connected = Event()

    def start(self):
        while True:
            try:
                self.connection.connect()

                for i in range(self.NET_WORKERS):
                    Thread(target=self.net_worker, name="NetWorker#{}".format(i + 1)).start()

                Thread(target=self.recv, name="RecvThread").start()

                self.current_salt = FutureSalt(0, 0, self.INITIAL_SALT)
                self.current_salt = FutureSalt(0, 0, self._send(functions.Ping(0)).new_server_salt)
                self.current_salt = self._send(functions.GetFutureSalts(1)).salts[0]

                self.next_salt_thread = Thread(target=self.next_salt, name="NextSaltThread")
                self.next_salt_thread.start()

                if not self.is_cdn:
                    self._send(
                        functions.InvokeWithLayer(
                            layer,
                            functions.InitConnection(
                                self.api_id,
                                self.DEVICE_MODEL,
                                self.SYSTEM_VERSION,
                                self.APP_VERSION,
                                "en", "", "en",
                                functions.help.GetConfig(),
                            )
                        )
                    )

                self.ping_thread = Thread(target=self.ping, name="PingThread")
                self.ping_thread.start()

                log.info("Connection inited: Layer {}".format(layer))
            except (OSError, TimeoutError, Error):
                self.stop()
            else:
                break

        self.is_connected.set()

        log.debug("Session started")

    def stop(self):
        self.is_connected.clear()

        self.ping_thread_event.set()
        self.next_salt_thread_event.set()

        if self.ping_thread is not None:
            self.ping_thread.join()

        if self.next_salt_thread is not None:
            self.next_salt_thread.join()

        self.ping_thread_event.clear()
        self.next_salt_thread_event.clear()

        self.connection.close()

        for i in range(self.NET_WORKERS):
            self.recv_queue.put(None)

        for i in self.results.values():
            i.event.set()

        log.debug("Session stopped")

    def restart(self):
        self.stop()
        self.start()

    def pack(self, message: Message):
        data = Long(self.current_salt.salt) + self.session_id + message.write()
        padding = urandom(-(len(data) + 12) % 16 + 12)

        # 88 = 88 + 0 (outgoing message)
        msg_key_large = sha256(self.auth_key[88: 88 + 32] + data + padding).digest()
        msg_key = msg_key_large[8:24]
        aes_key, aes_iv = KDF(self.auth_key, msg_key, True)

        return self.auth_key_id + msg_key + AES.ige_encrypt(data + padding, aes_key, aes_iv)

    def unpack(self, b: BytesIO) -> Message:
        assert b.read(8) == self.auth_key_id, b.getvalue()

        msg_key = b.read(16)
        aes_key, aes_iv = KDF(self.auth_key, msg_key, False)
        data = BytesIO(AES.ige_decrypt(b.read(), aes_key, aes_iv))
        data.read(8)

        # https://core.telegram.org/mtproto/security_guidelines#checking-session-id
        assert data.read(8) == self.session_id

        message = Message.read(data)

        # https://core.telegram.org/mtproto/security_guidelines#checking-sha256-hash-value-of-msg-key
        # https://core.telegram.org/mtproto/security_guidelines#checking-message-length
        # 96 = 88 + 8 (incoming message)
        assert msg_key == sha256(self.auth_key[96:96 + 32] + data.getvalue()).digest()[8:24]

        # https://core.telegram.org/mtproto/security_guidelines#checking-msg-id
        # TODO: check for lower msg_ids
        assert message.msg_id % 2 != 0

        return message

    def net_worker(self):
        name = threading.current_thread().name
        log.debug("{} started".format(name))

        while True:
            packet = self.recv_queue.get()

            if packet is None:
                break

            try:
                self.unpack_dispatch_and_ack(packet)
            except Exception as e:
                log.error(e, exc_info=True)

        log.debug("{} stopped".format(name))

    def unpack_dispatch_and_ack(self, packet: bytes):
        data = self.unpack(BytesIO(packet))

        messages = (
            data.body.messages
            if isinstance(data.body, MsgContainer)
            else [data]
        )

        log.debug(data)

        for msg in messages:
            if msg.seq_no % 2 != 0:
                if msg.msg_id in self.pending_acks:
                    continue
                else:
                    self.pending_acks.add(msg.msg_id)

            if isinstance(msg.body, (types.MsgDetailedInfo, types.MsgNewDetailedInfo)):
                self.pending_acks.add(msg.body.answer_msg_id)
                continue

            if isinstance(msg.body, types.NewSessionCreated):
                continue

            msg_id = None

            if isinstance(msg.body, (types.BadMsgNotification, types.BadServerSalt)):
                msg_id = msg.body.bad_msg_id
            elif isinstance(msg.body, (core.FutureSalts, types.RpcResult)):
                msg_id = msg.body.req_msg_id
            elif isinstance(msg.body, types.Pong):
                msg_id = msg.body.msg_id
            else:
                if self.client is not None:
                    self.client.updates_queue.put(msg.body)

            if msg_id in self.results:
                self.results[msg_id].value = getattr(msg.body, "result", msg.body)
                self.results[msg_id].event.set()

        if len(self.pending_acks) >= self.ACKS_THRESHOLD:
            log.info("Send {} acks".format(len(self.pending_acks)))

            try:
                self._send(types.MsgsAck(list(self.pending_acks)), False)
            except (OSError, TimeoutError):
                pass
            else:
                self.pending_acks.clear()

    def ping(self):
        log.debug("PingThread started")

        while True:
            self.ping_thread_event.wait(self.PING_INTERVAL)

            if self.ping_thread_event.is_set():
                break

            try:
                self._send(functions.PingDelayDisconnect(0, self.PING_INTERVAL + 15), False)
            except (OSError, TimeoutError):
                pass

        log.debug("PingThread stopped")

    def next_salt(self):
        log.debug("NextSaltThread started")

        while True:
            now = datetime.now()

            # Seconds to wait until middle-overlap, which is
            # 15 minutes before/after the current/next salt end/start time
            dt = (self.current_salt.valid_until - now).total_seconds() - 900

            log.debug("Current salt: {} | Next salt in {:.0f}m {:.0f}s ({})".format(
                self.current_salt.salt,
                dt // 60,
                dt % 60,
                now + timedelta(seconds=dt)
            ))

            self.next_salt_thread_event.wait(dt)

            if self.next_salt_thread_event.is_set():
                break

            try:
                self.current_salt = self._send(functions.GetFutureSalts(1)).salts[0]
            except (OSError, TimeoutError):
                self.connection.close()
                break

        log.debug("NextSaltThread stopped")

    def recv(self):
        log.debug("RecvThread started")

        while True:
            packet = self.connection.recv()

            if packet is None or len(packet) == 4:
                if packet:
                    log.warning("Server sent \"{}\"".format(Int.read(BytesIO(packet))))

                if self.is_connected.is_set():
                    Thread(target=self.restart, name="RestartThread").start()
                break

            self.recv_queue.put(packet)

        log.debug("RecvThread stopped")

    def _send(self, data: Object, wait_response: bool = True):
        message = self.msg_factory(data)
        msg_id = message.msg_id

        if wait_response:
            self.results[msg_id] = Result()

        payload = self.pack(message)

        try:
            self.connection.send(payload)
        except OSError as e:
            self.results.pop(msg_id, None)
            raise e

        if wait_response:
            self.results[msg_id].event.wait(self.WAIT_TIMEOUT)
            result = self.results.pop(msg_id).value

            if result is None:
                raise TimeoutError
            elif isinstance(result, types.RpcError):
                Error.raise_it(result, type(data))
            elif isinstance(result, types.BadMsgNotification):
                raise Exception(self.BAD_MSG_DESCRIPTION.get(
                    result.error_code,
                    "Error code {}".format(result.error_code)
                ))
            else:
                return result

    def send(self, data: Object):
        for i in range(self.MAX_RETRIES):
            self.is_connected.wait()

            try:
                return self._send(data)
            except (OSError, TimeoutError):
                (log.warning if i > 0 else log.info)("{}: {} Retrying {}".format(i, datetime.now(), type(data)))
                continue
        else:
            return None
예제 #5
0
class Session:
    INITIAL_SALT = 0x616e67656c696361
    NET_WORKERS = 1
    START_TIMEOUT = 1
    WAIT_TIMEOUT = 15
    SLEEP_THRESHOLD = 60
    MAX_RETRIES = 5
    ACKS_THRESHOLD = 8
    PING_INTERVAL = 5

    notice_displayed = False

    BAD_MSG_DESCRIPTION = {
        16: "[16] msg_id too low, the client time has to be synchronized",
        17: "[17] msg_id too high, the client time has to be synchronized",
        18:
        "[18] incorrect two lower order msg_id bits, the server expects client message msg_id to be divisible by 4",
        19:
        "[19] container msg_id is the same as msg_id of a previously received message",
        20: "[20] message too old, it cannot be verified by the server",
        32: "[32] msg_seqno too low",
        33: "[33] msg_seqno too high",
        34: "[34] an even msg_seqno expected, but odd received",
        35: "[35] odd msg_seqno expected, but even received",
        48: "[48] incorrect server salt",
        64: "[64] invalid container"
    }

    def __init__(self,
                 client: pyrogram,
                 dc_id: int,
                 auth_key: bytes,
                 is_media: bool = False,
                 is_cdn: bool = False):
        if not Session.notice_displayed:
            print("Pyrogram v{}, {}".format(__version__, __copyright__))
            print("Licensed under the terms of the " + __license__, end="\n\n")
            Session.notice_displayed = True

        self.client = client
        self.dc_id = dc_id
        self.auth_key = auth_key
        self.is_media = is_media
        self.is_cdn = is_cdn

        self.connection = None

        self.auth_key_id = sha1(auth_key).digest()[-8:]

        self.session_id = Long(MsgId())
        self.msg_factory = MsgFactory()

        self.current_salt = None

        self.pending_acks = set()

        self.recv_queue = Queue()
        self.results = {}

        self.ping_thread = None
        self.ping_thread_event = Event()

        self.next_salt_thread = None
        self.next_salt_thread_event = Event()

        self.net_worker_list = []

        self.is_connected = Event()

    def start(self):
        while True:
            self.connection = Connection(self.dc_id,
                                         self.client.storage.test_mode(),
                                         self.client.ipv6, self.client.proxy)

            try:
                self.connection.connect()

                for i in range(self.NET_WORKERS):
                    self.net_worker_list.append(
                        Thread(target=self.net_worker,
                               name="NetWorker#{}".format(i + 1)))

                    self.net_worker_list[-1].start()

                Thread(target=self.recv, name="RecvThread").start()

                self.current_salt = FutureSalt(0, 0, self.INITIAL_SALT)
                self.current_salt = FutureSalt(
                    0, 0,
                    self._send(functions.Ping(ping_id=0),
                               timeout=self.START_TIMEOUT).new_server_salt)
                self.current_salt = self._send(
                    functions.GetFutureSalts(num=1),
                    timeout=self.START_TIMEOUT).salts[0]

                self.next_salt_thread = Thread(target=self.next_salt,
                                               name="NextSaltThread")
                self.next_salt_thread.start()

                if not self.is_cdn:
                    self._send(functions.InvokeWithLayer(
                        layer=layer,
                        query=functions.InitConnection(
                            api_id=self.client.api_id,
                            app_version=self.client.app_version,
                            device_model=self.client.device_model,
                            system_version=self.client.system_version,
                            system_lang_code=self.client.lang_code,
                            lang_code=self.client.lang_code,
                            lang_pack="",
                            query=functions.help.GetConfig(),
                        )),
                               timeout=self.START_TIMEOUT)

                self.ping_thread = Thread(target=self.ping, name="PingThread")
                self.ping_thread.start()

                log.info("Session initialized: Layer {}".format(layer))
                log.info("Device: {} - {}".format(self.client.device_model,
                                                  self.client.app_version))
                log.info("System: {} ({})".format(
                    self.client.system_version, self.client.lang_code.upper()))

            except AuthKeyDuplicated as e:
                self.stop()
                raise e
            except (OSError, TimeoutError, RPCError):
                self.stop()
            except Exception as e:
                self.stop()
                raise e
            else:
                break

        self.is_connected.set()

        log.debug("Session started")

    def stop(self):
        self.is_connected.clear()

        self.ping_thread_event.set()
        self.next_salt_thread_event.set()

        if self.ping_thread is not None:
            self.ping_thread.join()

        if self.next_salt_thread is not None:
            self.next_salt_thread.join()

        self.ping_thread_event.clear()
        self.next_salt_thread_event.clear()

        self.connection.close()

        for i in range(self.NET_WORKERS):
            self.recv_queue.put(None)

        for i in self.net_worker_list:
            i.join()

        self.net_worker_list.clear()
        self.recv_queue.queue.clear()

        for i in self.results.values():
            i.event.set()

        if not self.is_media and callable(self.client.disconnect_handler):
            try:
                self.client.disconnect_handler(self.client)
            except Exception as e:
                log.error(e, exc_info=True)

        log.debug("Session stopped")

    def restart(self):
        self.stop()
        self.start()

    def pack(self, message: Message):
        data = Long(self.current_salt.salt) + self.session_id + message.write()
        padding = urandom(-(len(data) + 12) % 16 + 12)

        # 88 = 88 + 0 (outgoing message)
        msg_key_large = sha256(self.auth_key[88:88 + 32] + data +
                               padding).digest()
        msg_key = msg_key_large[8:24]
        aes_key, aes_iv = KDF(self.auth_key, msg_key, True)

        return self.auth_key_id + msg_key + AES.ige256_encrypt(
            data + padding, aes_key, aes_iv)

    def unpack(self, b: BytesIO) -> Message:
        assert b.read(8) == self.auth_key_id, b.getvalue()

        msg_key = b.read(16)
        aes_key, aes_iv = KDF(self.auth_key, msg_key, False)
        data = BytesIO(AES.ige256_decrypt(b.read(), aes_key, aes_iv))
        data.read(8)

        # https://core.telegram.org/mtproto/security_guidelines#checking-session-id
        assert data.read(8) == self.session_id

        message = Message.read(data)

        # https://core.telegram.org/mtproto/security_guidelines#checking-sha256-hash-value-of-msg-key
        # https://core.telegram.org/mtproto/security_guidelines#checking-message-length
        # 96 = 88 + 8 (incoming message)
        assert msg_key == sha256(self.auth_key[96:96 + 32] +
                                 data.getvalue()).digest()[8:24]

        # https://core.telegram.org/mtproto/security_guidelines#checking-msg-id
        # TODO: check for lower msg_ids
        assert message.msg_id % 2 != 0

        return message

    def net_worker(self):
        name = threading.current_thread().name
        log.debug("{} started".format(name))

        while True:
            packet = self.recv_queue.get()

            if packet is None:
                break

            try:
                data = self.unpack(BytesIO(packet))

                messages = (data.body.messages if isinstance(
                    data.body, MsgContainer) else [data])

                log.debug("Received:\n{}".format(data))

                for msg in messages:
                    if msg.seq_no % 2 != 0:
                        if msg.msg_id in self.pending_acks:
                            continue
                        else:
                            self.pending_acks.add(msg.msg_id)

                    if isinstance(
                            msg.body,
                        (types.MsgDetailedInfo, types.MsgNewDetailedInfo)):
                        self.pending_acks.add(msg.body.answer_msg_id)
                        continue

                    if isinstance(msg.body, types.NewSessionCreated):
                        continue

                    msg_id = None

                    if isinstance(
                            msg.body,
                        (types.BadMsgNotification, types.BadServerSalt)):
                        msg_id = msg.body.bad_msg_id
                    elif isinstance(msg.body,
                                    (core.FutureSalts, types.RpcResult)):
                        msg_id = msg.body.req_msg_id
                    elif isinstance(msg.body, types.Pong):
                        msg_id = msg.body.msg_id
                    else:
                        if self.client is not None:
                            self.client.updates_queue.put(msg.body)

                    if msg_id in self.results:
                        self.results[msg_id].value = getattr(
                            msg.body, "result", msg.body)
                        self.results[msg_id].event.set()

                if len(self.pending_acks) >= self.ACKS_THRESHOLD:
                    log.info("Send {} acks".format(len(self.pending_acks)))

                    try:
                        self._send(
                            types.MsgsAck(msg_ids=list(self.pending_acks)),
                            False)
                    except (OSError, TimeoutError):
                        pass
                    else:
                        self.pending_acks.clear()
            except Exception as e:
                log.error(e, exc_info=True)

        log.debug("{} stopped".format(name))

    def ping(self):
        log.debug("PingThread started")

        while True:
            self.ping_thread_event.wait(self.PING_INTERVAL)

            if self.ping_thread_event.is_set():
                break

            try:
                self._send(
                    functions.PingDelayDisconnect(
                        ping_id=0, disconnect_delay=self.WAIT_TIMEOUT + 10),
                    False)
            except (OSError, TimeoutError, RPCError):
                pass

        log.debug("PingThread stopped")

    def next_salt(self):
        log.debug("NextSaltThread started")

        while True:
            now = datetime.now()

            # Seconds to wait until middle-overlap, which is
            # 15 minutes before/after the current/next salt end/start time
            valid_until = datetime.fromtimestamp(self.current_salt.valid_until)
            dt = (valid_until - now).total_seconds() - 900

            log.debug(
                "Current salt: {} | Next salt in {:.0f}m {:.0f}s ({})".format(
                    self.current_salt.salt, dt // 60, dt % 60,
                    now + timedelta(seconds=dt)))

            self.next_salt_thread_event.wait(dt)

            if self.next_salt_thread_event.is_set():
                break

            try:
                self.current_salt = self._send(
                    functions.GetFutureSalts(num=1)).salts[0]
            except (OSError, TimeoutError, RPCError):
                self.connection.close()
                break

        log.debug("NextSaltThread stopped")

    def recv(self):
        log.debug("RecvThread started")

        while True:
            packet = self.connection.recv()

            if packet is None or len(packet) == 4:
                if packet:
                    log.warning("Server sent \"{}\"".format(
                        Int.read(BytesIO(packet))))

                if self.is_connected.is_set():
                    Thread(target=self.restart, name="RestartThread").start()
                break

            self.recv_queue.put(packet)

        log.debug("RecvThread stopped")

    def _send(self,
              data: TLObject,
              wait_response: bool = True,
              timeout: float = WAIT_TIMEOUT):
        message = self.msg_factory(data)
        msg_id = message.msg_id

        if wait_response:
            self.results[msg_id] = Result()

        log.debug("Sent:\n{}".format(message))

        payload = self.pack(message)

        try:
            self.connection.send(payload)
        except OSError as e:
            self.results.pop(msg_id, None)
            raise e

        if wait_response:
            self.results[msg_id].event.wait(timeout)
            result = self.results.pop(msg_id).value

            if result is None:
                raise TimeoutError
            elif isinstance(result, types.RpcError):
                if isinstance(data, (functions.InvokeWithoutUpdates,
                                     functions.InvokeWithTakeout)):
                    data = data.query

                RPCError.raise_it(result, type(data))
            elif isinstance(result, types.BadMsgNotification):
                raise Exception(
                    self.BAD_MSG_DESCRIPTION.get(
                        result.error_code,
                        "Error code {}".format(result.error_code)))
            else:
                return result

    def send(self,
             data: TLObject,
             retries: int = MAX_RETRIES,
             timeout: float = WAIT_TIMEOUT,
             sleep_threshold: float = SLEEP_THRESHOLD):
        self.is_connected.wait(self.WAIT_TIMEOUT)

        if isinstance(
                data,
            (functions.InvokeWithoutUpdates, functions.InvokeWithTakeout)):
            query = data.query
        else:
            query = data

        query = ".".join(query.QUALNAME.split(".")[1:])

        while True:
            try:
                return self._send(data, timeout=timeout)
            except FloodWait as e:
                amount = e.x

                if amount > sleep_threshold:
                    raise

                log.warning('[{}] Sleeping for {}s (required by "{}")'.format(
                    self.client.session_name, amount, query))

                time.sleep(amount)
            except (OSError, TimeoutError, InternalServerError) as e:
                if retries == 0:
                    raise e from None

                (log.warning if retries < 2 else log.info)(
                    '[{}] Retrying "{}" due to {}'.format(
                        Session.MAX_RETRIES - retries + 1, query, e))

                time.sleep(0.5)

                return self.send(data, retries - 1, timeout)