def __init__(self, db_conn, hs): super(SlavedEventStore, self).__init__(db_conn, hs) self._stream_id_gen = SlavedIdTracker( db_conn, "events", "stream_ordering", ) self._backfill_id_gen = SlavedIdTracker( db_conn, "events", "stream_ordering", step=-1 ) events_max = self._stream_id_gen.get_current_token() event_cache_prefill, min_event_val = self._get_cache_dict( db_conn, "events", entity_column="room_id", stream_column="stream_ordering", max_value=events_max, ) self._events_stream_cache = StreamChangeCache( "EventsRoomStreamChangeCache", min_event_val, prefilled_cache=event_cache_prefill, ) self._membership_stream_cache = StreamChangeCache( "MembershipStreamChangeCache", events_max, ) self.stream_ordering_month_ago = 0 self._stream_order_on_start = self.get_room_max_stream_ordering()
def test_prefilled_cache(self): """ Providing a prefilled cache to StreamChangeCache will result in a cache with the prefilled-cache entered in. """ cache = StreamChangeCache("#test", 1, prefilled_cache={"*****@*****.**": 2}) self.assertTrue(cache.has_entity_changed("*****@*****.**", 1))
class SlavedGroupServerStore(BaseSlavedStore): def __init__(self, db_conn, hs): super(SlavedGroupServerStore, self).__init__(db_conn, hs) self.hs = hs self._group_updates_id_gen = SlavedIdTracker( db_conn, "local_group_updates", "stream_id", ) self._group_updates_stream_cache = StreamChangeCache( "_group_updates_stream_cache", self._group_updates_id_gen.get_current_token(), ) get_groups_changes_for_user = DataStore.get_groups_changes_for_user.__func__ get_group_stream_token = DataStore.get_group_stream_token.__func__ get_all_groups_for_user = DataStore.get_all_groups_for_user.__func__ def stream_positions(self): result = super(SlavedGroupServerStore, self).stream_positions() result["groups"] = self._group_updates_id_gen.get_current_token() return result def process_replication_rows(self, stream_name, token, rows): if stream_name == "groups": self._group_updates_id_gen.advance(token) for row in rows: self._group_updates_stream_cache.entity_has_changed( row.user_id, token ) return super(SlavedGroupServerStore, self).process_replication_rows( stream_name, token, rows )
class SlavedReceiptsStore(BaseSlavedStore): def __init__(self, db_conn, hs): super(SlavedReceiptsStore, self).__init__(db_conn, hs) self._receipts_id_gen = SlavedIdTracker( db_conn, "receipts_linearized", "stream_id" ) self._receipts_stream_cache = StreamChangeCache( "ReceiptsRoomChangeCache", self._receipts_id_gen.get_current_token() ) get_receipts_for_user = ReceiptsStore.__dict__["get_receipts_for_user"] get_linearized_receipts_for_room = ( ReceiptsStore.__dict__["get_linearized_receipts_for_room"] ) _get_linearized_receipts_for_rooms = ( ReceiptsStore.__dict__["_get_linearized_receipts_for_rooms"] ) get_last_receipt_event_id_for_user = ( ReceiptsStore.__dict__["get_last_receipt_event_id_for_user"] ) get_max_receipt_stream_id = DataStore.get_max_receipt_stream_id.__func__ get_all_updated_receipts = DataStore.get_all_updated_receipts.__func__ get_linearized_receipts_for_rooms = ( DataStore.get_linearized_receipts_for_rooms.__func__ ) def stream_positions(self): result = super(SlavedReceiptsStore, self).stream_positions() result["receipts"] = self._receipts_id_gen.get_current_token() return result def process_replication(self, result): stream = result.get("receipts") if stream: self._receipts_id_gen.advance(int(stream["position"])) for row in stream["rows"]: position, room_id, receipt_type, user_id = row[:4] self.invalidate_caches_for_receipt(room_id, receipt_type, user_id) self._receipts_stream_cache.entity_has_changed(room_id, position) return super(SlavedReceiptsStore, self).process_replication(result) def invalidate_caches_for_receipt(self, room_id, receipt_type, user_id): self.get_receipts_for_user.invalidate((user_id, receipt_type)) self.get_linearized_receipts_for_room.invalidate_many((room_id,)) self.get_last_receipt_event_id_for_user.invalidate( (user_id, room_id, receipt_type) )
def __init__(self, db_conn, hs): super(SlavedDeviceStore, self).__init__(db_conn, hs) self.hs = hs self._device_list_id_gen = SlavedIdTracker( db_conn, "device_lists_stream", "stream_id", ) device_list_max = self._device_list_id_gen.get_current_token() self._device_list_stream_cache = StreamChangeCache( "DeviceListStreamChangeCache", device_list_max, ) self._device_list_federation_stream_cache = StreamChangeCache( "DeviceListFederationStreamChangeCache", device_list_max, )
class SlavedDeviceInboxStore(BaseSlavedStore): def __init__(self, db_conn, hs): super(SlavedDeviceInboxStore, self).__init__(db_conn, hs) self._device_inbox_id_gen = SlavedIdTracker( db_conn, "device_max_stream_id", "stream_id", ) self._device_inbox_stream_cache = StreamChangeCache( "DeviceInboxStreamChangeCache", self._device_inbox_id_gen.get_current_token() ) self._device_federation_outbox_stream_cache = StreamChangeCache( "DeviceFederationOutboxStreamChangeCache", self._device_inbox_id_gen.get_current_token() ) self._last_device_delete_cache = ExpiringCache( cache_name="last_device_delete_cache", clock=self._clock, max_len=10000, expiry_ms=30 * 60 * 1000, ) get_to_device_stream_token = DataStore.get_to_device_stream_token.__func__ get_new_messages_for_device = DataStore.get_new_messages_for_device.__func__ get_new_device_msgs_for_remote = DataStore.get_new_device_msgs_for_remote.__func__ delete_messages_for_device = DataStore.delete_messages_for_device.__func__ delete_device_msgs_for_remote = DataStore.delete_device_msgs_for_remote.__func__ def stream_positions(self): result = super(SlavedDeviceInboxStore, self).stream_positions() result["to_device"] = self._device_inbox_id_gen.get_current_token() return result def process_replication_rows(self, stream_name, token, rows): if stream_name == "to_device": self._device_inbox_id_gen.advance(token) for row in rows: if row.entity.startswith("@"): self._device_inbox_stream_cache.entity_has_changed( row.entity, token ) else: self._device_federation_outbox_stream_cache.entity_has_changed( row.entity, token ) return super(SlavedDeviceInboxStore, self).process_replication_rows( stream_name, token, rows )
class SlavedDeviceStore(BaseSlavedStore): def __init__(self, db_conn, hs): super(SlavedDeviceStore, self).__init__(db_conn, hs) self.hs = hs self._device_list_id_gen = SlavedIdTracker( db_conn, "device_lists_stream", "stream_id", ) device_list_max = self._device_list_id_gen.get_current_token() self._device_list_stream_cache = StreamChangeCache( "DeviceListStreamChangeCache", device_list_max, ) self._device_list_federation_stream_cache = StreamChangeCache( "DeviceListFederationStreamChangeCache", device_list_max, ) get_device_stream_token = __func__(DataStore.get_device_stream_token) get_user_whose_devices_changed = __func__(DataStore.get_user_whose_devices_changed) get_devices_by_remote = __func__(DataStore.get_devices_by_remote) _get_devices_by_remote_txn = __func__(DataStore._get_devices_by_remote_txn) _get_e2e_device_keys_txn = __func__(DataStore._get_e2e_device_keys_txn) mark_as_sent_devices_by_remote = __func__(DataStore.mark_as_sent_devices_by_remote) _mark_as_sent_devices_by_remote_txn = ( __func__(DataStore._mark_as_sent_devices_by_remote_txn) ) count_e2e_one_time_keys = EndToEndKeyStore.__dict__["count_e2e_one_time_keys"] def stream_positions(self): result = super(SlavedDeviceStore, self).stream_positions() result["device_lists"] = self._device_list_id_gen.get_current_token() return result def process_replication_rows(self, stream_name, token, rows): if stream_name == "device_lists": self._device_list_id_gen.advance(token) for row in rows: self._device_list_stream_cache.entity_has_changed( row.user_id, token ) if row.destination: self._device_list_federation_stream_cache.entity_has_changed( row.destination, token ) return super(SlavedDeviceStore, self).process_replication_rows( stream_name, token, rows )
def __init__(self, db_conn, hs): account_max = self.get_max_account_data_stream_id() self._account_data_stream_cache = StreamChangeCache( "AccountDataAndTagsChangeCache", account_max, ) super(AccountDataWorkerStore, self).__init__(db_conn, hs)
def __init__(self, hs): self.store = hs.get_datastore() self.server_name = hs.config.server_name self.auth = hs.get_auth() self.is_mine_id = hs.is_mine_id self.notifier = hs.get_notifier() self.state = hs.get_state_handler() self.hs = hs self.clock = hs.get_clock() self.wheel_timer = WheelTimer(bucket_size=5000) self.federation = hs.get_federation_sender() hs.get_federation_registry().register_edu_handler("m.typing", self._recv_edu) hs.get_distributor().observe("user_left_room", self.user_left_room) self._member_typing_until = {} # clock time we expect to stop self._member_last_federation_poke = {} self._latest_room_serial = 0 self._reset() # caches which room_ids changed at which serials self._typing_stream_change_cache = StreamChangeCache( "TypingStreamChangeCache", self._latest_room_serial, ) self.clock.looping_call( self._handle_timeouts, 5000, )
class SlavedDeviceStore(EndToEndKeyWorkerStore, DeviceWorkerStore, BaseSlavedStore): def __init__(self, db_conn, hs): super(SlavedDeviceStore, self).__init__(db_conn, hs) self.hs = hs self._device_list_id_gen = SlavedIdTracker( db_conn, "device_lists_stream", "stream_id", ) device_list_max = self._device_list_id_gen.get_current_token() self._device_list_stream_cache = StreamChangeCache( "DeviceListStreamChangeCache", device_list_max, ) self._device_list_federation_stream_cache = StreamChangeCache( "DeviceListFederationStreamChangeCache", device_list_max, ) def stream_positions(self): result = super(SlavedDeviceStore, self).stream_positions() result["device_lists"] = self._device_list_id_gen.get_current_token() return result def process_replication_rows(self, stream_name, token, rows): if stream_name == "device_lists": self._device_list_id_gen.advance(token) for row in rows: self._invalidate_caches_for_devices( token, row.user_id, row.destination, ) return super(SlavedDeviceStore, self).process_replication_rows( stream_name, token, rows ) def _invalidate_caches_for_devices(self, token, user_id, destination): self._device_list_stream_cache.entity_has_changed( user_id, token ) if destination: self._device_list_federation_stream_cache.entity_has_changed( destination, token ) self._get_cached_devices_for_user.invalidate((user_id,)) self._get_cached_user_device.invalidate_many((user_id,)) self.get_device_list_last_stream_id_for_remote.invalidate((user_id,))
class SlavedPushRuleStore(SlavedEventStore): def __init__(self, db_conn, hs): super(SlavedPushRuleStore, self).__init__(db_conn, hs) self._push_rules_stream_id_gen = SlavedIdTracker( db_conn, "push_rules_stream", "stream_id", ) self.push_rules_stream_cache = StreamChangeCache( "PushRulesStreamChangeCache", self._push_rules_stream_id_gen.get_current_token(), ) get_push_rules_for_user = PushRuleStore.__dict__["get_push_rules_for_user"] get_push_rules_enabled_for_user = ( PushRuleStore.__dict__["get_push_rules_enabled_for_user"] ) have_push_rules_changed_for_user = ( DataStore.have_push_rules_changed_for_user.__func__ ) def get_push_rules_stream_token(self): return ( self._push_rules_stream_id_gen.get_current_token(), self._stream_id_gen.get_current_token(), ) def stream_positions(self): result = super(SlavedPushRuleStore, self).stream_positions() result["push_rules"] = self._push_rules_stream_id_gen.get_current_token() return result def process_replication(self, result): stream = result.get("push_rules") if stream: for row in stream["rows"]: position = row[0] user_id = row[2] self.get_push_rules_for_user.invalidate((user_id,)) self.get_push_rules_enabled_for_user.invalidate((user_id,)) self.push_rules_stream_cache.entity_has_changed( user_id, position ) self._push_rules_stream_id_gen.advance(int(stream["position"])) return super(SlavedPushRuleStore, self).process_replication(result)
def test_has_any_entity_changed(self): """ StreamChangeCache.has_any_entity_changed will return True if any entities have been changed since the provided stream position, and False if they have not. If the cache has entries and the provided stream position is before it, it will return True, otherwise False if the cache has no entries. """ cache = StreamChangeCache("#test", 1) # With no entities, it returns False for the past, present, and future. self.assertFalse(cache.has_any_entity_changed(0)) self.assertFalse(cache.has_any_entity_changed(1)) self.assertFalse(cache.has_any_entity_changed(2)) # We add an entity cache.entity_has_changed("*****@*****.**", 2) # With an entity, it returns True for the past, the stream start # position, and False for the stream position the entity was changed # on and ones after it. self.assertTrue(cache.has_any_entity_changed(0)) self.assertTrue(cache.has_any_entity_changed(1)) self.assertFalse(cache.has_any_entity_changed(2)) self.assertFalse(cache.has_any_entity_changed(3))
def __init__(self, db_conn, hs): super(SlavedPushRuleStore, self).__init__(db_conn, hs) self._push_rules_stream_id_gen = SlavedIdTracker( db_conn, "push_rules_stream", "stream_id", ) self.push_rules_stream_cache = StreamChangeCache( "PushRulesStreamChangeCache", self._push_rules_stream_id_gen.get_current_token(), )
def __init__(self, db_conn, hs): super(SlavedAccountDataStore, self).__init__(db_conn, hs) self._account_data_id_gen = SlavedIdTracker( db_conn, "account_data_max_stream_id", "stream_id", ) self._account_data_stream_cache = StreamChangeCache( "AccountDataAndTagsChangeCache", self._account_data_id_gen.get_current_token(), )
def __init__(self, db_conn, hs): super(SlavedPresenceStore, self).__init__(db_conn, hs) self._presence_id_gen = SlavedIdTracker(db_conn, "presence_stream", "stream_id") self._presence_on_startup = self._get_active_presence(db_conn) self.presence_stream_cache = self.presence_stream_cache = StreamChangeCache( "PresenceStreamChangeCache", self._presence_id_gen.get_current_token())
def __init__(self, db_conn, hs): super(SlavedReceiptsStore, self).__init__(db_conn, hs) self._receipts_id_gen = SlavedIdTracker( db_conn, "receipts_linearized", "stream_id" ) self._receipts_stream_cache = StreamChangeCache( "ReceiptsRoomChangeCache", self._receipts_id_gen.get_current_token() )
def __init__(self, db_conn, hs): super(SlavedDeviceStore, self).__init__(db_conn, hs) self.hs = hs self._device_list_id_gen = SlavedIdTracker( db_conn, "device_lists_stream", "stream_id", ) device_list_max = self._device_list_id_gen.get_current_token() self._device_list_stream_cache = StreamChangeCache( "DeviceListStreamChangeCache", device_list_max, ) self._device_list_federation_stream_cache = StreamChangeCache( "DeviceListFederationStreamChangeCache", device_list_max, )
def __init__(self, db_conn, hs): super(UserDirectorySlaveStore, self).__init__(db_conn, hs) events_max = self._stream_id_gen.get_current_token() curr_state_delta_prefill, min_curr_state_delta_id = self._get_cache_dict( db_conn, "current_state_delta_stream", entity_column="room_id", stream_column="stream_id", max_value=events_max, # As we share the stream id with events token limit=1000, ) self._curr_state_delta_stream_cache = StreamChangeCache( "_curr_state_delta_stream_cache", min_curr_state_delta_id, prefilled_cache=curr_state_delta_prefill, ) self._current_state_delta_pos = events_max
def __init__(self, database: DatabasePool, db_conn, hs): super(SlavedDeviceInboxStore, self).__init__(database, db_conn, hs) self._device_inbox_id_gen = SlavedIdTracker(db_conn, "device_inbox", "stream_id") self._device_inbox_stream_cache = StreamChangeCache( "DeviceInboxStreamChangeCache", self._device_inbox_id_gen.get_current_token(), ) self._device_federation_outbox_stream_cache = StreamChangeCache( "DeviceFederationOutboxStreamChangeCache", self._device_inbox_id_gen.get_current_token(), ) self._last_device_delete_cache = ExpiringCache( cache_name="last_device_delete_cache", clock=self._clock, max_len=10000, expiry_ms=30 * 60 * 1000, )
def __init__(self, db_conn, hs): super(StreamWorkerStore, self).__init__(db_conn, hs) events_max = self.get_room_max_stream_ordering() event_cache_prefill, min_event_val = self._get_cache_dict( db_conn, "events", entity_column="room_id", stream_column="stream_ordering", max_value=events_max, ) self._events_stream_cache = StreamChangeCache( "EventsRoomStreamChangeCache", min_event_val, prefilled_cache=event_cache_prefill, ) self._membership_stream_cache = StreamChangeCache( "MembershipStreamChangeCache", events_max, ) self._stream_order_on_start = self.get_room_max_stream_ordering()
def __init__(self, db_conn, hs): super(StreamWorkerStore, self).__init__(db_conn, hs) events_max = self.get_room_max_stream_ordering() event_cache_prefill, min_event_val = self._get_cache_dict( db_conn, "events", entity_column="room_id", stream_column="stream_ordering", max_value=events_max, ) self._events_stream_cache = StreamChangeCache( "EventsRoomStreamChangeCache", min_event_val, prefilled_cache=event_cache_prefill, ) self._membership_stream_cache = StreamChangeCache( "MembershipStreamChangeCache", events_max, ) self._stream_order_on_start = self.get_room_max_stream_ordering()
class SlavedDeviceInboxStore(DeviceInboxWorkerStore, BaseSlavedStore): def __init__(self, database: Database, db_conn, hs): super(SlavedDeviceInboxStore, self).__init__(database, db_conn, hs) self._device_inbox_id_gen = SlavedIdTracker( db_conn, "device_max_stream_id", "stream_id" ) self._device_inbox_stream_cache = StreamChangeCache( "DeviceInboxStreamChangeCache", self._device_inbox_id_gen.get_current_token(), ) self._device_federation_outbox_stream_cache = StreamChangeCache( "DeviceFederationOutboxStreamChangeCache", self._device_inbox_id_gen.get_current_token(), ) self._last_device_delete_cache = ExpiringCache( cache_name="last_device_delete_cache", clock=self._clock, max_len=10000, expiry_ms=30 * 60 * 1000, ) def stream_positions(self): result = super(SlavedDeviceInboxStore, self).stream_positions() result["to_device"] = self._device_inbox_id_gen.get_current_token() return result def process_replication_rows(self, stream_name, token, rows): if stream_name == "to_device": self._device_inbox_id_gen.advance(token) for row in rows: if row.entity.startswith("@"): self._device_inbox_stream_cache.entity_has_changed( row.entity, token ) else: self._device_federation_outbox_stream_cache.entity_has_changed( row.entity, token ) return super(SlavedDeviceInboxStore, self).process_replication_rows( stream_name, token, rows )
def __init__( self, database: DatabasePool, db_conn: LoggingDatabaseConnection, hs: "HomeServer", ): self._instance_name = hs.get_instance_name() if isinstance(database.engine, PostgresEngine): self._can_write_to_receipts = ( self._instance_name in hs.config.worker.writers.receipts) self._receipts_id_gen = MultiWriterIdGenerator( db_conn=db_conn, db=database, stream_name="receipts", instance_name=self._instance_name, tables=[("receipts_linearized", "instance_name", "stream_id")], sequence_name="receipts_sequence", writers=hs.config.worker.writers.receipts, ) else: self._can_write_to_receipts = True # We shouldn't be running in worker mode with SQLite, but its useful # to support it for unit tests. # # If this process is the writer than we need to use # `StreamIdGenerator`, otherwise we use `SlavedIdTracker` which gets # updated over replication. (Multiple writers are not supported for # SQLite). if hs.get_instance_name() in hs.config.worker.writers.receipts: self._receipts_id_gen = StreamIdGenerator( db_conn, "receipts_linearized", "stream_id") else: self._receipts_id_gen = SlavedIdTracker( db_conn, "receipts_linearized", "stream_id") super().__init__(database, db_conn, hs) self._receipts_stream_cache = StreamChangeCache( "ReceiptsRoomChangeCache", self.get_max_receipt_stream_id())
def __init__( self, database: DatabasePool, db_conn: Connection, hs: "HomeServer", ): super().__init__(database, db_conn, hs) self._can_persist_presence = ( hs.get_instance_name() in hs.config.worker.writers.presence ) if isinstance(database.engine, PostgresEngine): self._presence_id_gen = MultiWriterIdGenerator( db_conn=db_conn, db=database, stream_name="presence_stream", instance_name=self._instance_name, tables=[("presence_stream", "instance_name", "stream_id")], sequence_name="presence_stream_sequence", writers=hs.config.worker.writers.to_device, ) else: self._presence_id_gen = StreamIdGenerator( db_conn, "presence_stream", "stream_id" ) self.hs = hs self._presence_on_startup = self._get_active_presence(db_conn) presence_cache_prefill, min_presence_val = self.db_pool.get_cache_dict( db_conn, "presence_stream", entity_column="user_id", stream_column="stream_id", max_value=self._presence_id_gen.get_current_token(), ) self.presence_stream_cache = StreamChangeCache( "PresenceStreamChangeCache", min_presence_val, prefilled_cache=presence_cache_prefill, )
class SlavedPushRuleStore(SlavedEventStore): def __init__(self, db_conn, hs): super(SlavedPushRuleStore, self).__init__(db_conn, hs) self._push_rules_stream_id_gen = SlavedIdTracker( db_conn, "push_rules_stream", "stream_id", ) self.push_rules_stream_cache = StreamChangeCache( "PushRulesStreamChangeCache", self._push_rules_stream_id_gen.get_current_token(), ) get_push_rules_for_user = PushRuleStore.__dict__["get_push_rules_for_user"] get_push_rules_enabled_for_user = ( PushRuleStore.__dict__["get_push_rules_enabled_for_user"]) have_push_rules_changed_for_user = ( DataStore.have_push_rules_changed_for_user.__func__) def get_push_rules_stream_token(self): return ( self._push_rules_stream_id_gen.get_current_token(), self._stream_id_gen.get_current_token(), ) def stream_positions(self): result = super(SlavedPushRuleStore, self).stream_positions() result[ "push_rules"] = self._push_rules_stream_id_gen.get_current_token() return result def process_replication_rows(self, stream_name, token, rows): if stream_name == "push_rules": self._push_rules_stream_id_gen.advance(token) for row in rows: self.get_push_rules_for_user.invalidate((row.user_id, )) self.get_push_rules_enabled_for_user.invalidate( (row.user_id, )) self.push_rules_stream_cache.entity_has_changed( row.user_id, token) return super(SlavedPushRuleStore, self).process_replication_rows(stream_name, token, rows)
class SlavedDeviceStore(EndToEndKeyWorkerStore, DeviceWorkerStore, BaseSlavedStore): def __init__(self, db_conn, hs): super(SlavedDeviceStore, self).__init__(db_conn, hs) self.hs = hs self._device_list_id_gen = SlavedIdTracker(db_conn, "device_lists_stream", "stream_id") device_list_max = self._device_list_id_gen.get_current_token() self._device_list_stream_cache = StreamChangeCache( "DeviceListStreamChangeCache", device_list_max) self._user_signature_stream_cache = StreamChangeCache( "UserSignatureStreamChangeCache", device_list_max) self._device_list_federation_stream_cache = StreamChangeCache( "DeviceListFederationStreamChangeCache", device_list_max) def stream_positions(self): result = super(SlavedDeviceStore, self).stream_positions() result["device_lists"] = self._device_list_id_gen.get_current_token() return result def process_replication_rows(self, stream_name, token, rows): if stream_name == "device_lists": self._device_list_id_gen.advance(token) for row in rows: self._invalidate_caches_for_devices(token, row.user_id, row.destination) return super(SlavedDeviceStore, self).process_replication_rows(stream_name, token, rows) def _invalidate_caches_for_devices(self, token, user_id, destination): self._device_list_stream_cache.entity_has_changed(user_id, token) if destination: self._device_list_federation_stream_cache.entity_has_changed( destination, token) self._get_cached_devices_for_user.invalidate((user_id, )) self._get_cached_user_device.invalidate_many((user_id, )) self.get_device_list_last_stream_id_for_remote.invalidate((user_id, ))
class UserDirectorySlaveStore( SlavedEventStore, SlavedApplicationServiceStore, SlavedRegistrationStore, SlavedClientIpStore, UserDirectoryStore, BaseSlavedStore, ): def __init__(self, db_conn, hs): super(UserDirectorySlaveStore, self).__init__(db_conn, hs) events_max = self._stream_id_gen.get_current_token() curr_state_delta_prefill, min_curr_state_delta_id = self._get_cache_dict( db_conn, "current_state_delta_stream", entity_column="room_id", stream_column="stream_id", max_value=events_max, # As we share the stream id with events token limit=1000, ) self._curr_state_delta_stream_cache = StreamChangeCache( "_curr_state_delta_stream_cache", min_curr_state_delta_id, prefilled_cache=curr_state_delta_prefill, ) def stream_positions(self): result = super(UserDirectorySlaveStore, self).stream_positions() return result def process_replication_rows(self, stream_name, token, rows): if stream_name == EventsStream.NAME: self._stream_id_gen.advance(token) for row in rows: if row.type != EventsStreamCurrentStateRow.TypeId: continue self._curr_state_delta_stream_cache.entity_has_changed( row.data.room_id, token ) return super(UserDirectorySlaveStore, self).process_replication_rows( stream_name, token, rows )
class SlavedDeviceInboxStore(BaseSlavedStore): def __init__(self, db_conn, hs): super(SlavedDeviceInboxStore, self).__init__(db_conn, hs) self._device_inbox_id_gen = SlavedIdTracker( db_conn, "device_max_stream_id", "stream_id", ) self._device_inbox_stream_cache = StreamChangeCache( "DeviceInboxStreamChangeCache", self._device_inbox_id_gen.get_current_token()) self._device_federation_outbox_stream_cache = StreamChangeCache( "DeviceFederationOutboxStreamChangeCache", self._device_inbox_id_gen.get_current_token()) get_to_device_stream_token = DataStore.get_to_device_stream_token.__func__ get_new_messages_for_device = DataStore.get_new_messages_for_device.__func__ get_new_device_msgs_for_remote = DataStore.get_new_device_msgs_for_remote.__func__ delete_messages_for_device = DataStore.delete_messages_for_device.__func__ delete_device_msgs_for_remote = DataStore.delete_device_msgs_for_remote.__func__ def stream_positions(self): result = super(SlavedDeviceInboxStore, self).stream_positions() result["to_device"] = self._device_inbox_id_gen.get_current_token() return result def process_replication(self, result): stream = result.get("to_device") if stream: self._device_inbox_id_gen.advance(int(stream["position"])) for row in stream["rows"]: stream_id = row[0] entity = row[1] if entity.startswith("@"): self._device_inbox_stream_cache.entity_has_changed( entity, stream_id) else: self._device_federation_outbox_stream_cache.entity_has_changed( entity, stream_id) return super(SlavedDeviceInboxStore, self).process_replication(result)
def __init__(self, db_conn, hs): super(SlavedGroupServerStore, self).__init__(db_conn, hs) self.hs = hs self._group_updates_id_gen = SlavedIdTracker( db_conn, "local_group_updates", "stream_id", ) self._group_updates_stream_cache = StreamChangeCache( "_group_updates_stream_cache", self._group_updates_id_gen.get_current_token(), )
def __init__(self, hs): super().__init__(hs) assert hs.config.worker.writers.typing == hs.get_instance_name() self.auth = hs.get_auth() self.notifier = hs.get_notifier() self.hs = hs hs.get_federation_registry().register_edu_handler( "m.typing", self._recv_edu) hs.get_distributor().observe("user_left_room", self.user_left_room) self._member_typing_until = {} # clock time we expect to stop # caches which room_ids changed at which serials self._typing_stream_change_cache = StreamChangeCache( "TypingStreamChangeCache", self._latest_room_serial)
def __init__(self, db_conn, hs): super(SlavedDeviceInboxStore, self).__init__(db_conn, hs) self._device_inbox_id_gen = SlavedIdTracker( db_conn, "device_max_stream_id", "stream_id", ) self._device_inbox_stream_cache = StreamChangeCache( "DeviceInboxStreamChangeCache", self._device_inbox_id_gen.get_current_token() ) self._device_federation_outbox_stream_cache = StreamChangeCache( "DeviceFederationOutboxStreamChangeCache", self._device_inbox_id_gen.get_current_token() ) self._last_device_delete_cache = ExpiringCache( cache_name="last_device_delete_cache", clock=self._clock, max_len=10000, expiry_ms=30 * 60 * 1000, )
class UserDirectorySlaveStore( SlavedEventStore, SlavedApplicationServiceStore, SlavedRegistrationStore, SlavedClientIpStore, UserDirectoryStore, BaseSlavedStore, ): def __init__(self, db_conn, hs): super(UserDirectorySlaveStore, self).__init__(db_conn, hs) events_max = self._stream_id_gen.get_current_token() curr_state_delta_prefill, min_curr_state_delta_id = self._get_cache_dict( db_conn, "current_state_delta_stream", entity_column="room_id", stream_column="stream_id", max_value=events_max, # As we share the stream id with events token limit=1000, ) self._curr_state_delta_stream_cache = StreamChangeCache( "_curr_state_delta_stream_cache", min_curr_state_delta_id, prefilled_cache=curr_state_delta_prefill, ) self._current_state_delta_pos = events_max def stream_positions(self): result = super(UserDirectorySlaveStore, self).stream_positions() result["current_state_deltas"] = self._current_state_delta_pos return result def process_replication_rows(self, stream_name, token, rows): if stream_name == "current_state_deltas": self._current_state_delta_pos = token for row in rows: self._curr_state_delta_stream_cache.entity_has_changed( row.room_id, token ) return super(UserDirectorySlaveStore, self).process_replication_rows( stream_name, token, rows )
def test_has_entity_changed_pops_off_start(self): """ StreamChangeCache.entity_has_changed will respect the max size and purge the oldest items upon reaching that max size. """ cache = StreamChangeCache("#test", 1, max_size=2) cache.entity_has_changed("*****@*****.**", 2) cache.entity_has_changed("*****@*****.**", 3) cache.entity_has_changed("*****@*****.**", 4) # The cache is at the max size, 2 self.assertEqual(len(cache._cache), 2) # The oldest item has been popped off self.assertTrue("*****@*****.**" not in cache._entity_to_key) # If we update an existing entity, it keeps the two existing entities cache.entity_has_changed("*****@*****.**", 5) self.assertEqual({"*****@*****.**", "*****@*****.**"}, set(cache._entity_to_key))
def test_get_all_entities_changed(self): """ StreamChangeCache.get_all_entities_changed will return all changed entities since the given position. If the position is before the start of the known stream, it returns None instead. """ cache = StreamChangeCache("#test", 1) cache.entity_has_changed("*****@*****.**", 2) cache.entity_has_changed("*****@*****.**", 3) cache.entity_has_changed("*****@*****.**", 4) self.assertEqual( cache.get_all_entities_changed(1), ["*****@*****.**", "*****@*****.**", "*****@*****.**"], ) self.assertEqual( cache.get_all_entities_changed(2), ["*****@*****.**", "*****@*****.**"] ) self.assertEqual(cache.get_all_entities_changed(3), ["*****@*****.**"]) self.assertEqual(cache.get_all_entities_changed(0), None)
def __init__( self, database: DatabasePool, db_conn: LoggingDatabaseConnection, hs: "HomeServer", ): super().__init__(database, db_conn, hs) self.hs = hs self._device_list_id_gen = SlavedIdTracker( db_conn, "device_lists_stream", "stream_id", extra_tables=[ ("user_signature_stream", "stream_id"), ("device_lists_outbound_pokes", "stream_id"), ], ) device_list_max = self._device_list_id_gen.get_current_token() self._device_list_stream_cache = StreamChangeCache( "DeviceListStreamChangeCache", device_list_max) self._user_signature_stream_cache = StreamChangeCache( "UserSignatureStreamChangeCache", device_list_max) self._device_list_federation_stream_cache = StreamChangeCache( "DeviceListFederationStreamChangeCache", device_list_max)
def __init__(self, database: Database, db_conn, hs): self._stream_id_gen = SlavedIdTracker(db_conn, "events", "stream_ordering") self._backfill_id_gen = SlavedIdTracker( db_conn, "events", "stream_ordering", step=-1 ) super(SlavedEventStore, self).__init__(database, db_conn, hs) events_max = self._stream_id_gen.get_current_token() curr_state_delta_prefill, min_curr_state_delta_id = self.db.get_cache_dict( db_conn, "current_state_delta_stream", entity_column="room_id", stream_column="stream_id", max_value=events_max, # As we share the stream id with events token limit=1000, ) self._curr_state_delta_stream_cache = StreamChangeCache( "_curr_state_delta_stream_cache", min_curr_state_delta_id, prefilled_cache=curr_state_delta_prefill, )
def __init__(self, hs): self.store = hs.get_datastore() self.server_name = hs.config.server_name self.auth = hs.get_auth() self.is_mine_id = hs.is_mine_id self.notifier = hs.get_notifier() self.state = hs.get_state_handler() self.hs = hs self.clock = hs.get_clock() self.wheel_timer = WheelTimer(bucket_size=5000) self.federation = hs.get_federation_sender() hs.get_federation_registry().register_edu_handler( "m.typing", self._recv_edu) hs.get_distributor().observe("user_left_room", self.user_left_room) self._member_typing_until = {} # clock time we expect to stop self._member_last_federation_poke = {} # map room IDs to serial numbers self._room_serials = {} self._latest_room_serial = 0 # map room IDs to sets of users currently typing self._room_typing = {} # caches which room_ids changed at which serials self._typing_stream_change_cache = StreamChangeCache( "TypingStreamChangeCache", self._latest_room_serial, ) self.clock.looping_call( self._handle_timeouts, 5000, )
class SlavedPresenceStore(BaseSlavedStore): def __init__(self, database: Database, db_conn, hs): super(SlavedPresenceStore, self).__init__(database, db_conn, hs) self._presence_id_gen = SlavedIdTracker(db_conn, "presence_stream", "stream_id") self._presence_on_startup = self._get_active_presence(db_conn) self.presence_stream_cache = StreamChangeCache( "PresenceStreamChangeCache", self._presence_id_gen.get_current_token()) _get_active_presence = __func__(DataStore._get_active_presence) take_presence_startup_info = __func__(DataStore.take_presence_startup_info) _get_presence_for_user = PresenceStore.__dict__["_get_presence_for_user"] get_presence_for_users = PresenceStore.__dict__["get_presence_for_users"] def get_current_presence_token(self): return self._presence_id_gen.get_current_token() def stream_positions(self): result = super(SlavedPresenceStore, self).stream_positions() if self.hs.config.use_presence: position = self._presence_id_gen.get_current_token() result["presence"] = position return result def process_replication_rows(self, stream_name, token, rows): if stream_name == "presence": self._presence_id_gen.advance(token) for row in rows: self.presence_stream_cache.entity_has_changed( row.user_id, token) self._get_presence_for_user.invalidate((row.user_id, )) return super(SlavedPresenceStore, self).process_replication_rows(stream_name, token, rows)
def __init__(self, db_conn, hs): super(PushRulesWorkerStore, self).__init__(db_conn, hs) push_rules_prefill, push_rules_id = self._get_cache_dict( db_conn, "push_rules_stream", entity_column="user_id", stream_column="stream_id", max_value=self.get_max_push_rules_stream_id(), ) self.push_rules_stream_cache = StreamChangeCache( "PushRulesStreamChangeCache", push_rules_id, prefilled_cache=push_rules_prefill, )
class SlavedGroupServerStore(GroupServerWorkerStore, BaseSlavedStore): def __init__(self, database: Database, db_conn, hs): super(SlavedGroupServerStore, self).__init__(database, db_conn, hs) self.hs = hs self._group_updates_id_gen = SlavedIdTracker( db_conn, "local_group_updates", "stream_id" ) self._group_updates_stream_cache = StreamChangeCache( "_group_updates_stream_cache", self._group_updates_id_gen.get_current_token(), ) def get_group_stream_token(self): return self._group_updates_id_gen.get_current_token() def process_replication_rows(self, stream_name, instance_name, token, rows): if stream_name == GroupServerStream.NAME: self._group_updates_id_gen.advance(token) for row in rows: self._group_updates_stream_cache.entity_has_changed(row.user_id, token) return super().process_replication_rows(stream_name, instance_name, token, rows)
def __init__(self, hs: "HomeServer"): super().__init__(hs) assert hs.get_instance_name() in hs.config.worker.writers.typing self.auth = hs.get_auth() self.notifier = hs.get_notifier() self.event_auth_handler = hs.get_event_auth_handler() self.hs = hs hs.get_federation_registry().register_edu_handler( EduTypes.TYPING, self._recv_edu ) hs.get_distributor().observe("user_left_room", self.user_left_room) # clock time we expect to stop self._member_typing_until: Dict[RoomMember, int] = {} # caches which room_ids changed at which serials self._typing_stream_change_cache = StreamChangeCache( "TypingStreamChangeCache", self._latest_room_serial )
def __init__(self, database: Database, db_conn, hs): super(PushRulesWorkerStore, self).__init__(database, db_conn, hs) if hs.config.worker.worker_app is None: self._push_rules_stream_id_gen = ChainedIdGenerator( self._stream_id_gen, db_conn, "push_rules_stream", "stream_id" ) # type: Union[ChainedIdGenerator, SlavedIdTracker] else: self._push_rules_stream_id_gen = SlavedIdTracker( db_conn, "push_rules_stream", "stream_id") push_rules_prefill, push_rules_id = self.db.get_cache_dict( db_conn, "push_rules_stream", entity_column="user_id", stream_column="stream_id", max_value=self.get_max_push_rules_stream_id(), ) self.push_rules_stream_cache = StreamChangeCache( "PushRulesStreamChangeCache", push_rules_id, prefilled_cache=push_rules_prefill, )
def __init__(self, db_conn, hs): super(UserDirectorySlaveStore, self).__init__(db_conn, hs) events_max = self._stream_id_gen.get_current_token() curr_state_delta_prefill, min_curr_state_delta_id = self._get_cache_dict( db_conn, "current_state_delta_stream", entity_column="room_id", stream_column="stream_id", max_value=events_max, # As we share the stream id with events token limit=1000, ) self._curr_state_delta_stream_cache = StreamChangeCache( "_curr_state_delta_stream_cache", min_curr_state_delta_id, prefilled_cache=curr_state_delta_prefill, )
class SlavedDeviceStore(EndToEndKeyWorkerStore, DeviceWorkerStore, BaseSlavedStore): def __init__(self, database: DatabasePool, db_conn, hs): super().__init__(database, db_conn, hs) self.hs = hs self._device_list_id_gen = SlavedIdTracker( db_conn, "device_lists_stream", "stream_id", extra_tables=[ ("user_signature_stream", "stream_id"), ("device_lists_outbound_pokes", "stream_id"), ], ) device_list_max = self._device_list_id_gen.get_current_token() self._device_list_stream_cache = StreamChangeCache( "DeviceListStreamChangeCache", device_list_max) self._user_signature_stream_cache = StreamChangeCache( "UserSignatureStreamChangeCache", device_list_max) self._device_list_federation_stream_cache = StreamChangeCache( "DeviceListFederationStreamChangeCache", device_list_max) def get_device_stream_token(self) -> int: return self._device_list_id_gen.get_current_token() def process_replication_rows(self, stream_name, instance_name, token, rows): if stream_name == DeviceListsStream.NAME: self._device_list_id_gen.advance(instance_name, token) self._invalidate_caches_for_devices(token, rows) elif stream_name == UserSignatureStream.NAME: self._device_list_id_gen.advance(instance_name, token) for row in rows: self._user_signature_stream_cache.entity_has_changed( row.user_id, token) return super().process_replication_rows(stream_name, instance_name, token, rows) def _invalidate_caches_for_devices(self, token, rows): for row in rows: # The entities are either user IDs (starting with '@') whose devices # have changed, or remote servers that we need to tell about # changes. if row.entity.startswith("@"): self._device_list_stream_cache.entity_has_changed( row.entity, token) self.get_cached_devices_for_user.invalidate((row.entity, )) self._get_cached_user_device.invalidate_many((row.entity, )) self.get_device_list_last_stream_id_for_remote.invalidate( (row.entity, )) else: self._device_list_federation_stream_cache.entity_has_changed( row.entity, token)
def test_has_entity_changed_pops_off_start(self): """ StreamChangeCache.entity_has_changed will respect the max size and purge the oldest items upon reaching that max size. """ cache = StreamChangeCache("#test", 1, max_size=2) cache.entity_has_changed("*****@*****.**", 2) cache.entity_has_changed("*****@*****.**", 3) cache.entity_has_changed("*****@*****.**", 4) # The cache is at the max size, 2 self.assertEqual(len(cache._cache), 2) # The oldest item has been popped off self.assertTrue("*****@*****.**" not in cache._entity_to_key) # If we update an existing entity, it keeps the two existing entities cache.entity_has_changed("*****@*****.**", 5) self.assertEqual( set(["*****@*****.**", "*****@*****.**"]), set(cache._entity_to_key) )
def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"): super().__init__(database, db_conn, hs) events_max = self._stream_id_gen.get_current_token() curr_state_delta_prefill, min_curr_state_delta_id = self.db_pool.get_cache_dict( db_conn, "current_state_delta_stream", entity_column="room_id", stream_column="stream_id", max_value=events_max, # As we share the stream id with events token limit=1000, ) self._curr_state_delta_stream_cache = StreamChangeCache( "_curr_state_delta_stream_cache", min_curr_state_delta_id, prefilled_cache=curr_state_delta_prefill, )
class AccountDataWorkerStore(SQLBaseStore): """This is an abstract base class where subclasses must implement `get_max_account_data_stream_id` which can be called in the initializer. """ # This ABCMeta metaclass ensures that we cannot be instantiated without # the abstract methods being implemented. __metaclass__ = abc.ABCMeta def __init__(self, db_conn, hs): account_max = self.get_max_account_data_stream_id() self._account_data_stream_cache = StreamChangeCache( "AccountDataAndTagsChangeCache", account_max ) super(AccountDataWorkerStore, self).__init__(db_conn, hs) @abc.abstractmethod def get_max_account_data_stream_id(self): """Get the current max stream ID for account data stream Returns: int """ raise NotImplementedError() @cached() def get_account_data_for_user(self, user_id): """Get all the client account_data for a user. Args: user_id(str): The user to get the account_data for. Returns: A deferred pair of a dict of global account_data and a dict mapping from room_id string to per room account_data dicts. """ def get_account_data_for_user_txn(txn): rows = self._simple_select_list_txn( txn, "account_data", {"user_id": user_id}, ["account_data_type", "content"], ) global_account_data = { row["account_data_type"]: json.loads(row["content"]) for row in rows } rows = self._simple_select_list_txn( txn, "room_account_data", {"user_id": user_id}, ["room_id", "account_data_type", "content"], ) by_room = {} for row in rows: room_data = by_room.setdefault(row["room_id"], {}) room_data[row["account_data_type"]] = json.loads(row["content"]) return global_account_data, by_room return self.runInteraction( "get_account_data_for_user", get_account_data_for_user_txn ) @cachedInlineCallbacks(num_args=2, max_entries=5000) def get_global_account_data_by_type_for_user(self, data_type, user_id): """ Returns: Deferred: A dict """ result = yield self._simple_select_one_onecol( table="account_data", keyvalues={"user_id": user_id, "account_data_type": data_type}, retcol="content", desc="get_global_account_data_by_type_for_user", allow_none=True, ) if result: return json.loads(result) else: return None @cached(num_args=2) def get_account_data_for_room(self, user_id, room_id): """Get all the client account_data for a user for a room. Args: user_id(str): The user to get the account_data for. room_id(str): The room to get the account_data for. Returns: A deferred dict of the room account_data """ def get_account_data_for_room_txn(txn): rows = self._simple_select_list_txn( txn, "room_account_data", {"user_id": user_id, "room_id": room_id}, ["account_data_type", "content"], ) return { row["account_data_type"]: json.loads(row["content"]) for row in rows } return self.runInteraction( "get_account_data_for_room", get_account_data_for_room_txn ) @cached(num_args=3, max_entries=5000) def get_account_data_for_room_and_type(self, user_id, room_id, account_data_type): """Get the client account_data of given type for a user for a room. Args: user_id(str): The user to get the account_data for. room_id(str): The room to get the account_data for. account_data_type (str): The account data type to get. Returns: A deferred of the room account_data for that type, or None if there isn't any set. """ def get_account_data_for_room_and_type_txn(txn): content_json = self._simple_select_one_onecol_txn( txn, table="room_account_data", keyvalues={ "user_id": user_id, "room_id": room_id, "account_data_type": account_data_type, }, retcol="content", allow_none=True, ) return json.loads(content_json) if content_json else None return self.runInteraction( "get_account_data_for_room_and_type", get_account_data_for_room_and_type_txn ) def get_all_updated_account_data( self, last_global_id, last_room_id, current_id, limit ): """Get all the client account_data that has changed on the server Args: last_global_id(int): The position to fetch from for top level data last_room_id(int): The position to fetch from for per room data current_id(int): The position to fetch up to. Returns: A deferred pair of lists of tuples of stream_id int, user_id string, room_id string, and type string. """ if last_room_id == current_id and last_global_id == current_id: return defer.succeed(([], [])) def get_updated_account_data_txn(txn): sql = ( "SELECT stream_id, user_id, account_data_type" " FROM account_data WHERE ? < stream_id AND stream_id <= ?" " ORDER BY stream_id ASC LIMIT ?" ) txn.execute(sql, (last_global_id, current_id, limit)) global_results = txn.fetchall() sql = ( "SELECT stream_id, user_id, room_id, account_data_type" " FROM room_account_data WHERE ? < stream_id AND stream_id <= ?" " ORDER BY stream_id ASC LIMIT ?" ) txn.execute(sql, (last_room_id, current_id, limit)) room_results = txn.fetchall() return global_results, room_results return self.runInteraction( "get_all_updated_account_data_txn", get_updated_account_data_txn ) def get_updated_account_data_for_user(self, user_id, stream_id): """Get all the client account_data for a that's changed for a user Args: user_id(str): The user to get the account_data for. stream_id(int): The point in the stream since which to get updates Returns: A deferred pair of a dict of global account_data and a dict mapping from room_id string to per room account_data dicts. """ def get_updated_account_data_for_user_txn(txn): sql = ( "SELECT account_data_type, content FROM account_data" " WHERE user_id = ? AND stream_id > ?" ) txn.execute(sql, (user_id, stream_id)) global_account_data = {row[0]: json.loads(row[1]) for row in txn} sql = ( "SELECT room_id, account_data_type, content FROM room_account_data" " WHERE user_id = ? AND stream_id > ?" ) txn.execute(sql, (user_id, stream_id)) account_data_by_room = {} for row in txn: room_account_data = account_data_by_room.setdefault(row[0], {}) room_account_data[row[1]] = json.loads(row[2]) return global_account_data, account_data_by_room changed = self._account_data_stream_cache.has_entity_changed( user_id, int(stream_id) ) if not changed: return {}, {} return self.runInteraction( "get_updated_account_data_for_user", get_updated_account_data_for_user_txn ) @cachedInlineCallbacks(num_args=2, cache_context=True, max_entries=5000) def is_ignored_by(self, ignored_user_id, ignorer_user_id, cache_context): ignored_account_data = yield self.get_global_account_data_by_type_for_user( "m.ignored_user_list", ignorer_user_id, on_invalidate=cache_context.invalidate, ) if not ignored_account_data: return False return ignored_user_id in ignored_account_data.get("ignored_users", {})
class PresenceStore(PresenceBackgroundUpdateStore): def __init__( self, database: DatabasePool, db_conn: LoggingDatabaseConnection, hs: "HomeServer", ): super().__init__(database, db_conn, hs) self._can_persist_presence = (hs.get_instance_name() in hs.config.worker.writers.presence) if isinstance(database.engine, PostgresEngine): self._presence_id_gen = MultiWriterIdGenerator( db_conn=db_conn, db=database, stream_name="presence_stream", instance_name=self._instance_name, tables=[("presence_stream", "instance_name", "stream_id")], sequence_name="presence_stream_sequence", writers=hs.config.worker.writers.presence, ) else: self._presence_id_gen = StreamIdGenerator(db_conn, "presence_stream", "stream_id") self.hs = hs self._presence_on_startup = self._get_active_presence(db_conn) presence_cache_prefill, min_presence_val = self.db_pool.get_cache_dict( db_conn, "presence_stream", entity_column="user_id", stream_column="stream_id", max_value=self._presence_id_gen.get_current_token(), ) self.presence_stream_cache = StreamChangeCache( "PresenceStreamChangeCache", min_presence_val, prefilled_cache=presence_cache_prefill, ) async def update_presence(self, presence_states) -> Tuple[int, int]: assert self._can_persist_presence stream_ordering_manager = self._presence_id_gen.get_next_mult( len(presence_states)) async with stream_ordering_manager as stream_orderings: await self.db_pool.runInteraction( "update_presence", self._update_presence_txn, stream_orderings, presence_states, ) return stream_orderings[-1], self._presence_id_gen.get_current_token() def _update_presence_txn(self, txn, stream_orderings, presence_states): for stream_id, state in zip(stream_orderings, presence_states): txn.call_after(self.presence_stream_cache.entity_has_changed, state.user_id, stream_id) txn.call_after(self._get_presence_for_user.invalidate, (state.user_id, )) # Delete old rows to stop database from getting really big sql = "DELETE FROM presence_stream WHERE stream_id < ? AND " for states in batch_iter(presence_states, 50): clause, args = make_in_list_sql_clause(self.database_engine, "user_id", [s.user_id for s in states]) txn.execute(sql + clause, [stream_id] + list(args)) # Actually insert new rows self.db_pool.simple_insert_many_txn( txn, table="presence_stream", keys=( "stream_id", "user_id", "state", "last_active_ts", "last_federation_update_ts", "last_user_sync_ts", "status_msg", "currently_active", "instance_name", ), values=[( stream_id, state.user_id, state.state, state.last_active_ts, state.last_federation_update_ts, state.last_user_sync_ts, state.status_msg, state.currently_active, self._instance_name, ) for stream_id, state in zip(stream_orderings, presence_states)], ) async def get_all_presence_updates( self, instance_name: str, last_id: int, current_id: int, limit: int) -> Tuple[List[Tuple[int, list]], int, bool]: """Get updates for presence replication stream. Args: instance_name: The writer we want to fetch updates from. Unused here since there is only ever one writer. last_id: The token to fetch updates from. Exclusive. current_id: The token to fetch updates up to. Inclusive. limit: The requested limit for the number of rows to return. The function may return more or fewer rows. Returns: A tuple consisting of: the updates, a token to use to fetch subsequent updates, and whether we returned fewer rows than exists between the requested tokens due to the limit. The token returned can be used in a subsequent call to this function to get further updatees. The updates are a list of 2-tuples of stream ID and the row data """ if last_id == current_id: return [], current_id, False def get_all_presence_updates_txn(txn): sql = """ SELECT stream_id, user_id, state, last_active_ts, last_federation_update_ts, last_user_sync_ts, status_msg, currently_active FROM presence_stream WHERE ? < stream_id AND stream_id <= ? ORDER BY stream_id ASC LIMIT ? """ txn.execute(sql, (last_id, current_id, limit)) updates = [(row[0], row[1:]) for row in txn] upper_bound = current_id limited = False if len(updates) >= limit: upper_bound = updates[-1][0] limited = True return updates, upper_bound, limited return await self.db_pool.runInteraction("get_all_presence_updates", get_all_presence_updates_txn) @cached() def _get_presence_for_user(self, user_id): raise NotImplementedError() @cachedList( cached_method_name="_get_presence_for_user", list_name="user_ids", num_args=1, ) async def get_presence_for_users(self, user_ids): rows = await self.db_pool.simple_select_many_batch( table="presence_stream", column="user_id", iterable=user_ids, keyvalues={}, retcols=( "user_id", "state", "last_active_ts", "last_federation_update_ts", "last_user_sync_ts", "status_msg", "currently_active", ), desc="get_presence_for_users", ) for row in rows: row["currently_active"] = bool(row["currently_active"]) return {row["user_id"]: UserPresenceState(**row) for row in rows} async def should_user_receive_full_presence_with_token( self, user_id: str, from_token: int, ) -> bool: """Check whether the given user should receive full presence using the stream token they're updating from. Args: user_id: The ID of the user to check. from_token: The stream token included in their /sync token. Returns: True if the user should have full presence sent to them, False otherwise. """ def _should_user_receive_full_presence_with_token_txn(txn): sql = """ SELECT 1 FROM users_to_send_full_presence_to WHERE user_id = ? AND presence_stream_id >= ? """ txn.execute(sql, (user_id, from_token)) return bool(txn.fetchone()) return await self.db_pool.runInteraction( "should_user_receive_full_presence_with_token", _should_user_receive_full_presence_with_token_txn, ) async def add_users_to_send_full_presence_to(self, user_ids: Iterable[str]): """Adds to the list of users who should receive a full snapshot of presence upon their next sync. Args: user_ids: An iterable of user IDs. """ # Add user entries to the table, updating the presence_stream_id column if the user already # exists in the table. presence_stream_id = self._presence_id_gen.get_current_token() await self.db_pool.simple_upsert_many( table="users_to_send_full_presence_to", key_names=("user_id", ), key_values=[(user_id, ) for user_id in user_ids], value_names=("presence_stream_id", ), # We save the current presence stream ID token along with the user ID entry so # that when a user /sync's, even if they syncing multiple times across separate # devices at different times, each device will receive full presence once - when # the presence stream ID in their sync token is less than the one in the table # for their user ID. value_values=[(presence_stream_id, ) for _ in user_ids], desc="add_users_to_send_full_presence_to", ) async def get_presence_for_all_users( self, include_offline: bool = True, ) -> Dict[str, UserPresenceState]: """Retrieve the current presence state for all users. Note that the presence_stream table is culled frequently, so it should only contain the latest presence state for each user. Args: include_offline: Whether to include offline presence states Returns: A dict of user IDs to their current UserPresenceState. """ users_to_state = {} exclude_keyvalues = None if not include_offline: # Exclude offline presence state exclude_keyvalues = {"state": "offline"} # This may be a very heavy database query. # We paginate in order to not block a database connection. limit = 100 offset = 0 while True: rows = await self.db_pool.runInteraction( "get_presence_for_all_users", self.db_pool.simple_select_list_paginate_txn, "presence_stream", orderby="stream_id", start=offset, limit=limit, exclude_keyvalues=exclude_keyvalues, retcols=( "user_id", "state", "last_active_ts", "last_federation_update_ts", "last_user_sync_ts", "status_msg", "currently_active", ), order_direction="ASC", ) for row in rows: users_to_state[row["user_id"]] = UserPresenceState(**row) # We've run out of updates to query if len(rows) < limit: break offset += limit return users_to_state def get_current_presence_token(self): return self._presence_id_gen.get_current_token() def _get_active_presence(self, db_conn: Connection): """Fetch non-offline presence from the database so that we can register the appropriate time outs. """ # The `presence_stream_state_not_offline_idx` index should be used for this # query. sql = ( "SELECT user_id, state, last_active_ts, last_federation_update_ts," " last_user_sync_ts, status_msg, currently_active FROM presence_stream" " WHERE state != ?") txn = db_conn.cursor() txn.execute(sql, (PresenceState.OFFLINE, )) rows = self.db_pool.cursor_to_dict(txn) txn.close() for row in rows: row["currently_active"] = bool(row["currently_active"]) return [UserPresenceState(**row) for row in rows] def take_presence_startup_info(self): active_on_startup = self._presence_on_startup self._presence_on_startup = None return active_on_startup def process_replication_rows(self, stream_name, instance_name, token, rows): if stream_name == PresenceStream.NAME: self._presence_id_gen.advance(instance_name, token) for row in rows: self.presence_stream_cache.entity_has_changed( row.user_id, token) self._get_presence_for_user.invalidate((row.user_id, )) return super().process_replication_rows(stream_name, instance_name, token, rows)
class ReceiptsStore(SQLBaseStore): def __init__(self, hs): super(ReceiptsStore, self).__init__(hs) self._receipts_stream_cache = StreamChangeCache( "ReceiptsRoomChangeCache", self._receipts_id_gen.get_current_token() ) @cachedInlineCallbacks() def get_users_with_read_receipts_in_room(self, room_id): receipts = yield self.get_receipts_for_room(room_id, "m.read") defer.returnValue(set(r['user_id'] for r in receipts)) def _invalidate_get_users_with_receipts_in_room(self, room_id, receipt_type, user_id): if receipt_type != "m.read": return # Returns an ObservableDeferred res = self.get_users_with_read_receipts_in_room.cache.get((room_id,), None) if res and res.called and user_id in res.result: # We'd only be adding to the set, so no point invalidating if the # user is already there return self.get_users_with_read_receipts_in_room.invalidate((room_id,)) @cached(num_args=2) def get_receipts_for_room(self, room_id, receipt_type): return self._simple_select_list( table="receipts_linearized", keyvalues={ "room_id": room_id, "receipt_type": receipt_type, }, retcols=("user_id", "event_id"), desc="get_receipts_for_room", ) @cached(num_args=3) def get_last_receipt_event_id_for_user(self, user_id, room_id, receipt_type): return self._simple_select_one_onecol( table="receipts_linearized", keyvalues={ "room_id": room_id, "receipt_type": receipt_type, "user_id": user_id }, retcol="event_id", desc="get_own_receipt_for_user", allow_none=True, ) @cachedInlineCallbacks(num_args=2) def get_receipts_for_user(self, user_id, receipt_type): rows = yield self._simple_select_list( table="receipts_linearized", keyvalues={ "user_id": user_id, "receipt_type": receipt_type, }, retcols=("room_id", "event_id"), desc="get_receipts_for_user", ) defer.returnValue({row["room_id"]: row["event_id"] for row in rows}) @defer.inlineCallbacks def get_linearized_receipts_for_rooms(self, room_ids, to_key, from_key=None): """Get receipts for multiple rooms for sending to clients. Args: room_ids (list): List of room_ids. to_key (int): Max stream id to fetch receipts upto. from_key (int): Min stream id to fetch receipts from. None fetches from the start. Returns: list: A list of receipts. """ room_ids = set(room_ids) if from_key: room_ids = yield self._receipts_stream_cache.get_entities_changed( room_ids, from_key ) results = yield self._get_linearized_receipts_for_rooms( room_ids, to_key, from_key=from_key ) defer.returnValue([ev for res in results.values() for ev in res]) @cachedInlineCallbacks(num_args=3, max_entries=5000, lru=True, tree=True) def get_linearized_receipts_for_room(self, room_id, to_key, from_key=None): """Get receipts for a single room for sending to clients. Args: room_ids (str): The room id. to_key (int): Max stream id to fetch receipts upto. from_key (int): Min stream id to fetch receipts from. None fetches from the start. Returns: list: A list of receipts. """ def f(txn): if from_key: sql = ( "SELECT * FROM receipts_linearized WHERE" " room_id = ? AND stream_id > ? AND stream_id <= ?" ) txn.execute( sql, (room_id, from_key, to_key) ) else: sql = ( "SELECT * FROM receipts_linearized WHERE" " room_id = ? AND stream_id <= ?" ) txn.execute( sql, (room_id, to_key) ) rows = self.cursor_to_dict(txn) return rows rows = yield self.runInteraction( "get_linearized_receipts_for_room", f ) if not rows: defer.returnValue([]) content = {} for row in rows: content.setdefault( row["event_id"], {} ).setdefault( row["receipt_type"], {} )[row["user_id"]] = json.loads(row["data"]) defer.returnValue([{ "type": "m.receipt", "room_id": room_id, "content": content, }]) @cachedList(cached_method_name="get_linearized_receipts_for_room", list_name="room_ids", num_args=3, inlineCallbacks=True) def _get_linearized_receipts_for_rooms(self, room_ids, to_key, from_key=None): if not room_ids: defer.returnValue({}) def f(txn): if from_key: sql = ( "SELECT * FROM receipts_linearized WHERE" " room_id IN (%s) AND stream_id > ? AND stream_id <= ?" ) % ( ",".join(["?"] * len(room_ids)) ) args = list(room_ids) args.extend([from_key, to_key]) txn.execute(sql, args) else: sql = ( "SELECT * FROM receipts_linearized WHERE" " room_id IN (%s) AND stream_id <= ?" ) % ( ",".join(["?"] * len(room_ids)) ) args = list(room_ids) args.append(to_key) txn.execute(sql, args) return self.cursor_to_dict(txn) txn_results = yield self.runInteraction( "_get_linearized_receipts_for_rooms", f ) results = {} for row in txn_results: # We want a single event per room, since we want to batch the # receipts by room, event and type. room_event = results.setdefault(row["room_id"], { "type": "m.receipt", "room_id": row["room_id"], "content": {}, }) # The content is of the form: # {"$foo:bar": { "read": { "@user:host": <receipt> }, .. }, .. } event_entry = room_event["content"].setdefault(row["event_id"], {}) receipt_type = event_entry.setdefault(row["receipt_type"], {}) receipt_type[row["user_id"]] = json.loads(row["data"]) results = { room_id: [results[room_id]] if room_id in results else [] for room_id in room_ids } defer.returnValue(results) def get_max_receipt_stream_id(self): return self._receipts_id_gen.get_current_token() def insert_linearized_receipt_txn(self, txn, room_id, receipt_type, user_id, event_id, data, stream_id): txn.call_after( self.get_receipts_for_room.invalidate, (room_id, receipt_type) ) txn.call_after( self._invalidate_get_users_with_receipts_in_room, room_id, receipt_type, user_id, ) txn.call_after( self.get_receipts_for_user.invalidate, (user_id, receipt_type) ) # FIXME: This shouldn't invalidate the whole cache txn.call_after(self.get_linearized_receipts_for_room.invalidate_many, (room_id,)) txn.call_after( self._receipts_stream_cache.entity_has_changed, room_id, stream_id ) txn.call_after( self.get_last_receipt_event_id_for_user.invalidate, (user_id, room_id, receipt_type) ) res = self._simple_select_one_txn( txn, table="events", retcols=["topological_ordering", "stream_ordering"], keyvalues={"event_id": event_id}, allow_none=True ) topological_ordering = int(res["topological_ordering"]) if res else None stream_ordering = int(res["stream_ordering"]) if res else None # We don't want to clobber receipts for more recent events, so we # have to compare orderings of existing receipts sql = ( "SELECT topological_ordering, stream_ordering, event_id FROM events" " INNER JOIN receipts_linearized as r USING (event_id, room_id)" " WHERE r.room_id = ? AND r.receipt_type = ? AND r.user_id = ?" ) txn.execute(sql, (room_id, receipt_type, user_id)) results = txn.fetchall() if results and topological_ordering: for to, so, _ in results: if int(to) > topological_ordering: return False elif int(to) == topological_ordering and int(so) >= stream_ordering: return False self._simple_delete_txn( txn, table="receipts_linearized", keyvalues={ "room_id": room_id, "receipt_type": receipt_type, "user_id": user_id, } ) self._simple_insert_txn( txn, table="receipts_linearized", values={ "stream_id": stream_id, "room_id": room_id, "receipt_type": receipt_type, "user_id": user_id, "event_id": event_id, "data": json.dumps(data), } ) if receipt_type == "m.read" and topological_ordering: self._remove_old_push_actions_before_txn( txn, room_id=room_id, user_id=user_id, topological_ordering=topological_ordering, ) return True @defer.inlineCallbacks def insert_receipt(self, room_id, receipt_type, user_id, event_ids, data): """Insert a receipt, either from local client or remote server. Automatically does conversion between linearized and graph representations. """ if not event_ids: return if len(event_ids) == 1: linearized_event_id = event_ids[0] else: # we need to points in graph -> linearized form. # TODO: Make this better. def graph_to_linear(txn): query = ( "SELECT event_id WHERE room_id = ? AND stream_ordering IN (" " SELECT max(stream_ordering) WHERE event_id IN (%s)" ")" ) % (",".join(["?"] * len(event_ids))) txn.execute(query, [room_id] + event_ids) rows = txn.fetchall() if rows: return rows[0][0] else: raise RuntimeError("Unrecognized event_ids: %r" % (event_ids,)) linearized_event_id = yield self.runInteraction( "insert_receipt_conv", graph_to_linear ) stream_id_manager = self._receipts_id_gen.get_next() with stream_id_manager as stream_id: have_persisted = yield self.runInteraction( "insert_linearized_receipt", self.insert_linearized_receipt_txn, room_id, receipt_type, user_id, linearized_event_id, data, stream_id=stream_id, ) if not have_persisted: defer.returnValue(None) yield self.insert_graph_receipt( room_id, receipt_type, user_id, event_ids, data ) max_persisted_id = self._stream_id_gen.get_current_token() defer.returnValue((stream_id, max_persisted_id)) def insert_graph_receipt(self, room_id, receipt_type, user_id, event_ids, data): return self.runInteraction( "insert_graph_receipt", self.insert_graph_receipt_txn, room_id, receipt_type, user_id, event_ids, data ) def insert_graph_receipt_txn(self, txn, room_id, receipt_type, user_id, event_ids, data): txn.call_after( self.get_receipts_for_room.invalidate, (room_id, receipt_type) ) txn.call_after( self._invalidate_get_users_with_receipts_in_room, room_id, receipt_type, user_id, ) txn.call_after( self.get_receipts_for_user.invalidate, (user_id, receipt_type) ) # FIXME: This shouldn't invalidate the whole cache txn.call_after(self.get_linearized_receipts_for_room.invalidate_many, (room_id,)) self._simple_delete_txn( txn, table="receipts_graph", keyvalues={ "room_id": room_id, "receipt_type": receipt_type, "user_id": user_id, } ) self._simple_insert_txn( txn, table="receipts_graph", values={ "room_id": room_id, "receipt_type": receipt_type, "user_id": user_id, "event_ids": json.dumps(event_ids), "data": json.dumps(data), } ) def get_all_updated_receipts(self, last_id, current_id, limit=None): if last_id == current_id: return defer.succeed([]) def get_all_updated_receipts_txn(txn): sql = ( "SELECT stream_id, room_id, receipt_type, user_id, event_id, data" " FROM receipts_linearized" " WHERE ? < stream_id AND stream_id <= ?" " ORDER BY stream_id ASC" ) args = [last_id, current_id] if limit is not None: sql += " LIMIT ?" args.append(limit) txn.execute(sql, args) return txn.fetchall() return self.runInteraction( "get_all_updated_receipts", get_all_updated_receipts_txn )
def __init__(self, hs): super(ReceiptsStore, self).__init__(hs) self._receipts_stream_cache = StreamChangeCache( "ReceiptsRoomChangeCache", self._receipts_id_gen.get_current_token() )
class SlavedEventStore(BaseSlavedStore): def __init__(self, db_conn, hs): super(SlavedEventStore, self).__init__(db_conn, hs) self._stream_id_gen = SlavedIdTracker( db_conn, "events", "stream_ordering", ) self._backfill_id_gen = SlavedIdTracker( db_conn, "events", "stream_ordering", step=-1 ) events_max = self._stream_id_gen.get_current_token() event_cache_prefill, min_event_val = self._get_cache_dict( db_conn, "events", entity_column="room_id", stream_column="stream_ordering", max_value=events_max, ) self._events_stream_cache = StreamChangeCache( "EventsRoomStreamChangeCache", min_event_val, prefilled_cache=event_cache_prefill, ) self._membership_stream_cache = StreamChangeCache( "MembershipStreamChangeCache", events_max, ) self.stream_ordering_month_ago = 0 self._stream_order_on_start = self.get_room_max_stream_ordering() # Cached functions can't be accessed through a class instance so we need # to reach inside the __dict__ to extract them. get_rooms_for_user = RoomMemberStore.__dict__["get_rooms_for_user"] get_users_in_room = RoomMemberStore.__dict__["get_users_in_room"] get_latest_event_ids_in_room = EventFederationStore.__dict__[ "get_latest_event_ids_in_room" ] _get_current_state_for_key = StateStore.__dict__[ "_get_current_state_for_key" ] get_invited_rooms_for_user = RoomMemberStore.__dict__[ "get_invited_rooms_for_user" ] get_unread_event_push_actions_by_room_for_user = ( EventPushActionsStore.__dict__["get_unread_event_push_actions_by_room_for_user"] ) _get_state_group_for_events = ( StateStore.__dict__["_get_state_group_for_events"] ) _get_state_group_for_event = ( StateStore.__dict__["_get_state_group_for_event"] ) _get_state_groups_from_groups = ( StateStore.__dict__["_get_state_groups_from_groups"] ) _get_state_groups_from_groups_txn = ( DataStore._get_state_groups_from_groups_txn.__func__ ) _get_state_group_from_group = ( StateStore.__dict__["_get_state_group_from_group"] ) get_recent_event_ids_for_room = ( StreamStore.__dict__["get_recent_event_ids_for_room"] ) get_unread_push_actions_for_user_in_range_for_http = ( DataStore.get_unread_push_actions_for_user_in_range_for_http.__func__ ) get_unread_push_actions_for_user_in_range_for_email = ( DataStore.get_unread_push_actions_for_user_in_range_for_email.__func__ ) get_push_action_users_in_range = ( DataStore.get_push_action_users_in_range.__func__ ) get_event = DataStore.get_event.__func__ get_events = DataStore.get_events.__func__ get_current_state = DataStore.get_current_state.__func__ get_current_state_for_key = DataStore.get_current_state_for_key.__func__ get_rooms_for_user_where_membership_is = ( DataStore.get_rooms_for_user_where_membership_is.__func__ ) get_membership_changes_for_user = ( DataStore.get_membership_changes_for_user.__func__ ) get_room_events_max_id = DataStore.get_room_events_max_id.__func__ get_room_events_stream_for_room = ( DataStore.get_room_events_stream_for_room.__func__ ) get_events_around = DataStore.get_events_around.__func__ get_state_for_event = DataStore.get_state_for_event.__func__ get_state_for_events = DataStore.get_state_for_events.__func__ get_state_groups = DataStore.get_state_groups.__func__ get_state_groups_ids = DataStore.get_state_groups_ids.__func__ get_state_ids_for_event = DataStore.get_state_ids_for_event.__func__ get_state_ids_for_events = DataStore.get_state_ids_for_events.__func__ get_joined_users_from_state = DataStore.get_joined_users_from_state.__func__ get_joined_users_from_context = DataStore.get_joined_users_from_context.__func__ _get_joined_users_from_context = ( RoomMemberStore.__dict__["_get_joined_users_from_context"] ) get_recent_events_for_room = DataStore.get_recent_events_for_room.__func__ get_room_events_stream_for_rooms = ( DataStore.get_room_events_stream_for_rooms.__func__ ) is_host_joined = DataStore.is_host_joined.__func__ _is_host_joined = RoomMemberStore.__dict__["_is_host_joined"] get_stream_token_for_event = DataStore.get_stream_token_for_event.__func__ _set_before_and_after = staticmethod(DataStore._set_before_and_after) _get_events = DataStore._get_events.__func__ _get_events_from_cache = DataStore._get_events_from_cache.__func__ _invalidate_get_event_cache = DataStore._invalidate_get_event_cache.__func__ _enqueue_events = DataStore._enqueue_events.__func__ _do_fetch = DataStore._do_fetch.__func__ _fetch_event_rows = DataStore._fetch_event_rows.__func__ _get_event_from_row = DataStore._get_event_from_row.__func__ _get_rooms_for_user_where_membership_is_txn = ( DataStore._get_rooms_for_user_where_membership_is_txn.__func__ ) _get_members_rows_txn = DataStore._get_members_rows_txn.__func__ _get_state_for_groups = DataStore._get_state_for_groups.__func__ _get_all_state_from_cache = DataStore._get_all_state_from_cache.__func__ _get_events_around_txn = DataStore._get_events_around_txn.__func__ _get_some_state_from_cache = DataStore._get_some_state_from_cache.__func__ get_backfill_events = DataStore.get_backfill_events.__func__ _get_backfill_events = DataStore._get_backfill_events.__func__ get_missing_events = DataStore.get_missing_events.__func__ _get_missing_events = DataStore._get_missing_events.__func__ get_auth_chain = DataStore.get_auth_chain.__func__ get_auth_chain_ids = DataStore.get_auth_chain_ids.__func__ _get_auth_chain_ids_txn = DataStore._get_auth_chain_ids_txn.__func__ get_room_max_stream_ordering = DataStore.get_room_max_stream_ordering.__func__ get_forward_extremeties_for_room = ( DataStore.get_forward_extremeties_for_room.__func__ ) _get_forward_extremeties_for_room = ( EventFederationStore.__dict__["_get_forward_extremeties_for_room"] ) def stream_positions(self): result = super(SlavedEventStore, self).stream_positions() result["events"] = self._stream_id_gen.get_current_token() result["backfill"] = -self._backfill_id_gen.get_current_token() return result def process_replication(self, result): state_resets = set( r[0] for r in result.get("state_resets", {"rows": []})["rows"] ) stream = result.get("events") if stream: self._stream_id_gen.advance(int(stream["position"])) for row in stream["rows"]: self._process_replication_row( row, backfilled=False, state_resets=state_resets ) stream = result.get("backfill") if stream: self._backfill_id_gen.advance(-int(stream["position"])) for row in stream["rows"]: self._process_replication_row( row, backfilled=True, state_resets=state_resets ) stream = result.get("forward_ex_outliers") if stream: self._stream_id_gen.advance(int(stream["position"])) for row in stream["rows"]: event_id = row[1] self._invalidate_get_event_cache(event_id) stream = result.get("backward_ex_outliers") if stream: self._backfill_id_gen.advance(-int(stream["position"])) for row in stream["rows"]: event_id = row[1] self._invalidate_get_event_cache(event_id) return super(SlavedEventStore, self).process_replication(result) def _process_replication_row(self, row, backfilled, state_resets): position = row[0] internal = json.loads(row[1]) event_json = json.loads(row[2]) event = FrozenEvent(event_json, internal_metadata_dict=internal) self.invalidate_caches_for_event( event, backfilled, reset_state=position in state_resets ) def invalidate_caches_for_event(self, event, backfilled, reset_state): if reset_state: self._get_current_state_for_key.invalidate_all() self.get_rooms_for_user.invalidate_all() self.get_users_in_room.invalidate((event.room_id,)) self._invalidate_get_event_cache(event.event_id) self.get_latest_event_ids_in_room.invalidate((event.room_id,)) self.get_unread_event_push_actions_by_room_for_user.invalidate_many( (event.room_id,) ) if not backfilled: self._events_stream_cache.entity_has_changed( event.room_id, event.internal_metadata.stream_ordering ) # self.get_unread_event_push_actions_by_room_for_user.invalidate_many( # (event.room_id,) # ) if event.type == EventTypes.Redaction: self._invalidate_get_event_cache(event.redacts) if event.type == EventTypes.Member: self.get_rooms_for_user.invalidate((event.state_key,)) self.get_users_in_room.invalidate((event.room_id,)) self._membership_stream_cache.entity_has_changed( event.state_key, event.internal_metadata.stream_ordering ) self.get_invited_rooms_for_user.invalidate((event.state_key,)) if not event.is_state(): return if backfilled: return if (not event.internal_metadata.is_invite_from_remote() and event.internal_metadata.is_outlier()): return self._get_current_state_for_key.invalidate(( event.room_id, event.type, event.state_key ))
def __init__(self, database: DatabasePool, db_conn, hs): super().__init__(database, db_conn, hs) self._instance_name = hs.get_instance_name() # Map of (user_id, device_id) to the last stream_id that has been # deleted up to. This is so that we can no op deletions. self._last_device_delete_cache = ExpiringCache( cache_name="last_device_delete_cache", clock=self._clock, max_len=10000, expiry_ms=30 * 60 * 1000, ) if isinstance(database.engine, PostgresEngine): self._can_write_to_device = (self._instance_name in hs.config.worker.writers.to_device) self._device_inbox_id_gen = MultiWriterIdGenerator( db_conn=db_conn, db=database, stream_name="to_device", instance_name=self._instance_name, tables=[("device_inbox", "instance_name", "stream_id")], sequence_name="device_inbox_sequence", writers=hs.config.worker.writers.to_device, ) else: self._can_write_to_device = True self._device_inbox_id_gen = StreamIdGenerator( db_conn, "device_inbox", "stream_id") max_device_inbox_id = self._device_inbox_id_gen.get_current_token() device_inbox_prefill, min_device_inbox_id = self.db_pool.get_cache_dict( db_conn, "device_inbox", entity_column="user_id", stream_column="stream_id", max_value=max_device_inbox_id, limit=1000, ) self._device_inbox_stream_cache = StreamChangeCache( "DeviceInboxStreamChangeCache", min_device_inbox_id, prefilled_cache=device_inbox_prefill, ) # The federation outbox and the local device inbox uses the same # stream_id generator. device_outbox_prefill, min_device_outbox_id = self.db_pool.get_cache_dict( db_conn, "device_federation_outbox", entity_column="destination", stream_column="stream_id", max_value=max_device_inbox_id, limit=1000, ) self._device_federation_outbox_stream_cache = StreamChangeCache( "DeviceFederationOutboxStreamChangeCache", min_device_outbox_id, prefilled_cache=device_outbox_prefill, )
class DeviceInboxWorkerStore(SQLBaseStore): def __init__(self, database: DatabasePool, db_conn, hs): super().__init__(database, db_conn, hs) self._instance_name = hs.get_instance_name() # Map of (user_id, device_id) to the last stream_id that has been # deleted up to. This is so that we can no op deletions. self._last_device_delete_cache = ExpiringCache( cache_name="last_device_delete_cache", clock=self._clock, max_len=10000, expiry_ms=30 * 60 * 1000, ) if isinstance(database.engine, PostgresEngine): self._can_write_to_device = (self._instance_name in hs.config.worker.writers.to_device) self._device_inbox_id_gen = MultiWriterIdGenerator( db_conn=db_conn, db=database, stream_name="to_device", instance_name=self._instance_name, tables=[("device_inbox", "instance_name", "stream_id")], sequence_name="device_inbox_sequence", writers=hs.config.worker.writers.to_device, ) else: self._can_write_to_device = True self._device_inbox_id_gen = StreamIdGenerator( db_conn, "device_inbox", "stream_id") max_device_inbox_id = self._device_inbox_id_gen.get_current_token() device_inbox_prefill, min_device_inbox_id = self.db_pool.get_cache_dict( db_conn, "device_inbox", entity_column="user_id", stream_column="stream_id", max_value=max_device_inbox_id, limit=1000, ) self._device_inbox_stream_cache = StreamChangeCache( "DeviceInboxStreamChangeCache", min_device_inbox_id, prefilled_cache=device_inbox_prefill, ) # The federation outbox and the local device inbox uses the same # stream_id generator. device_outbox_prefill, min_device_outbox_id = self.db_pool.get_cache_dict( db_conn, "device_federation_outbox", entity_column="destination", stream_column="stream_id", max_value=max_device_inbox_id, limit=1000, ) self._device_federation_outbox_stream_cache = StreamChangeCache( "DeviceFederationOutboxStreamChangeCache", min_device_outbox_id, prefilled_cache=device_outbox_prefill, ) def process_replication_rows(self, stream_name, instance_name, token, rows): if stream_name == ToDeviceStream.NAME: self._device_inbox_id_gen.advance(instance_name, token) for row in rows: if row.entity.startswith("@"): self._device_inbox_stream_cache.entity_has_changed( row.entity, token) else: self._device_federation_outbox_stream_cache.entity_has_changed( row.entity, token) return super().process_replication_rows(stream_name, instance_name, token, rows) def get_to_device_stream_token(self): return self._device_inbox_id_gen.get_current_token() async def get_new_messages_for_device( self, user_id: str, device_id: str, last_stream_id: int, current_stream_id: int, limit: int = 100, ) -> Tuple[List[dict], int]: """ Args: user_id: The recipient user_id. device_id: The recipient device_id. last_stream_id: The last stream ID checked. current_stream_id: The current position of the to device message stream. limit: The maximum number of messages to retrieve. Returns: A list of messages for the device and where in the stream the messages got to. """ has_changed = self._device_inbox_stream_cache.has_entity_changed( user_id, last_stream_id) if not has_changed: return ([], current_stream_id) def get_new_messages_for_device_txn(txn): sql = ("SELECT stream_id, message_json FROM device_inbox" " WHERE user_id = ? AND device_id = ?" " AND ? < stream_id AND stream_id <= ?" " ORDER BY stream_id ASC" " LIMIT ?") txn.execute( sql, (user_id, device_id, last_stream_id, current_stream_id, limit)) messages = [] for row in txn: stream_pos = row[0] messages.append(db_to_json(row[1])) if len(messages) < limit: stream_pos = current_stream_id return messages, stream_pos return await self.db_pool.runInteraction( "get_new_messages_for_device", get_new_messages_for_device_txn) @trace async def delete_messages_for_device(self, user_id: str, device_id: str, up_to_stream_id: int) -> int: """ Args: user_id: The recipient user_id. device_id: The recipient device_id. up_to_stream_id: Where to delete messages up to. Returns: The number of messages deleted. """ # If we have cached the last stream id we've deleted up to, we can # check if there is likely to be anything that needs deleting last_deleted_stream_id = self._last_device_delete_cache.get( (user_id, device_id), None) set_tag("last_deleted_stream_id", last_deleted_stream_id) if last_deleted_stream_id: has_changed = self._device_inbox_stream_cache.has_entity_changed( user_id, last_deleted_stream_id) if not has_changed: log_kv({"message": "No changes in cache since last check"}) return 0 def delete_messages_for_device_txn(txn): sql = ("DELETE FROM device_inbox" " WHERE user_id = ? AND device_id = ?" " AND stream_id <= ?") txn.execute(sql, (user_id, device_id, up_to_stream_id)) return txn.rowcount count = await self.db_pool.runInteraction( "delete_messages_for_device", delete_messages_for_device_txn) log_kv({ "message": "deleted {} messages for device".format(count), "count": count }) # Update the cache, ensuring that we only ever increase the value last_deleted_stream_id = self._last_device_delete_cache.get( (user_id, device_id), 0) self._last_device_delete_cache[(user_id, device_id)] = max( last_deleted_stream_id, up_to_stream_id) return count @trace async def get_new_device_msgs_for_remote(self, destination, last_stream_id, current_stream_id, limit) -> Tuple[List[dict], int]: """ Args: destination(str): The name of the remote server. last_stream_id(int|long): The last position of the device message stream that the server sent up to. current_stream_id(int|long): The current position of the device message stream. Returns: A list of messages for the device and where in the stream the messages got to. """ set_tag("destination", destination) set_tag("last_stream_id", last_stream_id) set_tag("current_stream_id", current_stream_id) set_tag("limit", limit) has_changed = self._device_federation_outbox_stream_cache.has_entity_changed( destination, last_stream_id) if not has_changed or last_stream_id == current_stream_id: log_kv({"message": "No new messages in stream"}) return ([], current_stream_id) if limit <= 0: # This can happen if we run out of room for EDUs in the transaction. return ([], last_stream_id) @trace def get_new_messages_for_remote_destination_txn(txn): sql = ( "SELECT stream_id, messages_json FROM device_federation_outbox" " WHERE destination = ?" " AND ? < stream_id AND stream_id <= ?" " ORDER BY stream_id ASC" " LIMIT ?") txn.execute( sql, (destination, last_stream_id, current_stream_id, limit)) messages = [] for row in txn: stream_pos = row[0] messages.append(db_to_json(row[1])) if len(messages) < limit: log_kv({"message": "Set stream position to current position"}) stream_pos = current_stream_id return messages, stream_pos return await self.db_pool.runInteraction( "get_new_device_msgs_for_remote", get_new_messages_for_remote_destination_txn, ) @trace async def delete_device_msgs_for_remote(self, destination: str, up_to_stream_id: int) -> None: """Used to delete messages when the remote destination acknowledges their receipt. Args: destination: The destination server_name up_to_stream_id: Where to delete messages up to. """ def delete_messages_for_remote_destination_txn(txn): sql = ("DELETE FROM device_federation_outbox" " WHERE destination = ?" " AND stream_id <= ?") txn.execute(sql, (destination, up_to_stream_id)) await self.db_pool.runInteraction( "delete_device_msgs_for_remote", delete_messages_for_remote_destination_txn) async def get_all_new_device_messages( self, instance_name: str, last_id: int, current_id: int, limit: int) -> Tuple[List[Tuple[int, tuple]], int, bool]: """Get updates for to device replication stream. Args: instance_name: The writer we want to fetch updates from. Unused here since there is only ever one writer. last_id: The token to fetch updates from. Exclusive. current_id: The token to fetch updates up to. Inclusive. limit: The requested limit for the number of rows to return. The function may return more or fewer rows. Returns: A tuple consisting of: the updates, a token to use to fetch subsequent updates, and whether we returned fewer rows than exists between the requested tokens due to the limit. The token returned can be used in a subsequent call to this function to get further updatees. The updates are a list of 2-tuples of stream ID and the row data """ if last_id == current_id: return [], current_id, False def get_all_new_device_messages_txn(txn): # We limit like this as we might have multiple rows per stream_id, and # we want to make sure we always get all entries for any stream_id # we return. upper_pos = min(current_id, last_id + limit) sql = ("SELECT max(stream_id), user_id" " FROM device_inbox" " WHERE ? < stream_id AND stream_id <= ?" " GROUP BY user_id") txn.execute(sql, (last_id, upper_pos)) updates = [(row[0], row[1:]) for row in txn] sql = ("SELECT max(stream_id), destination" " FROM device_federation_outbox" " WHERE ? < stream_id AND stream_id <= ?" " GROUP BY destination") txn.execute(sql, (last_id, upper_pos)) updates.extend((row[0], row[1:]) for row in txn) # Order by ascending stream ordering updates.sort() limited = False upto_token = current_id if len(updates) >= limit: upto_token = updates[-1][0] limited = True return updates, upto_token, limited return await self.db_pool.runInteraction( "get_all_new_device_messages", get_all_new_device_messages_txn) @trace async def add_messages_to_device_inbox( self, local_messages_by_user_then_device: dict, remote_messages_by_destination: dict, ) -> int: """Used to send messages from this server. Args: local_messages_by_user_and_device: Dictionary of user_id to device_id to message. remote_messages_by_destination: Dictionary of destination server_name to the EDU JSON to send. Returns: The new stream_id. """ assert self._can_write_to_device def add_messages_txn(txn, now_ms, stream_id): # Add the local messages directly to the local inbox. self._add_messages_to_local_device_inbox_txn( txn, stream_id, local_messages_by_user_then_device) # Add the remote messages to the federation outbox. # We'll send them to a remote server when we next send a # federation transaction to that destination. self.db_pool.simple_insert_many_txn( txn, table="device_federation_outbox", values=[{ "destination": destination, "stream_id": stream_id, "queued_ts": now_ms, "messages_json": json_encoder.encode(edu), "instance_name": self._instance_name, } for destination, edu in remote_messages_by_destination.items()], ) async with self._device_inbox_id_gen.get_next() as stream_id: now_ms = self.clock.time_msec() await self.db_pool.runInteraction("add_messages_to_device_inbox", add_messages_txn, now_ms, stream_id) for user_id in local_messages_by_user_then_device.keys(): self._device_inbox_stream_cache.entity_has_changed( user_id, stream_id) for destination in remote_messages_by_destination.keys(): self._device_federation_outbox_stream_cache.entity_has_changed( destination, stream_id) return self._device_inbox_id_gen.get_current_token() async def add_messages_from_remote_to_device_inbox( self, origin: str, message_id: str, local_messages_by_user_then_device: dict) -> int: assert self._can_write_to_device def add_messages_txn(txn, now_ms, stream_id): # Check if we've already inserted a matching message_id for that # origin. This can happen if the origin doesn't receive our # acknowledgement from the first time we received the message. already_inserted = self.db_pool.simple_select_one_txn( txn, table="device_federation_inbox", keyvalues={ "origin": origin, "message_id": message_id }, retcols=("message_id", ), allow_none=True, ) if already_inserted is not None: return # Add an entry for this message_id so that we know we've processed # it. self.db_pool.simple_insert_txn( txn, table="device_federation_inbox", values={ "origin": origin, "message_id": message_id, "received_ts": now_ms, }, ) # Add the messages to the approriate local device inboxes so that # they'll be sent to the devices when they next sync. self._add_messages_to_local_device_inbox_txn( txn, stream_id, local_messages_by_user_then_device) async with self._device_inbox_id_gen.get_next() as stream_id: now_ms = self.clock.time_msec() await self.db_pool.runInteraction( "add_messages_from_remote_to_device_inbox", add_messages_txn, now_ms, stream_id, ) for user_id in local_messages_by_user_then_device.keys(): self._device_inbox_stream_cache.entity_has_changed( user_id, stream_id) return stream_id def _add_messages_to_local_device_inbox_txn(self, txn, stream_id, messages_by_user_then_device): assert self._can_write_to_device local_by_user_then_device = {} for user_id, messages_by_device in messages_by_user_then_device.items( ): messages_json_for_user = {} devices = list(messages_by_device.keys()) if len(devices) == 1 and devices[0] == "*": # Handle wildcard device_ids. devices = self.db_pool.simple_select_onecol_txn( txn, table="devices", keyvalues={"user_id": user_id}, retcol="device_id", ) message_json = json_encoder.encode(messages_by_device["*"]) for device_id in devices: # Add the message for all devices for this user on this # server. messages_json_for_user[device_id] = message_json else: if not devices: continue rows = self.db_pool.simple_select_many_txn( txn, table="devices", keyvalues={"user_id": user_id}, column="device_id", iterable=devices, retcols=("device_id", ), ) for row in rows: # Only insert into the local inbox if the device exists on # this server device_id = row["device_id"] message_json = json_encoder.encode( messages_by_device[device_id]) messages_json_for_user[device_id] = message_json if messages_json_for_user: local_by_user_then_device[user_id] = messages_json_for_user if not local_by_user_then_device: return self.db_pool.simple_insert_many_txn( txn, table="device_inbox", values=[{ "user_id": user_id, "device_id": device_id, "stream_id": stream_id, "message_json": message_json, "instance_name": self._instance_name, } for user_id, messages_by_device in local_by_user_then_device.items() for device_id, message_json in messages_by_device.items()], )
class AccountDataWorkerStore(SQLBaseStore): """This is an abstract base class where subclasses must implement `get_max_account_data_stream_id` which can be called in the initializer. """ # This ABCMeta metaclass ensures that we cannot be instantiated without # the abstract methods being implemented. __metaclass__ = abc.ABCMeta def __init__(self, db_conn, hs): account_max = self.get_max_account_data_stream_id() self._account_data_stream_cache = StreamChangeCache( "AccountDataAndTagsChangeCache", account_max, ) super(AccountDataWorkerStore, self).__init__(db_conn, hs) @abc.abstractmethod def get_max_account_data_stream_id(self): """Get the current max stream ID for account data stream Returns: int """ raise NotImplementedError() @cached() def get_account_data_for_user(self, user_id): """Get all the client account_data for a user. Args: user_id(str): The user to get the account_data for. Returns: A deferred pair of a dict of global account_data and a dict mapping from room_id string to per room account_data dicts. """ def get_account_data_for_user_txn(txn): rows = self._simple_select_list_txn( txn, "account_data", {"user_id": user_id}, ["account_data_type", "content"] ) global_account_data = { row["account_data_type"]: json.loads(row["content"]) for row in rows } rows = self._simple_select_list_txn( txn, "room_account_data", {"user_id": user_id}, ["room_id", "account_data_type", "content"] ) by_room = {} for row in rows: room_data = by_room.setdefault(row["room_id"], {}) room_data[row["account_data_type"]] = json.loads(row["content"]) return (global_account_data, by_room) return self.runInteraction( "get_account_data_for_user", get_account_data_for_user_txn ) @cachedInlineCallbacks(num_args=2, max_entries=5000) def get_global_account_data_by_type_for_user(self, data_type, user_id): """ Returns: Deferred: A dict """ result = yield self._simple_select_one_onecol( table="account_data", keyvalues={ "user_id": user_id, "account_data_type": data_type, }, retcol="content", desc="get_global_account_data_by_type_for_user", allow_none=True, ) if result: defer.returnValue(json.loads(result)) else: defer.returnValue(None) @cached(num_args=2) def get_account_data_for_room(self, user_id, room_id): """Get all the client account_data for a user for a room. Args: user_id(str): The user to get the account_data for. room_id(str): The room to get the account_data for. Returns: A deferred dict of the room account_data """ def get_account_data_for_room_txn(txn): rows = self._simple_select_list_txn( txn, "room_account_data", {"user_id": user_id, "room_id": room_id}, ["account_data_type", "content"] ) return { row["account_data_type"]: json.loads(row["content"]) for row in rows } return self.runInteraction( "get_account_data_for_room", get_account_data_for_room_txn ) @cached(num_args=3, max_entries=5000) def get_account_data_for_room_and_type(self, user_id, room_id, account_data_type): """Get the client account_data of given type for a user for a room. Args: user_id(str): The user to get the account_data for. room_id(str): The room to get the account_data for. account_data_type (str): The account data type to get. Returns: A deferred of the room account_data for that type, or None if there isn't any set. """ def get_account_data_for_room_and_type_txn(txn): content_json = self._simple_select_one_onecol_txn( txn, table="room_account_data", keyvalues={ "user_id": user_id, "room_id": room_id, "account_data_type": account_data_type, }, retcol="content", allow_none=True ) return json.loads(content_json) if content_json else None return self.runInteraction( "get_account_data_for_room_and_type", get_account_data_for_room_and_type_txn, ) def get_all_updated_account_data(self, last_global_id, last_room_id, current_id, limit): """Get all the client account_data that has changed on the server Args: last_global_id(int): The position to fetch from for top level data last_room_id(int): The position to fetch from for per room data current_id(int): The position to fetch up to. Returns: A deferred pair of lists of tuples of stream_id int, user_id string, room_id string, type string, and content string. """ if last_room_id == current_id and last_global_id == current_id: return defer.succeed(([], [])) def get_updated_account_data_txn(txn): sql = ( "SELECT stream_id, user_id, account_data_type, content" " FROM account_data WHERE ? < stream_id AND stream_id <= ?" " ORDER BY stream_id ASC LIMIT ?" ) txn.execute(sql, (last_global_id, current_id, limit)) global_results = txn.fetchall() sql = ( "SELECT stream_id, user_id, room_id, account_data_type, content" " FROM room_account_data WHERE ? < stream_id AND stream_id <= ?" " ORDER BY stream_id ASC LIMIT ?" ) txn.execute(sql, (last_room_id, current_id, limit)) room_results = txn.fetchall() return (global_results, room_results) return self.runInteraction( "get_all_updated_account_data_txn", get_updated_account_data_txn ) def get_updated_account_data_for_user(self, user_id, stream_id): """Get all the client account_data for a that's changed for a user Args: user_id(str): The user to get the account_data for. stream_id(int): The point in the stream since which to get updates Returns: A deferred pair of a dict of global account_data and a dict mapping from room_id string to per room account_data dicts. """ def get_updated_account_data_for_user_txn(txn): sql = ( "SELECT account_data_type, content FROM account_data" " WHERE user_id = ? AND stream_id > ?" ) txn.execute(sql, (user_id, stream_id)) global_account_data = { row[0]: json.loads(row[1]) for row in txn } sql = ( "SELECT room_id, account_data_type, content FROM room_account_data" " WHERE user_id = ? AND stream_id > ?" ) txn.execute(sql, (user_id, stream_id)) account_data_by_room = {} for row in txn: room_account_data = account_data_by_room.setdefault(row[0], {}) room_account_data[row[1]] = json.loads(row[2]) return (global_account_data, account_data_by_room) changed = self._account_data_stream_cache.has_entity_changed( user_id, int(stream_id) ) if not changed: return ({}, {}) return self.runInteraction( "get_updated_account_data_for_user", get_updated_account_data_for_user_txn ) @cachedInlineCallbacks(num_args=2, cache_context=True, max_entries=5000) def is_ignored_by(self, ignored_user_id, ignorer_user_id, cache_context): ignored_account_data = yield self.get_global_account_data_by_type_for_user( "m.ignored_user_list", ignorer_user_id, on_invalidate=cache_context.invalidate, ) if not ignored_account_data: defer.returnValue(False) defer.returnValue( ignored_user_id in ignored_account_data.get("ignored_users", {}) )
class PushRulesWorkerStore( ApplicationServiceWorkerStore, ReceiptsWorkerStore, PusherWorkerStore, RoomMemberWorkerStore, EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta, ): """This is an abstract base class where subclasses must implement `get_max_push_rules_stream_id` which can be called in the initializer. """ def __init__( self, database: DatabasePool, db_conn: LoggingDatabaseConnection, hs: "HomeServer", ): super().__init__(database, db_conn, hs) if hs.config.worker.worker_app is None: self._push_rules_stream_id_gen: AbstractStreamIdTracker = StreamIdGenerator( db_conn, "push_rules_stream", "stream_id") else: self._push_rules_stream_id_gen = SlavedIdTracker( db_conn, "push_rules_stream", "stream_id") push_rules_prefill, push_rules_id = self.db_pool.get_cache_dict( db_conn, "push_rules_stream", entity_column="user_id", stream_column="stream_id", max_value=self.get_max_push_rules_stream_id(), ) self.push_rules_stream_cache = StreamChangeCache( "PushRulesStreamChangeCache", push_rules_id, prefilled_cache=push_rules_prefill, ) @abc.abstractmethod def get_max_push_rules_stream_id(self) -> int: """Get the position of the push rules stream. Returns: int """ raise NotImplementedError() @cached(max_entries=5000) async def get_push_rules_for_user(self, user_id: str) -> List[JsonDict]: rows = await self.db_pool.simple_select_list( table="push_rules", keyvalues={"user_name": user_id}, retcols=( "user_name", "rule_id", "priority_class", "priority", "conditions", "actions", ), desc="get_push_rules_for_user", ) rows.sort(key=lambda row: (-int(row["priority_class"]), -int(row["priority"]))) enabled_map = await self.get_push_rules_enabled_for_user(user_id) return _load_rules(rows, enabled_map, self.hs.config.experimental) @cached(max_entries=5000) async def get_push_rules_enabled_for_user(self, user_id: str) -> Dict[str, bool]: results = await self.db_pool.simple_select_list( table="push_rules_enable", keyvalues={"user_name": user_id}, retcols=("rule_id", "enabled"), desc="get_push_rules_enabled_for_user", ) return {r["rule_id"]: bool(r["enabled"]) for r in results} async def have_push_rules_changed_for_user(self, user_id: str, last_id: int) -> bool: if not self.push_rules_stream_cache.has_entity_changed( user_id, last_id): return False else: def have_push_rules_changed_txn(txn: LoggingTransaction) -> bool: sql = ("SELECT COUNT(stream_id) FROM push_rules_stream" " WHERE user_id = ? AND ? < stream_id") txn.execute(sql, (user_id, last_id)) (count, ) = cast(Tuple[int], txn.fetchone()) return bool(count) return await self.db_pool.runInteraction( "have_push_rules_changed", have_push_rules_changed_txn) @cachedList(cached_method_name="get_push_rules_for_user", list_name="user_ids") async def bulk_get_push_rules( self, user_ids: Collection[str]) -> Dict[str, List[JsonDict]]: if not user_ids: return {} results: Dict[str, List[JsonDict]] = {user_id: [] for user_id in user_ids} rows = await self.db_pool.simple_select_many_batch( table="push_rules", column="user_name", iterable=user_ids, retcols=("*", ), desc="bulk_get_push_rules", ) rows.sort(key=lambda row: (-int(row["priority_class"]), -int(row["priority"]))) for row in rows: results.setdefault(row["user_name"], []).append(row) enabled_map_by_user = await self.bulk_get_push_rules_enabled(user_ids) for user_id, rules in results.items(): results[user_id] = _load_rules( rules, enabled_map_by_user.get(user_id, {}), self.hs.config.experimental) return results @cachedList(cached_method_name="get_push_rules_enabled_for_user", list_name="user_ids") async def bulk_get_push_rules_enabled( self, user_ids: Collection[str]) -> Dict[str, Dict[str, bool]]: if not user_ids: return {} results: Dict[str, Dict[str, bool]] = {user_id: {} for user_id in user_ids} rows = await self.db_pool.simple_select_many_batch( table="push_rules_enable", column="user_name", iterable=user_ids, retcols=("user_name", "rule_id", "enabled"), desc="bulk_get_push_rules_enabled", ) for row in rows: enabled = bool(row["enabled"]) results.setdefault(row["user_name"], {})[row["rule_id"]] = enabled return results async def get_all_push_rule_updates( self, instance_name: str, last_id: int, current_id: int, limit: int) -> Tuple[List[Tuple[int, Tuple[str]]], int, bool]: """Get updates for push_rules replication stream. Args: instance_name: The writer we want to fetch updates from. Unused here since there is only ever one writer. last_id: The token to fetch updates from. Exclusive. current_id: The token to fetch updates up to. Inclusive. limit: The requested limit for the number of rows to return. The function may return more or fewer rows. Returns: A tuple consisting of: the updates, a token to use to fetch subsequent updates, and whether we returned fewer rows than exists between the requested tokens due to the limit. The token returned can be used in a subsequent call to this function to get further updatees. The updates are a list of 2-tuples of stream ID and the row data """ if last_id == current_id: return [], current_id, False def get_all_push_rule_updates_txn( txn: LoggingTransaction, ) -> Tuple[List[Tuple[int, Tuple[str]]], int, bool]: sql = """ SELECT stream_id, user_id FROM push_rules_stream WHERE ? < stream_id AND stream_id <= ? ORDER BY stream_id ASC LIMIT ? """ txn.execute(sql, (last_id, current_id, limit)) updates = cast( List[Tuple[int, Tuple[str]]], [(stream_id, (user_id, )) for stream_id, user_id in txn], ) limited = False upper_bound = current_id if len(updates) == limit: limited = True upper_bound = updates[-1][0] return updates, upper_bound, limited return await self.db_pool.runInteraction( "get_all_push_rule_updates", get_all_push_rule_updates_txn)
class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): """This is an abstract base class where subclasses must implement `get_room_max_stream_ordering` and `get_room_min_stream_ordering` which can be called in the initializer. """ __metaclass__ = abc.ABCMeta def __init__(self, db_conn, hs): super(StreamWorkerStore, self).__init__(db_conn, hs) events_max = self.get_room_max_stream_ordering() event_cache_prefill, min_event_val = self._get_cache_dict( db_conn, "events", entity_column="room_id", stream_column="stream_ordering", max_value=events_max, ) self._events_stream_cache = StreamChangeCache( "EventsRoomStreamChangeCache", min_event_val, prefilled_cache=event_cache_prefill, ) self._membership_stream_cache = StreamChangeCache( "MembershipStreamChangeCache", events_max ) self._stream_order_on_start = self.get_room_max_stream_ordering() @abc.abstractmethod def get_room_max_stream_ordering(self): raise NotImplementedError() @abc.abstractmethod def get_room_min_stream_ordering(self): raise NotImplementedError() @defer.inlineCallbacks def get_room_events_stream_for_rooms( self, room_ids, from_key, to_key, limit=0, order='DESC' ): """Get new room events in stream ordering since `from_key`. Args: room_id (str) from_key (str): Token from which no events are returned before to_key (str): Token from which no events are returned after. (This is typically the current stream token) limit (int): Maximum number of events to return order (str): Either "DESC" or "ASC". Determines which events are returned when the result is limited. If "DESC" then the most recent `limit` events are returned, otherwise returns the oldest `limit` events. Returns: Deferred[dict[str,tuple[list[FrozenEvent], str]]] A map from room id to a tuple containing: - list of recent events in the room - stream ordering key for the start of the chunk of events returned. """ from_id = RoomStreamToken.parse_stream_token(from_key).stream room_ids = yield self._events_stream_cache.get_entities_changed( room_ids, from_id ) if not room_ids: defer.returnValue({}) results = {} room_ids = list(room_ids) for rm_ids in (room_ids[i : i + 20] for i in range(0, len(room_ids), 20)): res = yield make_deferred_yieldable( defer.gatherResults( [ run_in_background( self.get_room_events_stream_for_room, room_id, from_key, to_key, limit, order=order, ) for room_id in rm_ids ], consumeErrors=True, ) ) results.update(dict(zip(rm_ids, res))) defer.returnValue(results) def get_rooms_that_changed(self, room_ids, from_key): """Given a list of rooms and a token, return rooms where there may have been changes. Args: room_ids (list) from_key (str): The room_key portion of a StreamToken """ from_key = RoomStreamToken.parse_stream_token(from_key).stream return set( room_id for room_id in room_ids if self._events_stream_cache.has_entity_changed(room_id, from_key) ) @defer.inlineCallbacks def get_room_events_stream_for_room( self, room_id, from_key, to_key, limit=0, order='DESC' ): """Get new room events in stream ordering since `from_key`. Args: room_id (str) from_key (str): Token from which no events are returned before to_key (str): Token from which no events are returned after. (This is typically the current stream token) limit (int): Maximum number of events to return order (str): Either "DESC" or "ASC". Determines which events are returned when the result is limited. If "DESC" then the most recent `limit` events are returned, otherwise returns the oldest `limit` events. Returns: Deferred[tuple[list[FrozenEvent], str]]: Returns the list of events (in ascending order) and the token from the start of the chunk of events returned. """ if from_key == to_key: defer.returnValue(([], from_key)) from_id = RoomStreamToken.parse_stream_token(from_key).stream to_id = RoomStreamToken.parse_stream_token(to_key).stream has_changed = yield self._events_stream_cache.has_entity_changed( room_id, from_id ) if not has_changed: defer.returnValue(([], from_key)) def f(txn): sql = ( "SELECT event_id, stream_ordering FROM events WHERE" " room_id = ?" " AND not outlier" " AND stream_ordering > ? AND stream_ordering <= ?" " ORDER BY stream_ordering %s LIMIT ?" ) % (order,) txn.execute(sql, (room_id, from_id, to_id, limit)) rows = [_EventDictReturn(row[0], None, row[1]) for row in txn] return rows rows = yield self.runInteraction("get_room_events_stream_for_room", f) ret = yield self.get_events_as_list([ r.event_id for r in rows], get_prev_content=True, ) self._set_before_and_after(ret, rows, topo_order=from_id is None) if order.lower() == "desc": ret.reverse() if rows: key = "s%d" % min(r.stream_ordering for r in rows) else: # Assume we didn't get anything because there was nothing to # get. key = from_key defer.returnValue((ret, key)) @defer.inlineCallbacks def get_membership_changes_for_user(self, user_id, from_key, to_key): from_id = RoomStreamToken.parse_stream_token(from_key).stream to_id = RoomStreamToken.parse_stream_token(to_key).stream if from_key == to_key: defer.returnValue([]) if from_id: has_changed = self._membership_stream_cache.has_entity_changed( user_id, int(from_id) ) if not has_changed: defer.returnValue([]) def f(txn): sql = ( "SELECT m.event_id, stream_ordering FROM events AS e," " room_memberships AS m" " WHERE e.event_id = m.event_id" " AND m.user_id = ?" " AND e.stream_ordering > ? AND e.stream_ordering <= ?" " ORDER BY e.stream_ordering ASC" ) txn.execute(sql, (user_id, from_id, to_id)) rows = [_EventDictReturn(row[0], None, row[1]) for row in txn] return rows rows = yield self.runInteraction("get_membership_changes_for_user", f) ret = yield self.get_events_as_list( [r.event_id for r in rows], get_prev_content=True, ) self._set_before_and_after(ret, rows, topo_order=False) defer.returnValue(ret) @defer.inlineCallbacks def get_recent_events_for_room(self, room_id, limit, end_token): """Get the most recent events in the room in topological ordering. Args: room_id (str) limit (int) end_token (str): The stream token representing now. Returns: Deferred[tuple[list[FrozenEvent], str]]: Returns a list of events and a token pointing to the start of the returned events. The events returned are in ascending order. """ rows, token = yield self.get_recent_event_ids_for_room( room_id, limit, end_token ) logger.debug("stream before") events = yield self.get_events_as_list( [r.event_id for r in rows], get_prev_content=True ) logger.debug("stream after") self._set_before_and_after(events, rows) defer.returnValue((events, token)) @defer.inlineCallbacks def get_recent_event_ids_for_room(self, room_id, limit, end_token): """Get the most recent events in the room in topological ordering. Args: room_id (str) limit (int) end_token (str): The stream token representing now. Returns: Deferred[tuple[list[_EventDictReturn], str]]: Returns a list of _EventDictReturn and a token pointing to the start of the returned events. The events returned are in ascending order. """ # Allow a zero limit here, and no-op. if limit == 0: defer.returnValue(([], end_token)) end_token = RoomStreamToken.parse(end_token) rows, token = yield self.runInteraction( "get_recent_event_ids_for_room", self._paginate_room_events_txn, room_id, from_token=end_token, limit=limit, ) # We want to return the results in ascending order. rows.reverse() defer.returnValue((rows, token)) def get_room_event_after_stream_ordering(self, room_id, stream_ordering): """Gets details of the first event in a room at or after a stream ordering Args: room_id (str): stream_ordering (int): Returns: Deferred[(int, int, str)]: (stream ordering, topological ordering, event_id) """ def _f(txn): sql = ( "SELECT stream_ordering, topological_ordering, event_id" " FROM events" " WHERE room_id = ? AND stream_ordering >= ?" " AND NOT outlier" " ORDER BY stream_ordering" " LIMIT 1" ) txn.execute(sql, (room_id, stream_ordering)) return txn.fetchone() return self.runInteraction("get_room_event_after_stream_ordering", _f) @defer.inlineCallbacks def get_room_events_max_id(self, room_id=None): """Returns the current token for rooms stream. By default, it returns the current global stream token. Specifying a `room_id` causes it to return the current room specific topological token. """ token = yield self.get_room_max_stream_ordering() if room_id is None: defer.returnValue("s%d" % (token,)) else: topo = yield self.runInteraction( "_get_max_topological_txn", self._get_max_topological_txn, room_id ) defer.returnValue("t%d-%d" % (topo, token)) def get_stream_token_for_event(self, event_id): """The stream token for an event Args: event_id(str): The id of the event to look up a stream token for. Raises: StoreError if the event wasn't in the database. Returns: A deferred "s%d" stream token. """ return self._simple_select_one_onecol( table="events", keyvalues={"event_id": event_id}, retcol="stream_ordering" ).addCallback(lambda row: "s%d" % (row,)) def get_topological_token_for_event(self, event_id): """The stream token for an event Args: event_id(str): The id of the event to look up a stream token for. Raises: StoreError if the event wasn't in the database. Returns: A deferred "t%d-%d" topological token. """ return self._simple_select_one( table="events", keyvalues={"event_id": event_id}, retcols=("stream_ordering", "topological_ordering"), desc="get_topological_token_for_event", ).addCallback( lambda row: "t%d-%d" % (row["topological_ordering"], row["stream_ordering"]) ) def get_max_topological_token(self, room_id, stream_key): """Get the max topological token in a room before the given stream ordering. Args: room_id (str) stream_key (int) Returns: Deferred[int] """ sql = ( "SELECT coalesce(max(topological_ordering), 0) FROM events" " WHERE room_id = ? AND stream_ordering < ?" ) return self._execute( "get_max_topological_token", None, sql, room_id, stream_key ).addCallback(lambda r: r[0][0] if r else 0) def _get_max_topological_txn(self, txn, room_id): txn.execute( "SELECT MAX(topological_ordering) FROM events" " WHERE room_id = ?", (room_id,), ) rows = txn.fetchall() return rows[0][0] if rows else 0 @staticmethod def _set_before_and_after(events, rows, topo_order=True): """Inserts ordering information to events' internal metadata from the DB rows. Args: events (list[FrozenEvent]) rows (list[_EventDictReturn]) topo_order (bool): Whether the events were ordered topologically or by stream ordering. If true then all rows should have a non null topological_ordering. """ for event, row in zip(events, rows): stream = row.stream_ordering if topo_order and row.topological_ordering: topo = row.topological_ordering else: topo = None internal = event.internal_metadata internal.before = str(RoomStreamToken(topo, stream - 1)) internal.after = str(RoomStreamToken(topo, stream)) internal.order = (int(topo) if topo else 0, int(stream)) @defer.inlineCallbacks def get_events_around( self, room_id, event_id, before_limit, after_limit, event_filter=None ): """Retrieve events and pagination tokens around a given event in a room. Args: room_id (str) event_id (str) before_limit (int) after_limit (int) event_filter (Filter|None) Returns: dict """ results = yield self.runInteraction( "get_events_around", self._get_events_around_txn, room_id, event_id, before_limit, after_limit, event_filter, ) events_before = yield self.get_events_as_list( [e for e in results["before"]["event_ids"]], get_prev_content=True ) events_after = yield self.get_events_as_list( [e for e in results["after"]["event_ids"]], get_prev_content=True ) defer.returnValue( { "events_before": events_before, "events_after": events_after, "start": results["before"]["token"], "end": results["after"]["token"], } ) def _get_events_around_txn( self, txn, room_id, event_id, before_limit, after_limit, event_filter ): """Retrieves event_ids and pagination tokens around a given event in a room. Args: room_id (str) event_id (str) before_limit (int) after_limit (int) event_filter (Filter|None) Returns: dict """ results = self._simple_select_one_txn( txn, "events", keyvalues={"event_id": event_id, "room_id": room_id}, retcols=["stream_ordering", "topological_ordering"], ) # Paginating backwards includes the event at the token, but paginating # forward doesn't. before_token = RoomStreamToken( results["topological_ordering"] - 1, results["stream_ordering"] ) after_token = RoomStreamToken( results["topological_ordering"], results["stream_ordering"] ) rows, start_token = self._paginate_room_events_txn( txn, room_id, before_token, direction='b', limit=before_limit, event_filter=event_filter, ) events_before = [r.event_id for r in rows] rows, end_token = self._paginate_room_events_txn( txn, room_id, after_token, direction='f', limit=after_limit, event_filter=event_filter, ) events_after = [r.event_id for r in rows] return { "before": {"event_ids": events_before, "token": start_token}, "after": {"event_ids": events_after, "token": end_token}, } @defer.inlineCallbacks def get_all_new_events_stream(self, from_id, current_id, limit): """Get all new events Returns all events with from_id < stream_ordering <= current_id. Args: from_id (int): the stream_ordering of the last event we processed current_id (int): the stream_ordering of the most recently processed event limit (int): the maximum number of events to return Returns: Deferred[Tuple[int, list[FrozenEvent]]]: A tuple of (next_id, events), where `next_id` is the next value to pass as `from_id` (it will either be the stream_ordering of the last returned event, or, if fewer than `limit` events were found, `current_id`. """ def get_all_new_events_stream_txn(txn): sql = ( "SELECT e.stream_ordering, e.event_id" " FROM events AS e" " WHERE" " ? < e.stream_ordering AND e.stream_ordering <= ?" " ORDER BY e.stream_ordering ASC" " LIMIT ?" ) txn.execute(sql, (from_id, current_id, limit)) rows = txn.fetchall() upper_bound = current_id if len(rows) == limit: upper_bound = rows[-1][0] return upper_bound, [row[1] for row in rows] upper_bound, event_ids = yield self.runInteraction( "get_all_new_events_stream", get_all_new_events_stream_txn ) events = yield self.get_events_as_list(event_ids) defer.returnValue((upper_bound, events)) def get_federation_out_pos(self, typ): return self._simple_select_one_onecol( table="federation_stream_position", retcol="stream_id", keyvalues={"type": typ}, desc="get_federation_out_pos", ) def update_federation_out_pos(self, typ, stream_id): return self._simple_update_one( table="federation_stream_position", keyvalues={"type": typ}, updatevalues={"stream_id": stream_id}, desc="update_federation_out_pos", ) def has_room_changed_since(self, room_id, stream_id): return self._events_stream_cache.has_entity_changed(room_id, stream_id) def _paginate_room_events_txn( self, txn, room_id, from_token, to_token=None, direction='b', limit=-1, event_filter=None, ): """Returns list of events before or after a given token. Args: txn room_id (str) from_token (RoomStreamToken): The token used to stream from to_token (RoomStreamToken|None): A token which if given limits the results to only those before direction(char): Either 'b' or 'f' to indicate whether we are paginating forwards or backwards from `from_key`. limit (int): The maximum number of events to return. event_filter (Filter|None): If provided filters the events to those that match the filter. Returns: Deferred[tuple[list[_EventDictReturn], str]]: Returns the results as a list of _EventDictReturn and a token that points to the end of the result set. """ assert int(limit) >= 0 # Tokens really represent positions between elements, but we use # the convention of pointing to the event before the gap. Hence # we have a bit of asymmetry when it comes to equalities. args = [False, room_id] if direction == 'b': order = "DESC" else: order = "ASC" bounds = generate_pagination_where_clause( direction=direction, column_names=("topological_ordering", "stream_ordering"), from_token=from_token, to_token=to_token, engine=self.database_engine, ) filter_clause, filter_args = filter_to_clause(event_filter) if filter_clause: bounds += " AND " + filter_clause args.extend(filter_args) args.append(int(limit)) sql = ( "SELECT event_id, topological_ordering, stream_ordering" " FROM events" " WHERE outlier = ? AND room_id = ? AND %(bounds)s" " ORDER BY topological_ordering %(order)s," " stream_ordering %(order)s LIMIT ?" ) % {"bounds": bounds, "order": order} txn.execute(sql, args) rows = [_EventDictReturn(row[0], row[1], row[2]) for row in txn] if rows: topo = rows[-1].topological_ordering toke = rows[-1].stream_ordering if direction == 'b': # Tokens are positions between events. # This token points *after* the last event in the chunk. # We need it to point to the event before it in the chunk # when we are going backwards so we subtract one from the # stream part. toke -= 1 next_token = RoomStreamToken(topo, toke) else: # TODO (erikj): We should work out what to do here instead. next_token = to_token if to_token else from_token return rows, str(next_token) @defer.inlineCallbacks def paginate_room_events( self, room_id, from_key, to_key=None, direction='b', limit=-1, event_filter=None ): """Returns list of events before or after a given token. Args: room_id (str) from_key (str): The token used to stream from to_key (str|None): A token which if given limits the results to only those before direction(char): Either 'b' or 'f' to indicate whether we are paginating forwards or backwards from `from_key`. limit (int): The maximum number of events to return. Zero or less means no limit. event_filter (Filter|None): If provided filters the events to those that match the filter. Returns: tuple[list[dict], str]: Returns the results as a list of dicts and a token that points to the end of the result set. The dicts have the keys "event_id", "topological_ordering" and "stream_orderign". """ from_key = RoomStreamToken.parse(from_key) if to_key: to_key = RoomStreamToken.parse(to_key) rows, token = yield self.runInteraction( "paginate_room_events", self._paginate_room_events_txn, room_id, from_key, to_key, direction, limit, event_filter, ) events = yield self.get_events_as_list( [r.event_id for r in rows], get_prev_content=True ) self._set_before_and_after(events, rows) defer.returnValue((events, token))
class TypingWriterHandler(FollowerTypingHandler): def __init__(self, hs: "HomeServer"): super().__init__(hs) assert hs.config.worker.writers.typing == hs.get_instance_name() self.auth = hs.get_auth() self.notifier = hs.get_notifier() self.hs = hs hs.get_federation_registry().register_edu_handler( "m.typing", self._recv_edu) hs.get_distributor().observe("user_left_room", self.user_left_room) # clock time we expect to stop self._member_typing_until = {} # type: Dict[RoomMember, int] # caches which room_ids changed at which serials self._typing_stream_change_cache = StreamChangeCache( "TypingStreamChangeCache", self._latest_room_serial) def _handle_timeout_for_member(self, now: int, member: RoomMember) -> None: super()._handle_timeout_for_member(now, member) if not self.is_typing(member): # Nothing to do if they're no longer typing return until = self._member_typing_until.get(member, None) if not until or until <= now: logger.info("Timing out typing for: %s", member.user_id) self._stopped_typing(member) return async def started_typing(self, target_user: UserID, requester: Requester, room_id: str, timeout: int) -> None: target_user_id = target_user.to_string() auth_user_id = requester.user.to_string() if not self.is_mine_id(target_user_id): raise SynapseError(400, "User is not hosted on this homeserver") if target_user_id != auth_user_id: raise AuthError(400, "Cannot set another user's typing state") if requester.shadow_banned: # We randomly sleep a bit just to annoy the requester. await self.clock.sleep(random.randint(1, 10)) raise ShadowBanError() await self.auth.check_user_in_room(room_id, target_user_id) logger.debug("%s has started typing in %s", target_user_id, room_id) member = RoomMember(room_id=room_id, user_id=target_user_id) was_present = member.user_id in self._room_typing.get(room_id, set()) now = self.clock.time_msec() self._member_typing_until[member] = now + timeout self.wheel_timer.insert(now=now, obj=member, then=now + timeout) if was_present: # No point sending another notification return self._push_update(member=member, typing=True) async def stopped_typing(self, target_user: UserID, requester: Requester, room_id: str) -> None: target_user_id = target_user.to_string() auth_user_id = requester.user.to_string() if not self.is_mine_id(target_user_id): raise SynapseError(400, "User is not hosted on this homeserver") if target_user_id != auth_user_id: raise AuthError(400, "Cannot set another user's typing state") if requester.shadow_banned: # We randomly sleep a bit just to annoy the requester. await self.clock.sleep(random.randint(1, 10)) raise ShadowBanError() await self.auth.check_user_in_room(room_id, target_user_id) logger.debug("%s has stopped typing in %s", target_user_id, room_id) member = RoomMember(room_id=room_id, user_id=target_user_id) self._stopped_typing(member) def user_left_room(self, user: UserID, room_id: str) -> None: user_id = user.to_string() if self.is_mine_id(user_id): member = RoomMember(room_id=room_id, user_id=user_id) self._stopped_typing(member) def _stopped_typing(self, member: RoomMember) -> None: if member.user_id not in self._room_typing.get(member.room_id, set()): # No point return self._member_typing_until.pop(member, None) self._member_last_federation_poke.pop(member, None) self._push_update(member=member, typing=False) def _push_update(self, member: RoomMember, typing: bool) -> None: if self.hs.is_mine_id(member.user_id): # Only send updates for changes to our own users. run_as_background_process("typing._push_remote", self._push_remote, member, typing) self._push_update_local(member=member, typing=typing) async def _recv_edu(self, origin: str, content: JsonDict) -> None: room_id = content["room_id"] user_id = content["user_id"] member = RoomMember(user_id=user_id, room_id=room_id) # Check that the string is a valid user id user = UserID.from_string(user_id) if user.domain != origin: logger.info("Got typing update from %r with bad 'user_id': %r", origin, user_id) return users = await self.store.get_users_in_room(room_id) domains = {get_domain_from_id(u) for u in users} if self.server_name in domains: logger.info("Got typing update from %s: %r", user_id, content) now = self.clock.time_msec() self._member_typing_until[member] = now + FEDERATION_TIMEOUT self.wheel_timer.insert(now=now, obj=member, then=now + FEDERATION_TIMEOUT) self._push_update_local(member=member, typing=content["typing"]) def _push_update_local(self, member: RoomMember, typing: bool) -> None: room_set = self._room_typing.setdefault(member.room_id, set()) if typing: room_set.add(member.user_id) else: room_set.discard(member.user_id) self._latest_room_serial += 1 self._room_serials[member.room_id] = self._latest_room_serial self._typing_stream_change_cache.entity_has_changed( member.room_id, self._latest_room_serial) self.notifier.on_new_event("typing_key", self._latest_room_serial, rooms=[member.room_id]) async def get_all_typing_updates( self, instance_name: str, last_id: int, current_id: int, limit: int) -> Tuple[List[Tuple[int, list]], int, bool]: """Get updates for typing replication stream. Args: instance_name: The writer we want to fetch updates from. Unused here since there is only ever one writer. last_id: The token to fetch updates from. Exclusive. current_id: The token to fetch updates up to. Inclusive. limit: The requested limit for the number of rows to return. The function may return more or fewer rows. Returns: A tuple consisting of: the updates, a token to use to fetch subsequent updates, and whether we returned fewer rows than exists between the requested tokens due to the limit. The token returned can be used in a subsequent call to this function to get further updates. The updates are a list of 2-tuples of stream ID and the row data """ if last_id == current_id: return [], current_id, False changed_rooms = self._typing_stream_change_cache.get_all_entities_changed( last_id) # type: Optional[Iterable[str]] if changed_rooms is None: changed_rooms = self._room_serials rows = [] for room_id in changed_rooms: serial = self._room_serials[room_id] if last_id < serial <= current_id: typing = self._room_typing[room_id] rows.append((serial, [room_id, list(typing)])) rows.sort() limited = False # We, unusually, use a strict limit here as we have all the rows in # memory rather than pulling them out of the database with a `LIMIT ?` # clause. if len(rows) > limit: rows = rows[:limit] current_id = rows[-1][0] limited = True return rows, current_id, limited def process_replication_rows( self, token: int, rows: List[TypingStream.TypingStreamRow]) -> None: # The writing process should never get updates from replication. raise Exception( "Typing writer instance got typing info over replication")
class PushRulesWorkerStore( ApplicationServiceWorkerStore, ReceiptsWorkerStore, PusherWorkerStore, RoomMemberWorkerStore, SQLBaseStore, ): """This is an abstract base class where subclasses must implement `get_max_push_rules_stream_id` which can be called in the initializer. """ # This ABCMeta metaclass ensures that we cannot be instantiated without # the abstract methods being implemented. __metaclass__ = abc.ABCMeta def __init__(self, db_conn, hs): super(PushRulesWorkerStore, self).__init__(db_conn, hs) push_rules_prefill, push_rules_id = self._get_cache_dict( db_conn, "push_rules_stream", entity_column="user_id", stream_column="stream_id", max_value=self.get_max_push_rules_stream_id(), ) self.push_rules_stream_cache = StreamChangeCache( "PushRulesStreamChangeCache", push_rules_id, prefilled_cache=push_rules_prefill, ) @abc.abstractmethod def get_max_push_rules_stream_id(self): """Get the position of the push rules stream. Returns: int """ raise NotImplementedError() @cachedInlineCallbacks(max_entries=5000) def get_push_rules_for_user(self, user_id): rows = yield self._simple_select_list( table="push_rules", keyvalues={"user_name": user_id}, retcols=( "user_name", "rule_id", "priority_class", "priority", "conditions", "actions", ), desc="get_push_rules_enabled_for_user", ) rows.sort(key=lambda row: (-int(row["priority_class"]), -int(row["priority"]))) enabled_map = yield self.get_push_rules_enabled_for_user(user_id) rules = _load_rules(rows, enabled_map) defer.returnValue(rules) @cachedInlineCallbacks(max_entries=5000) def get_push_rules_enabled_for_user(self, user_id): results = yield self._simple_select_list( table="push_rules_enable", keyvalues={'user_name': user_id}, retcols=("user_name", "rule_id", "enabled"), desc="get_push_rules_enabled_for_user", ) defer.returnValue( {r['rule_id']: False if r['enabled'] == 0 else True for r in results} ) def have_push_rules_changed_for_user(self, user_id, last_id): if not self.push_rules_stream_cache.has_entity_changed(user_id, last_id): return defer.succeed(False) else: def have_push_rules_changed_txn(txn): sql = ( "SELECT COUNT(stream_id) FROM push_rules_stream" " WHERE user_id = ? AND ? < stream_id" ) txn.execute(sql, (user_id, last_id)) count, = txn.fetchone() return bool(count) return self.runInteraction( "have_push_rules_changed", have_push_rules_changed_txn ) @cachedList( cached_method_name="get_push_rules_for_user", list_name="user_ids", num_args=1, inlineCallbacks=True, ) def bulk_get_push_rules(self, user_ids): if not user_ids: defer.returnValue({}) results = {user_id: [] for user_id in user_ids} rows = yield self._simple_select_many_batch( table="push_rules", column="user_name", iterable=user_ids, retcols=("*",), desc="bulk_get_push_rules", ) rows.sort(key=lambda row: (-int(row["priority_class"]), -int(row["priority"]))) for row in rows: results.setdefault(row['user_name'], []).append(row) enabled_map_by_user = yield self.bulk_get_push_rules_enabled(user_ids) for user_id, rules in results.items(): results[user_id] = _load_rules(rules, enabled_map_by_user.get(user_id, {})) defer.returnValue(results) @defer.inlineCallbacks def move_push_rule_from_room_to_room(self, new_room_id, user_id, rule): """Move a single push rule from one room to another for a specific user. Args: new_room_id (str): ID of the new room. user_id (str): ID of user the push rule belongs to. rule (Dict): A push rule. """ # Create new rule id rule_id_scope = '/'.join(rule["rule_id"].split('/')[:-1]) new_rule_id = rule_id_scope + "/" + new_room_id # Change room id in each condition for condition in rule.get("conditions", []): if condition.get("key") == "room_id": condition["pattern"] = new_room_id # Add the rule for the new room yield self.add_push_rule( user_id=user_id, rule_id=new_rule_id, priority_class=rule["priority_class"], conditions=rule["conditions"], actions=rule["actions"], ) # Delete push rule for the old room yield self.delete_push_rule(user_id, rule["rule_id"]) @defer.inlineCallbacks def move_push_rules_from_room_to_room_for_user( self, old_room_id, new_room_id, user_id ): """Move all of the push rules from one room to another for a specific user. Args: old_room_id (str): ID of the old room. new_room_id (str): ID of the new room. user_id (str): ID of user to copy push rules for. """ # Retrieve push rules for this user user_push_rules = yield self.get_push_rules_for_user(user_id) # Get rules relating to the old room, move them to the new room, then # delete them from the old room for rule in user_push_rules: conditions = rule.get("conditions", []) if any( (c.get("key") == "room_id" and c.get("pattern") == old_room_id) for c in conditions ): self.move_push_rule_from_room_to_room(new_room_id, user_id, rule) @defer.inlineCallbacks def bulk_get_push_rules_for_room(self, event, context): state_group = context.state_group if not state_group: # If state_group is None it means it has yet to be assigned a # state group, i.e. we need to make sure that calls with a state_group # of None don't hit previous cached calls with a None state_group. # To do this we set the state_group to a new object as object() != object() state_group = object() current_state_ids = yield context.get_current_state_ids(self) result = yield self._bulk_get_push_rules_for_room( event.room_id, state_group, current_state_ids, event=event ) defer.returnValue(result) @cachedInlineCallbacks(num_args=2, cache_context=True) def _bulk_get_push_rules_for_room( self, room_id, state_group, current_state_ids, cache_context, event=None ): # We don't use `state_group`, its there so that we can cache based # on it. However, its important that its never None, since two current_state's # with a state_group of None are likely to be different. # See bulk_get_push_rules_for_room for how we work around this. assert state_group is not None # We also will want to generate notifs for other people in the room so # their unread countss are correct in the event stream, but to avoid # generating them for bot / AS users etc, we only do so for people who've # sent a read receipt into the room. users_in_room = yield self._get_joined_users_from_context( room_id, state_group, current_state_ids, on_invalidate=cache_context.invalidate, event=event, ) # We ignore app service users for now. This is so that we don't fill # up the `get_if_users_have_pushers` cache with AS entries that we # know don't have pushers, nor even read receipts. local_users_in_room = set( u for u in users_in_room if self.hs.is_mine_id(u) and not self.get_if_app_services_interested_in_user(u) ) # users in the room who have pushers need to get push rules run because # that's how their pushers work if_users_with_pushers = yield self.get_if_users_have_pushers( local_users_in_room, on_invalidate=cache_context.invalidate ) user_ids = set( uid for uid, have_pusher in if_users_with_pushers.items() if have_pusher ) users_with_receipts = yield self.get_users_with_read_receipts_in_room( room_id, on_invalidate=cache_context.invalidate ) # any users with pushers must be ours: they have pushers for uid in users_with_receipts: if uid in local_users_in_room: user_ids.add(uid) rules_by_user = yield self.bulk_get_push_rules( user_ids, on_invalidate=cache_context.invalidate ) rules_by_user = {k: v for k, v in rules_by_user.items() if v is not None} defer.returnValue(rules_by_user) @cachedList( cached_method_name="get_push_rules_enabled_for_user", list_name="user_ids", num_args=1, inlineCallbacks=True, ) def bulk_get_push_rules_enabled(self, user_ids): if not user_ids: defer.returnValue({}) results = {user_id: {} for user_id in user_ids} rows = yield self._simple_select_many_batch( table="push_rules_enable", column="user_name", iterable=user_ids, retcols=("user_name", "rule_id", "enabled"), desc="bulk_get_push_rules_enabled", ) for row in rows: enabled = bool(row['enabled']) results.setdefault(row['user_name'], {})[row['rule_id']] = enabled defer.returnValue(results)
def __init__(self, db_conn, hs): super(ReceiptsWorkerStore, self).__init__(db_conn, hs) self._receipts_stream_cache = StreamChangeCache( "ReceiptsRoomChangeCache", self.get_max_receipt_stream_id() )
class SlavedAccountDataStore(BaseSlavedStore): def __init__(self, db_conn, hs): super(SlavedAccountDataStore, self).__init__(db_conn, hs) self._account_data_id_gen = SlavedIdTracker( db_conn, "account_data_max_stream_id", "stream_id", ) self._account_data_stream_cache = StreamChangeCache( "AccountDataAndTagsChangeCache", self._account_data_id_gen.get_current_token(), ) get_account_data_for_user = ( AccountDataStore.__dict__["get_account_data_for_user"] ) get_global_account_data_by_type_for_users = ( AccountDataStore.__dict__["get_global_account_data_by_type_for_users"] ) get_global_account_data_by_type_for_user = ( AccountDataStore.__dict__["get_global_account_data_by_type_for_user"] ) get_tags_for_user = TagsStore.__dict__["get_tags_for_user"] get_updated_tags = DataStore.get_updated_tags.__func__ get_updated_account_data_for_user = ( DataStore.get_updated_account_data_for_user.__func__ ) def get_max_account_data_stream_id(self): return self._account_data_id_gen.get_current_token() def stream_positions(self): result = super(SlavedAccountDataStore, self).stream_positions() position = self._account_data_id_gen.get_current_token() result["user_account_data"] = position result["room_account_data"] = position result["tag_account_data"] = position return result def process_replication(self, result): stream = result.get("user_account_data") if stream: self._account_data_id_gen.advance(int(stream["position"])) for row in stream["rows"]: position, user_id, data_type = row[:3] self.get_global_account_data_by_type_for_user.invalidate( (data_type, user_id,) ) self.get_account_data_for_user.invalidate((user_id,)) self._account_data_stream_cache.entity_has_changed( user_id, position ) stream = result.get("room_account_data") if stream: self._account_data_id_gen.advance(int(stream["position"])) for row in stream["rows"]: position, user_id = row[:2] self.get_account_data_for_user.invalidate((user_id,)) self._account_data_stream_cache.entity_has_changed( user_id, position ) stream = result.get("tag_account_data") if stream: self._account_data_id_gen.advance(int(stream["position"])) for row in stream["rows"]: position, user_id = row[:2] self.get_tags_for_user.invalidate((user_id,)) self._account_data_stream_cache.entity_has_changed( user_id, position ) return super(SlavedAccountDataStore, self).process_replication(result)