Esempio n. 1
0
    def __init__(
        self,
        database: DatabasePool,
        db_conn: LoggingDatabaseConnection,
        hs: "HomeServer",
    ):
        super().__init__(database, db_conn, hs)

        self._push_rule_id_gen = IdGenerator(db_conn, "push_rules", "id")
        self._push_rules_enable_id_gen = IdGenerator(db_conn,
                                                     "push_rules_enable", "id")
Esempio n. 2
0
class PushRuleStore(PushRulesWorkerStore):
    # Because we have write access, this will be a StreamIdGenerator
    # (see PushRulesWorkerStore.__init__)
    _push_rules_stream_id_gen: AbstractStreamIdGenerator

    def __init__(
        self,
        database: DatabasePool,
        db_conn: LoggingDatabaseConnection,
        hs: "HomeServer",
    ):
        super().__init__(database, db_conn, hs)

        self._push_rule_id_gen = IdGenerator(db_conn, "push_rules", "id")
        self._push_rules_enable_id_gen = IdGenerator(db_conn,
                                                     "push_rules_enable", "id")

    async def add_push_rule(
        self,
        user_id: str,
        rule_id: str,
        priority_class: int,
        conditions: List[Dict[str, str]],
        actions: List[Union[JsonDict, str]],
        before: Optional[str] = None,
        after: Optional[str] = None,
    ) -> None:
        conditions_json = json_encoder.encode(conditions)
        actions_json = json_encoder.encode(actions)
        async with self._push_rules_stream_id_gen.get_next() as stream_id:
            event_stream_ordering = self._stream_id_gen.get_current_token()

            if before or after:
                await self.db_pool.runInteraction(
                    "_add_push_rule_relative_txn",
                    self._add_push_rule_relative_txn,
                    stream_id,
                    event_stream_ordering,
                    user_id,
                    rule_id,
                    priority_class,
                    conditions_json,
                    actions_json,
                    before,
                    after,
                )
            else:
                await self.db_pool.runInteraction(
                    "_add_push_rule_highest_priority_txn",
                    self._add_push_rule_highest_priority_txn,
                    stream_id,
                    event_stream_ordering,
                    user_id,
                    rule_id,
                    priority_class,
                    conditions_json,
                    actions_json,
                )

    def _add_push_rule_relative_txn(
        self,
        txn: LoggingTransaction,
        stream_id: int,
        event_stream_ordering: int,
        user_id: str,
        rule_id: str,
        priority_class: int,
        conditions_json: str,
        actions_json: str,
        before: str,
        after: str,
    ) -> None:
        # Lock the table since otherwise we'll have annoying races between the
        # SELECT here and the UPSERT below.
        self.database_engine.lock_table(txn, "push_rules")

        relative_to_rule = before or after

        res = self.db_pool.simple_select_one_txn(
            txn,
            table="push_rules",
            keyvalues={
                "user_name": user_id,
                "rule_id": relative_to_rule
            },
            retcols=["priority_class", "priority"],
            allow_none=True,
        )

        if not res:
            raise RuleNotFoundException("before/after rule not found: %s" %
                                        (relative_to_rule, ))

        base_priority_class = res["priority_class"]
        base_rule_priority = res["priority"]

        if base_priority_class != priority_class:
            raise InconsistentRuleException(
                "Given priority class does not match class of relative rule")

        if before:
            # Higher priority rules are executed first, So adding a rule before
            # a rule means giving it a higher priority than that rule.
            new_rule_priority = base_rule_priority + 1
        else:
            # We increment the priority of the existing rules to make space for
            # the new rule. Therefore if we want this rule to appear after
            # an existing rule we give it the priority of the existing rule,
            # and then increment the priority of the existing rule.
            new_rule_priority = base_rule_priority

        sql = ("UPDATE push_rules SET priority = priority + 1"
               " WHERE user_name = ? AND priority_class = ? AND priority >= ?")

        txn.execute(sql, (user_id, priority_class, new_rule_priority))

        self._upsert_push_rule_txn(
            txn,
            stream_id,
            event_stream_ordering,
            user_id,
            rule_id,
            priority_class,
            new_rule_priority,
            conditions_json,
            actions_json,
        )

    def _add_push_rule_highest_priority_txn(
        self,
        txn: LoggingTransaction,
        stream_id: int,
        event_stream_ordering: int,
        user_id: str,
        rule_id: str,
        priority_class: int,
        conditions_json: str,
        actions_json: str,
    ) -> None:
        # Lock the table since otherwise we'll have annoying races between the
        # SELECT here and the UPSERT below.
        self.database_engine.lock_table(txn, "push_rules")

        # find the highest priority rule in that class
        sql = ("SELECT COUNT(*), MAX(priority) FROM push_rules"
               " WHERE user_name = ? and priority_class = ?")
        txn.execute(sql, (user_id, priority_class))
        res = txn.fetchall()
        (how_many, highest_prio) = res[0]

        new_prio = 0
        if how_many > 0:
            new_prio = highest_prio + 1

        self._upsert_push_rule_txn(
            txn,
            stream_id,
            event_stream_ordering,
            user_id,
            rule_id,
            priority_class,
            new_prio,
            conditions_json,
            actions_json,
        )

    def _upsert_push_rule_txn(
        self,
        txn: LoggingTransaction,
        stream_id: int,
        event_stream_ordering: int,
        user_id: str,
        rule_id: str,
        priority_class: int,
        priority: int,
        conditions_json: str,
        actions_json: str,
        update_stream: bool = True,
    ) -> None:
        """Specialised version of simple_upsert_txn that picks a push_rule_id
        using the _push_rule_id_gen if it needs to insert the rule. It assumes
        that the "push_rules" table is locked"""

        sql = (
            "UPDATE push_rules"
            " SET priority_class = ?, priority = ?, conditions = ?, actions = ?"
            " WHERE user_name = ? AND rule_id = ?")

        txn.execute(
            sql,
            (priority_class, priority, conditions_json, actions_json, user_id,
             rule_id),
        )

        if txn.rowcount == 0:
            # We didn't update a row with the given rule_id so insert one
            push_rule_id = self._push_rule_id_gen.get_next()

            self.db_pool.simple_insert_txn(
                txn,
                table="push_rules",
                values={
                    "id": push_rule_id,
                    "user_name": user_id,
                    "rule_id": rule_id,
                    "priority_class": priority_class,
                    "priority": priority,
                    "conditions": conditions_json,
                    "actions": actions_json,
                },
            )

        if update_stream:
            self._insert_push_rules_update_txn(
                txn,
                stream_id,
                event_stream_ordering,
                user_id,
                rule_id,
                op="ADD",
                data={
                    "priority_class": priority_class,
                    "priority": priority,
                    "conditions": conditions_json,
                    "actions": actions_json,
                },
            )

        # ensure we have a push_rules_enable row
        # enabledness defaults to true
        if isinstance(self.database_engine, PostgresEngine):
            sql = """
                INSERT INTO push_rules_enable (id, user_name, rule_id, enabled)
                VALUES (?, ?, ?, ?)
                ON CONFLICT DO NOTHING
            """
        elif isinstance(self.database_engine, Sqlite3Engine):
            sql = """
                INSERT OR IGNORE INTO push_rules_enable (id, user_name, rule_id, enabled)
                VALUES (?, ?, ?, ?)
            """
        else:
            raise RuntimeError("Unknown database engine")

        new_enable_id = self._push_rules_enable_id_gen.get_next()
        txn.execute(sql, (new_enable_id, user_id, rule_id, 1))

    async def delete_push_rule(self, user_id: str, rule_id: str) -> None:
        """
        Delete a push rule. Args specify the row to be deleted and can be
        any of the columns in the push_rule table, but below are the
        standard ones

        Args:
            user_id: The matrix ID of the push rule owner
            rule_id: The rule_id of the rule to be deleted
        """
        def delete_push_rule_txn(
            txn: LoggingTransaction,
            stream_id: int,
            event_stream_ordering: int,
        ) -> None:
            # we don't use simple_delete_one_txn because that would fail if the
            # user did not have a push_rule_enable row.
            self.db_pool.simple_delete_txn(txn, "push_rules_enable", {
                "user_name": user_id,
                "rule_id": rule_id
            })

            self.db_pool.simple_delete_one_txn(txn, "push_rules", {
                "user_name": user_id,
                "rule_id": rule_id
            })

            self._insert_push_rules_update_txn(txn,
                                               stream_id,
                                               event_stream_ordering,
                                               user_id,
                                               rule_id,
                                               op="DELETE")

        async with self._push_rules_stream_id_gen.get_next() as stream_id:
            event_stream_ordering = self._stream_id_gen.get_current_token()

            await self.db_pool.runInteraction(
                "delete_push_rule",
                delete_push_rule_txn,
                stream_id,
                event_stream_ordering,
            )

    async def set_push_rule_enabled(self, user_id: str, rule_id: str,
                                    enabled: bool,
                                    is_default_rule: bool) -> None:
        """
        Sets the `enabled` state of a push rule.

        Args:
            user_id: the user ID of the user who wishes to enable/disable the rule
                e.g. '@tina:example.org'
            rule_id: the full rule ID of the rule to be enabled/disabled
                e.g. 'global/override/.m.rule.roomnotif'
                  or 'global/override/myCustomRule'
            enabled: True if the rule is to be enabled, False if it is to be
                disabled
            is_default_rule: True if and only if this is a server-default rule.
                This skips the check for existence (as only user-created rules
                are always stored in the database `push_rules` table).

        Raises:
            RuleNotFoundException if the rule does not exist.
        """
        async with self._push_rules_stream_id_gen.get_next() as stream_id:
            event_stream_ordering = self._stream_id_gen.get_current_token()
            await self.db_pool.runInteraction(
                "_set_push_rule_enabled_txn",
                self._set_push_rule_enabled_txn,
                stream_id,
                event_stream_ordering,
                user_id,
                rule_id,
                enabled,
                is_default_rule,
            )

    def _set_push_rule_enabled_txn(
        self,
        txn: LoggingTransaction,
        stream_id: int,
        event_stream_ordering: int,
        user_id: str,
        rule_id: str,
        enabled: bool,
        is_default_rule: bool,
    ) -> None:
        new_id = self._push_rules_enable_id_gen.get_next()

        if not is_default_rule:
            # first check it exists; we need to lock for key share so that a
            # transaction that deletes the push rule will conflict with this one.
            # We also need a push_rule_enable row to exist for every push_rules
            # row, otherwise it is possible to simultaneously delete a push rule
            # (that has no _enable row) and enable it, resulting in a dangling
            # _enable row. To solve this: we either need to use SERIALISABLE or
            # ensure we always have a push_rule_enable row for every push_rule
            # row. We chose the latter.
            for_key_share = "FOR KEY SHARE"
            if not isinstance(self.database_engine, PostgresEngine):
                # For key share is not applicable/available on SQLite
                for_key_share = ""
            sql = ("""
                SELECT 1 FROM push_rules
                WHERE user_name = ? AND rule_id = ?
                %s
            """ % for_key_share)
            txn.execute(sql, (user_id, rule_id))
            if txn.fetchone() is None:
                raise RuleNotFoundException("Push rule does not exist.")

        self.db_pool.simple_upsert_txn(
            txn,
            "push_rules_enable",
            {
                "user_name": user_id,
                "rule_id": rule_id
            },
            {"enabled": 1 if enabled else 0},
            {"id": new_id},
        )

        self._insert_push_rules_update_txn(
            txn,
            stream_id,
            event_stream_ordering,
            user_id,
            rule_id,
            op="ENABLE" if enabled else "DISABLE",
        )

    async def set_push_rule_actions(
        self,
        user_id: str,
        rule_id: str,
        actions: List[Union[dict, str]],
        is_default_rule: bool,
    ) -> None:
        """
        Sets the `actions` state of a push rule.

        Args:
            user_id: the user ID of the user who wishes to enable/disable the rule
                e.g. '@tina:example.org'
            rule_id: the full rule ID of the rule to be enabled/disabled
                e.g. 'global/override/.m.rule.roomnotif'
                  or 'global/override/myCustomRule'
            actions: A list of actions (each action being a dict or string),
                e.g. ["notify", {"set_tweak": "highlight", "value": false}]
            is_default_rule: True if and only if this is a server-default rule.
                This skips the check for existence (as only user-created rules
                are always stored in the database `push_rules` table).

        Raises:
            RuleNotFoundException if the rule does not exist.
        """
        actions_json = json_encoder.encode(actions)

        def set_push_rule_actions_txn(
            txn: LoggingTransaction,
            stream_id: int,
            event_stream_ordering: int,
        ) -> None:
            if is_default_rule:
                # Add a dummy rule to the rules table with the user specified
                # actions.
                priority_class = -1
                priority = 1
                self._upsert_push_rule_txn(
                    txn,
                    stream_id,
                    event_stream_ordering,
                    user_id,
                    rule_id,
                    priority_class,
                    priority,
                    "[]",
                    actions_json,
                    update_stream=False,
                )
            else:
                try:
                    self.db_pool.simple_update_one_txn(
                        txn,
                        "push_rules",
                        {
                            "user_name": user_id,
                            "rule_id": rule_id
                        },
                        {"actions": actions_json},
                    )
                except StoreError as serr:
                    if serr.code == 404:
                        # this sets the NOT_FOUND error Code
                        raise RuleNotFoundException("Push rule does not exist")
                    else:
                        raise

            self._insert_push_rules_update_txn(
                txn,
                stream_id,
                event_stream_ordering,
                user_id,
                rule_id,
                op="ACTIONS",
                data={"actions": actions_json},
            )

        async with self._push_rules_stream_id_gen.get_next() as stream_id:
            event_stream_ordering = self._stream_id_gen.get_current_token()

            await self.db_pool.runInteraction(
                "set_push_rule_actions",
                set_push_rule_actions_txn,
                stream_id,
                event_stream_ordering,
            )

    def _insert_push_rules_update_txn(
        self,
        txn: LoggingTransaction,
        stream_id: int,
        event_stream_ordering: int,
        user_id: str,
        rule_id: str,
        op: str,
        data: Optional[JsonDict] = None,
    ) -> None:
        values = {
            "stream_id": stream_id,
            "event_stream_ordering": event_stream_ordering,
            "user_id": user_id,
            "rule_id": rule_id,
            "op": op,
        }
        if data is not None:
            values.update(data)

        self.db_pool.simple_insert_txn(txn, "push_rules_stream", values=values)

        txn.call_after(self.get_push_rules_for_user.invalidate, (user_id, ))
        txn.call_after(self.get_push_rules_enabled_for_user.invalidate,
                       (user_id, ))
        txn.call_after(self.push_rules_stream_cache.entity_has_changed,
                       user_id, stream_id)

    def get_max_push_rules_stream_id(self) -> int:
        return self._push_rules_stream_id_gen.get_current_token()

    async def copy_push_rule_from_room_to_room(self, new_room_id: str,
                                               user_id: str,
                                               rule: dict) -> None:
        """Copy a single push rule from one room to another for a specific user.

        Args:
            new_room_id: ID of the new room.
            user_id : ID of user the push rule belongs to.
            rule: A push rule.
        """
        # Create new rule id
        rule_id_scope = "/".join(rule["rule_id"].split("/")[:-1])
        new_rule_id = rule_id_scope + "/" + new_room_id

        # Change room id in each condition
        for condition in rule.get("conditions", []):
            if condition.get("key") == "room_id":
                condition["pattern"] = new_room_id

        # Add the rule for the new room
        await self.add_push_rule(
            user_id=user_id,
            rule_id=new_rule_id,
            priority_class=rule["priority_class"],
            conditions=rule["conditions"],
            actions=rule["actions"],
        )

    async def copy_push_rules_from_room_to_room_for_user(
            self, old_room_id: str, new_room_id: str, user_id: str) -> None:
        """Copy all of the push rules from one room to another for a specific
        user.

        Args:
            old_room_id: ID of the old room.
            new_room_id: ID of the new room.
            user_id: ID of user to copy push rules for.
        """
        # Retrieve push rules for this user
        user_push_rules = await self.get_push_rules_for_user(user_id)

        # Get rules relating to the old room and copy them to the new room
        for rule in user_push_rules:
            conditions = rule.get("conditions", [])
            if any(
                (c.get("key") == "room_id" and c.get("pattern") == old_room_id)
                    for c in conditions):
                await self.copy_push_rule_from_room_to_room(
                    new_room_id, user_id, rule)
Esempio n. 3
0
    def __init__(self, database: Database, db_conn, hs):
        self.hs = hs
        self._clock = hs.get_clock()
        self.database_engine = database.engine

        self._stream_id_gen = StreamIdGenerator(
            db_conn,
            "events",
            "stream_ordering",
            extra_tables=[("local_invites", "stream_id")],
        )
        self._backfill_id_gen = StreamIdGenerator(
            db_conn,
            "events",
            "stream_ordering",
            step=-1,
            extra_tables=[("ex_outlier_stream", "event_stream_ordering")],
        )
        self._presence_id_gen = StreamIdGenerator(db_conn, "presence_stream",
                                                  "stream_id")
        self._device_inbox_id_gen = StreamIdGenerator(db_conn,
                                                      "device_max_stream_id",
                                                      "stream_id")
        self._public_room_id_gen = StreamIdGenerator(
            db_conn, "public_room_list_stream", "stream_id")
        self._device_list_id_gen = StreamIdGenerator(
            db_conn,
            "device_lists_stream",
            "stream_id",
            extra_tables=[
                ("user_signature_stream", "stream_id"),
                ("device_lists_outbound_pokes", "stream_id"),
            ],
        )
        self._cross_signing_id_gen = StreamIdGenerator(
            db_conn, "e2e_cross_signing_keys", "stream_id")

        self._access_tokens_id_gen = IdGenerator(db_conn, "access_tokens",
                                                 "id")
        self._event_reports_id_gen = IdGenerator(db_conn, "event_reports",
                                                 "id")
        self._push_rule_id_gen = IdGenerator(db_conn, "push_rules", "id")
        self._push_rules_enable_id_gen = IdGenerator(db_conn,
                                                     "push_rules_enable", "id")
        self._push_rules_stream_id_gen = ChainedIdGenerator(
            self._stream_id_gen, db_conn, "push_rules_stream", "stream_id")
        self._pushers_id_gen = StreamIdGenerator(db_conn,
                                                 "pushers",
                                                 "id",
                                                 extra_tables=[
                                                     ("deleted_pushers",
                                                      "stream_id")
                                                 ])
        self._group_updates_id_gen = StreamIdGenerator(db_conn,
                                                       "local_group_updates",
                                                       "stream_id")

        if isinstance(self.database_engine, PostgresEngine):
            self._cache_id_gen = StreamIdGenerator(
                db_conn, "cache_invalidation_stream", "stream_id")
        else:
            self._cache_id_gen = None

        super(DataStore, self).__init__(database, db_conn, hs)

        self._presence_on_startup = self._get_active_presence(db_conn)

        presence_cache_prefill, min_presence_val = self.db.get_cache_dict(
            db_conn,
            "presence_stream",
            entity_column="user_id",
            stream_column="stream_id",
            max_value=self._presence_id_gen.get_current_token(),
        )
        self.presence_stream_cache = StreamChangeCache(
            "PresenceStreamChangeCache",
            min_presence_val,
            prefilled_cache=presence_cache_prefill,
        )

        max_device_inbox_id = self._device_inbox_id_gen.get_current_token()
        device_inbox_prefill, min_device_inbox_id = self.db.get_cache_dict(
            db_conn,
            "device_inbox",
            entity_column="user_id",
            stream_column="stream_id",
            max_value=max_device_inbox_id,
            limit=1000,
        )
        self._device_inbox_stream_cache = StreamChangeCache(
            "DeviceInboxStreamChangeCache",
            min_device_inbox_id,
            prefilled_cache=device_inbox_prefill,
        )
        # The federation outbox and the local device inbox uses the same
        # stream_id generator.
        device_outbox_prefill, min_device_outbox_id = self.db.get_cache_dict(
            db_conn,
            "device_federation_outbox",
            entity_column="destination",
            stream_column="stream_id",
            max_value=max_device_inbox_id,
            limit=1000,
        )
        self._device_federation_outbox_stream_cache = StreamChangeCache(
            "DeviceFederationOutboxStreamChangeCache",
            min_device_outbox_id,
            prefilled_cache=device_outbox_prefill,
        )

        device_list_max = self._device_list_id_gen.get_current_token()
        self._device_list_stream_cache = StreamChangeCache(
            "DeviceListStreamChangeCache", device_list_max)
        self._user_signature_stream_cache = StreamChangeCache(
            "UserSignatureStreamChangeCache", device_list_max)
        self._device_list_federation_stream_cache = StreamChangeCache(
            "DeviceListFederationStreamChangeCache", device_list_max)

        events_max = self._stream_id_gen.get_current_token()
        curr_state_delta_prefill, min_curr_state_delta_id = self.db.get_cache_dict(
            db_conn,
            "current_state_delta_stream",
            entity_column="room_id",
            stream_column="stream_id",
            max_value=events_max,  # As we share the stream id with events token
            limit=1000,
        )
        self._curr_state_delta_stream_cache = StreamChangeCache(
            "_curr_state_delta_stream_cache",
            min_curr_state_delta_id,
            prefilled_cache=curr_state_delta_prefill,
        )

        _group_updates_prefill, min_group_updates_id = self.db.get_cache_dict(
            db_conn,
            "local_group_updates",
            entity_column="user_id",
            stream_column="stream_id",
            max_value=self._group_updates_id_gen.get_current_token(),
            limit=1000,
        )
        self._group_updates_stream_cache = StreamChangeCache(
            "_group_updates_stream_cache",
            min_group_updates_id,
            prefilled_cache=_group_updates_prefill,
        )

        self._stream_order_on_start = self.get_room_max_stream_ordering()
        self._min_stream_order_on_start = self.get_room_min_stream_ordering()

        # Used in _generate_user_daily_visits to keep track of progress
        self._last_user_visit_update = self._get_start_of_day()
Esempio n. 4
0
    def __init__(self, database: DatabasePool, db_conn, hs):
        self.hs = hs
        self._clock = hs.get_clock()
        self.database_engine = database.engine

        self._presence_id_gen = StreamIdGenerator(
            db_conn, "presence_stream", "stream_id"
        )
        self._device_inbox_id_gen = StreamIdGenerator(
            db_conn, "device_inbox", "stream_id"
        )
        self._public_room_id_gen = StreamIdGenerator(
            db_conn, "public_room_list_stream", "stream_id"
        )
        self._device_list_id_gen = StreamIdGenerator(
            db_conn,
            "device_lists_stream",
            "stream_id",
            extra_tables=[
                ("user_signature_stream", "stream_id"),
                ("device_lists_outbound_pokes", "stream_id"),
            ],
        )
        self._cross_signing_id_gen = StreamIdGenerator(
            db_conn, "e2e_cross_signing_keys", "stream_id"
        )

        self._event_reports_id_gen = IdGenerator(db_conn, "event_reports", "id")
        self._push_rule_id_gen = IdGenerator(db_conn, "push_rules", "id")
        self._push_rules_enable_id_gen = IdGenerator(db_conn, "push_rules_enable", "id")
        self._group_updates_id_gen = StreamIdGenerator(
            db_conn, "local_group_updates", "stream_id"
        )

        if isinstance(self.database_engine, PostgresEngine):
            # We set the `writers` to an empty list here as we don't care about
            # missing updates over restarts, as we'll not have anything in our
            # caches to invalidate. (This reduces the amount of writes to the DB
            # that happen).
            self._cache_id_gen = MultiWriterIdGenerator(
                db_conn,
                database,
                stream_name="caches",
                instance_name=hs.get_instance_name(),
                table="cache_invalidation_stream_by_instance",
                instance_column="instance_name",
                id_column="stream_id",
                sequence_name="cache_invalidation_stream_seq",
                writers=[],
            )
        else:
            self._cache_id_gen = None

        super().__init__(database, db_conn, hs)

        self._presence_on_startup = self._get_active_presence(db_conn)

        presence_cache_prefill, min_presence_val = self.db_pool.get_cache_dict(
            db_conn,
            "presence_stream",
            entity_column="user_id",
            stream_column="stream_id",
            max_value=self._presence_id_gen.get_current_token(),
        )
        self.presence_stream_cache = StreamChangeCache(
            "PresenceStreamChangeCache",
            min_presence_val,
            prefilled_cache=presence_cache_prefill,
        )

        max_device_inbox_id = self._device_inbox_id_gen.get_current_token()
        device_inbox_prefill, min_device_inbox_id = self.db_pool.get_cache_dict(
            db_conn,
            "device_inbox",
            entity_column="user_id",
            stream_column="stream_id",
            max_value=max_device_inbox_id,
            limit=1000,
        )
        self._device_inbox_stream_cache = StreamChangeCache(
            "DeviceInboxStreamChangeCache",
            min_device_inbox_id,
            prefilled_cache=device_inbox_prefill,
        )
        # The federation outbox and the local device inbox uses the same
        # stream_id generator.
        device_outbox_prefill, min_device_outbox_id = self.db_pool.get_cache_dict(
            db_conn,
            "device_federation_outbox",
            entity_column="destination",
            stream_column="stream_id",
            max_value=max_device_inbox_id,
            limit=1000,
        )
        self._device_federation_outbox_stream_cache = StreamChangeCache(
            "DeviceFederationOutboxStreamChangeCache",
            min_device_outbox_id,
            prefilled_cache=device_outbox_prefill,
        )

        device_list_max = self._device_list_id_gen.get_current_token()
        self._device_list_stream_cache = StreamChangeCache(
            "DeviceListStreamChangeCache", device_list_max
        )
        self._user_signature_stream_cache = StreamChangeCache(
            "UserSignatureStreamChangeCache", device_list_max
        )
        self._device_list_federation_stream_cache = StreamChangeCache(
            "DeviceListFederationStreamChangeCache", device_list_max
        )

        events_max = self._stream_id_gen.get_current_token()
        curr_state_delta_prefill, min_curr_state_delta_id = self.db_pool.get_cache_dict(
            db_conn,
            "current_state_delta_stream",
            entity_column="room_id",
            stream_column="stream_id",
            max_value=events_max,  # As we share the stream id with events token
            limit=1000,
        )
        self._curr_state_delta_stream_cache = StreamChangeCache(
            "_curr_state_delta_stream_cache",
            min_curr_state_delta_id,
            prefilled_cache=curr_state_delta_prefill,
        )

        _group_updates_prefill, min_group_updates_id = self.db_pool.get_cache_dict(
            db_conn,
            "local_group_updates",
            entity_column="user_id",
            stream_column="stream_id",
            max_value=self._group_updates_id_gen.get_current_token(),
            limit=1000,
        )
        self._group_updates_stream_cache = StreamChangeCache(
            "_group_updates_stream_cache",
            min_group_updates_id,
            prefilled_cache=_group_updates_prefill,
        )

        self._stream_order_on_start = self.get_room_max_stream_ordering()
        self._min_stream_order_on_start = self.get_room_min_stream_ordering()
Esempio n. 5
0
    def __init__(
        self,
        database: DatabasePool,
        db_conn: LoggingDatabaseConnection,
        hs: "HomeServer",
    ):
        self.hs = hs
        self._clock = hs.get_clock()
        self.database_engine = database.engine

        self._device_list_id_gen = StreamIdGenerator(
            db_conn,
            "device_lists_stream",
            "stream_id",
            extra_tables=[
                ("user_signature_stream", "stream_id"),
                ("device_lists_outbound_pokes", "stream_id"),
            ],
        )

        self._push_rule_id_gen = IdGenerator(db_conn, "push_rules", "id")
        self._push_rules_enable_id_gen = IdGenerator(db_conn,
                                                     "push_rules_enable", "id")
        self._group_updates_id_gen = StreamIdGenerator(db_conn,
                                                       "local_group_updates",
                                                       "stream_id")

        self._cache_id_gen: Optional[MultiWriterIdGenerator]
        if isinstance(self.database_engine, PostgresEngine):
            # We set the `writers` to an empty list here as we don't care about
            # missing updates over restarts, as we'll not have anything in our
            # caches to invalidate. (This reduces the amount of writes to the DB
            # that happen).
            self._cache_id_gen = MultiWriterIdGenerator(
                db_conn,
                database,
                stream_name="caches",
                instance_name=hs.get_instance_name(),
                tables=[(
                    "cache_invalidation_stream_by_instance",
                    "instance_name",
                    "stream_id",
                )],
                sequence_name="cache_invalidation_stream_seq",
                writers=[],
            )

        else:
            self._cache_id_gen = None

        super().__init__(database, db_conn, hs)

        device_list_max = self._device_list_id_gen.get_current_token()
        self._device_list_stream_cache = StreamChangeCache(
            "DeviceListStreamChangeCache", device_list_max)
        self._user_signature_stream_cache = StreamChangeCache(
            "UserSignatureStreamChangeCache", device_list_max)
        self._device_list_federation_stream_cache = StreamChangeCache(
            "DeviceListFederationStreamChangeCache", device_list_max)

        events_max = self._stream_id_gen.get_current_token()
        curr_state_delta_prefill, min_curr_state_delta_id = self.db_pool.get_cache_dict(
            db_conn,
            "current_state_delta_stream",
            entity_column="room_id",
            stream_column="stream_id",
            max_value=events_max,  # As we share the stream id with events token
            limit=1000,
        )
        self._curr_state_delta_stream_cache = StreamChangeCache(
            "_curr_state_delta_stream_cache",
            min_curr_state_delta_id,
            prefilled_cache=curr_state_delta_prefill,
        )

        _group_updates_prefill, min_group_updates_id = self.db_pool.get_cache_dict(
            db_conn,
            "local_group_updates",
            entity_column="user_id",
            stream_column="stream_id",
            max_value=self._group_updates_id_gen.get_current_token(),
            limit=1000,
        )
        self._group_updates_stream_cache = StreamChangeCache(
            "_group_updates_stream_cache",
            min_group_updates_id,
            prefilled_cache=_group_updates_prefill,
        )

        self._stream_order_on_start = self.get_room_max_stream_ordering()
        self._min_stream_order_on_start = self.get_room_min_stream_ordering()