class Auth: MAX_RETRIES = 5 def __init__(self, client: "pyrogram.Client", dc_id: int): self.dc_id = dc_id self.test_mode = client.storage.test_mode() self.ipv6 = client.ipv6 self.proxy = client.proxy self.connection = None @staticmethod def pack(data: TLObject) -> bytes: return (bytes(8) + Long(MsgId()) + Int(len(data.write())) + data.write()) @staticmethod def unpack(b: BytesIO): b.seek( 20) # Skip auth_key_id (8), message_id (8) and message_length (4) return TLObject.read(b) def send(self, data: TLObject): data = self.pack(data) self.connection.send(data) response = BytesIO(self.connection.recv()) return self.unpack(response) def create(self): """ https://core.telegram.org/mtproto/auth_key https://core.telegram.org/mtproto/samples-auth_key """ retries_left = self.MAX_RETRIES # The server may close the connection at any time, causing the auth key creation to fail. # If that happens, just try again up to MAX_RETRIES times. while True: self.connection = Connection(self.dc_id, self.test_mode, self.ipv6, self.proxy) try: log.info("Start creating a new auth key on DC{}".format( self.dc_id)) self.connection.connect() # Step 1; Step 2 nonce = int.from_bytes(urandom(16), "little", signed=True) log.debug("Send req_pq: {}".format(nonce)) res_pq = self.send(functions.ReqPqMulti(nonce=nonce)) log.debug("Got ResPq: {}".format(res_pq.server_nonce)) log.debug("Server public key fingerprints: {}".format( res_pq.server_public_key_fingerprints)) for i in res_pq.server_public_key_fingerprints: if i in RSA.server_public_keys: log.debug("Using fingerprint: {}".format(i)) public_key_fingerprint = i break else: log.debug("Fingerprint unknown: {}".format(i)) else: raise Exception("Public key not found") # Step 3 pq = int.from_bytes(res_pq.pq, "big") log.debug("Start PQ factorization: {}".format(pq)) start = time.time() g = Prime.decompose(pq) p, q = sorted((g, pq // g)) # p < q log.debug("Done PQ factorization ({}s): {} {}".format( round(time.time() - start, 3), p, q)) # Step 4 server_nonce = res_pq.server_nonce new_nonce = int.from_bytes(urandom(32), "little", signed=True) data = types.PQInnerData( pq=res_pq.pq, p=p.to_bytes(4, "big"), q=q.to_bytes(4, "big"), nonce=nonce, server_nonce=server_nonce, new_nonce=new_nonce, ).write() sha = sha1(data).digest() padding = urandom(-(len(data) + len(sha)) % 255) data_with_hash = sha + data + padding encrypted_data = RSA.encrypt(data_with_hash, public_key_fingerprint) log.debug("Done encrypt data with RSA") # Step 5. TODO: Handle "server_DH_params_fail". Code assumes response is ok log.debug("Send req_DH_params") server_dh_params = self.send( functions.ReqDHParams( nonce=nonce, server_nonce=server_nonce, p=p.to_bytes(4, "big"), q=q.to_bytes(4, "big"), public_key_fingerprint=public_key_fingerprint, encrypted_data=encrypted_data)) encrypted_answer = server_dh_params.encrypted_answer server_nonce = server_nonce.to_bytes(16, "little", signed=True) new_nonce = new_nonce.to_bytes(32, "little", signed=True) tmp_aes_key = (sha1(new_nonce + server_nonce).digest() + sha1(server_nonce + new_nonce).digest()[:12]) tmp_aes_iv = (sha1(server_nonce + new_nonce).digest()[12:] + sha1(new_nonce + new_nonce).digest() + new_nonce[:4]) server_nonce = int.from_bytes(server_nonce, "little", signed=True) answer_with_hash = AES.ige256_decrypt(encrypted_answer, tmp_aes_key, tmp_aes_iv) answer = answer_with_hash[20:] server_dh_inner_data = TLObject.read(BytesIO(answer)) log.debug("Done decrypting answer") dh_prime = int.from_bytes(server_dh_inner_data.dh_prime, "big") delta_time = server_dh_inner_data.server_time - time.time() log.debug("Delta time: {}".format(round(delta_time, 3))) # Step 6 g = server_dh_inner_data.g b = int.from_bytes(urandom(256), "big") g_b = pow(g, b, dh_prime).to_bytes(256, "big") retry_id = 0 data = types.ClientDHInnerData(nonce=nonce, server_nonce=server_nonce, retry_id=retry_id, g_b=g_b).write() sha = sha1(data).digest() padding = urandom(-(len(data) + len(sha)) % 16) data_with_hash = sha + data + padding encrypted_data = AES.ige256_encrypt(data_with_hash, tmp_aes_key, tmp_aes_iv) log.debug("Send set_client_DH_params") set_client_dh_params_answer = self.send( functions.SetClientDHParams(nonce=nonce, server_nonce=server_nonce, encrypted_data=encrypted_data)) # TODO: Handle "auth_key_aux_hash" if the previous step fails # Step 7; Step 8 g_a = int.from_bytes(server_dh_inner_data.g_a, "big") auth_key = pow(g_a, b, dh_prime).to_bytes(256, "big") server_nonce = server_nonce.to_bytes(16, "little", signed=True) # TODO: Handle errors ####################### # Security checks ####################### assert dh_prime == Prime.CURRENT_DH_PRIME log.debug("DH parameters check: OK") # https://core.telegram.org/mtproto/security_guidelines#g-a-and-g-b-validation g_b = int.from_bytes(g_b, "big") assert 1 < g < dh_prime - 1 assert 1 < g_a < dh_prime - 1 assert 1 < g_b < dh_prime - 1 assert 2**(2048 - 64) < g_a < dh_prime - 2**(2048 - 64) assert 2**(2048 - 64) < g_b < dh_prime - 2**(2048 - 64) log.debug("g_a and g_b validation: OK") # https://core.telegram.org/mtproto/security_guidelines#checking-sha1-hash-values answer = server_dh_inner_data.write( ) # Call .write() to remove padding assert answer_with_hash[:20] == sha1(answer).digest() log.debug("SHA1 hash values check: OK") # https://core.telegram.org/mtproto/security_guidelines#checking-nonce-server-nonce-and-new-nonce-fields # 1st message assert nonce == res_pq.nonce # 2nd message server_nonce = int.from_bytes(server_nonce, "little", signed=True) assert nonce == server_dh_params.nonce assert server_nonce == server_dh_params.server_nonce # 3rd message assert nonce == set_client_dh_params_answer.nonce assert server_nonce == set_client_dh_params_answer.server_nonce server_nonce = server_nonce.to_bytes(16, "little", signed=True) log.debug("Nonce fields check: OK") # Step 9 server_salt = AES.xor(new_nonce[:8], server_nonce[:8]) log.debug("Server salt: {}".format( int.from_bytes(server_salt, "little"))) log.info("Done auth key exchange: {}".format( set_client_dh_params_answer.__class__.__name__)) except Exception as e: if retries_left: retries_left -= 1 else: raise e time.sleep(1) continue else: return auth_key finally: self.connection.close()
class Auth: MAX_RETRIES = 5 CURRENT_DH_PRIME = int( "C71CAEB9C6B1C9048E6C522F70F13F73980D40238E3E21C14934D037563D930F" "48198A0AA7C14058229493D22530F4DBFA336F6E0AC925139543AED44CCE7C37" "20FD51F69458705AC68CD4FE6B6B13ABDC9746512969328454F18FAF8C595F64" "2477FE96BB2A941D5BCD1D4AC8CC49880708FA9B378E3C4F3A9060BEE67CF9A4" "A4A695811051907E162753B56B0F6B410DBA74D8A84B2A14B3144E0EF1284754" "FD17ED950D5965B4B9DD46582DB1178D169C6BC465B0D6FF9CA3928FEF5B9AE4" "E418FC15E83EBEA0F87FA9FF5EED70050DED2849F47BF959D956850CE929851F" "0D8115F635B105EE2E4E15D04B2454BF6F4FADF034B10403119CD8E3B92FCC5B", 16 ) def __init__(self, dc_id: int, test_mode: bool, proxy: type): self.dc_id = dc_id self.test_mode = test_mode self.connection = Connection(DataCenter(dc_id, test_mode), proxy) @staticmethod def pack(data: Object) -> bytes: return ( bytes(8) + Long(MsgId()) + Int(len(data.write())) + data.write() ) @staticmethod def unpack(b: BytesIO): b.seek(20) # Skip auth_key_id (8), message_id (8) and message_length (4) return Object.read(b) def send(self, data: Object): data = self.pack(data) self.connection.send(data) response = BytesIO(self.connection.recv()) return self.unpack(response) def create(self): """ https://core.telegram.org/mtproto/auth_key https://core.telegram.org/mtproto/samples-auth_key """ retries_left = self.MAX_RETRIES # The server may close the connection at any time, causing the auth key creation to fail. # If that happens, just try again up to MAX_RETRIES times. while True: try: log.info("Start creating a new auth key on DC{}".format(self.dc_id)) self.connection.connect() # Step 1; Step 2 nonce = int.from_bytes(urandom(16), "little", signed=True) log.debug("Send req_pq: {}".format(nonce)) res_pq = self.send(functions.ReqPqMulti(nonce)) log.debug("Got ResPq: {}".format(res_pq.server_nonce)) log.debug("Server public key fingerprints: {}".format(res_pq.server_public_key_fingerprints)) for i in res_pq.server_public_key_fingerprints: if i in RSA.server_public_keys: log.debug("Using fingerprint: {}".format(i)) public_key_fingerprint = i break else: log.debug("Fingerprint unknown: {}".format(i)) else: raise Exception("Public key not found") # Step 3 pq = int.from_bytes(res_pq.pq, "big") log.debug("Start PQ factorization: {}".format(pq)) start = time.time() g = Prime.decompose(pq) p, q = sorted((g, pq // g)) # p < q log.debug("Done PQ factorization ({}s): {} {}".format(round(time.time() - start, 3), p, q)) # Step 4 server_nonce = res_pq.server_nonce new_nonce = int.from_bytes(urandom(32), "little", signed=True) data = types.PQInnerData( res_pq.pq, int.to_bytes(p, 4, "big"), int.to_bytes(q, 4, "big"), nonce, server_nonce, new_nonce, ).write() sha = sha1(data).digest() padding = urandom(- (len(data) + len(sha)) % 255) data_with_hash = sha + data + padding encrypted_data = RSA.encrypt(data_with_hash, public_key_fingerprint) log.debug("Done encrypt data with RSA") # Step 5. TODO: Handle "server_DH_params_fail". Code assumes response is ok log.debug("Send req_DH_params") server_dh_params = self.send( functions.ReqDHParams( nonce, server_nonce, int.to_bytes(p, 4, "big"), int.to_bytes(q, 4, "big"), public_key_fingerprint, encrypted_data ) ) encrypted_answer = server_dh_params.encrypted_answer server_nonce = int.to_bytes(server_nonce, 16, "little", signed=True) new_nonce = int.to_bytes(new_nonce, 32, "little", signed=True) tmp_aes_key = ( sha1(new_nonce + server_nonce).digest() + sha1(server_nonce + new_nonce).digest()[:12] ) tmp_aes_iv = ( sha1(server_nonce + new_nonce).digest()[12:] + sha1(new_nonce + new_nonce).digest() + new_nonce[:4] ) server_nonce = int.from_bytes(server_nonce, "little", signed=True) answer_with_hash = AES.ige_decrypt(encrypted_answer, tmp_aes_key, tmp_aes_iv) answer = answer_with_hash[20:] server_dh_inner_data = Object.read(BytesIO(answer)) log.debug("Done decrypting answer") dh_prime = int.from_bytes(server_dh_inner_data.dh_prime, "big") delta_time = server_dh_inner_data.server_time - time.time() log.debug("Delta time: {}".format(round(delta_time, 3))) # Step 6 g = server_dh_inner_data.g b = int.from_bytes(urandom(256), "big") g_b = int.to_bytes(pow(g, b, dh_prime), 256, "big") retry_id = 0 data = types.ClientDHInnerData( nonce, server_nonce, retry_id, g_b ).write() sha = sha1(data).digest() padding = urandom(- (len(data) + len(sha)) % 16) data_with_hash = sha + data + padding encrypted_data = AES.ige_encrypt(data_with_hash, tmp_aes_key, tmp_aes_iv) log.debug("Send set_client_DH_params") set_client_dh_params_answer = self.send( functions.SetClientDHParams( nonce, server_nonce, encrypted_data ) ) # TODO: Handle "auth_key_aux_hash" if the previous step fails # Step 7; Step 8 g_a = int.from_bytes(server_dh_inner_data.g_a, "big") auth_key = int.to_bytes(pow(g_a, b, dh_prime), 256, "big") server_nonce = int.to_bytes(server_nonce, 16, "little", signed=True) # TODO: Handle errors ####################### # Security checks ####################### assert dh_prime == self.CURRENT_DH_PRIME log.debug("DH parameters check: OK") # https://core.telegram.org/mtproto/security_guidelines#g-a-and-g-b-validation g_b = int.from_bytes(g_b, "big") assert 1 < g < dh_prime - 1 assert 1 < g_a < dh_prime - 1 assert 1 < g_b < dh_prime - 1 assert 2 ** (2048 - 64) < g_a < dh_prime - 2 ** (2048 - 64) assert 2 ** (2048 - 64) < g_b < dh_prime - 2 ** (2048 - 64) log.debug("g_a and g_b validation: OK") # https://core.telegram.org/mtproto/security_guidelines#checking-sha1-hash-values answer = server_dh_inner_data.write() # Call .write() to remove padding assert answer_with_hash[:20] == sha1(answer).digest() log.debug("SHA1 hash values check: OK") # https://core.telegram.org/mtproto/security_guidelines#checking-nonce-server-nonce-and-new-nonce-fields # 1st message assert nonce == res_pq.nonce # 2nd message server_nonce = int.from_bytes(server_nonce, "little", signed=True) assert nonce == server_dh_params.nonce assert server_nonce == server_dh_params.server_nonce # 3rd message assert nonce == set_client_dh_params_answer.nonce assert server_nonce == set_client_dh_params_answer.server_nonce server_nonce = int.to_bytes(server_nonce, 16, "little", signed=True) log.debug("Nonce fields check: OK") # Step 9 server_salt = AES.xor(new_nonce[:8], server_nonce[:8]) log.debug("Server salt: {}".format(int.from_bytes(server_salt, "little"))) log.info( "Done auth key exchange: {}".format( set_client_dh_params_answer.__class__.__name__ ) ) except Exception as e: if retries_left: retries_left -= 1 else: raise e log.warning("Auth key creation failed. Let's try again: {}".format(repr(e))) time.sleep(1) continue else: return auth_key finally: self.connection.close()
class Session: VERSION = __version__ APP_VERSION = "Pyrogram \U0001f525 {}".format(VERSION) DEVICE_MODEL = "{} {}".format(platform.python_implementation(), platform.python_version()) SYSTEM_VERSION = "{} {}".format(platform.system(), platform.release()) INITIAL_SALT = 0x616e67656c696361 WORKERS = 4 WAIT_TIMEOUT = 10 MAX_RETRIES = 5 ACKS_THRESHOLD = 8 PING_INTERVAL = 5 notice_displayed = False def __init__(self, dc_id: int, test_mode: bool, proxy: type, auth_key: bytes, api_id: str, is_cdn: bool = False): if not Session.notice_displayed: print("Pyrogram v{}, {}".format(__version__, __copyright__)) print("Licensed under the terms of the " + __license__, end="\n\n") Session.notice_displayed = True self.is_cdn = is_cdn self.connection = Connection(DataCenter(dc_id, test_mode), proxy) self.api_id = api_id self.auth_key = auth_key self.auth_key_id = sha1(auth_key).digest()[-8:] self.msg_id = MsgId() self.session_id = Long(self.msg_id()) self.msg_factory = MsgFactory(self.msg_id) self.current_salt = None self.pending_acks = set() self.recv_queue = Queue() self.results = {} self.ping_thread = None self.ping_thread_event = Event() self.next_salt_thread = None self.next_salt_thread_event = Event() self.is_connected = Event() self.update_handler = None self.total_connections = 0 self.total_messages = 0 self.total_bytes = 0 def start(self): terms = None while True: try: self.connection.connect() for i in range(self.WORKERS): Thread(target=self.worker, name="Worker#{}".format(i + 1)).start() Thread(target=self.recv, name="RecvThread").start() self.current_salt = FutureSalt(0, 0, self.INITIAL_SALT) self.current_salt = FutureSalt( 0, 0, self._send(functions.Ping(0)).new_server_salt) self.current_salt = self._send( functions.GetFutureSalts(1)).salts[0] self.next_salt_thread = Thread(target=self.next_salt, name="NextSaltThread") self.next_salt_thread.start() if not self.is_cdn: terms = self._send( functions.InvokeWithLayer( layer, functions.InitConnection( self.api_id, self.DEVICE_MODEL, self.SYSTEM_VERSION, self.APP_VERSION, "en", "", "en", functions.help.GetTermsOfService(), ))).text self.ping_thread = Thread(target=self.ping, name="PingThread") self.ping_thread.start() log.info("Connection inited: Layer {}".format(layer)) except (OSError, TimeoutError): self.stop() else: break self.is_connected.set() self.total_connections += 1 log.debug("Session started") return terms def stop(self): self.is_connected.clear() self.ping_thread_event.set() self.next_salt_thread_event.set() if self.ping_thread is not None: self.ping_thread.join() if self.next_salt_thread is not None: self.next_salt_thread.join() self.ping_thread_event.clear() self.next_salt_thread_event.clear() self.connection.close() for i in range(self.WORKERS): self.recv_queue.put(None) log.debug("Session stopped") def restart(self): self.stop() self.start() def pack(self, message: Message): data = Long(self.current_salt.salt) + self.session_id + message.write() # MTProto 2.0 requires a minimum of 12 padding bytes. # I don't get why it says up to 1024 when what it actually needs after the # required 12 bytes is just extra 0..15 padding bytes for aes # TODO: It works, but recheck this. What's the meaning of 12..1024 padding bytes? padding = urandom(-(len(data) + 12) % 16 + 12) # 88 = 88 + 0 (outgoing message) msg_key_large = sha256(self.auth_key[88:88 + 32] + data + padding).digest() msg_key = msg_key_large[8:24] aes_key, aes_iv = KDF(self.auth_key, msg_key, True) return self.auth_key_id + msg_key + IGE.encrypt( data + padding, aes_key, aes_iv) def unpack(self, b: BytesIO) -> Message: assert b.read(8) == self.auth_key_id, b.getvalue() msg_key = b.read(16) aes_key, aes_iv = KDF(self.auth_key, msg_key, False) data = BytesIO(IGE.decrypt(b.read(), aes_key, aes_iv)) data.read(8) # https://core.telegram.org/mtproto/security_guidelines#checking-session-id assert data.read(8) == self.session_id message = Message.read(data) # https://core.telegram.org/mtproto/security_guidelines#checking-sha256-hash-value-of-msg-key # https://core.telegram.org/mtproto/security_guidelines#checking-message-length # 96 = 88 + 8 (incoming message) assert msg_key == sha256(self.auth_key[96:96 + 32] + data.getvalue()).digest()[8:24] # https://core.telegram.org/mtproto/security_guidelines#checking-msg-id # TODO: check for lower msg_ids assert message.msg_id % 2 != 0 return message def worker(self): name = threading.current_thread().name log.debug("{} started".format(name)) while True: packet = self.recv_queue.get() if packet is None: break try: self.unpack_dispatch_and_ack(packet) except Exception as e: log.error(e, exc_info=True) log.debug("{} stopped".format(name)) def unpack_dispatch_and_ack(self, packet: bytes): # TODO: A better dispatcher data = self.unpack(BytesIO(packet)) messages = (data.body.messages if isinstance(data.body, MsgContainer) else [data]) log.debug(data) self.total_bytes += len(packet) self.total_messages += len(messages) for i in messages: if i.seq_no % 2 != 0: if i.msg_id in self.pending_acks: continue else: self.pending_acks.add(i.msg_id) # log.debug("{}".format(type(i.body))) if isinstance(i.body, (types.MsgDetailedInfo, types.MsgNewDetailedInfo)): self.pending_acks.add(i.body.answer_msg_id) continue if isinstance(i.body, types.NewSessionCreated): continue msg_id = None if isinstance(i.body, (types.BadMsgNotification, types.BadServerSalt)): msg_id = i.body.bad_msg_id elif isinstance(i.body, (core.FutureSalts, types.RpcResult)): msg_id = i.body.req_msg_id elif isinstance(i.body, types.Pong): msg_id = i.body.msg_id else: if self.update_handler: self.update_handler(i.body) if msg_id in self.results: self.results[msg_id].value = getattr(i.body, "result", i.body) self.results[msg_id].event.set() # print( # "This packet bytes: ({}) | Total bytes: ({})\n" # "This packet messages: ({}) | Total messages: ({})\n" # "Total connections: ({})".format( # len(packet), self.total_bytes, len(messages), self.total_messages, self.total_connections # ) # ) if len(self.pending_acks) >= self.ACKS_THRESHOLD: log.info("Send {} acks".format(len(self.pending_acks))) try: self._send(types.MsgsAck(list(self.pending_acks)), False) except (OSError, TimeoutError): pass else: self.pending_acks.clear() def ping(self): log.debug("PingThread started") while True: self.ping_thread_event.wait(self.PING_INTERVAL) if self.ping_thread_event.is_set(): break try: self._send(functions.Ping(0), False) except (OSError, TimeoutError): pass log.debug("PingThread stopped") def next_salt(self): log.debug("NextSaltThread started") while True: now = datetime.now() # Seconds to wait until middle-overlap, which is # 15 minutes before/after the current/next salt end/start time dt = (self.current_salt.valid_until - now).total_seconds() - 900 log.debug( "Current salt: {} | Next salt in {:.0f}m {:.0f}s ({})".format( self.current_salt.salt, dt // 60, dt % 60, now + timedelta(seconds=dt))) self.next_salt_thread_event.wait(dt) if self.next_salt_thread_event.is_set(): break try: self.current_salt = self._send( functions.GetFutureSalts(1)).salts[0] except (OSError, TimeoutError): self.connection.close() break log.debug("NextSaltThread stopped") def recv(self): log.debug("RecvThread started") while True: packet = self.connection.recv() if packet is None or (len(packet) == 4 and Int.read(BytesIO(packet)) == -404): if self.is_connected.is_set(): Thread(target=self.restart, name="RestartThread").start() break self.recv_queue.put(packet) log.debug("RecvThread stopped") def _send(self, data: Object, wait_response: bool = True): message = self.msg_factory(data) msg_id = message.msg_id if wait_response: self.results[msg_id] = Result() payload = self.pack(message) try: self.connection.send(payload) except OSError as e: self.results.pop(msg_id, None) raise e if wait_response: self.results[msg_id].event.wait(self.WAIT_TIMEOUT) result = self.results.pop(msg_id).value if result is None: raise TimeoutError elif isinstance(result, types.RpcError): Error.raise_it(result, type(data)) else: return result def send(self, data: Object): for i in range(self.MAX_RETRIES): self.is_connected.wait() try: return self._send(data) except (OSError, TimeoutError): log.warning("Retrying {}".format(type(data))) continue else: return None
class Session: VERSION = __version__ APP_VERSION = "Pyrogram \U0001f525 {}".format(VERSION) DEVICE_MODEL = "{} {}".format( platform.python_implementation(), platform.python_version() ) SYSTEM_VERSION = "{} {}".format( platform.system(), platform.release() ) INITIAL_SALT = 0x616e67656c696361 NET_WORKERS = 1 WAIT_TIMEOUT = 30 MAX_RETRIES = 5 ACKS_THRESHOLD = 8 PING_INTERVAL = 5 notice_displayed = False BAD_MSG_DESCRIPTION = { 16: "[16] msg_id too low, the client time has to be synchronized", 17: "[17] msg_id too high, the client time has to be synchronized", 18: "[18] incorrect two lower order msg_id bits, the server expects client message msg_id to be divisible by 4", 19: "[19] container msg_id is the same as msg_id of a previously received message", 20: "[20] message too old, it cannot be verified by the server", 32: "[32] msg_seqno too low", 33: "[33] msg_seqno too high", 34: "[34] an even msg_seqno expected, but odd received", 35: "[35] odd msg_seqno expected, but even received", 48: "[48] incorrect server salt", 64: "[64] invalid container" } def __init__(self, dc_id: int, test_mode: bool, proxy: type, auth_key: bytes, api_id: str, is_cdn: bool = False, client: pyrogram = None): if not Session.notice_displayed: print("Pyrogram v{}, {}".format(__version__, __copyright__)) print("Licensed under the terms of the " + __license__, end="\n\n") Session.notice_displayed = True self.connection = Connection(DataCenter(dc_id, test_mode), proxy) self.api_id = api_id self.is_cdn = is_cdn self.client = client self.auth_key = auth_key self.auth_key_id = sha1(auth_key).digest()[-8:] self.session_id = Long(MsgId()) self.msg_factory = MsgFactory() self.current_salt = None self.pending_acks = set() self.recv_queue = Queue() self.results = {} self.ping_thread = None self.ping_thread_event = Event() self.next_salt_thread = None self.next_salt_thread_event = Event() self.is_connected = Event() def start(self): while True: try: self.connection.connect() for i in range(self.NET_WORKERS): Thread(target=self.net_worker, name="NetWorker#{}".format(i + 1)).start() Thread(target=self.recv, name="RecvThread").start() self.current_salt = FutureSalt(0, 0, self.INITIAL_SALT) self.current_salt = FutureSalt(0, 0, self._send(functions.Ping(0)).new_server_salt) self.current_salt = self._send(functions.GetFutureSalts(1)).salts[0] self.next_salt_thread = Thread(target=self.next_salt, name="NextSaltThread") self.next_salt_thread.start() if not self.is_cdn: self._send( functions.InvokeWithLayer( layer, functions.InitConnection( self.api_id, self.DEVICE_MODEL, self.SYSTEM_VERSION, self.APP_VERSION, "en", "", "en", functions.help.GetConfig(), ) ) ) self.ping_thread = Thread(target=self.ping, name="PingThread") self.ping_thread.start() log.info("Connection inited: Layer {}".format(layer)) except (OSError, TimeoutError, Error): self.stop() else: break self.is_connected.set() log.debug("Session started") def stop(self): self.is_connected.clear() self.ping_thread_event.set() self.next_salt_thread_event.set() if self.ping_thread is not None: self.ping_thread.join() if self.next_salt_thread is not None: self.next_salt_thread.join() self.ping_thread_event.clear() self.next_salt_thread_event.clear() self.connection.close() for i in range(self.NET_WORKERS): self.recv_queue.put(None) for i in self.results.values(): i.event.set() log.debug("Session stopped") def restart(self): self.stop() self.start() def pack(self, message: Message): data = Long(self.current_salt.salt) + self.session_id + message.write() padding = urandom(-(len(data) + 12) % 16 + 12) # 88 = 88 + 0 (outgoing message) msg_key_large = sha256(self.auth_key[88: 88 + 32] + data + padding).digest() msg_key = msg_key_large[8:24] aes_key, aes_iv = KDF(self.auth_key, msg_key, True) return self.auth_key_id + msg_key + AES.ige_encrypt(data + padding, aes_key, aes_iv) def unpack(self, b: BytesIO) -> Message: assert b.read(8) == self.auth_key_id, b.getvalue() msg_key = b.read(16) aes_key, aes_iv = KDF(self.auth_key, msg_key, False) data = BytesIO(AES.ige_decrypt(b.read(), aes_key, aes_iv)) data.read(8) # https://core.telegram.org/mtproto/security_guidelines#checking-session-id assert data.read(8) == self.session_id message = Message.read(data) # https://core.telegram.org/mtproto/security_guidelines#checking-sha256-hash-value-of-msg-key # https://core.telegram.org/mtproto/security_guidelines#checking-message-length # 96 = 88 + 8 (incoming message) assert msg_key == sha256(self.auth_key[96:96 + 32] + data.getvalue()).digest()[8:24] # https://core.telegram.org/mtproto/security_guidelines#checking-msg-id # TODO: check for lower msg_ids assert message.msg_id % 2 != 0 return message def net_worker(self): name = threading.current_thread().name log.debug("{} started".format(name)) while True: packet = self.recv_queue.get() if packet is None: break try: self.unpack_dispatch_and_ack(packet) except Exception as e: log.error(e, exc_info=True) log.debug("{} stopped".format(name)) def unpack_dispatch_and_ack(self, packet: bytes): data = self.unpack(BytesIO(packet)) messages = ( data.body.messages if isinstance(data.body, MsgContainer) else [data] ) log.debug(data) for msg in messages: if msg.seq_no % 2 != 0: if msg.msg_id in self.pending_acks: continue else: self.pending_acks.add(msg.msg_id) if isinstance(msg.body, (types.MsgDetailedInfo, types.MsgNewDetailedInfo)): self.pending_acks.add(msg.body.answer_msg_id) continue if isinstance(msg.body, types.NewSessionCreated): continue msg_id = None if isinstance(msg.body, (types.BadMsgNotification, types.BadServerSalt)): msg_id = msg.body.bad_msg_id elif isinstance(msg.body, (core.FutureSalts, types.RpcResult)): msg_id = msg.body.req_msg_id elif isinstance(msg.body, types.Pong): msg_id = msg.body.msg_id else: if self.client is not None: self.client.updates_queue.put(msg.body) if msg_id in self.results: self.results[msg_id].value = getattr(msg.body, "result", msg.body) self.results[msg_id].event.set() if len(self.pending_acks) >= self.ACKS_THRESHOLD: log.info("Send {} acks".format(len(self.pending_acks))) try: self._send(types.MsgsAck(list(self.pending_acks)), False) except (OSError, TimeoutError): pass else: self.pending_acks.clear() def ping(self): log.debug("PingThread started") while True: self.ping_thread_event.wait(self.PING_INTERVAL) if self.ping_thread_event.is_set(): break try: self._send(functions.PingDelayDisconnect(0, self.PING_INTERVAL + 15), False) except (OSError, TimeoutError): pass log.debug("PingThread stopped") def next_salt(self): log.debug("NextSaltThread started") while True: now = datetime.now() # Seconds to wait until middle-overlap, which is # 15 minutes before/after the current/next salt end/start time dt = (self.current_salt.valid_until - now).total_seconds() - 900 log.debug("Current salt: {} | Next salt in {:.0f}m {:.0f}s ({})".format( self.current_salt.salt, dt // 60, dt % 60, now + timedelta(seconds=dt) )) self.next_salt_thread_event.wait(dt) if self.next_salt_thread_event.is_set(): break try: self.current_salt = self._send(functions.GetFutureSalts(1)).salts[0] except (OSError, TimeoutError): self.connection.close() break log.debug("NextSaltThread stopped") def recv(self): log.debug("RecvThread started") while True: packet = self.connection.recv() if packet is None or len(packet) == 4: if packet: log.warning("Server sent \"{}\"".format(Int.read(BytesIO(packet)))) if self.is_connected.is_set(): Thread(target=self.restart, name="RestartThread").start() break self.recv_queue.put(packet) log.debug("RecvThread stopped") def _send(self, data: Object, wait_response: bool = True): message = self.msg_factory(data) msg_id = message.msg_id if wait_response: self.results[msg_id] = Result() payload = self.pack(message) try: self.connection.send(payload) except OSError as e: self.results.pop(msg_id, None) raise e if wait_response: self.results[msg_id].event.wait(self.WAIT_TIMEOUT) result = self.results.pop(msg_id).value if result is None: raise TimeoutError elif isinstance(result, types.RpcError): Error.raise_it(result, type(data)) elif isinstance(result, types.BadMsgNotification): raise Exception(self.BAD_MSG_DESCRIPTION.get( result.error_code, "Error code {}".format(result.error_code) )) else: return result def send(self, data: Object): for i in range(self.MAX_RETRIES): self.is_connected.wait() try: return self._send(data) except (OSError, TimeoutError): (log.warning if i > 0 else log.info)("{}: {} Retrying {}".format(i, datetime.now(), type(data))) continue else: return None
class Session: INITIAL_SALT = 0x616e67656c696361 NET_WORKERS = 1 START_TIMEOUT = 1 WAIT_TIMEOUT = 15 SLEEP_THRESHOLD = 60 MAX_RETRIES = 5 ACKS_THRESHOLD = 8 PING_INTERVAL = 5 notice_displayed = False BAD_MSG_DESCRIPTION = { 16: "[16] msg_id too low, the client time has to be synchronized", 17: "[17] msg_id too high, the client time has to be synchronized", 18: "[18] incorrect two lower order msg_id bits, the server expects client message msg_id to be divisible by 4", 19: "[19] container msg_id is the same as msg_id of a previously received message", 20: "[20] message too old, it cannot be verified by the server", 32: "[32] msg_seqno too low", 33: "[33] msg_seqno too high", 34: "[34] an even msg_seqno expected, but odd received", 35: "[35] odd msg_seqno expected, but even received", 48: "[48] incorrect server salt", 64: "[64] invalid container" } def __init__(self, client: pyrogram, dc_id: int, auth_key: bytes, is_media: bool = False, is_cdn: bool = False): if not Session.notice_displayed: print("Pyrogram v{}, {}".format(__version__, __copyright__)) print("Licensed under the terms of the " + __license__, end="\n\n") Session.notice_displayed = True self.client = client self.dc_id = dc_id self.auth_key = auth_key self.is_media = is_media self.is_cdn = is_cdn self.connection = None self.auth_key_id = sha1(auth_key).digest()[-8:] self.session_id = Long(MsgId()) self.msg_factory = MsgFactory() self.current_salt = None self.pending_acks = set() self.recv_queue = Queue() self.results = {} self.ping_thread = None self.ping_thread_event = Event() self.next_salt_thread = None self.next_salt_thread_event = Event() self.net_worker_list = [] self.is_connected = Event() def start(self): while True: self.connection = Connection(self.dc_id, self.client.storage.test_mode(), self.client.ipv6, self.client.proxy) try: self.connection.connect() for i in range(self.NET_WORKERS): self.net_worker_list.append( Thread(target=self.net_worker, name="NetWorker#{}".format(i + 1))) self.net_worker_list[-1].start() Thread(target=self.recv, name="RecvThread").start() self.current_salt = FutureSalt(0, 0, self.INITIAL_SALT) self.current_salt = FutureSalt( 0, 0, self._send(functions.Ping(ping_id=0), timeout=self.START_TIMEOUT).new_server_salt) self.current_salt = self._send( functions.GetFutureSalts(num=1), timeout=self.START_TIMEOUT).salts[0] self.next_salt_thread = Thread(target=self.next_salt, name="NextSaltThread") self.next_salt_thread.start() if not self.is_cdn: self._send(functions.InvokeWithLayer( layer=layer, query=functions.InitConnection( api_id=self.client.api_id, app_version=self.client.app_version, device_model=self.client.device_model, system_version=self.client.system_version, system_lang_code=self.client.lang_code, lang_code=self.client.lang_code, lang_pack="", query=functions.help.GetConfig(), )), timeout=self.START_TIMEOUT) self.ping_thread = Thread(target=self.ping, name="PingThread") self.ping_thread.start() log.info("Session initialized: Layer {}".format(layer)) log.info("Device: {} - {}".format(self.client.device_model, self.client.app_version)) log.info("System: {} ({})".format( self.client.system_version, self.client.lang_code.upper())) except AuthKeyDuplicated as e: self.stop() raise e except (OSError, TimeoutError, RPCError): self.stop() except Exception as e: self.stop() raise e else: break self.is_connected.set() log.debug("Session started") def stop(self): self.is_connected.clear() self.ping_thread_event.set() self.next_salt_thread_event.set() if self.ping_thread is not None: self.ping_thread.join() if self.next_salt_thread is not None: self.next_salt_thread.join() self.ping_thread_event.clear() self.next_salt_thread_event.clear() self.connection.close() for i in range(self.NET_WORKERS): self.recv_queue.put(None) for i in self.net_worker_list: i.join() self.net_worker_list.clear() self.recv_queue.queue.clear() for i in self.results.values(): i.event.set() if not self.is_media and callable(self.client.disconnect_handler): try: self.client.disconnect_handler(self.client) except Exception as e: log.error(e, exc_info=True) log.debug("Session stopped") def restart(self): self.stop() self.start() def pack(self, message: Message): data = Long(self.current_salt.salt) + self.session_id + message.write() padding = urandom(-(len(data) + 12) % 16 + 12) # 88 = 88 + 0 (outgoing message) msg_key_large = sha256(self.auth_key[88:88 + 32] + data + padding).digest() msg_key = msg_key_large[8:24] aes_key, aes_iv = KDF(self.auth_key, msg_key, True) return self.auth_key_id + msg_key + AES.ige256_encrypt( data + padding, aes_key, aes_iv) def unpack(self, b: BytesIO) -> Message: assert b.read(8) == self.auth_key_id, b.getvalue() msg_key = b.read(16) aes_key, aes_iv = KDF(self.auth_key, msg_key, False) data = BytesIO(AES.ige256_decrypt(b.read(), aes_key, aes_iv)) data.read(8) # https://core.telegram.org/mtproto/security_guidelines#checking-session-id assert data.read(8) == self.session_id message = Message.read(data) # https://core.telegram.org/mtproto/security_guidelines#checking-sha256-hash-value-of-msg-key # https://core.telegram.org/mtproto/security_guidelines#checking-message-length # 96 = 88 + 8 (incoming message) assert msg_key == sha256(self.auth_key[96:96 + 32] + data.getvalue()).digest()[8:24] # https://core.telegram.org/mtproto/security_guidelines#checking-msg-id # TODO: check for lower msg_ids assert message.msg_id % 2 != 0 return message def net_worker(self): name = threading.current_thread().name log.debug("{} started".format(name)) while True: packet = self.recv_queue.get() if packet is None: break try: data = self.unpack(BytesIO(packet)) messages = (data.body.messages if isinstance( data.body, MsgContainer) else [data]) log.debug("Received:\n{}".format(data)) for msg in messages: if msg.seq_no % 2 != 0: if msg.msg_id in self.pending_acks: continue else: self.pending_acks.add(msg.msg_id) if isinstance( msg.body, (types.MsgDetailedInfo, types.MsgNewDetailedInfo)): self.pending_acks.add(msg.body.answer_msg_id) continue if isinstance(msg.body, types.NewSessionCreated): continue msg_id = None if isinstance( msg.body, (types.BadMsgNotification, types.BadServerSalt)): msg_id = msg.body.bad_msg_id elif isinstance(msg.body, (core.FutureSalts, types.RpcResult)): msg_id = msg.body.req_msg_id elif isinstance(msg.body, types.Pong): msg_id = msg.body.msg_id else: if self.client is not None: self.client.updates_queue.put(msg.body) if msg_id in self.results: self.results[msg_id].value = getattr( msg.body, "result", msg.body) self.results[msg_id].event.set() if len(self.pending_acks) >= self.ACKS_THRESHOLD: log.info("Send {} acks".format(len(self.pending_acks))) try: self._send( types.MsgsAck(msg_ids=list(self.pending_acks)), False) except (OSError, TimeoutError): pass else: self.pending_acks.clear() except Exception as e: log.error(e, exc_info=True) log.debug("{} stopped".format(name)) def ping(self): log.debug("PingThread started") while True: self.ping_thread_event.wait(self.PING_INTERVAL) if self.ping_thread_event.is_set(): break try: self._send( functions.PingDelayDisconnect( ping_id=0, disconnect_delay=self.WAIT_TIMEOUT + 10), False) except (OSError, TimeoutError, RPCError): pass log.debug("PingThread stopped") def next_salt(self): log.debug("NextSaltThread started") while True: now = datetime.now() # Seconds to wait until middle-overlap, which is # 15 minutes before/after the current/next salt end/start time valid_until = datetime.fromtimestamp(self.current_salt.valid_until) dt = (valid_until - now).total_seconds() - 900 log.debug( "Current salt: {} | Next salt in {:.0f}m {:.0f}s ({})".format( self.current_salt.salt, dt // 60, dt % 60, now + timedelta(seconds=dt))) self.next_salt_thread_event.wait(dt) if self.next_salt_thread_event.is_set(): break try: self.current_salt = self._send( functions.GetFutureSalts(num=1)).salts[0] except (OSError, TimeoutError, RPCError): self.connection.close() break log.debug("NextSaltThread stopped") def recv(self): log.debug("RecvThread started") while True: packet = self.connection.recv() if packet is None or len(packet) == 4: if packet: log.warning("Server sent \"{}\"".format( Int.read(BytesIO(packet)))) if self.is_connected.is_set(): Thread(target=self.restart, name="RestartThread").start() break self.recv_queue.put(packet) log.debug("RecvThread stopped") def _send(self, data: TLObject, wait_response: bool = True, timeout: float = WAIT_TIMEOUT): message = self.msg_factory(data) msg_id = message.msg_id if wait_response: self.results[msg_id] = Result() log.debug("Sent:\n{}".format(message)) payload = self.pack(message) try: self.connection.send(payload) except OSError as e: self.results.pop(msg_id, None) raise e if wait_response: self.results[msg_id].event.wait(timeout) result = self.results.pop(msg_id).value if result is None: raise TimeoutError elif isinstance(result, types.RpcError): if isinstance(data, (functions.InvokeWithoutUpdates, functions.InvokeWithTakeout)): data = data.query RPCError.raise_it(result, type(data)) elif isinstance(result, types.BadMsgNotification): raise Exception( self.BAD_MSG_DESCRIPTION.get( result.error_code, "Error code {}".format(result.error_code))) else: return result def send(self, data: TLObject, retries: int = MAX_RETRIES, timeout: float = WAIT_TIMEOUT, sleep_threshold: float = SLEEP_THRESHOLD): self.is_connected.wait(self.WAIT_TIMEOUT) if isinstance( data, (functions.InvokeWithoutUpdates, functions.InvokeWithTakeout)): query = data.query else: query = data query = ".".join(query.QUALNAME.split(".")[1:]) while True: try: return self._send(data, timeout=timeout) except FloodWait as e: amount = e.x if amount > sleep_threshold: raise log.warning('[{}] Sleeping for {}s (required by "{}")'.format( self.client.session_name, amount, query)) time.sleep(amount) except (OSError, TimeoutError, InternalServerError) as e: if retries == 0: raise e from None (log.warning if retries < 2 else log.info)( '[{}] Retrying "{}" due to {}'.format( Session.MAX_RETRIES - retries + 1, query, e)) time.sleep(0.5) return self.send(data, retries - 1, timeout)