예제 #1
0
            def r(
                txn: LoggingTransaction,
            ) -> Tuple[Optional[List[str]], List[Tuple], List[Tuple]]:
                forward_rows = []
                backward_rows = []
                if do_forward[0]:
                    txn.execute(forward_select,
                                (forward_chunk, self.batch_size))
                    forward_rows = txn.fetchall()
                    if not forward_rows:
                        do_forward[0] = False

                if do_backward[0]:
                    txn.execute(backward_select,
                                (backward_chunk, self.batch_size))
                    backward_rows = txn.fetchall()
                    if not backward_rows:
                        do_backward[0] = False

                if forward_rows or backward_rows:
                    headers = [column[0] for column in txn.description]
                else:
                    headers = None

                return headers, forward_rows, backward_rows
예제 #2
0
        def _make_staging_area(txn: LoggingTransaction) -> None:
            sql = ("CREATE TABLE IF NOT EXISTS " + TEMP_TABLE +
                   "_rooms(room_id TEXT NOT NULL, events BIGINT NOT NULL)")
            txn.execute(sql)

            sql = ("CREATE TABLE IF NOT EXISTS " + TEMP_TABLE +
                   "_position(position TEXT NOT NULL)")
            txn.execute(sql)

            # Get rooms we want to process from the database
            sql = """
                SELECT room_id, count(*) FROM current_state_events
                GROUP BY room_id
            """
            txn.execute(sql)
            rooms = list(txn.fetchall())
            self.db_pool.simple_insert_many_txn(txn,
                                                TEMP_TABLE + "_rooms",
                                                keys=("room_id", "events"),
                                                values=rooms)
            del rooms

            sql = ("CREATE TABLE IF NOT EXISTS " + TEMP_TABLE +
                   "_users(user_id TEXT NOT NULL)")
            txn.execute(sql)

            txn.execute("SELECT name FROM users")
            users = list(txn.fetchall())

            self.db_pool.simple_insert_many_txn(txn,
                                                TEMP_TABLE + "_users",
                                                keys=("user_id", ),
                                                values=users)
예제 #3
0
    def _graph_to_linear(self, txn: LoggingTransaction, room_id: str,
                         event_ids: List[str]) -> str:
        """
        Generate a linearized event from a list of events (i.e. a list of forward
        extremities in the room).

        This should allow for calculation of the correct read receipt even if
        servers have different event ordering.

        Args:
            txn: The transaction
            room_id: The room ID the events are in.
            event_ids: The list of event IDs to linearize.

        Returns:
            The linearized event ID.
        """
        # TODO: Make this better.
        clause, args = make_in_list_sql_clause(self.database_engine,
                                               "event_id", event_ids)

        sql = """
            SELECT event_id WHERE room_id = ? AND stream_ordering IN (
                SELECT max(stream_ordering) WHERE %s
            )
        """ % (clause, )

        txn.execute(sql, [room_id] + list(args))
        rows = txn.fetchall()
        if rows:
            return rows[0][0]
        else:
            raise RuntimeError("Unrecognized event_ids: %r" % (event_ids, ))
        def f(
            txn: LoggingTransaction,
        ) -> List[Tuple[str, str, int, int, str, bool, str, int]]:
            before_clause = ""
            if before:
                before_clause = "AND epa.stream_ordering < ?"
                args = [user_id, before, limit]
            else:
                args = [user_id, limit]

            if only_highlight:
                if len(before_clause) > 0:
                    before_clause += " "
                before_clause += "AND epa.highlight = 1"

            # NB. This assumes event_ids are globally unique since
            # it makes the query easier to index
            sql = (
                "SELECT epa.event_id, epa.room_id,"
                " epa.stream_ordering, epa.topological_ordering,"
                " epa.actions, epa.highlight, epa.profile_tag, e.received_ts"
                " FROM event_push_actions epa, events e"
                " WHERE epa.event_id = e.event_id"
                " AND epa.user_id = ? %s"
                " AND epa.notif = 1"
                " ORDER BY epa.stream_ordering DESC"
                " LIMIT ?" % (before_clause, ))
            txn.execute(sql, args)
            return cast(List[Tuple[str, str, int, int, str, bool, str, int]],
                        txn.fetchall())
 def get_after_receipt(
     txn: LoggingTransaction,
 ) -> List[Tuple[str, str, int, str, bool]]:
     # find rooms that have a read receipt in them and return the next
     # push actions
     sql = (
         "SELECT ep.event_id, ep.room_id, ep.stream_ordering, ep.actions,"
         "   ep.highlight "
         " FROM ("
         "   SELECT room_id,"
         "       MAX(stream_ordering) as stream_ordering"
         "   FROM events"
         "   INNER JOIN receipts_linearized USING (room_id, event_id)"
         "   WHERE receipt_type = 'm.read' AND user_id = ?"
         "   GROUP BY room_id"
         ") AS rl,"
         " event_push_actions AS ep"
         " WHERE"
         "   ep.room_id = rl.room_id"
         "   AND ep.stream_ordering > rl.stream_ordering"
         "   AND ep.user_id = ?"
         "   AND ep.stream_ordering > ?"
         "   AND ep.stream_ordering <= ?"
         "   AND ep.notif = 1"
         " ORDER BY ep.stream_ordering ASC LIMIT ?")
     args = [
         user_id, user_id, min_stream_ordering, max_stream_ordering,
         limit
     ]
     txn.execute(sql, args)
     return cast(List[Tuple[str, str, int, str, bool]], txn.fetchall())
 def get_no_receipt(
     txn: LoggingTransaction,
 ) -> List[Tuple[str, str, int, str, bool, int]]:
     sql = (
         "SELECT ep.event_id, ep.room_id, ep.stream_ordering, ep.actions,"
         "   ep.highlight, e.received_ts"
         " FROM event_push_actions AS ep"
         " INNER JOIN events AS e USING (room_id, event_id)"
         " WHERE"
         "   ep.room_id NOT IN ("
         "     SELECT room_id FROM receipts_linearized"
         "       WHERE receipt_type = 'm.read' AND user_id = ?"
         "       GROUP BY room_id"
         "   )"
         "   AND ep.user_id = ?"
         "   AND ep.stream_ordering > ?"
         "   AND ep.stream_ordering <= ?"
         "   AND ep.notif = 1"
         " ORDER BY ep.stream_ordering DESC LIMIT ?")
     args = [
         user_id, user_id, min_stream_ordering, max_stream_ordering,
         limit
     ]
     txn.execute(sql, args)
     return cast(List[Tuple[str, str, int, str, bool, int]],
                 txn.fetchall())
예제 #7
0
    def _delete_old_ui_auth_sessions_txn(self, txn: LoggingTransaction,
                                         expiration_time: int):
        # Get the expired sessions.
        sql = "SELECT session_id FROM ui_auth_sessions WHERE creation_time <= ?"
        txn.execute(sql, [expiration_time])
        session_ids = [r[0] for r in txn.fetchall()]

        # Delete the corresponding IP/user agents.
        self.db_pool.simple_delete_many_txn(
            txn,
            table="ui_auth_sessions_ips",
            column="session_id",
            iterable=session_ids,
            keyvalues={},
        )

        # Delete the corresponding completed credentials.
        self.db_pool.simple_delete_many_txn(
            txn,
            table="ui_auth_sessions_credentials",
            column="session_id",
            iterable=session_ids,
            keyvalues={},
        )

        # Finally, delete the sessions.
        self.db_pool.simple_delete_many_txn(
            txn,
            table="ui_auth_sessions",
            column="session_id",
            iterable=session_ids,
            keyvalues={},
        )
예제 #8
0
    def _mark_as_sent_devices_by_remote_txn(self, txn: LoggingTransaction,
                                            destination: str,
                                            stream_id: int) -> None:
        # We update the device_lists_outbound_last_success with the successfully
        # poked users.
        sql = """
            SELECT user_id, coalesce(max(o.stream_id), 0)
            FROM device_lists_outbound_pokes as o
            WHERE destination = ? AND o.stream_id <= ?
            GROUP BY user_id
        """
        txn.execute(sql, (destination, stream_id))
        rows = txn.fetchall()

        self.db_pool.simple_upsert_many_txn(
            txn=txn,
            table="device_lists_outbound_last_success",
            key_names=("destination", "user_id"),
            key_values=((destination, user_id) for user_id, _ in rows),
            value_names=("stream_id", ),
            value_values=((stream_id, ) for _, stream_id in rows),
        )

        # Delete all sent outbound pokes
        sql = """
            DELETE FROM device_lists_outbound_pokes
            WHERE destination = ? AND stream_id <= ?
        """
        txn.execute(sql, (destination, stream_id))
예제 #9
0
    def _get_e2e_cross_signing_signatures_for_devices_txn(
        self, txn: LoggingTransaction, device_query: Iterable[Tuple[str, str]]
    ) -> List[Tuple[str, str, str, str]]:
        """Get cross-signing signatures for a given list of devices

        Returns signatures made by the owners of the devices.

        Returns: a list of results; each entry in the list is a tuple of
            (user_id, key_id, target_device_id, signature).
        """
        signature_query_clauses = []
        signature_query_params = []

        for (user_id, device_id) in device_query:
            signature_query_clauses.append(
                "target_user_id = ? AND target_device_id = ? AND user_id = ?")
            signature_query_params.extend([user_id, device_id, user_id])

        signature_sql = """
            SELECT user_id, key_id, target_device_id, signature
            FROM e2e_cross_signing_signatures WHERE %s
            """ % (" OR ".join("(" + q + ")" for q in signature_query_clauses))

        txn.execute(signature_sql, signature_query_params)
        return cast(
            List[Tuple[str, str, str, str, ]],
            txn.fetchall(),
        )
예제 #10
0
        def reindex_search_txn(txn: LoggingTransaction) -> Tuple[int, bool]:
            sql = (
                "UPDATE event_search AS es SET stream_ordering = e.stream_ordering,"
                " origin_server_ts = e.origin_server_ts"
                " FROM events AS e"
                " WHERE e.event_id = es.event_id"
                " AND ? <= e.stream_ordering AND e.stream_ordering < ?"
                " RETURNING es.stream_ordering")

            min_stream_id = max_stream_id - batch_size
            txn.execute(sql, (min_stream_id, max_stream_id))
            rows = txn.fetchall()

            if min_stream_id < target_min_stream_id:
                # We've recached the end.
                return len(rows), False

            progress = {
                "target_min_stream_id_inclusive": target_min_stream_id,
                "max_stream_id_exclusive": min_stream_id,
                "rows_inserted": rows_inserted + len(rows),
                "have_added_indexes": True,
            }

            self.db_pool.updates._background_update_progress_txn(
                txn, self.EVENT_SEARCH_ORDER_UPDATE_NAME, progress)

            return len(rows), True
예제 #11
0
        def _get_next_batch(
            txn: LoggingTransaction,
        ) -> Optional[Sequence[Tuple[str, int]]]:
            # Only fetch 250 rooms, so we don't fetch too many at once, even
            # if those 250 rooms have less than batch_size state events.
            sql = """
                SELECT room_id, events FROM %s
                ORDER BY events DESC
                LIMIT 250
            """ % (
                TEMP_TABLE + "_rooms",
            )
            txn.execute(sql)
            rooms_to_work_on = cast(List[Tuple[str, int]], txn.fetchall())

            if not rooms_to_work_on:
                return None

            # Get how many are left to process, so we can give status on how
            # far we are in processing
            txn.execute("SELECT COUNT(*) FROM " + TEMP_TABLE + "_rooms")
            result = txn.fetchone()
            assert result is not None
            progress["remaining"] = result[0]

            return rooms_to_work_on
예제 #12
0
 def get_updated_global_account_data_txn(
     txn: LoggingTransaction, ) -> List[Tuple[int, str, str]]:
     sql = ("SELECT stream_id, user_id, account_data_type"
            " FROM account_data WHERE ? < stream_id AND stream_id <= ?"
            " ORDER BY stream_id ASC LIMIT ?")
     txn.execute(sql, (last_id, current_id, limit))
     return cast(List[Tuple[int, str, str]], txn.fetchall())
예제 #13
0
        def r(txn: LoggingTransaction) -> Tuple[List[str], List[Tuple]]:
            txn.execute(select)
            rows = txn.fetchall()
            headers: List[str] = [column[0] for column in txn.description]

            ts_ind = headers.index("ts")

            return headers, [r for r in rows if r[ts_ind] < yesterday]
예제 #14
0
        def _get_applicable_edits_txn(
                txn: LoggingTransaction) -> Dict[str, str]:
            clause, args = make_in_list_sql_clause(txn.database_engine,
                                                   "relates_to_id", event_ids)
            args.append(RelationTypes.REPLACE)

            txn.execute(sql % (clause, ), args)
            return dict(cast(Iterable[Tuple[str, str]], txn.fetchall()))
예제 #15
0
 def get_all_updated_tags_txn(
     txn: LoggingTransaction, ) -> List[Tuple[int, str, str]]:
     sql = ("SELECT stream_id, user_id, room_id"
            " FROM room_tags_revisions as r"
            " WHERE ? < stream_id AND stream_id <= ?"
            " ORDER BY stream_id ASC LIMIT ?")
     txn.execute(sql, (last_id, current_id, limit))
     # mypy doesn't understand what the query is selecting.
     return cast(List[Tuple[int, str, str]], txn.fetchall())
예제 #16
0
    def _get_max_topological_txn(self, txn: LoggingTransaction,
                                 room_id: str) -> int:
        txn.execute(
            "SELECT MAX(topological_ordering) FROM events WHERE room_id = ?",
            (room_id, ),
        )

        rows = txn.fetchall()
        return rows[0][0] if rows else 0
예제 #17
0
        def _devices_last_seen_update_txn(txn: LoggingTransaction) -> int:
            # This consists of two queries:
            #
            #   1. The sub-query searches for the next N devices and joins
            #      against user_ips to find the max last_seen associated with
            #      that device.
            #   2. The outer query then joins again against user_ips on
            #      user/device/last_seen. This *should* hopefully only
            #      return one row, but if it does return more than one then
            #      we'll just end up updating the same device row multiple
            #      times, which is fine.

            where_args: List[Union[str, int]]
            where_clause, where_args = make_tuple_comparison_clause(
                [("user_id", last_user_id), ("device_id", last_device_id)], )

            sql = """
                SELECT
                    last_seen, ip, user_agent, user_id, device_id
                FROM (
                    SELECT
                        user_id, device_id, MAX(u.last_seen) AS last_seen
                    FROM devices
                    INNER JOIN user_ips AS u USING (user_id, device_id)
                    WHERE %(where_clause)s
                    GROUP BY user_id, device_id
                    ORDER BY user_id ASC, device_id ASC
                    LIMIT ?
                ) c
                INNER JOIN user_ips AS u USING (user_id, device_id, last_seen)
            """ % {
                "where_clause": where_clause
            }
            txn.execute(sql, where_args + [batch_size])

            rows = cast(List[Tuple[int, str, str, str, str]], txn.fetchall())
            if not rows:
                return 0

            sql = """
                UPDATE devices
                SET last_seen = ?, ip = ?, user_agent = ?
                WHERE user_id = ? AND device_id = ?
            """
            txn.execute_batch(sql, rows)

            _, _, _, user_id, device_id = rows[-1]
            self.db_pool.updates._background_update_progress_txn(
                txn,
                "devices_last_seen",
                {
                    "last_user_id": user_id,
                    "last_device_id": device_id
                },
            )

            return len(rows)
예제 #18
0
        def _remove_hidden_devices_from_device_inbox_txn(
            txn: LoggingTransaction, ) -> int:
            """stream_id is not unique
            we need to use an inclusive `stream_id >= ?` clause,
            since we might not have deleted all hidden device messages for the stream_id
            returned from the previous query

            Then delete only rows matching the `(user_id, device_id, stream_id)` tuple,
            to avoid problems of deleting a large number of rows all at once
            due to a single device having lots of device messages.
            """

            last_stream_id = progress.get("stream_id", 0)

            sql = """
                SELECT device_id, user_id, stream_id
                FROM device_inbox
                WHERE
                    stream_id >= ?
                    AND (device_id, user_id) IN (
                        SELECT device_id, user_id FROM devices WHERE hidden = ?
                    )
                ORDER BY stream_id
                LIMIT ?
            """

            txn.execute(sql, (last_stream_id, True, batch_size))
            rows = txn.fetchall()

            num_deleted = 0
            for row in rows:
                num_deleted += self.db_pool.simple_delete_txn(
                    txn,
                    "device_inbox",
                    {
                        "device_id": row[0],
                        "user_id": row[1],
                        "stream_id": row[2]
                    },
                )

            if rows:
                # We don't just save the `stream_id` in progress as
                # otherwise it can happen in large deployments that
                # no change of status is visible in the log file, as
                # it may be that the stream_id does not change in several runs
                self.db_pool.updates._background_update_progress_txn(
                    txn,
                    self.REMOVE_HIDDEN_DEVICES,
                    {
                        "device_id": rows[-1][0],
                        "user_id": rows[-1][1],
                        "stream_id": rows[-1][2],
                    },
                )

            return num_deleted
예제 #19
0
        def reindex_search_txn(txn: LoggingTransaction) -> int:
            sql = ("SELECT stream_ordering, event_id FROM events"
                   " WHERE ? <= stream_ordering AND stream_ordering < ?"
                   " ORDER BY stream_ordering DESC"
                   " LIMIT ?")

            txn.execute(sql, (target_min_stream_id, max_stream_id, batch_size))

            rows = txn.fetchall()
            if not rows:
                return 0

            min_stream_id = rows[-1][0]
            event_ids = [row[1] for row in rows]

            rows_to_update = []

            chunks = [
                event_ids[i:i + 100] for i in range(0, len(event_ids), 100)
            ]
            for chunk in chunks:
                ev_rows = self.db_pool.simple_select_many_txn(
                    txn,
                    table="event_json",
                    column="event_id",
                    iterable=chunk,
                    retcols=["event_id", "json"],
                    keyvalues={},
                )

                for row in ev_rows:
                    event_id = row["event_id"]
                    event_json = db_to_json(row["json"])
                    try:
                        origin_server_ts = event_json["origin_server_ts"]
                    except (KeyError, AttributeError):
                        # If the event is missing a necessary field then
                        # skip over it.
                        continue

                    rows_to_update.append((origin_server_ts, event_id))

            sql = "UPDATE events SET origin_server_ts = ? WHERE event_id = ?"

            txn.execute_batch(sql, rows_to_update)

            progress = {
                "target_min_stream_id_inclusive": target_min_stream_id,
                "max_stream_id_exclusive": min_stream_id,
                "rows_inserted": rows_inserted + len(rows_to_update),
            }

            self.db_pool.updates._background_update_progress_txn(
                txn, _BackgroundUpdates.EVENT_ORIGIN_SERVER_TS_NAME, progress)

            return len(rows_to_update)
예제 #20
0
        def _list_users(txn: LoggingTransaction) -> List[Tuple[str, str]]:
            sql = f"""
                    SELECT COALESCE(appservice_id, 'native'), user_id
                    FROM monthly_active_users
                    LEFT JOIN users ON monthly_active_users.user_id=users.name
                    {where_clause};
                """

            txn.execute(sql, query_params)
            return cast(List[Tuple[str, str]], txn.fetchall())
예제 #21
0
 def f(txn: LoggingTransaction) -> List[Tuple[str, str, int, int]]:
     sql = ("SELECT rl.room_id, rl.event_id,"
            " e.topological_ordering, e.stream_ordering"
            " FROM receipts_linearized AS rl"
            " INNER JOIN events AS e USING (room_id, event_id)"
            " WHERE rl.room_id = e.room_id"
            " AND rl.event_id = e.event_id"
            " AND user_id = ?")
     txn.execute(sql, (user_id, ))
     return txn.fetchall()
예제 #22
0
 def _get_event_relations(
     txn: LoggingTransaction, ) -> Dict[str, Set[Tuple[str, str]]]:
     txn.execute(sql, [event_id] + rel_type_args)
     result: Dict[str, Set[Tuple[str, str]]] = {
         rel_type: set()
         for rel_type in relation_types
     }
     for rel_type, sender, type in txn.fetchall():
         result[rel_type].add((sender, type))
     return result
예제 #23
0
        def purged_chain_cover_txn(txn: LoggingTransaction) -> int:
            # The event ID from events will be null if the chain ID / sequence
            # number points to a purged event.
            sql = """
                SELECT event_id, chain_id, sequence_number, e.event_id IS NOT NULL
                FROM event_auth_chains
                LEFT JOIN events AS e USING (event_id)
                WHERE event_id > ? ORDER BY event_auth_chains.event_id ASC LIMIT ?
            """
            txn.execute(sql, (current_event_id, batch_size))

            rows = txn.fetchall()
            if not rows:
                return 0

            # The event IDs and chain IDs / sequence numbers where the event has
            # been purged.
            unreferenced_event_ids = []
            unreferenced_chain_id_tuples = []
            event_id = ""
            for event_id, chain_id, sequence_number, has_event in rows:
                if not has_event:
                    unreferenced_event_ids.append((event_id, ))
                    unreferenced_chain_id_tuples.append(
                        (chain_id, sequence_number))

            # Delete the unreferenced auth chains from event_auth_chain_links and
            # event_auth_chains.
            txn.executemany(
                """
                DELETE FROM event_auth_chains WHERE event_id = ?
                """,
                unreferenced_event_ids,
            )
            # We should also delete matching target_*, but there is no index on
            # target_chain_id. Hopefully any purged events are due to a room
            # being fully purged and they will be removed from the origin_*
            # searches.
            txn.executemany(
                """
                DELETE FROM event_auth_chain_links WHERE
                origin_chain_id = ? AND origin_sequence_number = ?
                """,
                unreferenced_chain_id_tuples,
            )

            progress = {
                "current_event_id": event_id,
            }

            self.db_pool.updates._background_update_progress_txn(
                txn, "purged_chain_cover", progress)

            return len(rows)
예제 #24
0
 def get_recent(txn: LoggingTransaction) -> List[Tuple[str, str, str, int]]:
     txn.execute(
         """
         SELECT access_token, ip, user_agent, last_seen FROM user_ips
         WHERE last_seen >= ? AND user_id = ?
         ORDER BY last_seen
         DESC
         """,
         (since_ts, user_id),
     )
     return cast(List[Tuple[str, str, str, int]], txn.fetchall())
예제 #25
0
        def _count_users_by_service(txn: LoggingTransaction) -> Dict[str, int]:
            sql = """
                SELECT COALESCE(appservice_id, 'native'), COUNT(*)
                FROM monthly_active_users
                LEFT JOIN users ON monthly_active_users.user_id=users.name
                GROUP BY appservice_id;
            """

            txn.execute(sql)
            result = cast(List[Tuple[str, int]], txn.fetchall())
            return dict(result)
예제 #26
0
        def get_start_id(txn: LoggingTransaction) -> int:
            txn.execute(
                "SELECT rowid FROM sent_transactions WHERE ts >= ?"
                " ORDER BY rowid ASC LIMIT 1",
                (yesterday, ),
            )

            rows = txn.fetchall()
            if rows:
                return rows[0][0]
            else:
                return 1
예제 #27
0
 def fetch(txn: LoggingTransaction) -> List[Tuple[int, int]]:
     txn.execute("""
         SELECT t1.c, t2.c
         FROM (
             SELECT room_id, COUNT(*) c FROM event_forward_extremities
             GROUP BY room_id
         ) t1 LEFT JOIN (
             SELECT room_id, COUNT(*) c FROM current_state_events
             GROUP BY room_id
         ) t2 ON t1.room_id = t2.room_id
         """)
     return cast(List[Tuple[int, int]], txn.fetchall())
예제 #28
0
        def get_user_id_for_token_txn(txn: LoggingTransaction) -> Optional[str]:
            sql = (
                "SELECT user_id FROM open_id_tokens"
                " WHERE token = ? AND ? <= ts_valid_until_ms"
            )

            txn.execute(sql, (token, ts_now_ms))

            rows = txn.fetchall()
            if not rows:
                return None
            else:
                return rows[0][0]
예제 #29
0
        def reindex_txn(txn: LoggingTransaction) -> int:
            sql = ("SELECT stream_ordering, event_id, json FROM events"
                   " INNER JOIN event_json USING (event_id)"
                   " WHERE ? <= stream_ordering AND stream_ordering < ?"
                   " ORDER BY stream_ordering DESC"
                   " LIMIT ?")

            txn.execute(sql, (target_min_stream_id, max_stream_id, batch_size))

            rows = txn.fetchall()
            if not rows:
                return 0

            min_stream_id = rows[-1][0]

            update_rows = []
            for row in rows:
                try:
                    event_id = row[1]
                    event_json = db_to_json(row[2])
                    sender = event_json["sender"]
                    content = event_json["content"]

                    contains_url = "url" in content
                    if contains_url:
                        contains_url &= isinstance(content["url"], str)
                except (KeyError, AttributeError):
                    # If the event is missing a necessary field then
                    # skip over it.
                    continue

                update_rows.append((sender, contains_url, event_id))

            sql = "UPDATE events SET sender = ?, contains_url = ? WHERE event_id = ?"

            txn.execute_batch(sql, update_rows)

            progress = {
                "target_min_stream_id_inclusive": target_min_stream_id,
                "max_stream_id_exclusive": min_stream_id,
                "rows_inserted": rows_inserted + len(rows),
            }

            self.db_pool.updates._background_update_progress_txn(
                txn, _BackgroundUpdates.EVENT_FIELDS_SENDER_URL_UPDATE_NAME,
                progress)

            return len(rows)
예제 #30
0
        def f(txn: LoggingTransaction) -> Set[str]:
            highlight_words = set()
            for event in events:
                # As a hack we simply join values of all possible keys. This is
                # fine since we're only using them to find possible highlights.
                values = []
                for key in ("body", "name", "topic"):
                    v = event.content.get(key, None)
                    if v:
                        v = _clean_value_for_search(v)
                        values.append(v)

                if not values:
                    continue

                value = " ".join(values)

                # We need to find some values for StartSel and StopSel that
                # aren't in the value so that we can pick results out.
                start_sel = "<"
                stop_sel = ">"

                while start_sel in value:
                    start_sel += "<"
                while stop_sel in value:
                    stop_sel += ">"

                query = "SELECT ts_headline(?, to_tsquery('english', ?), %s)" % (
                    _to_postgres_options({
                        "StartSel": start_sel,
                        "StopSel": stop_sel,
                        "MaxFragments": "50",
                    }))
                txn.execute(query, (value, search_query))
                (headline, ) = txn.fetchall()[0]

                # Now we need to pick the possible highlights out of the haedline
                # result.
                matcher_regex = "%s(.*?)%s" % (
                    re.escape(start_sel),
                    re.escape(stop_sel),
                )

                res = re.findall(matcher_regex, headline)
                highlight_words.update([r.lower() for r in res])

            return highlight_words