def _get_threaded_messages_per_user_txn( txn: LoggingTransaction, ) -> Dict[Tuple[str, str], int]: users_sql, users_args = make_in_list_sql_clause( self.database_engine, "child.sender", users) events_clause, events_args = make_in_list_sql_clause( txn.database_engine, "relates_to_id", event_ids) txn.execute( sql % (users_sql, events_clause), [RelationTypes.THREAD] + users_args + events_args, ) return {(row[0], row[1]): row[2] for row in txn}
def _reset_federation_positions_txn(self, txn) -> None: """Fiddles with the `federation_stream_position` table to make it match the configured federation sender instances during start up. """ # The federation sender instances may have changed, so we need to # massage the `federation_stream_position` table to have a row per type # per instance sending federation. If there is a mismatch we update the # table with the correct rows using the *minimum* stream ID seen. This # may result in resending of events/EDUs to remote servers, but that is # preferable to dropping them. if not self._send_federation: return # Pull out the configured instances. If we don't have a shard config then # we assume that we're the only instance sending. configured_instances = self._federation_shard_config.instances if not configured_instances: configured_instances = [self._instance_name] elif self._instance_name not in configured_instances: return instances_in_table = self.db_pool.simple_select_onecol_txn( txn, table="federation_stream_position", keyvalues={}, retcol="instance_name", ) if set(instances_in_table) == set(configured_instances): # Nothing to do return sql = """ SELECT type, MIN(stream_id) FROM federation_stream_position GROUP BY type """ txn.execute(sql) min_positions = dict(txn) # Map from type -> min position # Ensure we do actually have some values here assert set(min_positions) == {"federation", "events"} sql = """ DELETE FROM federation_stream_position WHERE NOT (%s) """ clause, args = make_in_list_sql_clause( txn.database_engine, "instance_name", configured_instances ) txn.execute(sql % (clause,), args) for typ, stream_id in min_positions.items(): self.db_pool.simple_upsert_txn( txn, table="federation_stream_position", keyvalues={"type": typ, "instance_name": self._instance_name}, values={"stream_id": stream_id}, )
def _get_bulk_e2e_unused_fallback_keys_txn( txn: LoggingTransaction, ) -> TransactionUnusedFallbackKeys: user_in_where_clause, user_parameters = make_in_list_sql_clause( self.database_engine, "devices.user_id", user_ids) # We can't use USING here because we require the `.used` condition # to be part of the JOIN condition so that we generate empty lists # when all keys are used (as opposed to just when there are no keys at all). sql = f""" SELECT devices.user_id, devices.device_id, algorithm FROM devices LEFT JOIN e2e_fallback_keys_json AS fallback_keys ON devices.user_id = fallback_keys.user_id AND devices.device_id = fallback_keys.device_id AND NOT fallback_keys.used WHERE {user_in_where_clause} """ txn.execute(sql, user_parameters) result: TransactionUnusedFallbackKeys = {} for user_id, device_id, algorithm in txn: # We deliberately construct empty dictionaries and lists for # users and devices without any unused fallback keys. # We *could* omit these empty dicts if there have been no # changes since the last transaction, but we currently don't # do any change tracking! device_unused_keys = result.setdefault(user_id, {}).setdefault( device_id, []) if algorithm is not None: # algorithm will be None if this device has no keys. device_unused_keys.append(algorithm) return result
def _count_bulk_e2e_one_time_keys_txn( txn: LoggingTransaction, ) -> TransactionOneTimeKeyCounts: user_in_where_clause, user_parameters = make_in_list_sql_clause( self.database_engine, "user_id", user_ids) sql = f""" SELECT user_id, device_id, algorithm, COUNT(key_id) FROM devices LEFT JOIN e2e_one_time_keys_json USING (user_id, device_id) WHERE {user_in_where_clause} GROUP BY user_id, device_id, algorithm """ txn.execute(sql, user_parameters) result: TransactionOneTimeKeyCounts = {} for user_id, device_id, algorithm, count in txn: # We deliberately construct empty dictionaries for # users and devices without any unused one-time keys. # We *could* omit these empty dicts if there have been no # changes since the last transaction, but we currently don't # do any change tracking! device_count_by_algo = result.setdefault(user_id, {}).setdefault( device_id, {}) if algorithm is not None: # algorithm will be None if this device has no keys. device_count_by_algo[algorithm] = count return result
def get_metadata_for_events_txn( txn: LoggingTransaction, batch_ids: Collection[str], ) -> Dict[str, EventMetadata]: clause, args = make_in_list_sql_clause(self.database_engine, "e.event_id", batch_ids) sql = f""" SELECT e.event_id, e.room_id, e.type, se.state_key, r.reason FROM events AS e LEFT JOIN state_events se USING (event_id) LEFT JOIN rejections r USING (event_id) WHERE {clause} """ txn.execute(sql, args) return { event_id: EventMetadata( room_id=room_id, event_type=event_type, state_key=state_key, rejection_reason=rejection_reason, ) for event_id, room_id, event_type, state_key, rejection_reason in txn }
def _get_applicable_edits_txn( txn: LoggingTransaction) -> Dict[str, str]: clause, args = make_in_list_sql_clause(txn.database_engine, "relates_to_id", event_ids) args.append(RelationTypes.REPLACE) txn.execute(sql % (clause, ), args) return dict(cast(Iterable[Tuple[str, str]], txn.fetchall()))
def _get_if_events_have_relations(txn) -> List[str]: clauses: List[str] = [] clause, args = make_in_list_sql_clause(txn.database_engine, "relates_to_id", parent_ids) clauses.append(clause) if relation_senders: clause, temp_args = make_in_list_sql_clause( txn.database_engine, "sender", relation_senders) clauses.append(clause) args.extend(temp_args) if relation_types: clause, temp_args = make_in_list_sql_clause( txn.database_engine, "relation_type", relation_types) clauses.append(clause) args.extend(temp_args) txn.execute(sql % " AND ".join(clauses), args) return [row[0] for row in txn]
async def get_aggregation_groups_for_users( self, event_id: str, room_id: str, limit: int, users: FrozenSet[str] = frozenset(), ) -> Dict[Tuple[str, str], int]: """Fetch the partial aggregations for an event for specific users. This is used, in conjunction with get_aggregation_groups_for_event, to remove information from the results for ignored users. Args: event_id: Fetch events that relate to this event ID. room_id: The room the event belongs to. limit: Only fetch the `limit` groups. users: The users to fetch information for. Returns: A map of (event type, aggregation key) to a count of users. """ if not users: return {} args: List[Union[str, int]] = [ event_id, room_id, RelationTypes.ANNOTATION, ] users_sql, users_args = make_in_list_sql_clause( self.database_engine, "sender", users) args.extend(users_args) sql = f""" SELECT type, aggregation_key, COUNT(DISTINCT sender) FROM event_relations INNER JOIN events USING (event_id) WHERE relates_to_id = ? AND room_id = ? AND relation_type = ? AND {users_sql} GROUP BY relation_type, type, aggregation_key ORDER BY COUNT(*) DESC LIMIT ? """ def _get_aggregation_groups_for_users_txn( txn: LoggingTransaction, ) -> Dict[Tuple[str, str], int]: txn.execute(sql, args + [limit]) return {(row[0], row[1]): row[2] for row in txn} return await self.db_pool.runInteraction( "get_aggregation_groups_for_users", _get_aggregation_groups_for_users_txn)
def _get_bare_e2e_cross_signing_keys_bulk_txn( self, txn: Connection, user_ids: List[str], ) -> Dict[str, Dict[str, dict]]: """Returns the cross-signing keys for a set of users. The output of this function should be passed to _get_e2e_cross_signing_signatures_txn if the signatures for the calling user need to be fetched. Args: txn (twisted.enterprise.adbapi.Connection): db connection user_ids (list[str]): the users whose keys are being requested Returns: dict[str, dict[str, dict]]: mapping from user ID to key type to key data. If a user's cross-signing keys were not found, their user ID will not be in the dict. """ result = {} for user_chunk in batch_iter(user_ids, 100): clause, params = make_in_list_sql_clause( txn.database_engine, "k.user_id", user_chunk ) sql = ( """ SELECT k.user_id, k.keytype, k.keydata, k.stream_id FROM e2e_cross_signing_keys k INNER JOIN (SELECT user_id, keytype, MAX(stream_id) AS stream_id FROM e2e_cross_signing_keys GROUP BY user_id, keytype) s USING (user_id, stream_id, keytype) WHERE """ + clause ) txn.execute(sql, params) rows = self.db.cursor_to_dict(txn) for row in rows: user_id = row["user_id"] key_type = row["keytype"] key = json.loads(row["keydata"]) user_info = result.setdefault(user_id, {}) user_info[key_type] = key return result
def _get_threads_participated_txn(txn: LoggingTransaction) -> Set[str]: # Fetch whether the requester has participated or not. sql = """ SELECT DISTINCT relates_to_id FROM events AS child INNER JOIN event_relations USING (event_id) INNER JOIN events AS parent ON parent.event_id = relates_to_id AND parent.room_id = child.room_id WHERE %s AND relation_type = ? AND child.sender = ? """ clause, args = make_in_list_sql_clause(txn.database_engine, "relates_to_id", event_ids) args.extend([RelationTypes.THREAD, user_id]) txn.execute(sql % (clause, ), args) return {row[0] for row in txn.fetchall()}
async def get_mutual_event_relations( self, event_id: str, relation_types: Collection[str] ) -> Dict[str, Set[Tuple[str, str]]]: """ Fetch event metadata for events which related to the same event as the given event. If the given event has no relation information, returns an empty dictionary. Args: event_id: The event ID which is targeted by relations. relation_types: The relation types to check for mutual relations. Returns: A dictionary of relation type to: A set of tuples of: The sender The event type """ rel_type_sql, rel_type_args = make_in_list_sql_clause( self.database_engine, "relation_type", relation_types) sql = f""" SELECT DISTINCT relation_type, sender, type FROM event_relations INNER JOIN events USING (event_id) WHERE relates_to_id = ? AND {rel_type_sql} """ def _get_event_relations( txn: LoggingTransaction, ) -> Dict[str, Set[Tuple[str, str]]]: txn.execute(sql, [event_id] + rel_type_args) result: Dict[str, Set[Tuple[str, str]]] = { rel_type: set() for rel_type in relation_types } for rel_type, sender, type in txn.fetchall(): result[rel_type].add((sender, type)) return result return await self.db_pool.runInteraction("get_event_relations", _get_event_relations)
def _reap_users(txn, reserved_users): """ Args: reserved_users (tuple): reserved users to preserve """ thirty_days_ago = int( self._clock.time_msec()) - (1000 * 60 * 60 * 24 * 30) in_clause, in_clause_args = make_in_list_sql_clause( self.database_engine, "user_id", reserved_users) txn.execute( "DELETE FROM monthly_active_users WHERE timestamp < ? AND NOT %s" % (in_clause, ), [thirty_days_ago] + in_clause_args, ) if self._limit_usage_by_mau: # If MAU user count still exceeds the MAU threshold, then delete on # a least recently active basis. # Note it is not possible to write this query using OFFSET due to # incompatibilities in how sqlite and postgres support the feature. # Sqlite requires 'LIMIT -1 OFFSET ?', the LIMIT must be present, # while Postgres does not require 'LIMIT', but also does not support # negative LIMIT values. So there is no way to write it that both can # support # Limit must be >= 0 for postgres num_of_non_reserved_users_to_remove = max( self._max_mau_value - len(reserved_users), 0) # It is important to filter reserved users twice to guard # against the case where the reserved user is present in the # SELECT, meaning that a legitimate mau is deleted. sql = """ DELETE FROM monthly_active_users WHERE user_id NOT IN ( SELECT user_id FROM monthly_active_users WHERE NOT %s ORDER BY timestamp DESC LIMIT ? ) AND NOT %s """ % ( in_clause, in_clause, ) query_args = (in_clause_args + [num_of_non_reserved_users_to_remove] + in_clause_args) txn.execute(sql, query_args) # It seems poor to invalidate the whole cache. Postgres supports # 'Returning' which would allow me to invalidate only the # specific users, but sqlite has no way to do this and instead # I would need to SELECT and the DELETE which without locking # is racy. # Have resolved to invalidate the whole cache for now and do # something about it if and when the perf becomes significant self._invalidate_all_cache_and_stream( txn, self.user_last_seen_monthly_active) self._invalidate_cache_and_stream(txn, self.get_monthly_active_count, ())
def _get_bare_e2e_cross_signing_keys_bulk_txn( self, txn: Connection, user_ids: List[str], ) -> Dict[str, Dict[str, dict]]: """Returns the cross-signing keys for a set of users. The output of this function should be passed to _get_e2e_cross_signing_signatures_txn if the signatures for the calling user need to be fetched. Args: txn (twisted.enterprise.adbapi.Connection): db connection user_ids (list[str]): the users whose keys are being requested Returns: dict[str, dict[str, dict]]: mapping from user ID to key type to key data. If a user's cross-signing keys were not found, their user ID will not be in the dict. """ result = {} for user_chunk in batch_iter(user_ids, 100): clause, params = make_in_list_sql_clause(txn.database_engine, "user_id", user_chunk) # Fetch the latest key for each type per user. if isinstance(self.database_engine, PostgresEngine): # The `DISTINCT ON` clause will pick the *first* row it # encounters, so ordering by stream ID desc will ensure we get # the latest key. sql = """ SELECT DISTINCT ON (user_id, keytype) user_id, keytype, keydata, stream_id FROM e2e_cross_signing_keys WHERE %(clause)s ORDER BY user_id, keytype, stream_id DESC """ % { "clause": clause } else: # SQLite has special handling for bare columns when using # MIN/MAX with a `GROUP BY` clause where it picks the value from # a row that matches the MIN/MAX. sql = """ SELECT user_id, keytype, keydata, MAX(stream_id) FROM e2e_cross_signing_keys WHERE %(clause)s GROUP BY user_id, keytype """ % { "clause": clause } txn.execute(sql, params) rows = self.db_pool.cursor_to_dict(txn) for row in rows: user_id = row["user_id"] key_type = row["keytype"] key = db_to_json(row["keydata"]) user_info = result.setdefault(user_id, {}) user_info[key_type] = key return result
def get_device_messages_txn( txn: LoggingTransaction, ) -> Tuple[Dict[Tuple[str, str], List[JsonDict]], int]: # Build a query to select messages from any of the given devices that # are between the given stream id bounds. # If a list of device IDs was not provided, retrieve all devices IDs # for the given users. We explicitly do not query hidden devices, as # hidden devices should not receive to-device messages. # Note that this is more efficient than just dropping `device_id` from the query, # since device_inbox has an index on `(user_id, device_id, stream_id)` if not device_ids_to_query: user_device_dicts = self.db_pool.simple_select_many_txn( txn, table="devices", column="user_id", iterable=user_ids_to_query, keyvalues={ "user_id": user_id, "hidden": False }, retcols=("device_id", ), ) device_ids_to_query.update( {row["device_id"] for row in user_device_dicts}) if not device_ids_to_query: # We've ended up with no devices to query. return {}, to_stream_id # We include both user IDs and device IDs in this query, as we have an index # (device_inbox_user_stream_id) for them. user_id_many_clause_sql, user_id_many_clause_args = make_in_list_sql_clause( self.database_engine, "user_id", user_ids_to_query) ( device_id_many_clause_sql, device_id_many_clause_args, ) = make_in_list_sql_clause(self.database_engine, "device_id", device_ids_to_query) sql = f""" SELECT stream_id, user_id, device_id, message_json FROM device_inbox WHERE {user_id_many_clause_sql} AND {device_id_many_clause_sql} AND ? < stream_id AND stream_id <= ? ORDER BY stream_id ASC """ sql_args = ( *user_id_many_clause_args, *device_id_many_clause_args, from_stream_id, to_stream_id, ) # If a limit was provided, limit the data retrieved from the database if limit is not None: sql += "LIMIT ?" sql_args += (limit, ) txn.execute(sql, sql_args) # Create and fill a dictionary of (user ID, device ID) -> list of messages # intended for each device. last_processed_stream_pos = to_stream_id recipient_device_to_messages: Dict[Tuple[str, str], List[JsonDict]] = {} rowcount = 0 for row in txn: rowcount += 1 last_processed_stream_pos = row[0] recipient_user_id = row[1] recipient_device_id = row[2] message_dict = db_to_json(row[3]) # Store the device details recipient_device_to_messages.setdefault( (recipient_user_id, recipient_device_id), []).append(message_dict) if limit is not None and rowcount == limit: # We ended up bumping up against the message limit. There may be more messages # to retrieve. Return what we have, as well as the last stream position that # was processed. # # The caller is expected to set this as the lower (exclusive) bound # for the next query of this device. return recipient_device_to_messages, last_processed_stream_pos # The limit was not reached, thus we know that recipient_device_to_messages # contains all to-device messages for the given device and stream id range. # # We return to_stream_id, which the caller should then provide as the lower # (exclusive) bound on the next query of this device. return recipient_device_to_messages, to_stream_id
def _get_thread_summaries_txn( txn: LoggingTransaction, ) -> Tuple[Dict[str, int], Dict[str, str]]: # Fetch the count of threaded events and the latest event ID. # TODO Should this only allow m.room.message events. if isinstance(self.database_engine, PostgresEngine): # The `DISTINCT ON` clause will pick the *first* row it encounters, # so ordering by topological ordering + stream ordering desc will # ensure we get the latest event in the thread. sql = """ SELECT DISTINCT ON (parent.event_id) parent.event_id, child.event_id FROM events AS child INNER JOIN event_relations USING (event_id) INNER JOIN events AS parent ON parent.event_id = relates_to_id AND parent.room_id = child.room_id WHERE %s AND relation_type = ? ORDER BY parent.event_id, child.topological_ordering DESC, child.stream_ordering DESC """ else: # SQLite uses a simplified query which returns all entries for a # thread. The first result for each thread is chosen to and subsequent # results for a thread are ignored. sql = """ SELECT parent.event_id, child.event_id FROM events AS child INNER JOIN event_relations USING (event_id) INNER JOIN events AS parent ON parent.event_id = relates_to_id AND parent.room_id = child.room_id WHERE %s AND relation_type = ? ORDER BY child.topological_ordering DESC, child.stream_ordering DESC """ clause, args = make_in_list_sql_clause(txn.database_engine, "relates_to_id", event_ids) args.append(RelationTypes.THREAD) txn.execute(sql % (clause, ), args) latest_event_ids = {} for parent_event_id, child_event_id in txn: # Only consider the latest threaded reply (by topological ordering). if parent_event_id not in latest_event_ids: latest_event_ids[parent_event_id] = child_event_id # If no threads were found, bail. if not latest_event_ids: return {}, latest_event_ids # Fetch the number of threaded replies. sql = """ SELECT parent.event_id, COUNT(child.event_id) FROM events AS child INNER JOIN event_relations USING (event_id) INNER JOIN events AS parent ON parent.event_id = relates_to_id AND parent.room_id = child.room_id WHERE %s AND relation_type = ? GROUP BY parent.event_id """ # Regenerate the arguments since only threads found above could # possibly have any replies. clause, args = make_in_list_sql_clause(txn.database_engine, "relates_to_id", latest_event_ids.keys()) args.append(RelationTypes.THREAD) txn.execute(sql % (clause, ), args) counts = dict(cast(List[Tuple[str, int]], txn.fetchall())) return counts, latest_event_ids