def test_prefilled_cache(self): """ Providing a prefilled cache to StreamChangeCache will result in a cache with the prefilled-cache entered in. """ cache = StreamChangeCache("#test", 1, prefilled_cache={"*****@*****.**": 2}) self.assertTrue(cache.has_entity_changed("*****@*****.**", 1))
def test_has_entity_changed(self): """ StreamChangeCache.entity_has_changed will mark entities as changed, and has_entity_changed will observe the changed entities. """ cache = StreamChangeCache("#test", 3) cache.entity_has_changed("*****@*****.**", 6) cache.entity_has_changed("*****@*****.**", 7) # If it's been changed after that stream position, return True self.assertTrue(cache.has_entity_changed("*****@*****.**", 4)) self.assertTrue(cache.has_entity_changed("*****@*****.**", 4)) # If it's been changed at that stream position, return False self.assertFalse(cache.has_entity_changed("*****@*****.**", 6)) # If there's no changes after that stream position, return False self.assertFalse(cache.has_entity_changed("*****@*****.**", 7)) # If the entity does not exist, return False. self.assertFalse(cache.has_entity_changed("*****@*****.**", 7)) # If we request before the stream cache's earliest known position, # return True, whether it's a known entity or not. self.assertTrue(cache.has_entity_changed("*****@*****.**", 0)) self.assertTrue(cache.has_entity_changed("*****@*****.**", 0))
class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): """This is an abstract base class where subclasses must implement `get_room_max_stream_ordering` and `get_room_min_stream_ordering` which can be called in the initializer. """ __metaclass__ = abc.ABCMeta def __init__(self, database: DatabasePool, db_conn, hs): super(StreamWorkerStore, self).__init__(database, db_conn, hs) self._instance_name = hs.get_instance_name() self._send_federation = hs.should_send_federation() self._federation_shard_config = hs.config.worker.federation_shard_config # If we're a process that sends federation we may need to reset the # `federation_stream_position` table to match the current sharding # config. We don't do this now as otherwise two processes could conflict # during startup which would cause one to die. self._need_to_reset_federation_stream_positions = self._send_federation events_max = self.get_room_max_stream_ordering() event_cache_prefill, min_event_val = self.db_pool.get_cache_dict( db_conn, "events", entity_column="room_id", stream_column="stream_ordering", max_value=events_max, ) self._events_stream_cache = StreamChangeCache( "EventsRoomStreamChangeCache", min_event_val, prefilled_cache=event_cache_prefill, ) self._membership_stream_cache = StreamChangeCache( "MembershipStreamChangeCache", events_max ) self._stream_order_on_start = self.get_room_max_stream_ordering() @abc.abstractmethod def get_room_max_stream_ordering(self): raise NotImplementedError() @abc.abstractmethod def get_room_min_stream_ordering(self): raise NotImplementedError() async def get_room_events_stream_for_rooms( self, room_ids: Iterable[str], from_key: str, to_key: str, limit: int = 0, order: str = "DESC", ) -> Dict[str, Tuple[List[EventBase], str]]: """Get new room events in stream ordering since `from_key`. Args: room_ids from_key: Token from which no events are returned before to_key: Token from which no events are returned after. (This is typically the current stream token) limit: Maximum number of events to return order: Either "DESC" or "ASC". Determines which events are returned when the result is limited. If "DESC" then the most recent `limit` events are returned, otherwise returns the oldest `limit` events. Returns: A map from room id to a tuple containing: - list of recent events in the room - stream ordering key for the start of the chunk of events returned. """ from_id = RoomStreamToken.parse_stream_token(from_key).stream room_ids = self._events_stream_cache.get_entities_changed(room_ids, from_id) if not room_ids: return {} results = {} room_ids = list(room_ids) for rm_ids in (room_ids[i : i + 20] for i in range(0, len(room_ids), 20)): res = await make_deferred_yieldable( defer.gatherResults( [ run_in_background( self.get_room_events_stream_for_room, room_id, from_key, to_key, limit, order=order, ) for room_id in rm_ids ], consumeErrors=True, ) ) results.update(dict(zip(rm_ids, res))) return results def get_rooms_that_changed(self, room_ids, from_key): """Given a list of rooms and a token, return rooms where there may have been changes. Args: room_ids (list) from_key (str): The room_key portion of a StreamToken """ from_key = RoomStreamToken.parse_stream_token(from_key).stream return { room_id for room_id in room_ids if self._events_stream_cache.has_entity_changed(room_id, from_key) } async def get_room_events_stream_for_room( self, room_id: str, from_key: str, to_key: str, limit: int = 0, order: str = "DESC", ) -> Tuple[List[EventBase], str]: """Get new room events in stream ordering since `from_key`. Args: room_id from_key: Token from which no events are returned before to_key: Token from which no events are returned after. (This is typically the current stream token) limit: Maximum number of events to return order: Either "DESC" or "ASC". Determines which events are returned when the result is limited. If "DESC" then the most recent `limit` events are returned, otherwise returns the oldest `limit` events. Returns: The list of events (in ascending order) and the token from the start of the chunk of events returned. """ if from_key == to_key: return [], from_key from_id = RoomStreamToken.parse_stream_token(from_key).stream to_id = RoomStreamToken.parse_stream_token(to_key).stream has_changed = self._events_stream_cache.has_entity_changed(room_id, from_id) if not has_changed: return [], from_key def f(txn): sql = ( "SELECT event_id, stream_ordering FROM events WHERE" " room_id = ?" " AND not outlier" " AND stream_ordering > ? AND stream_ordering <= ?" " ORDER BY stream_ordering %s LIMIT ?" ) % (order,) txn.execute(sql, (room_id, from_id, to_id, limit)) rows = [_EventDictReturn(row[0], None, row[1]) for row in txn] return rows rows = await self.db_pool.runInteraction("get_room_events_stream_for_room", f) ret = await self.get_events_as_list( [r.event_id for r in rows], get_prev_content=True ) self._set_before_and_after(ret, rows, topo_order=from_id is None) if order.lower() == "desc": ret.reverse() if rows: key = "s%d" % min(r.stream_ordering for r in rows) else: # Assume we didn't get anything because there was nothing to # get. key = from_key return ret, key async def get_membership_changes_for_user(self, user_id, from_key, to_key): from_id = RoomStreamToken.parse_stream_token(from_key).stream to_id = RoomStreamToken.parse_stream_token(to_key).stream if from_key == to_key: return [] if from_id: has_changed = self._membership_stream_cache.has_entity_changed( user_id, int(from_id) ) if not has_changed: return [] def f(txn): sql = ( "SELECT m.event_id, stream_ordering FROM events AS e," " room_memberships AS m" " WHERE e.event_id = m.event_id" " AND m.user_id = ?" " AND e.stream_ordering > ? AND e.stream_ordering <= ?" " ORDER BY e.stream_ordering ASC" ) txn.execute(sql, (user_id, from_id, to_id)) rows = [_EventDictReturn(row[0], None, row[1]) for row in txn] return rows rows = await self.db_pool.runInteraction("get_membership_changes_for_user", f) ret = await self.get_events_as_list( [r.event_id for r in rows], get_prev_content=True ) self._set_before_and_after(ret, rows, topo_order=False) return ret async def get_recent_events_for_room( self, room_id: str, limit: int, end_token: str ) -> Tuple[List[EventBase], str]: """Get the most recent events in the room in topological ordering. Args: room_id limit end_token: The stream token representing now. Returns: A list of events and a token pointing to the start of the returned events. The events returned are in ascending order. """ rows, token = await self.get_recent_event_ids_for_room( room_id, limit, end_token ) events = await self.get_events_as_list( [r.event_id for r in rows], get_prev_content=True ) self._set_before_and_after(events, rows) return (events, token) async def get_recent_event_ids_for_room( self, room_id: str, limit: int, end_token: str ) -> Tuple[List[_EventDictReturn], str]: """Get the most recent events in the room in topological ordering. Args: room_id limit end_token: The stream token representing now. Returns: A list of _EventDictReturn and a token pointing to the start of the returned events. The events returned are in ascending order. """ # Allow a zero limit here, and no-op. if limit == 0: return [], end_token end_token = RoomStreamToken.parse(end_token) rows, token = await self.db_pool.runInteraction( "get_recent_event_ids_for_room", self._paginate_room_events_txn, room_id, from_token=end_token, limit=limit, ) # We want to return the results in ascending order. rows.reverse() return rows, token def get_room_event_before_stream_ordering(self, room_id: str, stream_ordering: int): """Gets details of the first event in a room at or before a stream ordering Args: room_id: stream_ordering: Returns: Deferred[(int, int, str)]: (stream ordering, topological ordering, event_id) """ def _f(txn): sql = ( "SELECT stream_ordering, topological_ordering, event_id" " FROM events" " WHERE room_id = ? AND stream_ordering <= ?" " AND NOT outlier" " ORDER BY stream_ordering DESC" " LIMIT 1" ) txn.execute(sql, (room_id, stream_ordering)) return txn.fetchone() return self.db_pool.runInteraction("get_room_event_before_stream_ordering", _f) async def get_room_events_max_id(self, room_id: Optional[str] = None) -> str: """Returns the current token for rooms stream. By default, it returns the current global stream token. Specifying a `room_id` causes it to return the current room specific topological token. """ token = self.get_room_max_stream_ordering() if room_id is None: return "s%d" % (token,) else: topo = await self.db_pool.runInteraction( "_get_max_topological_txn", self._get_max_topological_txn, room_id ) return "t%d-%d" % (topo, token) async def get_stream_token_for_event(self, event_id: str) -> str: """The stream token for an event Args: event_id: The id of the event to look up a stream token for. Raises: StoreError if the event wasn't in the database. Returns: A "s%d" stream token. """ row = await self.db_pool.simple_select_one_onecol( table="events", keyvalues={"event_id": event_id}, retcol="stream_ordering" ) return "s%d" % (row,) async def get_topological_token_for_event(self, event_id: str) -> str: """The stream token for an event Args: event_id: The id of the event to look up a stream token for. Raises: StoreError if the event wasn't in the database. Returns: A "t%d-%d" topological token. """ row = await self.db_pool.simple_select_one( table="events", keyvalues={"event_id": event_id}, retcols=("stream_ordering", "topological_ordering"), desc="get_topological_token_for_event", ) return "t%d-%d" % (row["topological_ordering"], row["stream_ordering"]) async def get_max_topological_token(self, room_id: str, stream_key: int) -> int: """Get the max topological token in a room before the given stream ordering. Args: room_id stream_key Returns: The maximum topological token. """ sql = ( "SELECT coalesce(max(topological_ordering), 0) FROM events" " WHERE room_id = ? AND stream_ordering < ?" ) row = await self.db_pool.execute( "get_max_topological_token", None, sql, room_id, stream_key ) return row[0][0] if row else 0 def _get_max_topological_txn(self, txn, room_id): txn.execute( "SELECT MAX(topological_ordering) FROM events WHERE room_id = ?", (room_id,), ) rows = txn.fetchall() return rows[0][0] if rows else 0 @staticmethod def _set_before_and_after( events: List[EventBase], rows: List[_EventDictReturn], topo_order: bool = True ): """Inserts ordering information to events' internal metadata from the DB rows. Args: events rows topo_order: Whether the events were ordered topologically or by stream ordering. If true then all rows should have a non null topological_ordering. """ for event, row in zip(events, rows): stream = row.stream_ordering if topo_order and row.topological_ordering: topo = row.topological_ordering else: topo = None internal = event.internal_metadata internal.before = str(RoomStreamToken(topo, stream - 1)) internal.after = str(RoomStreamToken(topo, stream)) internal.order = (int(topo) if topo else 0, int(stream)) async def get_events_around( self, room_id: str, event_id: str, before_limit: int, after_limit: int, event_filter: Optional[Filter] = None, ) -> dict: """Retrieve events and pagination tokens around a given event in a room. """ results = await self.db_pool.runInteraction( "get_events_around", self._get_events_around_txn, room_id, event_id, before_limit, after_limit, event_filter, ) events_before = await self.get_events_as_list( list(results["before"]["event_ids"]), get_prev_content=True ) events_after = await self.get_events_as_list( list(results["after"]["event_ids"]), get_prev_content=True ) return { "events_before": events_before, "events_after": events_after, "start": results["before"]["token"], "end": results["after"]["token"], } def _get_events_around_txn( self, txn, room_id: str, event_id: str, before_limit: int, after_limit: int, event_filter: Optional[Filter], ) -> dict: """Retrieves event_ids and pagination tokens around a given event in a room. Args: room_id event_id before_limit after_limit event_filter Returns: dict """ results = self.db_pool.simple_select_one_txn( txn, "events", keyvalues={"event_id": event_id, "room_id": room_id}, retcols=["stream_ordering", "topological_ordering"], ) # Paginating backwards includes the event at the token, but paginating # forward doesn't. before_token = RoomStreamToken( results["topological_ordering"] - 1, results["stream_ordering"] ) after_token = RoomStreamToken( results["topological_ordering"], results["stream_ordering"] ) rows, start_token = self._paginate_room_events_txn( txn, room_id, before_token, direction="b", limit=before_limit, event_filter=event_filter, ) events_before = [r.event_id for r in rows] rows, end_token = self._paginate_room_events_txn( txn, room_id, after_token, direction="f", limit=after_limit, event_filter=event_filter, ) events_after = [r.event_id for r in rows] return { "before": {"event_ids": events_before, "token": start_token}, "after": {"event_ids": events_after, "token": end_token}, } async def get_all_new_events_stream( self, from_id: int, current_id: int, limit: int ) -> Tuple[int, List[EventBase]]: """Get all new events Returns all events with from_id < stream_ordering <= current_id. Args: from_id: the stream_ordering of the last event we processed current_id: the stream_ordering of the most recently processed event limit: the maximum number of events to return Returns: A tuple of (next_id, events), where `next_id` is the next value to pass as `from_id` (it will either be the stream_ordering of the last returned event, or, if fewer than `limit` events were found, the `current_id`). """ def get_all_new_events_stream_txn(txn): sql = ( "SELECT e.stream_ordering, e.event_id" " FROM events AS e" " WHERE" " ? < e.stream_ordering AND e.stream_ordering <= ?" " ORDER BY e.stream_ordering ASC" " LIMIT ?" ) txn.execute(sql, (from_id, current_id, limit)) rows = txn.fetchall() upper_bound = current_id if len(rows) == limit: upper_bound = rows[-1][0] return upper_bound, [row[1] for row in rows] upper_bound, event_ids = await self.db_pool.runInteraction( "get_all_new_events_stream", get_all_new_events_stream_txn ) events = await self.get_events_as_list(event_ids) return upper_bound, events async def get_federation_out_pos(self, typ: str) -> int: if self._need_to_reset_federation_stream_positions: await self.db_pool.runInteraction( "_reset_federation_positions_txn", self._reset_federation_positions_txn ) self._need_to_reset_federation_stream_positions = False return await self.db_pool.simple_select_one_onecol( table="federation_stream_position", retcol="stream_id", keyvalues={"type": typ, "instance_name": self._instance_name}, desc="get_federation_out_pos", ) async def update_federation_out_pos(self, typ: str, stream_id: int) -> None: if self._need_to_reset_federation_stream_positions: await self.db_pool.runInteraction( "_reset_federation_positions_txn", self._reset_federation_positions_txn ) self._need_to_reset_federation_stream_positions = False await self.db_pool.simple_update_one( table="federation_stream_position", keyvalues={"type": typ, "instance_name": self._instance_name}, updatevalues={"stream_id": stream_id}, desc="update_federation_out_pos", ) def _reset_federation_positions_txn(self, txn) -> None: """Fiddles with the `federation_stream_position` table to make it match the configured federation sender instances during start up. """ # The federation sender instances may have changed, so we need to # massage the `federation_stream_position` table to have a row per type # per instance sending federation. If there is a mismatch we update the # table with the correct rows using the *minimum* stream ID seen. This # may result in resending of events/EDUs to remote servers, but that is # preferable to dropping them. if not self._send_federation: return # Pull out the configured instances. If we don't have a shard config then # we assume that we're the only instance sending. configured_instances = self._federation_shard_config.instances if not configured_instances: configured_instances = [self._instance_name] elif self._instance_name not in configured_instances: return instances_in_table = self.db_pool.simple_select_onecol_txn( txn, table="federation_stream_position", keyvalues={}, retcol="instance_name", ) if set(instances_in_table) == set(configured_instances): # Nothing to do return sql = """ SELECT type, MIN(stream_id) FROM federation_stream_position GROUP BY type """ txn.execute(sql) min_positions = dict(txn) # Map from type -> min position # Ensure we do actually have some values here assert set(min_positions) == {"federation", "events"} sql = """ DELETE FROM federation_stream_position WHERE NOT (%s) """ clause, args = make_in_list_sql_clause( txn.database_engine, "instance_name", configured_instances ) txn.execute(sql % (clause,), args) for typ, stream_id in min_positions.items(): self.db_pool.simple_upsert_txn( txn, table="federation_stream_position", keyvalues={"type": typ, "instance_name": self._instance_name}, values={"stream_id": stream_id}, ) def has_room_changed_since(self, room_id: str, stream_id: int) -> bool: return self._events_stream_cache.has_entity_changed(room_id, stream_id) def _paginate_room_events_txn( self, txn, room_id: str, from_token: RoomStreamToken, to_token: Optional[RoomStreamToken] = None, direction: str = "b", limit: int = -1, event_filter: Optional[Filter] = None, ) -> Tuple[List[_EventDictReturn], str]: """Returns list of events before or after a given token. Args: txn room_id from_token: The token used to stream from to_token: A token which if given limits the results to only those before direction: Either 'b' or 'f' to indicate whether we are paginating forwards or backwards from `from_key`. limit: The maximum number of events to return. event_filter: If provided filters the events to those that match the filter. Returns: A list of _EventDictReturn and a token that points to the end of the result set. If no events are returned then the end of the stream has been reached (i.e. there are no events between `from_token` and `to_token`), or `limit` is zero. """ assert int(limit) >= 0 # Tokens really represent positions between elements, but we use # the convention of pointing to the event before the gap. Hence # we have a bit of asymmetry when it comes to equalities. args = [False, room_id] if direction == "b": order = "DESC" else: order = "ASC" bounds = generate_pagination_where_clause( direction=direction, column_names=("topological_ordering", "stream_ordering"), from_token=from_token, to_token=to_token, engine=self.database_engine, ) filter_clause, filter_args = filter_to_clause(event_filter) if filter_clause: bounds += " AND " + filter_clause args.extend(filter_args) args.append(int(limit)) select_keywords = "SELECT" join_clause = "" if event_filter and event_filter.labels: # If we're not filtering on a label, then joining on event_labels will # return as many row for a single event as the number of labels it has. To # avoid this, only join if we're filtering on at least one label. join_clause = """ LEFT JOIN event_labels USING (event_id, room_id, topological_ordering) """ if len(event_filter.labels) > 1: # Using DISTINCT in this SELECT query is quite expensive, because it # requires the engine to sort on the entire (not limited) result set, # i.e. the entire events table. We only need to use it when we're # filtering on more than two labels, because that's the only scenario # in which we can possibly to get multiple times the same event ID in # the results. select_keywords += "DISTINCT" sql = """ %(select_keywords)s event_id, topological_ordering, stream_ordering FROM events %(join_clause)s WHERE outlier = ? AND room_id = ? AND %(bounds)s ORDER BY topological_ordering %(order)s, stream_ordering %(order)s LIMIT ? """ % { "select_keywords": select_keywords, "join_clause": join_clause, "bounds": bounds, "order": order, } txn.execute(sql, args) rows = [_EventDictReturn(row[0], row[1], row[2]) for row in txn] if rows: topo = rows[-1].topological_ordering toke = rows[-1].stream_ordering if direction == "b": # Tokens are positions between events. # This token points *after* the last event in the chunk. # We need it to point to the event before it in the chunk # when we are going backwards so we subtract one from the # stream part. toke -= 1 next_token = RoomStreamToken(topo, toke) else: # TODO (erikj): We should work out what to do here instead. next_token = to_token if to_token else from_token return rows, str(next_token) async def paginate_room_events( self, room_id: str, from_key: str, to_key: Optional[str] = None, direction: str = "b", limit: int = -1, event_filter: Optional[Filter] = None, ) -> Tuple[List[EventBase], str]: """Returns list of events before or after a given token. Args: room_id from_key: The token used to stream from to_key: A token which if given limits the results to only those before direction: Either 'b' or 'f' to indicate whether we are paginating forwards or backwards from `from_key`. limit: The maximum number of events to return. event_filter: If provided filters the events to those that match the filter. Returns: The results as a list of events and a token that points to the end of the result set. If no events are returned then the end of the stream has been reached (i.e. there are no events between `from_key` and `to_key`). """ from_key = RoomStreamToken.parse(from_key) if to_key: to_key = RoomStreamToken.parse(to_key) rows, token = await self.db_pool.runInteraction( "paginate_room_events", self._paginate_room_events_txn, room_id, from_key, to_key, direction, limit, event_filter, ) events = await self.get_events_as_list( [r.event_id for r in rows], get_prev_content=True ) self._set_before_and_after(events, rows) return (events, token)
class PushRulesWorkerStore( ApplicationServiceWorkerStore, ReceiptsWorkerStore, PusherWorkerStore, RoomMemberWorkerStore, SQLBaseStore, ): """This is an abstract base class where subclasses must implement `get_max_push_rules_stream_id` which can be called in the initializer. """ # This ABCMeta metaclass ensures that we cannot be instantiated without # the abstract methods being implemented. __metaclass__ = abc.ABCMeta def __init__(self, db_conn, hs): super(PushRulesWorkerStore, self).__init__(db_conn, hs) push_rules_prefill, push_rules_id = self._get_cache_dict( db_conn, "push_rules_stream", entity_column="user_id", stream_column="stream_id", max_value=self.get_max_push_rules_stream_id(), ) self.push_rules_stream_cache = StreamChangeCache( "PushRulesStreamChangeCache", push_rules_id, prefilled_cache=push_rules_prefill, ) @abc.abstractmethod def get_max_push_rules_stream_id(self): """Get the position of the push rules stream. Returns: int """ raise NotImplementedError() @cachedInlineCallbacks(max_entries=5000) def get_push_rules_for_user(self, user_id): rows = yield self._simple_select_list( table="push_rules", keyvalues={"user_name": user_id}, retcols=( "user_name", "rule_id", "priority_class", "priority", "conditions", "actions", ), desc="get_push_rules_enabled_for_user", ) rows.sort(key=lambda row: (-int(row["priority_class"]), -int(row["priority"]))) enabled_map = yield self.get_push_rules_enabled_for_user(user_id) rules = _load_rules(rows, enabled_map) defer.returnValue(rules) @cachedInlineCallbacks(max_entries=5000) def get_push_rules_enabled_for_user(self, user_id): results = yield self._simple_select_list( table="push_rules_enable", keyvalues={'user_name': user_id}, retcols=("user_name", "rule_id", "enabled"), desc="get_push_rules_enabled_for_user", ) defer.returnValue( {r['rule_id']: False if r['enabled'] == 0 else True for r in results} ) def have_push_rules_changed_for_user(self, user_id, last_id): if not self.push_rules_stream_cache.has_entity_changed(user_id, last_id): return defer.succeed(False) else: def have_push_rules_changed_txn(txn): sql = ( "SELECT COUNT(stream_id) FROM push_rules_stream" " WHERE user_id = ? AND ? < stream_id" ) txn.execute(sql, (user_id, last_id)) count, = txn.fetchone() return bool(count) return self.runInteraction( "have_push_rules_changed", have_push_rules_changed_txn ) @cachedList( cached_method_name="get_push_rules_for_user", list_name="user_ids", num_args=1, inlineCallbacks=True, ) def bulk_get_push_rules(self, user_ids): if not user_ids: defer.returnValue({}) results = {user_id: [] for user_id in user_ids} rows = yield self._simple_select_many_batch( table="push_rules", column="user_name", iterable=user_ids, retcols=("*",), desc="bulk_get_push_rules", ) rows.sort(key=lambda row: (-int(row["priority_class"]), -int(row["priority"]))) for row in rows: results.setdefault(row['user_name'], []).append(row) enabled_map_by_user = yield self.bulk_get_push_rules_enabled(user_ids) for user_id, rules in results.items(): results[user_id] = _load_rules(rules, enabled_map_by_user.get(user_id, {})) defer.returnValue(results) @defer.inlineCallbacks def move_push_rule_from_room_to_room(self, new_room_id, user_id, rule): """Move a single push rule from one room to another for a specific user. Args: new_room_id (str): ID of the new room. user_id (str): ID of user the push rule belongs to. rule (Dict): A push rule. """ # Create new rule id rule_id_scope = '/'.join(rule["rule_id"].split('/')[:-1]) new_rule_id = rule_id_scope + "/" + new_room_id # Change room id in each condition for condition in rule.get("conditions", []): if condition.get("key") == "room_id": condition["pattern"] = new_room_id # Add the rule for the new room yield self.add_push_rule( user_id=user_id, rule_id=new_rule_id, priority_class=rule["priority_class"], conditions=rule["conditions"], actions=rule["actions"], ) # Delete push rule for the old room yield self.delete_push_rule(user_id, rule["rule_id"]) @defer.inlineCallbacks def move_push_rules_from_room_to_room_for_user( self, old_room_id, new_room_id, user_id ): """Move all of the push rules from one room to another for a specific user. Args: old_room_id (str): ID of the old room. new_room_id (str): ID of the new room. user_id (str): ID of user to copy push rules for. """ # Retrieve push rules for this user user_push_rules = yield self.get_push_rules_for_user(user_id) # Get rules relating to the old room, move them to the new room, then # delete them from the old room for rule in user_push_rules: conditions = rule.get("conditions", []) if any( (c.get("key") == "room_id" and c.get("pattern") == old_room_id) for c in conditions ): self.move_push_rule_from_room_to_room(new_room_id, user_id, rule) @defer.inlineCallbacks def bulk_get_push_rules_for_room(self, event, context): state_group = context.state_group if not state_group: # If state_group is None it means it has yet to be assigned a # state group, i.e. we need to make sure that calls with a state_group # of None don't hit previous cached calls with a None state_group. # To do this we set the state_group to a new object as object() != object() state_group = object() current_state_ids = yield context.get_current_state_ids(self) result = yield self._bulk_get_push_rules_for_room( event.room_id, state_group, current_state_ids, event=event ) defer.returnValue(result) @cachedInlineCallbacks(num_args=2, cache_context=True) def _bulk_get_push_rules_for_room( self, room_id, state_group, current_state_ids, cache_context, event=None ): # We don't use `state_group`, its there so that we can cache based # on it. However, its important that its never None, since two current_state's # with a state_group of None are likely to be different. # See bulk_get_push_rules_for_room for how we work around this. assert state_group is not None # We also will want to generate notifs for other people in the room so # their unread countss are correct in the event stream, but to avoid # generating them for bot / AS users etc, we only do so for people who've # sent a read receipt into the room. users_in_room = yield self._get_joined_users_from_context( room_id, state_group, current_state_ids, on_invalidate=cache_context.invalidate, event=event, ) # We ignore app service users for now. This is so that we don't fill # up the `get_if_users_have_pushers` cache with AS entries that we # know don't have pushers, nor even read receipts. local_users_in_room = set( u for u in users_in_room if self.hs.is_mine_id(u) and not self.get_if_app_services_interested_in_user(u) ) # users in the room who have pushers need to get push rules run because # that's how their pushers work if_users_with_pushers = yield self.get_if_users_have_pushers( local_users_in_room, on_invalidate=cache_context.invalidate ) user_ids = set( uid for uid, have_pusher in if_users_with_pushers.items() if have_pusher ) users_with_receipts = yield self.get_users_with_read_receipts_in_room( room_id, on_invalidate=cache_context.invalidate ) # any users with pushers must be ours: they have pushers for uid in users_with_receipts: if uid in local_users_in_room: user_ids.add(uid) rules_by_user = yield self.bulk_get_push_rules( user_ids, on_invalidate=cache_context.invalidate ) rules_by_user = {k: v for k, v in rules_by_user.items() if v is not None} defer.returnValue(rules_by_user) @cachedList( cached_method_name="get_push_rules_enabled_for_user", list_name="user_ids", num_args=1, inlineCallbacks=True, ) def bulk_get_push_rules_enabled(self, user_ids): if not user_ids: defer.returnValue({}) results = {user_id: {} for user_id in user_ids} rows = yield self._simple_select_many_batch( table="push_rules_enable", column="user_name", iterable=user_ids, retcols=("user_name", "rule_id", "enabled"), desc="bulk_get_push_rules_enabled", ) for row in rows: enabled = bool(row['enabled']) results.setdefault(row['user_name'], {})[row['rule_id']] = enabled defer.returnValue(results)
class DeviceInboxWorkerStore(SQLBaseStore): def __init__( self, database: DatabasePool, db_conn: LoggingDatabaseConnection, hs: "HomeServer", ): super().__init__(database, db_conn, hs) self._instance_name = hs.get_instance_name() # Map of (user_id, device_id) to the last stream_id that has been # deleted up to. This is so that we can no op deletions. self._last_device_delete_cache: ExpiringCache[Tuple[ str, Optional[str]], int] = ExpiringCache( cache_name="last_device_delete_cache", clock=self._clock, max_len=10000, expiry_ms=30 * 60 * 1000, ) if isinstance(database.engine, PostgresEngine): self._can_write_to_device = (self._instance_name in hs.config.worker.writers.to_device) self._device_inbox_id_gen: AbstractStreamIdGenerator = ( MultiWriterIdGenerator( db_conn=db_conn, db=database, stream_name="to_device", instance_name=self._instance_name, tables=[("device_inbox", "instance_name", "stream_id")], sequence_name="device_inbox_sequence", writers=hs.config.worker.writers.to_device, )) else: self._can_write_to_device = True self._device_inbox_id_gen = StreamIdGenerator( db_conn, "device_inbox", "stream_id") max_device_inbox_id = self._device_inbox_id_gen.get_current_token() device_inbox_prefill, min_device_inbox_id = self.db_pool.get_cache_dict( db_conn, "device_inbox", entity_column="user_id", stream_column="stream_id", max_value=max_device_inbox_id, limit=1000, ) self._device_inbox_stream_cache = StreamChangeCache( "DeviceInboxStreamChangeCache", min_device_inbox_id, prefilled_cache=device_inbox_prefill, ) # The federation outbox and the local device inbox uses the same # stream_id generator. device_outbox_prefill, min_device_outbox_id = self.db_pool.get_cache_dict( db_conn, "device_federation_outbox", entity_column="destination", stream_column="stream_id", max_value=max_device_inbox_id, limit=1000, ) self._device_federation_outbox_stream_cache = StreamChangeCache( "DeviceFederationOutboxStreamChangeCache", min_device_outbox_id, prefilled_cache=device_outbox_prefill, ) def process_replication_rows( self, stream_name: str, instance_name: str, token: int, rows: Iterable[ToDeviceStream.ToDeviceStreamRow], ) -> None: if stream_name == ToDeviceStream.NAME: # If replication is happening than postgres must be being used. assert isinstance(self._device_inbox_id_gen, MultiWriterIdGenerator) self._device_inbox_id_gen.advance(instance_name, token) for row in rows: if row.entity.startswith("@"): self._device_inbox_stream_cache.entity_has_changed( row.entity, token) else: self._device_federation_outbox_stream_cache.entity_has_changed( row.entity, token) return super().process_replication_rows(stream_name, instance_name, token, rows) def get_to_device_stream_token(self) -> int: return self._device_inbox_id_gen.get_current_token() async def get_messages_for_user_devices( self, user_ids: Collection[str], from_stream_id: int, to_stream_id: int, ) -> Dict[Tuple[str, str], List[JsonDict]]: """ Retrieve to-device messages for a given set of users. Only to-device messages with stream ids between the given boundaries (from < X <= to) are returned. Args: user_ids: The users to retrieve to-device messages for. from_stream_id: The lower boundary of stream id to filter with (exclusive). to_stream_id: The upper boundary of stream id to filter with (inclusive). Returns: A dictionary of (user id, device id) -> list of to-device messages. """ # We expect the stream ID returned by _get_device_messages to always # be to_stream_id. So, no need to return it from this function. ( user_id_device_id_to_messages, last_processed_stream_id, ) = await self._get_device_messages( user_ids=user_ids, from_stream_id=from_stream_id, to_stream_id=to_stream_id, ) assert ( last_processed_stream_id == to_stream_id ), "Expected _get_device_messages to process all to-device messages up to `to_stream_id`" return user_id_device_id_to_messages async def get_messages_for_device( self, user_id: str, device_id: str, from_stream_id: int, to_stream_id: int, limit: int = 100, ) -> Tuple[List[JsonDict], int]: """ Retrieve to-device messages for a single user device. Only to-device messages with stream ids between the given boundaries (from < X <= to) are returned. Args: user_id: The ID of the user to retrieve messages for. device_id: The ID of the device to retrieve to-device messages for. from_stream_id: The lower boundary of stream id to filter with (exclusive). to_stream_id: The upper boundary of stream id to filter with (inclusive). limit: A limit on the number of to-device messages returned. Returns: A tuple containing: * A list of to-device messages within the given stream id range intended for the given user / device combo. * The last-processed stream ID. Subsequent calls of this function with the same device should pass this value as 'from_stream_id'. """ ( user_id_device_id_to_messages, last_processed_stream_id, ) = await self._get_device_messages( user_ids=[user_id], device_id=device_id, from_stream_id=from_stream_id, to_stream_id=to_stream_id, limit=limit, ) if not user_id_device_id_to_messages: # There were no messages! return [], to_stream_id # Extract the messages, no need to return the user and device ID again to_device_messages = user_id_device_id_to_messages.get( (user_id, device_id), []) return to_device_messages, last_processed_stream_id async def _get_device_messages( self, user_ids: Collection[str], from_stream_id: int, to_stream_id: int, device_id: Optional[str] = None, limit: Optional[int] = None, ) -> Tuple[Dict[Tuple[str, str], List[JsonDict]], int]: """ Retrieve pending to-device messages for a collection of user devices. Only to-device messages with stream ids between the given boundaries (from < X <= to) are returned. Note that a stream ID can be shared by multiple copies of the same message with different recipient devices. Stream IDs are only unique in the context of a single user ID / device ID pair. Thus, applying a limit (of messages to return) when working with a sliding window of stream IDs is only possible when querying messages of a single user device. Finally, note that device IDs are not unique across users. Args: user_ids: The user IDs to filter device messages by. from_stream_id: The lower boundary of stream id to filter with (exclusive). to_stream_id: The upper boundary of stream id to filter with (inclusive). device_id: A device ID to query to-device messages for. If not provided, to-device messages from all device IDs for the given user IDs will be queried. May not be provided if `user_ids` contains more than one entry. limit: The maximum number of to-device messages to return. Can only be used when passing a single user ID / device ID tuple. Returns: A tuple containing: * A dict of (user_id, device_id) -> list of to-device messages * The last-processed stream ID. If this is less than `to_stream_id`, then there may be more messages to retrieve. If `limit` is not set, then this is always equal to 'to_stream_id'. """ if not user_ids: logger.warning("No users provided upon querying for device IDs") return {}, to_stream_id # Prevent a query for one user's device also retrieving another user's device with # the same device ID (device IDs are not unique across users). if len(user_ids) > 1 and device_id is not None: raise AssertionError( "Programming error: 'device_id' cannot be supplied to " "_get_device_messages when >1 user_id has been provided") # A limit can only be applied when querying for a single user ID / device ID tuple. # See the docstring of this function for more details. if limit is not None and device_id is None: raise AssertionError( "Programming error: _get_device_messages was passed 'limit' " "without a specific user_id/device_id") user_ids_to_query: Set[str] = set() device_ids_to_query: Set[str] = set() # Note that a device ID could be an empty str if device_id is not None: # If a device ID was passed, use it to filter results. # Otherwise, device IDs will be derived from the given collection of user IDs. device_ids_to_query.add(device_id) # Determine which users have devices with pending messages for user_id in user_ids: if self._device_inbox_stream_cache.has_entity_changed( user_id, from_stream_id): # This user has new messages sent to them. Query messages for them user_ids_to_query.add(user_id) if not user_ids_to_query: return {}, to_stream_id def get_device_messages_txn( txn: LoggingTransaction, ) -> Tuple[Dict[Tuple[str, str], List[JsonDict]], int]: # Build a query to select messages from any of the given devices that # are between the given stream id bounds. # If a list of device IDs was not provided, retrieve all devices IDs # for the given users. We explicitly do not query hidden devices, as # hidden devices should not receive to-device messages. # Note that this is more efficient than just dropping `device_id` from the query, # since device_inbox has an index on `(user_id, device_id, stream_id)` if not device_ids_to_query: user_device_dicts = self.db_pool.simple_select_many_txn( txn, table="devices", column="user_id", iterable=user_ids_to_query, keyvalues={ "user_id": user_id, "hidden": False }, retcols=("device_id", ), ) device_ids_to_query.update( {row["device_id"] for row in user_device_dicts}) if not device_ids_to_query: # We've ended up with no devices to query. return {}, to_stream_id # We include both user IDs and device IDs in this query, as we have an index # (device_inbox_user_stream_id) for them. user_id_many_clause_sql, user_id_many_clause_args = make_in_list_sql_clause( self.database_engine, "user_id", user_ids_to_query) ( device_id_many_clause_sql, device_id_many_clause_args, ) = make_in_list_sql_clause(self.database_engine, "device_id", device_ids_to_query) sql = f""" SELECT stream_id, user_id, device_id, message_json FROM device_inbox WHERE {user_id_many_clause_sql} AND {device_id_many_clause_sql} AND ? < stream_id AND stream_id <= ? ORDER BY stream_id ASC """ sql_args = ( *user_id_many_clause_args, *device_id_many_clause_args, from_stream_id, to_stream_id, ) # If a limit was provided, limit the data retrieved from the database if limit is not None: sql += "LIMIT ?" sql_args += (limit, ) txn.execute(sql, sql_args) # Create and fill a dictionary of (user ID, device ID) -> list of messages # intended for each device. last_processed_stream_pos = to_stream_id recipient_device_to_messages: Dict[Tuple[str, str], List[JsonDict]] = {} rowcount = 0 for row in txn: rowcount += 1 last_processed_stream_pos = row[0] recipient_user_id = row[1] recipient_device_id = row[2] message_dict = db_to_json(row[3]) # Store the device details recipient_device_to_messages.setdefault( (recipient_user_id, recipient_device_id), []).append(message_dict) if limit is not None and rowcount == limit: # We ended up bumping up against the message limit. There may be more messages # to retrieve. Return what we have, as well as the last stream position that # was processed. # # The caller is expected to set this as the lower (exclusive) bound # for the next query of this device. return recipient_device_to_messages, last_processed_stream_pos # The limit was not reached, thus we know that recipient_device_to_messages # contains all to-device messages for the given device and stream id range. # # We return to_stream_id, which the caller should then provide as the lower # (exclusive) bound on the next query of this device. return recipient_device_to_messages, to_stream_id return await self.db_pool.runInteraction("get_device_messages", get_device_messages_txn) @trace async def delete_messages_for_device(self, user_id: str, device_id: Optional[str], up_to_stream_id: int) -> int: """ Args: user_id: The recipient user_id. device_id: The recipient device_id. up_to_stream_id: Where to delete messages up to. Returns: The number of messages deleted. """ # If we have cached the last stream id we've deleted up to, we can # check if there is likely to be anything that needs deleting last_deleted_stream_id = self._last_device_delete_cache.get( (user_id, device_id), None) set_tag("last_deleted_stream_id", last_deleted_stream_id) if last_deleted_stream_id: has_changed = self._device_inbox_stream_cache.has_entity_changed( user_id, last_deleted_stream_id) if not has_changed: log_kv({"message": "No changes in cache since last check"}) return 0 def delete_messages_for_device_txn(txn: LoggingTransaction) -> int: sql = ("DELETE FROM device_inbox" " WHERE user_id = ? AND device_id = ?" " AND stream_id <= ?") txn.execute(sql, (user_id, device_id, up_to_stream_id)) return txn.rowcount count = await self.db_pool.runInteraction( "delete_messages_for_device", delete_messages_for_device_txn) log_kv({ "message": f"deleted {count} messages for device", "count": count }) # Update the cache, ensuring that we only ever increase the value updated_last_deleted_stream_id = self._last_device_delete_cache.get( (user_id, device_id), 0) self._last_device_delete_cache[(user_id, device_id)] = max( updated_last_deleted_stream_id, up_to_stream_id) return count @trace async def get_new_device_msgs_for_remote( self, destination: str, last_stream_id: int, current_stream_id: int, limit: int) -> Tuple[List[JsonDict], int]: """ Args: destination: The name of the remote server. last_stream_id: The last position of the device message stream that the server sent up to. current_stream_id: The current position of the device message stream. Returns: A list of messages for the device and where in the stream the messages got to. """ set_tag("destination", destination) set_tag("last_stream_id", last_stream_id) set_tag("current_stream_id", current_stream_id) set_tag("limit", limit) has_changed = self._device_federation_outbox_stream_cache.has_entity_changed( destination, last_stream_id) if not has_changed or last_stream_id == current_stream_id: log_kv({"message": "No new messages in stream"}) return [], current_stream_id if limit <= 0: # This can happen if we run out of room for EDUs in the transaction. return [], last_stream_id @trace def get_new_messages_for_remote_destination_txn( txn: LoggingTransaction, ) -> Tuple[List[JsonDict], int]: sql = ( "SELECT stream_id, messages_json FROM device_federation_outbox" " WHERE destination = ?" " AND ? < stream_id AND stream_id <= ?" " ORDER BY stream_id ASC" " LIMIT ?") txn.execute( sql, (destination, last_stream_id, current_stream_id, limit)) messages = [] stream_pos = current_stream_id for row in txn: stream_pos = row[0] messages.append(db_to_json(row[1])) # If the limit was not reached we know that there's no more data for this # user/device pair up to current_stream_id. if len(messages) < limit: log_kv({"message": "Set stream position to current position"}) stream_pos = current_stream_id return messages, stream_pos return await self.db_pool.runInteraction( "get_new_device_msgs_for_remote", get_new_messages_for_remote_destination_txn, ) @trace async def delete_device_msgs_for_remote(self, destination: str, up_to_stream_id: int) -> None: """Used to delete messages when the remote destination acknowledges their receipt. Args: destination: The destination server_name up_to_stream_id: Where to delete messages up to. """ def delete_messages_for_remote_destination_txn( txn: LoggingTransaction) -> None: sql = ("DELETE FROM device_federation_outbox" " WHERE destination = ?" " AND stream_id <= ?") txn.execute(sql, (destination, up_to_stream_id)) await self.db_pool.runInteraction( "delete_device_msgs_for_remote", delete_messages_for_remote_destination_txn) async def get_all_new_device_messages( self, instance_name: str, last_id: int, current_id: int, limit: int) -> Tuple[List[Tuple[int, tuple]], int, bool]: """Get updates for to device replication stream. Args: instance_name: The writer we want to fetch updates from. Unused here since there is only ever one writer. last_id: The token to fetch updates from. Exclusive. current_id: The token to fetch updates up to. Inclusive. limit: The requested limit for the number of rows to return. The function may return more or fewer rows. Returns: A tuple consisting of: the updates, a token to use to fetch subsequent updates, and whether we returned fewer rows than exists between the requested tokens due to the limit. The token returned can be used in a subsequent call to this function to get further updatees. The updates are a list of 2-tuples of stream ID and the row data """ if last_id == current_id: return [], current_id, False def get_all_new_device_messages_txn( txn: LoggingTransaction, ) -> Tuple[List[Tuple[int, tuple]], int, bool]: # We limit like this as we might have multiple rows per stream_id, and # we want to make sure we always get all entries for any stream_id # we return. upper_pos = min(current_id, last_id + limit) sql = ("SELECT max(stream_id), user_id" " FROM device_inbox" " WHERE ? < stream_id AND stream_id <= ?" " GROUP BY user_id") txn.execute(sql, (last_id, upper_pos)) updates = [(row[0], row[1:]) for row in txn] sql = ("SELECT max(stream_id), destination" " FROM device_federation_outbox" " WHERE ? < stream_id AND stream_id <= ?" " GROUP BY destination") txn.execute(sql, (last_id, upper_pos)) updates.extend((row[0], row[1:]) for row in txn) # Order by ascending stream ordering updates.sort() limited = False upto_token = current_id if len(updates) >= limit: upto_token = updates[-1][0] limited = True return updates, upto_token, limited return await self.db_pool.runInteraction( "get_all_new_device_messages", get_all_new_device_messages_txn) @trace async def add_messages_to_device_inbox( self, local_messages_by_user_then_device: Dict[str, Dict[str, JsonDict]], remote_messages_by_destination: Dict[str, JsonDict], ) -> int: """Used to send messages from this server. Args: local_messages_by_user_then_device: Dictionary of recipient user_id to recipient device_id to message. remote_messages_by_destination: Dictionary of destination server_name to the EDU JSON to send. Returns: The new stream_id. """ assert self._can_write_to_device def add_messages_txn(txn: LoggingTransaction, now_ms: int, stream_id: int) -> None: # Add the local messages directly to the local inbox. self._add_messages_to_local_device_inbox_txn( txn, stream_id, local_messages_by_user_then_device) # Add the remote messages to the federation outbox. # We'll send them to a remote server when we next send a # federation transaction to that destination. self.db_pool.simple_insert_many_txn( txn, table="device_federation_outbox", keys=( "destination", "stream_id", "queued_ts", "messages_json", "instance_name", ), values=[( destination, stream_id, now_ms, json_encoder.encode(edu), self._instance_name, ) for destination, edu in remote_messages_by_destination.items()], ) if remote_messages_by_destination: issue9533_logger.debug( "Queued outgoing to-device messages with stream_id %i for %s", stream_id, list(remote_messages_by_destination.keys()), ) async with self._device_inbox_id_gen.get_next() as stream_id: now_ms = self._clock.time_msec() await self.db_pool.runInteraction("add_messages_to_device_inbox", add_messages_txn, now_ms, stream_id) for user_id in local_messages_by_user_then_device.keys(): self._device_inbox_stream_cache.entity_has_changed( user_id, stream_id) for destination in remote_messages_by_destination.keys(): self._device_federation_outbox_stream_cache.entity_has_changed( destination, stream_id) return self._device_inbox_id_gen.get_current_token() async def add_messages_from_remote_to_device_inbox( self, origin: str, message_id: str, local_messages_by_user_then_device: Dict[str, Dict[str, JsonDict]], ) -> int: assert self._can_write_to_device def add_messages_txn(txn: LoggingTransaction, now_ms: int, stream_id: int) -> None: # Check if we've already inserted a matching message_id for that # origin. This can happen if the origin doesn't receive our # acknowledgement from the first time we received the message. already_inserted = self.db_pool.simple_select_one_txn( txn, table="device_federation_inbox", keyvalues={ "origin": origin, "message_id": message_id }, retcols=("message_id", ), allow_none=True, ) if already_inserted is not None: return # Add an entry for this message_id so that we know we've processed # it. self.db_pool.simple_insert_txn( txn, table="device_federation_inbox", values={ "origin": origin, "message_id": message_id, "received_ts": now_ms, }, ) # Add the messages to the appropriate local device inboxes so that # they'll be sent to the devices when they next sync. self._add_messages_to_local_device_inbox_txn( txn, stream_id, local_messages_by_user_then_device) async with self._device_inbox_id_gen.get_next() as stream_id: now_ms = self._clock.time_msec() await self.db_pool.runInteraction( "add_messages_from_remote_to_device_inbox", add_messages_txn, now_ms, stream_id, ) for user_id in local_messages_by_user_then_device.keys(): self._device_inbox_stream_cache.entity_has_changed( user_id, stream_id) return stream_id def _add_messages_to_local_device_inbox_txn( self, txn: LoggingTransaction, stream_id: int, messages_by_user_then_device: Dict[str, Dict[str, JsonDict]], ) -> None: assert self._can_write_to_device local_by_user_then_device = {} for user_id, messages_by_device in messages_by_user_then_device.items( ): messages_json_for_user = {} devices = list(messages_by_device.keys()) if len(devices) == 1 and devices[0] == "*": # Handle wildcard device_ids. # We exclude hidden devices (such as cross-signing keys) here as they are # not expected to receive to-device messages. devices = self.db_pool.simple_select_onecol_txn( txn, table="devices", keyvalues={ "user_id": user_id, "hidden": False }, retcol="device_id", ) message_json = json_encoder.encode(messages_by_device["*"]) for device_id in devices: # Add the message for all devices for this user on this # server. messages_json_for_user[device_id] = message_json else: if not devices: continue # We exclude hidden devices (such as cross-signing keys) here as they are # not expected to receive to-device messages. rows = self.db_pool.simple_select_many_txn( txn, table="devices", keyvalues={ "user_id": user_id, "hidden": False }, column="device_id", iterable=devices, retcols=("device_id", ), ) for row in rows: # Only insert into the local inbox if the device exists on # this server device_id = row["device_id"] message_json = json_encoder.encode( messages_by_device[device_id]) messages_json_for_user[device_id] = message_json if messages_json_for_user: local_by_user_then_device[user_id] = messages_json_for_user if not local_by_user_then_device: return self.db_pool.simple_insert_many_txn( txn, table="device_inbox", keys=("user_id", "device_id", "stream_id", "message_json", "instance_name"), values=[(user_id, device_id, stream_id, message_json, self._instance_name) for user_id, messages_by_device in local_by_user_then_device.items() for device_id, message_json in messages_by_device.items()], ) issue9533_logger.debug( "Stored to-device messages with stream_id %i for %s", stream_id, [(user_id, device_id) for (user_id, messages_by_device) in local_by_user_then_device.items() for device_id in messages_by_device.keys()], )
class DeviceInboxWorkerStore(SQLBaseStore): def __init__(self, database: DatabasePool, db_conn, hs): super().__init__(database, db_conn, hs) self._instance_name = hs.get_instance_name() # Map of (user_id, device_id) to the last stream_id that has been # deleted up to. This is so that we can no op deletions. self._last_device_delete_cache = ExpiringCache( cache_name="last_device_delete_cache", clock=self._clock, max_len=10000, expiry_ms=30 * 60 * 1000, ) if isinstance(database.engine, PostgresEngine): self._can_write_to_device = (self._instance_name in hs.config.worker.writers.to_device) self._device_inbox_id_gen = MultiWriterIdGenerator( db_conn=db_conn, db=database, stream_name="to_device", instance_name=self._instance_name, tables=[("device_inbox", "instance_name", "stream_id")], sequence_name="device_inbox_sequence", writers=hs.config.worker.writers.to_device, ) else: self._can_write_to_device = True self._device_inbox_id_gen = StreamIdGenerator( db_conn, "device_inbox", "stream_id") max_device_inbox_id = self._device_inbox_id_gen.get_current_token() device_inbox_prefill, min_device_inbox_id = self.db_pool.get_cache_dict( db_conn, "device_inbox", entity_column="user_id", stream_column="stream_id", max_value=max_device_inbox_id, limit=1000, ) self._device_inbox_stream_cache = StreamChangeCache( "DeviceInboxStreamChangeCache", min_device_inbox_id, prefilled_cache=device_inbox_prefill, ) # The federation outbox and the local device inbox uses the same # stream_id generator. device_outbox_prefill, min_device_outbox_id = self.db_pool.get_cache_dict( db_conn, "device_federation_outbox", entity_column="destination", stream_column="stream_id", max_value=max_device_inbox_id, limit=1000, ) self._device_federation_outbox_stream_cache = StreamChangeCache( "DeviceFederationOutboxStreamChangeCache", min_device_outbox_id, prefilled_cache=device_outbox_prefill, ) def process_replication_rows(self, stream_name, instance_name, token, rows): if stream_name == ToDeviceStream.NAME: self._device_inbox_id_gen.advance(instance_name, token) for row in rows: if row.entity.startswith("@"): self._device_inbox_stream_cache.entity_has_changed( row.entity, token) else: self._device_federation_outbox_stream_cache.entity_has_changed( row.entity, token) return super().process_replication_rows(stream_name, instance_name, token, rows) def get_to_device_stream_token(self): return self._device_inbox_id_gen.get_current_token() async def get_new_messages_for_device( self, user_id: str, device_id: str, last_stream_id: int, current_stream_id: int, limit: int = 100, ) -> Tuple[List[dict], int]: """ Args: user_id: The recipient user_id. device_id: The recipient device_id. last_stream_id: The last stream ID checked. current_stream_id: The current position of the to device message stream. limit: The maximum number of messages to retrieve. Returns: A list of messages for the device and where in the stream the messages got to. """ has_changed = self._device_inbox_stream_cache.has_entity_changed( user_id, last_stream_id) if not has_changed: return ([], current_stream_id) def get_new_messages_for_device_txn(txn): sql = ("SELECT stream_id, message_json FROM device_inbox" " WHERE user_id = ? AND device_id = ?" " AND ? < stream_id AND stream_id <= ?" " ORDER BY stream_id ASC" " LIMIT ?") txn.execute( sql, (user_id, device_id, last_stream_id, current_stream_id, limit)) messages = [] for row in txn: stream_pos = row[0] messages.append(db_to_json(row[1])) if len(messages) < limit: stream_pos = current_stream_id return messages, stream_pos return await self.db_pool.runInteraction( "get_new_messages_for_device", get_new_messages_for_device_txn) @trace async def delete_messages_for_device(self, user_id: str, device_id: str, up_to_stream_id: int) -> int: """ Args: user_id: The recipient user_id. device_id: The recipient device_id. up_to_stream_id: Where to delete messages up to. Returns: The number of messages deleted. """ # If we have cached the last stream id we've deleted up to, we can # check if there is likely to be anything that needs deleting last_deleted_stream_id = self._last_device_delete_cache.get( (user_id, device_id), None) set_tag("last_deleted_stream_id", last_deleted_stream_id) if last_deleted_stream_id: has_changed = self._device_inbox_stream_cache.has_entity_changed( user_id, last_deleted_stream_id) if not has_changed: log_kv({"message": "No changes in cache since last check"}) return 0 def delete_messages_for_device_txn(txn): sql = ("DELETE FROM device_inbox" " WHERE user_id = ? AND device_id = ?" " AND stream_id <= ?") txn.execute(sql, (user_id, device_id, up_to_stream_id)) return txn.rowcount count = await self.db_pool.runInteraction( "delete_messages_for_device", delete_messages_for_device_txn) log_kv({ "message": "deleted {} messages for device".format(count), "count": count }) # Update the cache, ensuring that we only ever increase the value last_deleted_stream_id = self._last_device_delete_cache.get( (user_id, device_id), 0) self._last_device_delete_cache[(user_id, device_id)] = max( last_deleted_stream_id, up_to_stream_id) return count @trace async def get_new_device_msgs_for_remote(self, destination, last_stream_id, current_stream_id, limit) -> Tuple[List[dict], int]: """ Args: destination(str): The name of the remote server. last_stream_id(int|long): The last position of the device message stream that the server sent up to. current_stream_id(int|long): The current position of the device message stream. Returns: A list of messages for the device and where in the stream the messages got to. """ set_tag("destination", destination) set_tag("last_stream_id", last_stream_id) set_tag("current_stream_id", current_stream_id) set_tag("limit", limit) has_changed = self._device_federation_outbox_stream_cache.has_entity_changed( destination, last_stream_id) if not has_changed or last_stream_id == current_stream_id: log_kv({"message": "No new messages in stream"}) return ([], current_stream_id) if limit <= 0: # This can happen if we run out of room for EDUs in the transaction. return ([], last_stream_id) @trace def get_new_messages_for_remote_destination_txn(txn): sql = ( "SELECT stream_id, messages_json FROM device_federation_outbox" " WHERE destination = ?" " AND ? < stream_id AND stream_id <= ?" " ORDER BY stream_id ASC" " LIMIT ?") txn.execute( sql, (destination, last_stream_id, current_stream_id, limit)) messages = [] for row in txn: stream_pos = row[0] messages.append(db_to_json(row[1])) if len(messages) < limit: log_kv({"message": "Set stream position to current position"}) stream_pos = current_stream_id return messages, stream_pos return await self.db_pool.runInteraction( "get_new_device_msgs_for_remote", get_new_messages_for_remote_destination_txn, ) @trace async def delete_device_msgs_for_remote(self, destination: str, up_to_stream_id: int) -> None: """Used to delete messages when the remote destination acknowledges their receipt. Args: destination: The destination server_name up_to_stream_id: Where to delete messages up to. """ def delete_messages_for_remote_destination_txn(txn): sql = ("DELETE FROM device_federation_outbox" " WHERE destination = ?" " AND stream_id <= ?") txn.execute(sql, (destination, up_to_stream_id)) await self.db_pool.runInteraction( "delete_device_msgs_for_remote", delete_messages_for_remote_destination_txn) async def get_all_new_device_messages( self, instance_name: str, last_id: int, current_id: int, limit: int) -> Tuple[List[Tuple[int, tuple]], int, bool]: """Get updates for to device replication stream. Args: instance_name: The writer we want to fetch updates from. Unused here since there is only ever one writer. last_id: The token to fetch updates from. Exclusive. current_id: The token to fetch updates up to. Inclusive. limit: The requested limit for the number of rows to return. The function may return more or fewer rows. Returns: A tuple consisting of: the updates, a token to use to fetch subsequent updates, and whether we returned fewer rows than exists between the requested tokens due to the limit. The token returned can be used in a subsequent call to this function to get further updatees. The updates are a list of 2-tuples of stream ID and the row data """ if last_id == current_id: return [], current_id, False def get_all_new_device_messages_txn(txn): # We limit like this as we might have multiple rows per stream_id, and # we want to make sure we always get all entries for any stream_id # we return. upper_pos = min(current_id, last_id + limit) sql = ("SELECT max(stream_id), user_id" " FROM device_inbox" " WHERE ? < stream_id AND stream_id <= ?" " GROUP BY user_id") txn.execute(sql, (last_id, upper_pos)) updates = [(row[0], row[1:]) for row in txn] sql = ("SELECT max(stream_id), destination" " FROM device_federation_outbox" " WHERE ? < stream_id AND stream_id <= ?" " GROUP BY destination") txn.execute(sql, (last_id, upper_pos)) updates.extend((row[0], row[1:]) for row in txn) # Order by ascending stream ordering updates.sort() limited = False upto_token = current_id if len(updates) >= limit: upto_token = updates[-1][0] limited = True return updates, upto_token, limited return await self.db_pool.runInteraction( "get_all_new_device_messages", get_all_new_device_messages_txn) @trace async def add_messages_to_device_inbox( self, local_messages_by_user_then_device: dict, remote_messages_by_destination: dict, ) -> int: """Used to send messages from this server. Args: local_messages_by_user_and_device: Dictionary of user_id to device_id to message. remote_messages_by_destination: Dictionary of destination server_name to the EDU JSON to send. Returns: The new stream_id. """ assert self._can_write_to_device def add_messages_txn(txn, now_ms, stream_id): # Add the local messages directly to the local inbox. self._add_messages_to_local_device_inbox_txn( txn, stream_id, local_messages_by_user_then_device) # Add the remote messages to the federation outbox. # We'll send them to a remote server when we next send a # federation transaction to that destination. self.db_pool.simple_insert_many_txn( txn, table="device_federation_outbox", values=[{ "destination": destination, "stream_id": stream_id, "queued_ts": now_ms, "messages_json": json_encoder.encode(edu), "instance_name": self._instance_name, } for destination, edu in remote_messages_by_destination.items()], ) async with self._device_inbox_id_gen.get_next() as stream_id: now_ms = self.clock.time_msec() await self.db_pool.runInteraction("add_messages_to_device_inbox", add_messages_txn, now_ms, stream_id) for user_id in local_messages_by_user_then_device.keys(): self._device_inbox_stream_cache.entity_has_changed( user_id, stream_id) for destination in remote_messages_by_destination.keys(): self._device_federation_outbox_stream_cache.entity_has_changed( destination, stream_id) return self._device_inbox_id_gen.get_current_token() async def add_messages_from_remote_to_device_inbox( self, origin: str, message_id: str, local_messages_by_user_then_device: dict) -> int: assert self._can_write_to_device def add_messages_txn(txn, now_ms, stream_id): # Check if we've already inserted a matching message_id for that # origin. This can happen if the origin doesn't receive our # acknowledgement from the first time we received the message. already_inserted = self.db_pool.simple_select_one_txn( txn, table="device_federation_inbox", keyvalues={ "origin": origin, "message_id": message_id }, retcols=("message_id", ), allow_none=True, ) if already_inserted is not None: return # Add an entry for this message_id so that we know we've processed # it. self.db_pool.simple_insert_txn( txn, table="device_federation_inbox", values={ "origin": origin, "message_id": message_id, "received_ts": now_ms, }, ) # Add the messages to the approriate local device inboxes so that # they'll be sent to the devices when they next sync. self._add_messages_to_local_device_inbox_txn( txn, stream_id, local_messages_by_user_then_device) async with self._device_inbox_id_gen.get_next() as stream_id: now_ms = self.clock.time_msec() await self.db_pool.runInteraction( "add_messages_from_remote_to_device_inbox", add_messages_txn, now_ms, stream_id, ) for user_id in local_messages_by_user_then_device.keys(): self._device_inbox_stream_cache.entity_has_changed( user_id, stream_id) return stream_id def _add_messages_to_local_device_inbox_txn(self, txn, stream_id, messages_by_user_then_device): assert self._can_write_to_device local_by_user_then_device = {} for user_id, messages_by_device in messages_by_user_then_device.items( ): messages_json_for_user = {} devices = list(messages_by_device.keys()) if len(devices) == 1 and devices[0] == "*": # Handle wildcard device_ids. devices = self.db_pool.simple_select_onecol_txn( txn, table="devices", keyvalues={"user_id": user_id}, retcol="device_id", ) message_json = json_encoder.encode(messages_by_device["*"]) for device_id in devices: # Add the message for all devices for this user on this # server. messages_json_for_user[device_id] = message_json else: if not devices: continue rows = self.db_pool.simple_select_many_txn( txn, table="devices", keyvalues={"user_id": user_id}, column="device_id", iterable=devices, retcols=("device_id", ), ) for row in rows: # Only insert into the local inbox if the device exists on # this server device_id = row["device_id"] message_json = json_encoder.encode( messages_by_device[device_id]) messages_json_for_user[device_id] = message_json if messages_json_for_user: local_by_user_then_device[user_id] = messages_json_for_user if not local_by_user_then_device: return self.db_pool.simple_insert_many_txn( txn, table="device_inbox", values=[{ "user_id": user_id, "device_id": device_id, "stream_id": stream_id, "message_json": message_json, "instance_name": self._instance_name, } for user_id, messages_by_device in local_by_user_then_device.items() for device_id, message_json in messages_by_device.items()], )
class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): """This is an abstract base class where subclasses must implement `get_room_max_stream_ordering` and `get_room_min_stream_ordering` which can be called in the initializer. """ __metaclass__ = abc.ABCMeta def __init__(self, db_conn, hs): super(StreamWorkerStore, self).__init__(db_conn, hs) events_max = self.get_room_max_stream_ordering() event_cache_prefill, min_event_val = self._get_cache_dict( db_conn, "events", entity_column="room_id", stream_column="stream_ordering", max_value=events_max, ) self._events_stream_cache = StreamChangeCache( "EventsRoomStreamChangeCache", min_event_val, prefilled_cache=event_cache_prefill, ) self._membership_stream_cache = StreamChangeCache( "MembershipStreamChangeCache", events_max ) self._stream_order_on_start = self.get_room_max_stream_ordering() @abc.abstractmethod def get_room_max_stream_ordering(self): raise NotImplementedError() @abc.abstractmethod def get_room_min_stream_ordering(self): raise NotImplementedError() @defer.inlineCallbacks def get_room_events_stream_for_rooms( self, room_ids, from_key, to_key, limit=0, order='DESC' ): """Get new room events in stream ordering since `from_key`. Args: room_id (str) from_key (str): Token from which no events are returned before to_key (str): Token from which no events are returned after. (This is typically the current stream token) limit (int): Maximum number of events to return order (str): Either "DESC" or "ASC". Determines which events are returned when the result is limited. If "DESC" then the most recent `limit` events are returned, otherwise returns the oldest `limit` events. Returns: Deferred[dict[str,tuple[list[FrozenEvent], str]]] A map from room id to a tuple containing: - list of recent events in the room - stream ordering key for the start of the chunk of events returned. """ from_id = RoomStreamToken.parse_stream_token(from_key).stream room_ids = yield self._events_stream_cache.get_entities_changed( room_ids, from_id ) if not room_ids: defer.returnValue({}) results = {} room_ids = list(room_ids) for rm_ids in (room_ids[i : i + 20] for i in range(0, len(room_ids), 20)): res = yield make_deferred_yieldable( defer.gatherResults( [ run_in_background( self.get_room_events_stream_for_room, room_id, from_key, to_key, limit, order=order, ) for room_id in rm_ids ], consumeErrors=True, ) ) results.update(dict(zip(rm_ids, res))) defer.returnValue(results) def get_rooms_that_changed(self, room_ids, from_key): """Given a list of rooms and a token, return rooms where there may have been changes. Args: room_ids (list) from_key (str): The room_key portion of a StreamToken """ from_key = RoomStreamToken.parse_stream_token(from_key).stream return set( room_id for room_id in room_ids if self._events_stream_cache.has_entity_changed(room_id, from_key) ) @defer.inlineCallbacks def get_room_events_stream_for_room( self, room_id, from_key, to_key, limit=0, order='DESC' ): """Get new room events in stream ordering since `from_key`. Args: room_id (str) from_key (str): Token from which no events are returned before to_key (str): Token from which no events are returned after. (This is typically the current stream token) limit (int): Maximum number of events to return order (str): Either "DESC" or "ASC". Determines which events are returned when the result is limited. If "DESC" then the most recent `limit` events are returned, otherwise returns the oldest `limit` events. Returns: Deferred[tuple[list[FrozenEvent], str]]: Returns the list of events (in ascending order) and the token from the start of the chunk of events returned. """ if from_key == to_key: defer.returnValue(([], from_key)) from_id = RoomStreamToken.parse_stream_token(from_key).stream to_id = RoomStreamToken.parse_stream_token(to_key).stream has_changed = yield self._events_stream_cache.has_entity_changed( room_id, from_id ) if not has_changed: defer.returnValue(([], from_key)) def f(txn): sql = ( "SELECT event_id, stream_ordering FROM events WHERE" " room_id = ?" " AND not outlier" " AND stream_ordering > ? AND stream_ordering <= ?" " ORDER BY stream_ordering %s LIMIT ?" ) % (order,) txn.execute(sql, (room_id, from_id, to_id, limit)) rows = [_EventDictReturn(row[0], None, row[1]) for row in txn] return rows rows = yield self.runInteraction("get_room_events_stream_for_room", f) ret = yield self.get_events_as_list([ r.event_id for r in rows], get_prev_content=True, ) self._set_before_and_after(ret, rows, topo_order=from_id is None) if order.lower() == "desc": ret.reverse() if rows: key = "s%d" % min(r.stream_ordering for r in rows) else: # Assume we didn't get anything because there was nothing to # get. key = from_key defer.returnValue((ret, key)) @defer.inlineCallbacks def get_membership_changes_for_user(self, user_id, from_key, to_key): from_id = RoomStreamToken.parse_stream_token(from_key).stream to_id = RoomStreamToken.parse_stream_token(to_key).stream if from_key == to_key: defer.returnValue([]) if from_id: has_changed = self._membership_stream_cache.has_entity_changed( user_id, int(from_id) ) if not has_changed: defer.returnValue([]) def f(txn): sql = ( "SELECT m.event_id, stream_ordering FROM events AS e," " room_memberships AS m" " WHERE e.event_id = m.event_id" " AND m.user_id = ?" " AND e.stream_ordering > ? AND e.stream_ordering <= ?" " ORDER BY e.stream_ordering ASC" ) txn.execute(sql, (user_id, from_id, to_id)) rows = [_EventDictReturn(row[0], None, row[1]) for row in txn] return rows rows = yield self.runInteraction("get_membership_changes_for_user", f) ret = yield self.get_events_as_list( [r.event_id for r in rows], get_prev_content=True, ) self._set_before_and_after(ret, rows, topo_order=False) defer.returnValue(ret) @defer.inlineCallbacks def get_recent_events_for_room(self, room_id, limit, end_token): """Get the most recent events in the room in topological ordering. Args: room_id (str) limit (int) end_token (str): The stream token representing now. Returns: Deferred[tuple[list[FrozenEvent], str]]: Returns a list of events and a token pointing to the start of the returned events. The events returned are in ascending order. """ rows, token = yield self.get_recent_event_ids_for_room( room_id, limit, end_token ) logger.debug("stream before") events = yield self.get_events_as_list( [r.event_id for r in rows], get_prev_content=True ) logger.debug("stream after") self._set_before_and_after(events, rows) defer.returnValue((events, token)) @defer.inlineCallbacks def get_recent_event_ids_for_room(self, room_id, limit, end_token): """Get the most recent events in the room in topological ordering. Args: room_id (str) limit (int) end_token (str): The stream token representing now. Returns: Deferred[tuple[list[_EventDictReturn], str]]: Returns a list of _EventDictReturn and a token pointing to the start of the returned events. The events returned are in ascending order. """ # Allow a zero limit here, and no-op. if limit == 0: defer.returnValue(([], end_token)) end_token = RoomStreamToken.parse(end_token) rows, token = yield self.runInteraction( "get_recent_event_ids_for_room", self._paginate_room_events_txn, room_id, from_token=end_token, limit=limit, ) # We want to return the results in ascending order. rows.reverse() defer.returnValue((rows, token)) def get_room_event_after_stream_ordering(self, room_id, stream_ordering): """Gets details of the first event in a room at or after a stream ordering Args: room_id (str): stream_ordering (int): Returns: Deferred[(int, int, str)]: (stream ordering, topological ordering, event_id) """ def _f(txn): sql = ( "SELECT stream_ordering, topological_ordering, event_id" " FROM events" " WHERE room_id = ? AND stream_ordering >= ?" " AND NOT outlier" " ORDER BY stream_ordering" " LIMIT 1" ) txn.execute(sql, (room_id, stream_ordering)) return txn.fetchone() return self.runInteraction("get_room_event_after_stream_ordering", _f) @defer.inlineCallbacks def get_room_events_max_id(self, room_id=None): """Returns the current token for rooms stream. By default, it returns the current global stream token. Specifying a `room_id` causes it to return the current room specific topological token. """ token = yield self.get_room_max_stream_ordering() if room_id is None: defer.returnValue("s%d" % (token,)) else: topo = yield self.runInteraction( "_get_max_topological_txn", self._get_max_topological_txn, room_id ) defer.returnValue("t%d-%d" % (topo, token)) def get_stream_token_for_event(self, event_id): """The stream token for an event Args: event_id(str): The id of the event to look up a stream token for. Raises: StoreError if the event wasn't in the database. Returns: A deferred "s%d" stream token. """ return self._simple_select_one_onecol( table="events", keyvalues={"event_id": event_id}, retcol="stream_ordering" ).addCallback(lambda row: "s%d" % (row,)) def get_topological_token_for_event(self, event_id): """The stream token for an event Args: event_id(str): The id of the event to look up a stream token for. Raises: StoreError if the event wasn't in the database. Returns: A deferred "t%d-%d" topological token. """ return self._simple_select_one( table="events", keyvalues={"event_id": event_id}, retcols=("stream_ordering", "topological_ordering"), desc="get_topological_token_for_event", ).addCallback( lambda row: "t%d-%d" % (row["topological_ordering"], row["stream_ordering"]) ) def get_max_topological_token(self, room_id, stream_key): """Get the max topological token in a room before the given stream ordering. Args: room_id (str) stream_key (int) Returns: Deferred[int] """ sql = ( "SELECT coalesce(max(topological_ordering), 0) FROM events" " WHERE room_id = ? AND stream_ordering < ?" ) return self._execute( "get_max_topological_token", None, sql, room_id, stream_key ).addCallback(lambda r: r[0][0] if r else 0) def _get_max_topological_txn(self, txn, room_id): txn.execute( "SELECT MAX(topological_ordering) FROM events" " WHERE room_id = ?", (room_id,), ) rows = txn.fetchall() return rows[0][0] if rows else 0 @staticmethod def _set_before_and_after(events, rows, topo_order=True): """Inserts ordering information to events' internal metadata from the DB rows. Args: events (list[FrozenEvent]) rows (list[_EventDictReturn]) topo_order (bool): Whether the events were ordered topologically or by stream ordering. If true then all rows should have a non null topological_ordering. """ for event, row in zip(events, rows): stream = row.stream_ordering if topo_order and row.topological_ordering: topo = row.topological_ordering else: topo = None internal = event.internal_metadata internal.before = str(RoomStreamToken(topo, stream - 1)) internal.after = str(RoomStreamToken(topo, stream)) internal.order = (int(topo) if topo else 0, int(stream)) @defer.inlineCallbacks def get_events_around( self, room_id, event_id, before_limit, after_limit, event_filter=None ): """Retrieve events and pagination tokens around a given event in a room. Args: room_id (str) event_id (str) before_limit (int) after_limit (int) event_filter (Filter|None) Returns: dict """ results = yield self.runInteraction( "get_events_around", self._get_events_around_txn, room_id, event_id, before_limit, after_limit, event_filter, ) events_before = yield self.get_events_as_list( [e for e in results["before"]["event_ids"]], get_prev_content=True ) events_after = yield self.get_events_as_list( [e for e in results["after"]["event_ids"]], get_prev_content=True ) defer.returnValue( { "events_before": events_before, "events_after": events_after, "start": results["before"]["token"], "end": results["after"]["token"], } ) def _get_events_around_txn( self, txn, room_id, event_id, before_limit, after_limit, event_filter ): """Retrieves event_ids and pagination tokens around a given event in a room. Args: room_id (str) event_id (str) before_limit (int) after_limit (int) event_filter (Filter|None) Returns: dict """ results = self._simple_select_one_txn( txn, "events", keyvalues={"event_id": event_id, "room_id": room_id}, retcols=["stream_ordering", "topological_ordering"], ) # Paginating backwards includes the event at the token, but paginating # forward doesn't. before_token = RoomStreamToken( results["topological_ordering"] - 1, results["stream_ordering"] ) after_token = RoomStreamToken( results["topological_ordering"], results["stream_ordering"] ) rows, start_token = self._paginate_room_events_txn( txn, room_id, before_token, direction='b', limit=before_limit, event_filter=event_filter, ) events_before = [r.event_id for r in rows] rows, end_token = self._paginate_room_events_txn( txn, room_id, after_token, direction='f', limit=after_limit, event_filter=event_filter, ) events_after = [r.event_id for r in rows] return { "before": {"event_ids": events_before, "token": start_token}, "after": {"event_ids": events_after, "token": end_token}, } @defer.inlineCallbacks def get_all_new_events_stream(self, from_id, current_id, limit): """Get all new events Returns all events with from_id < stream_ordering <= current_id. Args: from_id (int): the stream_ordering of the last event we processed current_id (int): the stream_ordering of the most recently processed event limit (int): the maximum number of events to return Returns: Deferred[Tuple[int, list[FrozenEvent]]]: A tuple of (next_id, events), where `next_id` is the next value to pass as `from_id` (it will either be the stream_ordering of the last returned event, or, if fewer than `limit` events were found, `current_id`. """ def get_all_new_events_stream_txn(txn): sql = ( "SELECT e.stream_ordering, e.event_id" " FROM events AS e" " WHERE" " ? < e.stream_ordering AND e.stream_ordering <= ?" " ORDER BY e.stream_ordering ASC" " LIMIT ?" ) txn.execute(sql, (from_id, current_id, limit)) rows = txn.fetchall() upper_bound = current_id if len(rows) == limit: upper_bound = rows[-1][0] return upper_bound, [row[1] for row in rows] upper_bound, event_ids = yield self.runInteraction( "get_all_new_events_stream", get_all_new_events_stream_txn ) events = yield self.get_events_as_list(event_ids) defer.returnValue((upper_bound, events)) def get_federation_out_pos(self, typ): return self._simple_select_one_onecol( table="federation_stream_position", retcol="stream_id", keyvalues={"type": typ}, desc="get_federation_out_pos", ) def update_federation_out_pos(self, typ, stream_id): return self._simple_update_one( table="federation_stream_position", keyvalues={"type": typ}, updatevalues={"stream_id": stream_id}, desc="update_federation_out_pos", ) def has_room_changed_since(self, room_id, stream_id): return self._events_stream_cache.has_entity_changed(room_id, stream_id) def _paginate_room_events_txn( self, txn, room_id, from_token, to_token=None, direction='b', limit=-1, event_filter=None, ): """Returns list of events before or after a given token. Args: txn room_id (str) from_token (RoomStreamToken): The token used to stream from to_token (RoomStreamToken|None): A token which if given limits the results to only those before direction(char): Either 'b' or 'f' to indicate whether we are paginating forwards or backwards from `from_key`. limit (int): The maximum number of events to return. event_filter (Filter|None): If provided filters the events to those that match the filter. Returns: Deferred[tuple[list[_EventDictReturn], str]]: Returns the results as a list of _EventDictReturn and a token that points to the end of the result set. """ assert int(limit) >= 0 # Tokens really represent positions between elements, but we use # the convention of pointing to the event before the gap. Hence # we have a bit of asymmetry when it comes to equalities. args = [False, room_id] if direction == 'b': order = "DESC" else: order = "ASC" bounds = generate_pagination_where_clause( direction=direction, column_names=("topological_ordering", "stream_ordering"), from_token=from_token, to_token=to_token, engine=self.database_engine, ) filter_clause, filter_args = filter_to_clause(event_filter) if filter_clause: bounds += " AND " + filter_clause args.extend(filter_args) args.append(int(limit)) sql = ( "SELECT event_id, topological_ordering, stream_ordering" " FROM events" " WHERE outlier = ? AND room_id = ? AND %(bounds)s" " ORDER BY topological_ordering %(order)s," " stream_ordering %(order)s LIMIT ?" ) % {"bounds": bounds, "order": order} txn.execute(sql, args) rows = [_EventDictReturn(row[0], row[1], row[2]) for row in txn] if rows: topo = rows[-1].topological_ordering toke = rows[-1].stream_ordering if direction == 'b': # Tokens are positions between events. # This token points *after* the last event in the chunk. # We need it to point to the event before it in the chunk # when we are going backwards so we subtract one from the # stream part. toke -= 1 next_token = RoomStreamToken(topo, toke) else: # TODO (erikj): We should work out what to do here instead. next_token = to_token if to_token else from_token return rows, str(next_token) @defer.inlineCallbacks def paginate_room_events( self, room_id, from_key, to_key=None, direction='b', limit=-1, event_filter=None ): """Returns list of events before or after a given token. Args: room_id (str) from_key (str): The token used to stream from to_key (str|None): A token which if given limits the results to only those before direction(char): Either 'b' or 'f' to indicate whether we are paginating forwards or backwards from `from_key`. limit (int): The maximum number of events to return. Zero or less means no limit. event_filter (Filter|None): If provided filters the events to those that match the filter. Returns: tuple[list[dict], str]: Returns the results as a list of dicts and a token that points to the end of the result set. The dicts have the keys "event_id", "topological_ordering" and "stream_orderign". """ from_key = RoomStreamToken.parse(from_key) if to_key: to_key = RoomStreamToken.parse(to_key) rows, token = yield self.runInteraction( "paginate_room_events", self._paginate_room_events_txn, room_id, from_key, to_key, direction, limit, event_filter, ) events = yield self.get_events_as_list( [r.event_id for r in rows], get_prev_content=True ) self._set_before_and_after(events, rows) defer.returnValue((events, token))
class PushRulesWorkerStore( ApplicationServiceWorkerStore, ReceiptsWorkerStore, PusherWorkerStore, RoomMemberWorkerStore, SQLBaseStore, ): """This is an abstract base class where subclasses must implement `get_max_push_rules_stream_id` which can be called in the initializer. """ # This ABCMeta metaclass ensures that we cannot be instantiated without # the abstract methods being implemented. __metaclass__ = abc.ABCMeta def __init__(self, db_conn, hs): super(PushRulesWorkerStore, self).__init__(db_conn, hs) push_rules_prefill, push_rules_id = self._get_cache_dict( db_conn, "push_rules_stream", entity_column="user_id", stream_column="stream_id", max_value=self.get_max_push_rules_stream_id(), ) self.push_rules_stream_cache = StreamChangeCache( "PushRulesStreamChangeCache", push_rules_id, prefilled_cache=push_rules_prefill, ) @abc.abstractmethod def get_max_push_rules_stream_id(self): """Get the position of the push rules stream. Returns: int """ raise NotImplementedError() @cachedInlineCallbacks(max_entries=5000) def get_push_rules_for_user(self, user_id): rows = yield self._simple_select_list( table="push_rules", keyvalues={"user_name": user_id}, retcols=( "user_name", "rule_id", "priority_class", "priority", "conditions", "actions", ), desc="get_push_rules_enabled_for_user", ) rows.sort(key=lambda row: (-int(row["priority_class"]), -int(row["priority"]))) enabled_map = yield self.get_push_rules_enabled_for_user(user_id) rules = _load_rules(rows, enabled_map) return rules @cachedInlineCallbacks(max_entries=5000) def get_push_rules_enabled_for_user(self, user_id): results = yield self._simple_select_list( table="push_rules_enable", keyvalues={"user_name": user_id}, retcols=("user_name", "rule_id", "enabled"), desc="get_push_rules_enabled_for_user", ) return {r["rule_id"]: False if r["enabled"] == 0 else True for r in results} def have_push_rules_changed_for_user(self, user_id, last_id): if not self.push_rules_stream_cache.has_entity_changed(user_id, last_id): return defer.succeed(False) else: def have_push_rules_changed_txn(txn): sql = ( "SELECT COUNT(stream_id) FROM push_rules_stream" " WHERE user_id = ? AND ? < stream_id" ) txn.execute(sql, (user_id, last_id)) (count,) = txn.fetchone() return bool(count) return self.runInteraction( "have_push_rules_changed", have_push_rules_changed_txn ) @cachedList( cached_method_name="get_push_rules_for_user", list_name="user_ids", num_args=1, inlineCallbacks=True, ) def bulk_get_push_rules(self, user_ids): if not user_ids: return {} results = {user_id: [] for user_id in user_ids} rows = yield self._simple_select_many_batch( table="push_rules", column="user_name", iterable=user_ids, retcols=("*",), desc="bulk_get_push_rules", ) rows.sort(key=lambda row: (-int(row["priority_class"]), -int(row["priority"]))) for row in rows: results.setdefault(row["user_name"], []).append(row) enabled_map_by_user = yield self.bulk_get_push_rules_enabled(user_ids) for user_id, rules in results.items(): results[user_id] = _load_rules(rules, enabled_map_by_user.get(user_id, {})) return results @defer.inlineCallbacks def copy_push_rule_from_room_to_room(self, new_room_id, user_id, rule): """Copy a single push rule from one room to another for a specific user. Args: new_room_id (str): ID of the new room. user_id (str): ID of user the push rule belongs to. rule (Dict): A push rule. """ # Create new rule id rule_id_scope = "/".join(rule["rule_id"].split("/")[:-1]) new_rule_id = rule_id_scope + "/" + new_room_id # Change room id in each condition for condition in rule.get("conditions", []): if condition.get("key") == "room_id": condition["pattern"] = new_room_id # Add the rule for the new room yield self.add_push_rule( user_id=user_id, rule_id=new_rule_id, priority_class=rule["priority_class"], conditions=rule["conditions"], actions=rule["actions"], ) @defer.inlineCallbacks def copy_push_rules_from_room_to_room_for_user( self, old_room_id, new_room_id, user_id ): """Copy all of the push rules from one room to another for a specific user. Args: old_room_id (str): ID of the old room. new_room_id (str): ID of the new room. user_id (str): ID of user to copy push rules for. """ # Retrieve push rules for this user user_push_rules = yield self.get_push_rules_for_user(user_id) # Get rules relating to the old room and copy them to the new room for rule in user_push_rules: conditions = rule.get("conditions", []) if any( (c.get("key") == "room_id" and c.get("pattern") == old_room_id) for c in conditions ): yield self.copy_push_rule_from_room_to_room(new_room_id, user_id, rule) @defer.inlineCallbacks def bulk_get_push_rules_for_room(self, event, context): state_group = context.state_group if not state_group: # If state_group is None it means it has yet to be assigned a # state group, i.e. we need to make sure that calls with a state_group # of None don't hit previous cached calls with a None state_group. # To do this we set the state_group to a new object as object() != object() state_group = object() current_state_ids = yield context.get_current_state_ids(self) result = yield self._bulk_get_push_rules_for_room( event.room_id, state_group, current_state_ids, event=event ) return result @cachedInlineCallbacks(num_args=2, cache_context=True) def _bulk_get_push_rules_for_room( self, room_id, state_group, current_state_ids, cache_context, event=None ): # We don't use `state_group`, its there so that we can cache based # on it. However, its important that its never None, since two current_state's # with a state_group of None are likely to be different. # See bulk_get_push_rules_for_room for how we work around this. assert state_group is not None # We also will want to generate notifs for other people in the room so # their unread countss are correct in the event stream, but to avoid # generating them for bot / AS users etc, we only do so for people who've # sent a read receipt into the room. users_in_room = yield self._get_joined_users_from_context( room_id, state_group, current_state_ids, on_invalidate=cache_context.invalidate, event=event, ) # We ignore app service users for now. This is so that we don't fill # up the `get_if_users_have_pushers` cache with AS entries that we # know don't have pushers, nor even read receipts. local_users_in_room = set( u for u in users_in_room if self.hs.is_mine_id(u) and not self.get_if_app_services_interested_in_user(u) ) # users in the room who have pushers need to get push rules run because # that's how their pushers work if_users_with_pushers = yield self.get_if_users_have_pushers( local_users_in_room, on_invalidate=cache_context.invalidate ) user_ids = set( uid for uid, have_pusher in if_users_with_pushers.items() if have_pusher ) users_with_receipts = yield self.get_users_with_read_receipts_in_room( room_id, on_invalidate=cache_context.invalidate ) # any users with pushers must be ours: they have pushers for uid in users_with_receipts: if uid in local_users_in_room: user_ids.add(uid) rules_by_user = yield self.bulk_get_push_rules( user_ids, on_invalidate=cache_context.invalidate ) rules_by_user = {k: v for k, v in rules_by_user.items() if v is not None} return rules_by_user @cachedList( cached_method_name="get_push_rules_enabled_for_user", list_name="user_ids", num_args=1, inlineCallbacks=True, ) def bulk_get_push_rules_enabled(self, user_ids): if not user_ids: return {} results = {user_id: {} for user_id in user_ids} rows = yield self._simple_select_many_batch( table="push_rules_enable", column="user_name", iterable=user_ids, retcols=("user_name", "rule_id", "enabled"), desc="bulk_get_push_rules_enabled", ) for row in rows: enabled = bool(row["enabled"]) results.setdefault(row["user_name"], {})[row["rule_id"]] = enabled return results
class PushRulesWorkerStore(ApplicationServiceWorkerStore, ReceiptsWorkerStore, PusherWorkerStore, RoomMemberWorkerStore, SQLBaseStore): """This is an abstract base class where subclasses must implement `get_max_push_rules_stream_id` which can be called in the initializer. """ # This ABCMeta metaclass ensures that we cannot be instantiated without # the abstract methods being implemented. __metaclass__ = abc.ABCMeta def __init__(self, db_conn, hs): super(PushRulesWorkerStore, self).__init__(db_conn, hs) push_rules_prefill, push_rules_id = self._get_cache_dict( db_conn, "push_rules_stream", entity_column="user_id", stream_column="stream_id", max_value=self.get_max_push_rules_stream_id(), ) self.push_rules_stream_cache = StreamChangeCache( "PushRulesStreamChangeCache", push_rules_id, prefilled_cache=push_rules_prefill, ) @abc.abstractmethod def get_max_push_rules_stream_id(self): """Get the position of the push rules stream. Returns: int """ raise NotImplementedError() @cachedInlineCallbacks(max_entries=5000) def get_push_rules_for_user(self, user_id): rows = yield self._simple_select_list( table="push_rules", keyvalues={ "user_name": user_id, }, retcols=( "user_name", "rule_id", "priority_class", "priority", "conditions", "actions", ), desc="get_push_rules_enabled_for_user", ) rows.sort(key=lambda row: (-int(row["priority_class"]), -int(row["priority"]))) enabled_map = yield self.get_push_rules_enabled_for_user(user_id) rules = _load_rules(rows, enabled_map) defer.returnValue(rules) @cachedInlineCallbacks(max_entries=5000) def get_push_rules_enabled_for_user(self, user_id): results = yield self._simple_select_list( table="push_rules_enable", keyvalues={'user_name': user_id}, retcols=( "user_name", "rule_id", "enabled", ), desc="get_push_rules_enabled_for_user", ) defer.returnValue({ r['rule_id']: False if r['enabled'] == 0 else True for r in results }) def have_push_rules_changed_for_user(self, user_id, last_id): if not self.push_rules_stream_cache.has_entity_changed( user_id, last_id): return defer.succeed(False) else: def have_push_rules_changed_txn(txn): sql = ("SELECT COUNT(stream_id) FROM push_rules_stream" " WHERE user_id = ? AND ? < stream_id") txn.execute(sql, (user_id, last_id)) count, = txn.fetchone() return bool(count) return self.runInteraction("have_push_rules_changed", have_push_rules_changed_txn) @cachedList(cached_method_name="get_push_rules_for_user", list_name="user_ids", num_args=1, inlineCallbacks=True) def bulk_get_push_rules(self, user_ids): if not user_ids: defer.returnValue({}) results = {user_id: [] for user_id in user_ids} rows = yield self._simple_select_many_batch( table="push_rules", column="user_name", iterable=user_ids, retcols=("*", ), desc="bulk_get_push_rules", ) rows.sort(key=lambda row: (-int(row["priority_class"]), -int(row["priority"]))) for row in rows: results.setdefault(row['user_name'], []).append(row) enabled_map_by_user = yield self.bulk_get_push_rules_enabled(user_ids) for user_id, rules in results.items(): results[user_id] = _load_rules( rules, enabled_map_by_user.get(user_id, {})) defer.returnValue(results) def bulk_get_push_rules_for_room(self, event, context): state_group = context.state_group if not state_group: # If state_group is None it means it has yet to be assigned a # state group, i.e. we need to make sure that calls with a state_group # of None don't hit previous cached calls with a None state_group. # To do this we set the state_group to a new object as object() != object() state_group = object() return self._bulk_get_push_rules_for_room(event.room_id, state_group, context.current_state_ids, event=event) @cachedInlineCallbacks(num_args=2, cache_context=True) def _bulk_get_push_rules_for_room(self, room_id, state_group, current_state_ids, cache_context, event=None): # We don't use `state_group`, its there so that we can cache based # on it. However, its important that its never None, since two current_state's # with a state_group of None are likely to be different. # See bulk_get_push_rules_for_room for how we work around this. assert state_group is not None # We also will want to generate notifs for other people in the room so # their unread countss are correct in the event stream, but to avoid # generating them for bot / AS users etc, we only do so for people who've # sent a read receipt into the room. users_in_room = yield self._get_joined_users_from_context( room_id, state_group, current_state_ids, on_invalidate=cache_context.invalidate, event=event, ) # We ignore app service users for now. This is so that we don't fill # up the `get_if_users_have_pushers` cache with AS entries that we # know don't have pushers, nor even read receipts. local_users_in_room = set( u for u in users_in_room if self.hs.is_mine_id(u) and not self.get_if_app_services_interested_in_user(u)) # users in the room who have pushers need to get push rules run because # that's how their pushers work if_users_with_pushers = yield self.get_if_users_have_pushers( local_users_in_room, on_invalidate=cache_context.invalidate, ) user_ids = set(uid for uid, have_pusher in if_users_with_pushers.items() if have_pusher) users_with_receipts = yield self.get_users_with_read_receipts_in_room( room_id, on_invalidate=cache_context.invalidate, ) # any users with pushers must be ours: they have pushers for uid in users_with_receipts: if uid in local_users_in_room: user_ids.add(uid) forgotten = yield self.who_forgot_in_room( event.room_id, on_invalidate=cache_context.invalidate, ) for row in forgotten: user_id = row["user_id"] event_id = row["event_id"] mem_id = current_state_ids.get((EventTypes.Member, user_id), None) if event_id == mem_id: user_ids.discard(user_id) rules_by_user = yield self.bulk_get_push_rules( user_ids, on_invalidate=cache_context.invalidate, ) rules_by_user = { k: v for k, v in rules_by_user.items() if v is not None } defer.returnValue(rules_by_user) @cachedList(cached_method_name="get_push_rules_enabled_for_user", list_name="user_ids", num_args=1, inlineCallbacks=True) def bulk_get_push_rules_enabled(self, user_ids): if not user_ids: defer.returnValue({}) results = {user_id: {} for user_id in user_ids} rows = yield self._simple_select_many_batch( table="push_rules_enable", column="user_name", iterable=user_ids, retcols=( "user_name", "rule_id", "enabled", ), desc="bulk_get_push_rules_enabled", ) for row in rows: enabled = bool(row['enabled']) results.setdefault(row['user_name'], {})[row['rule_id']] = enabled defer.returnValue(results)
class ReceiptsWorkerStore(SQLBaseStore): """This is an abstract base class where subclasses must implement `get_max_receipt_stream_id` which can be called in the initializer. """ # This ABCMeta metaclass ensures that we cannot be instantiated without # the abstract methods being implemented. __metaclass__ = abc.ABCMeta def __init__(self, db_conn, hs): super(ReceiptsWorkerStore, self).__init__(db_conn, hs) self._receipts_stream_cache = StreamChangeCache( "ReceiptsRoomChangeCache", self.get_max_receipt_stream_id() ) @abc.abstractmethod def get_max_receipt_stream_id(self): """Get the current max stream ID for receipts stream Returns: int """ raise NotImplementedError() @cachedInlineCallbacks() def get_users_with_read_receipts_in_room(self, room_id): receipts = yield self.get_receipts_for_room(room_id, "m.read") defer.returnValue(set(r['user_id'] for r in receipts)) @cached(num_args=2) def get_receipts_for_room(self, room_id, receipt_type): return self._simple_select_list( table="receipts_linearized", keyvalues={ "room_id": room_id, "receipt_type": receipt_type, }, retcols=("user_id", "event_id"), desc="get_receipts_for_room", ) @cached(num_args=3) def get_last_receipt_event_id_for_user(self, user_id, room_id, receipt_type): return self._simple_select_one_onecol( table="receipts_linearized", keyvalues={ "room_id": room_id, "receipt_type": receipt_type, "user_id": user_id }, retcol="event_id", desc="get_own_receipt_for_user", allow_none=True, ) @cachedInlineCallbacks(num_args=2) def get_receipts_for_user(self, user_id, receipt_type): rows = yield self._simple_select_list( table="receipts_linearized", keyvalues={ "user_id": user_id, "receipt_type": receipt_type, }, retcols=("room_id", "event_id"), desc="get_receipts_for_user", ) defer.returnValue({row["room_id"]: row["event_id"] for row in rows}) @defer.inlineCallbacks def get_receipts_for_user_with_orderings(self, user_id, receipt_type): def f(txn): sql = ( "SELECT rl.room_id, rl.event_id," " e.topological_ordering, e.stream_ordering" " FROM receipts_linearized AS rl" " INNER JOIN events AS e USING (room_id, event_id)" " WHERE rl.room_id = e.room_id" " AND rl.event_id = e.event_id" " AND user_id = ?" ) txn.execute(sql, (user_id,)) return txn.fetchall() rows = yield self.runInteraction( "get_receipts_for_user_with_orderings", f ) defer.returnValue({ row[0]: { "event_id": row[1], "topological_ordering": row[2], "stream_ordering": row[3], } for row in rows }) @defer.inlineCallbacks def get_linearized_receipts_for_rooms(self, room_ids, to_key, from_key=None): """Get receipts for multiple rooms for sending to clients. Args: room_ids (list): List of room_ids. to_key (int): Max stream id to fetch receipts upto. from_key (int): Min stream id to fetch receipts from. None fetches from the start. Returns: list: A list of receipts. """ room_ids = set(room_ids) if from_key is not None: # Only ask the database about rooms where there have been new # receipts added since `from_key` room_ids = yield self._receipts_stream_cache.get_entities_changed( room_ids, from_key ) results = yield self._get_linearized_receipts_for_rooms( room_ids, to_key, from_key=from_key ) defer.returnValue([ev for res in results.values() for ev in res]) def get_linearized_receipts_for_room(self, room_id, to_key, from_key=None): """Get receipts for a single room for sending to clients. Args: room_ids (str): The room id. to_key (int): Max stream id to fetch receipts upto. from_key (int): Min stream id to fetch receipts from. None fetches from the start. Returns: Deferred[list]: A list of receipts. """ if from_key is not None: # Check the cache first to see if any new receipts have been added # since`from_key`. If not we can no-op. if not self._receipts_stream_cache.has_entity_changed(room_id, from_key): defer.succeed([]) return self._get_linearized_receipts_for_room(room_id, to_key, from_key) @cachedInlineCallbacks(num_args=3, tree=True) def _get_linearized_receipts_for_room(self, room_id, to_key, from_key=None): """See get_linearized_receipts_for_room """ def f(txn): if from_key: sql = ( "SELECT * FROM receipts_linearized WHERE" " room_id = ? AND stream_id > ? AND stream_id <= ?" ) txn.execute( sql, (room_id, from_key, to_key) ) else: sql = ( "SELECT * FROM receipts_linearized WHERE" " room_id = ? AND stream_id <= ?" ) txn.execute( sql, (room_id, to_key) ) rows = self.cursor_to_dict(txn) return rows rows = yield self.runInteraction( "get_linearized_receipts_for_room", f ) if not rows: defer.returnValue([]) content = {} for row in rows: content.setdefault( row["event_id"], {} ).setdefault( row["receipt_type"], {} )[row["user_id"]] = json.loads(row["data"]) defer.returnValue([{ "type": "m.receipt", "room_id": room_id, "content": content, }]) @cachedList(cached_method_name="_get_linearized_receipts_for_room", list_name="room_ids", num_args=3, inlineCallbacks=True) def _get_linearized_receipts_for_rooms(self, room_ids, to_key, from_key=None): if not room_ids: defer.returnValue({}) def f(txn): if from_key: sql = ( "SELECT * FROM receipts_linearized WHERE" " room_id IN (%s) AND stream_id > ? AND stream_id <= ?" ) % ( ",".join(["?"] * len(room_ids)) ) args = list(room_ids) args.extend([from_key, to_key]) txn.execute(sql, args) else: sql = ( "SELECT * FROM receipts_linearized WHERE" " room_id IN (%s) AND stream_id <= ?" ) % ( ",".join(["?"] * len(room_ids)) ) args = list(room_ids) args.append(to_key) txn.execute(sql, args) return self.cursor_to_dict(txn) txn_results = yield self.runInteraction( "_get_linearized_receipts_for_rooms", f ) results = {} for row in txn_results: # We want a single event per room, since we want to batch the # receipts by room, event and type. room_event = results.setdefault(row["room_id"], { "type": "m.receipt", "room_id": row["room_id"], "content": {}, }) # The content is of the form: # {"$foo:bar": { "read": { "@user:host": <receipt> }, .. }, .. } event_entry = room_event["content"].setdefault(row["event_id"], {}) receipt_type = event_entry.setdefault(row["receipt_type"], {}) receipt_type[row["user_id"]] = json.loads(row["data"]) results = { room_id: [results[room_id]] if room_id in results else [] for room_id in room_ids } defer.returnValue(results) def get_all_updated_receipts(self, last_id, current_id, limit=None): if last_id == current_id: return defer.succeed([]) def get_all_updated_receipts_txn(txn): sql = ( "SELECT stream_id, room_id, receipt_type, user_id, event_id, data" " FROM receipts_linearized" " WHERE ? < stream_id AND stream_id <= ?" " ORDER BY stream_id ASC" ) args = [last_id, current_id] if limit is not None: sql += " LIMIT ?" args.append(limit) txn.execute(sql, args) return txn.fetchall() return self.runInteraction( "get_all_updated_receipts", get_all_updated_receipts_txn ) def _invalidate_get_users_with_receipts_in_room(self, room_id, receipt_type, user_id): if receipt_type != "m.read": return # Returns either an ObservableDeferred or the raw result res = self.get_users_with_read_receipts_in_room.cache.get( room_id, None, update_metrics=False, ) # first handle the Deferred case if isinstance(res, defer.Deferred): if res.called: res = res.result else: res = None if res and user_id in res: # We'd only be adding to the set, so no point invalidating if the # user is already there return self.get_users_with_read_receipts_in_room.invalidate((room_id,))
class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): """This is an abstract base class where subclasses must implement `get_room_max_stream_ordering` and `get_room_min_stream_ordering` which can be called in the initializer. """ __metaclass__ = abc.ABCMeta def __init__(self, db_conn, hs): super(StreamWorkerStore, self).__init__(db_conn, hs) events_max = self.get_room_max_stream_ordering() event_cache_prefill, min_event_val = self._get_cache_dict( db_conn, "events", entity_column="room_id", stream_column="stream_ordering", max_value=events_max, ) self._events_stream_cache = StreamChangeCache( "EventsRoomStreamChangeCache", min_event_val, prefilled_cache=event_cache_prefill, ) self._membership_stream_cache = StreamChangeCache( "MembershipStreamChangeCache", events_max, ) self._stream_order_on_start = self.get_room_max_stream_ordering() @abc.abstractmethod def get_room_max_stream_ordering(self): raise NotImplementedError() @abc.abstractmethod def get_room_min_stream_ordering(self): raise NotImplementedError() @defer.inlineCallbacks def get_room_events_stream_for_rooms(self, room_ids, from_key, to_key, limit=0, order='DESC'): from_id = RoomStreamToken.parse_stream_token(from_key).stream room_ids = yield self._events_stream_cache.get_entities_changed( room_ids, from_id ) if not room_ids: defer.returnValue({}) results = {} room_ids = list(room_ids) for rm_ids in (room_ids[i:i + 20] for i in range(0, len(room_ids), 20)): res = yield make_deferred_yieldable(defer.gatherResults([ run_in_background( self.get_room_events_stream_for_room, room_id, from_key, to_key, limit, order=order, ) for room_id in rm_ids ], consumeErrors=True)) results.update(dict(zip(rm_ids, res))) defer.returnValue(results) def get_rooms_that_changed(self, room_ids, from_key): """Given a list of rooms and a token, return rooms where there may have been changes. Args: room_ids (list) from_key (str): The room_key portion of a StreamToken """ from_key = RoomStreamToken.parse_stream_token(from_key).stream return set( room_id for room_id in room_ids if self._events_stream_cache.has_entity_changed(room_id, from_key) ) @defer.inlineCallbacks def get_room_events_stream_for_room(self, room_id, from_key, to_key, limit=0, order='DESC'): """Get new room events in stream ordering since `from_key`. Args: room_id (str) from_key (str): Token from which no events are returned before to_key (str): Token from which no events are returned after. (This is typically the current stream token) limit (int): Maximum number of events to return order (str): Either "DESC" or "ASC". Determines which events are returned when the result is limited. If "DESC" then the most recent `limit` events are returned, otherwise returns the oldest `limit` events. Returns: Deferred[tuple[list[FrozenEvent], str]]: Returns the list of events (in ascending order) and the token from the start of the chunk of events returned. """ if from_key == to_key: defer.returnValue(([], from_key)) from_id = RoomStreamToken.parse_stream_token(from_key).stream to_id = RoomStreamToken.parse_stream_token(to_key).stream has_changed = yield self._events_stream_cache.has_entity_changed( room_id, from_id ) if not has_changed: defer.returnValue(([], from_key)) def f(txn): sql = ( "SELECT event_id, stream_ordering FROM events WHERE" " room_id = ?" " AND not outlier" " AND stream_ordering > ? AND stream_ordering <= ?" " ORDER BY stream_ordering %s LIMIT ?" ) % (order,) txn.execute(sql, (room_id, from_id, to_id, limit)) rows = [_EventDictReturn(row[0], None, row[1]) for row in txn] return rows rows = yield self.runInteraction("get_room_events_stream_for_room", f) ret = yield self._get_events( [r.event_id for r in rows], get_prev_content=True ) self._set_before_and_after(ret, rows, topo_order=from_id is None) if order.lower() == "desc": ret.reverse() if rows: key = "s%d" % min(r.stream_ordering for r in rows) else: # Assume we didn't get anything because there was nothing to # get. key = from_key defer.returnValue((ret, key)) @defer.inlineCallbacks def get_membership_changes_for_user(self, user_id, from_key, to_key): from_id = RoomStreamToken.parse_stream_token(from_key).stream to_id = RoomStreamToken.parse_stream_token(to_key).stream if from_key == to_key: defer.returnValue([]) if from_id: has_changed = self._membership_stream_cache.has_entity_changed( user_id, int(from_id) ) if not has_changed: defer.returnValue([]) def f(txn): sql = ( "SELECT m.event_id, stream_ordering FROM events AS e," " room_memberships AS m" " WHERE e.event_id = m.event_id" " AND m.user_id = ?" " AND e.stream_ordering > ? AND e.stream_ordering <= ?" " ORDER BY e.stream_ordering ASC" ) txn.execute(sql, (user_id, from_id, to_id,)) rows = [_EventDictReturn(row[0], None, row[1]) for row in txn] return rows rows = yield self.runInteraction("get_membership_changes_for_user", f) ret = yield self._get_events( [r.event_id for r in rows], get_prev_content=True ) self._set_before_and_after(ret, rows, topo_order=False) defer.returnValue(ret) @defer.inlineCallbacks def get_recent_events_for_room(self, room_id, limit, end_token): """Get the most recent events in the room in topological ordering. Args: room_id (str) limit (int) end_token (str): The stream token representing now. Returns: Deferred[tuple[list[FrozenEvent], str]]: Returns a list of events and a token pointing to the start of the returned events. The events returned are in ascending order. """ rows, token = yield self.get_recent_event_ids_for_room( room_id, limit, end_token, ) logger.debug("stream before") events = yield self._get_events( [r.event_id for r in rows], get_prev_content=True ) logger.debug("stream after") self._set_before_and_after(events, rows) defer.returnValue((events, token)) @defer.inlineCallbacks def get_recent_event_ids_for_room(self, room_id, limit, end_token): """Get the most recent events in the room in topological ordering. Args: room_id (str) limit (int) end_token (str): The stream token representing now. Returns: Deferred[tuple[list[_EventDictReturn], str]]: Returns a list of _EventDictReturn and a token pointing to the start of the returned events. The events returned are in ascending order. """ # Allow a zero limit here, and no-op. if limit == 0: defer.returnValue(([], end_token)) end_token = RoomStreamToken.parse(end_token) rows, token = yield self.runInteraction( "get_recent_event_ids_for_room", self._paginate_room_events_txn, room_id, from_token=end_token, limit=limit, ) # We want to return the results in ascending order. rows.reverse() defer.returnValue((rows, token)) def get_room_event_after_stream_ordering(self, room_id, stream_ordering): """Gets details of the first event in a room at or after a stream ordering Args: room_id (str): stream_ordering (int): Returns: Deferred[(int, int, str)]: (stream ordering, topological ordering, event_id) """ def _f(txn): sql = ( "SELECT stream_ordering, topological_ordering, event_id" " FROM events" " WHERE room_id = ? AND stream_ordering >= ?" " AND NOT outlier" " ORDER BY stream_ordering" " LIMIT 1" ) txn.execute(sql, (room_id, stream_ordering, )) return txn.fetchone() return self.runInteraction( "get_room_event_after_stream_ordering", _f, ) @defer.inlineCallbacks def get_room_events_max_id(self, room_id=None): """Returns the current token for rooms stream. By default, it returns the current global stream token. Specifying a `room_id` causes it to return the current room specific topological token. """ token = yield self.get_room_max_stream_ordering() if room_id is None: defer.returnValue("s%d" % (token,)) else: topo = yield self.runInteraction( "_get_max_topological_txn", self._get_max_topological_txn, room_id, ) defer.returnValue("t%d-%d" % (topo, token)) def get_stream_token_for_event(self, event_id): """The stream token for an event Args: event_id(str): The id of the event to look up a stream token for. Raises: StoreError if the event wasn't in the database. Returns: A deferred "s%d" stream token. """ return self._simple_select_one_onecol( table="events", keyvalues={"event_id": event_id}, retcol="stream_ordering", ).addCallback(lambda row: "s%d" % (row,)) def get_topological_token_for_event(self, event_id): """The stream token for an event Args: event_id(str): The id of the event to look up a stream token for. Raises: StoreError if the event wasn't in the database. Returns: A deferred "t%d-%d" topological token. """ return self._simple_select_one( table="events", keyvalues={"event_id": event_id}, retcols=("stream_ordering", "topological_ordering"), desc="get_topological_token_for_event", ).addCallback(lambda row: "t%d-%d" % ( row["topological_ordering"], row["stream_ordering"],) ) def get_max_topological_token(self, room_id, stream_key): sql = ( "SELECT max(topological_ordering) FROM events" " WHERE room_id = ? AND stream_ordering < ?" ) return self._execute( "get_max_topological_token", None, sql, room_id, stream_key, ).addCallback( lambda r: r[0][0] if r else 0 ) def _get_max_topological_txn(self, txn, room_id): txn.execute( "SELECT MAX(topological_ordering) FROM events" " WHERE room_id = ?", (room_id,) ) rows = txn.fetchall() return rows[0][0] if rows else 0 @staticmethod def _set_before_and_after(events, rows, topo_order=True): """Inserts ordering information to events' internal metadata from the DB rows. Args: events (list[FrozenEvent]) rows (list[_EventDictReturn]) topo_order (bool): Whether the events were ordered topologically or by stream ordering. If true then all rows should have a non null topological_ordering. """ for event, row in zip(events, rows): stream = row.stream_ordering if topo_order and row.topological_ordering: topo = row.topological_ordering else: topo = None internal = event.internal_metadata internal.before = str(RoomStreamToken(topo, stream - 1)) internal.after = str(RoomStreamToken(topo, stream)) internal.order = ( int(topo) if topo else 0, int(stream), ) @defer.inlineCallbacks def get_events_around( self, room_id, event_id, before_limit, after_limit, event_filter=None, ): """Retrieve events and pagination tokens around a given event in a room. Args: room_id (str) event_id (str) before_limit (int) after_limit (int) event_filter (Filter|None) Returns: dict """ results = yield self.runInteraction( "get_events_around", self._get_events_around_txn, room_id, event_id, before_limit, after_limit, event_filter, ) events_before = yield self._get_events( [e for e in results["before"]["event_ids"]], get_prev_content=True ) events_after = yield self._get_events( [e for e in results["after"]["event_ids"]], get_prev_content=True ) defer.returnValue({ "events_before": events_before, "events_after": events_after, "start": results["before"]["token"], "end": results["after"]["token"], }) def _get_events_around_txn( self, txn, room_id, event_id, before_limit, after_limit, event_filter, ): """Retrieves event_ids and pagination tokens around a given event in a room. Args: room_id (str) event_id (str) before_limit (int) after_limit (int) event_filter (Filter|None) Returns: dict """ results = self._simple_select_one_txn( txn, "events", keyvalues={ "event_id": event_id, "room_id": room_id, }, retcols=["stream_ordering", "topological_ordering"], ) # Paginating backwards includes the event at the token, but paginating # forward doesn't. before_token = RoomStreamToken( results["topological_ordering"] - 1, results["stream_ordering"], ) after_token = RoomStreamToken( results["topological_ordering"], results["stream_ordering"], ) rows, start_token = self._paginate_room_events_txn( txn, room_id, before_token, direction='b', limit=before_limit, event_filter=event_filter, ) events_before = [r.event_id for r in rows] rows, end_token = self._paginate_room_events_txn( txn, room_id, after_token, direction='f', limit=after_limit, event_filter=event_filter, ) events_after = [r.event_id for r in rows] return { "before": { "event_ids": events_before, "token": start_token, }, "after": { "event_ids": events_after, "token": end_token, }, } @defer.inlineCallbacks def get_all_new_events_stream(self, from_id, current_id, limit): """Get all new events Returns all events with from_id < stream_ordering <= current_id. Args: from_id (int): the stream_ordering of the last event we processed current_id (int): the stream_ordering of the most recently processed event limit (int): the maximum number of events to return Returns: Deferred[Tuple[int, list[FrozenEvent]]]: A tuple of (next_id, events), where `next_id` is the next value to pass as `from_id` (it will either be the stream_ordering of the last returned event, or, if fewer than `limit` events were found, `current_id`. """ def get_all_new_events_stream_txn(txn): sql = ( "SELECT e.stream_ordering, e.event_id" " FROM events AS e" " WHERE" " ? < e.stream_ordering AND e.stream_ordering <= ?" " ORDER BY e.stream_ordering ASC" " LIMIT ?" ) txn.execute(sql, (from_id, current_id, limit)) rows = txn.fetchall() upper_bound = current_id if len(rows) == limit: upper_bound = rows[-1][0] return upper_bound, [row[1] for row in rows] upper_bound, event_ids = yield self.runInteraction( "get_all_new_events_stream", get_all_new_events_stream_txn, ) events = yield self._get_events(event_ids) defer.returnValue((upper_bound, events)) def get_federation_out_pos(self, typ): return self._simple_select_one_onecol( table="federation_stream_position", retcol="stream_id", keyvalues={"type": typ}, desc="get_federation_out_pos" ) def update_federation_out_pos(self, typ, stream_id): return self._simple_update_one( table="federation_stream_position", keyvalues={"type": typ}, updatevalues={"stream_id": stream_id}, desc="update_federation_out_pos", ) def has_room_changed_since(self, room_id, stream_id): return self._events_stream_cache.has_entity_changed(room_id, stream_id) def _paginate_room_events_txn(self, txn, room_id, from_token, to_token=None, direction='b', limit=-1, event_filter=None): """Returns list of events before or after a given token. Args: txn room_id (str) from_token (RoomStreamToken): The token used to stream from to_token (RoomStreamToken|None): A token which if given limits the results to only those before direction(char): Either 'b' or 'f' to indicate whether we are paginating forwards or backwards from `from_key`. limit (int): The maximum number of events to return. event_filter (Filter|None): If provided filters the events to those that match the filter. Returns: Deferred[tuple[list[_EventDictReturn], str]]: Returns the results as a list of _EventDictReturn and a token that points to the end of the result set. """ assert int(limit) >= 0 # Tokens really represent positions between elements, but we use # the convention of pointing to the event before the gap. Hence # we have a bit of asymmetry when it comes to equalities. args = [False, room_id] if direction == 'b': order = "DESC" bounds = upper_bound( from_token, self.database_engine ) if to_token: bounds = "%s AND %s" % (bounds, lower_bound( to_token, self.database_engine )) else: order = "ASC" bounds = lower_bound( from_token, self.database_engine ) if to_token: bounds = "%s AND %s" % (bounds, upper_bound( to_token, self.database_engine )) filter_clause, filter_args = filter_to_clause(event_filter) if filter_clause: bounds += " AND " + filter_clause args.extend(filter_args) args.append(int(limit)) sql = ( "SELECT event_id, topological_ordering, stream_ordering" " FROM events" " WHERE outlier = ? AND room_id = ? AND %(bounds)s" " ORDER BY topological_ordering %(order)s," " stream_ordering %(order)s LIMIT ?" ) % { "bounds": bounds, "order": order, } txn.execute(sql, args) rows = [_EventDictReturn(row[0], row[1], row[2]) for row in txn] if rows: topo = rows[-1].topological_ordering toke = rows[-1].stream_ordering if direction == 'b': # Tokens are positions between events. # This token points *after* the last event in the chunk. # We need it to point to the event before it in the chunk # when we are going backwards so we subtract one from the # stream part. toke -= 1 next_token = RoomStreamToken(topo, toke) else: # TODO (erikj): We should work out what to do here instead. next_token = to_token if to_token else from_token return rows, str(next_token), @defer.inlineCallbacks def paginate_room_events(self, room_id, from_key, to_key=None, direction='b', limit=-1, event_filter=None): """Returns list of events before or after a given token. Args: room_id (str) from_key (str): The token used to stream from to_key (str|None): A token which if given limits the results to only those before direction(char): Either 'b' or 'f' to indicate whether we are paginating forwards or backwards from `from_key`. limit (int): The maximum number of events to return. Zero or less means no limit. event_filter (Filter|None): If provided filters the events to those that match the filter. Returns: tuple[list[dict], str]: Returns the results as a list of dicts and a token that points to the end of the result set. The dicts have the keys "event_id", "topological_ordering" and "stream_orderign". """ from_key = RoomStreamToken.parse(from_key) if to_key: to_key = RoomStreamToken.parse(to_key) rows, token = yield self.runInteraction( "paginate_room_events", self._paginate_room_events_txn, room_id, from_key, to_key, direction, limit, event_filter, ) events = yield self._get_events( [r.event_id for r in rows], get_prev_content=True ) self._set_before_and_after(events, rows) defer.returnValue((events, token))
class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta): """This is an abstract base class where subclasses must implement `get_room_max_stream_ordering` and `get_room_min_stream_ordering` which can be called in the initializer. """ def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"): super().__init__(database, db_conn, hs) self._instance_name = hs.get_instance_name() self._send_federation = hs.should_send_federation() self._federation_shard_config = hs.config.worker.federation_shard_config # If we're a process that sends federation we may need to reset the # `federation_stream_position` table to match the current sharding # config. We don't do this now as otherwise two processes could conflict # during startup which would cause one to die. self._need_to_reset_federation_stream_positions = self._send_federation events_max = self.get_room_max_stream_ordering() event_cache_prefill, min_event_val = self.db_pool.get_cache_dict( db_conn, "events", entity_column="room_id", stream_column="stream_ordering", max_value=events_max, ) self._events_stream_cache = StreamChangeCache( "EventsRoomStreamChangeCache", min_event_val, prefilled_cache=event_cache_prefill, ) self._membership_stream_cache = StreamChangeCache( "MembershipStreamChangeCache", events_max) self._stream_order_on_start = self.get_room_max_stream_ordering() @abc.abstractmethod def get_room_max_stream_ordering(self) -> int: raise NotImplementedError() @abc.abstractmethod def get_room_min_stream_ordering(self) -> int: raise NotImplementedError() def get_room_max_token(self) -> RoomStreamToken: """Get a `RoomStreamToken` that marks the current maximum persisted position of the events stream. Useful to get a token that represents "now". The token returned is a "live" token that may have an instance_map component. """ min_pos = self._stream_id_gen.get_current_token() positions = {} if isinstance(self._stream_id_gen, MultiWriterIdGenerator): # The `min_pos` is the minimum position that we know all instances # have finished persisting to, so we only care about instances whose # positions are ahead of that. (Instance positions can be behind the # min position as there are times we can work out that the minimum # position is ahead of the naive minimum across all current # positions. See MultiWriterIdGenerator for details) positions = { i: p for i, p in self._stream_id_gen.get_positions().items() if p > min_pos } return RoomStreamToken(None, min_pos, frozendict(positions)) async def get_room_events_stream_for_rooms( self, room_ids: Collection[str], from_key: RoomStreamToken, to_key: RoomStreamToken, limit: int = 0, order: str = "DESC", ) -> Dict[str, Tuple[List[EventBase], RoomStreamToken]]: """Get new room events in stream ordering since `from_key`. Args: room_ids from_key: Token from which no events are returned before to_key: Token from which no events are returned after. (This is typically the current stream token) limit: Maximum number of events to return order: Either "DESC" or "ASC". Determines which events are returned when the result is limited. If "DESC" then the most recent `limit` events are returned, otherwise returns the oldest `limit` events. Returns: A map from room id to a tuple containing: - list of recent events in the room - stream ordering key for the start of the chunk of events returned. """ room_ids = self._events_stream_cache.get_entities_changed( room_ids, from_key.stream) if not room_ids: return {} results = {} room_ids = list(room_ids) for rm_ids in (room_ids[i:i + 20] for i in range(0, len(room_ids), 20)): res = await make_deferred_yieldable( defer.gatherResults( [ run_in_background( self.get_room_events_stream_for_room, room_id, from_key, to_key, limit, order=order, ) for room_id in rm_ids ], consumeErrors=True, )) results.update(dict(zip(rm_ids, res))) return results def get_rooms_that_changed(self, room_ids: Collection[str], from_key: RoomStreamToken) -> Set[str]: """Given a list of rooms and a token, return rooms where there may have been changes. """ from_id = from_key.stream return { room_id for room_id in room_ids if self._events_stream_cache.has_entity_changed(room_id, from_id) } async def get_room_events_stream_for_room( self, room_id: str, from_key: RoomStreamToken, to_key: RoomStreamToken, limit: int = 0, order: str = "DESC", ) -> Tuple[List[EventBase], RoomStreamToken]: """Get new room events in stream ordering since `from_key`. Args: room_id from_key: Token from which no events are returned before to_key: Token from which no events are returned after. (This is typically the current stream token) limit: Maximum number of events to return order: Either "DESC" or "ASC". Determines which events are returned when the result is limited. If "DESC" then the most recent `limit` events are returned, otherwise returns the oldest `limit` events. Returns: The list of events (in ascending stream order) and the token from the start of the chunk of events returned. """ if from_key == to_key: return [], from_key has_changed = self._events_stream_cache.has_entity_changed( room_id, from_key.stream) if not has_changed: return [], from_key def f(txn: LoggingTransaction) -> List[_EventDictReturn]: # To handle tokens with a non-empty instance_map we fetch more # results than necessary and then filter down min_from_id = from_key.stream max_to_id = to_key.get_max_stream_pos() sql = """ SELECT event_id, instance_name, topological_ordering, stream_ordering FROM events WHERE room_id = ? AND not outlier AND stream_ordering > ? AND stream_ordering <= ? ORDER BY stream_ordering %s LIMIT ? """ % (order, ) txn.execute(sql, (room_id, min_from_id, max_to_id, 2 * limit)) rows = [ _EventDictReturn(event_id, None, stream_ordering) for event_id, instance_name, topological_ordering, stream_ordering in txn if _filter_results( from_key, to_key, instance_name, topological_ordering, stream_ordering, ) ][:limit] return rows rows = await self.db_pool.runInteraction( "get_room_events_stream_for_room", f) ret = await self.get_events_as_list([r.event_id for r in rows], get_prev_content=True) self._set_before_and_after(ret, rows, topo_order=False) if order.lower() == "desc": ret.reverse() if rows: key = RoomStreamToken(None, min(r.stream_ordering for r in rows)) else: # Assume we didn't get anything because there was nothing to # get. key = from_key return ret, key async def get_membership_changes_for_user( self, user_id: str, from_key: RoomStreamToken, to_key: RoomStreamToken) -> List[EventBase]: """Fetch membership events for a given user. All such events whose stream ordering `s` lies in the range `from_key < s <= to_key` are returned. Events are ordered by ascending stream order. """ # Start by ruling out cases where a DB query is not necessary. if from_key == to_key: return [] if from_key: has_changed = self._membership_stream_cache.has_entity_changed( user_id, int(from_key.stream)) if not has_changed: return [] def f(txn: LoggingTransaction) -> List[_EventDictReturn]: # To handle tokens with a non-empty instance_map we fetch more # results than necessary and then filter down min_from_id = from_key.stream max_to_id = to_key.get_max_stream_pos() sql = """ SELECT m.event_id, instance_name, topological_ordering, stream_ordering FROM events AS e, room_memberships AS m WHERE e.event_id = m.event_id AND m.user_id = ? AND e.stream_ordering > ? AND e.stream_ordering <= ? ORDER BY e.stream_ordering ASC """ txn.execute( sql, ( user_id, min_from_id, max_to_id, ), ) rows = [ _EventDictReturn(event_id, None, stream_ordering) for event_id, instance_name, topological_ordering, stream_ordering in txn if _filter_results( from_key, to_key, instance_name, topological_ordering, stream_ordering, ) ] return rows rows = await self.db_pool.runInteraction( "get_membership_changes_for_user", f) ret = await self.get_events_as_list([r.event_id for r in rows], get_prev_content=True) self._set_before_and_after(ret, rows, topo_order=False) return ret async def get_recent_events_for_room( self, room_id: str, limit: int, end_token: RoomStreamToken ) -> Tuple[List[EventBase], RoomStreamToken]: """Get the most recent events in the room in topological ordering. Args: room_id limit end_token: The stream token representing now. Returns: A list of events and a token pointing to the start of the returned events. The events returned are in ascending topological order. """ rows, token = await self.get_recent_event_ids_for_room( room_id, limit, end_token) events = await self.get_events_as_list([r.event_id for r in rows], get_prev_content=True) self._set_before_and_after(events, rows) return events, token async def get_recent_event_ids_for_room( self, room_id: str, limit: int, end_token: RoomStreamToken ) -> Tuple[List[_EventDictReturn], RoomStreamToken]: """Get the most recent events in the room in topological ordering. Args: room_id limit end_token: The stream token representing now. Returns: A list of _EventDictReturn and a token pointing to the start of the returned events. The events returned are in ascending order. """ # Allow a zero limit here, and no-op. if limit == 0: return [], end_token rows, token = await self.db_pool.runInteraction( "get_recent_event_ids_for_room", self._paginate_room_events_txn, room_id, from_token=end_token, limit=limit, ) # We want to return the results in ascending order. rows.reverse() return rows, token async def get_room_event_before_stream_ordering( self, room_id: str, stream_ordering: int) -> Optional[Tuple[int, int, str]]: """Gets details of the first event in a room at or before a stream ordering Args: room_id: stream_ordering: Returns: A tuple of (stream ordering, topological ordering, event_id) """ def _f(txn): sql = ("SELECT stream_ordering, topological_ordering, event_id" " FROM events" " WHERE room_id = ? AND stream_ordering <= ?" " AND NOT outlier" " ORDER BY stream_ordering DESC" " LIMIT 1") txn.execute(sql, (room_id, stream_ordering)) return txn.fetchone() return await self.db_pool.runInteraction( "get_room_event_before_stream_ordering", _f) async def get_room_events_max_id(self, room_id: Optional[str] = None) -> str: """Returns the current token for rooms stream. By default, it returns the current global stream token. Specifying a `room_id` causes it to return the current room specific topological token. """ token = self.get_room_max_stream_ordering() if room_id is None: return "s%d" % (token, ) else: topo = await self.db_pool.runInteraction( "_get_max_topological_txn", self._get_max_topological_txn, room_id) return "t%d-%d" % (topo, token) def get_stream_id_for_event_txn( self, txn: LoggingTransaction, event_id: str, allow_none=False, ) -> int: return self.db_pool.simple_select_one_onecol_txn( txn=txn, table="events", keyvalues={"event_id": event_id}, retcol="stream_ordering", allow_none=allow_none, ) async def get_position_for_event(self, event_id: str) -> PersistedEventPosition: """Get the persisted position for an event""" row = await self.db_pool.simple_select_one( table="events", keyvalues={"event_id": event_id}, retcols=("stream_ordering", "instance_name"), desc="get_position_for_event", ) return PersistedEventPosition(row["instance_name"] or "master", row["stream_ordering"]) async def get_topological_token_for_event( self, event_id: str) -> RoomStreamToken: """The stream token for an event Args: event_id: The id of the event to look up a stream token for. Raises: StoreError if the event wasn't in the database. Returns: A `RoomStreamToken` topological token. """ row = await self.db_pool.simple_select_one( table="events", keyvalues={"event_id": event_id}, retcols=("stream_ordering", "topological_ordering"), desc="get_topological_token_for_event", ) return RoomStreamToken(row["topological_ordering"], row["stream_ordering"]) async def get_current_topological_token(self, room_id: str, stream_key: int) -> int: """Gets the topological token in a room after or at the given stream ordering. Args: room_id stream_key """ sql = ("SELECT coalesce(MIN(topological_ordering), 0) FROM events" " WHERE room_id = ? AND stream_ordering >= ?") row = await self.db_pool.execute("get_current_topological_token", None, sql, room_id, stream_key) return row[0][0] if row else 0 def _get_max_topological_txn(self, txn: LoggingTransaction, room_id: str) -> int: txn.execute( "SELECT MAX(topological_ordering) FROM events WHERE room_id = ?", (room_id, ), ) rows = txn.fetchall() return rows[0][0] if rows else 0 @staticmethod def _set_before_and_after(events: List[EventBase], rows: List[_EventDictReturn], topo_order: bool = True): """Inserts ordering information to events' internal metadata from the DB rows. Args: events rows topo_order: Whether the events were ordered topologically or by stream ordering. If true then all rows should have a non null topological_ordering. """ for event, row in zip(events, rows): stream = row.stream_ordering if topo_order and row.topological_ordering: topo = row.topological_ordering else: topo = None internal = event.internal_metadata internal.before = RoomStreamToken(topo, stream - 1) internal.after = RoomStreamToken(topo, stream) internal.order = (int(topo) if topo else 0, int(stream)) async def get_events_around( self, room_id: str, event_id: str, before_limit: int, after_limit: int, event_filter: Optional[Filter] = None, ) -> dict: """Retrieve events and pagination tokens around a given event in a room. """ results = await self.db_pool.runInteraction( "get_events_around", self._get_events_around_txn, room_id, event_id, before_limit, after_limit, event_filter, ) events_before = await self.get_events_as_list(list( results["before"]["event_ids"]), get_prev_content=True) events_after = await self.get_events_as_list(list( results["after"]["event_ids"]), get_prev_content=True) return { "events_before": events_before, "events_after": events_after, "start": results["before"]["token"], "end": results["after"]["token"], } def _get_events_around_txn( self, txn: LoggingTransaction, room_id: str, event_id: str, before_limit: int, after_limit: int, event_filter: Optional[Filter], ) -> dict: """Retrieves event_ids and pagination tokens around a given event in a room. Args: room_id event_id before_limit after_limit event_filter Returns: dict """ results = self.db_pool.simple_select_one_txn( txn, "events", keyvalues={ "event_id": event_id, "room_id": room_id }, retcols=["stream_ordering", "topological_ordering"], ) # This cannot happen as `allow_none=False`. assert results is not None # Paginating backwards includes the event at the token, but paginating # forward doesn't. before_token = RoomStreamToken(results["topological_ordering"] - 1, results["stream_ordering"]) after_token = RoomStreamToken(results["topological_ordering"], results["stream_ordering"]) rows, start_token = self._paginate_room_events_txn( txn, room_id, before_token, direction="b", limit=before_limit, event_filter=event_filter, ) events_before = [r.event_id for r in rows] rows, end_token = self._paginate_room_events_txn( txn, room_id, after_token, direction="f", limit=after_limit, event_filter=event_filter, ) events_after = [r.event_id for r in rows] return { "before": { "event_ids": events_before, "token": start_token }, "after": { "event_ids": events_after, "token": end_token }, } async def get_all_new_events_stream( self, from_id: int, current_id: int, limit: int) -> Tuple[int, List[EventBase]]: """Get all new events Returns all events with from_id < stream_ordering <= current_id. Args: from_id: the stream_ordering of the last event we processed current_id: the stream_ordering of the most recently processed event limit: the maximum number of events to return Returns: A tuple of (next_id, events), where `next_id` is the next value to pass as `from_id` (it will either be the stream_ordering of the last returned event, or, if fewer than `limit` events were found, the `current_id`). """ def get_all_new_events_stream_txn(txn): sql = ("SELECT e.stream_ordering, e.event_id" " FROM events AS e" " WHERE" " ? < e.stream_ordering AND e.stream_ordering <= ?" " ORDER BY e.stream_ordering ASC" " LIMIT ?") txn.execute(sql, (from_id, current_id, limit)) rows = txn.fetchall() upper_bound = current_id if len(rows) == limit: upper_bound = rows[-1][0] return upper_bound, [row[1] for row in rows] upper_bound, event_ids = await self.db_pool.runInteraction( "get_all_new_events_stream", get_all_new_events_stream_txn) events = await self.get_events_as_list(event_ids) return upper_bound, events async def get_federation_out_pos(self, typ: str) -> int: if self._need_to_reset_federation_stream_positions: await self.db_pool.runInteraction( "_reset_federation_positions_txn", self._reset_federation_positions_txn) self._need_to_reset_federation_stream_positions = False return await self.db_pool.simple_select_one_onecol( table="federation_stream_position", retcol="stream_id", keyvalues={ "type": typ, "instance_name": self._instance_name }, desc="get_federation_out_pos", ) async def update_federation_out_pos(self, typ: str, stream_id: int) -> None: if self._need_to_reset_federation_stream_positions: await self.db_pool.runInteraction( "_reset_federation_positions_txn", self._reset_federation_positions_txn) self._need_to_reset_federation_stream_positions = False await self.db_pool.simple_update_one( table="federation_stream_position", keyvalues={ "type": typ, "instance_name": self._instance_name }, updatevalues={"stream_id": stream_id}, desc="update_federation_out_pos", ) def _reset_federation_positions_txn(self, txn: LoggingTransaction) -> None: """Fiddles with the `federation_stream_position` table to make it match the configured federation sender instances during start up. """ # The federation sender instances may have changed, so we need to # massage the `federation_stream_position` table to have a row per type # per instance sending federation. If there is a mismatch we update the # table with the correct rows using the *minimum* stream ID seen. This # may result in resending of events/EDUs to remote servers, but that is # preferable to dropping them. if not self._send_federation: return # Pull out the configured instances. If we don't have a shard config then # we assume that we're the only instance sending. configured_instances = self._federation_shard_config.instances if not configured_instances: configured_instances = [self._instance_name] elif self._instance_name not in configured_instances: return instances_in_table = self.db_pool.simple_select_onecol_txn( txn, table="federation_stream_position", keyvalues={}, retcol="instance_name", ) if set(instances_in_table) == set(configured_instances): # Nothing to do return sql = """ SELECT type, MIN(stream_id) FROM federation_stream_position GROUP BY type """ txn.execute(sql) min_positions = {typ: pos for typ, pos in txn} # Map from type -> min position # Ensure we do actually have some values here assert set(min_positions) == {"federation", "events"} sql = """ DELETE FROM federation_stream_position WHERE NOT (%s) """ clause, args = make_in_list_sql_clause(txn.database_engine, "instance_name", configured_instances) txn.execute(sql % (clause, ), args) for typ, stream_id in min_positions.items(): self.db_pool.simple_upsert_txn( txn, table="federation_stream_position", keyvalues={ "type": typ, "instance_name": self._instance_name }, values={"stream_id": stream_id}, ) def has_room_changed_since(self, room_id: str, stream_id: int) -> bool: return self._events_stream_cache.has_entity_changed(room_id, stream_id) def _paginate_room_events_txn( self, txn: LoggingTransaction, room_id: str, from_token: RoomStreamToken, to_token: Optional[RoomStreamToken] = None, direction: str = "b", limit: int = -1, event_filter: Optional[Filter] = None, ) -> Tuple[List[_EventDictReturn], RoomStreamToken]: """Returns list of events before or after a given token. Args: txn room_id from_token: The token used to stream from to_token: A token which if given limits the results to only those before direction: Either 'b' or 'f' to indicate whether we are paginating forwards or backwards from `from_key`. limit: The maximum number of events to return. event_filter: If provided filters the events to those that match the filter. Returns: A list of _EventDictReturn and a token that points to the end of the result set. If no events are returned then the end of the stream has been reached (i.e. there are no events between `from_token` and `to_token`), or `limit` is zero. """ assert int(limit) >= 0 # Tokens really represent positions between elements, but we use # the convention of pointing to the event before the gap. Hence # we have a bit of asymmetry when it comes to equalities. args = [False, room_id] if direction == "b": order = "DESC" else: order = "ASC" # The bounds for the stream tokens are complicated by the fact # that we need to handle the instance_map part of the tokens. We do this # by fetching all events between the min stream token and the maximum # stream token (as returned by `RoomStreamToken.get_max_stream_pos`) and # then filtering the results. if from_token.topological is not None: from_bound: Tuple[Optional[int], int] = from_token.as_historical_tuple() elif direction == "b": from_bound = ( None, from_token.get_max_stream_pos(), ) else: from_bound = ( None, from_token.stream, ) to_bound: Optional[Tuple[Optional[int], int]] = None if to_token: if to_token.topological is not None: to_bound = to_token.as_historical_tuple() elif direction == "b": to_bound = ( None, to_token.stream, ) else: to_bound = ( None, to_token.get_max_stream_pos(), ) bounds = generate_pagination_where_clause( direction=direction, column_names=("event.topological_ordering", "event.stream_ordering"), from_token=from_bound, to_token=to_bound, engine=self.database_engine, ) filter_clause, filter_args = filter_to_clause(event_filter) if filter_clause: bounds += " AND " + filter_clause args.extend(filter_args) # We fetch more events as we'll filter the result set args.append(int(limit) * 2) select_keywords = "SELECT" join_clause = "" # Using DISTINCT in this SELECT query is quite expensive, because it # requires the engine to sort on the entire (not limited) result set, # i.e. the entire events table. Only use it in scenarios that could result # in the same event ID occurring multiple times in the results. needs_distinct = False if event_filter and event_filter.labels: # If we're not filtering on a label, then joining on event_labels will # return as many row for a single event as the number of labels it has. To # avoid this, only join if we're filtering on at least one label. join_clause += """ LEFT JOIN event_labels USING (event_id, room_id, topological_ordering) """ if len(event_filter.labels) > 1: # Multiple labels could cause the same event to appear multiple times. needs_distinct = True # If there is a filter on relation_senders and relation_types join to the # relations table. if event_filter and (event_filter.relation_senders or event_filter.relation_types): # Filtering by relations could cause the same event to appear multiple # times (since there's no limit on the number of relations to an event). needs_distinct = True join_clause += """ LEFT JOIN event_relations AS relation ON (event.event_id = relation.relates_to_id) """ if event_filter.relation_senders: join_clause += """ LEFT JOIN events AS related_event ON (relation.event_id = related_event.event_id) """ if needs_distinct: select_keywords += " DISTINCT" sql = """ %(select_keywords)s event.event_id, event.instance_name, event.topological_ordering, event.stream_ordering FROM events AS event %(join_clause)s WHERE event.outlier = ? AND event.room_id = ? AND %(bounds)s ORDER BY event.topological_ordering %(order)s, event.stream_ordering %(order)s LIMIT ? """ % { "select_keywords": select_keywords, "join_clause": join_clause, "bounds": bounds, "order": order, } txn.execute(sql, args) # Filter the result set. rows = [ _EventDictReturn(event_id, topological_ordering, stream_ordering) for event_id, instance_name, topological_ordering, stream_ordering in txn if _filter_results( lower_token=to_token if direction == "b" else from_token, upper_token=from_token if direction == "b" else to_token, instance_name=instance_name, topological_ordering=topological_ordering, stream_ordering=stream_ordering, ) ][:limit] if rows: topo = rows[-1].topological_ordering toke = rows[-1].stream_ordering if direction == "b": # Tokens are positions between events. # This token points *after* the last event in the chunk. # We need it to point to the event before it in the chunk # when we are going backwards so we subtract one from the # stream part. toke -= 1 next_token = RoomStreamToken(topo, toke) else: # TODO (erikj): We should work out what to do here instead. next_token = to_token if to_token else from_token return rows, next_token async def paginate_room_events( self, room_id: str, from_key: RoomStreamToken, to_key: Optional[RoomStreamToken] = None, direction: str = "b", limit: int = -1, event_filter: Optional[Filter] = None, ) -> Tuple[List[EventBase], RoomStreamToken]: """Returns list of events before or after a given token. Args: room_id from_key: The token used to stream from to_key: A token which if given limits the results to only those before direction: Either 'b' or 'f' to indicate whether we are paginating forwards or backwards from `from_key`. limit: The maximum number of events to return. event_filter: If provided filters the events to those that match the filter. Returns: The results as a list of events and a token that points to the end of the result set. If no events are returned then the end of the stream has been reached (i.e. there are no events between `from_key` and `to_key`). """ rows, token = await self.db_pool.runInteraction( "paginate_room_events", self._paginate_room_events_txn, room_id, from_key, to_key, direction, limit, event_filter, ) events = await self.get_events_as_list([r.event_id for r in rows], get_prev_content=True) self._set_before_and_after(events, rows) return events, token @cached() async def get_id_for_instance(self, instance_name: str) -> int: """Get a unique, immutable ID that corresponds to the given Synapse worker instance.""" def _get_id_for_instance_txn(txn): instance_id = self.db_pool.simple_select_one_onecol_txn( txn, table="instance_map", keyvalues={"instance_name": instance_name}, retcol="instance_id", allow_none=True, ) if instance_id is not None: return instance_id # If we don't have an entry upsert one. # # We could do this before the first check, and rely on the cache for # efficiency, but each UPSERT causes the next ID to increment which # can quickly bloat the size of the generated IDs for new instances. self.db_pool.simple_upsert_txn( txn, table="instance_map", keyvalues={"instance_name": instance_name}, values={}, ) return self.db_pool.simple_select_one_onecol_txn( txn, table="instance_map", keyvalues={"instance_name": instance_name}, retcol="instance_id", ) return await self.db_pool.runInteraction("get_id_for_instance", _get_id_for_instance_txn) @cached() async def get_name_from_instance_id(self, instance_id: int) -> str: """Get the instance name from an ID previously returned by `get_id_for_instance`. """ return await self.db_pool.simple_select_one_onecol( table="instance_map", keyvalues={"instance_id": instance_id}, retcol="instance_name", desc="get_name_from_instance_id", )
class AccountDataWorkerStore(SQLBaseStore): """This is an abstract base class where subclasses must implement `get_max_account_data_stream_id` which can be called in the initializer. """ # This ABCMeta metaclass ensures that we cannot be instantiated without # the abstract methods being implemented. __metaclass__ = abc.ABCMeta def __init__(self, database: DatabasePool, db_conn, hs): account_max = self.get_max_account_data_stream_id() self._account_data_stream_cache = StreamChangeCache( "AccountDataAndTagsChangeCache", account_max ) super(AccountDataWorkerStore, self).__init__(database, db_conn, hs) @abc.abstractmethod def get_max_account_data_stream_id(self): """Get the current max stream ID for account data stream Returns: int """ raise NotImplementedError() @cached() async def get_account_data_for_user( self, user_id: str ) -> Tuple[Dict[str, JsonDict], Dict[str, Dict[str, JsonDict]]]: """Get all the client account_data for a user. Args: user_id: The user to get the account_data for. Returns: A 2-tuple of a dict of global account_data and a dict mapping from room_id string to per room account_data dicts. """ def get_account_data_for_user_txn(txn): rows = self.db_pool.simple_select_list_txn( txn, "account_data", {"user_id": user_id}, ["account_data_type", "content"], ) global_account_data = { row["account_data_type"]: db_to_json(row["content"]) for row in rows } rows = self.db_pool.simple_select_list_txn( txn, "room_account_data", {"user_id": user_id}, ["room_id", "account_data_type", "content"], ) by_room = {} for row in rows: room_data = by_room.setdefault(row["room_id"], {}) room_data[row["account_data_type"]] = db_to_json(row["content"]) return global_account_data, by_room return await self.db_pool.runInteraction( "get_account_data_for_user", get_account_data_for_user_txn ) @cached(num_args=2, max_entries=5000) async def get_global_account_data_by_type_for_user( self, data_type: str, user_id: str ) -> Optional[JsonDict]: """ Returns: The account data. """ result = await self.db_pool.simple_select_one_onecol( table="account_data", keyvalues={"user_id": user_id, "account_data_type": data_type}, retcol="content", desc="get_global_account_data_by_type_for_user", allow_none=True, ) if result: return db_to_json(result) else: return None @cached(num_args=2) async def get_account_data_for_room( self, user_id: str, room_id: str ) -> Dict[str, JsonDict]: """Get all the client account_data for a user for a room. Args: user_id: The user to get the account_data for. room_id: The room to get the account_data for. Returns: A dict of the room account_data """ def get_account_data_for_room_txn(txn): rows = self.db_pool.simple_select_list_txn( txn, "room_account_data", {"user_id": user_id, "room_id": room_id}, ["account_data_type", "content"], ) return { row["account_data_type"]: db_to_json(row["content"]) for row in rows } return await self.db_pool.runInteraction( "get_account_data_for_room", get_account_data_for_room_txn ) @cached(num_args=3, max_entries=5000) async def get_account_data_for_room_and_type( self, user_id: str, room_id: str, account_data_type: str ) -> Optional[JsonDict]: """Get the client account_data of given type for a user for a room. Args: user_id: The user to get the account_data for. room_id: The room to get the account_data for. account_data_type: The account data type to get. Returns: The room account_data for that type, or None if there isn't any set. """ def get_account_data_for_room_and_type_txn(txn): content_json = self.db_pool.simple_select_one_onecol_txn( txn, table="room_account_data", keyvalues={ "user_id": user_id, "room_id": room_id, "account_data_type": account_data_type, }, retcol="content", allow_none=True, ) return db_to_json(content_json) if content_json else None return await self.db_pool.runInteraction( "get_account_data_for_room_and_type", get_account_data_for_room_and_type_txn ) async def get_updated_global_account_data( self, last_id: int, current_id: int, limit: int ) -> List[Tuple[int, str, str]]: """Get the global account_data that has changed, for the account_data stream Args: last_id: the last stream_id from the previous batch. current_id: the maximum stream_id to return up to limit: the maximum number of rows to return Returns: A list of tuples of stream_id int, user_id string, and type string. """ if last_id == current_id: return [] def get_updated_global_account_data_txn(txn): sql = ( "SELECT stream_id, user_id, account_data_type" " FROM account_data WHERE ? < stream_id AND stream_id <= ?" " ORDER BY stream_id ASC LIMIT ?" ) txn.execute(sql, (last_id, current_id, limit)) return txn.fetchall() return await self.db_pool.runInteraction( "get_updated_global_account_data", get_updated_global_account_data_txn ) async def get_updated_room_account_data( self, last_id: int, current_id: int, limit: int ) -> List[Tuple[int, str, str, str]]: """Get the global account_data that has changed, for the account_data stream Args: last_id: the last stream_id from the previous batch. current_id: the maximum stream_id to return up to limit: the maximum number of rows to return Returns: A list of tuples of stream_id int, user_id string, room_id string and type string. """ if last_id == current_id: return [] def get_updated_room_account_data_txn(txn): sql = ( "SELECT stream_id, user_id, room_id, account_data_type" " FROM room_account_data WHERE ? < stream_id AND stream_id <= ?" " ORDER BY stream_id ASC LIMIT ?" ) txn.execute(sql, (last_id, current_id, limit)) return txn.fetchall() return await self.db_pool.runInteraction( "get_updated_room_account_data", get_updated_room_account_data_txn ) async def get_updated_account_data_for_user( self, user_id: str, stream_id: int ) -> Tuple[Dict[str, JsonDict], Dict[str, Dict[str, JsonDict]]]: """Get all the client account_data for a that's changed for a user Args: user_id: The user to get the account_data for. stream_id: The point in the stream since which to get updates Returns: A deferred pair of a dict of global account_data and a dict mapping from room_id string to per room account_data dicts. """ def get_updated_account_data_for_user_txn(txn): sql = ( "SELECT account_data_type, content FROM account_data" " WHERE user_id = ? AND stream_id > ?" ) txn.execute(sql, (user_id, stream_id)) global_account_data = {row[0]: db_to_json(row[1]) for row in txn} sql = ( "SELECT room_id, account_data_type, content FROM room_account_data" " WHERE user_id = ? AND stream_id > ?" ) txn.execute(sql, (user_id, stream_id)) account_data_by_room = {} for row in txn: room_account_data = account_data_by_room.setdefault(row[0], {}) room_account_data[row[1]] = db_to_json(row[2]) return global_account_data, account_data_by_room changed = self._account_data_stream_cache.has_entity_changed( user_id, int(stream_id) ) if not changed: return ({}, {}) return await self.db_pool.runInteraction( "get_updated_account_data_for_user", get_updated_account_data_for_user_txn ) @cached(num_args=2, cache_context=True, max_entries=5000) async def is_ignored_by( self, ignored_user_id: str, ignorer_user_id: str, cache_context: _CacheContext ) -> bool: ignored_account_data = await self.get_global_account_data_by_type_for_user( "m.ignored_user_list", ignorer_user_id, on_invalidate=cache_context.invalidate, ) if not ignored_account_data: return False return ignored_user_id in ignored_account_data.get("ignored_users", {})
class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): """This is an abstract base class where subclasses must implement `get_room_max_stream_ordering` and `get_room_min_stream_ordering` which can be called in the initializer. """ __metaclass__ = abc.ABCMeta def __init__(self, db_conn, hs): super(StreamWorkerStore, self).__init__(db_conn, hs) events_max = self.get_room_max_stream_ordering() event_cache_prefill, min_event_val = self._get_cache_dict( db_conn, "events", entity_column="room_id", stream_column="stream_ordering", max_value=events_max, ) self._events_stream_cache = StreamChangeCache( "EventsRoomStreamChangeCache", min_event_val, prefilled_cache=event_cache_prefill, ) self._membership_stream_cache = StreamChangeCache( "MembershipStreamChangeCache", events_max, ) self._stream_order_on_start = self.get_room_max_stream_ordering() @abc.abstractmethod def get_room_max_stream_ordering(self): raise NotImplementedError() @abc.abstractmethod def get_room_min_stream_ordering(self): raise NotImplementedError() @defer.inlineCallbacks def get_room_events_stream_for_rooms(self, room_ids, from_key, to_key, limit=0, order='DESC'): from_id = RoomStreamToken.parse_stream_token(from_key).stream room_ids = yield self._events_stream_cache.get_entities_changed( room_ids, from_id ) if not room_ids: defer.returnValue({}) results = {} room_ids = list(room_ids) for rm_ids in (room_ids[i:i + 20] for i in xrange(0, len(room_ids), 20)): res = yield make_deferred_yieldable(defer.gatherResults([ preserve_fn(self.get_room_events_stream_for_room)( room_id, from_key, to_key, limit, order=order, ) for room_id in rm_ids ])) results.update(dict(zip(rm_ids, res))) defer.returnValue(results) def get_rooms_that_changed(self, room_ids, from_key): """Given a list of rooms and a token, return rooms where there may have been changes. Args: room_ids (list) from_key (str): The room_key portion of a StreamToken """ from_key = RoomStreamToken.parse_stream_token(from_key).stream return set( room_id for room_id in room_ids if self._events_stream_cache.has_entity_changed(room_id, from_key) ) @defer.inlineCallbacks def get_room_events_stream_for_room(self, room_id, from_key, to_key, limit=0, order='DESC'): # Note: If from_key is None then we return in topological order. This # is because in that case we're using this as a "get the last few messages # in a room" function, rather than "get new messages since last sync" if from_key is not None: from_id = RoomStreamToken.parse_stream_token(from_key).stream else: from_id = None to_id = RoomStreamToken.parse_stream_token(to_key).stream if from_key == to_key: defer.returnValue(([], from_key)) if from_id: has_changed = yield self._events_stream_cache.has_entity_changed( room_id, from_id ) if not has_changed: defer.returnValue(([], from_key)) def f(txn): if from_id is not None: sql = ( "SELECT event_id, stream_ordering FROM events WHERE" " room_id = ?" " AND not outlier" " AND stream_ordering > ? AND stream_ordering <= ?" " ORDER BY stream_ordering %s LIMIT ?" ) % (order,) txn.execute(sql, (room_id, from_id, to_id, limit)) else: sql = ( "SELECT event_id, stream_ordering FROM events WHERE" " room_id = ?" " AND not outlier" " AND stream_ordering <= ?" " ORDER BY topological_ordering %s, stream_ordering %s LIMIT ?" ) % (order, order,) txn.execute(sql, (room_id, to_id, limit)) rows = self.cursor_to_dict(txn) return rows rows = yield self.runInteraction("get_room_events_stream_for_room", f) ret = yield self._get_events( [r["event_id"] for r in rows], get_prev_content=True ) self._set_before_and_after(ret, rows, topo_order=from_id is None) if order.lower() == "desc": ret.reverse() if rows: key = "s%d" % min(r["stream_ordering"] for r in rows) else: # Assume we didn't get anything because there was nothing to # get. key = from_key defer.returnValue((ret, key)) @defer.inlineCallbacks def get_membership_changes_for_user(self, user_id, from_key, to_key): if from_key is not None: from_id = RoomStreamToken.parse_stream_token(from_key).stream else: from_id = None to_id = RoomStreamToken.parse_stream_token(to_key).stream if from_key == to_key: defer.returnValue([]) if from_id: has_changed = self._membership_stream_cache.has_entity_changed( user_id, int(from_id) ) if not has_changed: defer.returnValue([]) def f(txn): if from_id is not None: sql = ( "SELECT m.event_id, stream_ordering FROM events AS e," " room_memberships AS m" " WHERE e.event_id = m.event_id" " AND m.user_id = ?" " AND e.stream_ordering > ? AND e.stream_ordering <= ?" " ORDER BY e.stream_ordering ASC" ) txn.execute(sql, (user_id, from_id, to_id,)) else: sql = ( "SELECT m.event_id, stream_ordering FROM events AS e," " room_memberships AS m" " WHERE e.event_id = m.event_id" " AND m.user_id = ?" " AND stream_ordering <= ?" " ORDER BY stream_ordering ASC" ) txn.execute(sql, (user_id, to_id,)) rows = self.cursor_to_dict(txn) return rows rows = yield self.runInteraction("get_membership_changes_for_user", f) ret = yield self._get_events( [r["event_id"] for r in rows], get_prev_content=True ) self._set_before_and_after(ret, rows, topo_order=False) defer.returnValue(ret) @defer.inlineCallbacks def get_recent_events_for_room(self, room_id, limit, end_token, from_token=None): rows, token = yield self.get_recent_event_ids_for_room( room_id, limit, end_token, from_token ) logger.debug("stream before") events = yield self._get_events( [r["event_id"] for r in rows], get_prev_content=True ) logger.debug("stream after") self._set_before_and_after(events, rows) defer.returnValue((events, token)) @cached(num_args=4) def get_recent_event_ids_for_room(self, room_id, limit, end_token, from_token=None): end_token = RoomStreamToken.parse_stream_token(end_token) if from_token is None: sql = ( "SELECT stream_ordering, topological_ordering, event_id" " FROM events" " WHERE room_id = ? AND stream_ordering <= ? AND outlier = ?" " ORDER BY topological_ordering DESC, stream_ordering DESC" " LIMIT ?" ) else: from_token = RoomStreamToken.parse_stream_token(from_token) sql = ( "SELECT stream_ordering, topological_ordering, event_id" " FROM events" " WHERE room_id = ? AND stream_ordering > ?" " AND stream_ordering <= ? AND outlier = ?" " ORDER BY topological_ordering DESC, stream_ordering DESC" " LIMIT ?" ) def get_recent_events_for_room_txn(txn): if from_token is None: txn.execute(sql, (room_id, end_token.stream, False, limit,)) else: txn.execute(sql, ( room_id, from_token.stream, end_token.stream, False, limit )) rows = self.cursor_to_dict(txn) rows.reverse() # As we selected with reverse ordering if rows: # Tokens are positions between events. # This token points *after* the last event in the chunk. # We need it to point to the event before it in the chunk # since we are going backwards so we subtract one from the # stream part. topo = rows[0]["topological_ordering"] toke = rows[0]["stream_ordering"] - 1 start_token = str(RoomStreamToken(topo, toke)) token = (start_token, str(end_token)) else: token = (str(end_token), str(end_token)) return rows, token return self.runInteraction( "get_recent_events_for_room", get_recent_events_for_room_txn ) def get_room_event_after_stream_ordering(self, room_id, stream_ordering): """Gets details of the first event in a room at or after a stream ordering Args: room_id (str): stream_ordering (int): Returns: Deferred[(int, int, str)]: (stream ordering, topological ordering, event_id) """ def _f(txn): sql = ( "SELECT stream_ordering, topological_ordering, event_id" " FROM events" " WHERE room_id = ? AND stream_ordering >= ?" " AND NOT outlier" " ORDER BY stream_ordering" " LIMIT 1" ) txn.execute(sql, (room_id, stream_ordering, )) return txn.fetchone() return self.runInteraction( "get_room_event_after_stream_ordering", _f, ) @defer.inlineCallbacks def get_room_events_max_id(self, room_id=None): """Returns the current token for rooms stream. By default, it returns the current global stream token. Specifying a `room_id` causes it to return the current room specific topological token. """ token = yield self.get_room_max_stream_ordering() if room_id is None: defer.returnValue("s%d" % (token,)) else: topo = yield self.runInteraction( "_get_max_topological_txn", self._get_max_topological_txn, room_id, ) defer.returnValue("t%d-%d" % (topo, token)) def get_stream_token_for_event(self, event_id): """The stream token for an event Args: event_id(str): The id of the event to look up a stream token for. Raises: StoreError if the event wasn't in the database. Returns: A deferred "s%d" stream token. """ return self._simple_select_one_onecol( table="events", keyvalues={"event_id": event_id}, retcol="stream_ordering", ).addCallback(lambda row: "s%d" % (row,)) def get_topological_token_for_event(self, event_id): """The stream token for an event Args: event_id(str): The id of the event to look up a stream token for. Raises: StoreError if the event wasn't in the database. Returns: A deferred "t%d-%d" topological token. """ return self._simple_select_one( table="events", keyvalues={"event_id": event_id}, retcols=("stream_ordering", "topological_ordering"), desc="get_topological_token_for_event", ).addCallback(lambda row: "t%d-%d" % ( row["topological_ordering"], row["stream_ordering"],) ) def get_max_topological_token(self, room_id, stream_key): sql = ( "SELECT max(topological_ordering) FROM events" " WHERE room_id = ? AND stream_ordering < ?" ) return self._execute( "get_max_topological_token", None, sql, room_id, stream_key, ).addCallback( lambda r: r[0][0] if r else 0 ) def _get_max_topological_txn(self, txn, room_id): txn.execute( "SELECT MAX(topological_ordering) FROM events" " WHERE room_id = ?", (room_id,) ) rows = txn.fetchall() return rows[0][0] if rows else 0 @staticmethod def _set_before_and_after(events, rows, topo_order=True): for event, row in zip(events, rows): stream = row["stream_ordering"] if topo_order: topo = event.depth else: topo = None internal = event.internal_metadata internal.before = str(RoomStreamToken(topo, stream - 1)) internal.after = str(RoomStreamToken(topo, stream)) internal.order = ( int(topo) if topo else 0, int(stream), ) @defer.inlineCallbacks def get_events_around(self, room_id, event_id, before_limit, after_limit): """Retrieve events and pagination tokens around a given event in a room. Args: room_id (str) event_id (str) before_limit (int) after_limit (int) Returns: dict """ results = yield self.runInteraction( "get_events_around", self._get_events_around_txn, room_id, event_id, before_limit, after_limit ) events_before = yield self._get_events( [e for e in results["before"]["event_ids"]], get_prev_content=True ) events_after = yield self._get_events( [e for e in results["after"]["event_ids"]], get_prev_content=True ) defer.returnValue({ "events_before": events_before, "events_after": events_after, "start": results["before"]["token"], "end": results["after"]["token"], }) def _get_events_around_txn(self, txn, room_id, event_id, before_limit, after_limit): """Retrieves event_ids and pagination tokens around a given event in a room. Args: room_id (str) event_id (str) before_limit (int) after_limit (int) Returns: dict """ results = self._simple_select_one_txn( txn, "events", keyvalues={ "event_id": event_id, "room_id": room_id, }, retcols=["stream_ordering", "topological_ordering"], ) token = RoomStreamToken( results["topological_ordering"], results["stream_ordering"], ) if isinstance(self.database_engine, Sqlite3Engine): # SQLite3 doesn't optimise ``(x < a) OR (x = a AND y < b)`` # So we give pass it to SQLite3 as the UNION ALL of the two queries. query_before = ( "SELECT topological_ordering, stream_ordering, event_id FROM events" " WHERE room_id = ? AND topological_ordering < ?" " UNION ALL" " SELECT topological_ordering, stream_ordering, event_id FROM events" " WHERE room_id = ? AND topological_ordering = ? AND stream_ordering < ?" " ORDER BY topological_ordering DESC, stream_ordering DESC LIMIT ?" ) before_args = ( room_id, token.topological, room_id, token.topological, token.stream, before_limit, ) query_after = ( "SELECT topological_ordering, stream_ordering, event_id FROM events" " WHERE room_id = ? AND topological_ordering > ?" " UNION ALL" " SELECT topological_ordering, stream_ordering, event_id FROM events" " WHERE room_id = ? AND topological_ordering = ? AND stream_ordering > ?" " ORDER BY topological_ordering ASC, stream_ordering ASC LIMIT ?" ) after_args = ( room_id, token.topological, room_id, token.topological, token.stream, after_limit, ) else: query_before = ( "SELECT topological_ordering, stream_ordering, event_id FROM events" " WHERE room_id = ? AND %s" " ORDER BY topological_ordering DESC, stream_ordering DESC LIMIT ?" ) % (upper_bound(token, self.database_engine, inclusive=False),) before_args = (room_id, before_limit) query_after = ( "SELECT topological_ordering, stream_ordering, event_id FROM events" " WHERE room_id = ? AND %s" " ORDER BY topological_ordering ASC, stream_ordering ASC LIMIT ?" ) % (lower_bound(token, self.database_engine, inclusive=False),) after_args = (room_id, after_limit) txn.execute(query_before, before_args) rows = self.cursor_to_dict(txn) events_before = [r["event_id"] for r in rows] if rows: start_token = str(RoomStreamToken( rows[0]["topological_ordering"], rows[0]["stream_ordering"] - 1, )) else: start_token = str(RoomStreamToken( token.topological, token.stream - 1, )) txn.execute(query_after, after_args) rows = self.cursor_to_dict(txn) events_after = [r["event_id"] for r in rows] if rows: end_token = str(RoomStreamToken( rows[-1]["topological_ordering"], rows[-1]["stream_ordering"], )) else: end_token = str(token) return { "before": { "event_ids": events_before, "token": start_token, }, "after": { "event_ids": events_after, "token": end_token, }, } @defer.inlineCallbacks def get_all_new_events_stream(self, from_id, current_id, limit): """Get all new events""" def get_all_new_events_stream_txn(txn): sql = ( "SELECT e.stream_ordering, e.event_id" " FROM events AS e" " WHERE" " ? < e.stream_ordering AND e.stream_ordering <= ?" " ORDER BY e.stream_ordering ASC" " LIMIT ?" ) txn.execute(sql, (from_id, current_id, limit)) rows = txn.fetchall() upper_bound = current_id if len(rows) == limit: upper_bound = rows[-1][0] return upper_bound, [row[1] for row in rows] upper_bound, event_ids = yield self.runInteraction( "get_all_new_events_stream", get_all_new_events_stream_txn, ) events = yield self._get_events(event_ids) defer.returnValue((upper_bound, events)) def get_federation_out_pos(self, typ): return self._simple_select_one_onecol( table="federation_stream_position", retcol="stream_id", keyvalues={"type": typ}, desc="get_federation_out_pos" ) def update_federation_out_pos(self, typ, stream_id): return self._simple_update_one( table="federation_stream_position", keyvalues={"type": typ}, updatevalues={"stream_id": stream_id}, desc="update_federation_out_pos", ) def has_room_changed_since(self, room_id, stream_id): return self._events_stream_cache.has_entity_changed(room_id, stream_id)
class AccountDataWorkerStore(SQLBaseStore): """This is an abstract base class where subclasses must implement `get_max_account_data_stream_id` which can be called in the initializer. """ def __init__(self, database: DatabasePool, db_conn, hs): self._instance_name = hs.get_instance_name() if isinstance(database.engine, PostgresEngine): self._can_write_to_account_data = ( self._instance_name in hs.config.worker.writers.account_data) self._account_data_id_gen = MultiWriterIdGenerator( db_conn=db_conn, db=database, stream_name="account_data", instance_name=self._instance_name, tables=[ ("room_account_data", "instance_name", "stream_id"), ("room_tags_revisions", "instance_name", "stream_id"), ("account_data", "instance_name", "stream_id"), ], sequence_name="account_data_sequence", writers=hs.config.worker.writers.account_data, ) else: self._can_write_to_account_data = True # We shouldn't be running in worker mode with SQLite, but its useful # to support it for unit tests. # # If this process is the writer than we need to use # `StreamIdGenerator`, otherwise we use `SlavedIdTracker` which gets # updated over replication. (Multiple writers are not supported for # SQLite). if hs.get_instance_name() in hs.config.worker.writers.account_data: self._account_data_id_gen = StreamIdGenerator( db_conn, "room_account_data", "stream_id", extra_tables=[("room_tags_revisions", "stream_id")], ) else: self._account_data_id_gen = SlavedIdTracker( db_conn, "room_account_data", "stream_id", extra_tables=[("room_tags_revisions", "stream_id")], ) account_max = self.get_max_account_data_stream_id() self._account_data_stream_cache = StreamChangeCache( "AccountDataAndTagsChangeCache", account_max) super().__init__(database, db_conn, hs) def get_max_account_data_stream_id(self) -> int: """Get the current max stream ID for account data stream Returns: int """ return self._account_data_id_gen.get_current_token() @cached() async def get_account_data_for_user( self, user_id: str ) -> Tuple[Dict[str, JsonDict], Dict[str, Dict[str, JsonDict]]]: """Get all the client account_data for a user. Args: user_id: The user to get the account_data for. Returns: A 2-tuple of a dict of global account_data and a dict mapping from room_id string to per room account_data dicts. """ def get_account_data_for_user_txn(txn): rows = self.db_pool.simple_select_list_txn( txn, "account_data", {"user_id": user_id}, ["account_data_type", "content"], ) global_account_data = { row["account_data_type"]: db_to_json(row["content"]) for row in rows } rows = self.db_pool.simple_select_list_txn( txn, "room_account_data", {"user_id": user_id}, ["room_id", "account_data_type", "content"], ) by_room = {} for row in rows: room_data = by_room.setdefault(row["room_id"], {}) room_data[row["account_data_type"]] = db_to_json( row["content"]) return global_account_data, by_room return await self.db_pool.runInteraction( "get_account_data_for_user", get_account_data_for_user_txn) @cached(num_args=2, max_entries=5000) async def get_global_account_data_by_type_for_user( self, data_type: str, user_id: str) -> Optional[JsonDict]: """ Returns: The account data. """ result = await self.db_pool.simple_select_one_onecol( table="account_data", keyvalues={ "user_id": user_id, "account_data_type": data_type }, retcol="content", desc="get_global_account_data_by_type_for_user", allow_none=True, ) if result: return db_to_json(result) else: return None @cached(num_args=2) async def get_account_data_for_room(self, user_id: str, room_id: str) -> Dict[str, JsonDict]: """Get all the client account_data for a user for a room. Args: user_id: The user to get the account_data for. room_id: The room to get the account_data for. Returns: A dict of the room account_data """ def get_account_data_for_room_txn(txn): rows = self.db_pool.simple_select_list_txn( txn, "room_account_data", { "user_id": user_id, "room_id": room_id }, ["account_data_type", "content"], ) return { row["account_data_type"]: db_to_json(row["content"]) for row in rows } return await self.db_pool.runInteraction( "get_account_data_for_room", get_account_data_for_room_txn) @cached(num_args=3, max_entries=5000) async def get_account_data_for_room_and_type( self, user_id: str, room_id: str, account_data_type: str) -> Optional[JsonDict]: """Get the client account_data of given type for a user for a room. Args: user_id: The user to get the account_data for. room_id: The room to get the account_data for. account_data_type: The account data type to get. Returns: The room account_data for that type, or None if there isn't any set. """ def get_account_data_for_room_and_type_txn(txn): content_json = self.db_pool.simple_select_one_onecol_txn( txn, table="room_account_data", keyvalues={ "user_id": user_id, "room_id": room_id, "account_data_type": account_data_type, }, retcol="content", allow_none=True, ) return db_to_json(content_json) if content_json else None return await self.db_pool.runInteraction( "get_account_data_for_room_and_type", get_account_data_for_room_and_type_txn) async def get_updated_global_account_data( self, last_id: int, current_id: int, limit: int) -> List[Tuple[int, str, str]]: """Get the global account_data that has changed, for the account_data stream Args: last_id: the last stream_id from the previous batch. current_id: the maximum stream_id to return up to limit: the maximum number of rows to return Returns: A list of tuples of stream_id int, user_id string, and type string. """ if last_id == current_id: return [] def get_updated_global_account_data_txn(txn): sql = ("SELECT stream_id, user_id, account_data_type" " FROM account_data WHERE ? < stream_id AND stream_id <= ?" " ORDER BY stream_id ASC LIMIT ?") txn.execute(sql, (last_id, current_id, limit)) return txn.fetchall() return await self.db_pool.runInteraction( "get_updated_global_account_data", get_updated_global_account_data_txn) async def get_updated_room_account_data( self, last_id: int, current_id: int, limit: int) -> List[Tuple[int, str, str, str]]: """Get the global account_data that has changed, for the account_data stream Args: last_id: the last stream_id from the previous batch. current_id: the maximum stream_id to return up to limit: the maximum number of rows to return Returns: A list of tuples of stream_id int, user_id string, room_id string and type string. """ if last_id == current_id: return [] def get_updated_room_account_data_txn(txn): sql = ( "SELECT stream_id, user_id, room_id, account_data_type" " FROM room_account_data WHERE ? < stream_id AND stream_id <= ?" " ORDER BY stream_id ASC LIMIT ?") txn.execute(sql, (last_id, current_id, limit)) return txn.fetchall() return await self.db_pool.runInteraction( "get_updated_room_account_data", get_updated_room_account_data_txn) async def get_updated_account_data_for_user( self, user_id: str, stream_id: int ) -> Tuple[Dict[str, JsonDict], Dict[str, Dict[str, JsonDict]]]: """Get all the client account_data for a that's changed for a user Args: user_id: The user to get the account_data for. stream_id: The point in the stream since which to get updates Returns: A deferred pair of a dict of global account_data and a dict mapping from room_id string to per room account_data dicts. """ def get_updated_account_data_for_user_txn(txn): sql = ("SELECT account_data_type, content FROM account_data" " WHERE user_id = ? AND stream_id > ?") txn.execute(sql, (user_id, stream_id)) global_account_data = {row[0]: db_to_json(row[1]) for row in txn} sql = ( "SELECT room_id, account_data_type, content FROM room_account_data" " WHERE user_id = ? AND stream_id > ?") txn.execute(sql, (user_id, stream_id)) account_data_by_room = {} for row in txn: room_account_data = account_data_by_room.setdefault(row[0], {}) room_account_data[row[1]] = db_to_json(row[2]) return global_account_data, account_data_by_room changed = self._account_data_stream_cache.has_entity_changed( user_id, int(stream_id)) if not changed: return ({}, {}) return await self.db_pool.runInteraction( "get_updated_account_data_for_user", get_updated_account_data_for_user_txn) @cached(max_entries=5000, iterable=True) async def ignored_by(self, user_id: str) -> Set[str]: """ Get users which ignore the given user. Params: user_id: The user ID which might be ignored. Return: The user IDs which ignore the given user. """ return set(await self.db_pool.simple_select_onecol( table="ignored_users", keyvalues={"ignored_user_id": user_id}, retcol="ignorer_user_id", desc="ignored_by", )) def process_replication_rows(self, stream_name, instance_name, token, rows): if stream_name == TagAccountDataStream.NAME: self._account_data_id_gen.advance(instance_name, token) for row in rows: self.get_tags_for_user.invalidate((row.user_id, )) self._account_data_stream_cache.entity_has_changed( row.user_id, token) elif stream_name == AccountDataStream.NAME: self._account_data_id_gen.advance(instance_name, token) for row in rows: if not row.room_id: self.get_global_account_data_by_type_for_user.invalidate( (row.data_type, row.user_id)) self.get_account_data_for_user.invalidate((row.user_id, )) self.get_account_data_for_room.invalidate( (row.user_id, row.room_id)) self.get_account_data_for_room_and_type.invalidate( (row.user_id, row.room_id, row.data_type)) self._account_data_stream_cache.entity_has_changed( row.user_id, token) return super().process_replication_rows(stream_name, instance_name, token, rows) async def add_account_data_to_room(self, user_id: str, room_id: str, account_data_type: str, content: JsonDict) -> int: """Add some account_data to a room for a user. Args: user_id: The user to add a tag for. room_id: The room to add a tag for. account_data_type: The type of account_data to add. content: A json object to associate with the tag. Returns: The maximum stream ID. """ assert self._can_write_to_account_data content_json = json_encoder.encode(content) async with self._account_data_id_gen.get_next() as next_id: # no need to lock here as room_account_data has a unique constraint # on (user_id, room_id, account_data_type) so simple_upsert will # retry if there is a conflict. await self.db_pool.simple_upsert( desc="add_room_account_data", table="room_account_data", keyvalues={ "user_id": user_id, "room_id": room_id, "account_data_type": account_data_type, }, values={ "stream_id": next_id, "content": content_json }, lock=False, ) self._account_data_stream_cache.entity_has_changed( user_id, next_id) self.get_account_data_for_user.invalidate((user_id, )) self.get_account_data_for_room.invalidate((user_id, room_id)) self.get_account_data_for_room_and_type.prefill( (user_id, room_id, account_data_type), content) return self._account_data_id_gen.get_current_token() async def add_account_data_for_user(self, user_id: str, account_data_type: str, content: JsonDict) -> int: """Add some account_data to a room for a user. Args: user_id: The user to add a tag for. account_data_type: The type of account_data to add. content: A json object to associate with the tag. Returns: The maximum stream ID. """ assert self._can_write_to_account_data async with self._account_data_id_gen.get_next() as next_id: await self.db_pool.runInteraction( "add_user_account_data", self._add_account_data_for_user, next_id, user_id, account_data_type, content, ) self._account_data_stream_cache.entity_has_changed( user_id, next_id) self.get_account_data_for_user.invalidate((user_id, )) self.get_global_account_data_by_type_for_user.invalidate( (account_data_type, user_id)) return self._account_data_id_gen.get_current_token() def _add_account_data_for_user( self, txn, next_id: int, user_id: str, account_data_type: str, content: JsonDict, ) -> None: content_json = json_encoder.encode(content) # no need to lock here as account_data has a unique constraint on # (user_id, account_data_type) so simple_upsert will retry if # there is a conflict. self.db_pool.simple_upsert_txn( txn, table="account_data", keyvalues={ "user_id": user_id, "account_data_type": account_data_type }, values={ "stream_id": next_id, "content": content_json }, lock=False, ) # Ignored users get denormalized into a separate table as an optimisation. if account_data_type != AccountDataTypes.IGNORED_USER_LIST: return # Insert / delete to sync the list of ignored users. previously_ignored_users = set( self.db_pool.simple_select_onecol_txn( txn, table="ignored_users", keyvalues={"ignorer_user_id": user_id}, retcol="ignored_user_id", )) # If the data is invalid, no one is ignored. ignored_users_content = content.get("ignored_users", {}) if isinstance(ignored_users_content, dict): currently_ignored_users = set(ignored_users_content) else: currently_ignored_users = set() # Delete entries which are no longer ignored. self.db_pool.simple_delete_many_txn( txn, table="ignored_users", column="ignored_user_id", iterable=previously_ignored_users - currently_ignored_users, keyvalues={"ignorer_user_id": user_id}, ) # Add entries which are newly ignored. self.db_pool.simple_insert_many_txn( txn, table="ignored_users", values=[{ "ignorer_user_id": user_id, "ignored_user_id": u } for u in currently_ignored_users - previously_ignored_users], ) # Invalidate the cache for any ignored users which were added or removed. for ignored_user_id in previously_ignored_users ^ currently_ignored_users: self._invalidate_cache_and_stream(txn, self.ignored_by, (ignored_user_id, ))
class ReceiptsWorkerStore(SQLBaseStore): def __init__( self, database: DatabasePool, db_conn: LoggingDatabaseConnection, hs: "HomeServer", ): self._instance_name = hs.get_instance_name() self._receipts_id_gen: AbstractStreamIdTracker if isinstance(database.engine, PostgresEngine): self._can_write_to_receipts = ( self._instance_name in hs.config.worker.writers.receipts ) self._receipts_id_gen = MultiWriterIdGenerator( db_conn=db_conn, db=database, stream_name="receipts", instance_name=self._instance_name, tables=[("receipts_linearized", "instance_name", "stream_id")], sequence_name="receipts_sequence", writers=hs.config.worker.writers.receipts, ) else: self._can_write_to_receipts = True # We shouldn't be running in worker mode with SQLite, but its useful # to support it for unit tests. # # If this process is the writer than we need to use # `StreamIdGenerator`, otherwise we use `SlavedIdTracker` which gets # updated over replication. (Multiple writers are not supported for # SQLite). if hs.get_instance_name() in hs.config.worker.writers.receipts: self._receipts_id_gen = StreamIdGenerator( db_conn, "receipts_linearized", "stream_id" ) else: self._receipts_id_gen = SlavedIdTracker( db_conn, "receipts_linearized", "stream_id" ) super().__init__(database, db_conn, hs) max_receipts_stream_id = self.get_max_receipt_stream_id() receipts_stream_prefill, min_receipts_stream_id = self.db_pool.get_cache_dict( db_conn, "receipts_linearized", entity_column="room_id", stream_column="stream_id", max_value=max_receipts_stream_id, limit=10000, ) self._receipts_stream_cache = StreamChangeCache( "ReceiptsRoomChangeCache", min_receipts_stream_id, prefilled_cache=receipts_stream_prefill, ) def get_max_receipt_stream_id(self) -> int: """Get the current max stream ID for receipts stream""" return self._receipts_id_gen.get_current_token() async def get_last_receipt_event_id_for_user( self, user_id: str, room_id: str, receipt_types: Collection[str] ) -> Optional[str]: """ Fetch the event ID for the latest receipt in a room with one of the given receipt types. Args: user_id: The user to fetch receipts for. room_id: The room ID to fetch the receipt for. receipt_type: The receipt types to fetch. Returns: The latest receipt, if one exists. """ result = await self.db_pool.runInteraction( "get_last_receipt_event_id_for_user", self.get_last_receipt_for_user_txn, user_id, room_id, receipt_types, ) if not result: return None event_id, _ = result return event_id def get_last_receipt_for_user_txn( self, txn: LoggingTransaction, user_id: str, room_id: str, receipt_types: Collection[str], ) -> Optional[Tuple[str, int]]: """ Fetch the event ID and stream_ordering for the latest receipt in a room with one of the given receipt types. Args: user_id: The user to fetch receipts for. room_id: The room ID to fetch the receipt for. receipt_type: The receipt types to fetch. Returns: The latest receipt, if one exists. """ clause, args = make_in_list_sql_clause( self.database_engine, "receipt_type", receipt_types ) sql = f""" SELECT event_id, stream_ordering FROM receipts_linearized INNER JOIN events USING (room_id, event_id) WHERE {clause} AND user_id = ? AND room_id = ? ORDER BY stream_ordering DESC LIMIT 1 """ args.extend((user_id, room_id)) txn.execute(sql, args) return cast(Optional[Tuple[str, int]], txn.fetchone()) async def get_receipts_for_user( self, user_id: str, receipt_types: Iterable[str] ) -> Dict[str, str]: """ Fetch the event IDs for the latest receipts sent by the given user. Args: user_id: The user to fetch receipts for. receipt_types: The receipt types to check. Returns: A map of room ID to the event ID of the latest receipt for that room. If the user has not sent a receipt to a room then it will not appear in the returned dictionary. """ results = await self.get_receipts_for_user_with_orderings( user_id, receipt_types ) # Reduce the result to room ID -> event ID. return { room_id: room_result["event_id"] for room_id, room_result in results.items() } async def get_receipts_for_user_with_orderings( self, user_id: str, receipt_types: Iterable[str] ) -> JsonDict: """ Fetch receipts for all rooms that the given user is joined to. Args: user_id: The user to fetch receipts for. receipt_types: The receipt types to fetch. Earlier receipt types are given priority if multiple receipts point to the same event. Returns: A map of room ID to the latest receipt (for the given types). """ results: JsonDict = {} for receipt_type in receipt_types: partial_result = await self._get_receipts_for_user_with_orderings( user_id, receipt_type ) for room_id, room_result in partial_result.items(): # If the room has not yet been seen, or the receipt is newer, # use it. if ( room_id not in results or results[room_id]["stream_ordering"] < room_result["stream_ordering"] ): results[room_id] = room_result return results @cached() async def _get_receipts_for_user_with_orderings( self, user_id: str, receipt_type: str ) -> JsonDict: """ Fetch receipts for all rooms that the given user is joined to. Args: user_id: The user to fetch receipts for. receipt_type: The receipt type to fetch. Returns: A map of room ID to the latest receipt information. """ def f(txn: LoggingTransaction) -> List[Tuple[str, str, int, int]]: sql = ( "SELECT rl.room_id, rl.event_id," " e.topological_ordering, e.stream_ordering" " FROM receipts_linearized AS rl" " INNER JOIN events AS e USING (room_id, event_id)" " WHERE rl.room_id = e.room_id" " AND rl.event_id = e.event_id" " AND user_id = ?" " AND receipt_type = ?" ) txn.execute(sql, (user_id, receipt_type)) return cast(List[Tuple[str, str, int, int]], txn.fetchall()) rows = await self.db_pool.runInteraction( "get_receipts_for_user_with_orderings", f ) return { row[0]: { "event_id": row[1], "topological_ordering": row[2], "stream_ordering": row[3], } for row in rows } async def get_linearized_receipts_for_rooms( self, room_ids: Iterable[str], to_key: int, from_key: Optional[int] = None ) -> List[dict]: """Get receipts for multiple rooms for sending to clients. Args: room_id: The room IDs to fetch receipts of. to_key: Max stream id to fetch receipts up to. from_key: Min stream id to fetch receipts from. None fetches from the start. Returns: A list of receipts. """ room_ids = set(room_ids) if from_key is not None: # Only ask the database about rooms where there have been new # receipts added since `from_key` room_ids = self._receipts_stream_cache.get_entities_changed( room_ids, from_key ) results = await self._get_linearized_receipts_for_rooms( room_ids, to_key, from_key=from_key ) return [ev for res in results.values() for ev in res] async def get_linearized_receipts_for_room( self, room_id: str, to_key: int, from_key: Optional[int] = None ) -> List[dict]: """Get receipts for a single room for sending to clients. Args: room_ids: The room id. to_key: Max stream id to fetch receipts up to. from_key: Min stream id to fetch receipts from. None fetches from the start. Returns: A list of receipts. """ if from_key is not None: # Check the cache first to see if any new receipts have been added # since`from_key`. If not we can no-op. if not self._receipts_stream_cache.has_entity_changed(room_id, from_key): return [] return await self._get_linearized_receipts_for_room(room_id, to_key, from_key) @cached(tree=True) async def _get_linearized_receipts_for_room( self, room_id: str, to_key: int, from_key: Optional[int] = None ) -> List[JsonDict]: """See get_linearized_receipts_for_room""" def f(txn: LoggingTransaction) -> List[Dict[str, Any]]: if from_key: sql = ( "SELECT * FROM receipts_linearized WHERE" " room_id = ? AND stream_id > ? AND stream_id <= ?" ) txn.execute(sql, (room_id, from_key, to_key)) else: sql = ( "SELECT * FROM receipts_linearized WHERE" " room_id = ? AND stream_id <= ?" ) txn.execute(sql, (room_id, to_key)) rows = self.db_pool.cursor_to_dict(txn) return rows rows = await self.db_pool.runInteraction("get_linearized_receipts_for_room", f) if not rows: return [] content: JsonDict = {} for row in rows: content.setdefault(row["event_id"], {}).setdefault(row["receipt_type"], {})[ row["user_id"] ] = db_to_json(row["data"]) return [{"type": EduTypes.RECEIPT, "room_id": room_id, "content": content}] @cachedList( cached_method_name="_get_linearized_receipts_for_room", list_name="room_ids", num_args=3, ) async def _get_linearized_receipts_for_rooms( self, room_ids: Collection[str], to_key: int, from_key: Optional[int] = None ) -> Dict[str, List[JsonDict]]: if not room_ids: return {} def f(txn: LoggingTransaction) -> List[Dict[str, Any]]: if from_key: sql = """ SELECT * FROM receipts_linearized WHERE stream_id > ? AND stream_id <= ? AND """ clause, args = make_in_list_sql_clause( self.database_engine, "room_id", room_ids ) txn.execute(sql + clause, [from_key, to_key] + list(args)) else: sql = """ SELECT * FROM receipts_linearized WHERE stream_id <= ? AND """ clause, args = make_in_list_sql_clause( self.database_engine, "room_id", room_ids ) txn.execute(sql + clause, [to_key] + list(args)) return self.db_pool.cursor_to_dict(txn) txn_results = await self.db_pool.runInteraction( "_get_linearized_receipts_for_rooms", f ) results: JsonDict = {} for row in txn_results: # We want a single event per room, since we want to batch the # receipts by room, event and type. room_event = results.setdefault( row["room_id"], {"type": EduTypes.RECEIPT, "room_id": row["room_id"], "content": {}}, ) # The content is of the form: # {"$foo:bar": { "read": { "@user:host": <receipt> }, .. }, .. } event_entry = room_event["content"].setdefault(row["event_id"], {}) receipt_type = event_entry.setdefault(row["receipt_type"], {}) receipt_type[row["user_id"]] = db_to_json(row["data"]) results = { room_id: [results[room_id]] if room_id in results else [] for room_id in room_ids } return results @cached( num_args=2, ) async def get_linearized_receipts_for_all_rooms( self, to_key: int, from_key: Optional[int] = None ) -> Dict[str, JsonDict]: """Get receipts for all rooms between two stream_ids, up to a limit of the latest 100 read receipts. Args: to_key: Max stream id to fetch receipts up to. from_key: Min stream id to fetch receipts from. None fetches from the start. Returns: A dictionary of roomids to a list of receipts. """ def f(txn: LoggingTransaction) -> List[Dict[str, Any]]: if from_key: sql = """ SELECT * FROM receipts_linearized WHERE stream_id > ? AND stream_id <= ? ORDER BY stream_id DESC LIMIT 100 """ txn.execute(sql, [from_key, to_key]) else: sql = """ SELECT * FROM receipts_linearized WHERE stream_id <= ? ORDER BY stream_id DESC LIMIT 100 """ txn.execute(sql, [to_key]) return self.db_pool.cursor_to_dict(txn) txn_results = await self.db_pool.runInteraction( "get_linearized_receipts_for_all_rooms", f ) results: JsonDict = {} for row in txn_results: # We want a single event per room, since we want to batch the # receipts by room, event and type. room_event = results.setdefault( row["room_id"], {"type": EduTypes.RECEIPT, "room_id": row["room_id"], "content": {}}, ) # The content is of the form: # {"$foo:bar": { "read": { "@user:host": <receipt> }, .. }, .. } event_entry = room_event["content"].setdefault(row["event_id"], {}) receipt_type = event_entry.setdefault(row["receipt_type"], {}) receipt_type[row["user_id"]] = db_to_json(row["data"]) return results async def get_users_sent_receipts_between( self, last_id: int, current_id: int ) -> List[str]: """Get all users who sent receipts between `last_id` exclusive and `current_id` inclusive. Returns: The list of users. """ if last_id == current_id: return [] def _get_users_sent_receipts_between_txn(txn: LoggingTransaction) -> List[str]: sql = """ SELECT DISTINCT user_id FROM receipts_linearized WHERE ? < stream_id AND stream_id <= ? """ txn.execute(sql, (last_id, current_id)) return [r[0] for r in txn] return await self.db_pool.runInteraction( "get_users_sent_receipts_between", _get_users_sent_receipts_between_txn ) async def get_all_updated_receipts( self, instance_name: str, last_id: int, current_id: int, limit: int ) -> Tuple[List[Tuple[int, list]], int, bool]: """Get updates for receipts replication stream. Args: instance_name: The writer we want to fetch updates from. Unused here since there is only ever one writer. last_id: The token to fetch updates from. Exclusive. current_id: The token to fetch updates up to. Inclusive. limit: The requested limit for the number of rows to return. The function may return more or fewer rows. Returns: A tuple consisting of: the updates, a token to use to fetch subsequent updates, and whether we returned fewer rows than exists between the requested tokens due to the limit. The token returned can be used in a subsequent call to this function to get further updatees. The updates are a list of 2-tuples of stream ID and the row data """ if last_id == current_id: return [], current_id, False def get_all_updated_receipts_txn( txn: LoggingTransaction, ) -> Tuple[List[Tuple[int, list]], int, bool]: sql = """ SELECT stream_id, room_id, receipt_type, user_id, event_id, data FROM receipts_linearized WHERE ? < stream_id AND stream_id <= ? ORDER BY stream_id ASC LIMIT ? """ txn.execute(sql, (last_id, current_id, limit)) updates = cast( List[Tuple[int, list]], [(r[0], r[1:5] + (db_to_json(r[5]),)) for r in txn], ) limited = False upper_bound = current_id if len(updates) == limit: limited = True upper_bound = updates[-1][0] return updates, upper_bound, limited return await self.db_pool.runInteraction( "get_all_updated_receipts", get_all_updated_receipts_txn ) def invalidate_caches_for_receipt( self, room_id: str, receipt_type: str, user_id: str ) -> None: self._get_receipts_for_user_with_orderings.invalidate((user_id, receipt_type)) self._get_linearized_receipts_for_room.invalidate((room_id,)) # We use this method to invalidate so that we don't end up with circular # dependencies between the receipts and push action stores. self._attempt_to_invalidate_cache( "get_unread_event_push_actions_by_room_for_user", (room_id,) ) def process_replication_rows( self, stream_name: str, instance_name: str, token: int, rows: Iterable[Any], ) -> None: if stream_name == ReceiptsStream.NAME: self._receipts_id_gen.advance(instance_name, token) for row in rows: self.invalidate_caches_for_receipt( row.room_id, row.receipt_type, row.user_id ) self._receipts_stream_cache.entity_has_changed(row.room_id, token) return super().process_replication_rows(stream_name, instance_name, token, rows) def _insert_linearized_receipt_txn( self, txn: LoggingTransaction, room_id: str, receipt_type: str, user_id: str, event_id: str, data: JsonDict, stream_id: int, ) -> Optional[int]: """Inserts a receipt into the database if it's newer than the current one. Returns: None if the receipt is older than the current receipt otherwise, the rx timestamp of the event that the receipt corresponds to (or 0 if the event is unknown) """ assert self._can_write_to_receipts res = self.db_pool.simple_select_one_txn( txn, table="events", retcols=["stream_ordering", "received_ts"], keyvalues={"event_id": event_id}, allow_none=True, ) stream_ordering = int(res["stream_ordering"]) if res else None rx_ts = res["received_ts"] if res else 0 # We don't want to clobber receipts for more recent events, so we # have to compare orderings of existing receipts if stream_ordering is not None: sql = ( "SELECT stream_ordering, event_id FROM events" " INNER JOIN receipts_linearized AS r USING (event_id, room_id)" " WHERE r.room_id = ? AND r.receipt_type = ? AND r.user_id = ?" ) txn.execute(sql, (room_id, receipt_type, user_id)) for so, eid in txn: if int(so) >= stream_ordering: logger.debug( "Ignoring new receipt for %s in favour of existing " "one for later event %s", event_id, eid, ) return None txn.call_after( self.invalidate_caches_for_receipt, room_id, receipt_type, user_id ) txn.call_after( self._receipts_stream_cache.entity_has_changed, room_id, stream_id ) self.db_pool.simple_upsert_txn( txn, table="receipts_linearized", keyvalues={ "room_id": room_id, "receipt_type": receipt_type, "user_id": user_id, }, values={ "stream_id": stream_id, "event_id": event_id, "data": json_encoder.encode(data), }, # receipts_linearized has a unique constraint on # (user_id, room_id, receipt_type), so no need to lock lock=False, ) return rx_ts def _graph_to_linear( self, txn: LoggingTransaction, room_id: str, event_ids: List[str] ) -> str: """ Generate a linearized event from a list of events (i.e. a list of forward extremities in the room). This should allow for calculation of the correct read receipt even if servers have different event ordering. Args: txn: The transaction room_id: The room ID the events are in. event_ids: The list of event IDs to linearize. Returns: The linearized event ID. """ # TODO: Make this better. clause, args = make_in_list_sql_clause( self.database_engine, "event_id", event_ids ) sql = """ SELECT event_id WHERE room_id = ? AND stream_ordering IN ( SELECT max(stream_ordering) WHERE %s ) """ % ( clause, ) txn.execute(sql, [room_id] + list(args)) rows = txn.fetchall() if rows: return rows[0][0] else: raise RuntimeError("Unrecognized event_ids: %r" % (event_ids,)) async def insert_receipt( self, room_id: str, receipt_type: str, user_id: str, event_ids: List[str], data: dict, ) -> Optional[Tuple[int, int]]: """Insert a receipt, either from local client or remote server. Automatically does conversion between linearized and graph representations. Returns: The new receipts stream ID and token, if the receipt is newer than what was previously persisted. None, otherwise. """ assert self._can_write_to_receipts if not event_ids: return None if len(event_ids) == 1: linearized_event_id = event_ids[0] else: # we need to points in graph -> linearized form. linearized_event_id = await self.db_pool.runInteraction( "insert_receipt_conv", self._graph_to_linear, room_id, event_ids ) async with self._receipts_id_gen.get_next() as stream_id: # type: ignore[attr-defined] event_ts = await self.db_pool.runInteraction( "insert_linearized_receipt", self._insert_linearized_receipt_txn, room_id, receipt_type, user_id, linearized_event_id, data, stream_id=stream_id, # Read committed is actually beneficial here because we check for a receipt with # greater stream order, and checking the very latest data at select time is better # than the data at transaction start time. isolation_level=IsolationLevel.READ_COMMITTED, ) # If the receipt was older than the currently persisted one, nothing to do. if event_ts is None: return None now = self._clock.time_msec() logger.debug( "RR for event %s in %s (%i ms old)", linearized_event_id, room_id, now - event_ts, ) await self.db_pool.runInteraction( "insert_graph_receipt", self._insert_graph_receipt_txn, room_id, receipt_type, user_id, event_ids, data, ) max_persisted_id = self._receipts_id_gen.get_current_token() return stream_id, max_persisted_id def _insert_graph_receipt_txn( self, txn: LoggingTransaction, room_id: str, receipt_type: str, user_id: str, event_ids: List[str], data: JsonDict, ) -> None: assert self._can_write_to_receipts txn.call_after( self._get_receipts_for_user_with_orderings.invalidate, (user_id, receipt_type), ) # FIXME: This shouldn't invalidate the whole cache txn.call_after(self._get_linearized_receipts_for_room.invalidate, (room_id,)) self.db_pool.simple_delete_txn( txn, table="receipts_graph", keyvalues={ "room_id": room_id, "receipt_type": receipt_type, "user_id": user_id, }, ) self.db_pool.simple_insert_txn( txn, table="receipts_graph", values={ "room_id": room_id, "receipt_type": receipt_type, "user_id": user_id, "event_ids": json_encoder.encode(event_ids), "data": json_encoder.encode(data), }, )
class ReceiptsWorkerStore(SQLBaseStore): """This is an abstract base class where subclasses must implement `get_max_receipt_stream_id` which can be called in the initializer. """ # This ABCMeta metaclass ensures that we cannot be instantiated without # the abstract methods being implemented. __metaclass__ = abc.ABCMeta def __init__(self, db_conn, hs): super(ReceiptsWorkerStore, self).__init__(db_conn, hs) self._receipts_stream_cache = StreamChangeCache( "ReceiptsRoomChangeCache", self.get_max_receipt_stream_id() ) @abc.abstractmethod def get_max_receipt_stream_id(self): """Get the current max stream ID for receipts stream Returns: int """ raise NotImplementedError() @cachedInlineCallbacks() def get_users_with_read_receipts_in_room(self, room_id): receipts = yield self.get_receipts_for_room(room_id, "m.read") return set(r["user_id"] for r in receipts) @cached(num_args=2) def get_receipts_for_room(self, room_id, receipt_type): return self._simple_select_list( table="receipts_linearized", keyvalues={"room_id": room_id, "receipt_type": receipt_type}, retcols=("user_id", "event_id"), desc="get_receipts_for_room", ) @cached(num_args=3) def get_last_receipt_event_id_for_user(self, user_id, room_id, receipt_type): return self._simple_select_one_onecol( table="receipts_linearized", keyvalues={ "room_id": room_id, "receipt_type": receipt_type, "user_id": user_id, }, retcol="event_id", desc="get_own_receipt_for_user", allow_none=True, ) @cachedInlineCallbacks(num_args=2) def get_receipts_for_user(self, user_id, receipt_type): rows = yield self._simple_select_list( table="receipts_linearized", keyvalues={"user_id": user_id, "receipt_type": receipt_type}, retcols=("room_id", "event_id"), desc="get_receipts_for_user", ) return {row["room_id"]: row["event_id"] for row in rows} @defer.inlineCallbacks def get_receipts_for_user_with_orderings(self, user_id, receipt_type): def f(txn): sql = ( "SELECT rl.room_id, rl.event_id," " e.topological_ordering, e.stream_ordering" " FROM receipts_linearized AS rl" " INNER JOIN events AS e USING (room_id, event_id)" " WHERE rl.room_id = e.room_id" " AND rl.event_id = e.event_id" " AND user_id = ?" ) txn.execute(sql, (user_id,)) return txn.fetchall() rows = yield self.runInteraction("get_receipts_for_user_with_orderings", f) return { row[0]: { "event_id": row[1], "topological_ordering": row[2], "stream_ordering": row[3], } for row in rows } @defer.inlineCallbacks def get_linearized_receipts_for_rooms(self, room_ids, to_key, from_key=None): """Get receipts for multiple rooms for sending to clients. Args: room_ids (list): List of room_ids. to_key (int): Max stream id to fetch receipts upto. from_key (int): Min stream id to fetch receipts from. None fetches from the start. Returns: list: A list of receipts. """ room_ids = set(room_ids) if from_key is not None: # Only ask the database about rooms where there have been new # receipts added since `from_key` room_ids = yield self._receipts_stream_cache.get_entities_changed( room_ids, from_key ) results = yield self._get_linearized_receipts_for_rooms( room_ids, to_key, from_key=from_key ) return [ev for res in results.values() for ev in res] def get_linearized_receipts_for_room(self, room_id, to_key, from_key=None): """Get receipts for a single room for sending to clients. Args: room_ids (str): The room id. to_key (int): Max stream id to fetch receipts upto. from_key (int): Min stream id to fetch receipts from. None fetches from the start. Returns: Deferred[list]: A list of receipts. """ if from_key is not None: # Check the cache first to see if any new receipts have been added # since`from_key`. If not we can no-op. if not self._receipts_stream_cache.has_entity_changed(room_id, from_key): defer.succeed([]) return self._get_linearized_receipts_for_room(room_id, to_key, from_key) @cachedInlineCallbacks(num_args=3, tree=True) def _get_linearized_receipts_for_room(self, room_id, to_key, from_key=None): """See get_linearized_receipts_for_room """ def f(txn): if from_key: sql = ( "SELECT * FROM receipts_linearized WHERE" " room_id = ? AND stream_id > ? AND stream_id <= ?" ) txn.execute(sql, (room_id, from_key, to_key)) else: sql = ( "SELECT * FROM receipts_linearized WHERE" " room_id = ? AND stream_id <= ?" ) txn.execute(sql, (room_id, to_key)) rows = self.cursor_to_dict(txn) return rows rows = yield self.runInteraction("get_linearized_receipts_for_room", f) if not rows: return [] content = {} for row in rows: content.setdefault(row["event_id"], {}).setdefault(row["receipt_type"], {})[ row["user_id"] ] = json.loads(row["data"]) return [{"type": "m.receipt", "room_id": room_id, "content": content}] @cachedList( cached_method_name="_get_linearized_receipts_for_room", list_name="room_ids", num_args=3, inlineCallbacks=True, ) def _get_linearized_receipts_for_rooms(self, room_ids, to_key, from_key=None): if not room_ids: return {} def f(txn): if from_key: sql = ( "SELECT * FROM receipts_linearized WHERE" " room_id IN (%s) AND stream_id > ? AND stream_id <= ?" ) % (",".join(["?"] * len(room_ids))) args = list(room_ids) args.extend([from_key, to_key]) txn.execute(sql, args) else: sql = ( "SELECT * FROM receipts_linearized WHERE" " room_id IN (%s) AND stream_id <= ?" ) % (",".join(["?"] * len(room_ids))) args = list(room_ids) args.append(to_key) txn.execute(sql, args) return self.cursor_to_dict(txn) txn_results = yield self.runInteraction("_get_linearized_receipts_for_rooms", f) results = {} for row in txn_results: # We want a single event per room, since we want to batch the # receipts by room, event and type. room_event = results.setdefault( row["room_id"], {"type": "m.receipt", "room_id": row["room_id"], "content": {}}, ) # The content is of the form: # {"$foo:bar": { "read": { "@user:host": <receipt> }, .. }, .. } event_entry = room_event["content"].setdefault(row["event_id"], {}) receipt_type = event_entry.setdefault(row["receipt_type"], {}) receipt_type[row["user_id"]] = json.loads(row["data"]) results = { room_id: [results[room_id]] if room_id in results else [] for room_id in room_ids } return results def get_all_updated_receipts(self, last_id, current_id, limit=None): if last_id == current_id: return defer.succeed([]) def get_all_updated_receipts_txn(txn): sql = ( "SELECT stream_id, room_id, receipt_type, user_id, event_id, data" " FROM receipts_linearized" " WHERE ? < stream_id AND stream_id <= ?" " ORDER BY stream_id ASC" ) args = [last_id, current_id] if limit is not None: sql += " LIMIT ?" args.append(limit) txn.execute(sql, args) return (r[0:5] + (json.loads(r[5]),) for r in txn) return self.runInteraction( "get_all_updated_receipts", get_all_updated_receipts_txn ) def _invalidate_get_users_with_receipts_in_room( self, room_id, receipt_type, user_id ): if receipt_type != "m.read": return # Returns either an ObservableDeferred or the raw result res = self.get_users_with_read_receipts_in_room.cache.get( room_id, None, update_metrics=False ) # first handle the Deferred case if isinstance(res, defer.Deferred): if res.called: res = res.result else: res = None if res and user_id in res: # We'd only be adding to the set, so no point invalidating if the # user is already there return self.get_users_with_read_receipts_in_room.invalidate((room_id,))
class AccountDataWorkerStore(SQLBaseStore): """This is an abstract base class where subclasses must implement `get_max_account_data_stream_id` which can be called in the initializer. """ # This ABCMeta metaclass ensures that we cannot be instantiated without # the abstract methods being implemented. __metaclass__ = abc.ABCMeta def __init__(self, db_conn, hs): account_max = self.get_max_account_data_stream_id() self._account_data_stream_cache = StreamChangeCache( "AccountDataAndTagsChangeCache", account_max, ) super(AccountDataWorkerStore, self).__init__(db_conn, hs) @abc.abstractmethod def get_max_account_data_stream_id(self): """Get the current max stream ID for account data stream Returns: int """ raise NotImplementedError() @cached() def get_account_data_for_user(self, user_id): """Get all the client account_data for a user. Args: user_id(str): The user to get the account_data for. Returns: A deferred pair of a dict of global account_data and a dict mapping from room_id string to per room account_data dicts. """ def get_account_data_for_user_txn(txn): rows = self._simple_select_list_txn( txn, "account_data", {"user_id": user_id}, ["account_data_type", "content"] ) global_account_data = { row["account_data_type"]: json.loads(row["content"]) for row in rows } rows = self._simple_select_list_txn( txn, "room_account_data", {"user_id": user_id}, ["room_id", "account_data_type", "content"] ) by_room = {} for row in rows: room_data = by_room.setdefault(row["room_id"], {}) room_data[row["account_data_type"]] = json.loads(row["content"]) return (global_account_data, by_room) return self.runInteraction( "get_account_data_for_user", get_account_data_for_user_txn ) @cachedInlineCallbacks(num_args=2, max_entries=5000) def get_global_account_data_by_type_for_user(self, data_type, user_id): """ Returns: Deferred: A dict """ result = yield self._simple_select_one_onecol( table="account_data", keyvalues={ "user_id": user_id, "account_data_type": data_type, }, retcol="content", desc="get_global_account_data_by_type_for_user", allow_none=True, ) if result: defer.returnValue(json.loads(result)) else: defer.returnValue(None) @cached(num_args=2) def get_account_data_for_room(self, user_id, room_id): """Get all the client account_data for a user for a room. Args: user_id(str): The user to get the account_data for. room_id(str): The room to get the account_data for. Returns: A deferred dict of the room account_data """ def get_account_data_for_room_txn(txn): rows = self._simple_select_list_txn( txn, "room_account_data", {"user_id": user_id, "room_id": room_id}, ["account_data_type", "content"] ) return { row["account_data_type"]: json.loads(row["content"]) for row in rows } return self.runInteraction( "get_account_data_for_room", get_account_data_for_room_txn ) @cached(num_args=3, max_entries=5000) def get_account_data_for_room_and_type(self, user_id, room_id, account_data_type): """Get the client account_data of given type for a user for a room. Args: user_id(str): The user to get the account_data for. room_id(str): The room to get the account_data for. account_data_type (str): The account data type to get. Returns: A deferred of the room account_data for that type, or None if there isn't any set. """ def get_account_data_for_room_and_type_txn(txn): content_json = self._simple_select_one_onecol_txn( txn, table="room_account_data", keyvalues={ "user_id": user_id, "room_id": room_id, "account_data_type": account_data_type, }, retcol="content", allow_none=True ) return json.loads(content_json) if content_json else None return self.runInteraction( "get_account_data_for_room_and_type", get_account_data_for_room_and_type_txn, ) def get_all_updated_account_data(self, last_global_id, last_room_id, current_id, limit): """Get all the client account_data that has changed on the server Args: last_global_id(int): The position to fetch from for top level data last_room_id(int): The position to fetch from for per room data current_id(int): The position to fetch up to. Returns: A deferred pair of lists of tuples of stream_id int, user_id string, room_id string, type string, and content string. """ if last_room_id == current_id and last_global_id == current_id: return defer.succeed(([], [])) def get_updated_account_data_txn(txn): sql = ( "SELECT stream_id, user_id, account_data_type, content" " FROM account_data WHERE ? < stream_id AND stream_id <= ?" " ORDER BY stream_id ASC LIMIT ?" ) txn.execute(sql, (last_global_id, current_id, limit)) global_results = txn.fetchall() sql = ( "SELECT stream_id, user_id, room_id, account_data_type, content" " FROM room_account_data WHERE ? < stream_id AND stream_id <= ?" " ORDER BY stream_id ASC LIMIT ?" ) txn.execute(sql, (last_room_id, current_id, limit)) room_results = txn.fetchall() return (global_results, room_results) return self.runInteraction( "get_all_updated_account_data_txn", get_updated_account_data_txn ) def get_updated_account_data_for_user(self, user_id, stream_id): """Get all the client account_data for a that's changed for a user Args: user_id(str): The user to get the account_data for. stream_id(int): The point in the stream since which to get updates Returns: A deferred pair of a dict of global account_data and a dict mapping from room_id string to per room account_data dicts. """ def get_updated_account_data_for_user_txn(txn): sql = ( "SELECT account_data_type, content FROM account_data" " WHERE user_id = ? AND stream_id > ?" ) txn.execute(sql, (user_id, stream_id)) global_account_data = { row[0]: json.loads(row[1]) for row in txn } sql = ( "SELECT room_id, account_data_type, content FROM room_account_data" " WHERE user_id = ? AND stream_id > ?" ) txn.execute(sql, (user_id, stream_id)) account_data_by_room = {} for row in txn: room_account_data = account_data_by_room.setdefault(row[0], {}) room_account_data[row[1]] = json.loads(row[2]) return (global_account_data, account_data_by_room) changed = self._account_data_stream_cache.has_entity_changed( user_id, int(stream_id) ) if not changed: return ({}, {}) return self.runInteraction( "get_updated_account_data_for_user", get_updated_account_data_for_user_txn ) @cachedInlineCallbacks(num_args=2, cache_context=True, max_entries=5000) def is_ignored_by(self, ignored_user_id, ignorer_user_id, cache_context): ignored_account_data = yield self.get_global_account_data_by_type_for_user( "m.ignored_user_list", ignorer_user_id, on_invalidate=cache_context.invalidate, ) if not ignored_account_data: defer.returnValue(False) defer.returnValue( ignored_user_id in ignored_account_data.get("ignored_users", {}) )
class PushRulesWorkerStore( ApplicationServiceWorkerStore, ReceiptsWorkerStore, PusherWorkerStore, RoomMemberWorkerStore, EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta, ): """This is an abstract base class where subclasses must implement `get_max_push_rules_stream_id` which can be called in the initializer. """ def __init__(self, database: DatabasePool, db_conn, hs): super().__init__(database, db_conn, hs) if hs.config.worker.worker_app is None: self._push_rules_stream_id_gen = StreamIdGenerator( db_conn, "push_rules_stream", "stream_id") # type: Union[StreamIdGenerator, SlavedIdTracker] else: self._push_rules_stream_id_gen = SlavedIdTracker( db_conn, "push_rules_stream", "stream_id") push_rules_prefill, push_rules_id = self.db_pool.get_cache_dict( db_conn, "push_rules_stream", entity_column="user_id", stream_column="stream_id", max_value=self.get_max_push_rules_stream_id(), ) self.push_rules_stream_cache = StreamChangeCache( "PushRulesStreamChangeCache", push_rules_id, prefilled_cache=push_rules_prefill, ) self._users_new_default_push_rules = hs.config.users_new_default_push_rules @abc.abstractmethod def get_max_push_rules_stream_id(self): """Get the position of the push rules stream. Returns: int """ raise NotImplementedError() @cached(max_entries=5000) async def get_push_rules_for_user(self, user_id): rows = await self.db_pool.simple_select_list( table="push_rules", keyvalues={"user_name": user_id}, retcols=( "user_name", "rule_id", "priority_class", "priority", "conditions", "actions", ), desc="get_push_rules_enabled_for_user", ) rows.sort(key=lambda row: (-int(row["priority_class"]), -int(row["priority"]))) enabled_map = await self.get_push_rules_enabled_for_user(user_id) use_new_defaults = user_id in self._users_new_default_push_rules return _load_rules(rows, enabled_map, use_new_defaults) @cached(max_entries=5000) async def get_push_rules_enabled_for_user(self, user_id): results = await self.db_pool.simple_select_list( table="push_rules_enable", keyvalues={"user_name": user_id}, retcols=("user_name", "rule_id", "enabled"), desc="get_push_rules_enabled_for_user", ) return { r["rule_id"]: False if r["enabled"] == 0 else True for r in results } async def have_push_rules_changed_for_user(self, user_id: str, last_id: int) -> bool: if not self.push_rules_stream_cache.has_entity_changed( user_id, last_id): return False else: def have_push_rules_changed_txn(txn): sql = ("SELECT COUNT(stream_id) FROM push_rules_stream" " WHERE user_id = ? AND ? < stream_id") txn.execute(sql, (user_id, last_id)) (count, ) = txn.fetchone() return bool(count) return await self.db_pool.runInteraction( "have_push_rules_changed", have_push_rules_changed_txn) @cachedList( cached_method_name="get_push_rules_for_user", list_name="user_ids", num_args=1, ) async def bulk_get_push_rules(self, user_ids): if not user_ids: return {} results = {user_id: [] for user_id in user_ids} rows = await self.db_pool.simple_select_many_batch( table="push_rules", column="user_name", iterable=user_ids, retcols=("*", ), desc="bulk_get_push_rules", ) rows.sort(key=lambda row: (-int(row["priority_class"]), -int(row["priority"]))) for row in rows: results.setdefault(row["user_name"], []).append(row) enabled_map_by_user = await self.bulk_get_push_rules_enabled(user_ids) for user_id, rules in results.items(): use_new_defaults = user_id in self._users_new_default_push_rules results[user_id] = _load_rules( rules, enabled_map_by_user.get(user_id, {}), use_new_defaults, ) return results async def copy_push_rule_from_room_to_room(self, new_room_id: str, user_id: str, rule: dict) -> None: """Copy a single push rule from one room to another for a specific user. Args: new_room_id: ID of the new room. user_id : ID of user the push rule belongs to. rule: A push rule. """ # Create new rule id rule_id_scope = "/".join(rule["rule_id"].split("/")[:-1]) new_rule_id = rule_id_scope + "/" + new_room_id # Change room id in each condition for condition in rule.get("conditions", []): if condition.get("key") == "room_id": condition["pattern"] = new_room_id # Add the rule for the new room await self.add_push_rule( user_id=user_id, rule_id=new_rule_id, priority_class=rule["priority_class"], conditions=rule["conditions"], actions=rule["actions"], ) async def copy_push_rules_from_room_to_room_for_user( self, old_room_id: str, new_room_id: str, user_id: str) -> None: """Copy all of the push rules from one room to another for a specific user. Args: old_room_id: ID of the old room. new_room_id: ID of the new room. user_id: ID of user to copy push rules for. """ # Retrieve push rules for this user user_push_rules = await self.get_push_rules_for_user(user_id) # Get rules relating to the old room and copy them to the new room for rule in user_push_rules: conditions = rule.get("conditions", []) if any( (c.get("key") == "room_id" and c.get("pattern") == old_room_id) for c in conditions): await self.copy_push_rule_from_room_to_room( new_room_id, user_id, rule) @cachedList( cached_method_name="get_push_rules_enabled_for_user", list_name="user_ids", num_args=1, ) async def bulk_get_push_rules_enabled(self, user_ids): if not user_ids: return {} results = {user_id: {} for user_id in user_ids} rows = await self.db_pool.simple_select_many_batch( table="push_rules_enable", column="user_name", iterable=user_ids, retcols=("user_name", "rule_id", "enabled"), desc="bulk_get_push_rules_enabled", ) for row in rows: enabled = bool(row["enabled"]) results.setdefault(row["user_name"], {})[row["rule_id"]] = enabled return results async def get_all_push_rule_updates( self, instance_name: str, last_id: int, current_id: int, limit: int) -> Tuple[List[Tuple[int, tuple]], int, bool]: """Get updates for push_rules replication stream. Args: instance_name: The writer we want to fetch updates from. Unused here since there is only ever one writer. last_id: The token to fetch updates from. Exclusive. current_id: The token to fetch updates up to. Inclusive. limit: The requested limit for the number of rows to return. The function may return more or fewer rows. Returns: A tuple consisting of: the updates, a token to use to fetch subsequent updates, and whether we returned fewer rows than exists between the requested tokens due to the limit. The token returned can be used in a subsequent call to this function to get further updatees. The updates are a list of 2-tuples of stream ID and the row data """ if last_id == current_id: return [], current_id, False def get_all_push_rule_updates_txn(txn): sql = """ SELECT stream_id, user_id FROM push_rules_stream WHERE ? < stream_id AND stream_id <= ? ORDER BY stream_id ASC LIMIT ? """ txn.execute(sql, (last_id, current_id, limit)) updates = [(stream_id, (user_id, )) for stream_id, user_id in txn] limited = False upper_bound = current_id if len(updates) == limit: limited = True upper_bound = updates[-1][0] return updates, upper_bound, limited return await self.db_pool.runInteraction( "get_all_push_rule_updates", get_all_push_rule_updates_txn)
class ReceiptsWorkerStore(SQLBaseStore): def __init__(self, database: DatabasePool, db_conn, hs): self._instance_name = hs.get_instance_name() if isinstance(database.engine, PostgresEngine): self._can_write_to_receipts = ( self._instance_name in hs.config.worker.writers.receipts) self._receipts_id_gen = MultiWriterIdGenerator( db_conn=db_conn, db=database, stream_name="receipts", instance_name=self._instance_name, tables=[("receipts_linearized", "instance_name", "stream_id")], sequence_name="receipts_sequence", writers=hs.config.worker.writers.receipts, ) else: self._can_write_to_receipts = True # We shouldn't be running in worker mode with SQLite, but its useful # to support it for unit tests. # # If this process is the writer than we need to use # `StreamIdGenerator`, otherwise we use `SlavedIdTracker` which gets # updated over replication. (Multiple writers are not supported for # SQLite). if hs.get_instance_name() in hs.config.worker.writers.receipts: self._receipts_id_gen = StreamIdGenerator( db_conn, "receipts_linearized", "stream_id") else: self._receipts_id_gen = SlavedIdTracker( db_conn, "receipts_linearized", "stream_id") super().__init__(database, db_conn, hs) self._receipts_stream_cache = StreamChangeCache( "ReceiptsRoomChangeCache", self.get_max_receipt_stream_id()) def get_max_receipt_stream_id(self): """Get the current max stream ID for receipts stream Returns: int """ return self._receipts_id_gen.get_current_token() @cached() async def get_users_with_read_receipts_in_room(self, room_id): receipts = await self.get_receipts_for_room(room_id, "m.read") return {r["user_id"] for r in receipts} @cached(num_args=2) async def get_receipts_for_room(self, room_id: str, receipt_type: str) -> List[Dict[str, Any]]: return await self.db_pool.simple_select_list( table="receipts_linearized", keyvalues={ "room_id": room_id, "receipt_type": receipt_type }, retcols=("user_id", "event_id"), desc="get_receipts_for_room", ) @cached(num_args=3) async def get_last_receipt_event_id_for_user( self, user_id: str, room_id: str, receipt_type: str) -> Optional[str]: return await self.db_pool.simple_select_one_onecol( table="receipts_linearized", keyvalues={ "room_id": room_id, "receipt_type": receipt_type, "user_id": user_id, }, retcol="event_id", desc="get_own_receipt_for_user", allow_none=True, ) @cached(num_args=2) async def get_receipts_for_user(self, user_id, receipt_type): rows = await self.db_pool.simple_select_list( table="receipts_linearized", keyvalues={ "user_id": user_id, "receipt_type": receipt_type }, retcols=("room_id", "event_id"), desc="get_receipts_for_user", ) return {row["room_id"]: row["event_id"] for row in rows} async def get_receipts_for_user_with_orderings(self, user_id, receipt_type): def f(txn): sql = ("SELECT rl.room_id, rl.event_id," " e.topological_ordering, e.stream_ordering" " FROM receipts_linearized AS rl" " INNER JOIN events AS e USING (room_id, event_id)" " WHERE rl.room_id = e.room_id" " AND rl.event_id = e.event_id" " AND user_id = ?") txn.execute(sql, (user_id, )) return txn.fetchall() rows = await self.db_pool.runInteraction( "get_receipts_for_user_with_orderings", f) return { row[0]: { "event_id": row[1], "topological_ordering": row[2], "stream_ordering": row[3], } for row in rows } async def get_linearized_receipts_for_rooms( self, room_ids: List[str], to_key: int, from_key: Optional[int] = None) -> List[dict]: """Get receipts for multiple rooms for sending to clients. Args: room_id: List of room_ids. to_key: Max stream id to fetch receipts up to. from_key: Min stream id to fetch receipts from. None fetches from the start. Returns: A list of receipts. """ room_ids = set(room_ids) if from_key is not None: # Only ask the database about rooms where there have been new # receipts added since `from_key` room_ids = self._receipts_stream_cache.get_entities_changed( room_ids, from_key) results = await self._get_linearized_receipts_for_rooms( room_ids, to_key, from_key=from_key) return [ev for res in results.values() for ev in res] async def get_linearized_receipts_for_room( self, room_id: str, to_key: int, from_key: Optional[int] = None) -> List[dict]: """Get receipts for a single room for sending to clients. Args: room_ids: The room id. to_key: Max stream id to fetch receipts up to. from_key: Min stream id to fetch receipts from. None fetches from the start. Returns: A list of receipts. """ if from_key is not None: # Check the cache first to see if any new receipts have been added # since`from_key`. If not we can no-op. if not self._receipts_stream_cache.has_entity_changed( room_id, from_key): return [] return await self._get_linearized_receipts_for_room( room_id, to_key, from_key) @cached(num_args=3, tree=True) async def _get_linearized_receipts_for_room( self, room_id: str, to_key: int, from_key: Optional[int] = None) -> List[dict]: """See get_linearized_receipts_for_room""" def f(txn): if from_key: sql = ("SELECT * FROM receipts_linearized WHERE" " room_id = ? AND stream_id > ? AND stream_id <= ?") txn.execute(sql, (room_id, from_key, to_key)) else: sql = ("SELECT * FROM receipts_linearized WHERE" " room_id = ? AND stream_id <= ?") txn.execute(sql, (room_id, to_key)) rows = self.db_pool.cursor_to_dict(txn) return rows rows = await self.db_pool.runInteraction( "get_linearized_receipts_for_room", f) if not rows: return [] content = {} for row in rows: content.setdefault(row["event_id"], {}).setdefault( row["receipt_type"], {})[row["user_id"]] = db_to_json(row["data"]) return [{"type": "m.receipt", "room_id": room_id, "content": content}] @cachedList( cached_method_name="_get_linearized_receipts_for_room", list_name="room_ids", num_args=3, ) async def _get_linearized_receipts_for_rooms(self, room_ids, to_key, from_key=None): if not room_ids: return {} def f(txn): if from_key: sql = """ SELECT * FROM receipts_linearized WHERE stream_id > ? AND stream_id <= ? AND """ clause, args = make_in_list_sql_clause(self.database_engine, "room_id", room_ids) txn.execute(sql + clause, [from_key, to_key] + list(args)) else: sql = """ SELECT * FROM receipts_linearized WHERE stream_id <= ? AND """ clause, args = make_in_list_sql_clause(self.database_engine, "room_id", room_ids) txn.execute(sql + clause, [to_key] + list(args)) return self.db_pool.cursor_to_dict(txn) txn_results = await self.db_pool.runInteraction( "_get_linearized_receipts_for_rooms", f) results = {} for row in txn_results: # We want a single event per room, since we want to batch the # receipts by room, event and type. room_event = results.setdefault( row["room_id"], { "type": "m.receipt", "room_id": row["room_id"], "content": {} }, ) # The content is of the form: # {"$foo:bar": { "read": { "@user:host": <receipt> }, .. }, .. } event_entry = room_event["content"].setdefault(row["event_id"], {}) receipt_type = event_entry.setdefault(row["receipt_type"], {}) receipt_type[row["user_id"]] = db_to_json(row["data"]) results = { room_id: [results[room_id]] if room_id in results else [] for room_id in room_ids } return results @cached( num_args=2, ) async def get_linearized_receipts_for_all_rooms( self, to_key: int, from_key: Optional[int] = None) -> Dict[str, JsonDict]: """Get receipts for all rooms between two stream_ids, up to a limit of the latest 100 read receipts. Args: to_key: Max stream id to fetch receipts up to. from_key: Min stream id to fetch receipts from. None fetches from the start. Returns: A dictionary of roomids to a list of receipts. """ def f(txn): if from_key: sql = """ SELECT * FROM receipts_linearized WHERE stream_id > ? AND stream_id <= ? ORDER BY stream_id DESC LIMIT 100 """ txn.execute(sql, [from_key, to_key]) else: sql = """ SELECT * FROM receipts_linearized WHERE stream_id <= ? ORDER BY stream_id DESC LIMIT 100 """ txn.execute(sql, [to_key]) return self.db_pool.cursor_to_dict(txn) txn_results = await self.db_pool.runInteraction( "get_linearized_receipts_for_all_rooms", f) results = {} for row in txn_results: # We want a single event per room, since we want to batch the # receipts by room, event and type. room_event = results.setdefault( row["room_id"], { "type": "m.receipt", "room_id": row["room_id"], "content": {} }, ) # The content is of the form: # {"$foo:bar": { "read": { "@user:host": <receipt> }, .. }, .. } event_entry = room_event["content"].setdefault(row["event_id"], {}) receipt_type = event_entry.setdefault(row["receipt_type"], {}) receipt_type[row["user_id"]] = db_to_json(row["data"]) return results async def get_users_sent_receipts_between(self, last_id: int, current_id: int) -> List[str]: """Get all users who sent receipts between `last_id` exclusive and `current_id` inclusive. Returns: The list of users. """ if last_id == current_id: return defer.succeed([]) def _get_users_sent_receipts_between_txn(txn): sql = """ SELECT DISTINCT user_id FROM receipts_linearized WHERE ? < stream_id AND stream_id <= ? """ txn.execute(sql, (last_id, current_id)) return [r[0] for r in txn] return await self.db_pool.runInteraction( "get_users_sent_receipts_between", _get_users_sent_receipts_between_txn) async def get_all_updated_receipts( self, instance_name: str, last_id: int, current_id: int, limit: int) -> Tuple[List[Tuple[int, list]], int, bool]: """Get updates for receipts replication stream. Args: instance_name: The writer we want to fetch updates from. Unused here since there is only ever one writer. last_id: The token to fetch updates from. Exclusive. current_id: The token to fetch updates up to. Inclusive. limit: The requested limit for the number of rows to return. The function may return more or fewer rows. Returns: A tuple consisting of: the updates, a token to use to fetch subsequent updates, and whether we returned fewer rows than exists between the requested tokens due to the limit. The token returned can be used in a subsequent call to this function to get further updatees. The updates are a list of 2-tuples of stream ID and the row data """ if last_id == current_id: return [], current_id, False def get_all_updated_receipts_txn(txn): sql = """ SELECT stream_id, room_id, receipt_type, user_id, event_id, data FROM receipts_linearized WHERE ? < stream_id AND stream_id <= ? ORDER BY stream_id ASC LIMIT ? """ txn.execute(sql, (last_id, current_id, limit)) updates = [(r[0], r[1:5] + (db_to_json(r[5]), )) for r in txn] limited = False upper_bound = current_id if len(updates) == limit: limited = True upper_bound = updates[-1][0] return updates, upper_bound, limited return await self.db_pool.runInteraction("get_all_updated_receipts", get_all_updated_receipts_txn) def _invalidate_get_users_with_receipts_in_room(self, room_id: str, receipt_type: str, user_id: str): if receipt_type != "m.read": return res = self.get_users_with_read_receipts_in_room.cache.get_immediate( room_id, None, update_metrics=False) if res and user_id in res: # We'd only be adding to the set, so no point invalidating if the # user is already there return self.get_users_with_read_receipts_in_room.invalidate((room_id, )) def invalidate_caches_for_receipt(self, room_id, receipt_type, user_id): self.get_receipts_for_user.invalidate((user_id, receipt_type)) self._get_linearized_receipts_for_room.invalidate_many((room_id, )) self.get_last_receipt_event_id_for_user.invalidate( (user_id, room_id, receipt_type)) self._invalidate_get_users_with_receipts_in_room( room_id, receipt_type, user_id) self.get_receipts_for_room.invalidate((room_id, receipt_type)) def process_replication_rows(self, stream_name, instance_name, token, rows): if stream_name == ReceiptsStream.NAME: self._receipts_id_gen.advance(instance_name, token) for row in rows: self.invalidate_caches_for_receipt(row.room_id, row.receipt_type, row.user_id) self._receipts_stream_cache.entity_has_changed( row.room_id, token) return super().process_replication_rows(stream_name, instance_name, token, rows) def insert_linearized_receipt_txn(self, txn, room_id, receipt_type, user_id, event_id, data, stream_id): """Inserts a read-receipt into the database if it's newer than the current RR Returns: int|None None if the RR is older than the current RR otherwise, the rx timestamp of the event that the RR corresponds to (or 0 if the event is unknown) """ assert self._can_write_to_receipts res = self.db_pool.simple_select_one_txn( txn, table="events", retcols=["stream_ordering", "received_ts"], keyvalues={"event_id": event_id}, allow_none=True, ) stream_ordering = int(res["stream_ordering"]) if res else None rx_ts = res["received_ts"] if res else 0 # We don't want to clobber receipts for more recent events, so we # have to compare orderings of existing receipts if stream_ordering is not None: sql = ( "SELECT stream_ordering, event_id FROM events" " INNER JOIN receipts_linearized as r USING (event_id, room_id)" " WHERE r.room_id = ? AND r.receipt_type = ? AND r.user_id = ?" ) txn.execute(sql, (room_id, receipt_type, user_id)) for so, eid in txn: if int(so) >= stream_ordering: logger.debug( "Ignoring new receipt for %s in favour of existing " "one for later event %s", event_id, eid, ) return None txn.call_after(self.invalidate_caches_for_receipt, room_id, receipt_type, user_id) txn.call_after(self._receipts_stream_cache.entity_has_changed, room_id, stream_id) self.db_pool.simple_upsert_txn( txn, table="receipts_linearized", keyvalues={ "room_id": room_id, "receipt_type": receipt_type, "user_id": user_id, }, values={ "stream_id": stream_id, "event_id": event_id, "data": json_encoder.encode(data), }, # receipts_linearized has a unique constraint on # (user_id, room_id, receipt_type), so no need to lock lock=False, ) if receipt_type == "m.read" and stream_ordering is not None: self._remove_old_push_actions_before_txn( txn, room_id=room_id, user_id=user_id, stream_ordering=stream_ordering) return rx_ts async def insert_receipt( self, room_id: str, receipt_type: str, user_id: str, event_ids: List[str], data: dict, ) -> Optional[Tuple[int, int]]: """Insert a receipt, either from local client or remote server. Automatically does conversion between linearized and graph representations. """ assert self._can_write_to_receipts if not event_ids: return None if len(event_ids) == 1: linearized_event_id = event_ids[0] else: # we need to points in graph -> linearized form. # TODO: Make this better. def graph_to_linear(txn): clause, args = make_in_list_sql_clause(self.database_engine, "event_id", event_ids) sql = """ SELECT event_id WHERE room_id = ? AND stream_ordering IN ( SELECT max(stream_ordering) WHERE %s ) """ % (clause, ) txn.execute(sql, [room_id] + list(args)) rows = txn.fetchall() if rows: return rows[0][0] else: raise RuntimeError("Unrecognized event_ids: %r" % (event_ids, )) linearized_event_id = await self.db_pool.runInteraction( "insert_receipt_conv", graph_to_linear) async with self._receipts_id_gen.get_next() as stream_id: event_ts = await self.db_pool.runInteraction( "insert_linearized_receipt", self.insert_linearized_receipt_txn, room_id, receipt_type, user_id, linearized_event_id, data, stream_id=stream_id, ) if event_ts is None: return None now = self._clock.time_msec() logger.debug( "RR for event %s in %s (%i ms old)", linearized_event_id, room_id, now - event_ts, ) await self.insert_graph_receipt(room_id, receipt_type, user_id, event_ids, data) max_persisted_id = self._receipts_id_gen.get_current_token() return stream_id, max_persisted_id async def insert_graph_receipt(self, room_id, receipt_type, user_id, event_ids, data): assert self._can_write_to_receipts return await self.db_pool.runInteraction( "insert_graph_receipt", self.insert_graph_receipt_txn, room_id, receipt_type, user_id, event_ids, data, ) def insert_graph_receipt_txn(self, txn, room_id, receipt_type, user_id, event_ids, data): assert self._can_write_to_receipts txn.call_after(self.get_receipts_for_room.invalidate, (room_id, receipt_type)) txn.call_after( self._invalidate_get_users_with_receipts_in_room, room_id, receipt_type, user_id, ) txn.call_after(self.get_receipts_for_user.invalidate, (user_id, receipt_type)) # FIXME: This shouldn't invalidate the whole cache txn.call_after(self._get_linearized_receipts_for_room.invalidate_many, (room_id, )) self.db_pool.simple_delete_txn( txn, table="receipts_graph", keyvalues={ "room_id": room_id, "receipt_type": receipt_type, "user_id": user_id, }, ) self.db_pool.simple_insert_txn( txn, table="receipts_graph", values={ "room_id": room_id, "receipt_type": receipt_type, "user_id": user_id, "event_ids": json_encoder.encode(event_ids), "data": json_encoder.encode(data), }, )
class PushRulesWorkerStore( ApplicationServiceWorkerStore, ReceiptsWorkerStore, PusherWorkerStore, RoomMemberWorkerStore, EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta, ): """This is an abstract base class where subclasses must implement `get_max_push_rules_stream_id` which can be called in the initializer. """ def __init__( self, database: DatabasePool, db_conn: LoggingDatabaseConnection, hs: "HomeServer", ): super().__init__(database, db_conn, hs) if hs.config.worker.worker_app is None: self._push_rules_stream_id_gen: AbstractStreamIdTracker = StreamIdGenerator( db_conn, "push_rules_stream", "stream_id") else: self._push_rules_stream_id_gen = SlavedIdTracker( db_conn, "push_rules_stream", "stream_id") push_rules_prefill, push_rules_id = self.db_pool.get_cache_dict( db_conn, "push_rules_stream", entity_column="user_id", stream_column="stream_id", max_value=self.get_max_push_rules_stream_id(), ) self.push_rules_stream_cache = StreamChangeCache( "PushRulesStreamChangeCache", push_rules_id, prefilled_cache=push_rules_prefill, ) @abc.abstractmethod def get_max_push_rules_stream_id(self) -> int: """Get the position of the push rules stream. Returns: int """ raise NotImplementedError() @cached(max_entries=5000) async def get_push_rules_for_user(self, user_id: str) -> List[JsonDict]: rows = await self.db_pool.simple_select_list( table="push_rules", keyvalues={"user_name": user_id}, retcols=( "user_name", "rule_id", "priority_class", "priority", "conditions", "actions", ), desc="get_push_rules_for_user", ) rows.sort(key=lambda row: (-int(row["priority_class"]), -int(row["priority"]))) enabled_map = await self.get_push_rules_enabled_for_user(user_id) return _load_rules(rows, enabled_map, self.hs.config.experimental) @cached(max_entries=5000) async def get_push_rules_enabled_for_user(self, user_id: str) -> Dict[str, bool]: results = await self.db_pool.simple_select_list( table="push_rules_enable", keyvalues={"user_name": user_id}, retcols=("rule_id", "enabled"), desc="get_push_rules_enabled_for_user", ) return {r["rule_id"]: bool(r["enabled"]) for r in results} async def have_push_rules_changed_for_user(self, user_id: str, last_id: int) -> bool: if not self.push_rules_stream_cache.has_entity_changed( user_id, last_id): return False else: def have_push_rules_changed_txn(txn: LoggingTransaction) -> bool: sql = ("SELECT COUNT(stream_id) FROM push_rules_stream" " WHERE user_id = ? AND ? < stream_id") txn.execute(sql, (user_id, last_id)) (count, ) = cast(Tuple[int], txn.fetchone()) return bool(count) return await self.db_pool.runInteraction( "have_push_rules_changed", have_push_rules_changed_txn) @cachedList(cached_method_name="get_push_rules_for_user", list_name="user_ids") async def bulk_get_push_rules( self, user_ids: Collection[str]) -> Dict[str, List[JsonDict]]: if not user_ids: return {} results: Dict[str, List[JsonDict]] = {user_id: [] for user_id in user_ids} rows = await self.db_pool.simple_select_many_batch( table="push_rules", column="user_name", iterable=user_ids, retcols=("*", ), desc="bulk_get_push_rules", ) rows.sort(key=lambda row: (-int(row["priority_class"]), -int(row["priority"]))) for row in rows: results.setdefault(row["user_name"], []).append(row) enabled_map_by_user = await self.bulk_get_push_rules_enabled(user_ids) for user_id, rules in results.items(): results[user_id] = _load_rules( rules, enabled_map_by_user.get(user_id, {}), self.hs.config.experimental) return results @cachedList(cached_method_name="get_push_rules_enabled_for_user", list_name="user_ids") async def bulk_get_push_rules_enabled( self, user_ids: Collection[str]) -> Dict[str, Dict[str, bool]]: if not user_ids: return {} results: Dict[str, Dict[str, bool]] = {user_id: {} for user_id in user_ids} rows = await self.db_pool.simple_select_many_batch( table="push_rules_enable", column="user_name", iterable=user_ids, retcols=("user_name", "rule_id", "enabled"), desc="bulk_get_push_rules_enabled", ) for row in rows: enabled = bool(row["enabled"]) results.setdefault(row["user_name"], {})[row["rule_id"]] = enabled return results async def get_all_push_rule_updates( self, instance_name: str, last_id: int, current_id: int, limit: int) -> Tuple[List[Tuple[int, Tuple[str]]], int, bool]: """Get updates for push_rules replication stream. Args: instance_name: The writer we want to fetch updates from. Unused here since there is only ever one writer. last_id: The token to fetch updates from. Exclusive. current_id: The token to fetch updates up to. Inclusive. limit: The requested limit for the number of rows to return. The function may return more or fewer rows. Returns: A tuple consisting of: the updates, a token to use to fetch subsequent updates, and whether we returned fewer rows than exists between the requested tokens due to the limit. The token returned can be used in a subsequent call to this function to get further updatees. The updates are a list of 2-tuples of stream ID and the row data """ if last_id == current_id: return [], current_id, False def get_all_push_rule_updates_txn( txn: LoggingTransaction, ) -> Tuple[List[Tuple[int, Tuple[str]]], int, bool]: sql = """ SELECT stream_id, user_id FROM push_rules_stream WHERE ? < stream_id AND stream_id <= ? ORDER BY stream_id ASC LIMIT ? """ txn.execute(sql, (last_id, current_id, limit)) updates = cast( List[Tuple[int, Tuple[str]]], [(stream_id, (user_id, )) for stream_id, user_id in txn], ) limited = False upper_bound = current_id if len(updates) == limit: limited = True upper_bound = updates[-1][0] return updates, upper_bound, limited return await self.db_pool.runInteraction( "get_all_push_rule_updates", get_all_push_rule_updates_txn)
class ReceiptsWorkerStore(SQLBaseStore): """This is an abstract base class where subclasses must implement `get_max_receipt_stream_id` which can be called in the initializer. """ # This ABCMeta metaclass ensures that we cannot be instantiated without # the abstract methods being implemented. __metaclass__ = abc.ABCMeta def __init__(self, database: DatabasePool, db_conn, hs): super(ReceiptsWorkerStore, self).__init__(database, db_conn, hs) self._receipts_stream_cache = StreamChangeCache( "ReceiptsRoomChangeCache", self.get_max_receipt_stream_id()) @abc.abstractmethod def get_max_receipt_stream_id(self): """Get the current max stream ID for receipts stream Returns: int """ raise NotImplementedError() @cached() async def get_users_with_read_receipts_in_room(self, room_id): receipts = await self.get_receipts_for_room(room_id, "m.read") return {r["user_id"] for r in receipts} @cached(num_args=2) async def get_receipts_for_room(self, room_id: str, receipt_type: str) -> List[Dict[str, Any]]: return await self.db_pool.simple_select_list( table="receipts_linearized", keyvalues={ "room_id": room_id, "receipt_type": receipt_type }, retcols=("user_id", "event_id"), desc="get_receipts_for_room", ) @cached(num_args=3) async def get_last_receipt_event_id_for_user( self, user_id: str, room_id: str, receipt_type: str) -> Optional[str]: return await self.db_pool.simple_select_one_onecol( table="receipts_linearized", keyvalues={ "room_id": room_id, "receipt_type": receipt_type, "user_id": user_id, }, retcol="event_id", desc="get_own_receipt_for_user", allow_none=True, ) @cached(num_args=2) async def get_receipts_for_user(self, user_id, receipt_type): rows = await self.db_pool.simple_select_list( table="receipts_linearized", keyvalues={ "user_id": user_id, "receipt_type": receipt_type }, retcols=("room_id", "event_id"), desc="get_receipts_for_user", ) return {row["room_id"]: row["event_id"] for row in rows} async def get_receipts_for_user_with_orderings(self, user_id, receipt_type): def f(txn): sql = ("SELECT rl.room_id, rl.event_id," " e.topological_ordering, e.stream_ordering" " FROM receipts_linearized AS rl" " INNER JOIN events AS e USING (room_id, event_id)" " WHERE rl.room_id = e.room_id" " AND rl.event_id = e.event_id" " AND user_id = ?") txn.execute(sql, (user_id, )) return txn.fetchall() rows = await self.db_pool.runInteraction( "get_receipts_for_user_with_orderings", f) return { row[0]: { "event_id": row[1], "topological_ordering": row[2], "stream_ordering": row[3], } for row in rows } async def get_linearized_receipts_for_rooms( self, room_ids: List[str], to_key: int, from_key: Optional[int] = None) -> List[dict]: """Get receipts for multiple rooms for sending to clients. Args: room_id: List of room_ids. to_key: Max stream id to fetch receipts upto. from_key: Min stream id to fetch receipts from. None fetches from the start. Returns: A list of receipts. """ room_ids = set(room_ids) if from_key is not None: # Only ask the database about rooms where there have been new # receipts added since `from_key` room_ids = self._receipts_stream_cache.get_entities_changed( room_ids, from_key) results = await self._get_linearized_receipts_for_rooms( room_ids, to_key, from_key=from_key) return [ev for res in results.values() for ev in res] async def get_linearized_receipts_for_room( self, room_id: str, to_key: int, from_key: Optional[int] = None) -> List[dict]: """Get receipts for a single room for sending to clients. Args: room_ids: The room id. to_key: Max stream id to fetch receipts upto. from_key: Min stream id to fetch receipts from. None fetches from the start. Returns: A list of receipts. """ if from_key is not None: # Check the cache first to see if any new receipts have been added # since`from_key`. If not we can no-op. if not self._receipts_stream_cache.has_entity_changed( room_id, from_key): return [] return await self._get_linearized_receipts_for_room( room_id, to_key, from_key) @cached(num_args=3, tree=True) async def _get_linearized_receipts_for_room( self, room_id: str, to_key: int, from_key: Optional[int] = None) -> List[dict]: """See get_linearized_receipts_for_room """ def f(txn): if from_key: sql = ("SELECT * FROM receipts_linearized WHERE" " room_id = ? AND stream_id > ? AND stream_id <= ?") txn.execute(sql, (room_id, from_key, to_key)) else: sql = ("SELECT * FROM receipts_linearized WHERE" " room_id = ? AND stream_id <= ?") txn.execute(sql, (room_id, to_key)) rows = self.db_pool.cursor_to_dict(txn) return rows rows = await self.db_pool.runInteraction( "get_linearized_receipts_for_room", f) if not rows: return [] content = {} for row in rows: content.setdefault(row["event_id"], {}).setdefault( row["receipt_type"], {})[row["user_id"]] = db_to_json(row["data"]) return [{"type": "m.receipt", "room_id": room_id, "content": content}] @cachedList( cached_method_name="_get_linearized_receipts_for_room", list_name="room_ids", num_args=3, ) async def _get_linearized_receipts_for_rooms(self, room_ids, to_key, from_key=None): if not room_ids: return {} def f(txn): if from_key: sql = """ SELECT * FROM receipts_linearized WHERE stream_id > ? AND stream_id <= ? AND """ clause, args = make_in_list_sql_clause(self.database_engine, "room_id", room_ids) txn.execute(sql + clause, [from_key, to_key] + list(args)) else: sql = """ SELECT * FROM receipts_linearized WHERE stream_id <= ? AND """ clause, args = make_in_list_sql_clause(self.database_engine, "room_id", room_ids) txn.execute(sql + clause, [to_key] + list(args)) return self.db_pool.cursor_to_dict(txn) txn_results = await self.db_pool.runInteraction( "_get_linearized_receipts_for_rooms", f) results = {} for row in txn_results: # We want a single event per room, since we want to batch the # receipts by room, event and type. room_event = results.setdefault( row["room_id"], { "type": "m.receipt", "room_id": row["room_id"], "content": {} }, ) # The content is of the form: # {"$foo:bar": { "read": { "@user:host": <receipt> }, .. }, .. } event_entry = room_event["content"].setdefault(row["event_id"], {}) receipt_type = event_entry.setdefault(row["receipt_type"], {}) receipt_type[row["user_id"]] = db_to_json(row["data"]) results = { room_id: [results[room_id]] if room_id in results else [] for room_id in room_ids } return results async def get_users_sent_receipts_between(self, last_id: int, current_id: int) -> List[str]: """Get all users who sent receipts between `last_id` exclusive and `current_id` inclusive. Returns: The list of users. """ if last_id == current_id: return defer.succeed([]) def _get_users_sent_receipts_between_txn(txn): sql = """ SELECT DISTINCT user_id FROM receipts_linearized WHERE ? < stream_id AND stream_id <= ? """ txn.execute(sql, (last_id, current_id)) return [r[0] for r in txn] return await self.db_pool.runInteraction( "get_users_sent_receipts_between", _get_users_sent_receipts_between_txn) async def get_all_updated_receipts( self, instance_name: str, last_id: int, current_id: int, limit: int) -> Tuple[List[Tuple[int, list]], int, bool]: """Get updates for receipts replication stream. Args: instance_name: The writer we want to fetch updates from. Unused here since there is only ever one writer. last_id: The token to fetch updates from. Exclusive. current_id: The token to fetch updates up to. Inclusive. limit: The requested limit for the number of rows to return. The function may return more or fewer rows. Returns: A tuple consisting of: the updates, a token to use to fetch subsequent updates, and whether we returned fewer rows than exists between the requested tokens due to the limit. The token returned can be used in a subsequent call to this function to get further updatees. The updates are a list of 2-tuples of stream ID and the row data """ if last_id == current_id: return [], current_id, False def get_all_updated_receipts_txn(txn): sql = """ SELECT stream_id, room_id, receipt_type, user_id, event_id, data FROM receipts_linearized WHERE ? < stream_id AND stream_id <= ? ORDER BY stream_id ASC LIMIT ? """ txn.execute(sql, (last_id, current_id, limit)) updates = [(r[0], r[1:5] + (db_to_json(r[5]), )) for r in txn] limited = False upper_bound = current_id if len(updates) == limit: limited = True upper_bound = updates[-1][0] return updates, upper_bound, limited return await self.db_pool.runInteraction("get_all_updated_receipts", get_all_updated_receipts_txn) def _invalidate_get_users_with_receipts_in_room(self, room_id: str, receipt_type: str, user_id: str): if receipt_type != "m.read": return # Returns either an ObservableDeferred or the raw result res = self.get_users_with_read_receipts_in_room.cache.get( room_id, None, update_metrics=False) # first handle the ObservableDeferred case if isinstance(res, ObservableDeferred): if res.has_called(): res = res.get_result() else: res = None if res and user_id in res: # We'd only be adding to the set, so no point invalidating if the # user is already there return self.get_users_with_read_receipts_in_room.invalidate((room_id, ))
class AccountDataWorkerStore(SQLBaseStore): """This is an abstract base class where subclasses must implement `get_max_account_data_stream_id` which can be called in the initializer. """ # This ABCMeta metaclass ensures that we cannot be instantiated without # the abstract methods being implemented. __metaclass__ = abc.ABCMeta def __init__(self, db_conn, hs): account_max = self.get_max_account_data_stream_id() self._account_data_stream_cache = StreamChangeCache( "AccountDataAndTagsChangeCache", account_max ) super(AccountDataWorkerStore, self).__init__(db_conn, hs) @abc.abstractmethod def get_max_account_data_stream_id(self): """Get the current max stream ID for account data stream Returns: int """ raise NotImplementedError() @cached() def get_account_data_for_user(self, user_id): """Get all the client account_data for a user. Args: user_id(str): The user to get the account_data for. Returns: A deferred pair of a dict of global account_data and a dict mapping from room_id string to per room account_data dicts. """ def get_account_data_for_user_txn(txn): rows = self._simple_select_list_txn( txn, "account_data", {"user_id": user_id}, ["account_data_type", "content"], ) global_account_data = { row["account_data_type"]: json.loads(row["content"]) for row in rows } rows = self._simple_select_list_txn( txn, "room_account_data", {"user_id": user_id}, ["room_id", "account_data_type", "content"], ) by_room = {} for row in rows: room_data = by_room.setdefault(row["room_id"], {}) room_data[row["account_data_type"]] = json.loads(row["content"]) return global_account_data, by_room return self.runInteraction( "get_account_data_for_user", get_account_data_for_user_txn ) @cachedInlineCallbacks(num_args=2, max_entries=5000) def get_global_account_data_by_type_for_user(self, data_type, user_id): """ Returns: Deferred: A dict """ result = yield self._simple_select_one_onecol( table="account_data", keyvalues={"user_id": user_id, "account_data_type": data_type}, retcol="content", desc="get_global_account_data_by_type_for_user", allow_none=True, ) if result: return json.loads(result) else: return None @cached(num_args=2) def get_account_data_for_room(self, user_id, room_id): """Get all the client account_data for a user for a room. Args: user_id(str): The user to get the account_data for. room_id(str): The room to get the account_data for. Returns: A deferred dict of the room account_data """ def get_account_data_for_room_txn(txn): rows = self._simple_select_list_txn( txn, "room_account_data", {"user_id": user_id, "room_id": room_id}, ["account_data_type", "content"], ) return { row["account_data_type"]: json.loads(row["content"]) for row in rows } return self.runInteraction( "get_account_data_for_room", get_account_data_for_room_txn ) @cached(num_args=3, max_entries=5000) def get_account_data_for_room_and_type(self, user_id, room_id, account_data_type): """Get the client account_data of given type for a user for a room. Args: user_id(str): The user to get the account_data for. room_id(str): The room to get the account_data for. account_data_type (str): The account data type to get. Returns: A deferred of the room account_data for that type, or None if there isn't any set. """ def get_account_data_for_room_and_type_txn(txn): content_json = self._simple_select_one_onecol_txn( txn, table="room_account_data", keyvalues={ "user_id": user_id, "room_id": room_id, "account_data_type": account_data_type, }, retcol="content", allow_none=True, ) return json.loads(content_json) if content_json else None return self.runInteraction( "get_account_data_for_room_and_type", get_account_data_for_room_and_type_txn ) def get_all_updated_account_data( self, last_global_id, last_room_id, current_id, limit ): """Get all the client account_data that has changed on the server Args: last_global_id(int): The position to fetch from for top level data last_room_id(int): The position to fetch from for per room data current_id(int): The position to fetch up to. Returns: A deferred pair of lists of tuples of stream_id int, user_id string, room_id string, and type string. """ if last_room_id == current_id and last_global_id == current_id: return defer.succeed(([], [])) def get_updated_account_data_txn(txn): sql = ( "SELECT stream_id, user_id, account_data_type" " FROM account_data WHERE ? < stream_id AND stream_id <= ?" " ORDER BY stream_id ASC LIMIT ?" ) txn.execute(sql, (last_global_id, current_id, limit)) global_results = txn.fetchall() sql = ( "SELECT stream_id, user_id, room_id, account_data_type" " FROM room_account_data WHERE ? < stream_id AND stream_id <= ?" " ORDER BY stream_id ASC LIMIT ?" ) txn.execute(sql, (last_room_id, current_id, limit)) room_results = txn.fetchall() return global_results, room_results return self.runInteraction( "get_all_updated_account_data_txn", get_updated_account_data_txn ) def get_updated_account_data_for_user(self, user_id, stream_id): """Get all the client account_data for a that's changed for a user Args: user_id(str): The user to get the account_data for. stream_id(int): The point in the stream since which to get updates Returns: A deferred pair of a dict of global account_data and a dict mapping from room_id string to per room account_data dicts. """ def get_updated_account_data_for_user_txn(txn): sql = ( "SELECT account_data_type, content FROM account_data" " WHERE user_id = ? AND stream_id > ?" ) txn.execute(sql, (user_id, stream_id)) global_account_data = {row[0]: json.loads(row[1]) for row in txn} sql = ( "SELECT room_id, account_data_type, content FROM room_account_data" " WHERE user_id = ? AND stream_id > ?" ) txn.execute(sql, (user_id, stream_id)) account_data_by_room = {} for row in txn: room_account_data = account_data_by_room.setdefault(row[0], {}) room_account_data[row[1]] = json.loads(row[2]) return global_account_data, account_data_by_room changed = self._account_data_stream_cache.has_entity_changed( user_id, int(stream_id) ) if not changed: return {}, {} return self.runInteraction( "get_updated_account_data_for_user", get_updated_account_data_for_user_txn ) @cachedInlineCallbacks(num_args=2, cache_context=True, max_entries=5000) def is_ignored_by(self, ignored_user_id, ignorer_user_id, cache_context): ignored_account_data = yield self.get_global_account_data_by_type_for_user( "m.ignored_user_list", ignorer_user_id, on_invalidate=cache_context.invalidate, ) if not ignored_account_data: return False return ignored_user_id in ignored_account_data.get("ignored_users", {})
class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore): def __init__( self, database: DatabasePool, db_conn: LoggingDatabaseConnection, hs: "HomeServer", ): super().__init__(database, db_conn, hs) # `_can_write_to_account_data` indicates whether the current worker is allowed # to write account data. A value of `True` implies that `_account_data_id_gen` # is an `AbstractStreamIdGenerator` and not just a tracker. self._account_data_id_gen: AbstractStreamIdTracker if isinstance(database.engine, PostgresEngine): self._can_write_to_account_data = ( self._instance_name in hs.config.worker.writers.account_data) self._account_data_id_gen = MultiWriterIdGenerator( db_conn=db_conn, db=database, stream_name="account_data", instance_name=self._instance_name, tables=[ ("room_account_data", "instance_name", "stream_id"), ("room_tags_revisions", "instance_name", "stream_id"), ("account_data", "instance_name", "stream_id"), ], sequence_name="account_data_sequence", writers=hs.config.worker.writers.account_data, ) else: # We shouldn't be running in worker mode with SQLite, but its useful # to support it for unit tests. # # If this process is the writer than we need to use # `StreamIdGenerator`, otherwise we use `SlavedIdTracker` which gets # updated over replication. (Multiple writers are not supported for # SQLite). if self._instance_name in hs.config.worker.writers.account_data: self._can_write_to_account_data = True self._account_data_id_gen = StreamIdGenerator( db_conn, "room_account_data", "stream_id", extra_tables=[("room_tags_revisions", "stream_id")], ) else: self._account_data_id_gen = SlavedIdTracker( db_conn, "room_account_data", "stream_id", extra_tables=[("room_tags_revisions", "stream_id")], ) account_max = self.get_max_account_data_stream_id() self._account_data_stream_cache = StreamChangeCache( "AccountDataAndTagsChangeCache", account_max) self.db_pool.updates.register_background_update_handler( "delete_account_data_for_deactivated_users", self._delete_account_data_for_deactivated_users, ) def get_max_account_data_stream_id(self) -> int: """Get the current max stream ID for account data stream Returns: int """ return self._account_data_id_gen.get_current_token() @cached() async def get_account_data_for_user( self, user_id: str ) -> Tuple[Dict[str, JsonDict], Dict[str, Dict[str, JsonDict]]]: """Get all the client account_data for a user. Args: user_id: The user to get the account_data for. Returns: A 2-tuple of a dict of global account_data and a dict mapping from room_id string to per room account_data dicts. """ def get_account_data_for_user_txn( txn: LoggingTransaction, ) -> Tuple[Dict[str, JsonDict], Dict[str, Dict[str, JsonDict]]]: rows = self.db_pool.simple_select_list_txn( txn, "account_data", {"user_id": user_id}, ["account_data_type", "content"], ) global_account_data = { row["account_data_type"]: db_to_json(row["content"]) for row in rows } rows = self.db_pool.simple_select_list_txn( txn, "room_account_data", {"user_id": user_id}, ["room_id", "account_data_type", "content"], ) by_room: Dict[str, Dict[str, JsonDict]] = {} for row in rows: room_data = by_room.setdefault(row["room_id"], {}) room_data[row["account_data_type"]] = db_to_json( row["content"]) return global_account_data, by_room return await self.db_pool.runInteraction( "get_account_data_for_user", get_account_data_for_user_txn) @cached(num_args=2, max_entries=5000, tree=True) async def get_global_account_data_by_type_for_user( self, user_id: str, data_type: str) -> Optional[JsonDict]: """ Returns: The account data. """ result = await self.db_pool.simple_select_one_onecol( table="account_data", keyvalues={ "user_id": user_id, "account_data_type": data_type }, retcol="content", desc="get_global_account_data_by_type_for_user", allow_none=True, ) if result: return db_to_json(result) else: return None @cached(num_args=2, tree=True) async def get_account_data_for_room(self, user_id: str, room_id: str) -> Dict[str, JsonDict]: """Get all the client account_data for a user for a room. Args: user_id: The user to get the account_data for. room_id: The room to get the account_data for. Returns: A dict of the room account_data """ def get_account_data_for_room_txn( txn: LoggingTransaction, ) -> Dict[str, JsonDict]: rows = self.db_pool.simple_select_list_txn( txn, "room_account_data", { "user_id": user_id, "room_id": room_id }, ["account_data_type", "content"], ) return { row["account_data_type"]: db_to_json(row["content"]) for row in rows } return await self.db_pool.runInteraction( "get_account_data_for_room", get_account_data_for_room_txn) @cached(num_args=3, max_entries=5000, tree=True) async def get_account_data_for_room_and_type( self, user_id: str, room_id: str, account_data_type: str) -> Optional[JsonDict]: """Get the client account_data of given type for a user for a room. Args: user_id: The user to get the account_data for. room_id: The room to get the account_data for. account_data_type: The account data type to get. Returns: The room account_data for that type, or None if there isn't any set. """ def get_account_data_for_room_and_type_txn( txn: LoggingTransaction, ) -> Optional[JsonDict]: content_json = self.db_pool.simple_select_one_onecol_txn( txn, table="room_account_data", keyvalues={ "user_id": user_id, "room_id": room_id, "account_data_type": account_data_type, }, retcol="content", allow_none=True, ) return db_to_json(content_json) if content_json else None return await self.db_pool.runInteraction( "get_account_data_for_room_and_type", get_account_data_for_room_and_type_txn) async def get_updated_global_account_data( self, last_id: int, current_id: int, limit: int) -> List[Tuple[int, str, str]]: """Get the global account_data that has changed, for the account_data stream Args: last_id: the last stream_id from the previous batch. current_id: the maximum stream_id to return up to limit: the maximum number of rows to return Returns: A list of tuples of stream_id int, user_id string, and type string. """ if last_id == current_id: return [] def get_updated_global_account_data_txn( txn: LoggingTransaction, ) -> List[Tuple[int, str, str]]: sql = ("SELECT stream_id, user_id, account_data_type" " FROM account_data WHERE ? < stream_id AND stream_id <= ?" " ORDER BY stream_id ASC LIMIT ?") txn.execute(sql, (last_id, current_id, limit)) return cast(List[Tuple[int, str, str]], txn.fetchall()) return await self.db_pool.runInteraction( "get_updated_global_account_data", get_updated_global_account_data_txn) async def get_updated_room_account_data( self, last_id: int, current_id: int, limit: int) -> List[Tuple[int, str, str, str]]: """Get the global account_data that has changed, for the account_data stream Args: last_id: the last stream_id from the previous batch. current_id: the maximum stream_id to return up to limit: the maximum number of rows to return Returns: A list of tuples of stream_id int, user_id string, room_id string and type string. """ if last_id == current_id: return [] def get_updated_room_account_data_txn( txn: LoggingTransaction, ) -> List[Tuple[int, str, str, str]]: sql = ( "SELECT stream_id, user_id, room_id, account_data_type" " FROM room_account_data WHERE ? < stream_id AND stream_id <= ?" " ORDER BY stream_id ASC LIMIT ?") txn.execute(sql, (last_id, current_id, limit)) return cast(List[Tuple[int, str, str, str]], txn.fetchall()) return await self.db_pool.runInteraction( "get_updated_room_account_data", get_updated_room_account_data_txn) async def get_updated_account_data_for_user( self, user_id: str, stream_id: int ) -> Tuple[Dict[str, JsonDict], Dict[str, Dict[str, JsonDict]]]: """Get all the client account_data for a that's changed for a user Args: user_id: The user to get the account_data for. stream_id: The point in the stream since which to get updates Returns: A deferred pair of a dict of global account_data and a dict mapping from room_id string to per room account_data dicts. """ def get_updated_account_data_for_user_txn( txn: LoggingTransaction, ) -> Tuple[Dict[str, JsonDict], Dict[str, Dict[str, JsonDict]]]: sql = ("SELECT account_data_type, content FROM account_data" " WHERE user_id = ? AND stream_id > ?") txn.execute(sql, (user_id, stream_id)) global_account_data = {row[0]: db_to_json(row[1]) for row in txn} sql = ( "SELECT room_id, account_data_type, content FROM room_account_data" " WHERE user_id = ? AND stream_id > ?") txn.execute(sql, (user_id, stream_id)) account_data_by_room: Dict[str, Dict[str, JsonDict]] = {} for row in txn: room_account_data = account_data_by_room.setdefault(row[0], {}) room_account_data[row[1]] = db_to_json(row[2]) return global_account_data, account_data_by_room changed = self._account_data_stream_cache.has_entity_changed( user_id, int(stream_id)) if not changed: return {}, {} return await self.db_pool.runInteraction( "get_updated_account_data_for_user", get_updated_account_data_for_user_txn) @cached(max_entries=5000, iterable=True) async def ignored_by(self, user_id: str) -> FrozenSet[str]: """ Get users which ignore the given user. Params: user_id: The user ID which might be ignored. Return: The user IDs which ignore the given user. """ return frozenset(await self.db_pool.simple_select_onecol( table="ignored_users", keyvalues={"ignored_user_id": user_id}, retcol="ignorer_user_id", desc="ignored_by", )) @cached(max_entries=5000, iterable=True) async def ignored_users(self, user_id: str) -> FrozenSet[str]: """ Get users which the given user ignores. Params: user_id: The user ID which is making the request. Return: The user IDs which are ignored by the given user. """ return frozenset(await self.db_pool.simple_select_onecol( table="ignored_users", keyvalues={"ignorer_user_id": user_id}, retcol="ignored_user_id", desc="ignored_users", )) def process_replication_rows( self, stream_name: str, instance_name: str, token: int, rows: Iterable[Any], ) -> None: if stream_name == TagAccountDataStream.NAME: self._account_data_id_gen.advance(instance_name, token) elif stream_name == AccountDataStream.NAME: self._account_data_id_gen.advance(instance_name, token) for row in rows: if not row.room_id: self.get_global_account_data_by_type_for_user.invalidate( (row.user_id, row.data_type)) self.get_account_data_for_user.invalidate((row.user_id, )) self.get_account_data_for_room.invalidate( (row.user_id, row.room_id)) self.get_account_data_for_room_and_type.invalidate( (row.user_id, row.room_id, row.data_type)) self._account_data_stream_cache.entity_has_changed( row.user_id, token) super().process_replication_rows(stream_name, instance_name, token, rows) async def add_account_data_to_room(self, user_id: str, room_id: str, account_data_type: str, content: JsonDict) -> int: """Add some account_data to a room for a user. Args: user_id: The user to add a tag for. room_id: The room to add a tag for. account_data_type: The type of account_data to add. content: A json object to associate with the tag. Returns: The maximum stream ID. """ assert self._can_write_to_account_data assert isinstance(self._account_data_id_gen, AbstractStreamIdGenerator) content_json = json_encoder.encode(content) async with self._account_data_id_gen.get_next() as next_id: # no need to lock here as room_account_data has a unique constraint # on (user_id, room_id, account_data_type) so simple_upsert will # retry if there is a conflict. await self.db_pool.simple_upsert( desc="add_room_account_data", table="room_account_data", keyvalues={ "user_id": user_id, "room_id": room_id, "account_data_type": account_data_type, }, values={ "stream_id": next_id, "content": content_json }, lock=False, ) self._account_data_stream_cache.entity_has_changed( user_id, next_id) self.get_account_data_for_user.invalidate((user_id, )) self.get_account_data_for_room.invalidate((user_id, room_id)) self.get_account_data_for_room_and_type.prefill( (user_id, room_id, account_data_type), content) return self._account_data_id_gen.get_current_token() async def add_account_data_for_user(self, user_id: str, account_data_type: str, content: JsonDict) -> int: """Add some global account_data for a user. Args: user_id: The user to add a tag for. account_data_type: The type of account_data to add. content: A json object to associate with the tag. Returns: The maximum stream ID. """ assert self._can_write_to_account_data assert isinstance(self._account_data_id_gen, AbstractStreamIdGenerator) async with self._account_data_id_gen.get_next() as next_id: await self.db_pool.runInteraction( "add_user_account_data", self._add_account_data_for_user, next_id, user_id, account_data_type, content, ) self._account_data_stream_cache.entity_has_changed( user_id, next_id) self.get_account_data_for_user.invalidate((user_id, )) self.get_global_account_data_by_type_for_user.invalidate( (user_id, account_data_type)) return self._account_data_id_gen.get_current_token() def _add_account_data_for_user( self, txn: LoggingTransaction, next_id: int, user_id: str, account_data_type: str, content: JsonDict, ) -> None: content_json = json_encoder.encode(content) # no need to lock here as account_data has a unique constraint on # (user_id, account_data_type) so simple_upsert will retry if # there is a conflict. self.db_pool.simple_upsert_txn( txn, table="account_data", keyvalues={ "user_id": user_id, "account_data_type": account_data_type }, values={ "stream_id": next_id, "content": content_json }, lock=False, ) # Ignored users get denormalized into a separate table as an optimisation. if account_data_type != AccountDataTypes.IGNORED_USER_LIST: return # Insert / delete to sync the list of ignored users. previously_ignored_users = set( self.db_pool.simple_select_onecol_txn( txn, table="ignored_users", keyvalues={"ignorer_user_id": user_id}, retcol="ignored_user_id", )) # If the data is invalid, no one is ignored. ignored_users_content = content.get("ignored_users", {}) if isinstance(ignored_users_content, dict): currently_ignored_users = set(ignored_users_content) else: currently_ignored_users = set() # If the data has not changed, nothing to do. if previously_ignored_users == currently_ignored_users: return # Delete entries which are no longer ignored. self.db_pool.simple_delete_many_txn( txn, table="ignored_users", column="ignored_user_id", values=previously_ignored_users - currently_ignored_users, keyvalues={"ignorer_user_id": user_id}, ) # Add entries which are newly ignored. self.db_pool.simple_insert_many_txn( txn, table="ignored_users", keys=("ignorer_user_id", "ignored_user_id"), values=[(user_id, u) for u in currently_ignored_users - previously_ignored_users ], ) # Invalidate the cache for any ignored users which were added or removed. for ignored_user_id in previously_ignored_users ^ currently_ignored_users: self._invalidate_cache_and_stream(txn, self.ignored_by, (ignored_user_id, )) self._invalidate_cache_and_stream(txn, self.ignored_users, (user_id, )) async def purge_account_data_for_user(self, user_id: str) -> None: """ Removes ALL the account data for a user. Intended to be used upon user deactivation. Also purges the user from the ignored_users cache table and the push_rules cache tables. """ await self.db_pool.runInteraction( "purge_account_data_for_user_txn", self._purge_account_data_for_user_txn, user_id, ) def _purge_account_data_for_user_txn(self, txn: LoggingTransaction, user_id: str) -> None: """ See `purge_account_data_for_user`. """ # Purge from the primary account_data tables. self.db_pool.simple_delete_txn(txn, table="account_data", keyvalues={"user_id": user_id}) self.db_pool.simple_delete_txn(txn, table="room_account_data", keyvalues={"user_id": user_id}) # Purge from ignored_users where this user is the ignorer. # N.B. We don't purge where this user is the ignoree, because that # interferes with other users' account data. # It's also not this user's data to delete! self.db_pool.simple_delete_txn(txn, table="ignored_users", keyvalues={"ignorer_user_id": user_id}) # Remove the push rules self.db_pool.simple_delete_txn(txn, table="push_rules", keyvalues={"user_name": user_id}) self.db_pool.simple_delete_txn(txn, table="push_rules_enable", keyvalues={"user_name": user_id}) self.db_pool.simple_delete_txn(txn, table="push_rules_stream", keyvalues={"user_id": user_id}) # Invalidate caches as appropriate self._invalidate_cache_and_stream( txn, self.get_account_data_for_room_and_type, (user_id, )) self._invalidate_cache_and_stream(txn, self.get_account_data_for_user, (user_id, )) self._invalidate_cache_and_stream( txn, self.get_global_account_data_by_type_for_user, (user_id, )) self._invalidate_cache_and_stream(txn, self.get_account_data_for_room, (user_id, )) self._invalidate_cache_and_stream(txn, self.get_push_rules_for_user, (user_id, )) self._invalidate_cache_and_stream(txn, self.get_push_rules_enabled_for_user, (user_id, )) # This user might be contained in the ignored_by cache for other users, # so we have to invalidate it all. self._invalidate_all_cache_and_stream(txn, self.ignored_by) async def _delete_account_data_for_deactivated_users( self, progress: dict, batch_size: int) -> int: """ Retroactively purges account data for users that have already been deactivated. Gets run as a background update caused by a schema delta. """ last_user: str = progress.get("last_user", "") def _delete_account_data_for_deactivated_users_txn( txn: LoggingTransaction, ) -> int: sql = """ SELECT name FROM users WHERE deactivated = ? and name > ? ORDER BY name ASC LIMIT ? """ txn.execute(sql, (1, last_user, batch_size)) users = [row[0] for row in txn] for user in users: self._purge_account_data_for_user_txn(txn, user_id=user) if users: self.db_pool.updates._background_update_progress_txn( txn, "delete_account_data_for_deactivated_users", {"last_user": users[-1]}, ) return len(users) number_deleted = await self.db_pool.runInteraction( "_delete_account_data_for_deactivated_users", _delete_account_data_for_deactivated_users_txn, ) if number_deleted < batch_size: await self.db_pool.updates._end_background_update( "delete_account_data_for_deactivated_users") return number_deleted