def _get_users_server_still_shares_room_with_txn(txn): sql = """ SELECT state_key FROM current_state_events WHERE type = 'm.room.member' AND membership = 'join' AND %s GROUP BY state_key """ clause, args = make_in_list_sql_clause(self.database_engine, "state_key", user_ids) txn.execute(sql % (clause, ), args) return {row[0] for row in txn}
def _get_users_whose_devices_changed_txn(txn): changes = set() sql = """ SELECT DISTINCT user_id FROM device_lists_stream WHERE stream_id > ? AND """ for chunk in batch_iter(to_check, 100): clause, args = make_in_list_sql_clause(txn.database_engine, "user_id", chunk) txn.execute(sql + clause, (from_key, ) + tuple(args)) changes.update(user_id for user_id, in txn) return changes
def graph_to_linear(txn): clause, args = make_in_list_sql_clause(self.database_engine, "event_id", event_ids) sql = """ SELECT event_id WHERE room_id = ? AND stream_ordering IN ( SELECT max(stream_ordering) WHERE %s ) """ % (clause, ) txn.execute(sql, [room_id] + list(args)) rows = txn.fetchall() if rows: return rows[0][0] else: raise RuntimeError("Unrecognized event_ids: %r" % (event_ids, ))
def _update_presence_txn(self, txn, stream_orderings, presence_states): for stream_id, state in zip(stream_orderings, presence_states): txn.call_after(self.presence_stream_cache.entity_has_changed, state.user_id, stream_id) txn.call_after(self._get_presence_for_user.invalidate, (state.user_id, )) # Delete old rows to stop database from getting really big sql = "DELETE FROM presence_stream WHERE stream_id < ? AND " for states in batch_iter(presence_states, 50): clause, args = make_in_list_sql_clause(self.database_engine, "user_id", [s.user_id for s in states]) txn.execute(sql + clause, [stream_id] + list(args)) # Actually insert new rows self.db_pool.simple_insert_many_txn( txn, table="presence_stream", keys=( "stream_id", "user_id", "state", "last_active_ts", "last_federation_update_ts", "last_user_sync_ts", "status_msg", "currently_active", "instance_name", ), values=[( stream_id, state.user_id, state.state, state.last_active_ts, state.last_federation_update_ts, state.last_user_sync_ts, state.status_msg, state.currently_active, self._instance_name, ) for stream_id, state in zip(stream_orderings, presence_states)], )
def get_last_receipt_for_user_txn( self, txn: LoggingTransaction, user_id: str, room_id: str, receipt_types: Collection[str], ) -> Optional[Tuple[str, int]]: """ Fetch the event ID and stream_ordering for the latest receipt in a room with one of the given receipt types. Args: user_id: The user to fetch receipts for. room_id: The room ID to fetch the receipt for. receipt_type: The receipt types to fetch. Returns: The latest receipt, if one exists. """ clause, args = make_in_list_sql_clause( self.database_engine, "receipt_type", receipt_types ) sql = f""" SELECT event_id, stream_ordering FROM receipts_linearized INNER JOIN events USING (room_id, event_id) WHERE {clause} AND user_id = ? AND room_id = ? ORDER BY stream_ordering DESC LIMIT 1 """ args.extend((user_id, room_id)) txn.execute(sql, args) return cast(Optional[Tuple[str, int]], txn.fetchone())
def _graph_to_linear( self, txn: LoggingTransaction, room_id: str, event_ids: List[str] ) -> str: """ Generate a linearized event from a list of events (i.e. a list of forward extremities in the room). This should allow for calculation of the correct read receipt even if servers have different event ordering. Args: txn: The transaction room_id: The room ID the events are in. event_ids: The list of event IDs to linearize. Returns: The linearized event ID. """ # TODO: Make this better. clause, args = make_in_list_sql_clause( self.database_engine, "event_id", event_ids ) sql = """ SELECT event_id WHERE room_id = ? AND stream_ordering IN ( SELECT max(stream_ordering) WHERE %s ) """ % ( clause, ) txn.execute(sql, [room_id] + list(args)) rows = txn.fetchall() if rows: return rows[0][0] else: raise RuntimeError("Unrecognized event_ids: %r" % (event_ids,))
def _get_rooms_for_local_user_where_membership_is_txn( self, txn, user_id, membership_list): # Paranoia check. if not self.hs.is_mine_id(user_id): raise Exception( "Cannot call 'get_rooms_for_local_user_where_membership_is' on non-local user %r" % (user_id, ), ) clause, args = make_in_list_sql_clause(self.database_engine, "c.membership", membership_list) sql = """ SELECT room_id, e.sender, c.membership, event_id, e.stream_ordering FROM local_current_membership AS c INNER JOIN events AS e USING (room_id, event_id) WHERE user_id = ? AND %s """ % (clause, ) txn.execute(sql, (user_id, *args)) results = [RoomsForUser(**r) for r in self.db.cursor_to_dict(txn)] return results
def _get_auth_chain_difference_txn(self, txn, state_sets: List[Set[str]]) -> Set[str]: # Algorithm Description # ~~~~~~~~~~~~~~~~~~~~~ # # The idea here is to basically walk the auth graph of each state set in # tandem, keeping track of which auth events are reachable by each state # set. If we reach an auth event we've already visited (via a different # state set) then we mark that auth event and all ancestors as reachable # by the state set. This requires that we keep track of the auth chains # in memory. # # Doing it in a such a way means that we can stop early if all auth # events we're currently walking are reachable by all state sets. # # *Note*: We can't stop walking an event's auth chain if it is reachable # by all state sets. This is because other auth chains we're walking # might be reachable only via the original auth chain. For example, # given the following auth chain: # # A -> C -> D -> E # / / # B -´---------´ # # and state sets {A} and {B} then walking the auth chains of A and B # would immediately show that C is reachable by both. However, if we # stopped at C then we'd only reach E via the auth chain of B and so E # would errornously get included in the returned difference. # # The other thing that we do is limit the number of auth chains we walk # at once, due to practical limits (i.e. we can only query the database # with a limited set of parameters). We pick the auth chains we walk # each iteration based on their depth, in the hope that events with a # lower depth are likely reachable by those with higher depths. # # We could use any ordering that we believe would give a rough # topological ordering, e.g. origin server timestamp. If the ordering # chosen is not topological then the algorithm still produces the right # result, but perhaps a bit more inefficiently. This is why it is safe # to use "depth" here. initial_events = set(state_sets[0]).union(*state_sets[1:]) # Dict from events in auth chains to which sets *cannot* reach them. # I.e. if the set is empty then all sets can reach the event. event_to_missing_sets = { event_id: {i for i, a in enumerate(state_sets) if event_id not in a} for event_id in initial_events } # The sorted list of events whose auth chains we should walk. search = [] # type: List[Tuple[int, str]] # We need to get the depth of the initial events for sorting purposes. sql = """ SELECT depth, event_id FROM events WHERE %s """ # the list can be huge, so let's avoid looking them all up in one massive # query. for batch in batch_iter(initial_events, 1000): clause, args = make_in_list_sql_clause(txn.database_engine, "event_id", batch) txn.execute(sql % (clause, ), args) # I think building a temporary list with fetchall is more efficient than # just `search.extend(txn)`, but this is unconfirmed search.extend(txn.fetchall()) # sort by depth search.sort() # Map from event to its auth events event_to_auth_events = {} # type: Dict[str, Set[str]] base_sql = """ SELECT a.event_id, auth_id, depth FROM event_auth AS a INNER JOIN events AS e ON (e.event_id = a.auth_id) WHERE """ while search: # Check whether all our current walks are reachable by all state # sets. If so we can bail. if all(not event_to_missing_sets[eid] for _, eid in search): break # Fetch the auth events and their depths of the N last events we're # currently walking, either from cache or DB. search, chunk = search[:-100], search[-100:] found = [] # Results found # type: List[Tuple[str, str, int]] to_fetch = [] # Event IDs to fetch from DB # type: List[str] for _, event_id in chunk: res = self._event_auth_cache.get(event_id) if res is None: to_fetch.append(event_id) else: found.extend( (event_id, auth_id, depth) for auth_id, depth in res) if to_fetch: clause, args = make_in_list_sql_clause(txn.database_engine, "a.event_id", to_fetch) txn.execute(base_sql + clause, args) # We parse the results and add the to the `found` set and the # cache (note we need to batch up the results by event ID before # adding to the cache). to_cache = {} for event_id, auth_event_id, auth_event_depth in txn: to_cache.setdefault(event_id, []).append( (auth_event_id, auth_event_depth)) found.append((event_id, auth_event_id, auth_event_depth)) for event_id, auth_events in to_cache.items(): self._event_auth_cache.set(event_id, auth_events) for event_id, auth_event_id, auth_event_depth in found: event_to_auth_events.setdefault(event_id, set()).add(auth_event_id) sets = event_to_missing_sets.get(auth_event_id) if sets is None: # First time we're seeing this event, so we add it to the # queue of things to fetch. search.append((auth_event_depth, auth_event_id)) # Assume that this event is unreachable from any of the # state sets until proven otherwise sets = event_to_missing_sets[auth_event_id] = set( range(len(state_sets))) else: # We've previously seen this event, so look up its auth # events and recursively mark all ancestors as reachable # by the current event's state set. a_ids = event_to_auth_events.get(auth_event_id) while a_ids: new_aids = set() for a_id in a_ids: event_to_missing_sets[a_id].intersection_update( event_to_missing_sets[event_id]) b = event_to_auth_events.get(a_id) if b: new_aids.update(b) a_ids = new_aids # Mark that the auth event is reachable by the approriate sets. sets.intersection_update(event_to_missing_sets[event_id]) search.sort() # Return all events where not all sets can reach them. return {eid for eid, n in event_to_missing_sets.items() if n}
def _cleanup_extremities_bg_update_txn(txn): # The set of extremity event IDs that we're checking this round original_set = set() # A dict[str, set[str]] of event ID to their prev events. graph = {} # The set of descendants of the original set that are not rejected # nor soft-failed. Ancestors of these events should be removed # from the forward extremities table. non_rejected_leaves = set() # Set of event IDs that have been soft failed, and for which we # should check if they have descendants which haven't been soft # failed. soft_failed_events_to_lookup = set() # First, we get `batch_size` events from the table, pulling out # their successor events, if any, and the successor events' # rejection status. txn.execute( """SELECT prev_event_id, event_id, internal_metadata, rejections.event_id IS NOT NULL, events.outlier FROM ( SELECT event_id AS prev_event_id FROM _extremities_to_check LIMIT ? ) AS f LEFT JOIN event_edges USING (prev_event_id) LEFT JOIN events USING (event_id) LEFT JOIN event_json USING (event_id) LEFT JOIN rejections USING (event_id) """, (batch_size, ), ) for prev_event_id, event_id, metadata, rejected, outlier in txn: original_set.add(prev_event_id) if not event_id or outlier: # Common case where the forward extremity doesn't have any # descendants. continue graph.setdefault(event_id, set()).add(prev_event_id) soft_failed = False if metadata: soft_failed = db_to_json(metadata).get("soft_failed") if soft_failed or rejected: soft_failed_events_to_lookup.add(event_id) else: non_rejected_leaves.add(event_id) # Now we recursively check all the soft-failed descendants we # found above in the same way, until we have nothing left to # check. while soft_failed_events_to_lookup: # We only want to do 100 at a time, so we split given list # into two. batch = list(soft_failed_events_to_lookup) to_check, to_defer = batch[:100], batch[100:] soft_failed_events_to_lookup = set(to_defer) sql = """SELECT prev_event_id, event_id, internal_metadata, rejections.event_id IS NOT NULL FROM event_edges INNER JOIN events USING (event_id) INNER JOIN event_json USING (event_id) LEFT JOIN rejections USING (event_id) WHERE NOT events.outlier AND """ clause, args = make_in_list_sql_clause(self.database_engine, "prev_event_id", to_check) txn.execute(sql + clause, list(args)) for prev_event_id, event_id, metadata, rejected in txn: if event_id in graph: # Already handled this event previously, but we still # want to record the edge. graph[event_id].add(prev_event_id) continue graph[event_id] = {prev_event_id} soft_failed = db_to_json(metadata).get("soft_failed") if soft_failed or rejected: soft_failed_events_to_lookup.add(event_id) else: non_rejected_leaves.add(event_id) # We have a set of non-soft-failed descendants, so we recurse up # the graph to find all ancestors and add them to the set of event # IDs that we can delete from forward extremities table. to_delete = set() while non_rejected_leaves: event_id = non_rejected_leaves.pop() prev_event_ids = graph.get(event_id, set()) non_rejected_leaves.update(prev_event_ids) to_delete.update(prev_event_ids) to_delete.intersection_update(original_set) deleted = self.db_pool.simple_delete_many_txn( txn=txn, table="event_forward_extremities", column="event_id", iterable=to_delete, keyvalues={}, ) logger.info( "Deleted %d forward extremities of %d checked, to clean up #5269", deleted, len(original_set), ) if deleted: # We now need to invalidate the caches of these rooms rows = self.db_pool.simple_select_many_txn( txn, table="events", column="event_id", iterable=to_delete, keyvalues={}, retcols=("room_id", ), ) room_ids = {row["room_id"] for row in rows} for room_id in room_ids: txn.call_after( self.get_latest_event_ids_in_room.invalidate, (room_id, )) self.db_pool.simple_delete_many_txn( txn=txn, table="_extremities_to_check", column="event_id", iterable=original_set, keyvalues={}, ) return len(original_set)
def search_rooms(self, room_ids, search_term, keys, limit, pagination_token=None): """Performs a full text search over events with given keys. Args: room_id (list): The room_ids to search in search_term (str): Search term to search for keys (list): List of keys to search in, currently supports "content.body", "content.name", "content.topic" pagination_token (str): A pagination token previously returned Returns: list of dicts """ clauses = [] search_query = _parse_query(self.database_engine, search_term) args = [] # Make sure we don't explode because the person is in too many rooms. # We filter the results below regardless. if len(room_ids) < 500: clause, args = make_in_list_sql_clause(self.database_engine, "room_id", room_ids) clauses = [clause] local_clauses = [] for key in keys: local_clauses.append("key = ?") args.append(key) clauses.append("(%s)" % (" OR ".join(local_clauses), )) # take copies of the current args and clauses lists, before adding # pagination clauses to main query. count_args = list(args) count_clauses = list(clauses) if pagination_token: try: origin_server_ts, stream = pagination_token.split(",") origin_server_ts = int(origin_server_ts) stream = int(stream) except Exception: raise SynapseError(400, "Invalid pagination token") clauses.append( "(origin_server_ts < ?" " OR (origin_server_ts = ? AND stream_ordering < ?))") args.extend([origin_server_ts, origin_server_ts, stream]) if isinstance(self.database_engine, PostgresEngine): sql = ( "SELECT ts_rank_cd(vector, to_tsquery('english', ?)) as rank," " origin_server_ts, stream_ordering, room_id, event_id" " FROM event_search" " WHERE vector @@ to_tsquery('english', ?) AND ") args = [search_query, search_query] + args count_sql = ("SELECT room_id, count(*) as count FROM event_search" " WHERE vector @@ to_tsquery('english', ?) AND ") count_args = [search_query] + count_args elif isinstance(self.database_engine, Sqlite3Engine): # We use CROSS JOIN here to ensure we use the right indexes. # https://sqlite.org/optoverview.html#crossjoin # # We want to use the full text search index on event_search to # extract all possible matches first, then lookup those matches # in the events table to get the topological ordering. We need # to use the indexes in this order because sqlite refuses to # MATCH unless it uses the full text search index sql = ( "SELECT rank(matchinfo) as rank, room_id, event_id," " origin_server_ts, stream_ordering" " FROM (SELECT key, event_id, matchinfo(event_search) as matchinfo" " FROM event_search" " WHERE value MATCH ?" " )" " CROSS JOIN events USING (event_id)" " WHERE ") args = [search_query] + args count_sql = ("SELECT room_id, count(*) as count FROM event_search" " WHERE value MATCH ? AND ") count_args = [search_term] + count_args else: # This should be unreachable. raise Exception("Unrecognized database engine") sql += " AND ".join(clauses) count_sql += " AND ".join(count_clauses) # We add an arbitrary limit here to ensure we don't try to pull the # entire table from the database. if isinstance(self.database_engine, PostgresEngine): sql += (" ORDER BY origin_server_ts DESC NULLS LAST," " stream_ordering DESC NULLS LAST LIMIT ?") elif isinstance(self.database_engine, Sqlite3Engine): sql += " ORDER BY origin_server_ts DESC, stream_ordering DESC LIMIT ?" else: raise Exception("Unrecognized database engine") args.append(limit) results = yield self.db.execute("search_rooms", self.db.cursor_to_dict, sql, *args) results = list(filter(lambda row: row["room_id"] in room_ids, results)) # We set redact_behaviour to BLOCK here to prevent redacted events being returned in # search results (which is a data leak) events = yield self.get_events_as_list( [r["event_id"] for r in results], redact_behaviour=EventRedactBehaviour.BLOCK, ) event_map = {ev.event_id: ev for ev in events} highlights = None if isinstance(self.database_engine, PostgresEngine): highlights = yield self._find_highlights_in_postgres( search_query, events) count_sql += " GROUP BY room_id" count_results = yield self.db.execute("search_rooms_count", self.db.cursor_to_dict, count_sql, *count_args) count = sum(row["count"] for row in count_results if row["room_id"] in room_ids) return { "results": [{ "event": event_map[r["event_id"]], "rank": r["rank"], "pagination_token": "%s,%s" % (r["origin_server_ts"], r["stream_ordering"]), } for r in results if r["event_id"] in event_map], "highlights": highlights, "count": count, }
def search_msgs(self, room_ids, search_term, keys): """Performs a full text search over events with given keys. Args: room_ids (list): List of room ids to search in search_term (str): Search term to search for keys (list): List of keys to search in, currently supports "content.body", "content.name", "content.topic" Returns: list of dicts """ clauses = [] search_query = _parse_query(self.database_engine, search_term) args = [] # Make sure we don't explode because the person is in too many rooms. # We filter the results below regardless. if len(room_ids) < 500: clause, args = make_in_list_sql_clause(self.database_engine, "room_id", room_ids) clauses = [clause] local_clauses = [] for key in keys: local_clauses.append("key = ?") args.append(key) clauses.append("(%s)" % (" OR ".join(local_clauses), )) count_args = args count_clauses = clauses if isinstance(self.database_engine, PostgresEngine): sql = ( "SELECT ts_rank_cd(vector, to_tsquery('english', ?)) AS rank," " room_id, event_id" " FROM event_search" " WHERE vector @@ to_tsquery('english', ?)") args = [search_query, search_query] + args count_sql = ("SELECT room_id, count(*) as count FROM event_search" " WHERE vector @@ to_tsquery('english', ?)") count_args = [search_query] + count_args elif isinstance(self.database_engine, Sqlite3Engine): sql = ( "SELECT rank(matchinfo(event_search)) as rank, room_id, event_id" " FROM event_search" " WHERE value MATCH ?") args = [search_query] + args count_sql = ("SELECT room_id, count(*) as count FROM event_search" " WHERE value MATCH ?") count_args = [search_term] + count_args else: # This should be unreachable. raise Exception("Unrecognized database engine") for clause in clauses: sql += " AND " + clause for clause in count_clauses: count_sql += " AND " + clause # We add an arbitrary limit here to ensure we don't try to pull the # entire table from the database. sql += " ORDER BY rank DESC LIMIT 500" results = yield self.db.execute("search_msgs", self.db.cursor_to_dict, sql, *args) results = list(filter(lambda row: row["room_id"] in room_ids, results)) # We set redact_behaviour to BLOCK here to prevent redacted events being returned in # search results (which is a data leak) events = yield self.get_events_as_list( [r["event_id"] for r in results], redact_behaviour=EventRedactBehaviour.BLOCK, ) event_map = {ev.event_id: ev for ev in events} highlights = None if isinstance(self.database_engine, PostgresEngine): highlights = yield self._find_highlights_in_postgres( search_query, events) count_sql += " GROUP BY room_id" count_results = yield self.db.execute("search_rooms_count", self.db.cursor_to_dict, count_sql, *count_args) count = sum(row["count"] for row in count_results if row["room_id"] in room_ids) return { "results": [{ "event": event_map[r["event_id"]], "rank": r["rank"] } for r in results if r["event_id"] in event_map], "highlights": highlights, "count": count, }
def _add_messages_to_local_device_inbox_txn(self, txn, stream_id, messages_by_user_then_device): # Compatible method of performing an upsert sql = "SELECT stream_id FROM device_max_stream_id" txn.execute(sql) rows = txn.fetchone() if rows: db_stream_id = rows[0] if db_stream_id < stream_id: # Insert the new stream_id sql = "UPDATE device_max_stream_id SET stream_id = ?" else: # No rows, perform an insert sql = "INSERT INTO device_max_stream_id (stream_id) VALUES (?)" txn.execute(sql, (stream_id, )) local_by_user_then_device = {} for user_id, messages_by_device in messages_by_user_then_device.items( ): messages_json_for_user = {} devices = list(messages_by_device.keys()) if len(devices) == 1 and devices[0] == "*": # Handle wildcard device_ids. sql = "SELECT device_id FROM devices" " WHERE user_id = ?" txn.execute(sql, (user_id, )) message_json = json.dumps(messages_by_device["*"]) for row in txn: # Add the message for all devices for this user on this # server. device = row[0] messages_json_for_user[device] = message_json else: if not devices: continue clause, args = make_in_list_sql_clause(txn.database_engine, "device_id", devices) sql = "SELECT device_id FROM devices WHERE user_id = ? AND " + clause # TODO: Maybe this needs to be done in batches if there are # too many local devices for a given user. txn.execute(sql, [user_id] + list(args)) for row in txn: # Only insert into the local inbox if the device exists on # this server device = row[0] message_json = json.dumps(messages_by_device[device]) messages_json_for_user[device] = message_json if messages_json_for_user: local_by_user_then_device[user_id] = messages_json_for_user if not local_by_user_then_device: return sql = ("INSERT INTO device_inbox" " (user_id, device_id, stream_id, message_json)" " VALUES (?,?,?,?)") rows = [] for user_id, messages_by_device in local_by_user_then_device.items(): for device_id, message_json in messages_by_device.items(): rows.append((user_id, device_id, stream_id, message_json)) txn.executemany(sql, rows)
def _get_auth_chain_difference_using_cover_index_txn( self, txn: Cursor, room_id: str, state_sets: List[Set[str]]) -> Set[str]: """Calculates the auth chain difference using the chain index. See docs/auth_chain_difference_algorithm.md for details """ # First we look up the chain ID/sequence numbers for all the events, and # work out the chain/sequence numbers reachable from each state set. initial_events = set(state_sets[0]).union(*state_sets[1:]) # Map from event_id -> (chain ID, seq no) chain_info = {} # type: Dict[str, Tuple[int, int]] # Map from chain ID -> seq no -> event Id chain_to_event = {} # type: Dict[int, Dict[int, str]] # All the chains that we've found that are reachable from the state # sets. seen_chains = set() # type: Set[int] sql = """ SELECT event_id, chain_id, sequence_number FROM event_auth_chains WHERE %s """ for batch in batch_iter(initial_events, 1000): clause, args = make_in_list_sql_clause(txn.database_engine, "event_id", batch) txn.execute(sql % (clause, ), args) for event_id, chain_id, sequence_number in txn: chain_info[event_id] = (chain_id, sequence_number) seen_chains.add(chain_id) chain_to_event.setdefault(chain_id, {})[sequence_number] = event_id # Check that we actually have a chain ID for all the events. events_missing_chain_info = initial_events.difference(chain_info) if events_missing_chain_info: # This can happen due to e.g. downgrade/upgrade of the server. We # raise an exception and fall back to the previous algorithm. logger.info( "Unexpectedly found that events don't have chain IDs in room %s: %s", room_id, events_missing_chain_info, ) raise _NoChainCoverIndex(room_id) # Corresponds to `state_sets`, except as a map from chain ID to max # sequence number reachable from the state set. set_to_chain = [] # type: List[Dict[int, int]] for state_set in state_sets: chains = {} # type: Dict[int, int] set_to_chain.append(chains) for event_id in state_set: chain_id, seq_no = chain_info[event_id] chains[chain_id] = max(seq_no, chains.get(chain_id, 0)) # Now we look up all links for the chains we have, adding chains to # set_to_chain that are reachable from each set. sql = """ SELECT origin_chain_id, origin_sequence_number, target_chain_id, target_sequence_number FROM event_auth_chain_links WHERE %s """ # (We need to take a copy of `seen_chains` as we want to mutate it in # the loop) for batch in batch_iter(set(seen_chains), 1000): clause, args = make_in_list_sql_clause(txn.database_engine, "origin_chain_id", batch) txn.execute(sql % (clause, ), args) for ( origin_chain_id, origin_sequence_number, target_chain_id, target_sequence_number, ) in txn: for chains in set_to_chain: # chains are only reachable if the origin sequence number of # the link is less than the max sequence number in the # origin chain. if origin_sequence_number <= chains.get( origin_chain_id, 0): chains[target_chain_id] = max( target_sequence_number, chains.get(target_chain_id, 0), ) seen_chains.add(target_chain_id) # Now for each chain we figure out the maximum sequence number reachable # from *any* state set and the minimum sequence number reachable from # *all* state sets. Events in that range are in the auth chain # difference. result = set() # Mapping from chain ID to the range of sequence numbers that should be # pulled from the database. chain_to_gap = {} # type: Dict[int, Tuple[int, int]] for chain_id in seen_chains: min_seq_no = min( chains.get(chain_id, 0) for chains in set_to_chain) max_seq_no = max( chains.get(chain_id, 0) for chains in set_to_chain) if min_seq_no < max_seq_no: # We have a non empty gap, try and fill it from the events that # we have, otherwise add them to the list of gaps to pull out # from the DB. for seq_no in range(min_seq_no + 1, max_seq_no + 1): event_id = chain_to_event.get(chain_id, {}).get(seq_no) if event_id: result.add(event_id) else: chain_to_gap[chain_id] = (min_seq_no, max_seq_no) break if not chain_to_gap: # If there are no gaps to fetch, we're done! return result if isinstance(self.database_engine, PostgresEngine): # We can use `execute_values` to efficiently fetch the gaps when # using postgres. sql = """ SELECT event_id FROM event_auth_chains AS c, (VALUES ?) AS l(chain_id, min_seq, max_seq) WHERE c.chain_id = l.chain_id AND min_seq < sequence_number AND sequence_number <= max_seq """ args = [(chain_id, min_no, max_no) for chain_id, (min_no, max_no) in chain_to_gap.items()] rows = txn.execute_values(sql, args) result.update(r for r, in rows) else: # For SQLite we just fall back to doing a noddy for loop. sql = """ SELECT event_id FROM event_auth_chains WHERE chain_id = ? AND ? < sequence_number AND sequence_number <= ? """ for chain_id, (min_no, max_no) in chain_to_gap.items(): txn.execute(sql, (chain_id, min_no, max_no)) result.update(r for r, in txn) return result
def _fetch_event_rows(self, txn, event_ids): """Fetch event rows from the database Events which are not found are omitted from the result. The returned per-event dicts contain the following keys: * event_id (str) * json (str): json-encoded event structure * internal_metadata (str): json-encoded internal metadata dict * format_version (int|None): The format of the event. Hopefully one of EventFormatVersions. 'None' means the event predates EventFormatVersions (so the event is format V1). * rejected_reason (str|None): if the event was rejected, the reason why. * redactions (List[str]): a list of event-ids which (claim to) redact this event. Args: txn (twisted.enterprise.adbapi.Connection): event_ids (Iterable[str]): event IDs to fetch Returns: Dict[str, Dict]: a map from event id to event info. """ event_dict = {} for evs in batch_iter(event_ids, 200): sql = ("SELECT " " e.event_id, " " e.internal_metadata," " e.json," " e.format_version, " " rej.reason " " FROM event_json as e" " LEFT JOIN rejections as rej USING (event_id)" " WHERE ") clause, args = make_in_list_sql_clause(txn.database_engine, "e.event_id", evs) txn.execute(sql + clause, args) for row in txn: event_id = row[0] event_dict[event_id] = { "event_id": event_id, "internal_metadata": row[1], "json": row[2], "format_version": row[3], "rejected_reason": row[4], "redactions": [], } # check for redactions redactions_sql = "SELECT event_id, redacts FROM redactions WHERE " clause, args = make_in_list_sql_clause(txn.database_engine, "redacts", evs) txn.execute(redactions_sql + clause, args) for (redacter, redacted) in txn: d = event_dict.get(redacted) if d: d["redactions"].append(redacter) return event_dict
def _get_rooms_for_user_where_membership_is_txn( self, txn, user_id, membership_list ): do_invite = Membership.INVITE in membership_list membership_list = [m for m in membership_list if m != Membership.INVITE] results = [] if membership_list: if self._current_state_events_membership_up_to_date: clause, args = make_in_list_sql_clause( self.database_engine, "c.membership", membership_list ) sql = """ SELECT room_id, e.sender, c.membership, event_id, e.stream_ordering FROM current_state_events AS c INNER JOIN events AS e USING (room_id, event_id) WHERE c.type = 'm.room.member' AND state_key = ? AND %s """ % ( clause, ) else: clause, args = make_in_list_sql_clause( self.database_engine, "m.membership", membership_list ) sql = """ SELECT room_id, e.sender, m.membership, event_id, e.stream_ordering FROM current_state_events AS c INNER JOIN room_memberships AS m USING (room_id, event_id) INNER JOIN events AS e USING (room_id, event_id) WHERE c.type = 'm.room.member' AND state_key = ? AND %s """ % ( clause, ) txn.execute(sql, (user_id, *args)) results = [RoomsForUser(**r) for r in self.db.cursor_to_dict(txn)] if do_invite: sql = ( "SELECT i.room_id, inviter, i.event_id, e.stream_ordering" " FROM local_invites as i" " INNER JOIN events as e USING (event_id)" " WHERE invitee = ? AND locally_rejected is NULL" " AND replaced_by is NULL" ) txn.execute(sql, (user_id,)) results.extend( RoomsForUser( room_id=r["room_id"], sender=r["inviter"], event_id=r["event_id"], stream_ordering=r["stream_ordering"], membership=Membership.INVITE, ) for r in self.db.cursor_to_dict(txn) ) return results
def _get_auth_chain_ids_using_cover_index_txn( self, txn: Cursor, room_id: str, event_ids: Collection[str], include_given: bool) -> List[str]: """Calculates the auth chain IDs using the chain index.""" # First we look up the chain ID/sequence numbers for the given events. initial_events = set(event_ids) # All the events that we've found that are reachable from the events. seen_events = set() # type: Set[str] # A map from chain ID to max sequence number of the given events. event_chains = {} # type: Dict[int, int] sql = """ SELECT event_id, chain_id, sequence_number FROM event_auth_chains WHERE %s """ for batch in batch_iter(initial_events, 1000): clause, args = make_in_list_sql_clause(txn.database_engine, "event_id", batch) txn.execute(sql % (clause, ), args) for event_id, chain_id, sequence_number in txn: seen_events.add(event_id) event_chains[chain_id] = max(sequence_number, event_chains.get(chain_id, 0)) # Check that we actually have a chain ID for all the events. events_missing_chain_info = initial_events.difference(seen_events) if events_missing_chain_info: # This can happen due to e.g. downgrade/upgrade of the server. We # raise an exception and fall back to the previous algorithm. logger.info( "Unexpectedly found that events don't have chain IDs in room %s: %s", room_id, events_missing_chain_info, ) raise _NoChainCoverIndex(room_id) # Now we look up all links for the chains we have, adding chains that # are reachable from any event. sql = """ SELECT origin_chain_id, origin_sequence_number, target_chain_id, target_sequence_number FROM event_auth_chain_links WHERE %s """ # A map from chain ID to max sequence number *reachable* from any event ID. chains = {} # type: Dict[int, int] # Add all linked chains reachable from initial set of chains. for batch in batch_iter(event_chains, 1000): clause, args = make_in_list_sql_clause(txn.database_engine, "origin_chain_id", batch) txn.execute(sql % (clause, ), args) for ( origin_chain_id, origin_sequence_number, target_chain_id, target_sequence_number, ) in txn: # chains are only reachable if the origin sequence number of # the link is less than the max sequence number in the # origin chain. if origin_sequence_number <= event_chains.get( origin_chain_id, 0): chains[target_chain_id] = max( target_sequence_number, chains.get(target_chain_id, 0), ) # Add the initial set of chains, excluding the sequence corresponding to # initial event. for chain_id, seq_no in event_chains.items(): chains[chain_id] = max(seq_no - 1, chains.get(chain_id, 0)) # Now for each chain we figure out the maximum sequence number reachable # from *any* event ID. Events with a sequence less than that are in the # auth chain. if include_given: results = initial_events else: results = set() if isinstance(self.database_engine, PostgresEngine): # We can use `execute_values` to efficiently fetch the gaps when # using postgres. sql = """ SELECT event_id FROM event_auth_chains AS c, (VALUES ?) AS l(chain_id, max_seq) WHERE c.chain_id = l.chain_id AND sequence_number <= max_seq """ rows = txn.execute_values(sql, chains.items()) results.update(r for r, in rows) else: # For SQLite we just fall back to doing a noddy for loop. sql = """ SELECT event_id FROM event_auth_chains WHERE chain_id = ? AND sequence_number <= ? """ for chain_id, max_no in chains.items(): txn.execute(sql, (chain_id, max_no)) results.update(r for r, in txn) return list(results)