Exemple #1
0
        def _do_txn(txn: LoggingTransaction) -> int:
            sql = (
                "SELECT filter_id FROM user_filters "
                "WHERE user_id = ? AND filter_json = ?"
            )
            txn.execute(sql, (user_localpart, bytearray(def_json)))
            filter_id_response = txn.fetchone()
            if filter_id_response is not None:
                return filter_id_response[0]

            sql = "SELECT MAX(filter_id) FROM user_filters WHERE user_id = ?"
            txn.execute(sql, (user_localpart,))
            max_id = cast(Tuple[Optional[int]], txn.fetchone())[0]
            if max_id is None:
                filter_id = 0
            else:
                filter_id = max_id + 1

            sql = (
                "INSERT INTO user_filters (user_id, filter_id, filter_json)"
                "VALUES(?, ?, ?)"
            )
            txn.execute(sql, (user_localpart, filter_id, bytearray(def_json)))

            return filter_id
Exemple #2
0
        def _get_thread_summary_txn(
            txn: LoggingTransaction, ) -> Tuple[int, Optional[str]]:
            # Fetch the count of threaded events and the latest event ID.
            # TODO Should this only allow m.room.message events.
            sql = """
                SELECT event_id
                FROM event_relations
                INNER JOIN events USING (event_id)
                WHERE
                    relates_to_id = ?
                    AND relation_type = ?
                ORDER BY topological_ordering DESC, stream_ordering DESC
                LIMIT 1
            """

            txn.execute(sql, (event_id, RelationTypes.THREAD))
            row = txn.fetchone()
            if row is None:
                return 0, None

            latest_event_id = row[0]

            sql = """
                SELECT COALESCE(COUNT(event_id), 0)
                FROM event_relations
                WHERE
                    relates_to_id = ?
                    AND relation_type = ?
            """
            txn.execute(sql, (event_id, RelationTypes.THREAD))
            count = txn.fetchone()[0]  # type: ignore[index]

            return count, latest_event_id
        def get_destination_rooms_paginate_txn(
            txn: LoggingTransaction, ) -> Tuple[List[JsonDict], int]:

            if direction == "b":
                order = "DESC"
            else:
                order = "ASC"

            sql = """
                SELECT COUNT(*) as total_rooms
                FROM destination_rooms
                WHERE destination = ?
                """
            txn.execute(sql, [destination])
            count = cast(Tuple[int], txn.fetchone())[0]

            rooms = self.db_pool.simple_select_list_paginate_txn(
                txn=txn,
                table="destination_rooms",
                orderby="room_id",
                start=start,
                limit=limit,
                retcols=("room_id", "stream_ordering"),
                order_direction=order,
            )
            return rooms, count
Exemple #4
0
 def get_sent_table_size(txn: LoggingTransaction) -> int:
     txn.execute(
         "SELECT count(*) FROM sent_transactions"
         " WHERE ts >= ?", (yesterday, ))
     result = txn.fetchone()
     assert result is not None
     return int(result[0])
        def _claim_e2e_one_time_key_returning(
                txn: LoggingTransaction, user_id: str, device_id: str,
                algorithm: str) -> Optional[Tuple[str, str]]:
            """Claim OTK for device for DBs that support RETURNING.

            Returns:
                A tuple of key name (algorithm + key ID) and key JSON, if an
                OTK was found.
            """

            # We can use RETURNING to do the fetch and DELETE in once step.
            sql = """
                DELETE FROM e2e_one_time_keys_json
                WHERE user_id = ? AND device_id = ? AND algorithm = ?
                    AND key_id IN (
                        SELECT key_id FROM e2e_one_time_keys_json
                        WHERE user_id = ? AND device_id = ? AND algorithm = ?
                        LIMIT 1
                    )
                RETURNING key_id, key_json
            """

            txn.execute(
                sql,
                (user_id, device_id, algorithm, user_id, device_id, algorithm))
            otk_row = txn.fetchone()
            if otk_row is None:
                return None

            self._invalidate_cache_and_stream(txn,
                                              self.count_e2e_one_time_keys,
                                              (user_id, device_id))

            key_id, key_json = otk_row
            return f"{algorithm}:{key_id}", key_json
Exemple #6
0
    def _delete_room_alias_txn(
        self, txn: LoggingTransaction, room_alias: RoomAlias
    ) -> Optional[str]:
        txn.execute(
            "SELECT room_id FROM room_aliases WHERE room_alias = ?",
            (room_alias.to_string(),),
        )

        res = txn.fetchone()
        if res:
            room_id = res[0]
        else:
            return None

        txn.execute(
            "DELETE FROM room_aliases WHERE room_alias = ?", (room_alias.to_string(),)
        )

        txn.execute(
            "DELETE FROM room_alias_servers WHERE room_alias = ?",
            (room_alias.to_string(),),
        )

        self._invalidate_cache_and_stream(txn, self.get_aliases_for_room, (room_id,))

        return room_id
        def get_destinations_paginate_txn(
            txn: LoggingTransaction, ) -> Tuple[List[JsonDict], int]:
            order_by_column = DestinationSortOrder(order_by).value

            if direction == "b":
                order = "DESC"
            else:
                order = "ASC"

            args: List[object] = []
            where_statement = ""
            if destination:
                args.extend(["%" + destination.lower() + "%"])
                where_statement = "WHERE LOWER(destination) LIKE ?"

            sql_base = f"FROM destinations {where_statement} "
            sql = f"SELECT COUNT(*) as total_destinations {sql_base}"
            txn.execute(sql, args)
            count = cast(Tuple[int], txn.fetchone())[0]

            sql = f"""
                SELECT destination, retry_last_ts, retry_interval, failure_ts,
                last_successful_stream_ordering
                {sql_base}
                ORDER BY {order_by_column} {order}, destination ASC
                LIMIT ? OFFSET ?
            """
            txn.execute(sql, args + [limit, start])
            destinations = self.db_pool.cursor_to_dict(txn)
            return destinations, count
Exemple #8
0
        def _get_next_batch(
            txn: LoggingTransaction,
        ) -> Optional[Sequence[Tuple[str, int]]]:
            # Only fetch 250 rooms, so we don't fetch too many at once, even
            # if those 250 rooms have less than batch_size state events.
            sql = """
                SELECT room_id, events FROM %s
                ORDER BY events DESC
                LIMIT 250
            """ % (
                TEMP_TABLE + "_rooms",
            )
            txn.execute(sql)
            rooms_to_work_on = cast(List[Tuple[str, int]], txn.fetchall())

            if not rooms_to_work_on:
                return None

            # Get how many are left to process, so we can give status on how
            # far we are in processing
            txn.execute("SELECT COUNT(*) FROM " + TEMP_TABLE + "_rooms")
            result = txn.fetchone()
            assert result is not None
            progress["remaining"] = result[0]

            return rooms_to_work_on
Exemple #9
0
        def _fetch_current_state_stats(
            txn: LoggingTransaction,
        ) -> Tuple[List[str], Dict[str, int], int, List[str], int]:
            pos = self.get_room_max_stream_ordering(
            )  # type: ignore[attr-defined]

            rows = self.db_pool.simple_select_many_txn(
                txn,
                table="current_state_events",
                column="type",
                iterable=[
                    EventTypes.Create,
                    EventTypes.JoinRules,
                    EventTypes.RoomHistoryVisibility,
                    EventTypes.RoomEncryption,
                    EventTypes.Name,
                    EventTypes.Topic,
                    EventTypes.RoomAvatar,
                    EventTypes.CanonicalAlias,
                ],
                keyvalues={
                    "room_id": room_id,
                    "state_key": ""
                },
                retcols=["event_id"],
            )

            event_ids = cast(List[str], [row["event_id"] for row in rows])

            txn.execute(
                """
                    SELECT membership, count(*) FROM current_state_events
                    WHERE room_id = ? AND type = 'm.room.member'
                    GROUP BY membership
                """,
                (room_id, ),
            )
            membership_counts = {membership: cnt for membership, cnt in txn}

            txn.execute(
                """
                    SELECT COUNT(*) FROM current_state_events
                    WHERE room_id = ?
                """,
                (room_id, ),
            )

            current_state_events_count = cast(Tuple[int], txn.fetchone())[0]

            users_in_room = self.get_users_in_room_txn(
                txn, room_id)  # type: ignore[attr-defined]

            return (
                event_ids,
                membership_counts,
                current_state_events_count,
                users_in_room,
                pos,
            )
Exemple #10
0
 def _count(txn: LoggingTransaction) -> int:
     sql = """
         SELECT COUNT(DISTINCT room_id) FROM events
         WHERE type = 'm.room.message'
         AND stream_ordering > ?
     """
     txn.execute(sql, (self.stream_ordering_day_ago, ))
     (count, ) = cast(Tuple[int], txn.fetchone())
     return count
    def _get_unread_counts_by_pos_txn(self, txn: LoggingTransaction,
                                      room_id: str, user_id: str,
                                      stream_ordering: int) -> NotifCounts:
        sql = ("SELECT"
               "   COUNT(CASE WHEN notif = 1 THEN 1 END),"
               "   COUNT(CASE WHEN highlight = 1 THEN 1 END),"
               "   COUNT(CASE WHEN unread = 1 THEN 1 END)"
               " FROM event_push_actions ea"
               " WHERE user_id = ?"
               "   AND room_id = ?"
               "   AND stream_ordering > ?")

        txn.execute(sql, (user_id, room_id, stream_ordering))
        row = txn.fetchone()

        (notif_count, highlight_count, unread_count) = (0, 0, 0)

        if row:
            (notif_count, highlight_count, unread_count) = row

        txn.execute(
            """
                SELECT notif_count, unread_count FROM event_push_summary
                WHERE room_id = ? AND user_id = ? AND stream_ordering > ?
            """,
            (room_id, user_id, stream_ordering),
        )
        row = txn.fetchone()

        if row:
            notif_count += row[0]

            if row[1] is not None:
                # The unread_count column of event_push_summary is NULLable, so we need
                # to make sure we don't try increasing the unread counts if it's NULL
                # for this row.
                unread_count += row[1]

        return NotifCounts(
            notify_count=notif_count,
            unread_count=unread_count,
            highlight_count=highlight_count,
        )
        def _get_if_maybe_push_in_range_for_user_txn(
                txn: LoggingTransaction) -> bool:
            sql = """
                SELECT 1 FROM event_push_actions
                WHERE user_id = ? AND stream_ordering > ? AND notif = 1
                LIMIT 1
            """

            txn.execute(sql, (user_id, min_stream_ordering))
            return bool(txn.fetchone())
 def f(txn: LoggingTransaction) -> Optional[Tuple[int]]:
     sql = (
         "SELECT e.received_ts"
         " FROM event_push_actions AS ep"
         " JOIN events e ON ep.room_id = e.room_id AND ep.event_id = e.event_id"
         " WHERE ep.stream_ordering > ? AND notif = 1"
         " ORDER BY ep.stream_ordering ASC"
         " LIMIT 1")
     txn.execute(sql, (stream_ordering, ))
     return cast(Optional[Tuple[int]], txn.fetchone())
Exemple #14
0
    def _set_push_rule_enabled_txn(
        self,
        txn: LoggingTransaction,
        stream_id: int,
        event_stream_ordering: int,
        user_id: str,
        rule_id: str,
        enabled: bool,
        is_default_rule: bool,
    ) -> None:
        new_id = self._push_rules_enable_id_gen.get_next()

        if not is_default_rule:
            # first check it exists; we need to lock for key share so that a
            # transaction that deletes the push rule will conflict with this one.
            # We also need a push_rule_enable row to exist for every push_rules
            # row, otherwise it is possible to simultaneously delete a push rule
            # (that has no _enable row) and enable it, resulting in a dangling
            # _enable row. To solve this: we either need to use SERIALISABLE or
            # ensure we always have a push_rule_enable row for every push_rule
            # row. We chose the latter.
            for_key_share = "FOR KEY SHARE"
            if not isinstance(self.database_engine, PostgresEngine):
                # For key share is not applicable/available on SQLite
                for_key_share = ""
            sql = ("""
                SELECT 1 FROM push_rules
                WHERE user_name = ? AND rule_id = ?
                %s
            """ % for_key_share)
            txn.execute(sql, (user_id, rule_id))
            if txn.fetchone() is None:
                raise RuleNotFoundException("Push rule does not exist.")

        self.db_pool.simple_upsert_txn(
            txn,
            "push_rules_enable",
            {
                "user_name": user_id,
                "rule_id": rule_id
            },
            {"enabled": 1 if enabled else 0},
            {"id": new_id},
        )

        self._insert_push_rules_update_txn(
            txn,
            stream_id,
            event_stream_ordering,
            user_id,
            rule_id,
            op="ENABLE" if enabled else "DISABLE",
        )
Exemple #15
0
 def _count_users(txn: LoggingTransaction) -> int:
     # Exclude app service users
     sql = """
         SELECT COUNT(*)
         FROM monthly_active_users
             LEFT JOIN users
             ON monthly_active_users.user_id=users.name
         WHERE (users.appservice_id IS NULL OR users.appservice_id = '');
     """
     txn.execute(sql)
     (count, ) = cast(Tuple[int], txn.fetchone())
     return count
Exemple #16
0
 def _get_current_version(txn: LoggingTransaction, user_id: str) -> int:
     txn.execute(
         "SELECT MAX(version) FROM e2e_room_keys_versions "
         "WHERE user_id=? AND deleted=0",
         (user_id, ),
     )
     # `SELECT MAX() FROM ...` will always return 1 row. The value in that row will
     # be `NULL` when there are no available versions.
     row = cast(Tuple[Optional[int]], txn.fetchone())
     if row[0] is None:
         raise StoreError(404, "No current backup version")
     return row[0]
        def _get_if_user_has_annotated_event(txn: LoggingTransaction) -> bool:
            txn.execute(
                sql,
                (
                    parent_id,
                    RelationTypes.ANNOTATION,
                    event_type,
                    sender,
                    aggregation_key,
                ),
            )

            return bool(txn.fetchone())
        def get_url_cache_txn(txn: LoggingTransaction) -> Optional[Dict[str, Any]]:
            # get the most recently cached result (relative to the given ts)
            sql = (
                "SELECT response_code, etag, expires_ts, og, media_id, download_ts"
                " FROM local_media_repository_url_cache"
                " WHERE url = ? AND download_ts <= ?"
                " ORDER BY download_ts DESC LIMIT 1"
            )
            txn.execute(sql, (url, ts))
            row = txn.fetchone()

            if not row:
                # ...or if we've requested a timestamp older than the oldest
                # copy in the cache, return the oldest copy (if any)
                sql = (
                    "SELECT response_code, etag, expires_ts, og, media_id, download_ts"
                    " FROM local_media_repository_url_cache"
                    " WHERE url = ? AND download_ts > ?"
                    " ORDER BY download_ts ASC LIMIT 1"
                )
                txn.execute(sql, (url, ts))
                row = txn.fetchone()

            if not row:
                return None

            return dict(
                zip(
                    (
                        "response_code",
                        "etag",
                        "expires_ts",
                        "og",
                        "media_id",
                        "download_ts",
                    ),
                    row,
                )
            )
Exemple #19
0
        def f(txn: LoggingTransaction) -> None:
            # first check if they are already in the list
            txn.execute("SELECT 1 FROM erased_users WHERE user_id = ?",
                        (user_id, ))
            if txn.fetchone():
                return

            # they are not already there: do the insert.
            txn.execute("INSERT INTO erased_users (user_id) VALUES (?)",
                        (user_id, ))

            self._invalidate_cache_and_stream(txn, self.is_user_erased,
                                              (user_id, ))
Exemple #20
0
        def get_users_paginate_txn(
            txn: LoggingTransaction, ) -> Tuple[List[JsonDict], int]:
            filters = []
            args = [self.hs.config.server.server_name]

            # Set ordering
            order_by_column = UserSortOrder(order_by).value

            if direction == "b":
                order = "DESC"
            else:
                order = "ASC"

            # `name` is in database already in lower case
            if name:
                filters.append("(name LIKE ? OR LOWER(displayname) LIKE ?)")
                args.extend(
                    ["@%" + name.lower() + "%:%", "%" + name.lower() + "%"])
            elif user_id:
                filters.append("name LIKE ?")
                args.extend(["%" + user_id.lower() + "%"])

            if not guests:
                filters.append("is_guest = 0")

            if not deactivated:
                filters.append("deactivated = 0")

            where_clause = "WHERE " + " AND ".join(filters) if len(
                filters) > 0 else ""

            sql_base = f"""
                FROM users as u
                LEFT JOIN profiles AS p ON u.name = '@' || p.user_id || ':' || ?
                {where_clause}
                """
            sql = "SELECT COUNT(*) as total_users " + sql_base
            txn.execute(sql, args)
            count = cast(Tuple[int], txn.fetchone())[0]

            sql = f"""
                SELECT name, user_type, is_guest, admin, deactivated, shadow_banned,
                displayname, avatar_url, creation_ts * 1000 as creation_ts
                {sql_base}
                ORDER BY {order_by_column} {order}, u.name ASC
                LIMIT ? OFFSET ?
            """
            args += [limit, start]
            txn.execute(sql, args)
            users = self.db_pool.cursor_to_dict(txn)
            return users, count
Exemple #21
0
        def f(txn: LoggingTransaction) -> None:
            # first check if they are already in the list
            txn.execute("SELECT 1 FROM erased_users WHERE user_id = ?",
                        (user_id, ))
            if not txn.fetchone():
                return

            # They are there, delete them.
            self.db_pool.simple_delete_one_txn(txn,
                                               "erased_users",
                                               keyvalues={"user_id": user_id})

            self._invalidate_cache_and_stream(txn, self.is_user_erased,
                                              (user_id, ))
Exemple #22
0
        def _calculate_and_set_initial_state_for_user_txn(
            txn: LoggingTransaction, ) -> Tuple[int, int]:
            pos = self._get_max_stream_id_in_current_state_deltas_txn(txn)

            txn.execute(
                """
                SELECT COUNT(distinct room_id) FROM current_state_events
                    WHERE type = 'm.room.member' AND state_key = ?
                        AND membership = 'join'
                """,
                (user_id, ),
            )
            count = cast(Tuple[int], txn.fetchone())[0]
            return count, pos
Exemple #23
0
        def _remove_dead_devices_from_device_inbox_txn(
            txn: LoggingTransaction, ) -> Tuple[int, bool]:

            if "max_stream_id" in progress:
                max_stream_id = progress["max_stream_id"]
            else:
                txn.execute("SELECT max(stream_id) FROM device_inbox")
                # There's a type mismatch here between how we want to type the row and
                # what fetchone says it returns, but we silence it because we know that
                # res can't be None.
                res: Tuple[
                    Optional[int]] = txn.fetchone()  # type: ignore[assignment]
                if res[0] is None:
                    # this can only happen if the `device_inbox` table is empty, in which
                    # case we have no work to do.
                    return 0, True
                else:
                    max_stream_id = res[0]

            start = progress.get("stream_id", 0)
            stop = start + batch_size

            # delete rows in `device_inbox` which do *not* correspond to a known,
            # unhidden device.
            sql = """
                DELETE FROM device_inbox
                WHERE
                    stream_id >= ? AND stream_id < ?
                    AND NOT EXISTS (
                        SELECT * FROM devices d
                        WHERE
                            d.device_id=device_inbox.device_id
                            AND d.user_id=device_inbox.user_id
                            AND NOT hidden
                    )
                """

            txn.execute(sql, (start, stop))

            self.db_pool.updates._background_update_progress_txn(
                txn,
                self.REMOVE_DEAD_DEVICES_FROM_INBOX,
                {
                    "stream_id": stop,
                    "max_stream_id": max_stream_id,
                },
            )

            return stop > max_stream_id
Exemple #24
0
        def _count_messages(txn: LoggingTransaction) -> int:
            # This is good enough as if you have silly characters in your own
            # hostname then that's your own fault.
            like_clause = "%:" + self.hs.hostname

            sql = """
                SELECT COUNT(*) FROM events
                WHERE type = 'm.room.message'
                    AND sender LIKE ?
                AND stream_ordering > ?
            """

            txn.execute(sql, (like_clause, self.stream_ordering_day_ago))
            (count, ) = cast(Tuple[int], txn.fetchone())
            return count
        def _get_thread_summary_txn(txn: LoggingTransaction) -> bool:
            # Fetch whether the requester has participated or not.
            sql = """
                SELECT 1
                FROM event_relations
                INNER JOIN events USING (event_id)
                WHERE
                    relates_to_id = ?
                    AND room_id = ?
                    AND relation_type = ?
                    AND sender = ?
            """

            txn.execute(sql, (event_id, room_id, RelationTypes.THREAD, user_id))
            return bool(txn.fetchone())
Exemple #26
0
        def _get_session(txn: LoggingTransaction, session_type: str,
                         session_id: str, ts: int) -> JsonDict:
            # This includes the expiry time since items are only periodically
            # deleted, not upon expiry.
            select_sql = """
            SELECT value FROM sessions WHERE
            session_type = ? AND session_id = ? AND expiry_time_ms > ?
            """
            txn.execute(select_sql, [session_type, session_id, ts])
            row = txn.fetchone()

            if not row:
                raise StoreError(404, "No session")

            return db_to_json(row[0])
Exemple #27
0
 def get_type_stream_id_for_appservice_txn(
         txn: LoggingTransaction) -> int:
     stream_id_type = "%s_stream_id" % type
     txn.execute(
         # We do NOT want to escape `stream_id_type`.
         "SELECT %s FROM application_services_state WHERE as_id=?" %
         stream_id_type,
         (service.id, ),
     )
     last_stream_id = txn.fetchone()
     if last_stream_id is None or last_stream_id[
             0] is None:  # no row exists
         # Stream tokens always start from 1, to avoid foot guns around `0` being falsey.
         return 1
     else:
         return int(last_stream_id[0])
Exemple #28
0
 def get_last_seen(txn: LoggingTransaction) -> Optional[int]:
     txn.execute(
         """
         SELECT last_seen FROM user_ips
         WHERE last_seen > ?
         ORDER BY last_seen
         LIMIT 1
         OFFSET ?
         """,
         (begin_last_seen, batch_size),
     )
     row = cast(Optional[Tuple[int]], txn.fetchone())
     if row:
         return row[0]
     else:
         return None
Exemple #29
0
 def _count_users(self, txn: LoggingTransaction, time_from: int) -> int:
     """
     Returns number of users seen in the past time_from period
     """
     sql = """
         SELECT COUNT(*) FROM (
             SELECT user_id FROM user_ips
             WHERE last_seen > ?
             GROUP BY user_id
         ) u
     """
     txn.execute(sql, (time_from, ))
     # Mypy knows that fetchone() might return None if there are no rows.
     # We know better: "SELECT COUNT(...) FROM ..." without any GROUP BY always
     # returns exactly one row.
     (count, ) = cast(Tuple[int], txn.fetchone())
     return count
        def get_local_media_by_user_paginate_txn(
            txn: LoggingTransaction,
        ) -> Tuple[List[Dict[str, Any]], int]:

            # Set ordering
            order_by_column = MediaSortOrder(order_by).value

            if direction == "b":
                order = "DESC"
            else:
                order = "ASC"

            args: List[Union[str, int]] = [user_id]
            sql = """
                SELECT COUNT(*) as total_media
                FROM local_media_repository
                WHERE user_id = ?
            """
            txn.execute(sql, args)
            count = cast(Tuple[int], txn.fetchone())[0]

            sql = """
                SELECT
                    "media_id",
                    "media_type",
                    "media_length",
                    "upload_name",
                    "created_ts",
                    "last_access_ts",
                    "quarantined_by",
                    "safe_from_quarantine"
                FROM local_media_repository
                WHERE user_id = ?
                ORDER BY {order_by_column} {order}, media_id ASC
                LIMIT ? OFFSET ?
            """.format(
                order_by_column=order_by_column,
                order=order,
            )

            args += [limit, start]
            txn.execute(sql, args)
            media = self.db_pool.cursor_to_dict(txn)
            return media, count