def _do_txn(txn: LoggingTransaction) -> int: sql = ( "SELECT filter_id FROM user_filters " "WHERE user_id = ? AND filter_json = ?" ) txn.execute(sql, (user_localpart, bytearray(def_json))) filter_id_response = txn.fetchone() if filter_id_response is not None: return filter_id_response[0] sql = "SELECT MAX(filter_id) FROM user_filters WHERE user_id = ?" txn.execute(sql, (user_localpart,)) max_id = cast(Tuple[Optional[int]], txn.fetchone())[0] if max_id is None: filter_id = 0 else: filter_id = max_id + 1 sql = ( "INSERT INTO user_filters (user_id, filter_id, filter_json)" "VALUES(?, ?, ?)" ) txn.execute(sql, (user_localpart, filter_id, bytearray(def_json))) return filter_id
def _get_thread_summary_txn( txn: LoggingTransaction, ) -> Tuple[int, Optional[str]]: # Fetch the count of threaded events and the latest event ID. # TODO Should this only allow m.room.message events. sql = """ SELECT event_id FROM event_relations INNER JOIN events USING (event_id) WHERE relates_to_id = ? AND relation_type = ? ORDER BY topological_ordering DESC, stream_ordering DESC LIMIT 1 """ txn.execute(sql, (event_id, RelationTypes.THREAD)) row = txn.fetchone() if row is None: return 0, None latest_event_id = row[0] sql = """ SELECT COALESCE(COUNT(event_id), 0) FROM event_relations WHERE relates_to_id = ? AND relation_type = ? """ txn.execute(sql, (event_id, RelationTypes.THREAD)) count = txn.fetchone()[0] # type: ignore[index] return count, latest_event_id
def get_destination_rooms_paginate_txn( txn: LoggingTransaction, ) -> Tuple[List[JsonDict], int]: if direction == "b": order = "DESC" else: order = "ASC" sql = """ SELECT COUNT(*) as total_rooms FROM destination_rooms WHERE destination = ? """ txn.execute(sql, [destination]) count = cast(Tuple[int], txn.fetchone())[0] rooms = self.db_pool.simple_select_list_paginate_txn( txn=txn, table="destination_rooms", orderby="room_id", start=start, limit=limit, retcols=("room_id", "stream_ordering"), order_direction=order, ) return rooms, count
def get_sent_table_size(txn: LoggingTransaction) -> int: txn.execute( "SELECT count(*) FROM sent_transactions" " WHERE ts >= ?", (yesterday, )) result = txn.fetchone() assert result is not None return int(result[0])
def _claim_e2e_one_time_key_returning( txn: LoggingTransaction, user_id: str, device_id: str, algorithm: str) -> Optional[Tuple[str, str]]: """Claim OTK for device for DBs that support RETURNING. Returns: A tuple of key name (algorithm + key ID) and key JSON, if an OTK was found. """ # We can use RETURNING to do the fetch and DELETE in once step. sql = """ DELETE FROM e2e_one_time_keys_json WHERE user_id = ? AND device_id = ? AND algorithm = ? AND key_id IN ( SELECT key_id FROM e2e_one_time_keys_json WHERE user_id = ? AND device_id = ? AND algorithm = ? LIMIT 1 ) RETURNING key_id, key_json """ txn.execute( sql, (user_id, device_id, algorithm, user_id, device_id, algorithm)) otk_row = txn.fetchone() if otk_row is None: return None self._invalidate_cache_and_stream(txn, self.count_e2e_one_time_keys, (user_id, device_id)) key_id, key_json = otk_row return f"{algorithm}:{key_id}", key_json
def _delete_room_alias_txn( self, txn: LoggingTransaction, room_alias: RoomAlias ) -> Optional[str]: txn.execute( "SELECT room_id FROM room_aliases WHERE room_alias = ?", (room_alias.to_string(),), ) res = txn.fetchone() if res: room_id = res[0] else: return None txn.execute( "DELETE FROM room_aliases WHERE room_alias = ?", (room_alias.to_string(),) ) txn.execute( "DELETE FROM room_alias_servers WHERE room_alias = ?", (room_alias.to_string(),), ) self._invalidate_cache_and_stream(txn, self.get_aliases_for_room, (room_id,)) return room_id
def get_destinations_paginate_txn( txn: LoggingTransaction, ) -> Tuple[List[JsonDict], int]: order_by_column = DestinationSortOrder(order_by).value if direction == "b": order = "DESC" else: order = "ASC" args: List[object] = [] where_statement = "" if destination: args.extend(["%" + destination.lower() + "%"]) where_statement = "WHERE LOWER(destination) LIKE ?" sql_base = f"FROM destinations {where_statement} " sql = f"SELECT COUNT(*) as total_destinations {sql_base}" txn.execute(sql, args) count = cast(Tuple[int], txn.fetchone())[0] sql = f""" SELECT destination, retry_last_ts, retry_interval, failure_ts, last_successful_stream_ordering {sql_base} ORDER BY {order_by_column} {order}, destination ASC LIMIT ? OFFSET ? """ txn.execute(sql, args + [limit, start]) destinations = self.db_pool.cursor_to_dict(txn) return destinations, count
def _get_next_batch( txn: LoggingTransaction, ) -> Optional[Sequence[Tuple[str, int]]]: # Only fetch 250 rooms, so we don't fetch too many at once, even # if those 250 rooms have less than batch_size state events. sql = """ SELECT room_id, events FROM %s ORDER BY events DESC LIMIT 250 """ % ( TEMP_TABLE + "_rooms", ) txn.execute(sql) rooms_to_work_on = cast(List[Tuple[str, int]], txn.fetchall()) if not rooms_to_work_on: return None # Get how many are left to process, so we can give status on how # far we are in processing txn.execute("SELECT COUNT(*) FROM " + TEMP_TABLE + "_rooms") result = txn.fetchone() assert result is not None progress["remaining"] = result[0] return rooms_to_work_on
def _fetch_current_state_stats( txn: LoggingTransaction, ) -> Tuple[List[str], Dict[str, int], int, List[str], int]: pos = self.get_room_max_stream_ordering( ) # type: ignore[attr-defined] rows = self.db_pool.simple_select_many_txn( txn, table="current_state_events", column="type", iterable=[ EventTypes.Create, EventTypes.JoinRules, EventTypes.RoomHistoryVisibility, EventTypes.RoomEncryption, EventTypes.Name, EventTypes.Topic, EventTypes.RoomAvatar, EventTypes.CanonicalAlias, ], keyvalues={ "room_id": room_id, "state_key": "" }, retcols=["event_id"], ) event_ids = cast(List[str], [row["event_id"] for row in rows]) txn.execute( """ SELECT membership, count(*) FROM current_state_events WHERE room_id = ? AND type = 'm.room.member' GROUP BY membership """, (room_id, ), ) membership_counts = {membership: cnt for membership, cnt in txn} txn.execute( """ SELECT COUNT(*) FROM current_state_events WHERE room_id = ? """, (room_id, ), ) current_state_events_count = cast(Tuple[int], txn.fetchone())[0] users_in_room = self.get_users_in_room_txn( txn, room_id) # type: ignore[attr-defined] return ( event_ids, membership_counts, current_state_events_count, users_in_room, pos, )
def _count(txn: LoggingTransaction) -> int: sql = """ SELECT COUNT(DISTINCT room_id) FROM events WHERE type = 'm.room.message' AND stream_ordering > ? """ txn.execute(sql, (self.stream_ordering_day_ago, )) (count, ) = cast(Tuple[int], txn.fetchone()) return count
def _get_unread_counts_by_pos_txn(self, txn: LoggingTransaction, room_id: str, user_id: str, stream_ordering: int) -> NotifCounts: sql = ("SELECT" " COUNT(CASE WHEN notif = 1 THEN 1 END)," " COUNT(CASE WHEN highlight = 1 THEN 1 END)," " COUNT(CASE WHEN unread = 1 THEN 1 END)" " FROM event_push_actions ea" " WHERE user_id = ?" " AND room_id = ?" " AND stream_ordering > ?") txn.execute(sql, (user_id, room_id, stream_ordering)) row = txn.fetchone() (notif_count, highlight_count, unread_count) = (0, 0, 0) if row: (notif_count, highlight_count, unread_count) = row txn.execute( """ SELECT notif_count, unread_count FROM event_push_summary WHERE room_id = ? AND user_id = ? AND stream_ordering > ? """, (room_id, user_id, stream_ordering), ) row = txn.fetchone() if row: notif_count += row[0] if row[1] is not None: # The unread_count column of event_push_summary is NULLable, so we need # to make sure we don't try increasing the unread counts if it's NULL # for this row. unread_count += row[1] return NotifCounts( notify_count=notif_count, unread_count=unread_count, highlight_count=highlight_count, )
def _get_if_maybe_push_in_range_for_user_txn( txn: LoggingTransaction) -> bool: sql = """ SELECT 1 FROM event_push_actions WHERE user_id = ? AND stream_ordering > ? AND notif = 1 LIMIT 1 """ txn.execute(sql, (user_id, min_stream_ordering)) return bool(txn.fetchone())
def f(txn: LoggingTransaction) -> Optional[Tuple[int]]: sql = ( "SELECT e.received_ts" " FROM event_push_actions AS ep" " JOIN events e ON ep.room_id = e.room_id AND ep.event_id = e.event_id" " WHERE ep.stream_ordering > ? AND notif = 1" " ORDER BY ep.stream_ordering ASC" " LIMIT 1") txn.execute(sql, (stream_ordering, )) return cast(Optional[Tuple[int]], txn.fetchone())
def _set_push_rule_enabled_txn( self, txn: LoggingTransaction, stream_id: int, event_stream_ordering: int, user_id: str, rule_id: str, enabled: bool, is_default_rule: bool, ) -> None: new_id = self._push_rules_enable_id_gen.get_next() if not is_default_rule: # first check it exists; we need to lock for key share so that a # transaction that deletes the push rule will conflict with this one. # We also need a push_rule_enable row to exist for every push_rules # row, otherwise it is possible to simultaneously delete a push rule # (that has no _enable row) and enable it, resulting in a dangling # _enable row. To solve this: we either need to use SERIALISABLE or # ensure we always have a push_rule_enable row for every push_rule # row. We chose the latter. for_key_share = "FOR KEY SHARE" if not isinstance(self.database_engine, PostgresEngine): # For key share is not applicable/available on SQLite for_key_share = "" sql = (""" SELECT 1 FROM push_rules WHERE user_name = ? AND rule_id = ? %s """ % for_key_share) txn.execute(sql, (user_id, rule_id)) if txn.fetchone() is None: raise RuleNotFoundException("Push rule does not exist.") self.db_pool.simple_upsert_txn( txn, "push_rules_enable", { "user_name": user_id, "rule_id": rule_id }, {"enabled": 1 if enabled else 0}, {"id": new_id}, ) self._insert_push_rules_update_txn( txn, stream_id, event_stream_ordering, user_id, rule_id, op="ENABLE" if enabled else "DISABLE", )
def _count_users(txn: LoggingTransaction) -> int: # Exclude app service users sql = """ SELECT COUNT(*) FROM monthly_active_users LEFT JOIN users ON monthly_active_users.user_id=users.name WHERE (users.appservice_id IS NULL OR users.appservice_id = ''); """ txn.execute(sql) (count, ) = cast(Tuple[int], txn.fetchone()) return count
def _get_current_version(txn: LoggingTransaction, user_id: str) -> int: txn.execute( "SELECT MAX(version) FROM e2e_room_keys_versions " "WHERE user_id=? AND deleted=0", (user_id, ), ) # `SELECT MAX() FROM ...` will always return 1 row. The value in that row will # be `NULL` when there are no available versions. row = cast(Tuple[Optional[int]], txn.fetchone()) if row[0] is None: raise StoreError(404, "No current backup version") return row[0]
def _get_if_user_has_annotated_event(txn: LoggingTransaction) -> bool: txn.execute( sql, ( parent_id, RelationTypes.ANNOTATION, event_type, sender, aggregation_key, ), ) return bool(txn.fetchone())
def get_url_cache_txn(txn: LoggingTransaction) -> Optional[Dict[str, Any]]: # get the most recently cached result (relative to the given ts) sql = ( "SELECT response_code, etag, expires_ts, og, media_id, download_ts" " FROM local_media_repository_url_cache" " WHERE url = ? AND download_ts <= ?" " ORDER BY download_ts DESC LIMIT 1" ) txn.execute(sql, (url, ts)) row = txn.fetchone() if not row: # ...or if we've requested a timestamp older than the oldest # copy in the cache, return the oldest copy (if any) sql = ( "SELECT response_code, etag, expires_ts, og, media_id, download_ts" " FROM local_media_repository_url_cache" " WHERE url = ? AND download_ts > ?" " ORDER BY download_ts ASC LIMIT 1" ) txn.execute(sql, (url, ts)) row = txn.fetchone() if not row: return None return dict( zip( ( "response_code", "etag", "expires_ts", "og", "media_id", "download_ts", ), row, ) )
def f(txn: LoggingTransaction) -> None: # first check if they are already in the list txn.execute("SELECT 1 FROM erased_users WHERE user_id = ?", (user_id, )) if txn.fetchone(): return # they are not already there: do the insert. txn.execute("INSERT INTO erased_users (user_id) VALUES (?)", (user_id, )) self._invalidate_cache_and_stream(txn, self.is_user_erased, (user_id, ))
def get_users_paginate_txn( txn: LoggingTransaction, ) -> Tuple[List[JsonDict], int]: filters = [] args = [self.hs.config.server.server_name] # Set ordering order_by_column = UserSortOrder(order_by).value if direction == "b": order = "DESC" else: order = "ASC" # `name` is in database already in lower case if name: filters.append("(name LIKE ? OR LOWER(displayname) LIKE ?)") args.extend( ["@%" + name.lower() + "%:%", "%" + name.lower() + "%"]) elif user_id: filters.append("name LIKE ?") args.extend(["%" + user_id.lower() + "%"]) if not guests: filters.append("is_guest = 0") if not deactivated: filters.append("deactivated = 0") where_clause = "WHERE " + " AND ".join(filters) if len( filters) > 0 else "" sql_base = f""" FROM users as u LEFT JOIN profiles AS p ON u.name = '@' || p.user_id || ':' || ? {where_clause} """ sql = "SELECT COUNT(*) as total_users " + sql_base txn.execute(sql, args) count = cast(Tuple[int], txn.fetchone())[0] sql = f""" SELECT name, user_type, is_guest, admin, deactivated, shadow_banned, displayname, avatar_url, creation_ts * 1000 as creation_ts {sql_base} ORDER BY {order_by_column} {order}, u.name ASC LIMIT ? OFFSET ? """ args += [limit, start] txn.execute(sql, args) users = self.db_pool.cursor_to_dict(txn) return users, count
def f(txn: LoggingTransaction) -> None: # first check if they are already in the list txn.execute("SELECT 1 FROM erased_users WHERE user_id = ?", (user_id, )) if not txn.fetchone(): return # They are there, delete them. self.db_pool.simple_delete_one_txn(txn, "erased_users", keyvalues={"user_id": user_id}) self._invalidate_cache_and_stream(txn, self.is_user_erased, (user_id, ))
def _calculate_and_set_initial_state_for_user_txn( txn: LoggingTransaction, ) -> Tuple[int, int]: pos = self._get_max_stream_id_in_current_state_deltas_txn(txn) txn.execute( """ SELECT COUNT(distinct room_id) FROM current_state_events WHERE type = 'm.room.member' AND state_key = ? AND membership = 'join' """, (user_id, ), ) count = cast(Tuple[int], txn.fetchone())[0] return count, pos
def _remove_dead_devices_from_device_inbox_txn( txn: LoggingTransaction, ) -> Tuple[int, bool]: if "max_stream_id" in progress: max_stream_id = progress["max_stream_id"] else: txn.execute("SELECT max(stream_id) FROM device_inbox") # There's a type mismatch here between how we want to type the row and # what fetchone says it returns, but we silence it because we know that # res can't be None. res: Tuple[ Optional[int]] = txn.fetchone() # type: ignore[assignment] if res[0] is None: # this can only happen if the `device_inbox` table is empty, in which # case we have no work to do. return 0, True else: max_stream_id = res[0] start = progress.get("stream_id", 0) stop = start + batch_size # delete rows in `device_inbox` which do *not* correspond to a known, # unhidden device. sql = """ DELETE FROM device_inbox WHERE stream_id >= ? AND stream_id < ? AND NOT EXISTS ( SELECT * FROM devices d WHERE d.device_id=device_inbox.device_id AND d.user_id=device_inbox.user_id AND NOT hidden ) """ txn.execute(sql, (start, stop)) self.db_pool.updates._background_update_progress_txn( txn, self.REMOVE_DEAD_DEVICES_FROM_INBOX, { "stream_id": stop, "max_stream_id": max_stream_id, }, ) return stop > max_stream_id
def _count_messages(txn: LoggingTransaction) -> int: # This is good enough as if you have silly characters in your own # hostname then that's your own fault. like_clause = "%:" + self.hs.hostname sql = """ SELECT COUNT(*) FROM events WHERE type = 'm.room.message' AND sender LIKE ? AND stream_ordering > ? """ txn.execute(sql, (like_clause, self.stream_ordering_day_ago)) (count, ) = cast(Tuple[int], txn.fetchone()) return count
def _get_thread_summary_txn(txn: LoggingTransaction) -> bool: # Fetch whether the requester has participated or not. sql = """ SELECT 1 FROM event_relations INNER JOIN events USING (event_id) WHERE relates_to_id = ? AND room_id = ? AND relation_type = ? AND sender = ? """ txn.execute(sql, (event_id, room_id, RelationTypes.THREAD, user_id)) return bool(txn.fetchone())
def _get_session(txn: LoggingTransaction, session_type: str, session_id: str, ts: int) -> JsonDict: # This includes the expiry time since items are only periodically # deleted, not upon expiry. select_sql = """ SELECT value FROM sessions WHERE session_type = ? AND session_id = ? AND expiry_time_ms > ? """ txn.execute(select_sql, [session_type, session_id, ts]) row = txn.fetchone() if not row: raise StoreError(404, "No session") return db_to_json(row[0])
def get_type_stream_id_for_appservice_txn( txn: LoggingTransaction) -> int: stream_id_type = "%s_stream_id" % type txn.execute( # We do NOT want to escape `stream_id_type`. "SELECT %s FROM application_services_state WHERE as_id=?" % stream_id_type, (service.id, ), ) last_stream_id = txn.fetchone() if last_stream_id is None or last_stream_id[ 0] is None: # no row exists # Stream tokens always start from 1, to avoid foot guns around `0` being falsey. return 1 else: return int(last_stream_id[0])
def get_last_seen(txn: LoggingTransaction) -> Optional[int]: txn.execute( """ SELECT last_seen FROM user_ips WHERE last_seen > ? ORDER BY last_seen LIMIT 1 OFFSET ? """, (begin_last_seen, batch_size), ) row = cast(Optional[Tuple[int]], txn.fetchone()) if row: return row[0] else: return None
def _count_users(self, txn: LoggingTransaction, time_from: int) -> int: """ Returns number of users seen in the past time_from period """ sql = """ SELECT COUNT(*) FROM ( SELECT user_id FROM user_ips WHERE last_seen > ? GROUP BY user_id ) u """ txn.execute(sql, (time_from, )) # Mypy knows that fetchone() might return None if there are no rows. # We know better: "SELECT COUNT(...) FROM ..." without any GROUP BY always # returns exactly one row. (count, ) = cast(Tuple[int], txn.fetchone()) return count
def get_local_media_by_user_paginate_txn( txn: LoggingTransaction, ) -> Tuple[List[Dict[str, Any]], int]: # Set ordering order_by_column = MediaSortOrder(order_by).value if direction == "b": order = "DESC" else: order = "ASC" args: List[Union[str, int]] = [user_id] sql = """ SELECT COUNT(*) as total_media FROM local_media_repository WHERE user_id = ? """ txn.execute(sql, args) count = cast(Tuple[int], txn.fetchone())[0] sql = """ SELECT "media_id", "media_type", "media_length", "upload_name", "created_ts", "last_access_ts", "quarantined_by", "safe_from_quarantine" FROM local_media_repository WHERE user_id = ? ORDER BY {order_by_column} {order}, media_id ASC LIMIT ? OFFSET ? """.format( order_by_column=order_by_column, order=order, ) args += [limit, start] txn.execute(sql, args) media = self.db_pool.cursor_to_dict(txn) return media, count