Exemple #1
0
    def get_next_txn(self, txn: LoggingTransaction):
        """
        Usage:

            stream_id = stream_id_gen.get_next(txn)
            # ... persist event ...
        """

        next_id = self._load_next_id_txn(txn)

        with self._lock:
            self._unfinished_ids.add(next_id)

        txn.call_after(self._mark_id_as_finished, next_id)
        txn.call_on_exception(self._mark_id_as_finished, next_id)

        # Update the `stream_positions` table with newly updated stream
        # ID (unless self._writers is not set in which case we don't
        # bother, as nothing will read it).
        #
        # We only do this on the success path so that the persisted current
        # position points to a persisted row with the correct instance name.
        if self._writers:
            txn.call_after(
                run_as_background_process,
                "MultiWriterIdGenerator._update_table",
                self._db.runInteraction,
                "MultiWriterIdGenerator._update_table",
                self._update_stream_positions_table_txn,
            )

        return self._return_factor * next_id
Exemple #2
0
        def delete_expired_event_txn(txn: LoggingTransaction) -> None:
            # Delete the expiry timestamp associated with this event from the database.
            self._delete_event_expiry_txn(txn, event_id)

            if not event:
                # If we can't find the event, log a warning and delete the expiry date
                # from the database so that we don't try to expire it again in the
                # future.
                logger.warning(
                    "Can't expire event %s because we don't have it.", event_id
                )
                return

            # Prune the event's dict then convert it to JSON.
            pruned_json = json_encoder.encode(
                prune_event_dict(event.room_version, event.get_dict())
            )

            # Update the event_json table to replace the event's JSON with the pruned
            # JSON.
            self._censor_event_txn(txn, event.event_id, pruned_json)

            # We need to invalidate the event cache entry for this event because we
            # changed its content in the database. We can't call
            # self._invalidate_cache_and_stream because self.get_event_cache isn't of the
            # right type.
            txn.call_after(self._get_event_cache.invalidate, (event.event_id,))
            # Send that invalidation to replication so that other workers also invalidate
            # the event cache.
            self._send_invalidation_to_replication(
                txn, "_get_event_cache", (event.event_id,)
            )
    def set_expiration_date_for_user_txn(
        self,
        txn: LoggingTransaction,
        user_id: str,
    ):
        """Sets an expiration date to the account with the given user ID.

        Args:
            user_id: User ID to set an expiration date for.
        """
        now_ms = self._api.current_time_ms()
        expiration_ts = now_ms + self._period

        sql = """
        INSERT INTO email_account_validity (user_id, expiration_ts_ms, email_sent)
        VALUES (?, ?, ?)
        ON CONFLICT (user_id) DO
            UPDATE SET
                expiration_ts_ms = EXCLUDED.expiration_ts_ms,
                email_sent = EXCLUDED.email_sent
        """

        txn.execute(sql, (user_id, expiration_ts, False))

        txn.call_after(self.get_expiration_ts_for_user.invalidate, (user_id, ))
Exemple #4
0
    def _update_state_for_partial_state_event_txn(
        self,
        txn: LoggingTransaction,
        event: EventBase,
        context: EventContext,
    ) -> None:
        # we shouldn't have any outliers here
        assert not event.internal_metadata.is_outlier()

        # anything that was rejected should have the same state as its
        # predecessor.
        if context.rejected:
            assert context.state_group == context.state_group_before_event

        self.db_pool.simple_update_txn(
            txn,
            table="event_to_state_groups",
            keyvalues={"event_id": event.event_id},
            updatevalues={"state_group": context.state_group},
        )

        self.db_pool.simple_delete_one_txn(
            txn,
            table="partial_state_events",
            keyvalues={"event_id": event.event_id},
        )

        # TODO(faster_joins): need to do something about workers here
        txn.call_after(self.is_partial_state_event.invalidate,
                       (event.event_id, ))
        txn.call_after(
            self._get_state_group_for_event.prefill,
            (event.event_id, ),
            context.state_group,
        )
Exemple #5
0
    def get_next_txn(self, txn: LoggingTransaction) -> int:
        """
        Usage:

            stream_id = stream_id_gen.get_next(txn)
            # ... persist event ...
        """

        # If we have a list of instances that are allowed to write to this
        # stream, make sure we're in it.
        if self._writers and self._instance_name not in self._writers:
            raise Exception("Tried to allocate stream ID on non-writer")

        next_id = self._load_next_id_txn(txn)

        txn.call_after(self._mark_id_as_finished, next_id)
        txn.call_on_exception(self._mark_id_as_finished, next_id)

        # Update the `stream_positions` table with newly updated stream
        # ID (unless self._writers is not set in which case we don't
        # bother, as nothing will read it).
        #
        # We only do this on the success path so that the persisted current
        # position points to a persisted row with the correct instance name.
        if self._writers:
            txn.call_after(
                run_as_background_process,
                "MultiWriterIdGenerator._update_table",
                self._db.runInteraction,
                "MultiWriterIdGenerator._update_table",
                self._update_stream_positions_table_txn,
            )

        return self._return_factor * next_id
Exemple #6
0
    def _invalidate_state_caches_and_stream(
            self, txn: LoggingTransaction, room_id: str,
            members_changed: Collection[str]) -> None:
        """Special case invalidation of caches based on current state.

        We special case this so that we can batch the cache invalidations into a
        single replication poke.

        Args:
            txn
            room_id: Room where state changed
            members_changed: The user_ids of members that have changed
        """
        txn.call_after(self._invalidate_state_caches, room_id, members_changed)

        if members_changed:
            # We need to be careful that the size of the `members_changed` list
            # isn't so large that it causes problems sending over replication, so we
            # send them in chunks.
            # Max line length is 16K, and max user ID length is 255, so 50 should
            # be safe.
            for chunk in batch_iter(members_changed, 50):
                keys = itertools.chain([room_id], chunk)
                self._send_invalidation_to_replication(
                    txn, CURRENT_STATE_CACHE_NAME, keys)
        else:
            # if no members changed, we still need to invalidate the other caches.
            self._send_invalidation_to_replication(txn,
                                                   CURRENT_STATE_CACHE_NAME,
                                                   [room_id])
Exemple #7
0
    def _add_device_change_to_stream_txn(
        self,
        txn: LoggingTransaction,
        user_id: str,
        device_ids: Collection[str],
        stream_ids: List[str],
    ):
        txn.call_after(
            self._device_list_stream_cache.entity_has_changed, user_id, stream_ids[-1],
        )

        min_stream_id = stream_ids[0]

        # Delete older entries in the table, as we really only care about
        # when the latest change happened.
        txn.executemany(
            """
            DELETE FROM device_lists_stream
            WHERE user_id = ? AND device_id = ? AND stream_id < ?
            """,
            [(user_id, device_id, min_stream_id) for device_id in device_ids],
        )

        self.db.simple_insert_many_txn(
            txn,
            table="device_lists_stream",
            values=[
                {"stream_id": stream_id, "user_id": user_id, "device_id": device_id}
                for stream_id, device_id in zip(stream_ids, device_ids)
            ],
        )
Exemple #8
0
    def _invalidate_all_cache_and_stream(self, txn: LoggingTransaction,
                                         cache_func: _CachedFunction) -> None:
        """Invalidates the entire cache and adds it to the cache stream so slaves
        will know to invalidate their caches.
        """

        txn.call_after(cache_func.invalidate_all)
        self._send_invalidation_to_replication(txn, cache_func.__name__, None)
Exemple #9
0
    def _invalidate_cache_and_stream(
        self,
        txn: LoggingTransaction,
        cache_func: _CachedFunction,
        keys: Tuple[Any, ...],
    ) -> None:
        """Invalidates the cache and adds it to the cache stream so slaves
        will know to invalidate their caches.

        This should only be used to invalidate caches where slaves won't
        otherwise know from other replication streams that the cache should
        be invalidated.
        """
        txn.call_after(cache_func.invalidate, keys)
        self._send_invalidation_to_replication(txn, cache_func.__name__, keys)
Exemple #10
0
    def _update_presence_txn(
        self,
        txn: LoggingTransaction,
        stream_orderings: List[int],
        presence_states: List[UserPresenceState],
    ) -> None:
        for stream_id, state in zip(stream_orderings, presence_states):
            txn.call_after(self.presence_stream_cache.entity_has_changed,
                           state.user_id, stream_id)
            txn.call_after(self._get_presence_for_user.invalidate,
                           (state.user_id, ))

        # Delete old rows to stop database from getting really big
        sql = "DELETE FROM presence_stream WHERE stream_id < ? AND "

        for states in batch_iter(presence_states, 50):
            clause, args = make_in_list_sql_clause(self.database_engine,
                                                   "user_id",
                                                   [s.user_id for s in states])
            txn.execute(sql + clause, [stream_id] + list(args))

        # Actually insert new rows
        self.db_pool.simple_insert_many_txn(
            txn,
            table="presence_stream",
            keys=(
                "stream_id",
                "user_id",
                "state",
                "last_active_ts",
                "last_federation_update_ts",
                "last_user_sync_ts",
                "status_msg",
                "currently_active",
                "instance_name",
            ),
            values=[(
                stream_id,
                state.user_id,
                state.state,
                state.last_active_ts,
                state.last_federation_update_ts,
                state.last_user_sync_ts,
                state.status_msg,
                state.currently_active,
                self._instance_name,
            ) for stream_id, state in zip(stream_orderings, presence_states)],
        )
Exemple #11
0
        def _update_profile_in_user_dir_txn(txn: LoggingTransaction) -> None:
            self.db_pool.simple_upsert_txn(
                txn,
                table="user_directory",
                keyvalues={"user_id": user_id},
                values={
                    "display_name": display_name,
                    "avatar_url": avatar_url
                },
                lock=False,  # We're only inserter
            )

            if isinstance(self.database_engine, PostgresEngine):
                # We weight the localpart most highly, then display name and finally
                # server name
                sql = """
                        INSERT INTO user_directory_search(user_id, vector)
                        VALUES (?,
                            setweight(to_tsvector('simple', ?), 'A')
                            || setweight(to_tsvector('simple', ?), 'D')
                            || setweight(to_tsvector('simple', COALESCE(?, '')), 'B')
                        ) ON CONFLICT (user_id) DO UPDATE SET vector=EXCLUDED.vector
                    """
                txn.execute(
                    sql,
                    (
                        user_id,
                        get_localpart_from_id(user_id),
                        get_domain_from_id(user_id),
                        display_name,
                    ),
                )
            elif isinstance(self.database_engine, Sqlite3Engine):
                value = "%s %s" % (user_id,
                                   display_name) if display_name else user_id
                self.db_pool.simple_upsert_txn(
                    txn,
                    table="user_directory_search",
                    keyvalues={"user_id": user_id},
                    values={"value": value},
                    lock=False,  # We're only inserter
                )
            else:
                # This should be unreachable.
                raise Exception("Unrecognized database engine")

            txn.call_after(self.get_user_in_directory.invalidate, (user_id, ))
    def _remove_old_push_actions_before_txn(self, txn: LoggingTransaction,
                                            room_id: str, user_id: str,
                                            stream_ordering: int) -> None:
        """
        Purges old push actions for a user and room before a given
        stream_ordering.

        We however keep a months worth of highlighted notifications, so that
        users can still get a list of recent highlights.

        Args:
            txn: The transcation
            room_id: Room ID to delete from
            user_id: user ID to delete for
            stream_ordering: The lowest stream ordering which will
                                  not be deleted.
        """
        txn.call_after(
            self.get_unread_event_push_actions_by_room_for_user.invalidate,
            (room_id, user_id),
        )

        # We need to join on the events table to get the received_ts for
        # event_push_actions and sqlite won't let us use a join in a delete so
        # we can't just delete where received_ts < x. Furthermore we can
        # only identify event_push_actions by a tuple of room_id, event_id
        # we we can't use a subquery.
        # Instead, we look up the stream ordering for the last event in that
        # room received before the threshold time and delete event_push_actions
        # in the room with a stream_odering before that.
        txn.execute(
            "DELETE FROM event_push_actions "
            " WHERE user_id = ? AND room_id = ? AND "
            " stream_ordering <= ?"
            " AND ((stream_ordering < ? AND highlight = 1) or highlight = 0)",
            (user_id, room_id, stream_ordering,
             self.stream_ordering_month_ago),
        )

        txn.execute(
            """
            DELETE FROM event_push_summary
            WHERE room_id = ? AND user_id = ? AND stream_ordering <= ?
        """,
            (room_id, user_id, stream_ordering),
        )
Exemple #13
0
    def get_next_txn(self, txn: LoggingTransaction):
        """
        Usage:

            stream_id = stream_id_gen.get_next(txn)
            # ... persist event ...
        """

        next_id = self._load_next_id_txn(txn)

        with self._lock:
            self._unfinished_ids.add(next_id)

        txn.call_after(self._mark_id_as_finished, next_id)
        txn.call_on_exception(self._mark_id_as_finished, next_id)

        return next_id
Exemple #14
0
    def _update_remote_device_list_cache_entry_txn(
        self,
        txn: LoggingTransaction,
        user_id: str,
        device_id: str,
        content: JsonDict,
        stream_id: int,
    ) -> None:
        if content.get("deleted"):
            self.db_pool.simple_delete_txn(
                txn,
                table="device_lists_remote_cache",
                keyvalues={
                    "user_id": user_id,
                    "device_id": device_id
                },
            )

            txn.call_after(self.device_id_exists_cache.invalidate,
                           (user_id, device_id))
        else:
            self.db_pool.simple_upsert_txn(
                txn,
                table="device_lists_remote_cache",
                keyvalues={
                    "user_id": user_id,
                    "device_id": device_id
                },
                values={"content": json_encoder.encode(content)},
                # we don't need to lock, because we assume we are the only thread
                # updating this user's devices.
                lock=False,
            )

        txn.call_after(self._get_cached_user_device.invalidate,
                       (user_id, device_id))
        txn.call_after(self.get_cached_devices_for_user.invalidate,
                       (user_id, ))
        txn.call_after(
            self.get_device_list_last_stream_id_for_remote.invalidate,
            (user_id, ))

        self.db_pool.simple_upsert_txn(
            txn,
            table="device_lists_remote_extremeties",
            keyvalues={"user_id": user_id},
            values={"stream_id": stream_id},
            # again, we can assume we are the only thread updating this user's
            # extremity.
            lock=False,
        )
Exemple #15
0
 def _remove_from_user_dir_txn(txn: LoggingTransaction) -> None:
     self.db_pool.simple_delete_txn(txn,
                                    table="user_directory",
                                    keyvalues={"user_id": user_id})
     self.db_pool.simple_delete_txn(txn,
                                    table="user_directory_search",
                                    keyvalues={"user_id": user_id})
     self.db_pool.simple_delete_txn(txn,
                                    table="users_in_public_rooms",
                                    keyvalues={"user_id": user_id})
     self.db_pool.simple_delete_txn(
         txn,
         table="users_who_share_private_rooms",
         keyvalues={"user_id": user_id},
     )
     self.db_pool.simple_delete_txn(
         txn,
         table="users_who_share_private_rooms",
         keyvalues={"other_user_id": user_id},
     )
     txn.call_after(self.get_user_in_directory.invalidate, (user_id, ))
Exemple #16
0
 def _add_user_signature_change_txn(
     self,
     txn: LoggingTransaction,
     from_user_id: str,
     user_ids: List[str],
     stream_id: int,
 ) -> None:
     txn.call_after(
         self._user_signature_stream_cache.entity_has_changed,
         from_user_id,
         stream_id,
     )
     self.db_pool.simple_insert_txn(
         txn,
         "user_signature_stream",
         values={
             "stream_id": stream_id,
             "from_user_id": from_user_id,
             "user_ids": json_encoder.encode(user_ids),
         },
     )
Exemple #17
0
    def _add_device_outbound_poke_to_stream_txn(
        self,
        txn: LoggingTransaction,
        user_id: str,
        device_ids: Collection[str],
        hosts: List[str],
        stream_ids: List[str],
        context: Dict[str, str],
    ):
        for host in hosts:
            txn.call_after(
                self._device_list_federation_stream_cache.entity_has_changed,
                host,
                stream_ids[-1],
            )

        now = self._clock.time_msec()
        next_stream_id = iter(stream_ids)

        self.db_pool.simple_insert_many_txn(
            txn,
            table="device_lists_outbound_pokes",
            values=[{
                "destination":
                destination,
                "stream_id":
                next(next_stream_id),
                "user_id":
                user_id,
                "device_id":
                device_id,
                "sent":
                False,
                "ts":
                now,
                "opentracing_context":
                json_encoder.encode(context)
                if whitelisted_homeserver(destination) else "{}",
            } for destination in hosts for device_id in device_ids],
        )
Exemple #18
0
    def _insert_push_rules_update_txn(
        self,
        txn: LoggingTransaction,
        stream_id: int,
        event_stream_ordering: int,
        user_id: str,
        rule_id: str,
        op: str,
        data: Optional[JsonDict] = None,
    ) -> None:
        values = {
            "stream_id": stream_id,
            "event_stream_ordering": event_stream_ordering,
            "user_id": user_id,
            "rule_id": rule_id,
            "op": op,
        }
        if data is not None:
            values.update(data)

        self.db_pool.simple_insert_txn(txn, "push_rules_stream", values=values)

        txn.call_after(self.get_push_rules_for_user.invalidate, (user_id, ))
        txn.call_after(self.get_push_rules_enabled_for_user.invalidate,
                       (user_id, ))
        txn.call_after(self.push_rules_stream_cache.entity_has_changed,
                       user_id, stream_id)
Exemple #19
0
    def _send_invalidation_to_replication(
            self, txn: LoggingTransaction, cache_name: str,
            keys: Optional[Iterable[Any]]) -> None:
        """Notifies replication that given cache has been invalidated.

        Note that this does *not* invalidate the cache locally.

        Args:
            txn
            cache_name
            keys: Entry to invalidate. If None will invalidate all.
        """

        if cache_name == CURRENT_STATE_CACHE_NAME and keys is None:
            raise Exception(
                "Can't stream invalidate all with magic current state cache")

        if isinstance(self.database_engine, PostgresEngine):
            # get_next() returns a context manager which is designed to wrap
            # the transaction. However, we want to only get an ID when we want
            # to use it, here, so we need to call __enter__ manually, and have
            # __exit__ called after the transaction finishes.
            stream_id = self._cache_id_gen.get_next_txn(txn)
            txn.call_after(self.hs.get_notifier().on_new_replication_data)

            if keys is not None:
                keys = list(keys)

            self.db_pool.simple_insert_txn(
                txn,
                table="cache_invalidation_stream_by_instance",
                values={
                    "stream_id": stream_id,
                    "instance_name": self._instance_name,
                    "cache_func": cache_name,
                    "keys": keys,
                    "invalidation_ts": self._clock.time_msec(),
                },
            )
        def set_account_validity_for_user_txn(txn: LoggingTransaction):
            txn.execute(
                """
                INSERT INTO email_account_validity (
                    user_id,
                    expiration_ts_ms,
                    email_sent,
                    renewal_token,
                    token_used_ts_ms
                )
                VALUES (?, ?, ?, ?, ?)
                ON CONFLICT (user_id) DO UPDATE
                SET
                    expiration_ts_ms = EXCLUDED.expiration_ts_ms,
                    email_sent = EXCLUDED.email_sent,
                    renewal_token = EXCLUDED.renewal_token,
                    token_used_ts_ms = EXCLUDED.token_used_ts_ms
                """, (user_id, expiration_ts, email_sent, renewal_token,
                      token_used_ts))

            txn.call_after(self.get_expiration_ts_for_user.invalidate,
                           (user_id, ))
Exemple #21
0
    def _insert_graph_receipt_txn(
        self,
        txn: LoggingTransaction,
        room_id: str,
        receipt_type: str,
        user_id: str,
        event_ids: List[str],
        data: JsonDict,
    ) -> None:
        assert self._can_write_to_receipts

        txn.call_after(
            self._get_receipts_for_user_with_orderings.invalidate,
            (user_id, receipt_type),
        )
        # FIXME: This shouldn't invalidate the whole cache
        txn.call_after(self._get_linearized_receipts_for_room.invalidate,
                       (room_id, ))

        self.db_pool.simple_delete_txn(
            txn,
            table="receipts_graph",
            keyvalues={
                "room_id": room_id,
                "receipt_type": receipt_type,
                "user_id": user_id,
            },
        )
        self.db_pool.simple_insert_txn(
            txn,
            table="receipts_graph",
            values={
                "room_id": room_id,
                "receipt_type": receipt_type,
                "user_id": user_id,
                "event_ids": json_encoder.encode(event_ids),
                "data": json_encoder.encode(data),
            },
        )
Exemple #22
0
    def _update_remote_device_list_cache_txn(self, txn: LoggingTransaction,
                                             user_id: str, devices: List[dict],
                                             stream_id: int) -> None:
        self.db_pool.simple_delete_txn(txn,
                                       table="device_lists_remote_cache",
                                       keyvalues={"user_id": user_id})

        self.db_pool.simple_insert_many_txn(
            txn,
            table="device_lists_remote_cache",
            values=[{
                "user_id": user_id,
                "device_id": content["device_id"],
                "content": json_encoder.encode(content),
            } for content in devices],
        )

        txn.call_after(self.get_cached_devices_for_user.invalidate,
                       (user_id, ))
        txn.call_after(self._get_cached_user_device.invalidate_many,
                       (user_id, ))
        txn.call_after(
            self.get_device_list_last_stream_id_for_remote.invalidate,
            (user_id, ))

        self.db_pool.simple_upsert_txn(
            txn,
            table="device_lists_remote_extremeties",
            keyvalues={"user_id": user_id},
            values={"stream_id": stream_id},
            # we don't need to lock, because we can assume we are the only thread
            # updating this user's extremity.
            lock=False,
        )

        # If we're replacing the remote user's device list cache presumably
        # we've done a full resync, so we remove the entry that says we need
        # to resync
        self.db_pool.simple_delete_txn(
            txn,
            table="device_lists_remote_resync",
            keyvalues={"user_id": user_id},
        )
Exemple #23
0
 def _test_txn(txn: LoggingTransaction) -> None:
     txn.call_after(after_callback, 123, 456, extra=789)
     txn.call_on_exception(exception_callback, 987, 654, extra=321)
     d.cancel()
     # Simulate a retryable failure on every attempt.
     raise self.db_pool.engine.module.OperationalError()
Exemple #24
0
 def _delete_all_from_user_dir_txn(txn: LoggingTransaction) -> None:
     txn.execute("DELETE FROM user_directory")
     txn.execute("DELETE FROM user_directory_search")
     txn.execute("DELETE FROM users_in_public_rooms")
     txn.execute("DELETE FROM users_who_share_private_rooms")
     txn.call_after(self.get_user_in_directory.invalidate_all)
        def _cleanup_extremities_bg_update_txn(txn: LoggingTransaction) -> int:
            # The set of extremity event IDs that we're checking this round
            original_set = set()

            # A dict[str, Set[str]] of event ID to their prev events.
            graph: Dict[str, Set[str]] = {}

            # The set of descendants of the original set that are not rejected
            # nor soft-failed. Ancestors of these events should be removed
            # from the forward extremities table.
            non_rejected_leaves = set()

            # Set of event IDs that have been soft failed, and for which we
            # should check if they have descendants which haven't been soft
            # failed.
            soft_failed_events_to_lookup = set()

            # First, we get `batch_size` events from the table, pulling out
            # their successor events, if any, and the successor events'
            # rejection status.
            txn.execute(
                """SELECT prev_event_id, event_id, internal_metadata,
                    rejections.event_id IS NOT NULL, events.outlier
                FROM (
                    SELECT event_id AS prev_event_id
                    FROM _extremities_to_check
                    LIMIT ?
                ) AS f
                LEFT JOIN event_edges USING (prev_event_id)
                LEFT JOIN events USING (event_id)
                LEFT JOIN event_json USING (event_id)
                LEFT JOIN rejections USING (event_id)
                """,
                (batch_size, ),
            )

            for prev_event_id, event_id, metadata, rejected, outlier in txn:
                original_set.add(prev_event_id)

                if not event_id or outlier:
                    # Common case where the forward extremity doesn't have any
                    # descendants.
                    continue

                graph.setdefault(event_id, set()).add(prev_event_id)

                soft_failed = False
                if metadata:
                    soft_failed = db_to_json(metadata).get("soft_failed")

                if soft_failed or rejected:
                    soft_failed_events_to_lookup.add(event_id)
                else:
                    non_rejected_leaves.add(event_id)

            # Now we recursively check all the soft-failed descendants we
            # found above in the same way, until we have nothing left to
            # check.
            while soft_failed_events_to_lookup:
                # We only want to do 100 at a time, so we split given list
                # into two.
                batch = list(soft_failed_events_to_lookup)
                to_check, to_defer = batch[:100], batch[100:]
                soft_failed_events_to_lookup = set(to_defer)

                sql = """SELECT prev_event_id, event_id, internal_metadata,
                    rejections.event_id IS NOT NULL
                    FROM event_edges
                    INNER JOIN events USING (event_id)
                    INNER JOIN event_json USING (event_id)
                    LEFT JOIN rejections USING (event_id)
                    WHERE
                        NOT events.outlier
                        AND
                """
                clause, args = make_in_list_sql_clause(self.database_engine,
                                                       "prev_event_id",
                                                       to_check)
                txn.execute(sql + clause, list(args))

                for prev_event_id, event_id, metadata, rejected in txn:
                    if event_id in graph:
                        # Already handled this event previously, but we still
                        # want to record the edge.
                        graph[event_id].add(prev_event_id)
                        continue

                    graph[event_id] = {prev_event_id}

                    soft_failed = db_to_json(metadata).get("soft_failed")
                    if soft_failed or rejected:
                        soft_failed_events_to_lookup.add(event_id)
                    else:
                        non_rejected_leaves.add(event_id)

            # We have a set of non-soft-failed descendants, so we recurse up
            # the graph to find all ancestors and add them to the set of event
            # IDs that we can delete from forward extremities table.
            to_delete = set()
            while non_rejected_leaves:
                event_id = non_rejected_leaves.pop()
                prev_event_ids = graph.get(event_id, set())
                non_rejected_leaves.update(prev_event_ids)
                to_delete.update(prev_event_ids)

            to_delete.intersection_update(original_set)

            deleted = self.db_pool.simple_delete_many_txn(
                txn=txn,
                table="event_forward_extremities",
                column="event_id",
                values=to_delete,
                keyvalues={},
            )

            logger.info(
                "Deleted %d forward extremities of %d checked, to clean up #5269",
                deleted,
                len(original_set),
            )

            if deleted:
                # We now need to invalidate the caches of these rooms
                rows = self.db_pool.simple_select_many_txn(
                    txn,
                    table="events",
                    column="event_id",
                    iterable=to_delete,
                    keyvalues={},
                    retcols=("room_id", ),
                )
                room_ids = {row["room_id"] for row in rows}
                for room_id in room_ids:
                    txn.call_after(
                        self.get_latest_event_ids_in_room.invalidate,
                        (room_id, )  # type: ignore[attr-defined]
                    )

            self.db_pool.simple_delete_many_txn(
                txn=txn,
                table="_extremities_to_check",
                column="event_id",
                values=original_set,
                keyvalues={},
            )

            return len(original_set)
Exemple #26
0
    def _insert_linearized_receipt_txn(
        self,
        txn: LoggingTransaction,
        room_id: str,
        receipt_type: str,
        user_id: str,
        event_id: str,
        data: JsonDict,
        stream_id: int,
    ) -> Optional[int]:
        """Inserts a receipt into the database if it's newer than the current one.

        Returns:
            None if the receipt is older than the current receipt
            otherwise, the rx timestamp of the event that the receipt corresponds to
                (or 0 if the event is unknown)
        """
        assert self._can_write_to_receipts

        res = self.db_pool.simple_select_one_txn(
            txn,
            table="events",
            retcols=["stream_ordering", "received_ts"],
            keyvalues={"event_id": event_id},
            allow_none=True,
        )

        stream_ordering = int(res["stream_ordering"]) if res else None
        rx_ts = res["received_ts"] if res else 0

        # We don't want to clobber receipts for more recent events, so we
        # have to compare orderings of existing receipts
        if stream_ordering is not None:
            sql = (
                "SELECT stream_ordering, event_id FROM events"
                " INNER JOIN receipts_linearized AS r USING (event_id, room_id)"
                " WHERE r.room_id = ? AND r.receipt_type = ? AND r.user_id = ?"
            )
            txn.execute(sql, (room_id, receipt_type, user_id))

            for so, eid in txn:
                if int(so) >= stream_ordering:
                    logger.debug(
                        "Ignoring new receipt for %s in favour of existing "
                        "one for later event %s",
                        event_id,
                        eid,
                    )
                    return None

        txn.call_after(
            self.invalidate_caches_for_receipt, room_id, receipt_type, user_id
        )

        txn.call_after(
            self._receipts_stream_cache.entity_has_changed, room_id, stream_id
        )

        self.db_pool.simple_upsert_txn(
            txn,
            table="receipts_linearized",
            keyvalues={
                "room_id": room_id,
                "receipt_type": receipt_type,
                "user_id": user_id,
            },
            values={
                "stream_id": stream_id,
                "event_id": event_id,
                "data": json_encoder.encode(data),
            },
            # receipts_linearized has a unique constraint on
            # (user_id, room_id, receipt_type), so no need to lock
            lock=False,
        )

        return rx_ts
Exemple #27
0
 def _test_txn(txn: LoggingTransaction) -> None:
     txn.call_after(after_callback, 123, 456, extra=789)
     txn.call_on_exception(exception_callback, 987, 654, extra=321)
     func(txn)
Exemple #28
0
        def _store_state_group_txn(txn: LoggingTransaction) -> int:
            if current_state_ids is None:
                # AFAIK, this can never happen
                raise Exception("current_state_ids cannot be None")

            state_group = self._state_group_seq_gen.get_next_id_txn(txn)

            self.db_pool.simple_insert_txn(
                txn,
                table="state_groups",
                values={
                    "id": state_group,
                    "room_id": room_id,
                    "event_id": event_id
                },
            )

            # We persist as a delta if we can, while also ensuring the chain
            # of deltas isn't tooo long, as otherwise read performance degrades.
            if prev_group:
                is_in_db = self.db_pool.simple_select_one_onecol_txn(
                    txn,
                    table="state_groups",
                    keyvalues={"id": prev_group},
                    retcol="id",
                    allow_none=True,
                )
                if not is_in_db:
                    raise Exception(
                        "Trying to persist state with unpersisted prev_group: %r"
                        % (prev_group, ))

                potential_hops = self._count_state_group_hops_txn(
                    txn, prev_group)
            if prev_group and potential_hops < MAX_STATE_DELTA_HOPS:
                assert delta_ids is not None

                self.db_pool.simple_insert_txn(
                    txn,
                    table="state_group_edges",
                    values={
                        "state_group": state_group,
                        "prev_state_group": prev_group
                    },
                )

                self.db_pool.simple_insert_many_txn(
                    txn,
                    table="state_groups_state",
                    values=[{
                        "state_group": state_group,
                        "room_id": room_id,
                        "type": key[0],
                        "state_key": key[1],
                        "event_id": state_id,
                    } for key, state_id in delta_ids.items()],
                )
            else:
                self.db_pool.simple_insert_many_txn(
                    txn,
                    table="state_groups_state",
                    values=[{
                        "state_group": state_group,
                        "room_id": room_id,
                        "type": key[0],
                        "state_key": key[1],
                        "event_id": state_id,
                    } for key, state_id in current_state_ids.items()],
                )

            # Prefill the state group caches with this group.
            # It's fine to use the sequence like this as the state group map
            # is immutable. (If the map wasn't immutable then this prefill could
            # race with another update)

            current_member_state_ids = {
                s: ev
                for (s, ev) in current_state_ids.items()
                if s[0] == EventTypes.Member
            }
            txn.call_after(
                self._state_group_members_cache.update,
                self._state_group_members_cache.sequence,
                key=state_group,
                value=dict(current_member_state_ids),
            )

            current_non_member_state_ids = {
                s: ev
                for (s, ev) in current_state_ids.items()
                if s[0] != EventTypes.Member
            }
            txn.call_after(
                self._state_group_cache.update,
                self._state_group_cache.sequence,
                key=state_group,
                value=dict(current_non_member_state_ids),
            )

            return state_group