def r( txn: LoggingTransaction, ) -> Tuple[Optional[List[str]], List[Tuple], List[Tuple]]: forward_rows = [] backward_rows = [] if do_forward[0]: txn.execute(forward_select, (forward_chunk, self.batch_size)) forward_rows = txn.fetchall() if not forward_rows: do_forward[0] = False if do_backward[0]: txn.execute(backward_select, (backward_chunk, self.batch_size)) backward_rows = txn.fetchall() if not backward_rows: do_backward[0] = False if forward_rows or backward_rows: headers = [column[0] for column in txn.description] else: headers = None return headers, forward_rows, backward_rows
def _make_staging_area(txn: LoggingTransaction) -> None: sql = ("CREATE TABLE IF NOT EXISTS " + TEMP_TABLE + "_rooms(room_id TEXT NOT NULL, events BIGINT NOT NULL)") txn.execute(sql) sql = ("CREATE TABLE IF NOT EXISTS " + TEMP_TABLE + "_position(position TEXT NOT NULL)") txn.execute(sql) # Get rooms we want to process from the database sql = """ SELECT room_id, count(*) FROM current_state_events GROUP BY room_id """ txn.execute(sql) rooms = list(txn.fetchall()) self.db_pool.simple_insert_many_txn(txn, TEMP_TABLE + "_rooms", keys=("room_id", "events"), values=rooms) del rooms sql = ("CREATE TABLE IF NOT EXISTS " + TEMP_TABLE + "_users(user_id TEXT NOT NULL)") txn.execute(sql) txn.execute("SELECT name FROM users") users = list(txn.fetchall()) self.db_pool.simple_insert_many_txn(txn, TEMP_TABLE + "_users", keys=("user_id", ), values=users)
def _graph_to_linear(self, txn: LoggingTransaction, room_id: str, event_ids: List[str]) -> str: """ Generate a linearized event from a list of events (i.e. a list of forward extremities in the room). This should allow for calculation of the correct read receipt even if servers have different event ordering. Args: txn: The transaction room_id: The room ID the events are in. event_ids: The list of event IDs to linearize. Returns: The linearized event ID. """ # TODO: Make this better. clause, args = make_in_list_sql_clause(self.database_engine, "event_id", event_ids) sql = """ SELECT event_id WHERE room_id = ? AND stream_ordering IN ( SELECT max(stream_ordering) WHERE %s ) """ % (clause, ) txn.execute(sql, [room_id] + list(args)) rows = txn.fetchall() if rows: return rows[0][0] else: raise RuntimeError("Unrecognized event_ids: %r" % (event_ids, ))
def f( txn: LoggingTransaction, ) -> List[Tuple[str, str, int, int, str, bool, str, int]]: before_clause = "" if before: before_clause = "AND epa.stream_ordering < ?" args = [user_id, before, limit] else: args = [user_id, limit] if only_highlight: if len(before_clause) > 0: before_clause += " " before_clause += "AND epa.highlight = 1" # NB. This assumes event_ids are globally unique since # it makes the query easier to index sql = ( "SELECT epa.event_id, epa.room_id," " epa.stream_ordering, epa.topological_ordering," " epa.actions, epa.highlight, epa.profile_tag, e.received_ts" " FROM event_push_actions epa, events e" " WHERE epa.event_id = e.event_id" " AND epa.user_id = ? %s" " AND epa.notif = 1" " ORDER BY epa.stream_ordering DESC" " LIMIT ?" % (before_clause, )) txn.execute(sql, args) return cast(List[Tuple[str, str, int, int, str, bool, str, int]], txn.fetchall())
def get_after_receipt( txn: LoggingTransaction, ) -> List[Tuple[str, str, int, str, bool]]: # find rooms that have a read receipt in them and return the next # push actions sql = ( "SELECT ep.event_id, ep.room_id, ep.stream_ordering, ep.actions," " ep.highlight " " FROM (" " SELECT room_id," " MAX(stream_ordering) as stream_ordering" " FROM events" " INNER JOIN receipts_linearized USING (room_id, event_id)" " WHERE receipt_type = 'm.read' AND user_id = ?" " GROUP BY room_id" ") AS rl," " event_push_actions AS ep" " WHERE" " ep.room_id = rl.room_id" " AND ep.stream_ordering > rl.stream_ordering" " AND ep.user_id = ?" " AND ep.stream_ordering > ?" " AND ep.stream_ordering <= ?" " AND ep.notif = 1" " ORDER BY ep.stream_ordering ASC LIMIT ?") args = [ user_id, user_id, min_stream_ordering, max_stream_ordering, limit ] txn.execute(sql, args) return cast(List[Tuple[str, str, int, str, bool]], txn.fetchall())
def get_no_receipt( txn: LoggingTransaction, ) -> List[Tuple[str, str, int, str, bool, int]]: sql = ( "SELECT ep.event_id, ep.room_id, ep.stream_ordering, ep.actions," " ep.highlight, e.received_ts" " FROM event_push_actions AS ep" " INNER JOIN events AS e USING (room_id, event_id)" " WHERE" " ep.room_id NOT IN (" " SELECT room_id FROM receipts_linearized" " WHERE receipt_type = 'm.read' AND user_id = ?" " GROUP BY room_id" " )" " AND ep.user_id = ?" " AND ep.stream_ordering > ?" " AND ep.stream_ordering <= ?" " AND ep.notif = 1" " ORDER BY ep.stream_ordering DESC LIMIT ?") args = [ user_id, user_id, min_stream_ordering, max_stream_ordering, limit ] txn.execute(sql, args) return cast(List[Tuple[str, str, int, str, bool, int]], txn.fetchall())
def _delete_old_ui_auth_sessions_txn(self, txn: LoggingTransaction, expiration_time: int): # Get the expired sessions. sql = "SELECT session_id FROM ui_auth_sessions WHERE creation_time <= ?" txn.execute(sql, [expiration_time]) session_ids = [r[0] for r in txn.fetchall()] # Delete the corresponding IP/user agents. self.db_pool.simple_delete_many_txn( txn, table="ui_auth_sessions_ips", column="session_id", iterable=session_ids, keyvalues={}, ) # Delete the corresponding completed credentials. self.db_pool.simple_delete_many_txn( txn, table="ui_auth_sessions_credentials", column="session_id", iterable=session_ids, keyvalues={}, ) # Finally, delete the sessions. self.db_pool.simple_delete_many_txn( txn, table="ui_auth_sessions", column="session_id", iterable=session_ids, keyvalues={}, )
def _mark_as_sent_devices_by_remote_txn(self, txn: LoggingTransaction, destination: str, stream_id: int) -> None: # We update the device_lists_outbound_last_success with the successfully # poked users. sql = """ SELECT user_id, coalesce(max(o.stream_id), 0) FROM device_lists_outbound_pokes as o WHERE destination = ? AND o.stream_id <= ? GROUP BY user_id """ txn.execute(sql, (destination, stream_id)) rows = txn.fetchall() self.db_pool.simple_upsert_many_txn( txn=txn, table="device_lists_outbound_last_success", key_names=("destination", "user_id"), key_values=((destination, user_id) for user_id, _ in rows), value_names=("stream_id", ), value_values=((stream_id, ) for _, stream_id in rows), ) # Delete all sent outbound pokes sql = """ DELETE FROM device_lists_outbound_pokes WHERE destination = ? AND stream_id <= ? """ txn.execute(sql, (destination, stream_id))
def _get_e2e_cross_signing_signatures_for_devices_txn( self, txn: LoggingTransaction, device_query: Iterable[Tuple[str, str]] ) -> List[Tuple[str, str, str, str]]: """Get cross-signing signatures for a given list of devices Returns signatures made by the owners of the devices. Returns: a list of results; each entry in the list is a tuple of (user_id, key_id, target_device_id, signature). """ signature_query_clauses = [] signature_query_params = [] for (user_id, device_id) in device_query: signature_query_clauses.append( "target_user_id = ? AND target_device_id = ? AND user_id = ?") signature_query_params.extend([user_id, device_id, user_id]) signature_sql = """ SELECT user_id, key_id, target_device_id, signature FROM e2e_cross_signing_signatures WHERE %s """ % (" OR ".join("(" + q + ")" for q in signature_query_clauses)) txn.execute(signature_sql, signature_query_params) return cast( List[Tuple[str, str, str, str, ]], txn.fetchall(), )
def reindex_search_txn(txn: LoggingTransaction) -> Tuple[int, bool]: sql = ( "UPDATE event_search AS es SET stream_ordering = e.stream_ordering," " origin_server_ts = e.origin_server_ts" " FROM events AS e" " WHERE e.event_id = es.event_id" " AND ? <= e.stream_ordering AND e.stream_ordering < ?" " RETURNING es.stream_ordering") min_stream_id = max_stream_id - batch_size txn.execute(sql, (min_stream_id, max_stream_id)) rows = txn.fetchall() if min_stream_id < target_min_stream_id: # We've recached the end. return len(rows), False progress = { "target_min_stream_id_inclusive": target_min_stream_id, "max_stream_id_exclusive": min_stream_id, "rows_inserted": rows_inserted + len(rows), "have_added_indexes": True, } self.db_pool.updates._background_update_progress_txn( txn, self.EVENT_SEARCH_ORDER_UPDATE_NAME, progress) return len(rows), True
def _get_next_batch( txn: LoggingTransaction, ) -> Optional[Sequence[Tuple[str, int]]]: # Only fetch 250 rooms, so we don't fetch too many at once, even # if those 250 rooms have less than batch_size state events. sql = """ SELECT room_id, events FROM %s ORDER BY events DESC LIMIT 250 """ % ( TEMP_TABLE + "_rooms", ) txn.execute(sql) rooms_to_work_on = cast(List[Tuple[str, int]], txn.fetchall()) if not rooms_to_work_on: return None # Get how many are left to process, so we can give status on how # far we are in processing txn.execute("SELECT COUNT(*) FROM " + TEMP_TABLE + "_rooms") result = txn.fetchone() assert result is not None progress["remaining"] = result[0] return rooms_to_work_on
def get_updated_global_account_data_txn( txn: LoggingTransaction, ) -> List[Tuple[int, str, str]]: sql = ("SELECT stream_id, user_id, account_data_type" " FROM account_data WHERE ? < stream_id AND stream_id <= ?" " ORDER BY stream_id ASC LIMIT ?") txn.execute(sql, (last_id, current_id, limit)) return cast(List[Tuple[int, str, str]], txn.fetchall())
def r(txn: LoggingTransaction) -> Tuple[List[str], List[Tuple]]: txn.execute(select) rows = txn.fetchall() headers: List[str] = [column[0] for column in txn.description] ts_ind = headers.index("ts") return headers, [r for r in rows if r[ts_ind] < yesterday]
def _get_applicable_edits_txn( txn: LoggingTransaction) -> Dict[str, str]: clause, args = make_in_list_sql_clause(txn.database_engine, "relates_to_id", event_ids) args.append(RelationTypes.REPLACE) txn.execute(sql % (clause, ), args) return dict(cast(Iterable[Tuple[str, str]], txn.fetchall()))
def get_all_updated_tags_txn( txn: LoggingTransaction, ) -> List[Tuple[int, str, str]]: sql = ("SELECT stream_id, user_id, room_id" " FROM room_tags_revisions as r" " WHERE ? < stream_id AND stream_id <= ?" " ORDER BY stream_id ASC LIMIT ?") txn.execute(sql, (last_id, current_id, limit)) # mypy doesn't understand what the query is selecting. return cast(List[Tuple[int, str, str]], txn.fetchall())
def _get_max_topological_txn(self, txn: LoggingTransaction, room_id: str) -> int: txn.execute( "SELECT MAX(topological_ordering) FROM events WHERE room_id = ?", (room_id, ), ) rows = txn.fetchall() return rows[0][0] if rows else 0
def _devices_last_seen_update_txn(txn: LoggingTransaction) -> int: # This consists of two queries: # # 1. The sub-query searches for the next N devices and joins # against user_ips to find the max last_seen associated with # that device. # 2. The outer query then joins again against user_ips on # user/device/last_seen. This *should* hopefully only # return one row, but if it does return more than one then # we'll just end up updating the same device row multiple # times, which is fine. where_args: List[Union[str, int]] where_clause, where_args = make_tuple_comparison_clause( [("user_id", last_user_id), ("device_id", last_device_id)], ) sql = """ SELECT last_seen, ip, user_agent, user_id, device_id FROM ( SELECT user_id, device_id, MAX(u.last_seen) AS last_seen FROM devices INNER JOIN user_ips AS u USING (user_id, device_id) WHERE %(where_clause)s GROUP BY user_id, device_id ORDER BY user_id ASC, device_id ASC LIMIT ? ) c INNER JOIN user_ips AS u USING (user_id, device_id, last_seen) """ % { "where_clause": where_clause } txn.execute(sql, where_args + [batch_size]) rows = cast(List[Tuple[int, str, str, str, str]], txn.fetchall()) if not rows: return 0 sql = """ UPDATE devices SET last_seen = ?, ip = ?, user_agent = ? WHERE user_id = ? AND device_id = ? """ txn.execute_batch(sql, rows) _, _, _, user_id, device_id = rows[-1] self.db_pool.updates._background_update_progress_txn( txn, "devices_last_seen", { "last_user_id": user_id, "last_device_id": device_id }, ) return len(rows)
def _remove_hidden_devices_from_device_inbox_txn( txn: LoggingTransaction, ) -> int: """stream_id is not unique we need to use an inclusive `stream_id >= ?` clause, since we might not have deleted all hidden device messages for the stream_id returned from the previous query Then delete only rows matching the `(user_id, device_id, stream_id)` tuple, to avoid problems of deleting a large number of rows all at once due to a single device having lots of device messages. """ last_stream_id = progress.get("stream_id", 0) sql = """ SELECT device_id, user_id, stream_id FROM device_inbox WHERE stream_id >= ? AND (device_id, user_id) IN ( SELECT device_id, user_id FROM devices WHERE hidden = ? ) ORDER BY stream_id LIMIT ? """ txn.execute(sql, (last_stream_id, True, batch_size)) rows = txn.fetchall() num_deleted = 0 for row in rows: num_deleted += self.db_pool.simple_delete_txn( txn, "device_inbox", { "device_id": row[0], "user_id": row[1], "stream_id": row[2] }, ) if rows: # We don't just save the `stream_id` in progress as # otherwise it can happen in large deployments that # no change of status is visible in the log file, as # it may be that the stream_id does not change in several runs self.db_pool.updates._background_update_progress_txn( txn, self.REMOVE_HIDDEN_DEVICES, { "device_id": rows[-1][0], "user_id": rows[-1][1], "stream_id": rows[-1][2], }, ) return num_deleted
def reindex_search_txn(txn: LoggingTransaction) -> int: sql = ("SELECT stream_ordering, event_id FROM events" " WHERE ? <= stream_ordering AND stream_ordering < ?" " ORDER BY stream_ordering DESC" " LIMIT ?") txn.execute(sql, (target_min_stream_id, max_stream_id, batch_size)) rows = txn.fetchall() if not rows: return 0 min_stream_id = rows[-1][0] event_ids = [row[1] for row in rows] rows_to_update = [] chunks = [ event_ids[i:i + 100] for i in range(0, len(event_ids), 100) ] for chunk in chunks: ev_rows = self.db_pool.simple_select_many_txn( txn, table="event_json", column="event_id", iterable=chunk, retcols=["event_id", "json"], keyvalues={}, ) for row in ev_rows: event_id = row["event_id"] event_json = db_to_json(row["json"]) try: origin_server_ts = event_json["origin_server_ts"] except (KeyError, AttributeError): # If the event is missing a necessary field then # skip over it. continue rows_to_update.append((origin_server_ts, event_id)) sql = "UPDATE events SET origin_server_ts = ? WHERE event_id = ?" txn.execute_batch(sql, rows_to_update) progress = { "target_min_stream_id_inclusive": target_min_stream_id, "max_stream_id_exclusive": min_stream_id, "rows_inserted": rows_inserted + len(rows_to_update), } self.db_pool.updates._background_update_progress_txn( txn, _BackgroundUpdates.EVENT_ORIGIN_SERVER_TS_NAME, progress) return len(rows_to_update)
def _list_users(txn: LoggingTransaction) -> List[Tuple[str, str]]: sql = f""" SELECT COALESCE(appservice_id, 'native'), user_id FROM monthly_active_users LEFT JOIN users ON monthly_active_users.user_id=users.name {where_clause}; """ txn.execute(sql, query_params) return cast(List[Tuple[str, str]], txn.fetchall())
def f(txn: LoggingTransaction) -> List[Tuple[str, str, int, int]]: sql = ("SELECT rl.room_id, rl.event_id," " e.topological_ordering, e.stream_ordering" " FROM receipts_linearized AS rl" " INNER JOIN events AS e USING (room_id, event_id)" " WHERE rl.room_id = e.room_id" " AND rl.event_id = e.event_id" " AND user_id = ?") txn.execute(sql, (user_id, )) return txn.fetchall()
def _get_event_relations( txn: LoggingTransaction, ) -> Dict[str, Set[Tuple[str, str]]]: txn.execute(sql, [event_id] + rel_type_args) result: Dict[str, Set[Tuple[str, str]]] = { rel_type: set() for rel_type in relation_types } for rel_type, sender, type in txn.fetchall(): result[rel_type].add((sender, type)) return result
def purged_chain_cover_txn(txn: LoggingTransaction) -> int: # The event ID from events will be null if the chain ID / sequence # number points to a purged event. sql = """ SELECT event_id, chain_id, sequence_number, e.event_id IS NOT NULL FROM event_auth_chains LEFT JOIN events AS e USING (event_id) WHERE event_id > ? ORDER BY event_auth_chains.event_id ASC LIMIT ? """ txn.execute(sql, (current_event_id, batch_size)) rows = txn.fetchall() if not rows: return 0 # The event IDs and chain IDs / sequence numbers where the event has # been purged. unreferenced_event_ids = [] unreferenced_chain_id_tuples = [] event_id = "" for event_id, chain_id, sequence_number, has_event in rows: if not has_event: unreferenced_event_ids.append((event_id, )) unreferenced_chain_id_tuples.append( (chain_id, sequence_number)) # Delete the unreferenced auth chains from event_auth_chain_links and # event_auth_chains. txn.executemany( """ DELETE FROM event_auth_chains WHERE event_id = ? """, unreferenced_event_ids, ) # We should also delete matching target_*, but there is no index on # target_chain_id. Hopefully any purged events are due to a room # being fully purged and they will be removed from the origin_* # searches. txn.executemany( """ DELETE FROM event_auth_chain_links WHERE origin_chain_id = ? AND origin_sequence_number = ? """, unreferenced_chain_id_tuples, ) progress = { "current_event_id": event_id, } self.db_pool.updates._background_update_progress_txn( txn, "purged_chain_cover", progress) return len(rows)
def get_recent(txn: LoggingTransaction) -> List[Tuple[str, str, str, int]]: txn.execute( """ SELECT access_token, ip, user_agent, last_seen FROM user_ips WHERE last_seen >= ? AND user_id = ? ORDER BY last_seen DESC """, (since_ts, user_id), ) return cast(List[Tuple[str, str, str, int]], txn.fetchall())
def _count_users_by_service(txn: LoggingTransaction) -> Dict[str, int]: sql = """ SELECT COALESCE(appservice_id, 'native'), COUNT(*) FROM monthly_active_users LEFT JOIN users ON monthly_active_users.user_id=users.name GROUP BY appservice_id; """ txn.execute(sql) result = cast(List[Tuple[str, int]], txn.fetchall()) return dict(result)
def get_start_id(txn: LoggingTransaction) -> int: txn.execute( "SELECT rowid FROM sent_transactions WHERE ts >= ?" " ORDER BY rowid ASC LIMIT 1", (yesterday, ), ) rows = txn.fetchall() if rows: return rows[0][0] else: return 1
def fetch(txn: LoggingTransaction) -> List[Tuple[int, int]]: txn.execute(""" SELECT t1.c, t2.c FROM ( SELECT room_id, COUNT(*) c FROM event_forward_extremities GROUP BY room_id ) t1 LEFT JOIN ( SELECT room_id, COUNT(*) c FROM current_state_events GROUP BY room_id ) t2 ON t1.room_id = t2.room_id """) return cast(List[Tuple[int, int]], txn.fetchall())
def get_user_id_for_token_txn(txn: LoggingTransaction) -> Optional[str]: sql = ( "SELECT user_id FROM open_id_tokens" " WHERE token = ? AND ? <= ts_valid_until_ms" ) txn.execute(sql, (token, ts_now_ms)) rows = txn.fetchall() if not rows: return None else: return rows[0][0]
def reindex_txn(txn: LoggingTransaction) -> int: sql = ("SELECT stream_ordering, event_id, json FROM events" " INNER JOIN event_json USING (event_id)" " WHERE ? <= stream_ordering AND stream_ordering < ?" " ORDER BY stream_ordering DESC" " LIMIT ?") txn.execute(sql, (target_min_stream_id, max_stream_id, batch_size)) rows = txn.fetchall() if not rows: return 0 min_stream_id = rows[-1][0] update_rows = [] for row in rows: try: event_id = row[1] event_json = db_to_json(row[2]) sender = event_json["sender"] content = event_json["content"] contains_url = "url" in content if contains_url: contains_url &= isinstance(content["url"], str) except (KeyError, AttributeError): # If the event is missing a necessary field then # skip over it. continue update_rows.append((sender, contains_url, event_id)) sql = "UPDATE events SET sender = ?, contains_url = ? WHERE event_id = ?" txn.execute_batch(sql, update_rows) progress = { "target_min_stream_id_inclusive": target_min_stream_id, "max_stream_id_exclusive": min_stream_id, "rows_inserted": rows_inserted + len(rows), } self.db_pool.updates._background_update_progress_txn( txn, _BackgroundUpdates.EVENT_FIELDS_SENDER_URL_UPDATE_NAME, progress) return len(rows)
def f(txn: LoggingTransaction) -> Set[str]: highlight_words = set() for event in events: # As a hack we simply join values of all possible keys. This is # fine since we're only using them to find possible highlights. values = [] for key in ("body", "name", "topic"): v = event.content.get(key, None) if v: v = _clean_value_for_search(v) values.append(v) if not values: continue value = " ".join(values) # We need to find some values for StartSel and StopSel that # aren't in the value so that we can pick results out. start_sel = "<" stop_sel = ">" while start_sel in value: start_sel += "<" while stop_sel in value: stop_sel += ">" query = "SELECT ts_headline(?, to_tsquery('english', ?), %s)" % ( _to_postgres_options({ "StartSel": start_sel, "StopSel": stop_sel, "MaxFragments": "50", })) txn.execute(query, (value, search_query)) (headline, ) = txn.fetchall()[0] # Now we need to pick the possible highlights out of the haedline # result. matcher_regex = "%s(.*?)%s" % ( re.escape(start_sel), re.escape(stop_sel), ) res = re.findall(matcher_regex, headline) highlight_words.update([r.lower() for r in res]) return highlight_words