Beispiel #1
0
        def _get_threaded_messages_per_user_txn(
            txn: LoggingTransaction, ) -> Dict[Tuple[str, str], int]:
            users_sql, users_args = make_in_list_sql_clause(
                self.database_engine, "child.sender", users)
            events_clause, events_args = make_in_list_sql_clause(
                txn.database_engine, "relates_to_id", event_ids)

            txn.execute(
                sql % (users_sql, events_clause),
                [RelationTypes.THREAD] + users_args + events_args,
            )
            return {(row[0], row[1]): row[2] for row in txn}
Beispiel #2
0
    def _reset_federation_positions_txn(self, txn) -> None:
        """Fiddles with the `federation_stream_position` table to make it match
        the configured federation sender instances during start up.
        """

        # The federation sender instances may have changed, so we need to
        # massage the `federation_stream_position` table to have a row per type
        # per instance sending federation. If there is a mismatch we update the
        # table with the correct rows using the *minimum* stream ID seen. This
        # may result in resending of events/EDUs to remote servers, but that is
        # preferable to dropping them.

        if not self._send_federation:
            return

        # Pull out the configured instances. If we don't have a shard config then
        # we assume that we're the only instance sending.
        configured_instances = self._federation_shard_config.instances
        if not configured_instances:
            configured_instances = [self._instance_name]
        elif self._instance_name not in configured_instances:
            return

        instances_in_table = self.db_pool.simple_select_onecol_txn(
            txn,
            table="federation_stream_position",
            keyvalues={},
            retcol="instance_name",
        )

        if set(instances_in_table) == set(configured_instances):
            # Nothing to do
            return

        sql = """
            SELECT type, MIN(stream_id) FROM federation_stream_position
            GROUP BY type
        """
        txn.execute(sql)
        min_positions = dict(txn)  # Map from type -> min position

        # Ensure we do actually have some values here
        assert set(min_positions) == {"federation", "events"}

        sql = """
            DELETE FROM federation_stream_position
            WHERE NOT (%s)
        """
        clause, args = make_in_list_sql_clause(
            txn.database_engine, "instance_name", configured_instances
        )
        txn.execute(sql % (clause,), args)

        for typ, stream_id in min_positions.items():
            self.db_pool.simple_upsert_txn(
                txn,
                table="federation_stream_position",
                keyvalues={"type": typ, "instance_name": self._instance_name},
                values={"stream_id": stream_id},
            )
Beispiel #3
0
        def _get_bulk_e2e_unused_fallback_keys_txn(
            txn: LoggingTransaction, ) -> TransactionUnusedFallbackKeys:
            user_in_where_clause, user_parameters = make_in_list_sql_clause(
                self.database_engine, "devices.user_id", user_ids)
            # We can't use USING here because we require the `.used` condition
            # to be part of the JOIN condition so that we generate empty lists
            # when all keys are used (as opposed to just when there are no keys at all).
            sql = f"""
                SELECT devices.user_id, devices.device_id, algorithm
                FROM devices
                LEFT JOIN e2e_fallback_keys_json AS fallback_keys
                    ON devices.user_id = fallback_keys.user_id
                    AND devices.device_id = fallback_keys.device_id
                    AND NOT fallback_keys.used
                WHERE
                    {user_in_where_clause}
            """
            txn.execute(sql, user_parameters)

            result: TransactionUnusedFallbackKeys = {}

            for user_id, device_id, algorithm in txn:
                # We deliberately construct empty dictionaries and lists for
                # users and devices without any unused fallback keys.
                # We *could* omit these empty dicts if there have been no
                # changes since the last transaction, but we currently don't
                # do any change tracking!
                device_unused_keys = result.setdefault(user_id, {}).setdefault(
                    device_id, [])
                if algorithm is not None:
                    # algorithm will be None if this device has no keys.
                    device_unused_keys.append(algorithm)

            return result
Beispiel #4
0
        def _count_bulk_e2e_one_time_keys_txn(
            txn: LoggingTransaction, ) -> TransactionOneTimeKeyCounts:
            user_in_where_clause, user_parameters = make_in_list_sql_clause(
                self.database_engine, "user_id", user_ids)
            sql = f"""
                SELECT user_id, device_id, algorithm, COUNT(key_id)
                FROM devices
                LEFT JOIN e2e_one_time_keys_json USING (user_id, device_id)
                WHERE {user_in_where_clause}
                GROUP BY user_id, device_id, algorithm
            """
            txn.execute(sql, user_parameters)

            result: TransactionOneTimeKeyCounts = {}

            for user_id, device_id, algorithm, count in txn:
                # We deliberately construct empty dictionaries for
                # users and devices without any unused one-time keys.
                # We *could* omit these empty dicts if there have been no
                # changes since the last transaction, but we currently don't
                # do any change tracking!
                device_count_by_algo = result.setdefault(user_id,
                                                         {}).setdefault(
                                                             device_id, {})
                if algorithm is not None:
                    # algorithm will be None if this device has no keys.
                    device_count_by_algo[algorithm] = count

            return result
Beispiel #5
0
        def get_metadata_for_events_txn(
            txn: LoggingTransaction,
            batch_ids: Collection[str],
        ) -> Dict[str, EventMetadata]:
            clause, args = make_in_list_sql_clause(self.database_engine,
                                                   "e.event_id", batch_ids)

            sql = f"""
                SELECT e.event_id, e.room_id, e.type, se.state_key, r.reason
                FROM events AS e
                LEFT JOIN state_events se USING (event_id)
                LEFT JOIN rejections r USING (event_id)
                WHERE {clause}
            """

            txn.execute(sql, args)
            return {
                event_id: EventMetadata(
                    room_id=room_id,
                    event_type=event_type,
                    state_key=state_key,
                    rejection_reason=rejection_reason,
                )
                for event_id, room_id, event_type, state_key, rejection_reason
                in txn
            }
Beispiel #6
0
        def _get_applicable_edits_txn(
                txn: LoggingTransaction) -> Dict[str, str]:
            clause, args = make_in_list_sql_clause(txn.database_engine,
                                                   "relates_to_id", event_ids)
            args.append(RelationTypes.REPLACE)

            txn.execute(sql % (clause, ), args)
            return dict(cast(Iterable[Tuple[str, str]], txn.fetchall()))
Beispiel #7
0
        def _get_if_events_have_relations(txn) -> List[str]:
            clauses: List[str] = []
            clause, args = make_in_list_sql_clause(txn.database_engine,
                                                   "relates_to_id", parent_ids)
            clauses.append(clause)

            if relation_senders:
                clause, temp_args = make_in_list_sql_clause(
                    txn.database_engine, "sender", relation_senders)
                clauses.append(clause)
                args.extend(temp_args)
            if relation_types:
                clause, temp_args = make_in_list_sql_clause(
                    txn.database_engine, "relation_type", relation_types)
                clauses.append(clause)
                args.extend(temp_args)

            txn.execute(sql % " AND ".join(clauses), args)

            return [row[0] for row in txn]
Beispiel #8
0
    async def get_aggregation_groups_for_users(
            self,
            event_id: str,
            room_id: str,
            limit: int,
            users: FrozenSet[str] = frozenset(),
    ) -> Dict[Tuple[str, str], int]:
        """Fetch the partial aggregations for an event for specific users.

        This is used, in conjunction with get_aggregation_groups_for_event, to
        remove information from the results for ignored users.

        Args:
            event_id: Fetch events that relate to this event ID.
            room_id: The room the event belongs to.
            limit: Only fetch the `limit` groups.
            users: The users to fetch information for.

        Returns:
            A map of (event type, aggregation key) to a count of users.
        """

        if not users:
            return {}

        args: List[Union[str, int]] = [
            event_id,
            room_id,
            RelationTypes.ANNOTATION,
        ]

        users_sql, users_args = make_in_list_sql_clause(
            self.database_engine, "sender", users)
        args.extend(users_args)

        sql = f"""
            SELECT type, aggregation_key, COUNT(DISTINCT sender)
            FROM event_relations
            INNER JOIN events USING (event_id)
            WHERE relates_to_id = ? AND room_id = ? AND relation_type = ? AND {users_sql}
            GROUP BY relation_type, type, aggregation_key
            ORDER BY COUNT(*) DESC
            LIMIT ?
        """

        def _get_aggregation_groups_for_users_txn(
            txn: LoggingTransaction, ) -> Dict[Tuple[str, str], int]:
            txn.execute(sql, args + [limit])

            return {(row[0], row[1]): row[2] for row in txn}

        return await self.db_pool.runInteraction(
            "get_aggregation_groups_for_users",
            _get_aggregation_groups_for_users_txn)
Beispiel #9
0
    def _get_bare_e2e_cross_signing_keys_bulk_txn(
        self, txn: Connection, user_ids: List[str],
    ) -> Dict[str, Dict[str, dict]]:
        """Returns the cross-signing keys for a set of users.  The output of this
        function should be passed to _get_e2e_cross_signing_signatures_txn if
        the signatures for the calling user need to be fetched.

        Args:
            txn (twisted.enterprise.adbapi.Connection): db connection
            user_ids (list[str]): the users whose keys are being requested

        Returns:
            dict[str, dict[str, dict]]: mapping from user ID to key type to key
                data.  If a user's cross-signing keys were not found, their user
                ID will not be in the dict.

        """
        result = {}

        for user_chunk in batch_iter(user_ids, 100):
            clause, params = make_in_list_sql_clause(
                txn.database_engine, "k.user_id", user_chunk
            )
            sql = (
                """
                SELECT k.user_id, k.keytype, k.keydata, k.stream_id
                  FROM e2e_cross_signing_keys k
                  INNER JOIN (SELECT user_id, keytype, MAX(stream_id) AS stream_id
                                FROM e2e_cross_signing_keys
                               GROUP BY user_id, keytype) s
                 USING (user_id, stream_id, keytype)
                 WHERE
            """
                + clause
            )

            txn.execute(sql, params)
            rows = self.db.cursor_to_dict(txn)

            for row in rows:
                user_id = row["user_id"]
                key_type = row["keytype"]
                key = json.loads(row["keydata"])
                user_info = result.setdefault(user_id, {})
                user_info[key_type] = key

        return result
Beispiel #10
0
        def _get_threads_participated_txn(txn: LoggingTransaction) -> Set[str]:
            # Fetch whether the requester has participated or not.
            sql = """
                SELECT DISTINCT relates_to_id
                FROM events AS child
                INNER JOIN event_relations USING (event_id)
                INNER JOIN events AS parent ON
                    parent.event_id = relates_to_id
                    AND parent.room_id = child.room_id
                WHERE
                    %s
                    AND relation_type = ?
                    AND child.sender = ?
            """

            clause, args = make_in_list_sql_clause(txn.database_engine,
                                                   "relates_to_id", event_ids)
            args.extend([RelationTypes.THREAD, user_id])

            txn.execute(sql % (clause, ), args)
            return {row[0] for row in txn.fetchall()}
Beispiel #11
0
    async def get_mutual_event_relations(
            self, event_id: str, relation_types: Collection[str]
    ) -> Dict[str, Set[Tuple[str, str]]]:
        """
        Fetch event metadata for events which related to the same event as the given event.

        If the given event has no relation information, returns an empty dictionary.

        Args:
            event_id: The event ID which is targeted by relations.
            relation_types: The relation types to check for mutual relations.

        Returns:
            A dictionary of relation type to:
                A set of tuples of:
                    The sender
                    The event type
        """
        rel_type_sql, rel_type_args = make_in_list_sql_clause(
            self.database_engine, "relation_type", relation_types)

        sql = f"""
            SELECT DISTINCT relation_type, sender, type FROM event_relations
            INNER JOIN events USING (event_id)
            WHERE relates_to_id = ? AND {rel_type_sql}
        """

        def _get_event_relations(
            txn: LoggingTransaction, ) -> Dict[str, Set[Tuple[str, str]]]:
            txn.execute(sql, [event_id] + rel_type_args)
            result: Dict[str, Set[Tuple[str, str]]] = {
                rel_type: set()
                for rel_type in relation_types
            }
            for rel_type, sender, type in txn.fetchall():
                result[rel_type].add((sender, type))
            return result

        return await self.db_pool.runInteraction("get_event_relations",
                                                 _get_event_relations)
Beispiel #12
0
        def _reap_users(txn, reserved_users):
            """
            Args:
                reserved_users (tuple): reserved users to preserve
            """

            thirty_days_ago = int(
                self._clock.time_msec()) - (1000 * 60 * 60 * 24 * 30)

            in_clause, in_clause_args = make_in_list_sql_clause(
                self.database_engine, "user_id", reserved_users)

            txn.execute(
                "DELETE FROM monthly_active_users WHERE timestamp < ? AND NOT %s"
                % (in_clause, ),
                [thirty_days_ago] + in_clause_args,
            )

            if self._limit_usage_by_mau:
                # If MAU user count still exceeds the MAU threshold, then delete on
                # a least recently active basis.
                # Note it is not possible to write this query using OFFSET due to
                # incompatibilities in how sqlite and postgres support the feature.
                # Sqlite requires 'LIMIT -1 OFFSET ?', the LIMIT must be present,
                # while Postgres does not require 'LIMIT', but also does not support
                # negative LIMIT values. So there is no way to write it that both can
                # support

                # Limit must be >= 0 for postgres
                num_of_non_reserved_users_to_remove = max(
                    self._max_mau_value - len(reserved_users), 0)

                # It is important to filter reserved users twice to guard
                # against the case where the reserved user is present in the
                # SELECT, meaning that a legitimate mau is deleted.
                sql = """
                    DELETE FROM monthly_active_users
                    WHERE user_id NOT IN (
                        SELECT user_id FROM monthly_active_users
                        WHERE NOT %s
                        ORDER BY timestamp DESC
                        LIMIT ?
                    )
                    AND NOT %s
                """ % (
                    in_clause,
                    in_clause,
                )

                query_args = (in_clause_args +
                              [num_of_non_reserved_users_to_remove] +
                              in_clause_args)
                txn.execute(sql, query_args)

            # It seems poor to invalidate the whole cache. Postgres supports
            # 'Returning' which would allow me to invalidate only the
            # specific users, but sqlite has no way to do this and instead
            # I would need to SELECT and the DELETE which without locking
            # is racy.
            # Have resolved to invalidate the whole cache for now and do
            # something about it if and when the perf becomes significant
            self._invalidate_all_cache_and_stream(
                txn, self.user_last_seen_monthly_active)
            self._invalidate_cache_and_stream(txn,
                                              self.get_monthly_active_count,
                                              ())
    def _get_bare_e2e_cross_signing_keys_bulk_txn(
        self,
        txn: Connection,
        user_ids: List[str],
    ) -> Dict[str, Dict[str, dict]]:
        """Returns the cross-signing keys for a set of users.  The output of this
        function should be passed to _get_e2e_cross_signing_signatures_txn if
        the signatures for the calling user need to be fetched.

        Args:
            txn (twisted.enterprise.adbapi.Connection): db connection
            user_ids (list[str]): the users whose keys are being requested

        Returns:
            dict[str, dict[str, dict]]: mapping from user ID to key type to key
                data.  If a user's cross-signing keys were not found, their user
                ID will not be in the dict.

        """
        result = {}

        for user_chunk in batch_iter(user_ids, 100):
            clause, params = make_in_list_sql_clause(txn.database_engine,
                                                     "user_id", user_chunk)

            # Fetch the latest key for each type per user.
            if isinstance(self.database_engine, PostgresEngine):
                # The `DISTINCT ON` clause will pick the *first* row it
                # encounters, so ordering by stream ID desc will ensure we get
                # the latest key.
                sql = """
                    SELECT DISTINCT ON (user_id, keytype) user_id, keytype, keydata, stream_id
                        FROM e2e_cross_signing_keys
                        WHERE %(clause)s
                        ORDER BY user_id, keytype, stream_id DESC
                """ % {
                    "clause": clause
                }
            else:
                # SQLite has special handling for bare columns when using
                # MIN/MAX with a `GROUP BY` clause where it picks the value from
                # a row that matches the MIN/MAX.
                sql = """
                    SELECT user_id, keytype, keydata, MAX(stream_id)
                        FROM e2e_cross_signing_keys
                        WHERE %(clause)s
                        GROUP BY user_id, keytype
                """ % {
                    "clause": clause
                }

            txn.execute(sql, params)
            rows = self.db_pool.cursor_to_dict(txn)

            for row in rows:
                user_id = row["user_id"]
                key_type = row["keytype"]
                key = db_to_json(row["keydata"])
                user_info = result.setdefault(user_id, {})
                user_info[key_type] = key

        return result
Beispiel #14
0
        def get_device_messages_txn(
            txn: LoggingTransaction,
        ) -> Tuple[Dict[Tuple[str, str], List[JsonDict]], int]:
            # Build a query to select messages from any of the given devices that
            # are between the given stream id bounds.

            # If a list of device IDs was not provided, retrieve all devices IDs
            # for the given users. We explicitly do not query hidden devices, as
            # hidden devices should not receive to-device messages.
            # Note that this is more efficient than just dropping `device_id` from the query,
            # since device_inbox has an index on `(user_id, device_id, stream_id)`
            if not device_ids_to_query:
                user_device_dicts = self.db_pool.simple_select_many_txn(
                    txn,
                    table="devices",
                    column="user_id",
                    iterable=user_ids_to_query,
                    keyvalues={
                        "user_id": user_id,
                        "hidden": False
                    },
                    retcols=("device_id", ),
                )

                device_ids_to_query.update(
                    {row["device_id"]
                     for row in user_device_dicts})

            if not device_ids_to_query:
                # We've ended up with no devices to query.
                return {}, to_stream_id

            # We include both user IDs and device IDs in this query, as we have an index
            # (device_inbox_user_stream_id) for them.
            user_id_many_clause_sql, user_id_many_clause_args = make_in_list_sql_clause(
                self.database_engine, "user_id", user_ids_to_query)
            (
                device_id_many_clause_sql,
                device_id_many_clause_args,
            ) = make_in_list_sql_clause(self.database_engine, "device_id",
                                        device_ids_to_query)

            sql = f"""
                SELECT stream_id, user_id, device_id, message_json FROM device_inbox
                WHERE {user_id_many_clause_sql}
                AND {device_id_many_clause_sql}
                AND ? < stream_id AND stream_id <= ?
                ORDER BY stream_id ASC
            """
            sql_args = (
                *user_id_many_clause_args,
                *device_id_many_clause_args,
                from_stream_id,
                to_stream_id,
            )

            # If a limit was provided, limit the data retrieved from the database
            if limit is not None:
                sql += "LIMIT ?"
                sql_args += (limit, )

            txn.execute(sql, sql_args)

            # Create and fill a dictionary of (user ID, device ID) -> list of messages
            # intended for each device.
            last_processed_stream_pos = to_stream_id
            recipient_device_to_messages: Dict[Tuple[str, str],
                                               List[JsonDict]] = {}
            rowcount = 0
            for row in txn:
                rowcount += 1

                last_processed_stream_pos = row[0]
                recipient_user_id = row[1]
                recipient_device_id = row[2]
                message_dict = db_to_json(row[3])

                # Store the device details
                recipient_device_to_messages.setdefault(
                    (recipient_user_id, recipient_device_id),
                    []).append(message_dict)

            if limit is not None and rowcount == limit:
                # We ended up bumping up against the message limit. There may be more messages
                # to retrieve. Return what we have, as well as the last stream position that
                # was processed.
                #
                # The caller is expected to set this as the lower (exclusive) bound
                # for the next query of this device.
                return recipient_device_to_messages, last_processed_stream_pos

            # The limit was not reached, thus we know that recipient_device_to_messages
            # contains all to-device messages for the given device and stream id range.
            #
            # We return to_stream_id, which the caller should then provide as the lower
            # (exclusive) bound on the next query of this device.
            return recipient_device_to_messages, to_stream_id
Beispiel #15
0
        def _get_thread_summaries_txn(
            txn: LoggingTransaction,
        ) -> Tuple[Dict[str, int], Dict[str, str]]:
            # Fetch the count of threaded events and the latest event ID.
            # TODO Should this only allow m.room.message events.
            if isinstance(self.database_engine, PostgresEngine):
                # The `DISTINCT ON` clause will pick the *first* row it encounters,
                # so ordering by topological ordering + stream ordering desc will
                # ensure we get the latest event in the thread.
                sql = """
                    SELECT DISTINCT ON (parent.event_id) parent.event_id, child.event_id FROM events AS child
                    INNER JOIN event_relations USING (event_id)
                    INNER JOIN events AS parent ON
                        parent.event_id = relates_to_id
                        AND parent.room_id = child.room_id
                    WHERE
                        %s
                        AND relation_type = ?
                    ORDER BY parent.event_id, child.topological_ordering DESC, child.stream_ordering DESC
                """
            else:
                # SQLite uses a simplified query which returns all entries for a
                # thread. The first result for each thread is chosen to and subsequent
                # results for a thread are ignored.
                sql = """
                    SELECT parent.event_id, child.event_id FROM events AS child
                    INNER JOIN event_relations USING (event_id)
                    INNER JOIN events AS parent ON
                        parent.event_id = relates_to_id
                        AND parent.room_id = child.room_id
                    WHERE
                        %s
                        AND relation_type = ?
                    ORDER BY child.topological_ordering DESC, child.stream_ordering DESC
                """

            clause, args = make_in_list_sql_clause(txn.database_engine,
                                                   "relates_to_id", event_ids)
            args.append(RelationTypes.THREAD)

            txn.execute(sql % (clause, ), args)
            latest_event_ids = {}
            for parent_event_id, child_event_id in txn:
                # Only consider the latest threaded reply (by topological ordering).
                if parent_event_id not in latest_event_ids:
                    latest_event_ids[parent_event_id] = child_event_id

            # If no threads were found, bail.
            if not latest_event_ids:
                return {}, latest_event_ids

            # Fetch the number of threaded replies.
            sql = """
                SELECT parent.event_id, COUNT(child.event_id) FROM events AS child
                INNER JOIN event_relations USING (event_id)
                INNER JOIN events AS parent ON
                    parent.event_id = relates_to_id
                    AND parent.room_id = child.room_id
                WHERE
                    %s
                    AND relation_type = ?
                GROUP BY parent.event_id
            """

            # Regenerate the arguments since only threads found above could
            # possibly have any replies.
            clause, args = make_in_list_sql_clause(txn.database_engine,
                                                   "relates_to_id",
                                                   latest_event_ids.keys())
            args.append(RelationTypes.THREAD)

            txn.execute(sql % (clause, ), args)
            counts = dict(cast(List[Tuple[str, int]], txn.fetchall()))

            return counts, latest_event_ids