def __init__(self, database: Database, db_conn, hs): # We instantiate this first as the ReceiptsWorkerStore constructor # needs to be able to call get_max_receipt_stream_id self._receipts_id_gen = StreamIdGenerator(db_conn, "receipts_linearized", "stream_id") super(ReceiptsStore, self).__init__(database, db_conn, hs)
def __init__( self, database: DatabasePool, db_conn: LoggingDatabaseConnection, hs: "HomeServer", ): super().__init__(database, db_conn, hs) self._cross_signing_id_gen = StreamIdGenerator( db_conn, "e2e_cross_signing_keys", "stream_id")
def __init__(self, database: Database, db_conn, hs): self._account_data_id_gen = StreamIdGenerator( db_conn, "account_data_max_stream_id", "stream_id", extra_tables=[ ("room_account_data", "stream_id"), ("room_tags_revisions", "stream_id"), ], ) super(AccountDataStore, self).__init__(database, db_conn, hs)
def __init__( self, database: DatabasePool, db_conn: LoggingDatabaseConnection, hs: "HomeServer", ): super().__init__(database, db_conn, hs) if hs.config.worker.worker_app is None: self._push_rules_stream_id_gen: AbstractStreamIdTracker = StreamIdGenerator( db_conn, "push_rules_stream", "stream_id") else: self._push_rules_stream_id_gen = SlavedIdTracker( db_conn, "push_rules_stream", "stream_id") push_rules_prefill, push_rules_id = self.db_pool.get_cache_dict( db_conn, "push_rules_stream", entity_column="user_id", stream_column="stream_id", max_value=self.get_max_push_rules_stream_id(), ) self.push_rules_stream_cache = StreamChangeCache( "PushRulesStreamChangeCache", push_rules_id, prefilled_cache=push_rules_prefill, )
def __init__(self, database: DatabasePool, db_conn, hs): super().__init__(database, db_conn, hs) if hs.config.worker.worker_app is None: self._push_rules_stream_id_gen = StreamIdGenerator( db_conn, "push_rules_stream", "stream_id") # type: Union[StreamIdGenerator, SlavedIdTracker] else: self._push_rules_stream_id_gen = SlavedIdTracker( db_conn, "push_rules_stream", "stream_id") push_rules_prefill, push_rules_id = self.db_pool.get_cache_dict( db_conn, "push_rules_stream", entity_column="user_id", stream_column="stream_id", max_value=self.get_max_push_rules_stream_id(), ) self.push_rules_stream_cache = StreamChangeCache( "PushRulesStreamChangeCache", push_rules_id, prefilled_cache=push_rules_prefill, ) self._users_new_default_push_rules = hs.config.users_new_default_push_rules
def __init__( self, database: DatabasePool, db_conn: LoggingDatabaseConnection, hs: "HomeServer", ): super().__init__(database, db_conn, hs) self._pushers_id_gen = StreamIdGenerator(db_conn, "pushers", "id", extra_tables=[ ("deleted_pushers", "stream_id") ]) self.db_pool.updates.register_background_update_handler( "remove_deactivated_pushers", self._remove_deactivated_pushers, ) self.db_pool.updates.register_background_update_handler( "remove_stale_pushers", self._remove_stale_pushers, ) self.db_pool.updates.register_background_update_handler( "remove_deleted_email_pushers", self._remove_deleted_email_pushers, )
def __init__( self, database: DatabasePool, db_conn: LoggingDatabaseConnection, hs: "HomeServer", ): self._instance_name = hs.get_instance_name() self._receipts_id_gen: AbstractStreamIdTracker if isinstance(database.engine, PostgresEngine): self._can_write_to_receipts = ( self._instance_name in hs.config.worker.writers.receipts ) self._receipts_id_gen = MultiWriterIdGenerator( db_conn=db_conn, db=database, stream_name="receipts", instance_name=self._instance_name, tables=[("receipts_linearized", "instance_name", "stream_id")], sequence_name="receipts_sequence", writers=hs.config.worker.writers.receipts, ) else: self._can_write_to_receipts = True # We shouldn't be running in worker mode with SQLite, but its useful # to support it for unit tests. # # If this process is the writer than we need to use # `StreamIdGenerator`, otherwise we use `SlavedIdTracker` which gets # updated over replication. (Multiple writers are not supported for # SQLite). if hs.get_instance_name() in hs.config.worker.writers.receipts: self._receipts_id_gen = StreamIdGenerator( db_conn, "receipts_linearized", "stream_id" ) else: self._receipts_id_gen = SlavedIdTracker( db_conn, "receipts_linearized", "stream_id" ) super().__init__(database, db_conn, hs) max_receipts_stream_id = self.get_max_receipt_stream_id() receipts_stream_prefill, min_receipts_stream_id = self.db_pool.get_cache_dict( db_conn, "receipts_linearized", entity_column="room_id", stream_column="stream_id", max_value=max_receipts_stream_id, limit=10000, ) self._receipts_stream_cache = StreamChangeCache( "ReceiptsRoomChangeCache", min_receipts_stream_id, prefilled_cache=receipts_stream_prefill, )
def __init__( self, database: DatabasePool, db_conn: LoggingDatabaseConnection, hs: "HomeServer", ): super().__init__(database, db_conn, hs) # `_can_write_to_account_data` indicates whether the current worker is allowed # to write account data. A value of `True` implies that `_account_data_id_gen` # is an `AbstractStreamIdGenerator` and not just a tracker. self._account_data_id_gen: AbstractStreamIdTracker if isinstance(database.engine, PostgresEngine): self._can_write_to_account_data = ( self._instance_name in hs.config.worker.writers.account_data) self._account_data_id_gen = MultiWriterIdGenerator( db_conn=db_conn, db=database, stream_name="account_data", instance_name=self._instance_name, tables=[ ("room_account_data", "instance_name", "stream_id"), ("room_tags_revisions", "instance_name", "stream_id"), ("account_data", "instance_name", "stream_id"), ], sequence_name="account_data_sequence", writers=hs.config.worker.writers.account_data, ) else: # We shouldn't be running in worker mode with SQLite, but its useful # to support it for unit tests. # # If this process is the writer than we need to use # `StreamIdGenerator`, otherwise we use `SlavedIdTracker` which gets # updated over replication. (Multiple writers are not supported for # SQLite). if self._instance_name in hs.config.worker.writers.account_data: self._can_write_to_account_data = True self._account_data_id_gen = StreamIdGenerator( db_conn, "room_account_data", "stream_id", extra_tables=[("room_tags_revisions", "stream_id")], ) else: self._account_data_id_gen = SlavedIdTracker( db_conn, "room_account_data", "stream_id", extra_tables=[("room_tags_revisions", "stream_id")], ) account_max = self.get_max_account_data_stream_id() self._account_data_stream_cache = StreamChangeCache( "AccountDataAndTagsChangeCache", account_max)
def __init__( self, database: DatabasePool, db_conn: LoggingDatabaseConnection, hs: "HomeServer", ) -> None: super().__init__(database, db_conn, hs) self._instance_name = hs.get_instance_name() self._presence_id_gen: AbstractStreamIdGenerator self._can_persist_presence = (self._instance_name in hs.config.worker.writers.presence) if isinstance(database.engine, PostgresEngine): self._presence_id_gen = MultiWriterIdGenerator( db_conn=db_conn, db=database, stream_name="presence_stream", instance_name=self._instance_name, tables=[("presence_stream", "instance_name", "stream_id")], sequence_name="presence_stream_sequence", writers=hs.config.worker.writers.presence, ) else: self._presence_id_gen = StreamIdGenerator(db_conn, "presence_stream", "stream_id") self.hs = hs self._presence_on_startup = self._get_active_presence(db_conn) presence_cache_prefill, min_presence_val = self.db_pool.get_cache_dict( db_conn, "presence_stream", entity_column="user_id", stream_column="stream_id", max_value=self._presence_id_gen.get_current_token(), ) self.presence_stream_cache = StreamChangeCache( "PresenceStreamChangeCache", min_presence_val, prefilled_cache=presence_cache_prefill, )
def __init__(self, database: DatabasePool, db_conn: Connection, hs: "HomeServer"): super().__init__(database, db_conn, hs) self._pushers_id_gen = StreamIdGenerator(db_conn, "pushers", "id", extra_tables=[ ("deleted_pushers", "stream_id") ])
def __init__(self, database: DatabasePool, db_conn, hs): self._instance_name = hs.get_instance_name() if isinstance(database.engine, PostgresEngine): self._can_write_to_account_data = ( self._instance_name in hs.config.worker.writers.account_data ) self._account_data_id_gen = MultiWriterIdGenerator( db_conn=db_conn, db=database, stream_name="account_data", instance_name=self._instance_name, tables=[ ("room_account_data", "instance_name", "stream_id"), ("room_tags_revisions", "instance_name", "stream_id"), ("account_data", "instance_name", "stream_id"), ], sequence_name="account_data_sequence", writers=hs.config.worker.writers.account_data, ) else: self._can_write_to_account_data = True # We shouldn't be running in worker mode with SQLite, but its useful # to support it for unit tests. # # If this process is the writer than we need to use # `StreamIdGenerator`, otherwise we use `SlavedIdTracker` which gets # updated over replication. (Multiple writers are not supported for # SQLite). if hs.get_instance_name() in hs.config.worker.writers.account_data: self._account_data_id_gen = StreamIdGenerator( db_conn, "room_account_data", "stream_id", extra_tables=[("room_tags_revisions", "stream_id")], ) else: self._account_data_id_gen = SlavedIdTracker( db_conn, "room_account_data", "stream_id", extra_tables=[("room_tags_revisions", "stream_id")], ) account_max = self.get_max_account_data_stream_id() self._account_data_stream_cache = StreamChangeCache( "AccountDataAndTagsChangeCache", account_max ) super().__init__(database, db_conn, hs)
def __init__(self, database: DatabasePool, db_conn, hs): super(EventsWorkerStore, self).__init__(database, db_conn, hs) if hs.config.worker.writers.events == hs.get_instance_name(): # We are the process in charge of generating stream ids for events, # so instantiate ID generators based on the database self._stream_id_gen = StreamIdGenerator( db_conn, "events", "stream_ordering", ) self._backfill_id_gen = StreamIdGenerator( db_conn, "events", "stream_ordering", step=-1, extra_tables=[("ex_outlier_stream", "event_stream_ordering")], ) else: # Another process is in charge of persisting events and generating # stream IDs: rely on the replication streams to let us know which # IDs we can process. self._stream_id_gen = SlavedIdTracker(db_conn, "events", "stream_ordering") self._backfill_id_gen = SlavedIdTracker(db_conn, "events", "stream_ordering", step=-1) self._get_event_cache = Cache( "*getEvent*", keylen=3, max_entries=hs.config.caches.event_cache_size, apply_cache_factor_from_config=False, ) self._event_fetch_lock = threading.Condition() self._event_fetch_list = [] self._event_fetch_ongoing = 0
class DataStore( EventsBackgroundUpdatesStore, RoomMemberStore, RoomStore, RegistrationStore, StreamStore, ProfileStore, PresenceStore, TransactionStore, DirectoryStore, KeyStore, StateStore, SignatureStore, ApplicationServiceStore, PurgeEventsStore, EventFederationStore, MediaRepositoryStore, RejectionsStore, FilteringStore, PusherStore, PushRuleStore, ApplicationServiceTransactionStore, ReceiptsStore, EndToEndKeyStore, EndToEndRoomKeyStore, SearchStore, TagsStore, AccountDataStore, EventPushActionsStore, OpenIdStore, ClientIpStore, DeviceStore, DeviceInboxStore, UserDirectoryStore, GroupServerStore, UserErasureStore, MonthlyActiveUsersStore, StatsStore, RelationsStore, CensorEventsStore, UIAuthStore, CacheInvalidationWorkerStore, ServerMetricsStore, NewsWorkerStore, AccessWorkerStore ): def __init__(self, database: DatabasePool, db_conn, hs): self.hs = hs self._clock = hs.get_clock() self.database_engine = database.engine self._presence_id_gen = StreamIdGenerator( db_conn, "presence_stream", "stream_id" ) self._device_inbox_id_gen = StreamIdGenerator( db_conn, "device_inbox", "stream_id" ) self._public_room_id_gen = StreamIdGenerator( db_conn, "public_room_list_stream", "stream_id" ) self._device_list_id_gen = StreamIdGenerator( db_conn, "device_lists_stream", "stream_id", extra_tables=[ ("user_signature_stream", "stream_id"), ("device_lists_outbound_pokes", "stream_id"), ], ) self._cross_signing_id_gen = StreamIdGenerator( db_conn, "e2e_cross_signing_keys", "stream_id" ) self._event_reports_id_gen = IdGenerator(db_conn, "event_reports", "id") self._push_rule_id_gen = IdGenerator(db_conn, "push_rules", "id") self._push_rules_enable_id_gen = IdGenerator(db_conn, "push_rules_enable", "id") self._group_updates_id_gen = StreamIdGenerator( db_conn, "local_group_updates", "stream_id" ) if isinstance(self.database_engine, PostgresEngine): # We set the `writers` to an empty list here as we don't care about # missing updates over restarts, as we'll not have anything in our # caches to invalidate. (This reduces the amount of writes to the DB # that happen). self._cache_id_gen = MultiWriterIdGenerator( db_conn, database, stream_name="caches", instance_name=hs.get_instance_name(), table="cache_invalidation_stream_by_instance", instance_column="instance_name", id_column="stream_id", sequence_name="cache_invalidation_stream_seq", writers=[], ) else: self._cache_id_gen = None super().__init__(database, db_conn, hs) self._presence_on_startup = self._get_active_presence(db_conn) presence_cache_prefill, min_presence_val = self.db_pool.get_cache_dict( db_conn, "presence_stream", entity_column="user_id", stream_column="stream_id", max_value=self._presence_id_gen.get_current_token(), ) self.presence_stream_cache = StreamChangeCache( "PresenceStreamChangeCache", min_presence_val, prefilled_cache=presence_cache_prefill, ) max_device_inbox_id = self._device_inbox_id_gen.get_current_token() device_inbox_prefill, min_device_inbox_id = self.db_pool.get_cache_dict( db_conn, "device_inbox", entity_column="user_id", stream_column="stream_id", max_value=max_device_inbox_id, limit=1000, ) self._device_inbox_stream_cache = StreamChangeCache( "DeviceInboxStreamChangeCache", min_device_inbox_id, prefilled_cache=device_inbox_prefill, ) # The federation outbox and the local device inbox uses the same # stream_id generator. device_outbox_prefill, min_device_outbox_id = self.db_pool.get_cache_dict( db_conn, "device_federation_outbox", entity_column="destination", stream_column="stream_id", max_value=max_device_inbox_id, limit=1000, ) self._device_federation_outbox_stream_cache = StreamChangeCache( "DeviceFederationOutboxStreamChangeCache", min_device_outbox_id, prefilled_cache=device_outbox_prefill, ) device_list_max = self._device_list_id_gen.get_current_token() self._device_list_stream_cache = StreamChangeCache( "DeviceListStreamChangeCache", device_list_max ) self._user_signature_stream_cache = StreamChangeCache( "UserSignatureStreamChangeCache", device_list_max ) self._device_list_federation_stream_cache = StreamChangeCache( "DeviceListFederationStreamChangeCache", device_list_max ) events_max = self._stream_id_gen.get_current_token() curr_state_delta_prefill, min_curr_state_delta_id = self.db_pool.get_cache_dict( db_conn, "current_state_delta_stream", entity_column="room_id", stream_column="stream_id", max_value=events_max, # As we share the stream id with events token limit=1000, ) self._curr_state_delta_stream_cache = StreamChangeCache( "_curr_state_delta_stream_cache", min_curr_state_delta_id, prefilled_cache=curr_state_delta_prefill, ) _group_updates_prefill, min_group_updates_id = self.db_pool.get_cache_dict( db_conn, "local_group_updates", entity_column="user_id", stream_column="stream_id", max_value=self._group_updates_id_gen.get_current_token(), limit=1000, ) self._group_updates_stream_cache = StreamChangeCache( "_group_updates_stream_cache", min_group_updates_id, prefilled_cache=_group_updates_prefill, ) self._stream_order_on_start = self.get_room_max_stream_ordering() self._min_stream_order_on_start = self.get_room_min_stream_ordering() def get_device_stream_token(self) -> int: return self._device_list_id_gen.get_current_token() def take_presence_startup_info(self): active_on_startup = self._presence_on_startup self._presence_on_startup = None return active_on_startup def _get_active_presence(self, db_conn): """Fetch non-offline presence from the database so that we can register the appropriate time outs. """ sql = ( "SELECT user_id, state, last_active_ts, last_federation_update_ts," " last_user_sync_ts, status_msg, currently_active FROM presence_stream" " WHERE state != ?" ) txn = db_conn.cursor() txn.execute(sql, (PresenceState.OFFLINE,)) rows = self.db_pool.cursor_to_dict(txn) txn.close() for row in rows: row["currently_active"] = bool(row["currently_active"]) return [UserPresenceState(**row) for row in rows] async def get_users(self) -> List[Dict[str, Any]]: """Function to retrieve a list of users in users table. Returns: A list of dictionaries representing users. """ return await self.db_pool.simple_select_list( table="users", keyvalues={}, retcols=[ "name", "password_hash", "is_guest", "admin", "user_type", "deactivated", ], desc="get_users", ) async def get_users_paginate( self, start: int, limit: int, user_id: Optional[str] = None, name: Optional[str] = None, guests: bool = True, deactivated: bool = False, ) -> Tuple[List[Dict[str, Any]], int]: """Function to retrieve a paginated list of users from users list. This will return a json list of users and the total number of users matching the filter criteria. Args: start: start number to begin the query from limit: number of rows to retrieve user_id: search for user_id. ignored if name is not None name: search for local part of user_id or display name guests: whether to in include guest users deactivated: whether to include deactivated users Returns: A tuple of a list of mappings from user to information and a count of total users. """ def get_users_paginate_txn(txn): filters = [] args = [self.hs.config.server_name] # `name` is in database already in lower case if name: filters.append("(name LIKE ? OR LOWER(displayname) LIKE ?)") args.extend(["@%" + name.lower() + "%:%", "%" + name.lower() + "%"]) elif user_id: filters.append("name LIKE ?") args.extend(["%" + user_id.lower() + "%"]) if not guests: filters.append("is_guest = 0") if not deactivated: filters.append("deactivated = 0") where_clause = "WHERE " + " AND ".join(filters) if len(filters) > 0 else "" sql_base = """ FROM users as u LEFT JOIN profiles AS p ON u.name = '@' || p.user_id || ':' || ? {} """.format( where_clause ) sql = "SELECT COUNT(*) as total_users " + sql_base txn.execute(sql, args) count = txn.fetchone()[0] sql = ( "SELECT name, user_type, is_guest, admin, deactivated, displayname, avatar_url " + sql_base + " ORDER BY u.name LIMIT ? OFFSET ?" ) args += [limit, start] txn.execute(sql, args) users = self.db_pool.cursor_to_dict(txn) return users, count return await self.db_pool.runInteraction( "get_users_paginate_txn", get_users_paginate_txn ) async def search_users(self, term: str) -> Optional[List[Dict[str, Any]]]: """Function to search users list for one or more users with the matched term. Args: term: search term Returns: A list of dictionaries or None. """ return await self.db_pool.simple_search_list( table="users", term=term, col="name", retcols=["name", "password_hash", "is_guest", "admin", "user_type"], desc="search_users", )
def __init__(self, database: DatabasePool, db_conn, hs): self.hs = hs self._clock = hs.get_clock() self.database_engine = database.engine self._presence_id_gen = StreamIdGenerator( db_conn, "presence_stream", "stream_id" ) self._device_inbox_id_gen = StreamIdGenerator( db_conn, "device_inbox", "stream_id" ) self._public_room_id_gen = StreamIdGenerator( db_conn, "public_room_list_stream", "stream_id" ) self._device_list_id_gen = StreamIdGenerator( db_conn, "device_lists_stream", "stream_id", extra_tables=[ ("user_signature_stream", "stream_id"), ("device_lists_outbound_pokes", "stream_id"), ], ) self._cross_signing_id_gen = StreamIdGenerator( db_conn, "e2e_cross_signing_keys", "stream_id" ) self._event_reports_id_gen = IdGenerator(db_conn, "event_reports", "id") self._push_rule_id_gen = IdGenerator(db_conn, "push_rules", "id") self._push_rules_enable_id_gen = IdGenerator(db_conn, "push_rules_enable", "id") self._group_updates_id_gen = StreamIdGenerator( db_conn, "local_group_updates", "stream_id" ) if isinstance(self.database_engine, PostgresEngine): # We set the `writers` to an empty list here as we don't care about # missing updates over restarts, as we'll not have anything in our # caches to invalidate. (This reduces the amount of writes to the DB # that happen). self._cache_id_gen = MultiWriterIdGenerator( db_conn, database, stream_name="caches", instance_name=hs.get_instance_name(), table="cache_invalidation_stream_by_instance", instance_column="instance_name", id_column="stream_id", sequence_name="cache_invalidation_stream_seq", writers=[], ) else: self._cache_id_gen = None super().__init__(database, db_conn, hs) self._presence_on_startup = self._get_active_presence(db_conn) presence_cache_prefill, min_presence_val = self.db_pool.get_cache_dict( db_conn, "presence_stream", entity_column="user_id", stream_column="stream_id", max_value=self._presence_id_gen.get_current_token(), ) self.presence_stream_cache = StreamChangeCache( "PresenceStreamChangeCache", min_presence_val, prefilled_cache=presence_cache_prefill, ) max_device_inbox_id = self._device_inbox_id_gen.get_current_token() device_inbox_prefill, min_device_inbox_id = self.db_pool.get_cache_dict( db_conn, "device_inbox", entity_column="user_id", stream_column="stream_id", max_value=max_device_inbox_id, limit=1000, ) self._device_inbox_stream_cache = StreamChangeCache( "DeviceInboxStreamChangeCache", min_device_inbox_id, prefilled_cache=device_inbox_prefill, ) # The federation outbox and the local device inbox uses the same # stream_id generator. device_outbox_prefill, min_device_outbox_id = self.db_pool.get_cache_dict( db_conn, "device_federation_outbox", entity_column="destination", stream_column="stream_id", max_value=max_device_inbox_id, limit=1000, ) self._device_federation_outbox_stream_cache = StreamChangeCache( "DeviceFederationOutboxStreamChangeCache", min_device_outbox_id, prefilled_cache=device_outbox_prefill, ) device_list_max = self._device_list_id_gen.get_current_token() self._device_list_stream_cache = StreamChangeCache( "DeviceListStreamChangeCache", device_list_max ) self._user_signature_stream_cache = StreamChangeCache( "UserSignatureStreamChangeCache", device_list_max ) self._device_list_federation_stream_cache = StreamChangeCache( "DeviceListFederationStreamChangeCache", device_list_max ) events_max = self._stream_id_gen.get_current_token() curr_state_delta_prefill, min_curr_state_delta_id = self.db_pool.get_cache_dict( db_conn, "current_state_delta_stream", entity_column="room_id", stream_column="stream_id", max_value=events_max, # As we share the stream id with events token limit=1000, ) self._curr_state_delta_stream_cache = StreamChangeCache( "_curr_state_delta_stream_cache", min_curr_state_delta_id, prefilled_cache=curr_state_delta_prefill, ) _group_updates_prefill, min_group_updates_id = self.db_pool.get_cache_dict( db_conn, "local_group_updates", entity_column="user_id", stream_column="stream_id", max_value=self._group_updates_id_gen.get_current_token(), limit=1000, ) self._group_updates_stream_cache = StreamChangeCache( "_group_updates_stream_cache", min_group_updates_id, prefilled_cache=_group_updates_prefill, ) self._stream_order_on_start = self.get_room_max_stream_ordering() self._min_stream_order_on_start = self.get_room_min_stream_ordering()
class DataStore( EventsBackgroundUpdatesStore, RoomMemberStore, RoomStore, RegistrationStore, StreamStore, ProfileStore, PresenceStore, TransactionStore, DirectoryStore, KeyStore, StateStore, SignatureStore, ApplicationServiceStore, EventsStore, EventFederationStore, MediaRepositoryStore, RejectionsStore, FilteringStore, PusherStore, PushRuleStore, ApplicationServiceTransactionStore, ReceiptsStore, EndToEndKeyStore, EndToEndRoomKeyStore, SearchStore, TagsStore, AccountDataStore, EventPushActionsStore, OpenIdStore, ClientIpStore, DeviceStore, DeviceInboxStore, UserDirectoryStore, GroupServerStore, UserErasureStore, MonthlyActiveUsersStore, StatsStore, RelationsStore, CacheInvalidationStore, UIAuthStore, ): def __init__(self, database: Database, db_conn, hs): self.hs = hs self._clock = hs.get_clock() self.database_engine = database.engine self._stream_id_gen = StreamIdGenerator( db_conn, "events", "stream_ordering", extra_tables=[("local_invites", "stream_id")], ) self._backfill_id_gen = StreamIdGenerator( db_conn, "events", "stream_ordering", step=-1, extra_tables=[("ex_outlier_stream", "event_stream_ordering")], ) self._presence_id_gen = StreamIdGenerator(db_conn, "presence_stream", "stream_id") self._device_inbox_id_gen = StreamIdGenerator(db_conn, "device_max_stream_id", "stream_id") self._public_room_id_gen = StreamIdGenerator( db_conn, "public_room_list_stream", "stream_id") self._device_list_id_gen = StreamIdGenerator( db_conn, "device_lists_stream", "stream_id", extra_tables=[ ("user_signature_stream", "stream_id"), ("device_lists_outbound_pokes", "stream_id"), ], ) self._cross_signing_id_gen = StreamIdGenerator( db_conn, "e2e_cross_signing_keys", "stream_id") self._access_tokens_id_gen = IdGenerator(db_conn, "access_tokens", "id") self._event_reports_id_gen = IdGenerator(db_conn, "event_reports", "id") self._push_rule_id_gen = IdGenerator(db_conn, "push_rules", "id") self._push_rules_enable_id_gen = IdGenerator(db_conn, "push_rules_enable", "id") self._push_rules_stream_id_gen = ChainedIdGenerator( self._stream_id_gen, db_conn, "push_rules_stream", "stream_id") self._pushers_id_gen = StreamIdGenerator(db_conn, "pushers", "id", extra_tables=[ ("deleted_pushers", "stream_id") ]) self._group_updates_id_gen = StreamIdGenerator(db_conn, "local_group_updates", "stream_id") if isinstance(self.database_engine, PostgresEngine): self._cache_id_gen = StreamIdGenerator( db_conn, "cache_invalidation_stream", "stream_id") else: self._cache_id_gen = None super(DataStore, self).__init__(database, db_conn, hs) self._presence_on_startup = self._get_active_presence(db_conn) presence_cache_prefill, min_presence_val = self.db.get_cache_dict( db_conn, "presence_stream", entity_column="user_id", stream_column="stream_id", max_value=self._presence_id_gen.get_current_token(), ) self.presence_stream_cache = StreamChangeCache( "PresenceStreamChangeCache", min_presence_val, prefilled_cache=presence_cache_prefill, ) max_device_inbox_id = self._device_inbox_id_gen.get_current_token() device_inbox_prefill, min_device_inbox_id = self.db.get_cache_dict( db_conn, "device_inbox", entity_column="user_id", stream_column="stream_id", max_value=max_device_inbox_id, limit=1000, ) self._device_inbox_stream_cache = StreamChangeCache( "DeviceInboxStreamChangeCache", min_device_inbox_id, prefilled_cache=device_inbox_prefill, ) # The federation outbox and the local device inbox uses the same # stream_id generator. device_outbox_prefill, min_device_outbox_id = self.db.get_cache_dict( db_conn, "device_federation_outbox", entity_column="destination", stream_column="stream_id", max_value=max_device_inbox_id, limit=1000, ) self._device_federation_outbox_stream_cache = StreamChangeCache( "DeviceFederationOutboxStreamChangeCache", min_device_outbox_id, prefilled_cache=device_outbox_prefill, ) device_list_max = self._device_list_id_gen.get_current_token() self._device_list_stream_cache = StreamChangeCache( "DeviceListStreamChangeCache", device_list_max) self._user_signature_stream_cache = StreamChangeCache( "UserSignatureStreamChangeCache", device_list_max) self._device_list_federation_stream_cache = StreamChangeCache( "DeviceListFederationStreamChangeCache", device_list_max) events_max = self._stream_id_gen.get_current_token() curr_state_delta_prefill, min_curr_state_delta_id = self.db.get_cache_dict( db_conn, "current_state_delta_stream", entity_column="room_id", stream_column="stream_id", max_value=events_max, # As we share the stream id with events token limit=1000, ) self._curr_state_delta_stream_cache = StreamChangeCache( "_curr_state_delta_stream_cache", min_curr_state_delta_id, prefilled_cache=curr_state_delta_prefill, ) _group_updates_prefill, min_group_updates_id = self.db.get_cache_dict( db_conn, "local_group_updates", entity_column="user_id", stream_column="stream_id", max_value=self._group_updates_id_gen.get_current_token(), limit=1000, ) self._group_updates_stream_cache = StreamChangeCache( "_group_updates_stream_cache", min_group_updates_id, prefilled_cache=_group_updates_prefill, ) self._stream_order_on_start = self.get_room_max_stream_ordering() self._min_stream_order_on_start = self.get_room_min_stream_ordering() # Used in _generate_user_daily_visits to keep track of progress self._last_user_visit_update = self._get_start_of_day() def take_presence_startup_info(self): active_on_startup = self._presence_on_startup self._presence_on_startup = None return active_on_startup def _get_active_presence(self, db_conn): """Fetch non-offline presence from the database so that we can register the appropriate time outs. """ sql = ( "SELECT user_id, state, last_active_ts, last_federation_update_ts," " last_user_sync_ts, status_msg, currently_active FROM presence_stream" " WHERE state != ?") sql = self.database_engine.convert_param_style(sql) txn = db_conn.cursor() txn.execute(sql, (PresenceState.OFFLINE, )) rows = self.db.cursor_to_dict(txn) txn.close() for row in rows: row["currently_active"] = bool(row["currently_active"]) return [UserPresenceState(**row) for row in rows] def count_daily_users(self): """ Counts the number of users who used this homeserver in the last 24 hours. """ yesterday = int(self._clock.time_msec()) - (1000 * 60 * 60 * 24) return self.db.runInteraction("count_daily_users", self._count_users, yesterday) def count_monthly_users(self): """ Counts the number of users who used this homeserver in the last 30 days. Note this method is intended for phonehome metrics only and is different from the mau figure in synapse.storage.monthly_active_users which, amongst other things, includes a 3 day grace period before a user counts. """ thirty_days_ago = int( self._clock.time_msec()) - (1000 * 60 * 60 * 24 * 30) return self.db.runInteraction("count_monthly_users", self._count_users, thirty_days_ago) def _count_users(self, txn, time_from): """ Returns number of users seen in the past time_from period """ sql = """ SELECT COALESCE(count(*), 0) FROM ( SELECT user_id FROM user_ips WHERE last_seen > ? GROUP BY user_id ) u """ txn.execute(sql, (time_from, )) (count, ) = txn.fetchone() return count def count_r30_users(self): """ Counts the number of 30 day retained users, defined as:- * Users who have created their accounts more than 30 days ago * Where last seen at most 30 days ago * Where account creation and last_seen are > 30 days apart Returns counts globaly for a given user as well as breaking by platform """ def _count_r30_users(txn): thirty_days_in_secs = 86400 * 30 now = int(self._clock.time()) thirty_days_ago_in_secs = now - thirty_days_in_secs sql = """ SELECT platform, COALESCE(count(*), 0) FROM ( SELECT users.name, platform, users.creation_ts * 1000, MAX(uip.last_seen) FROM users INNER JOIN ( SELECT user_id, last_seen, CASE WHEN user_agent LIKE '%%Android%%' THEN 'android' WHEN user_agent LIKE '%%iOS%%' THEN 'ios' WHEN user_agent LIKE '%%Electron%%' THEN 'electron' WHEN user_agent LIKE '%%Mozilla%%' THEN 'web' WHEN user_agent LIKE '%%Gecko%%' THEN 'web' ELSE 'unknown' END AS platform FROM user_ips ) uip ON users.name = uip.user_id AND users.appservice_id is NULL AND users.creation_ts < ? AND uip.last_seen/1000 > ? AND (uip.last_seen/1000) - users.creation_ts > 86400 * 30 GROUP BY users.name, platform, users.creation_ts ) u GROUP BY platform """ results = {} txn.execute(sql, (thirty_days_ago_in_secs, thirty_days_ago_in_secs)) for row in txn: if row[0] == "unknown": pass results[row[0]] = row[1] sql = """ SELECT COALESCE(count(*), 0) FROM ( SELECT users.name, users.creation_ts * 1000, MAX(uip.last_seen) FROM users INNER JOIN ( SELECT user_id, last_seen FROM user_ips ) uip ON users.name = uip.user_id AND appservice_id is NULL AND users.creation_ts < ? AND uip.last_seen/1000 > ? AND (uip.last_seen/1000) - users.creation_ts > 86400 * 30 GROUP BY users.name, users.creation_ts ) u """ txn.execute(sql, (thirty_days_ago_in_secs, thirty_days_ago_in_secs)) (count, ) = txn.fetchone() results["all"] = count return results return self.db.runInteraction("count_r30_users", _count_r30_users) def _get_start_of_day(self): """ Returns millisecond unixtime for start of UTC day. """ now = time.gmtime() today_start = calendar.timegm( (now.tm_year, now.tm_mon, now.tm_mday, 0, 0, 0)) return today_start * 1000 def generate_user_daily_visits(self): """ Generates daily visit data for use in cohort/ retention analysis """ def _generate_user_daily_visits(txn): logger.info("Calling _generate_user_daily_visits") today_start = self._get_start_of_day() a_day_in_milliseconds = 24 * 60 * 60 * 1000 now = self.clock.time_msec() sql = """ INSERT INTO user_daily_visits (user_id, device_id, timestamp) SELECT u.user_id, u.device_id, ? FROM user_ips AS u LEFT JOIN ( SELECT user_id, device_id, timestamp FROM user_daily_visits WHERE timestamp = ? ) udv ON u.user_id = udv.user_id AND u.device_id=udv.device_id INNER JOIN users ON users.name=u.user_id WHERE last_seen > ? AND last_seen <= ? AND udv.timestamp IS NULL AND users.is_guest=0 AND users.appservice_id IS NULL GROUP BY u.user_id, u.device_id """ # This means that the day has rolled over but there could still # be entries from the previous day. There is an edge case # where if the user logs in at 23:59 and overwrites their # last_seen at 00:01 then they will not be counted in the # previous day's stats - it is important that the query is run # often to minimise this case. if today_start > self._last_user_visit_update: yesterday_start = today_start - a_day_in_milliseconds txn.execute( sql, ( yesterday_start, yesterday_start, self._last_user_visit_update, today_start, ), ) self._last_user_visit_update = today_start txn.execute( sql, (today_start, today_start, self._last_user_visit_update, now)) # Update _last_user_visit_update to now. The reason to do this # rather just clamping to the beginning of the day is to limit # the size of the join - meaning that the query can be run more # frequently self._last_user_visit_update = now return self.db.runInteraction("generate_user_daily_visits", _generate_user_daily_visits) def get_users(self): """Function to retrieve a list of users in users table. Args: Returns: defer.Deferred: resolves to list[dict[str, Any]] """ return self.db.simple_select_list( table="users", keyvalues={}, retcols=[ "name", "password_hash", "is_guest", "admin", "user_type", "deactivated", ], desc="get_users", ) def get_users_paginate(self, start, limit, name=None, guests=True, deactivated=False): """Function to retrieve a paginated list of users from users list. This will return a json list of users and the total number of users matching the filter criteria. Args: start (int): start number to begin the query from limit (int): number of rows to retrieve name (string): filter for user names guests (bool): whether to in include guest users deactivated (bool): whether to include deactivated users Returns: defer.Deferred: resolves to list[dict[str, Any]], int """ def get_users_paginate_txn(txn): filters = [] args = [] if name: filters.append("name LIKE ?") args.append("%" + name + "%") if not guests: filters.append("is_guest = 0") if not deactivated: filters.append("deactivated = 0") where_clause = "WHERE " + " AND ".join(filters) if len( filters) > 0 else "" sql = "SELECT COUNT(*) as total_users FROM users %s" % ( where_clause) txn.execute(sql, args) count = txn.fetchone()[0] args = [self.hs.config.server_name] + args + [limit, start] sql = """ SELECT name, user_type, is_guest, admin, deactivated, displayname, avatar_url FROM users as u LEFT JOIN profiles AS p ON u.name = '@' || p.user_id || ':' || ? {} ORDER BY u.name LIMIT ? OFFSET ? """.format(where_clause) txn.execute(sql, args) users = self.db.cursor_to_dict(txn) return users, count return self.db.runInteraction("get_users_paginate_txn", get_users_paginate_txn) def search_users(self, term): """Function to search users list for one or more users with the matched term. Args: term (str): search term col (str): column to query term should be matched to Returns: defer.Deferred: resolves to list[dict[str, Any]] """ return self.db.simple_search_list( table="users", term=term, col="name", retcols=[ "name", "password_hash", "is_guest", "admin", "user_type" ], desc="search_users", )
def __init__(self, database: Database, db_conn, hs): self.hs = hs self._clock = hs.get_clock() self.database_engine = database.engine self._stream_id_gen = StreamIdGenerator( db_conn, "events", "stream_ordering", extra_tables=[("local_invites", "stream_id")], ) self._backfill_id_gen = StreamIdGenerator( db_conn, "events", "stream_ordering", step=-1, extra_tables=[("ex_outlier_stream", "event_stream_ordering")], ) self._presence_id_gen = StreamIdGenerator(db_conn, "presence_stream", "stream_id") self._device_inbox_id_gen = StreamIdGenerator(db_conn, "device_max_stream_id", "stream_id") self._public_room_id_gen = StreamIdGenerator( db_conn, "public_room_list_stream", "stream_id") self._device_list_id_gen = StreamIdGenerator( db_conn, "device_lists_stream", "stream_id", extra_tables=[ ("user_signature_stream", "stream_id"), ("device_lists_outbound_pokes", "stream_id"), ], ) self._cross_signing_id_gen = StreamIdGenerator( db_conn, "e2e_cross_signing_keys", "stream_id") self._access_tokens_id_gen = IdGenerator(db_conn, "access_tokens", "id") self._event_reports_id_gen = IdGenerator(db_conn, "event_reports", "id") self._push_rule_id_gen = IdGenerator(db_conn, "push_rules", "id") self._push_rules_enable_id_gen = IdGenerator(db_conn, "push_rules_enable", "id") self._push_rules_stream_id_gen = ChainedIdGenerator( self._stream_id_gen, db_conn, "push_rules_stream", "stream_id") self._pushers_id_gen = StreamIdGenerator(db_conn, "pushers", "id", extra_tables=[ ("deleted_pushers", "stream_id") ]) self._group_updates_id_gen = StreamIdGenerator(db_conn, "local_group_updates", "stream_id") if isinstance(self.database_engine, PostgresEngine): self._cache_id_gen = StreamIdGenerator( db_conn, "cache_invalidation_stream", "stream_id") else: self._cache_id_gen = None super(DataStore, self).__init__(database, db_conn, hs) self._presence_on_startup = self._get_active_presence(db_conn) presence_cache_prefill, min_presence_val = self.db.get_cache_dict( db_conn, "presence_stream", entity_column="user_id", stream_column="stream_id", max_value=self._presence_id_gen.get_current_token(), ) self.presence_stream_cache = StreamChangeCache( "PresenceStreamChangeCache", min_presence_val, prefilled_cache=presence_cache_prefill, ) max_device_inbox_id = self._device_inbox_id_gen.get_current_token() device_inbox_prefill, min_device_inbox_id = self.db.get_cache_dict( db_conn, "device_inbox", entity_column="user_id", stream_column="stream_id", max_value=max_device_inbox_id, limit=1000, ) self._device_inbox_stream_cache = StreamChangeCache( "DeviceInboxStreamChangeCache", min_device_inbox_id, prefilled_cache=device_inbox_prefill, ) # The federation outbox and the local device inbox uses the same # stream_id generator. device_outbox_prefill, min_device_outbox_id = self.db.get_cache_dict( db_conn, "device_federation_outbox", entity_column="destination", stream_column="stream_id", max_value=max_device_inbox_id, limit=1000, ) self._device_federation_outbox_stream_cache = StreamChangeCache( "DeviceFederationOutboxStreamChangeCache", min_device_outbox_id, prefilled_cache=device_outbox_prefill, ) device_list_max = self._device_list_id_gen.get_current_token() self._device_list_stream_cache = StreamChangeCache( "DeviceListStreamChangeCache", device_list_max) self._user_signature_stream_cache = StreamChangeCache( "UserSignatureStreamChangeCache", device_list_max) self._device_list_federation_stream_cache = StreamChangeCache( "DeviceListFederationStreamChangeCache", device_list_max) events_max = self._stream_id_gen.get_current_token() curr_state_delta_prefill, min_curr_state_delta_id = self.db.get_cache_dict( db_conn, "current_state_delta_stream", entity_column="room_id", stream_column="stream_id", max_value=events_max, # As we share the stream id with events token limit=1000, ) self._curr_state_delta_stream_cache = StreamChangeCache( "_curr_state_delta_stream_cache", min_curr_state_delta_id, prefilled_cache=curr_state_delta_prefill, ) _group_updates_prefill, min_group_updates_id = self.db.get_cache_dict( db_conn, "local_group_updates", entity_column="user_id", stream_column="stream_id", max_value=self._group_updates_id_gen.get_current_token(), limit=1000, ) self._group_updates_stream_cache = StreamChangeCache( "_group_updates_stream_cache", min_group_updates_id, prefilled_cache=_group_updates_prefill, ) self._stream_order_on_start = self.get_room_max_stream_ordering() self._min_stream_order_on_start = self.get_room_min_stream_ordering() # Used in _generate_user_daily_visits to keep track of progress self._last_user_visit_update = self._get_start_of_day()
class DeviceInboxWorkerStore(SQLBaseStore): def __init__( self, database: DatabasePool, db_conn: LoggingDatabaseConnection, hs: "HomeServer", ): super().__init__(database, db_conn, hs) self._instance_name = hs.get_instance_name() # Map of (user_id, device_id) to the last stream_id that has been # deleted up to. This is so that we can no op deletions. self._last_device_delete_cache: ExpiringCache[Tuple[ str, Optional[str]], int] = ExpiringCache( cache_name="last_device_delete_cache", clock=self._clock, max_len=10000, expiry_ms=30 * 60 * 1000, ) if isinstance(database.engine, PostgresEngine): self._can_write_to_device = (self._instance_name in hs.config.worker.writers.to_device) self._device_inbox_id_gen: AbstractStreamIdGenerator = ( MultiWriterIdGenerator( db_conn=db_conn, db=database, stream_name="to_device", instance_name=self._instance_name, tables=[("device_inbox", "instance_name", "stream_id")], sequence_name="device_inbox_sequence", writers=hs.config.worker.writers.to_device, )) else: self._can_write_to_device = True self._device_inbox_id_gen = StreamIdGenerator( db_conn, "device_inbox", "stream_id") max_device_inbox_id = self._device_inbox_id_gen.get_current_token() device_inbox_prefill, min_device_inbox_id = self.db_pool.get_cache_dict( db_conn, "device_inbox", entity_column="user_id", stream_column="stream_id", max_value=max_device_inbox_id, limit=1000, ) self._device_inbox_stream_cache = StreamChangeCache( "DeviceInboxStreamChangeCache", min_device_inbox_id, prefilled_cache=device_inbox_prefill, ) # The federation outbox and the local device inbox uses the same # stream_id generator. device_outbox_prefill, min_device_outbox_id = self.db_pool.get_cache_dict( db_conn, "device_federation_outbox", entity_column="destination", stream_column="stream_id", max_value=max_device_inbox_id, limit=1000, ) self._device_federation_outbox_stream_cache = StreamChangeCache( "DeviceFederationOutboxStreamChangeCache", min_device_outbox_id, prefilled_cache=device_outbox_prefill, ) def process_replication_rows( self, stream_name: str, instance_name: str, token: int, rows: Iterable[ToDeviceStream.ToDeviceStreamRow], ) -> None: if stream_name == ToDeviceStream.NAME: # If replication is happening than postgres must be being used. assert isinstance(self._device_inbox_id_gen, MultiWriterIdGenerator) self._device_inbox_id_gen.advance(instance_name, token) for row in rows: if row.entity.startswith("@"): self._device_inbox_stream_cache.entity_has_changed( row.entity, token) else: self._device_federation_outbox_stream_cache.entity_has_changed( row.entity, token) return super().process_replication_rows(stream_name, instance_name, token, rows) def get_to_device_stream_token(self) -> int: return self._device_inbox_id_gen.get_current_token() async def get_messages_for_user_devices( self, user_ids: Collection[str], from_stream_id: int, to_stream_id: int, ) -> Dict[Tuple[str, str], List[JsonDict]]: """ Retrieve to-device messages for a given set of users. Only to-device messages with stream ids between the given boundaries (from < X <= to) are returned. Args: user_ids: The users to retrieve to-device messages for. from_stream_id: The lower boundary of stream id to filter with (exclusive). to_stream_id: The upper boundary of stream id to filter with (inclusive). Returns: A dictionary of (user id, device id) -> list of to-device messages. """ # We expect the stream ID returned by _get_device_messages to always # be to_stream_id. So, no need to return it from this function. ( user_id_device_id_to_messages, last_processed_stream_id, ) = await self._get_device_messages( user_ids=user_ids, from_stream_id=from_stream_id, to_stream_id=to_stream_id, ) assert ( last_processed_stream_id == to_stream_id ), "Expected _get_device_messages to process all to-device messages up to `to_stream_id`" return user_id_device_id_to_messages async def get_messages_for_device( self, user_id: str, device_id: str, from_stream_id: int, to_stream_id: int, limit: int = 100, ) -> Tuple[List[JsonDict], int]: """ Retrieve to-device messages for a single user device. Only to-device messages with stream ids between the given boundaries (from < X <= to) are returned. Args: user_id: The ID of the user to retrieve messages for. device_id: The ID of the device to retrieve to-device messages for. from_stream_id: The lower boundary of stream id to filter with (exclusive). to_stream_id: The upper boundary of stream id to filter with (inclusive). limit: A limit on the number of to-device messages returned. Returns: A tuple containing: * A list of to-device messages within the given stream id range intended for the given user / device combo. * The last-processed stream ID. Subsequent calls of this function with the same device should pass this value as 'from_stream_id'. """ ( user_id_device_id_to_messages, last_processed_stream_id, ) = await self._get_device_messages( user_ids=[user_id], device_id=device_id, from_stream_id=from_stream_id, to_stream_id=to_stream_id, limit=limit, ) if not user_id_device_id_to_messages: # There were no messages! return [], to_stream_id # Extract the messages, no need to return the user and device ID again to_device_messages = user_id_device_id_to_messages.get( (user_id, device_id), []) return to_device_messages, last_processed_stream_id async def _get_device_messages( self, user_ids: Collection[str], from_stream_id: int, to_stream_id: int, device_id: Optional[str] = None, limit: Optional[int] = None, ) -> Tuple[Dict[Tuple[str, str], List[JsonDict]], int]: """ Retrieve pending to-device messages for a collection of user devices. Only to-device messages with stream ids between the given boundaries (from < X <= to) are returned. Note that a stream ID can be shared by multiple copies of the same message with different recipient devices. Stream IDs are only unique in the context of a single user ID / device ID pair. Thus, applying a limit (of messages to return) when working with a sliding window of stream IDs is only possible when querying messages of a single user device. Finally, note that device IDs are not unique across users. Args: user_ids: The user IDs to filter device messages by. from_stream_id: The lower boundary of stream id to filter with (exclusive). to_stream_id: The upper boundary of stream id to filter with (inclusive). device_id: A device ID to query to-device messages for. If not provided, to-device messages from all device IDs for the given user IDs will be queried. May not be provided if `user_ids` contains more than one entry. limit: The maximum number of to-device messages to return. Can only be used when passing a single user ID / device ID tuple. Returns: A tuple containing: * A dict of (user_id, device_id) -> list of to-device messages * The last-processed stream ID. If this is less than `to_stream_id`, then there may be more messages to retrieve. If `limit` is not set, then this is always equal to 'to_stream_id'. """ if not user_ids: logger.warning("No users provided upon querying for device IDs") return {}, to_stream_id # Prevent a query for one user's device also retrieving another user's device with # the same device ID (device IDs are not unique across users). if len(user_ids) > 1 and device_id is not None: raise AssertionError( "Programming error: 'device_id' cannot be supplied to " "_get_device_messages when >1 user_id has been provided") # A limit can only be applied when querying for a single user ID / device ID tuple. # See the docstring of this function for more details. if limit is not None and device_id is None: raise AssertionError( "Programming error: _get_device_messages was passed 'limit' " "without a specific user_id/device_id") user_ids_to_query: Set[str] = set() device_ids_to_query: Set[str] = set() # Note that a device ID could be an empty str if device_id is not None: # If a device ID was passed, use it to filter results. # Otherwise, device IDs will be derived from the given collection of user IDs. device_ids_to_query.add(device_id) # Determine which users have devices with pending messages for user_id in user_ids: if self._device_inbox_stream_cache.has_entity_changed( user_id, from_stream_id): # This user has new messages sent to them. Query messages for them user_ids_to_query.add(user_id) if not user_ids_to_query: return {}, to_stream_id def get_device_messages_txn( txn: LoggingTransaction, ) -> Tuple[Dict[Tuple[str, str], List[JsonDict]], int]: # Build a query to select messages from any of the given devices that # are between the given stream id bounds. # If a list of device IDs was not provided, retrieve all devices IDs # for the given users. We explicitly do not query hidden devices, as # hidden devices should not receive to-device messages. # Note that this is more efficient than just dropping `device_id` from the query, # since device_inbox has an index on `(user_id, device_id, stream_id)` if not device_ids_to_query: user_device_dicts = self.db_pool.simple_select_many_txn( txn, table="devices", column="user_id", iterable=user_ids_to_query, keyvalues={ "user_id": user_id, "hidden": False }, retcols=("device_id", ), ) device_ids_to_query.update( {row["device_id"] for row in user_device_dicts}) if not device_ids_to_query: # We've ended up with no devices to query. return {}, to_stream_id # We include both user IDs and device IDs in this query, as we have an index # (device_inbox_user_stream_id) for them. user_id_many_clause_sql, user_id_many_clause_args = make_in_list_sql_clause( self.database_engine, "user_id", user_ids_to_query) ( device_id_many_clause_sql, device_id_many_clause_args, ) = make_in_list_sql_clause(self.database_engine, "device_id", device_ids_to_query) sql = f""" SELECT stream_id, user_id, device_id, message_json FROM device_inbox WHERE {user_id_many_clause_sql} AND {device_id_many_clause_sql} AND ? < stream_id AND stream_id <= ? ORDER BY stream_id ASC """ sql_args = ( *user_id_many_clause_args, *device_id_many_clause_args, from_stream_id, to_stream_id, ) # If a limit was provided, limit the data retrieved from the database if limit is not None: sql += "LIMIT ?" sql_args += (limit, ) txn.execute(sql, sql_args) # Create and fill a dictionary of (user ID, device ID) -> list of messages # intended for each device. last_processed_stream_pos = to_stream_id recipient_device_to_messages: Dict[Tuple[str, str], List[JsonDict]] = {} rowcount = 0 for row in txn: rowcount += 1 last_processed_stream_pos = row[0] recipient_user_id = row[1] recipient_device_id = row[2] message_dict = db_to_json(row[3]) # Store the device details recipient_device_to_messages.setdefault( (recipient_user_id, recipient_device_id), []).append(message_dict) if limit is not None and rowcount == limit: # We ended up bumping up against the message limit. There may be more messages # to retrieve. Return what we have, as well as the last stream position that # was processed. # # The caller is expected to set this as the lower (exclusive) bound # for the next query of this device. return recipient_device_to_messages, last_processed_stream_pos # The limit was not reached, thus we know that recipient_device_to_messages # contains all to-device messages for the given device and stream id range. # # We return to_stream_id, which the caller should then provide as the lower # (exclusive) bound on the next query of this device. return recipient_device_to_messages, to_stream_id return await self.db_pool.runInteraction("get_device_messages", get_device_messages_txn) @trace async def delete_messages_for_device(self, user_id: str, device_id: Optional[str], up_to_stream_id: int) -> int: """ Args: user_id: The recipient user_id. device_id: The recipient device_id. up_to_stream_id: Where to delete messages up to. Returns: The number of messages deleted. """ # If we have cached the last stream id we've deleted up to, we can # check if there is likely to be anything that needs deleting last_deleted_stream_id = self._last_device_delete_cache.get( (user_id, device_id), None) set_tag("last_deleted_stream_id", last_deleted_stream_id) if last_deleted_stream_id: has_changed = self._device_inbox_stream_cache.has_entity_changed( user_id, last_deleted_stream_id) if not has_changed: log_kv({"message": "No changes in cache since last check"}) return 0 def delete_messages_for_device_txn(txn: LoggingTransaction) -> int: sql = ("DELETE FROM device_inbox" " WHERE user_id = ? AND device_id = ?" " AND stream_id <= ?") txn.execute(sql, (user_id, device_id, up_to_stream_id)) return txn.rowcount count = await self.db_pool.runInteraction( "delete_messages_for_device", delete_messages_for_device_txn) log_kv({ "message": f"deleted {count} messages for device", "count": count }) # Update the cache, ensuring that we only ever increase the value updated_last_deleted_stream_id = self._last_device_delete_cache.get( (user_id, device_id), 0) self._last_device_delete_cache[(user_id, device_id)] = max( updated_last_deleted_stream_id, up_to_stream_id) return count @trace async def get_new_device_msgs_for_remote( self, destination: str, last_stream_id: int, current_stream_id: int, limit: int) -> Tuple[List[JsonDict], int]: """ Args: destination: The name of the remote server. last_stream_id: The last position of the device message stream that the server sent up to. current_stream_id: The current position of the device message stream. Returns: A list of messages for the device and where in the stream the messages got to. """ set_tag("destination", destination) set_tag("last_stream_id", last_stream_id) set_tag("current_stream_id", current_stream_id) set_tag("limit", limit) has_changed = self._device_federation_outbox_stream_cache.has_entity_changed( destination, last_stream_id) if not has_changed or last_stream_id == current_stream_id: log_kv({"message": "No new messages in stream"}) return [], current_stream_id if limit <= 0: # This can happen if we run out of room for EDUs in the transaction. return [], last_stream_id @trace def get_new_messages_for_remote_destination_txn( txn: LoggingTransaction, ) -> Tuple[List[JsonDict], int]: sql = ( "SELECT stream_id, messages_json FROM device_federation_outbox" " WHERE destination = ?" " AND ? < stream_id AND stream_id <= ?" " ORDER BY stream_id ASC" " LIMIT ?") txn.execute( sql, (destination, last_stream_id, current_stream_id, limit)) messages = [] stream_pos = current_stream_id for row in txn: stream_pos = row[0] messages.append(db_to_json(row[1])) # If the limit was not reached we know that there's no more data for this # user/device pair up to current_stream_id. if len(messages) < limit: log_kv({"message": "Set stream position to current position"}) stream_pos = current_stream_id return messages, stream_pos return await self.db_pool.runInteraction( "get_new_device_msgs_for_remote", get_new_messages_for_remote_destination_txn, ) @trace async def delete_device_msgs_for_remote(self, destination: str, up_to_stream_id: int) -> None: """Used to delete messages when the remote destination acknowledges their receipt. Args: destination: The destination server_name up_to_stream_id: Where to delete messages up to. """ def delete_messages_for_remote_destination_txn( txn: LoggingTransaction) -> None: sql = ("DELETE FROM device_federation_outbox" " WHERE destination = ?" " AND stream_id <= ?") txn.execute(sql, (destination, up_to_stream_id)) await self.db_pool.runInteraction( "delete_device_msgs_for_remote", delete_messages_for_remote_destination_txn) async def get_all_new_device_messages( self, instance_name: str, last_id: int, current_id: int, limit: int) -> Tuple[List[Tuple[int, tuple]], int, bool]: """Get updates for to device replication stream. Args: instance_name: The writer we want to fetch updates from. Unused here since there is only ever one writer. last_id: The token to fetch updates from. Exclusive. current_id: The token to fetch updates up to. Inclusive. limit: The requested limit for the number of rows to return. The function may return more or fewer rows. Returns: A tuple consisting of: the updates, a token to use to fetch subsequent updates, and whether we returned fewer rows than exists between the requested tokens due to the limit. The token returned can be used in a subsequent call to this function to get further updatees. The updates are a list of 2-tuples of stream ID and the row data """ if last_id == current_id: return [], current_id, False def get_all_new_device_messages_txn( txn: LoggingTransaction, ) -> Tuple[List[Tuple[int, tuple]], int, bool]: # We limit like this as we might have multiple rows per stream_id, and # we want to make sure we always get all entries for any stream_id # we return. upper_pos = min(current_id, last_id + limit) sql = ("SELECT max(stream_id), user_id" " FROM device_inbox" " WHERE ? < stream_id AND stream_id <= ?" " GROUP BY user_id") txn.execute(sql, (last_id, upper_pos)) updates = [(row[0], row[1:]) for row in txn] sql = ("SELECT max(stream_id), destination" " FROM device_federation_outbox" " WHERE ? < stream_id AND stream_id <= ?" " GROUP BY destination") txn.execute(sql, (last_id, upper_pos)) updates.extend((row[0], row[1:]) for row in txn) # Order by ascending stream ordering updates.sort() limited = False upto_token = current_id if len(updates) >= limit: upto_token = updates[-1][0] limited = True return updates, upto_token, limited return await self.db_pool.runInteraction( "get_all_new_device_messages", get_all_new_device_messages_txn) @trace async def add_messages_to_device_inbox( self, local_messages_by_user_then_device: Dict[str, Dict[str, JsonDict]], remote_messages_by_destination: Dict[str, JsonDict], ) -> int: """Used to send messages from this server. Args: local_messages_by_user_then_device: Dictionary of recipient user_id to recipient device_id to message. remote_messages_by_destination: Dictionary of destination server_name to the EDU JSON to send. Returns: The new stream_id. """ assert self._can_write_to_device def add_messages_txn(txn: LoggingTransaction, now_ms: int, stream_id: int) -> None: # Add the local messages directly to the local inbox. self._add_messages_to_local_device_inbox_txn( txn, stream_id, local_messages_by_user_then_device) # Add the remote messages to the federation outbox. # We'll send them to a remote server when we next send a # federation transaction to that destination. self.db_pool.simple_insert_many_txn( txn, table="device_federation_outbox", keys=( "destination", "stream_id", "queued_ts", "messages_json", "instance_name", ), values=[( destination, stream_id, now_ms, json_encoder.encode(edu), self._instance_name, ) for destination, edu in remote_messages_by_destination.items()], ) if remote_messages_by_destination: issue9533_logger.debug( "Queued outgoing to-device messages with stream_id %i for %s", stream_id, list(remote_messages_by_destination.keys()), ) async with self._device_inbox_id_gen.get_next() as stream_id: now_ms = self._clock.time_msec() await self.db_pool.runInteraction("add_messages_to_device_inbox", add_messages_txn, now_ms, stream_id) for user_id in local_messages_by_user_then_device.keys(): self._device_inbox_stream_cache.entity_has_changed( user_id, stream_id) for destination in remote_messages_by_destination.keys(): self._device_federation_outbox_stream_cache.entity_has_changed( destination, stream_id) return self._device_inbox_id_gen.get_current_token() async def add_messages_from_remote_to_device_inbox( self, origin: str, message_id: str, local_messages_by_user_then_device: Dict[str, Dict[str, JsonDict]], ) -> int: assert self._can_write_to_device def add_messages_txn(txn: LoggingTransaction, now_ms: int, stream_id: int) -> None: # Check if we've already inserted a matching message_id for that # origin. This can happen if the origin doesn't receive our # acknowledgement from the first time we received the message. already_inserted = self.db_pool.simple_select_one_txn( txn, table="device_federation_inbox", keyvalues={ "origin": origin, "message_id": message_id }, retcols=("message_id", ), allow_none=True, ) if already_inserted is not None: return # Add an entry for this message_id so that we know we've processed # it. self.db_pool.simple_insert_txn( txn, table="device_federation_inbox", values={ "origin": origin, "message_id": message_id, "received_ts": now_ms, }, ) # Add the messages to the appropriate local device inboxes so that # they'll be sent to the devices when they next sync. self._add_messages_to_local_device_inbox_txn( txn, stream_id, local_messages_by_user_then_device) async with self._device_inbox_id_gen.get_next() as stream_id: now_ms = self._clock.time_msec() await self.db_pool.runInteraction( "add_messages_from_remote_to_device_inbox", add_messages_txn, now_ms, stream_id, ) for user_id in local_messages_by_user_then_device.keys(): self._device_inbox_stream_cache.entity_has_changed( user_id, stream_id) return stream_id def _add_messages_to_local_device_inbox_txn( self, txn: LoggingTransaction, stream_id: int, messages_by_user_then_device: Dict[str, Dict[str, JsonDict]], ) -> None: assert self._can_write_to_device local_by_user_then_device = {} for user_id, messages_by_device in messages_by_user_then_device.items( ): messages_json_for_user = {} devices = list(messages_by_device.keys()) if len(devices) == 1 and devices[0] == "*": # Handle wildcard device_ids. # We exclude hidden devices (such as cross-signing keys) here as they are # not expected to receive to-device messages. devices = self.db_pool.simple_select_onecol_txn( txn, table="devices", keyvalues={ "user_id": user_id, "hidden": False }, retcol="device_id", ) message_json = json_encoder.encode(messages_by_device["*"]) for device_id in devices: # Add the message for all devices for this user on this # server. messages_json_for_user[device_id] = message_json else: if not devices: continue # We exclude hidden devices (such as cross-signing keys) here as they are # not expected to receive to-device messages. rows = self.db_pool.simple_select_many_txn( txn, table="devices", keyvalues={ "user_id": user_id, "hidden": False }, column="device_id", iterable=devices, retcols=("device_id", ), ) for row in rows: # Only insert into the local inbox if the device exists on # this server device_id = row["device_id"] message_json = json_encoder.encode( messages_by_device[device_id]) messages_json_for_user[device_id] = message_json if messages_json_for_user: local_by_user_then_device[user_id] = messages_json_for_user if not local_by_user_then_device: return self.db_pool.simple_insert_many_txn( txn, table="device_inbox", keys=("user_id", "device_id", "stream_id", "message_json", "instance_name"), values=[(user_id, device_id, stream_id, message_json, self._instance_name) for user_id, messages_by_device in local_by_user_then_device.items() for device_id, message_json in messages_by_device.items()], ) issue9533_logger.debug( "Stored to-device messages with stream_id %i for %s", stream_id, [(user_id, device_id) for (user_id, messages_by_device) in local_by_user_then_device.items() for device_id in messages_by_device.keys()], )
class AccountDataStore(AccountDataWorkerStore): def __init__(self, db_conn, hs): self._account_data_id_gen = StreamIdGenerator( db_conn, "account_data_max_stream_id", "stream_id" ) super(AccountDataStore, self).__init__(db_conn, hs) def get_max_account_data_stream_id(self): """Get the current max stream id for the private user data stream Returns: A deferred int. """ return self._account_data_id_gen.get_current_token() @defer.inlineCallbacks def add_account_data_to_room(self, user_id, room_id, account_data_type, content): """Add some account_data to a room for a user. Args: user_id(str): The user to add a tag for. room_id(str): The room to add a tag for. account_data_type(str): The type of account_data to add. content(dict): A json object to associate with the tag. Returns: A deferred that completes once the account_data has been added. """ content_json = json.dumps(content) with self._account_data_id_gen.get_next() as next_id: # no need to lock here as room_account_data has a unique constraint # on (user_id, room_id, account_data_type) so _simple_upsert will # retry if there is a conflict. yield self._simple_upsert( desc="add_room_account_data", table="room_account_data", keyvalues={ "user_id": user_id, "room_id": room_id, "account_data_type": account_data_type, }, values={ "stream_id": next_id, "content": content_json, }, lock=False, ) # it's theoretically possible for the above to succeed and the # below to fail - in which case we might reuse a stream id on # restart, and the above update might not get propagated. That # doesn't sound any worse than the whole update getting lost, # which is what would happen if we combined the two into one # transaction. yield self._update_max_stream_id(next_id) self._account_data_stream_cache.entity_has_changed(user_id, next_id) self.get_account_data_for_user.invalidate((user_id,)) self.get_account_data_for_room.invalidate((user_id, room_id,)) self.get_account_data_for_room_and_type.prefill( (user_id, room_id, account_data_type,), content, ) result = self._account_data_id_gen.get_current_token() defer.returnValue(result) @defer.inlineCallbacks def add_account_data_for_user(self, user_id, account_data_type, content): """Add some account_data to a room for a user. Args: user_id(str): The user to add a tag for. account_data_type(str): The type of account_data to add. content(dict): A json object to associate with the tag. Returns: A deferred that completes once the account_data has been added. """ content_json = json.dumps(content) with self._account_data_id_gen.get_next() as next_id: # no need to lock here as account_data has a unique constraint on # (user_id, account_data_type) so _simple_upsert will retry if # there is a conflict. yield self._simple_upsert( desc="add_user_account_data", table="account_data", keyvalues={ "user_id": user_id, "account_data_type": account_data_type, }, values={ "stream_id": next_id, "content": content_json, }, lock=False, ) # it's theoretically possible for the above to succeed and the # below to fail - in which case we might reuse a stream id on # restart, and the above update might not get propagated. That # doesn't sound any worse than the whole update getting lost, # which is what would happen if we combined the two into one # transaction. yield self._update_max_stream_id(next_id) self._account_data_stream_cache.entity_has_changed( user_id, next_id, ) self.get_account_data_for_user.invalidate((user_id,)) self.get_global_account_data_by_type_for_user.invalidate( (account_data_type, user_id,) ) result = self._account_data_id_gen.get_current_token() defer.returnValue(result) def _update_max_stream_id(self, next_id): """Update the max stream_id Args: next_id(int): The the revision to advance to. """ def _update(txn): update_max_id_sql = ( "UPDATE account_data_max_stream_id" " SET stream_id = ?" " WHERE stream_id < ?" ) txn.execute(update_max_id_sql, (next_id, next_id)) return self.runInteraction( "update_account_data_max_stream_id", _update, )
class AccountDataStore(AccountDataWorkerStore): def __init__(self, database: DatabasePool, db_conn, hs): self._account_data_id_gen = StreamIdGenerator( db_conn, "account_data_max_stream_id", "stream_id", extra_tables=[ ("room_account_data", "stream_id"), ("room_tags_revisions", "stream_id"), ], ) super(AccountDataStore, self).__init__(database, db_conn, hs) def get_max_account_data_stream_id(self) -> int: """Get the current max stream id for the private user data stream Returns: The maximum stream ID. """ return self._account_data_id_gen.get_current_token() async def add_account_data_to_room(self, user_id: str, room_id: str, account_data_type: str, content: JsonDict) -> int: """Add some account_data to a room for a user. Args: user_id: The user to add a tag for. room_id: The room to add a tag for. account_data_type: The type of account_data to add. content: A json object to associate with the tag. Returns: The maximum stream ID. """ content_json = json_encoder.encode(content) with self._account_data_id_gen.get_next() as next_id: # no need to lock here as room_account_data has a unique constraint # on (user_id, room_id, account_data_type) so simple_upsert will # retry if there is a conflict. await self.db_pool.simple_upsert( desc="add_room_account_data", table="room_account_data", keyvalues={ "user_id": user_id, "room_id": room_id, "account_data_type": account_data_type, }, values={ "stream_id": next_id, "content": content_json }, lock=False, ) # it's theoretically possible for the above to succeed and the # below to fail - in which case we might reuse a stream id on # restart, and the above update might not get propagated. That # doesn't sound any worse than the whole update getting lost, # which is what would happen if we combined the two into one # transaction. await self._update_max_stream_id(next_id) self._account_data_stream_cache.entity_has_changed( user_id, next_id) self.get_account_data_for_user.invalidate((user_id, )) self.get_account_data_for_room.invalidate((user_id, room_id)) self.get_account_data_for_room_and_type.prefill( (user_id, room_id, account_data_type), content) return self._account_data_id_gen.get_current_token() async def add_account_data_for_user(self, user_id: str, account_data_type: str, content: JsonDict) -> int: """Add some account_data to a room for a user. Args: user_id: The user to add a tag for. account_data_type: The type of account_data to add. content: A json object to associate with the tag. Returns: The maximum stream ID. """ content_json = json_encoder.encode(content) with self._account_data_id_gen.get_next() as next_id: # no need to lock here as account_data has a unique constraint on # (user_id, account_data_type) so simple_upsert will retry if # there is a conflict. await self.db_pool.simple_upsert( desc="add_user_account_data", table="account_data", keyvalues={ "user_id": user_id, "account_data_type": account_data_type }, values={ "stream_id": next_id, "content": content_json }, lock=False, ) # it's theoretically possible for the above to succeed and the # below to fail - in which case we might reuse a stream id on # restart, and the above update might not get propagated. That # doesn't sound any worse than the whole update getting lost, # which is what would happen if we combined the two into one # transaction. # # Note: This is only here for backwards compat to allow admins to # roll back to a previous Synapse version. Next time we update the # database version we can remove this table. await self._update_max_stream_id(next_id) self._account_data_stream_cache.entity_has_changed( user_id, next_id) self.get_account_data_for_user.invalidate((user_id, )) self.get_global_account_data_by_type_for_user.invalidate( (account_data_type, user_id)) return self._account_data_id_gen.get_current_token() def _update_max_stream_id(self, next_id: int): """Update the max stream_id Args: next_id: The the revision to advance to. """ # Note: This is only here for backwards compat to allow admins to # roll back to a previous Synapse version. Next time we update the # database version we can remove this table. def _update(txn): update_max_id_sql = ("UPDATE account_data_max_stream_id" " SET stream_id = ?" " WHERE stream_id < ?") txn.execute(update_max_id_sql, (next_id, next_id)) return self.db_pool.runInteraction("update_account_data_max_stream_id", _update)
def __init__(self, db_conn, hs): self._account_data_id_gen = StreamIdGenerator( db_conn, "account_data_max_stream_id", "stream_id" ) super(AccountDataStore, self).__init__(db_conn, hs)
class AccountDataStore(AccountDataWorkerStore): def __init__(self, db_conn, hs): self._account_data_id_gen = StreamIdGenerator( db_conn, "account_data_max_stream_id", "stream_id" ) super(AccountDataStore, self).__init__(db_conn, hs) def get_max_account_data_stream_id(self): """Get the current max stream id for the private user data stream Returns: A deferred int. """ return self._account_data_id_gen.get_current_token() @defer.inlineCallbacks def add_account_data_to_room(self, user_id, room_id, account_data_type, content): """Add some account_data to a room for a user. Args: user_id(str): The user to add a tag for. room_id(str): The room to add a tag for. account_data_type(str): The type of account_data to add. content(dict): A json object to associate with the tag. Returns: A deferred that completes once the account_data has been added. """ content_json = json.dumps(content) with self._account_data_id_gen.get_next() as next_id: # no need to lock here as room_account_data has a unique constraint # on (user_id, room_id, account_data_type) so _simple_upsert will # retry if there is a conflict. yield self._simple_upsert( desc="add_room_account_data", table="room_account_data", keyvalues={ "user_id": user_id, "room_id": room_id, "account_data_type": account_data_type, }, values={"stream_id": next_id, "content": content_json}, lock=False, ) # it's theoretically possible for the above to succeed and the # below to fail - in which case we might reuse a stream id on # restart, and the above update might not get propagated. That # doesn't sound any worse than the whole update getting lost, # which is what would happen if we combined the two into one # transaction. yield self._update_max_stream_id(next_id) self._account_data_stream_cache.entity_has_changed(user_id, next_id) self.get_account_data_for_user.invalidate((user_id,)) self.get_account_data_for_room.invalidate((user_id, room_id)) self.get_account_data_for_room_and_type.prefill( (user_id, room_id, account_data_type), content ) result = self._account_data_id_gen.get_current_token() return result @defer.inlineCallbacks def add_account_data_for_user(self, user_id, account_data_type, content): """Add some account_data to a room for a user. Args: user_id(str): The user to add a tag for. account_data_type(str): The type of account_data to add. content(dict): A json object to associate with the tag. Returns: A deferred that completes once the account_data has been added. """ content_json = json.dumps(content) with self._account_data_id_gen.get_next() as next_id: # no need to lock here as account_data has a unique constraint on # (user_id, account_data_type) so _simple_upsert will retry if # there is a conflict. yield self._simple_upsert( desc="add_user_account_data", table="account_data", keyvalues={"user_id": user_id, "account_data_type": account_data_type}, values={"stream_id": next_id, "content": content_json}, lock=False, ) # it's theoretically possible for the above to succeed and the # below to fail - in which case we might reuse a stream id on # restart, and the above update might not get propagated. That # doesn't sound any worse than the whole update getting lost, # which is what would happen if we combined the two into one # transaction. yield self._update_max_stream_id(next_id) self._account_data_stream_cache.entity_has_changed(user_id, next_id) self.get_account_data_for_user.invalidate((user_id,)) self.get_global_account_data_by_type_for_user.invalidate( (account_data_type, user_id) ) result = self._account_data_id_gen.get_current_token() return result def _update_max_stream_id(self, next_id): """Update the max stream_id Args: next_id(int): The the revision to advance to. """ def _update(txn): update_max_id_sql = ( "UPDATE account_data_max_stream_id" " SET stream_id = ?" " WHERE stream_id < ?" ) txn.execute(update_max_id_sql, (next_id, next_id)) return self.runInteraction("update_account_data_max_stream_id", _update)
def __init__(self, database: DatabasePool, db_conn, hs): super().__init__(database, db_conn, hs) self._instance_name = hs.get_instance_name() # Map of (user_id, device_id) to the last stream_id that has been # deleted up to. This is so that we can no op deletions. self._last_device_delete_cache = ExpiringCache( cache_name="last_device_delete_cache", clock=self._clock, max_len=10000, expiry_ms=30 * 60 * 1000, ) if isinstance(database.engine, PostgresEngine): self._can_write_to_device = (self._instance_name in hs.config.worker.writers.to_device) self._device_inbox_id_gen = MultiWriterIdGenerator( db_conn=db_conn, db=database, stream_name="to_device", instance_name=self._instance_name, tables=[("device_inbox", "instance_name", "stream_id")], sequence_name="device_inbox_sequence", writers=hs.config.worker.writers.to_device, ) else: self._can_write_to_device = True self._device_inbox_id_gen = StreamIdGenerator( db_conn, "device_inbox", "stream_id") max_device_inbox_id = self._device_inbox_id_gen.get_current_token() device_inbox_prefill, min_device_inbox_id = self.db_pool.get_cache_dict( db_conn, "device_inbox", entity_column="user_id", stream_column="stream_id", max_value=max_device_inbox_id, limit=1000, ) self._device_inbox_stream_cache = StreamChangeCache( "DeviceInboxStreamChangeCache", min_device_inbox_id, prefilled_cache=device_inbox_prefill, ) # The federation outbox and the local device inbox uses the same # stream_id generator. device_outbox_prefill, min_device_outbox_id = self.db_pool.get_cache_dict( db_conn, "device_federation_outbox", entity_column="destination", stream_column="stream_id", max_value=max_device_inbox_id, limit=1000, ) self._device_federation_outbox_stream_cache = StreamChangeCache( "DeviceFederationOutboxStreamChangeCache", min_device_outbox_id, prefilled_cache=device_outbox_prefill, )
class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore): def __init__( self, database: DatabasePool, db_conn: LoggingDatabaseConnection, hs: "HomeServer", ): super().__init__(database, db_conn, hs) self._cross_signing_id_gen = StreamIdGenerator( db_conn, "e2e_cross_signing_keys", "stream_id") async def set_e2e_device_keys(self, user_id: str, device_id: str, time_now: int, device_keys: JsonDict) -> bool: """Stores device keys for a device. Returns whether there was a change or the keys were already in the database. """ def _set_e2e_device_keys_txn(txn: LoggingTransaction) -> bool: set_tag("user_id", user_id) set_tag("device_id", device_id) set_tag("time_now", time_now) set_tag("device_keys", device_keys) old_key_json = self.db_pool.simple_select_one_onecol_txn( txn, table="e2e_device_keys_json", keyvalues={ "user_id": user_id, "device_id": device_id }, retcol="key_json", allow_none=True, ) # In py3 we need old_key_json to match new_key_json type. The DB # returns unicode while encode_canonical_json returns bytes. new_key_json = encode_canonical_json(device_keys).decode("utf-8") if old_key_json == new_key_json: log_kv({"Message": "Device key already stored."}) return False self.db_pool.simple_upsert_txn( txn, table="e2e_device_keys_json", keyvalues={ "user_id": user_id, "device_id": device_id }, values={ "ts_added_ms": time_now, "key_json": new_key_json }, ) log_kv({"message": "Device keys stored."}) return True return await self.db_pool.runInteraction("set_e2e_device_keys", _set_e2e_device_keys_txn) async def delete_e2e_keys_by_device(self, user_id: str, device_id: str) -> None: def delete_e2e_keys_by_device_txn(txn: LoggingTransaction) -> None: log_kv({ "message": "Deleting keys for device", "device_id": device_id, "user_id": user_id, }) self.db_pool.simple_delete_txn( txn, table="e2e_device_keys_json", keyvalues={ "user_id": user_id, "device_id": device_id }, ) self.db_pool.simple_delete_txn( txn, table="e2e_one_time_keys_json", keyvalues={ "user_id": user_id, "device_id": device_id }, ) self._invalidate_cache_and_stream(txn, self.count_e2e_one_time_keys, (user_id, device_id)) self.db_pool.simple_delete_txn( txn, table="dehydrated_devices", keyvalues={ "user_id": user_id, "device_id": device_id }, ) self.db_pool.simple_delete_txn( txn, table="e2e_fallback_keys_json", keyvalues={ "user_id": user_id, "device_id": device_id }, ) self._invalidate_cache_and_stream( txn, self.get_e2e_unused_fallback_key_types, (user_id, device_id)) await self.db_pool.runInteraction("delete_e2e_keys_by_device", delete_e2e_keys_by_device_txn) def _set_e2e_cross_signing_key_txn( self, txn: LoggingTransaction, user_id: str, key_type: str, key: JsonDict, stream_id: int, ) -> None: """Set a user's cross-signing key. Args: txn: db connection user_id: the user to set the signing key for key_type: the type of key that is being set: either 'master' for a master key, 'self_signing' for a self-signing key, or 'user_signing' for a user-signing key key: the key data stream_id """ # the 'key' dict will look something like: # { # "user_id": "@alice:example.com", # "usage": ["self_signing"], # "keys": { # "ed25519:base64+self+signing+public+key": "base64+self+signing+public+key", # }, # "signatures": { # "@alice:example.com": { # "ed25519:base64+master+public+key": "base64+signature" # } # } # } # The "keys" property must only have one entry, which will be the public # key, so we just grab the first value in there pubkey = next(iter(key["keys"].values())) # The cross-signing keys need to occupy the same namespace as devices, # since signatures are identified by device ID. So add an entry to the # device table to make sure that we don't have a collision with device # IDs. # We only need to do this for local users, since remote servers should be # responsible for checking this for their own users. if self.hs.is_mine_id(user_id): self.db_pool.simple_insert_txn( txn, "devices", values={ "user_id": user_id, "device_id": pubkey, "display_name": key_type + " signing key", "hidden": True, }, ) # and finally, store the key itself self.db_pool.simple_insert_txn( txn, "e2e_cross_signing_keys", values={ "user_id": user_id, "keytype": key_type, "keydata": json_encoder.encode(key), "stream_id": stream_id, }, ) self._invalidate_cache_and_stream( txn, self._get_bare_e2e_cross_signing_keys, (user_id, )) async def set_e2e_cross_signing_key(self, user_id: str, key_type: str, key: JsonDict) -> None: """Set a user's cross-signing key. Args: user_id: the user to set the user-signing key for key_type: the type of cross-signing key to set key: the key data """ async with self._cross_signing_id_gen.get_next() as stream_id: return await self.db_pool.runInteraction( "add_e2e_cross_signing_key", self._set_e2e_cross_signing_key_txn, user_id, key_type, key, stream_id, ) async def store_e2e_cross_signing_signatures( self, user_id: str, signatures: "Iterable[SignatureListItem]") -> None: """Stores cross-signing signatures. Args: user_id: the user who made the signatures signatures: signatures to add """ await self.db_pool.simple_insert_many( "e2e_cross_signing_signatures", keys=( "user_id", "key_id", "target_user_id", "target_device_id", "signature", ), values=[( user_id, item.signing_key_id, item.target_user_id, item.target_device_id, item.signature, ) for item in signatures], desc="add_e2e_signing_key", )
class DeviceInboxWorkerStore(SQLBaseStore): def __init__(self, database: DatabasePool, db_conn, hs): super().__init__(database, db_conn, hs) self._instance_name = hs.get_instance_name() # Map of (user_id, device_id) to the last stream_id that has been # deleted up to. This is so that we can no op deletions. self._last_device_delete_cache = ExpiringCache( cache_name="last_device_delete_cache", clock=self._clock, max_len=10000, expiry_ms=30 * 60 * 1000, ) if isinstance(database.engine, PostgresEngine): self._can_write_to_device = (self._instance_name in hs.config.worker.writers.to_device) self._device_inbox_id_gen = MultiWriterIdGenerator( db_conn=db_conn, db=database, stream_name="to_device", instance_name=self._instance_name, tables=[("device_inbox", "instance_name", "stream_id")], sequence_name="device_inbox_sequence", writers=hs.config.worker.writers.to_device, ) else: self._can_write_to_device = True self._device_inbox_id_gen = StreamIdGenerator( db_conn, "device_inbox", "stream_id") max_device_inbox_id = self._device_inbox_id_gen.get_current_token() device_inbox_prefill, min_device_inbox_id = self.db_pool.get_cache_dict( db_conn, "device_inbox", entity_column="user_id", stream_column="stream_id", max_value=max_device_inbox_id, limit=1000, ) self._device_inbox_stream_cache = StreamChangeCache( "DeviceInboxStreamChangeCache", min_device_inbox_id, prefilled_cache=device_inbox_prefill, ) # The federation outbox and the local device inbox uses the same # stream_id generator. device_outbox_prefill, min_device_outbox_id = self.db_pool.get_cache_dict( db_conn, "device_federation_outbox", entity_column="destination", stream_column="stream_id", max_value=max_device_inbox_id, limit=1000, ) self._device_federation_outbox_stream_cache = StreamChangeCache( "DeviceFederationOutboxStreamChangeCache", min_device_outbox_id, prefilled_cache=device_outbox_prefill, ) def process_replication_rows(self, stream_name, instance_name, token, rows): if stream_name == ToDeviceStream.NAME: self._device_inbox_id_gen.advance(instance_name, token) for row in rows: if row.entity.startswith("@"): self._device_inbox_stream_cache.entity_has_changed( row.entity, token) else: self._device_federation_outbox_stream_cache.entity_has_changed( row.entity, token) return super().process_replication_rows(stream_name, instance_name, token, rows) def get_to_device_stream_token(self): return self._device_inbox_id_gen.get_current_token() async def get_new_messages_for_device( self, user_id: str, device_id: str, last_stream_id: int, current_stream_id: int, limit: int = 100, ) -> Tuple[List[dict], int]: """ Args: user_id: The recipient user_id. device_id: The recipient device_id. last_stream_id: The last stream ID checked. current_stream_id: The current position of the to device message stream. limit: The maximum number of messages to retrieve. Returns: A list of messages for the device and where in the stream the messages got to. """ has_changed = self._device_inbox_stream_cache.has_entity_changed( user_id, last_stream_id) if not has_changed: return ([], current_stream_id) def get_new_messages_for_device_txn(txn): sql = ("SELECT stream_id, message_json FROM device_inbox" " WHERE user_id = ? AND device_id = ?" " AND ? < stream_id AND stream_id <= ?" " ORDER BY stream_id ASC" " LIMIT ?") txn.execute( sql, (user_id, device_id, last_stream_id, current_stream_id, limit)) messages = [] for row in txn: stream_pos = row[0] messages.append(db_to_json(row[1])) if len(messages) < limit: stream_pos = current_stream_id return messages, stream_pos return await self.db_pool.runInteraction( "get_new_messages_for_device", get_new_messages_for_device_txn) @trace async def delete_messages_for_device(self, user_id: str, device_id: str, up_to_stream_id: int) -> int: """ Args: user_id: The recipient user_id. device_id: The recipient device_id. up_to_stream_id: Where to delete messages up to. Returns: The number of messages deleted. """ # If we have cached the last stream id we've deleted up to, we can # check if there is likely to be anything that needs deleting last_deleted_stream_id = self._last_device_delete_cache.get( (user_id, device_id), None) set_tag("last_deleted_stream_id", last_deleted_stream_id) if last_deleted_stream_id: has_changed = self._device_inbox_stream_cache.has_entity_changed( user_id, last_deleted_stream_id) if not has_changed: log_kv({"message": "No changes in cache since last check"}) return 0 def delete_messages_for_device_txn(txn): sql = ("DELETE FROM device_inbox" " WHERE user_id = ? AND device_id = ?" " AND stream_id <= ?") txn.execute(sql, (user_id, device_id, up_to_stream_id)) return txn.rowcount count = await self.db_pool.runInteraction( "delete_messages_for_device", delete_messages_for_device_txn) log_kv({ "message": "deleted {} messages for device".format(count), "count": count }) # Update the cache, ensuring that we only ever increase the value last_deleted_stream_id = self._last_device_delete_cache.get( (user_id, device_id), 0) self._last_device_delete_cache[(user_id, device_id)] = max( last_deleted_stream_id, up_to_stream_id) return count @trace async def get_new_device_msgs_for_remote(self, destination, last_stream_id, current_stream_id, limit) -> Tuple[List[dict], int]: """ Args: destination(str): The name of the remote server. last_stream_id(int|long): The last position of the device message stream that the server sent up to. current_stream_id(int|long): The current position of the device message stream. Returns: A list of messages for the device and where in the stream the messages got to. """ set_tag("destination", destination) set_tag("last_stream_id", last_stream_id) set_tag("current_stream_id", current_stream_id) set_tag("limit", limit) has_changed = self._device_federation_outbox_stream_cache.has_entity_changed( destination, last_stream_id) if not has_changed or last_stream_id == current_stream_id: log_kv({"message": "No new messages in stream"}) return ([], current_stream_id) if limit <= 0: # This can happen if we run out of room for EDUs in the transaction. return ([], last_stream_id) @trace def get_new_messages_for_remote_destination_txn(txn): sql = ( "SELECT stream_id, messages_json FROM device_federation_outbox" " WHERE destination = ?" " AND ? < stream_id AND stream_id <= ?" " ORDER BY stream_id ASC" " LIMIT ?") txn.execute( sql, (destination, last_stream_id, current_stream_id, limit)) messages = [] for row in txn: stream_pos = row[0] messages.append(db_to_json(row[1])) if len(messages) < limit: log_kv({"message": "Set stream position to current position"}) stream_pos = current_stream_id return messages, stream_pos return await self.db_pool.runInteraction( "get_new_device_msgs_for_remote", get_new_messages_for_remote_destination_txn, ) @trace async def delete_device_msgs_for_remote(self, destination: str, up_to_stream_id: int) -> None: """Used to delete messages when the remote destination acknowledges their receipt. Args: destination: The destination server_name up_to_stream_id: Where to delete messages up to. """ def delete_messages_for_remote_destination_txn(txn): sql = ("DELETE FROM device_federation_outbox" " WHERE destination = ?" " AND stream_id <= ?") txn.execute(sql, (destination, up_to_stream_id)) await self.db_pool.runInteraction( "delete_device_msgs_for_remote", delete_messages_for_remote_destination_txn) async def get_all_new_device_messages( self, instance_name: str, last_id: int, current_id: int, limit: int) -> Tuple[List[Tuple[int, tuple]], int, bool]: """Get updates for to device replication stream. Args: instance_name: The writer we want to fetch updates from. Unused here since there is only ever one writer. last_id: The token to fetch updates from. Exclusive. current_id: The token to fetch updates up to. Inclusive. limit: The requested limit for the number of rows to return. The function may return more or fewer rows. Returns: A tuple consisting of: the updates, a token to use to fetch subsequent updates, and whether we returned fewer rows than exists between the requested tokens due to the limit. The token returned can be used in a subsequent call to this function to get further updatees. The updates are a list of 2-tuples of stream ID and the row data """ if last_id == current_id: return [], current_id, False def get_all_new_device_messages_txn(txn): # We limit like this as we might have multiple rows per stream_id, and # we want to make sure we always get all entries for any stream_id # we return. upper_pos = min(current_id, last_id + limit) sql = ("SELECT max(stream_id), user_id" " FROM device_inbox" " WHERE ? < stream_id AND stream_id <= ?" " GROUP BY user_id") txn.execute(sql, (last_id, upper_pos)) updates = [(row[0], row[1:]) for row in txn] sql = ("SELECT max(stream_id), destination" " FROM device_federation_outbox" " WHERE ? < stream_id AND stream_id <= ?" " GROUP BY destination") txn.execute(sql, (last_id, upper_pos)) updates.extend((row[0], row[1:]) for row in txn) # Order by ascending stream ordering updates.sort() limited = False upto_token = current_id if len(updates) >= limit: upto_token = updates[-1][0] limited = True return updates, upto_token, limited return await self.db_pool.runInteraction( "get_all_new_device_messages", get_all_new_device_messages_txn) @trace async def add_messages_to_device_inbox( self, local_messages_by_user_then_device: dict, remote_messages_by_destination: dict, ) -> int: """Used to send messages from this server. Args: local_messages_by_user_and_device: Dictionary of user_id to device_id to message. remote_messages_by_destination: Dictionary of destination server_name to the EDU JSON to send. Returns: The new stream_id. """ assert self._can_write_to_device def add_messages_txn(txn, now_ms, stream_id): # Add the local messages directly to the local inbox. self._add_messages_to_local_device_inbox_txn( txn, stream_id, local_messages_by_user_then_device) # Add the remote messages to the federation outbox. # We'll send them to a remote server when we next send a # federation transaction to that destination. self.db_pool.simple_insert_many_txn( txn, table="device_federation_outbox", values=[{ "destination": destination, "stream_id": stream_id, "queued_ts": now_ms, "messages_json": json_encoder.encode(edu), "instance_name": self._instance_name, } for destination, edu in remote_messages_by_destination.items()], ) async with self._device_inbox_id_gen.get_next() as stream_id: now_ms = self.clock.time_msec() await self.db_pool.runInteraction("add_messages_to_device_inbox", add_messages_txn, now_ms, stream_id) for user_id in local_messages_by_user_then_device.keys(): self._device_inbox_stream_cache.entity_has_changed( user_id, stream_id) for destination in remote_messages_by_destination.keys(): self._device_federation_outbox_stream_cache.entity_has_changed( destination, stream_id) return self._device_inbox_id_gen.get_current_token() async def add_messages_from_remote_to_device_inbox( self, origin: str, message_id: str, local_messages_by_user_then_device: dict) -> int: assert self._can_write_to_device def add_messages_txn(txn, now_ms, stream_id): # Check if we've already inserted a matching message_id for that # origin. This can happen if the origin doesn't receive our # acknowledgement from the first time we received the message. already_inserted = self.db_pool.simple_select_one_txn( txn, table="device_federation_inbox", keyvalues={ "origin": origin, "message_id": message_id }, retcols=("message_id", ), allow_none=True, ) if already_inserted is not None: return # Add an entry for this message_id so that we know we've processed # it. self.db_pool.simple_insert_txn( txn, table="device_federation_inbox", values={ "origin": origin, "message_id": message_id, "received_ts": now_ms, }, ) # Add the messages to the approriate local device inboxes so that # they'll be sent to the devices when they next sync. self._add_messages_to_local_device_inbox_txn( txn, stream_id, local_messages_by_user_then_device) async with self._device_inbox_id_gen.get_next() as stream_id: now_ms = self.clock.time_msec() await self.db_pool.runInteraction( "add_messages_from_remote_to_device_inbox", add_messages_txn, now_ms, stream_id, ) for user_id in local_messages_by_user_then_device.keys(): self._device_inbox_stream_cache.entity_has_changed( user_id, stream_id) return stream_id def _add_messages_to_local_device_inbox_txn(self, txn, stream_id, messages_by_user_then_device): assert self._can_write_to_device local_by_user_then_device = {} for user_id, messages_by_device in messages_by_user_then_device.items( ): messages_json_for_user = {} devices = list(messages_by_device.keys()) if len(devices) == 1 and devices[0] == "*": # Handle wildcard device_ids. devices = self.db_pool.simple_select_onecol_txn( txn, table="devices", keyvalues={"user_id": user_id}, retcol="device_id", ) message_json = json_encoder.encode(messages_by_device["*"]) for device_id in devices: # Add the message for all devices for this user on this # server. messages_json_for_user[device_id] = message_json else: if not devices: continue rows = self.db_pool.simple_select_many_txn( txn, table="devices", keyvalues={"user_id": user_id}, column="device_id", iterable=devices, retcols=("device_id", ), ) for row in rows: # Only insert into the local inbox if the device exists on # this server device_id = row["device_id"] message_json = json_encoder.encode( messages_by_device[device_id]) messages_json_for_user[device_id] = message_json if messages_json_for_user: local_by_user_then_device[user_id] = messages_json_for_user if not local_by_user_then_device: return self.db_pool.simple_insert_many_txn( txn, table="device_inbox", values=[{ "user_id": user_id, "device_id": device_id, "stream_id": stream_id, "message_json": message_json, "instance_name": self._instance_name, } for user_id, messages_by_device in local_by_user_then_device.items() for device_id, message_json in messages_by_device.items()], )
def __init__( self, database: DatabasePool, db_conn: LoggingDatabaseConnection, hs: "HomeServer", ): self.hs = hs self._clock = hs.get_clock() self.database_engine = database.engine self._device_list_id_gen = StreamIdGenerator( db_conn, "device_lists_stream", "stream_id", extra_tables=[ ("user_signature_stream", "stream_id"), ("device_lists_outbound_pokes", "stream_id"), ("device_lists_changes_in_room", "stream_id"), ], ) self._cache_id_gen: Optional[MultiWriterIdGenerator] if isinstance(self.database_engine, PostgresEngine): # We set the `writers` to an empty list here as we don't care about # missing updates over restarts, as we'll not have anything in our # caches to invalidate. (This reduces the amount of writes to the DB # that happen). self._cache_id_gen = MultiWriterIdGenerator( db_conn, database, stream_name="caches", instance_name=hs.get_instance_name(), tables=[( "cache_invalidation_stream_by_instance", "instance_name", "stream_id", )], sequence_name="cache_invalidation_stream_seq", writers=[], ) else: self._cache_id_gen = None super().__init__(database, db_conn, hs) events_max = self._stream_id_gen.get_current_token() curr_state_delta_prefill, min_curr_state_delta_id = self.db_pool.get_cache_dict( db_conn, "current_state_delta_stream", entity_column="room_id", stream_column="stream_id", max_value=events_max, # As we share the stream id with events token limit=1000, ) self._curr_state_delta_stream_cache = StreamChangeCache( "_curr_state_delta_stream_cache", min_curr_state_delta_id, prefilled_cache=curr_state_delta_prefill, ) self._stream_order_on_start = self.get_room_max_stream_ordering() self._min_stream_order_on_start = self.get_room_min_stream_ordering()
class PresenceStore(PresenceBackgroundUpdateStore): def __init__( self, database: DatabasePool, db_conn: LoggingDatabaseConnection, hs: "HomeServer", ): super().__init__(database, db_conn, hs) self._can_persist_presence = (hs.get_instance_name() in hs.config.worker.writers.presence) if isinstance(database.engine, PostgresEngine): self._presence_id_gen = MultiWriterIdGenerator( db_conn=db_conn, db=database, stream_name="presence_stream", instance_name=self._instance_name, tables=[("presence_stream", "instance_name", "stream_id")], sequence_name="presence_stream_sequence", writers=hs.config.worker.writers.presence, ) else: self._presence_id_gen = StreamIdGenerator(db_conn, "presence_stream", "stream_id") self.hs = hs self._presence_on_startup = self._get_active_presence(db_conn) presence_cache_prefill, min_presence_val = self.db_pool.get_cache_dict( db_conn, "presence_stream", entity_column="user_id", stream_column="stream_id", max_value=self._presence_id_gen.get_current_token(), ) self.presence_stream_cache = StreamChangeCache( "PresenceStreamChangeCache", min_presence_val, prefilled_cache=presence_cache_prefill, ) async def update_presence(self, presence_states) -> Tuple[int, int]: assert self._can_persist_presence stream_ordering_manager = self._presence_id_gen.get_next_mult( len(presence_states)) async with stream_ordering_manager as stream_orderings: await self.db_pool.runInteraction( "update_presence", self._update_presence_txn, stream_orderings, presence_states, ) return stream_orderings[-1], self._presence_id_gen.get_current_token() def _update_presence_txn(self, txn, stream_orderings, presence_states): for stream_id, state in zip(stream_orderings, presence_states): txn.call_after(self.presence_stream_cache.entity_has_changed, state.user_id, stream_id) txn.call_after(self._get_presence_for_user.invalidate, (state.user_id, )) # Delete old rows to stop database from getting really big sql = "DELETE FROM presence_stream WHERE stream_id < ? AND " for states in batch_iter(presence_states, 50): clause, args = make_in_list_sql_clause(self.database_engine, "user_id", [s.user_id for s in states]) txn.execute(sql + clause, [stream_id] + list(args)) # Actually insert new rows self.db_pool.simple_insert_many_txn( txn, table="presence_stream", keys=( "stream_id", "user_id", "state", "last_active_ts", "last_federation_update_ts", "last_user_sync_ts", "status_msg", "currently_active", "instance_name", ), values=[( stream_id, state.user_id, state.state, state.last_active_ts, state.last_federation_update_ts, state.last_user_sync_ts, state.status_msg, state.currently_active, self._instance_name, ) for stream_id, state in zip(stream_orderings, presence_states)], ) async def get_all_presence_updates( self, instance_name: str, last_id: int, current_id: int, limit: int) -> Tuple[List[Tuple[int, list]], int, bool]: """Get updates for presence replication stream. Args: instance_name: The writer we want to fetch updates from. Unused here since there is only ever one writer. last_id: The token to fetch updates from. Exclusive. current_id: The token to fetch updates up to. Inclusive. limit: The requested limit for the number of rows to return. The function may return more or fewer rows. Returns: A tuple consisting of: the updates, a token to use to fetch subsequent updates, and whether we returned fewer rows than exists between the requested tokens due to the limit. The token returned can be used in a subsequent call to this function to get further updatees. The updates are a list of 2-tuples of stream ID and the row data """ if last_id == current_id: return [], current_id, False def get_all_presence_updates_txn(txn): sql = """ SELECT stream_id, user_id, state, last_active_ts, last_federation_update_ts, last_user_sync_ts, status_msg, currently_active FROM presence_stream WHERE ? < stream_id AND stream_id <= ? ORDER BY stream_id ASC LIMIT ? """ txn.execute(sql, (last_id, current_id, limit)) updates = [(row[0], row[1:]) for row in txn] upper_bound = current_id limited = False if len(updates) >= limit: upper_bound = updates[-1][0] limited = True return updates, upper_bound, limited return await self.db_pool.runInteraction("get_all_presence_updates", get_all_presence_updates_txn) @cached() def _get_presence_for_user(self, user_id): raise NotImplementedError() @cachedList( cached_method_name="_get_presence_for_user", list_name="user_ids", num_args=1, ) async def get_presence_for_users(self, user_ids): rows = await self.db_pool.simple_select_many_batch( table="presence_stream", column="user_id", iterable=user_ids, keyvalues={}, retcols=( "user_id", "state", "last_active_ts", "last_federation_update_ts", "last_user_sync_ts", "status_msg", "currently_active", ), desc="get_presence_for_users", ) for row in rows: row["currently_active"] = bool(row["currently_active"]) return {row["user_id"]: UserPresenceState(**row) for row in rows} async def should_user_receive_full_presence_with_token( self, user_id: str, from_token: int, ) -> bool: """Check whether the given user should receive full presence using the stream token they're updating from. Args: user_id: The ID of the user to check. from_token: The stream token included in their /sync token. Returns: True if the user should have full presence sent to them, False otherwise. """ def _should_user_receive_full_presence_with_token_txn(txn): sql = """ SELECT 1 FROM users_to_send_full_presence_to WHERE user_id = ? AND presence_stream_id >= ? """ txn.execute(sql, (user_id, from_token)) return bool(txn.fetchone()) return await self.db_pool.runInteraction( "should_user_receive_full_presence_with_token", _should_user_receive_full_presence_with_token_txn, ) async def add_users_to_send_full_presence_to(self, user_ids: Iterable[str]): """Adds to the list of users who should receive a full snapshot of presence upon their next sync. Args: user_ids: An iterable of user IDs. """ # Add user entries to the table, updating the presence_stream_id column if the user already # exists in the table. presence_stream_id = self._presence_id_gen.get_current_token() await self.db_pool.simple_upsert_many( table="users_to_send_full_presence_to", key_names=("user_id", ), key_values=[(user_id, ) for user_id in user_ids], value_names=("presence_stream_id", ), # We save the current presence stream ID token along with the user ID entry so # that when a user /sync's, even if they syncing multiple times across separate # devices at different times, each device will receive full presence once - when # the presence stream ID in their sync token is less than the one in the table # for their user ID. value_values=[(presence_stream_id, ) for _ in user_ids], desc="add_users_to_send_full_presence_to", ) async def get_presence_for_all_users( self, include_offline: bool = True, ) -> Dict[str, UserPresenceState]: """Retrieve the current presence state for all users. Note that the presence_stream table is culled frequently, so it should only contain the latest presence state for each user. Args: include_offline: Whether to include offline presence states Returns: A dict of user IDs to their current UserPresenceState. """ users_to_state = {} exclude_keyvalues = None if not include_offline: # Exclude offline presence state exclude_keyvalues = {"state": "offline"} # This may be a very heavy database query. # We paginate in order to not block a database connection. limit = 100 offset = 0 while True: rows = await self.db_pool.runInteraction( "get_presence_for_all_users", self.db_pool.simple_select_list_paginate_txn, "presence_stream", orderby="stream_id", start=offset, limit=limit, exclude_keyvalues=exclude_keyvalues, retcols=( "user_id", "state", "last_active_ts", "last_federation_update_ts", "last_user_sync_ts", "status_msg", "currently_active", ), order_direction="ASC", ) for row in rows: users_to_state[row["user_id"]] = UserPresenceState(**row) # We've run out of updates to query if len(rows) < limit: break offset += limit return users_to_state def get_current_presence_token(self): return self._presence_id_gen.get_current_token() def _get_active_presence(self, db_conn: Connection): """Fetch non-offline presence from the database so that we can register the appropriate time outs. """ # The `presence_stream_state_not_offline_idx` index should be used for this # query. sql = ( "SELECT user_id, state, last_active_ts, last_federation_update_ts," " last_user_sync_ts, status_msg, currently_active FROM presence_stream" " WHERE state != ?") txn = db_conn.cursor() txn.execute(sql, (PresenceState.OFFLINE, )) rows = self.db_pool.cursor_to_dict(txn) txn.close() for row in rows: row["currently_active"] = bool(row["currently_active"]) return [UserPresenceState(**row) for row in rows] def take_presence_startup_info(self): active_on_startup = self._presence_on_startup self._presence_on_startup = None return active_on_startup def process_replication_rows(self, stream_name, instance_name, token, rows): if stream_name == PresenceStream.NAME: self._presence_id_gen.advance(instance_name, token) for row in rows: self.presence_stream_cache.entity_has_changed( row.user_id, token) self._get_presence_for_user.invalidate((row.user_id, )) return super().process_replication_rows(stream_name, instance_name, token, rows)
class DataStore( EventsBackgroundUpdatesStore, RoomMemberStore, RoomStore, RoomBatchStore, RegistrationStore, StreamWorkerStore, ProfileStore, PresenceStore, TransactionWorkerStore, DirectoryStore, KeyStore, StateStore, SignatureStore, ApplicationServiceStore, PurgeEventsStore, EventFederationStore, MediaRepositoryStore, RejectionsStore, FilteringStore, PusherStore, PushRuleStore, ApplicationServiceTransactionStore, ReceiptsStore, EndToEndKeyStore, EndToEndRoomKeyStore, SearchStore, TagsStore, AccountDataStore, EventPushActionsStore, OpenIdStore, ClientIpWorkerStore, DeviceStore, DeviceInboxStore, UserDirectoryStore, GroupServerStore, UserErasureStore, MonthlyActiveUsersWorkerStore, StatsStore, RelationsStore, CensorEventsStore, UIAuthStore, EventForwardExtremitiesStore, CacheInvalidationWorkerStore, ServerMetricsStore, LockStore, SessionStore, ): def __init__( self, database: DatabasePool, db_conn: LoggingDatabaseConnection, hs: "HomeServer", ): self.hs = hs self._clock = hs.get_clock() self.database_engine = database.engine self._device_list_id_gen = StreamIdGenerator( db_conn, "device_lists_stream", "stream_id", extra_tables=[ ("user_signature_stream", "stream_id"), ("device_lists_outbound_pokes", "stream_id"), ("device_lists_changes_in_room", "stream_id"), ], ) self._cache_id_gen: Optional[MultiWriterIdGenerator] if isinstance(self.database_engine, PostgresEngine): # We set the `writers` to an empty list here as we don't care about # missing updates over restarts, as we'll not have anything in our # caches to invalidate. (This reduces the amount of writes to the DB # that happen). self._cache_id_gen = MultiWriterIdGenerator( db_conn, database, stream_name="caches", instance_name=hs.get_instance_name(), tables=[( "cache_invalidation_stream_by_instance", "instance_name", "stream_id", )], sequence_name="cache_invalidation_stream_seq", writers=[], ) else: self._cache_id_gen = None super().__init__(database, db_conn, hs) events_max = self._stream_id_gen.get_current_token() curr_state_delta_prefill, min_curr_state_delta_id = self.db_pool.get_cache_dict( db_conn, "current_state_delta_stream", entity_column="room_id", stream_column="stream_id", max_value=events_max, # As we share the stream id with events token limit=1000, ) self._curr_state_delta_stream_cache = StreamChangeCache( "_curr_state_delta_stream_cache", min_curr_state_delta_id, prefilled_cache=curr_state_delta_prefill, ) self._stream_order_on_start = self.get_room_max_stream_ordering() self._min_stream_order_on_start = self.get_room_min_stream_ordering() def get_device_stream_token(self) -> int: return self._device_list_id_gen.get_current_token() async def get_users(self) -> List[JsonDict]: """Function to retrieve a list of users in users table. Returns: A list of dictionaries representing users. """ return await self.db_pool.simple_select_list( table="users", keyvalues={}, retcols=[ "name", "password_hash", "is_guest", "admin", "user_type", "deactivated", ], desc="get_users", ) async def get_users_paginate( self, start: int, limit: int, user_id: Optional[str] = None, name: Optional[str] = None, guests: bool = True, deactivated: bool = False, order_by: str = UserSortOrder.USER_ID.value, direction: str = "f", ) -> Tuple[List[JsonDict], int]: """Function to retrieve a paginated list of users from users list. This will return a json list of users and the total number of users matching the filter criteria. Args: start: start number to begin the query from limit: number of rows to retrieve user_id: search for user_id. ignored if name is not None name: search for local part of user_id or display name guests: whether to in include guest users deactivated: whether to include deactivated users order_by: the sort order of the returned list direction: sort ascending or descending Returns: A tuple of a list of mappings from user to information and a count of total users. """ def get_users_paginate_txn( txn: LoggingTransaction, ) -> Tuple[List[JsonDict], int]: filters = [] args = [self.hs.config.server.server_name] # Set ordering order_by_column = UserSortOrder(order_by).value if direction == "b": order = "DESC" else: order = "ASC" # `name` is in database already in lower case if name: filters.append("(name LIKE ? OR LOWER(displayname) LIKE ?)") args.extend( ["@%" + name.lower() + "%:%", "%" + name.lower() + "%"]) elif user_id: filters.append("name LIKE ?") args.extend(["%" + user_id.lower() + "%"]) if not guests: filters.append("is_guest = 0") if not deactivated: filters.append("deactivated = 0") where_clause = "WHERE " + " AND ".join(filters) if len( filters) > 0 else "" sql_base = f""" FROM users as u LEFT JOIN profiles AS p ON u.name = '@' || p.user_id || ':' || ? {where_clause} """ sql = "SELECT COUNT(*) as total_users " + sql_base txn.execute(sql, args) count = cast(Tuple[int], txn.fetchone())[0] sql = f""" SELECT name, user_type, is_guest, admin, deactivated, shadow_banned, displayname, avatar_url, creation_ts * 1000 as creation_ts {sql_base} ORDER BY {order_by_column} {order}, u.name ASC LIMIT ? OFFSET ? """ args += [limit, start] txn.execute(sql, args) users = self.db_pool.cursor_to_dict(txn) return users, count return await self.db_pool.runInteraction("get_users_paginate_txn", get_users_paginate_txn) async def search_users(self, term: str) -> Optional[List[JsonDict]]: """Function to search users list for one or more users with the matched term. Args: term: search term Returns: A list of dictionaries or None. """ return await self.db_pool.simple_search_list( table="users", term=term, col="name", retcols=[ "name", "password_hash", "is_guest", "admin", "user_type" ], desc="search_users", )
class ReceiptsStore(ReceiptsWorkerStore): def __init__(self, database: Database, db_conn, hs): # We instantiate this first as the ReceiptsWorkerStore constructor # needs to be able to call get_max_receipt_stream_id self._receipts_id_gen = StreamIdGenerator(db_conn, "receipts_linearized", "stream_id") super(ReceiptsStore, self).__init__(database, db_conn, hs) def get_max_receipt_stream_id(self): return self._receipts_id_gen.get_current_token() def insert_linearized_receipt_txn(self, txn, room_id, receipt_type, user_id, event_id, data, stream_id): """Inserts a read-receipt into the database if it's newer than the current RR Returns: int|None None if the RR is older than the current RR otherwise, the rx timestamp of the event that the RR corresponds to (or 0 if the event is unknown) """ res = self.db.simple_select_one_txn( txn, table="events", retcols=["stream_ordering", "received_ts"], keyvalues={"event_id": event_id}, allow_none=True, ) stream_ordering = int(res["stream_ordering"]) if res else None rx_ts = res["received_ts"] if res else 0 # We don't want to clobber receipts for more recent events, so we # have to compare orderings of existing receipts if stream_ordering is not None: sql = ( "SELECT stream_ordering, event_id FROM events" " INNER JOIN receipts_linearized as r USING (event_id, room_id)" " WHERE r.room_id = ? AND r.receipt_type = ? AND r.user_id = ?" ) txn.execute(sql, (room_id, receipt_type, user_id)) for so, eid in txn: if int(so) >= stream_ordering: logger.debug( "Ignoring new receipt for %s in favour of existing " "one for later event %s", event_id, eid, ) return None txn.call_after(self.get_receipts_for_room.invalidate, (room_id, receipt_type)) txn.call_after( self._invalidate_get_users_with_receipts_in_room, room_id, receipt_type, user_id, ) txn.call_after(self.get_receipts_for_user.invalidate, (user_id, receipt_type)) # FIXME: This shouldn't invalidate the whole cache txn.call_after(self._get_linearized_receipts_for_room.invalidate_many, (room_id, )) txn.call_after(self._receipts_stream_cache.entity_has_changed, room_id, stream_id) txn.call_after( self.get_last_receipt_event_id_for_user.invalidate, (user_id, room_id, receipt_type), ) self.db.simple_upsert_txn( txn, table="receipts_linearized", keyvalues={ "room_id": room_id, "receipt_type": receipt_type, "user_id": user_id, }, values={ "stream_id": stream_id, "event_id": event_id, "data": json.dumps(data), }, # receipts_linearized has a unique constraint on # (user_id, room_id, receipt_type), so no need to lock lock=False, ) if receipt_type == "m.read" and stream_ordering is not None: self._remove_old_push_actions_before_txn( txn, room_id=room_id, user_id=user_id, stream_ordering=stream_ordering) return rx_ts @defer.inlineCallbacks def insert_receipt(self, room_id, receipt_type, user_id, event_ids, data): """Insert a receipt, either from local client or remote server. Automatically does conversion between linearized and graph representations. """ if not event_ids: return if len(event_ids) == 1: linearized_event_id = event_ids[0] else: # we need to points in graph -> linearized form. # TODO: Make this better. def graph_to_linear(txn): clause, args = make_in_list_sql_clause(self.database_engine, "event_id", event_ids) sql = """ SELECT event_id WHERE room_id = ? AND stream_ordering IN ( SELECT max(stream_ordering) WHERE %s ) """ % (clause, ) txn.execute(sql, [room_id] + list(args)) rows = txn.fetchall() if rows: return rows[0][0] else: raise RuntimeError("Unrecognized event_ids: %r" % (event_ids, )) linearized_event_id = yield self.db.runInteraction( "insert_receipt_conv", graph_to_linear) stream_id_manager = self._receipts_id_gen.get_next() with stream_id_manager as stream_id: event_ts = yield self.db.runInteraction( "insert_linearized_receipt", self.insert_linearized_receipt_txn, room_id, receipt_type, user_id, linearized_event_id, data, stream_id=stream_id, ) if event_ts is None: return None now = self._clock.time_msec() logger.debug( "RR for event %s in %s (%i ms old)", linearized_event_id, room_id, now - event_ts, ) yield self.insert_graph_receipt(room_id, receipt_type, user_id, event_ids, data) max_persisted_id = self._receipts_id_gen.get_current_token() return stream_id, max_persisted_id def insert_graph_receipt(self, room_id, receipt_type, user_id, event_ids, data): return self.db.runInteraction( "insert_graph_receipt", self.insert_graph_receipt_txn, room_id, receipt_type, user_id, event_ids, data, ) def insert_graph_receipt_txn(self, txn, room_id, receipt_type, user_id, event_ids, data): txn.call_after(self.get_receipts_for_room.invalidate, (room_id, receipt_type)) txn.call_after( self._invalidate_get_users_with_receipts_in_room, room_id, receipt_type, user_id, ) txn.call_after(self.get_receipts_for_user.invalidate, (user_id, receipt_type)) # FIXME: This shouldn't invalidate the whole cache txn.call_after(self._get_linearized_receipts_for_room.invalidate_many, (room_id, )) self.db.simple_delete_txn( txn, table="receipts_graph", keyvalues={ "room_id": room_id, "receipt_type": receipt_type, "user_id": user_id, }, ) self.db.simple_insert_txn( txn, table="receipts_graph", values={ "room_id": room_id, "receipt_type": receipt_type, "user_id": user_id, "event_ids": json.dumps(event_ids), "data": json.dumps(data), }, )