Ejemplo n.º 1
0
    def __init__(self, database: Database, db_conn, hs):
        # We instantiate this first as the ReceiptsWorkerStore constructor
        # needs to be able to call get_max_receipt_stream_id
        self._receipts_id_gen = StreamIdGenerator(db_conn,
                                                  "receipts_linearized",
                                                  "stream_id")

        super(ReceiptsStore, self).__init__(database, db_conn, hs)
Ejemplo n.º 2
0
    def __init__(
        self,
        database: DatabasePool,
        db_conn: LoggingDatabaseConnection,
        hs: "HomeServer",
    ):
        super().__init__(database, db_conn, hs)

        self._cross_signing_id_gen = StreamIdGenerator(
            db_conn, "e2e_cross_signing_keys", "stream_id")
Ejemplo n.º 3
0
    def __init__(self, database: Database, db_conn, hs):
        self._account_data_id_gen = StreamIdGenerator(
            db_conn,
            "account_data_max_stream_id",
            "stream_id",
            extra_tables=[
                ("room_account_data", "stream_id"),
                ("room_tags_revisions", "stream_id"),
            ],
        )

        super(AccountDataStore, self).__init__(database, db_conn, hs)
Ejemplo n.º 4
0
    def __init__(
        self,
        database: DatabasePool,
        db_conn: LoggingDatabaseConnection,
        hs: "HomeServer",
    ):
        super().__init__(database, db_conn, hs)

        if hs.config.worker.worker_app is None:
            self._push_rules_stream_id_gen: AbstractStreamIdTracker = StreamIdGenerator(
                db_conn, "push_rules_stream", "stream_id")
        else:
            self._push_rules_stream_id_gen = SlavedIdTracker(
                db_conn, "push_rules_stream", "stream_id")

        push_rules_prefill, push_rules_id = self.db_pool.get_cache_dict(
            db_conn,
            "push_rules_stream",
            entity_column="user_id",
            stream_column="stream_id",
            max_value=self.get_max_push_rules_stream_id(),
        )

        self.push_rules_stream_cache = StreamChangeCache(
            "PushRulesStreamChangeCache",
            push_rules_id,
            prefilled_cache=push_rules_prefill,
        )
Ejemplo n.º 5
0
    def __init__(self, database: DatabasePool, db_conn, hs):
        super().__init__(database, db_conn, hs)

        if hs.config.worker.worker_app is None:
            self._push_rules_stream_id_gen = StreamIdGenerator(
                db_conn, "push_rules_stream",
                "stream_id")  # type: Union[StreamIdGenerator, SlavedIdTracker]
        else:
            self._push_rules_stream_id_gen = SlavedIdTracker(
                db_conn, "push_rules_stream", "stream_id")

        push_rules_prefill, push_rules_id = self.db_pool.get_cache_dict(
            db_conn,
            "push_rules_stream",
            entity_column="user_id",
            stream_column="stream_id",
            max_value=self.get_max_push_rules_stream_id(),
        )

        self.push_rules_stream_cache = StreamChangeCache(
            "PushRulesStreamChangeCache",
            push_rules_id,
            prefilled_cache=push_rules_prefill,
        )

        self._users_new_default_push_rules = hs.config.users_new_default_push_rules
Ejemplo n.º 6
0
    def __init__(
        self,
        database: DatabasePool,
        db_conn: LoggingDatabaseConnection,
        hs: "HomeServer",
    ):
        super().__init__(database, db_conn, hs)
        self._pushers_id_gen = StreamIdGenerator(db_conn,
                                                 "pushers",
                                                 "id",
                                                 extra_tables=[
                                                     ("deleted_pushers",
                                                      "stream_id")
                                                 ])

        self.db_pool.updates.register_background_update_handler(
            "remove_deactivated_pushers",
            self._remove_deactivated_pushers,
        )

        self.db_pool.updates.register_background_update_handler(
            "remove_stale_pushers",
            self._remove_stale_pushers,
        )

        self.db_pool.updates.register_background_update_handler(
            "remove_deleted_email_pushers",
            self._remove_deleted_email_pushers,
        )
Ejemplo n.º 7
0
    def __init__(
        self,
        database: DatabasePool,
        db_conn: LoggingDatabaseConnection,
        hs: "HomeServer",
    ):
        self._instance_name = hs.get_instance_name()
        self._receipts_id_gen: AbstractStreamIdTracker

        if isinstance(database.engine, PostgresEngine):
            self._can_write_to_receipts = (
                self._instance_name in hs.config.worker.writers.receipts
            )

            self._receipts_id_gen = MultiWriterIdGenerator(
                db_conn=db_conn,
                db=database,
                stream_name="receipts",
                instance_name=self._instance_name,
                tables=[("receipts_linearized", "instance_name", "stream_id")],
                sequence_name="receipts_sequence",
                writers=hs.config.worker.writers.receipts,
            )
        else:
            self._can_write_to_receipts = True

            # We shouldn't be running in worker mode with SQLite, but its useful
            # to support it for unit tests.
            #
            # If this process is the writer than we need to use
            # `StreamIdGenerator`, otherwise we use `SlavedIdTracker` which gets
            # updated over replication. (Multiple writers are not supported for
            # SQLite).
            if hs.get_instance_name() in hs.config.worker.writers.receipts:
                self._receipts_id_gen = StreamIdGenerator(
                    db_conn, "receipts_linearized", "stream_id"
                )
            else:
                self._receipts_id_gen = SlavedIdTracker(
                    db_conn, "receipts_linearized", "stream_id"
                )

        super().__init__(database, db_conn, hs)

        max_receipts_stream_id = self.get_max_receipt_stream_id()
        receipts_stream_prefill, min_receipts_stream_id = self.db_pool.get_cache_dict(
            db_conn,
            "receipts_linearized",
            entity_column="room_id",
            stream_column="stream_id",
            max_value=max_receipts_stream_id,
            limit=10000,
        )
        self._receipts_stream_cache = StreamChangeCache(
            "ReceiptsRoomChangeCache",
            min_receipts_stream_id,
            prefilled_cache=receipts_stream_prefill,
        )
Ejemplo n.º 8
0
    def __init__(
        self,
        database: DatabasePool,
        db_conn: LoggingDatabaseConnection,
        hs: "HomeServer",
    ):
        super().__init__(database, db_conn, hs)

        # `_can_write_to_account_data` indicates whether the current worker is allowed
        # to write account data. A value of `True` implies that `_account_data_id_gen`
        # is an `AbstractStreamIdGenerator` and not just a tracker.
        self._account_data_id_gen: AbstractStreamIdTracker

        if isinstance(database.engine, PostgresEngine):
            self._can_write_to_account_data = (
                self._instance_name in hs.config.worker.writers.account_data)

            self._account_data_id_gen = MultiWriterIdGenerator(
                db_conn=db_conn,
                db=database,
                stream_name="account_data",
                instance_name=self._instance_name,
                tables=[
                    ("room_account_data", "instance_name", "stream_id"),
                    ("room_tags_revisions", "instance_name", "stream_id"),
                    ("account_data", "instance_name", "stream_id"),
                ],
                sequence_name="account_data_sequence",
                writers=hs.config.worker.writers.account_data,
            )
        else:
            # We shouldn't be running in worker mode with SQLite, but its useful
            # to support it for unit tests.
            #
            # If this process is the writer than we need to use
            # `StreamIdGenerator`, otherwise we use `SlavedIdTracker` which gets
            # updated over replication. (Multiple writers are not supported for
            # SQLite).
            if self._instance_name in hs.config.worker.writers.account_data:
                self._can_write_to_account_data = True
                self._account_data_id_gen = StreamIdGenerator(
                    db_conn,
                    "room_account_data",
                    "stream_id",
                    extra_tables=[("room_tags_revisions", "stream_id")],
                )
            else:
                self._account_data_id_gen = SlavedIdTracker(
                    db_conn,
                    "room_account_data",
                    "stream_id",
                    extra_tables=[("room_tags_revisions", "stream_id")],
                )

        account_max = self.get_max_account_data_stream_id()
        self._account_data_stream_cache = StreamChangeCache(
            "AccountDataAndTagsChangeCache", account_max)
Ejemplo n.º 9
0
    def __init__(
        self,
        database: DatabasePool,
        db_conn: LoggingDatabaseConnection,
        hs: "HomeServer",
    ) -> None:
        super().__init__(database, db_conn, hs)

        self._instance_name = hs.get_instance_name()
        self._presence_id_gen: AbstractStreamIdGenerator

        self._can_persist_presence = (self._instance_name
                                      in hs.config.worker.writers.presence)

        if isinstance(database.engine, PostgresEngine):
            self._presence_id_gen = MultiWriterIdGenerator(
                db_conn=db_conn,
                db=database,
                stream_name="presence_stream",
                instance_name=self._instance_name,
                tables=[("presence_stream", "instance_name", "stream_id")],
                sequence_name="presence_stream_sequence",
                writers=hs.config.worker.writers.presence,
            )
        else:
            self._presence_id_gen = StreamIdGenerator(db_conn,
                                                      "presence_stream",
                                                      "stream_id")

        self.hs = hs
        self._presence_on_startup = self._get_active_presence(db_conn)

        presence_cache_prefill, min_presence_val = self.db_pool.get_cache_dict(
            db_conn,
            "presence_stream",
            entity_column="user_id",
            stream_column="stream_id",
            max_value=self._presence_id_gen.get_current_token(),
        )
        self.presence_stream_cache = StreamChangeCache(
            "PresenceStreamChangeCache",
            min_presence_val,
            prefilled_cache=presence_cache_prefill,
        )
Ejemplo n.º 10
0
 def __init__(self, database: DatabasePool, db_conn: Connection,
              hs: "HomeServer"):
     super().__init__(database, db_conn, hs)
     self._pushers_id_gen = StreamIdGenerator(db_conn,
                                              "pushers",
                                              "id",
                                              extra_tables=[
                                                  ("deleted_pushers",
                                                   "stream_id")
                                              ])
Ejemplo n.º 11
0
    def __init__(self, database: DatabasePool, db_conn, hs):
        self._instance_name = hs.get_instance_name()

        if isinstance(database.engine, PostgresEngine):
            self._can_write_to_account_data = (
                self._instance_name in hs.config.worker.writers.account_data
            )

            self._account_data_id_gen = MultiWriterIdGenerator(
                db_conn=db_conn,
                db=database,
                stream_name="account_data",
                instance_name=self._instance_name,
                tables=[
                    ("room_account_data", "instance_name", "stream_id"),
                    ("room_tags_revisions", "instance_name", "stream_id"),
                    ("account_data", "instance_name", "stream_id"),
                ],
                sequence_name="account_data_sequence",
                writers=hs.config.worker.writers.account_data,
            )
        else:
            self._can_write_to_account_data = True

            # We shouldn't be running in worker mode with SQLite, but its useful
            # to support it for unit tests.
            #
            # If this process is the writer than we need to use
            # `StreamIdGenerator`, otherwise we use `SlavedIdTracker` which gets
            # updated over replication. (Multiple writers are not supported for
            # SQLite).
            if hs.get_instance_name() in hs.config.worker.writers.account_data:
                self._account_data_id_gen = StreamIdGenerator(
                    db_conn,
                    "room_account_data",
                    "stream_id",
                    extra_tables=[("room_tags_revisions", "stream_id")],
                )
            else:
                self._account_data_id_gen = SlavedIdTracker(
                    db_conn,
                    "room_account_data",
                    "stream_id",
                    extra_tables=[("room_tags_revisions", "stream_id")],
                )

        account_max = self.get_max_account_data_stream_id()
        self._account_data_stream_cache = StreamChangeCache(
            "AccountDataAndTagsChangeCache", account_max
        )

        super().__init__(database, db_conn, hs)
Ejemplo n.º 12
0
    def __init__(self, database: DatabasePool, db_conn, hs):
        super(EventsWorkerStore, self).__init__(database, db_conn, hs)

        if hs.config.worker.writers.events == hs.get_instance_name():
            # We are the process in charge of generating stream ids for events,
            # so instantiate ID generators based on the database
            self._stream_id_gen = StreamIdGenerator(
                db_conn,
                "events",
                "stream_ordering",
            )
            self._backfill_id_gen = StreamIdGenerator(
                db_conn,
                "events",
                "stream_ordering",
                step=-1,
                extra_tables=[("ex_outlier_stream", "event_stream_ordering")],
            )
        else:
            # Another process is in charge of persisting events and generating
            # stream IDs: rely on the replication streams to let us know which
            # IDs we can process.
            self._stream_id_gen = SlavedIdTracker(db_conn, "events",
                                                  "stream_ordering")
            self._backfill_id_gen = SlavedIdTracker(db_conn,
                                                    "events",
                                                    "stream_ordering",
                                                    step=-1)

        self._get_event_cache = Cache(
            "*getEvent*",
            keylen=3,
            max_entries=hs.config.caches.event_cache_size,
            apply_cache_factor_from_config=False,
        )

        self._event_fetch_lock = threading.Condition()
        self._event_fetch_list = []
        self._event_fetch_ongoing = 0
Ejemplo n.º 13
0
class DataStore(
    EventsBackgroundUpdatesStore,
    RoomMemberStore,
    RoomStore,
    RegistrationStore,
    StreamStore,
    ProfileStore,
    PresenceStore,
    TransactionStore,
    DirectoryStore,
    KeyStore,
    StateStore,
    SignatureStore,
    ApplicationServiceStore,
    PurgeEventsStore,
    EventFederationStore,
    MediaRepositoryStore,
    RejectionsStore,
    FilteringStore,
    PusherStore,
    PushRuleStore,
    ApplicationServiceTransactionStore,
    ReceiptsStore,
    EndToEndKeyStore,
    EndToEndRoomKeyStore,
    SearchStore,
    TagsStore,
    AccountDataStore,
    EventPushActionsStore,
    OpenIdStore,
    ClientIpStore,
    DeviceStore,
    DeviceInboxStore,
    UserDirectoryStore,
    GroupServerStore,
    UserErasureStore,
    MonthlyActiveUsersStore,
    StatsStore,
    RelationsStore,
    CensorEventsStore,
    UIAuthStore,
    CacheInvalidationWorkerStore,
    ServerMetricsStore,
    NewsWorkerStore,
    AccessWorkerStore
):
    def __init__(self, database: DatabasePool, db_conn, hs):
        self.hs = hs
        self._clock = hs.get_clock()
        self.database_engine = database.engine

        self._presence_id_gen = StreamIdGenerator(
            db_conn, "presence_stream", "stream_id"
        )
        self._device_inbox_id_gen = StreamIdGenerator(
            db_conn, "device_inbox", "stream_id"
        )
        self._public_room_id_gen = StreamIdGenerator(
            db_conn, "public_room_list_stream", "stream_id"
        )
        self._device_list_id_gen = StreamIdGenerator(
            db_conn,
            "device_lists_stream",
            "stream_id",
            extra_tables=[
                ("user_signature_stream", "stream_id"),
                ("device_lists_outbound_pokes", "stream_id"),
            ],
        )
        self._cross_signing_id_gen = StreamIdGenerator(
            db_conn, "e2e_cross_signing_keys", "stream_id"
        )

        self._event_reports_id_gen = IdGenerator(db_conn, "event_reports", "id")
        self._push_rule_id_gen = IdGenerator(db_conn, "push_rules", "id")
        self._push_rules_enable_id_gen = IdGenerator(db_conn, "push_rules_enable", "id")
        self._group_updates_id_gen = StreamIdGenerator(
            db_conn, "local_group_updates", "stream_id"
        )

        if isinstance(self.database_engine, PostgresEngine):
            # We set the `writers` to an empty list here as we don't care about
            # missing updates over restarts, as we'll not have anything in our
            # caches to invalidate. (This reduces the amount of writes to the DB
            # that happen).
            self._cache_id_gen = MultiWriterIdGenerator(
                db_conn,
                database,
                stream_name="caches",
                instance_name=hs.get_instance_name(),
                table="cache_invalidation_stream_by_instance",
                instance_column="instance_name",
                id_column="stream_id",
                sequence_name="cache_invalidation_stream_seq",
                writers=[],
            )
        else:
            self._cache_id_gen = None

        super().__init__(database, db_conn, hs)

        self._presence_on_startup = self._get_active_presence(db_conn)

        presence_cache_prefill, min_presence_val = self.db_pool.get_cache_dict(
            db_conn,
            "presence_stream",
            entity_column="user_id",
            stream_column="stream_id",
            max_value=self._presence_id_gen.get_current_token(),
        )
        self.presence_stream_cache = StreamChangeCache(
            "PresenceStreamChangeCache",
            min_presence_val,
            prefilled_cache=presence_cache_prefill,
        )

        max_device_inbox_id = self._device_inbox_id_gen.get_current_token()
        device_inbox_prefill, min_device_inbox_id = self.db_pool.get_cache_dict(
            db_conn,
            "device_inbox",
            entity_column="user_id",
            stream_column="stream_id",
            max_value=max_device_inbox_id,
            limit=1000,
        )
        self._device_inbox_stream_cache = StreamChangeCache(
            "DeviceInboxStreamChangeCache",
            min_device_inbox_id,
            prefilled_cache=device_inbox_prefill,
        )
        # The federation outbox and the local device inbox uses the same
        # stream_id generator.
        device_outbox_prefill, min_device_outbox_id = self.db_pool.get_cache_dict(
            db_conn,
            "device_federation_outbox",
            entity_column="destination",
            stream_column="stream_id",
            max_value=max_device_inbox_id,
            limit=1000,
        )
        self._device_federation_outbox_stream_cache = StreamChangeCache(
            "DeviceFederationOutboxStreamChangeCache",
            min_device_outbox_id,
            prefilled_cache=device_outbox_prefill,
        )

        device_list_max = self._device_list_id_gen.get_current_token()
        self._device_list_stream_cache = StreamChangeCache(
            "DeviceListStreamChangeCache", device_list_max
        )
        self._user_signature_stream_cache = StreamChangeCache(
            "UserSignatureStreamChangeCache", device_list_max
        )
        self._device_list_federation_stream_cache = StreamChangeCache(
            "DeviceListFederationStreamChangeCache", device_list_max
        )

        events_max = self._stream_id_gen.get_current_token()
        curr_state_delta_prefill, min_curr_state_delta_id = self.db_pool.get_cache_dict(
            db_conn,
            "current_state_delta_stream",
            entity_column="room_id",
            stream_column="stream_id",
            max_value=events_max,  # As we share the stream id with events token
            limit=1000,
        )
        self._curr_state_delta_stream_cache = StreamChangeCache(
            "_curr_state_delta_stream_cache",
            min_curr_state_delta_id,
            prefilled_cache=curr_state_delta_prefill,
        )

        _group_updates_prefill, min_group_updates_id = self.db_pool.get_cache_dict(
            db_conn,
            "local_group_updates",
            entity_column="user_id",
            stream_column="stream_id",
            max_value=self._group_updates_id_gen.get_current_token(),
            limit=1000,
        )
        self._group_updates_stream_cache = StreamChangeCache(
            "_group_updates_stream_cache",
            min_group_updates_id,
            prefilled_cache=_group_updates_prefill,
        )

        self._stream_order_on_start = self.get_room_max_stream_ordering()
        self._min_stream_order_on_start = self.get_room_min_stream_ordering()

    def get_device_stream_token(self) -> int:
        return self._device_list_id_gen.get_current_token()

    def take_presence_startup_info(self):
        active_on_startup = self._presence_on_startup
        self._presence_on_startup = None
        return active_on_startup

    def _get_active_presence(self, db_conn):
        """Fetch non-offline presence from the database so that we can register
        the appropriate time outs.
        """

        sql = (
            "SELECT user_id, state, last_active_ts, last_federation_update_ts,"
            " last_user_sync_ts, status_msg, currently_active FROM presence_stream"
            " WHERE state != ?"
        )

        txn = db_conn.cursor()
        txn.execute(sql, (PresenceState.OFFLINE,))
        rows = self.db_pool.cursor_to_dict(txn)
        txn.close()

        for row in rows:
            row["currently_active"] = bool(row["currently_active"])

        return [UserPresenceState(**row) for row in rows]

    async def get_users(self) -> List[Dict[str, Any]]:
        """Function to retrieve a list of users in users table.

        Returns:
            A list of dictionaries representing users.
        """
        return await self.db_pool.simple_select_list(
            table="users",
            keyvalues={},
            retcols=[
                "name",
                "password_hash",
                "is_guest",
                "admin",
                "user_type",
                "deactivated",
            ],
            desc="get_users",
        )

    async def get_users_paginate(
        self,
        start: int,
        limit: int,
        user_id: Optional[str] = None,
        name: Optional[str] = None,
        guests: bool = True,
        deactivated: bool = False,
    ) -> Tuple[List[Dict[str, Any]], int]:
        """Function to retrieve a paginated list of users from
        users list. This will return a json list of users and the
        total number of users matching the filter criteria.

        Args:
            start: start number to begin the query from
            limit: number of rows to retrieve
            user_id: search for user_id. ignored if name is not None
            name: search for local part of user_id or display name
            guests: whether to in include guest users
            deactivated: whether to include deactivated users
        Returns:
            A tuple of a list of mappings from user to information and a count of total users.
        """

        def get_users_paginate_txn(txn):
            filters = []
            args = [self.hs.config.server_name]

            # `name` is in database already in lower case
            if name:
                filters.append("(name LIKE ? OR LOWER(displayname) LIKE ?)")
                args.extend(["@%" + name.lower() + "%:%", "%" + name.lower() + "%"])
            elif user_id:
                filters.append("name LIKE ?")
                args.extend(["%" + user_id.lower() + "%"])

            if not guests:
                filters.append("is_guest = 0")

            if not deactivated:
                filters.append("deactivated = 0")

            where_clause = "WHERE " + " AND ".join(filters) if len(filters) > 0 else ""

            sql_base = """
                FROM users as u
                LEFT JOIN profiles AS p ON u.name = '@' || p.user_id || ':' || ?
                {}
                """.format(
                where_clause
            )
            sql = "SELECT COUNT(*) as total_users " + sql_base
            txn.execute(sql, args)
            count = txn.fetchone()[0]

            sql = (
                "SELECT name, user_type, is_guest, admin, deactivated, displayname, avatar_url "
                + sql_base
                + " ORDER BY u.name LIMIT ? OFFSET ?"
            )
            args += [limit, start]
            txn.execute(sql, args)
            users = self.db_pool.cursor_to_dict(txn)
            return users, count

        return await self.db_pool.runInteraction(
            "get_users_paginate_txn", get_users_paginate_txn
        )

    async def search_users(self, term: str) -> Optional[List[Dict[str, Any]]]:
        """Function to search users list for one or more users with
        the matched term.

        Args:
            term: search term

        Returns:
            A list of dictionaries or None.
        """
        return await self.db_pool.simple_search_list(
            table="users",
            term=term,
            col="name",
            retcols=["name", "password_hash", "is_guest", "admin", "user_type"],
            desc="search_users",
        )
Ejemplo n.º 14
0
    def __init__(self, database: DatabasePool, db_conn, hs):
        self.hs = hs
        self._clock = hs.get_clock()
        self.database_engine = database.engine

        self._presence_id_gen = StreamIdGenerator(
            db_conn, "presence_stream", "stream_id"
        )
        self._device_inbox_id_gen = StreamIdGenerator(
            db_conn, "device_inbox", "stream_id"
        )
        self._public_room_id_gen = StreamIdGenerator(
            db_conn, "public_room_list_stream", "stream_id"
        )
        self._device_list_id_gen = StreamIdGenerator(
            db_conn,
            "device_lists_stream",
            "stream_id",
            extra_tables=[
                ("user_signature_stream", "stream_id"),
                ("device_lists_outbound_pokes", "stream_id"),
            ],
        )
        self._cross_signing_id_gen = StreamIdGenerator(
            db_conn, "e2e_cross_signing_keys", "stream_id"
        )

        self._event_reports_id_gen = IdGenerator(db_conn, "event_reports", "id")
        self._push_rule_id_gen = IdGenerator(db_conn, "push_rules", "id")
        self._push_rules_enable_id_gen = IdGenerator(db_conn, "push_rules_enable", "id")
        self._group_updates_id_gen = StreamIdGenerator(
            db_conn, "local_group_updates", "stream_id"
        )

        if isinstance(self.database_engine, PostgresEngine):
            # We set the `writers` to an empty list here as we don't care about
            # missing updates over restarts, as we'll not have anything in our
            # caches to invalidate. (This reduces the amount of writes to the DB
            # that happen).
            self._cache_id_gen = MultiWriterIdGenerator(
                db_conn,
                database,
                stream_name="caches",
                instance_name=hs.get_instance_name(),
                table="cache_invalidation_stream_by_instance",
                instance_column="instance_name",
                id_column="stream_id",
                sequence_name="cache_invalidation_stream_seq",
                writers=[],
            )
        else:
            self._cache_id_gen = None

        super().__init__(database, db_conn, hs)

        self._presence_on_startup = self._get_active_presence(db_conn)

        presence_cache_prefill, min_presence_val = self.db_pool.get_cache_dict(
            db_conn,
            "presence_stream",
            entity_column="user_id",
            stream_column="stream_id",
            max_value=self._presence_id_gen.get_current_token(),
        )
        self.presence_stream_cache = StreamChangeCache(
            "PresenceStreamChangeCache",
            min_presence_val,
            prefilled_cache=presence_cache_prefill,
        )

        max_device_inbox_id = self._device_inbox_id_gen.get_current_token()
        device_inbox_prefill, min_device_inbox_id = self.db_pool.get_cache_dict(
            db_conn,
            "device_inbox",
            entity_column="user_id",
            stream_column="stream_id",
            max_value=max_device_inbox_id,
            limit=1000,
        )
        self._device_inbox_stream_cache = StreamChangeCache(
            "DeviceInboxStreamChangeCache",
            min_device_inbox_id,
            prefilled_cache=device_inbox_prefill,
        )
        # The federation outbox and the local device inbox uses the same
        # stream_id generator.
        device_outbox_prefill, min_device_outbox_id = self.db_pool.get_cache_dict(
            db_conn,
            "device_federation_outbox",
            entity_column="destination",
            stream_column="stream_id",
            max_value=max_device_inbox_id,
            limit=1000,
        )
        self._device_federation_outbox_stream_cache = StreamChangeCache(
            "DeviceFederationOutboxStreamChangeCache",
            min_device_outbox_id,
            prefilled_cache=device_outbox_prefill,
        )

        device_list_max = self._device_list_id_gen.get_current_token()
        self._device_list_stream_cache = StreamChangeCache(
            "DeviceListStreamChangeCache", device_list_max
        )
        self._user_signature_stream_cache = StreamChangeCache(
            "UserSignatureStreamChangeCache", device_list_max
        )
        self._device_list_federation_stream_cache = StreamChangeCache(
            "DeviceListFederationStreamChangeCache", device_list_max
        )

        events_max = self._stream_id_gen.get_current_token()
        curr_state_delta_prefill, min_curr_state_delta_id = self.db_pool.get_cache_dict(
            db_conn,
            "current_state_delta_stream",
            entity_column="room_id",
            stream_column="stream_id",
            max_value=events_max,  # As we share the stream id with events token
            limit=1000,
        )
        self._curr_state_delta_stream_cache = StreamChangeCache(
            "_curr_state_delta_stream_cache",
            min_curr_state_delta_id,
            prefilled_cache=curr_state_delta_prefill,
        )

        _group_updates_prefill, min_group_updates_id = self.db_pool.get_cache_dict(
            db_conn,
            "local_group_updates",
            entity_column="user_id",
            stream_column="stream_id",
            max_value=self._group_updates_id_gen.get_current_token(),
            limit=1000,
        )
        self._group_updates_stream_cache = StreamChangeCache(
            "_group_updates_stream_cache",
            min_group_updates_id,
            prefilled_cache=_group_updates_prefill,
        )

        self._stream_order_on_start = self.get_room_max_stream_ordering()
        self._min_stream_order_on_start = self.get_room_min_stream_ordering()
Ejemplo n.º 15
0
class DataStore(
        EventsBackgroundUpdatesStore,
        RoomMemberStore,
        RoomStore,
        RegistrationStore,
        StreamStore,
        ProfileStore,
        PresenceStore,
        TransactionStore,
        DirectoryStore,
        KeyStore,
        StateStore,
        SignatureStore,
        ApplicationServiceStore,
        EventsStore,
        EventFederationStore,
        MediaRepositoryStore,
        RejectionsStore,
        FilteringStore,
        PusherStore,
        PushRuleStore,
        ApplicationServiceTransactionStore,
        ReceiptsStore,
        EndToEndKeyStore,
        EndToEndRoomKeyStore,
        SearchStore,
        TagsStore,
        AccountDataStore,
        EventPushActionsStore,
        OpenIdStore,
        ClientIpStore,
        DeviceStore,
        DeviceInboxStore,
        UserDirectoryStore,
        GroupServerStore,
        UserErasureStore,
        MonthlyActiveUsersStore,
        StatsStore,
        RelationsStore,
        CacheInvalidationStore,
        UIAuthStore,
):
    def __init__(self, database: Database, db_conn, hs):
        self.hs = hs
        self._clock = hs.get_clock()
        self.database_engine = database.engine

        self._stream_id_gen = StreamIdGenerator(
            db_conn,
            "events",
            "stream_ordering",
            extra_tables=[("local_invites", "stream_id")],
        )
        self._backfill_id_gen = StreamIdGenerator(
            db_conn,
            "events",
            "stream_ordering",
            step=-1,
            extra_tables=[("ex_outlier_stream", "event_stream_ordering")],
        )
        self._presence_id_gen = StreamIdGenerator(db_conn, "presence_stream",
                                                  "stream_id")
        self._device_inbox_id_gen = StreamIdGenerator(db_conn,
                                                      "device_max_stream_id",
                                                      "stream_id")
        self._public_room_id_gen = StreamIdGenerator(
            db_conn, "public_room_list_stream", "stream_id")
        self._device_list_id_gen = StreamIdGenerator(
            db_conn,
            "device_lists_stream",
            "stream_id",
            extra_tables=[
                ("user_signature_stream", "stream_id"),
                ("device_lists_outbound_pokes", "stream_id"),
            ],
        )
        self._cross_signing_id_gen = StreamIdGenerator(
            db_conn, "e2e_cross_signing_keys", "stream_id")

        self._access_tokens_id_gen = IdGenerator(db_conn, "access_tokens",
                                                 "id")
        self._event_reports_id_gen = IdGenerator(db_conn, "event_reports",
                                                 "id")
        self._push_rule_id_gen = IdGenerator(db_conn, "push_rules", "id")
        self._push_rules_enable_id_gen = IdGenerator(db_conn,
                                                     "push_rules_enable", "id")
        self._push_rules_stream_id_gen = ChainedIdGenerator(
            self._stream_id_gen, db_conn, "push_rules_stream", "stream_id")
        self._pushers_id_gen = StreamIdGenerator(db_conn,
                                                 "pushers",
                                                 "id",
                                                 extra_tables=[
                                                     ("deleted_pushers",
                                                      "stream_id")
                                                 ])
        self._group_updates_id_gen = StreamIdGenerator(db_conn,
                                                       "local_group_updates",
                                                       "stream_id")

        if isinstance(self.database_engine, PostgresEngine):
            self._cache_id_gen = StreamIdGenerator(
                db_conn, "cache_invalidation_stream", "stream_id")
        else:
            self._cache_id_gen = None

        super(DataStore, self).__init__(database, db_conn, hs)

        self._presence_on_startup = self._get_active_presence(db_conn)

        presence_cache_prefill, min_presence_val = self.db.get_cache_dict(
            db_conn,
            "presence_stream",
            entity_column="user_id",
            stream_column="stream_id",
            max_value=self._presence_id_gen.get_current_token(),
        )
        self.presence_stream_cache = StreamChangeCache(
            "PresenceStreamChangeCache",
            min_presence_val,
            prefilled_cache=presence_cache_prefill,
        )

        max_device_inbox_id = self._device_inbox_id_gen.get_current_token()
        device_inbox_prefill, min_device_inbox_id = self.db.get_cache_dict(
            db_conn,
            "device_inbox",
            entity_column="user_id",
            stream_column="stream_id",
            max_value=max_device_inbox_id,
            limit=1000,
        )
        self._device_inbox_stream_cache = StreamChangeCache(
            "DeviceInboxStreamChangeCache",
            min_device_inbox_id,
            prefilled_cache=device_inbox_prefill,
        )
        # The federation outbox and the local device inbox uses the same
        # stream_id generator.
        device_outbox_prefill, min_device_outbox_id = self.db.get_cache_dict(
            db_conn,
            "device_federation_outbox",
            entity_column="destination",
            stream_column="stream_id",
            max_value=max_device_inbox_id,
            limit=1000,
        )
        self._device_federation_outbox_stream_cache = StreamChangeCache(
            "DeviceFederationOutboxStreamChangeCache",
            min_device_outbox_id,
            prefilled_cache=device_outbox_prefill,
        )

        device_list_max = self._device_list_id_gen.get_current_token()
        self._device_list_stream_cache = StreamChangeCache(
            "DeviceListStreamChangeCache", device_list_max)
        self._user_signature_stream_cache = StreamChangeCache(
            "UserSignatureStreamChangeCache", device_list_max)
        self._device_list_federation_stream_cache = StreamChangeCache(
            "DeviceListFederationStreamChangeCache", device_list_max)

        events_max = self._stream_id_gen.get_current_token()
        curr_state_delta_prefill, min_curr_state_delta_id = self.db.get_cache_dict(
            db_conn,
            "current_state_delta_stream",
            entity_column="room_id",
            stream_column="stream_id",
            max_value=events_max,  # As we share the stream id with events token
            limit=1000,
        )
        self._curr_state_delta_stream_cache = StreamChangeCache(
            "_curr_state_delta_stream_cache",
            min_curr_state_delta_id,
            prefilled_cache=curr_state_delta_prefill,
        )

        _group_updates_prefill, min_group_updates_id = self.db.get_cache_dict(
            db_conn,
            "local_group_updates",
            entity_column="user_id",
            stream_column="stream_id",
            max_value=self._group_updates_id_gen.get_current_token(),
            limit=1000,
        )
        self._group_updates_stream_cache = StreamChangeCache(
            "_group_updates_stream_cache",
            min_group_updates_id,
            prefilled_cache=_group_updates_prefill,
        )

        self._stream_order_on_start = self.get_room_max_stream_ordering()
        self._min_stream_order_on_start = self.get_room_min_stream_ordering()

        # Used in _generate_user_daily_visits to keep track of progress
        self._last_user_visit_update = self._get_start_of_day()

    def take_presence_startup_info(self):
        active_on_startup = self._presence_on_startup
        self._presence_on_startup = None
        return active_on_startup

    def _get_active_presence(self, db_conn):
        """Fetch non-offline presence from the database so that we can register
        the appropriate time outs.
        """

        sql = (
            "SELECT user_id, state, last_active_ts, last_federation_update_ts,"
            " last_user_sync_ts, status_msg, currently_active FROM presence_stream"
            " WHERE state != ?")
        sql = self.database_engine.convert_param_style(sql)

        txn = db_conn.cursor()
        txn.execute(sql, (PresenceState.OFFLINE, ))
        rows = self.db.cursor_to_dict(txn)
        txn.close()

        for row in rows:
            row["currently_active"] = bool(row["currently_active"])

        return [UserPresenceState(**row) for row in rows]

    def count_daily_users(self):
        """
        Counts the number of users who used this homeserver in the last 24 hours.
        """
        yesterday = int(self._clock.time_msec()) - (1000 * 60 * 60 * 24)
        return self.db.runInteraction("count_daily_users", self._count_users,
                                      yesterday)

    def count_monthly_users(self):
        """
        Counts the number of users who used this homeserver in the last 30 days.
        Note this method is intended for phonehome metrics only and is different
        from the mau figure in synapse.storage.monthly_active_users which,
        amongst other things, includes a 3 day grace period before a user counts.
        """
        thirty_days_ago = int(
            self._clock.time_msec()) - (1000 * 60 * 60 * 24 * 30)
        return self.db.runInteraction("count_monthly_users", self._count_users,
                                      thirty_days_ago)

    def _count_users(self, txn, time_from):
        """
        Returns number of users seen in the past time_from period
        """
        sql = """
            SELECT COALESCE(count(*), 0) FROM (
                SELECT user_id FROM user_ips
                WHERE last_seen > ?
                GROUP BY user_id
            ) u
        """
        txn.execute(sql, (time_from, ))
        (count, ) = txn.fetchone()
        return count

    def count_r30_users(self):
        """
        Counts the number of 30 day retained users, defined as:-
         * Users who have created their accounts more than 30 days ago
         * Where last seen at most 30 days ago
         * Where account creation and last_seen are > 30 days apart

         Returns counts globaly for a given user as well as breaking
         by platform
        """
        def _count_r30_users(txn):
            thirty_days_in_secs = 86400 * 30
            now = int(self._clock.time())
            thirty_days_ago_in_secs = now - thirty_days_in_secs

            sql = """
                SELECT platform, COALESCE(count(*), 0) FROM (
                     SELECT
                        users.name, platform, users.creation_ts * 1000,
                        MAX(uip.last_seen)
                     FROM users
                     INNER JOIN (
                         SELECT
                         user_id,
                         last_seen,
                         CASE
                             WHEN user_agent LIKE '%%Android%%' THEN 'android'
                             WHEN user_agent LIKE '%%iOS%%' THEN 'ios'
                             WHEN user_agent LIKE '%%Electron%%' THEN 'electron'
                             WHEN user_agent LIKE '%%Mozilla%%' THEN 'web'
                             WHEN user_agent LIKE '%%Gecko%%' THEN 'web'
                             ELSE 'unknown'
                         END
                         AS platform
                         FROM user_ips
                     ) uip
                     ON users.name = uip.user_id
                     AND users.appservice_id is NULL
                     AND users.creation_ts < ?
                     AND uip.last_seen/1000 > ?
                     AND (uip.last_seen/1000) - users.creation_ts > 86400 * 30
                     GROUP BY users.name, platform, users.creation_ts
                ) u GROUP BY platform
            """

            results = {}
            txn.execute(sql,
                        (thirty_days_ago_in_secs, thirty_days_ago_in_secs))

            for row in txn:
                if row[0] == "unknown":
                    pass
                results[row[0]] = row[1]

            sql = """
                SELECT COALESCE(count(*), 0) FROM (
                    SELECT users.name, users.creation_ts * 1000,
                                                        MAX(uip.last_seen)
                    FROM users
                    INNER JOIN (
                        SELECT
                        user_id,
                        last_seen
                        FROM user_ips
                    ) uip
                    ON users.name = uip.user_id
                    AND appservice_id is NULL
                    AND users.creation_ts < ?
                    AND uip.last_seen/1000 > ?
                    AND (uip.last_seen/1000) - users.creation_ts > 86400 * 30
                    GROUP BY users.name, users.creation_ts
                ) u
            """

            txn.execute(sql,
                        (thirty_days_ago_in_secs, thirty_days_ago_in_secs))

            (count, ) = txn.fetchone()
            results["all"] = count

            return results

        return self.db.runInteraction("count_r30_users", _count_r30_users)

    def _get_start_of_day(self):
        """
        Returns millisecond unixtime for start of UTC day.
        """
        now = time.gmtime()
        today_start = calendar.timegm(
            (now.tm_year, now.tm_mon, now.tm_mday, 0, 0, 0))
        return today_start * 1000

    def generate_user_daily_visits(self):
        """
        Generates daily visit data for use in cohort/ retention analysis
        """
        def _generate_user_daily_visits(txn):
            logger.info("Calling _generate_user_daily_visits")
            today_start = self._get_start_of_day()
            a_day_in_milliseconds = 24 * 60 * 60 * 1000
            now = self.clock.time_msec()

            sql = """
                INSERT INTO user_daily_visits (user_id, device_id, timestamp)
                    SELECT u.user_id, u.device_id, ?
                    FROM user_ips AS u
                    LEFT JOIN (
                      SELECT user_id, device_id, timestamp FROM user_daily_visits
                      WHERE timestamp = ?
                    ) udv
                    ON u.user_id = udv.user_id AND u.device_id=udv.device_id
                    INNER JOIN users ON users.name=u.user_id
                    WHERE last_seen > ? AND last_seen <= ?
                    AND udv.timestamp IS NULL AND users.is_guest=0
                    AND users.appservice_id IS NULL
                    GROUP BY u.user_id, u.device_id
            """

            # This means that the day has rolled over but there could still
            # be entries from the previous day. There is an edge case
            # where if the user logs in at 23:59 and overwrites their
            # last_seen at 00:01 then they will not be counted in the
            # previous day's stats - it is important that the query is run
            # often to minimise this case.
            if today_start > self._last_user_visit_update:
                yesterday_start = today_start - a_day_in_milliseconds
                txn.execute(
                    sql,
                    (
                        yesterday_start,
                        yesterday_start,
                        self._last_user_visit_update,
                        today_start,
                    ),
                )
                self._last_user_visit_update = today_start

            txn.execute(
                sql,
                (today_start, today_start, self._last_user_visit_update, now))
            # Update _last_user_visit_update to now. The reason to do this
            # rather just clamping to the beginning of the day is to limit
            # the size of the join - meaning that the query can be run more
            # frequently
            self._last_user_visit_update = now

        return self.db.runInteraction("generate_user_daily_visits",
                                      _generate_user_daily_visits)

    def get_users(self):
        """Function to retrieve a list of users in users table.

        Args:
        Returns:
            defer.Deferred: resolves to list[dict[str, Any]]
        """
        return self.db.simple_select_list(
            table="users",
            keyvalues={},
            retcols=[
                "name",
                "password_hash",
                "is_guest",
                "admin",
                "user_type",
                "deactivated",
            ],
            desc="get_users",
        )

    def get_users_paginate(self,
                           start,
                           limit,
                           name=None,
                           guests=True,
                           deactivated=False):
        """Function to retrieve a paginated list of users from
        users list. This will return a json list of users and the
        total number of users matching the filter criteria.

        Args:
            start (int): start number to begin the query from
            limit (int): number of rows to retrieve
            name (string): filter for user names
            guests (bool): whether to in include guest users
            deactivated (bool): whether to include deactivated users
        Returns:
            defer.Deferred: resolves to list[dict[str, Any]], int
        """
        def get_users_paginate_txn(txn):
            filters = []
            args = []

            if name:
                filters.append("name LIKE ?")
                args.append("%" + name + "%")

            if not guests:
                filters.append("is_guest = 0")

            if not deactivated:
                filters.append("deactivated = 0")

            where_clause = "WHERE " + " AND ".join(filters) if len(
                filters) > 0 else ""

            sql = "SELECT COUNT(*) as total_users FROM users %s" % (
                where_clause)
            txn.execute(sql, args)
            count = txn.fetchone()[0]

            args = [self.hs.config.server_name] + args + [limit, start]
            sql = """
                SELECT name, user_type, is_guest, admin, deactivated, displayname, avatar_url
                FROM users as u
                LEFT JOIN profiles AS p ON u.name = '@' || p.user_id || ':' || ?
                {}
                ORDER BY u.name LIMIT ? OFFSET ?
                """.format(where_clause)
            txn.execute(sql, args)
            users = self.db.cursor_to_dict(txn)
            return users, count

        return self.db.runInteraction("get_users_paginate_txn",
                                      get_users_paginate_txn)

    def search_users(self, term):
        """Function to search users list for one or more users with
        the matched term.

        Args:
            term (str): search term
            col (str): column to query term should be matched to
        Returns:
            defer.Deferred: resolves to list[dict[str, Any]]
        """
        return self.db.simple_search_list(
            table="users",
            term=term,
            col="name",
            retcols=[
                "name", "password_hash", "is_guest", "admin", "user_type"
            ],
            desc="search_users",
        )
Ejemplo n.º 16
0
    def __init__(self, database: Database, db_conn, hs):
        self.hs = hs
        self._clock = hs.get_clock()
        self.database_engine = database.engine

        self._stream_id_gen = StreamIdGenerator(
            db_conn,
            "events",
            "stream_ordering",
            extra_tables=[("local_invites", "stream_id")],
        )
        self._backfill_id_gen = StreamIdGenerator(
            db_conn,
            "events",
            "stream_ordering",
            step=-1,
            extra_tables=[("ex_outlier_stream", "event_stream_ordering")],
        )
        self._presence_id_gen = StreamIdGenerator(db_conn, "presence_stream",
                                                  "stream_id")
        self._device_inbox_id_gen = StreamIdGenerator(db_conn,
                                                      "device_max_stream_id",
                                                      "stream_id")
        self._public_room_id_gen = StreamIdGenerator(
            db_conn, "public_room_list_stream", "stream_id")
        self._device_list_id_gen = StreamIdGenerator(
            db_conn,
            "device_lists_stream",
            "stream_id",
            extra_tables=[
                ("user_signature_stream", "stream_id"),
                ("device_lists_outbound_pokes", "stream_id"),
            ],
        )
        self._cross_signing_id_gen = StreamIdGenerator(
            db_conn, "e2e_cross_signing_keys", "stream_id")

        self._access_tokens_id_gen = IdGenerator(db_conn, "access_tokens",
                                                 "id")
        self._event_reports_id_gen = IdGenerator(db_conn, "event_reports",
                                                 "id")
        self._push_rule_id_gen = IdGenerator(db_conn, "push_rules", "id")
        self._push_rules_enable_id_gen = IdGenerator(db_conn,
                                                     "push_rules_enable", "id")
        self._push_rules_stream_id_gen = ChainedIdGenerator(
            self._stream_id_gen, db_conn, "push_rules_stream", "stream_id")
        self._pushers_id_gen = StreamIdGenerator(db_conn,
                                                 "pushers",
                                                 "id",
                                                 extra_tables=[
                                                     ("deleted_pushers",
                                                      "stream_id")
                                                 ])
        self._group_updates_id_gen = StreamIdGenerator(db_conn,
                                                       "local_group_updates",
                                                       "stream_id")

        if isinstance(self.database_engine, PostgresEngine):
            self._cache_id_gen = StreamIdGenerator(
                db_conn, "cache_invalidation_stream", "stream_id")
        else:
            self._cache_id_gen = None

        super(DataStore, self).__init__(database, db_conn, hs)

        self._presence_on_startup = self._get_active_presence(db_conn)

        presence_cache_prefill, min_presence_val = self.db.get_cache_dict(
            db_conn,
            "presence_stream",
            entity_column="user_id",
            stream_column="stream_id",
            max_value=self._presence_id_gen.get_current_token(),
        )
        self.presence_stream_cache = StreamChangeCache(
            "PresenceStreamChangeCache",
            min_presence_val,
            prefilled_cache=presence_cache_prefill,
        )

        max_device_inbox_id = self._device_inbox_id_gen.get_current_token()
        device_inbox_prefill, min_device_inbox_id = self.db.get_cache_dict(
            db_conn,
            "device_inbox",
            entity_column="user_id",
            stream_column="stream_id",
            max_value=max_device_inbox_id,
            limit=1000,
        )
        self._device_inbox_stream_cache = StreamChangeCache(
            "DeviceInboxStreamChangeCache",
            min_device_inbox_id,
            prefilled_cache=device_inbox_prefill,
        )
        # The federation outbox and the local device inbox uses the same
        # stream_id generator.
        device_outbox_prefill, min_device_outbox_id = self.db.get_cache_dict(
            db_conn,
            "device_federation_outbox",
            entity_column="destination",
            stream_column="stream_id",
            max_value=max_device_inbox_id,
            limit=1000,
        )
        self._device_federation_outbox_stream_cache = StreamChangeCache(
            "DeviceFederationOutboxStreamChangeCache",
            min_device_outbox_id,
            prefilled_cache=device_outbox_prefill,
        )

        device_list_max = self._device_list_id_gen.get_current_token()
        self._device_list_stream_cache = StreamChangeCache(
            "DeviceListStreamChangeCache", device_list_max)
        self._user_signature_stream_cache = StreamChangeCache(
            "UserSignatureStreamChangeCache", device_list_max)
        self._device_list_federation_stream_cache = StreamChangeCache(
            "DeviceListFederationStreamChangeCache", device_list_max)

        events_max = self._stream_id_gen.get_current_token()
        curr_state_delta_prefill, min_curr_state_delta_id = self.db.get_cache_dict(
            db_conn,
            "current_state_delta_stream",
            entity_column="room_id",
            stream_column="stream_id",
            max_value=events_max,  # As we share the stream id with events token
            limit=1000,
        )
        self._curr_state_delta_stream_cache = StreamChangeCache(
            "_curr_state_delta_stream_cache",
            min_curr_state_delta_id,
            prefilled_cache=curr_state_delta_prefill,
        )

        _group_updates_prefill, min_group_updates_id = self.db.get_cache_dict(
            db_conn,
            "local_group_updates",
            entity_column="user_id",
            stream_column="stream_id",
            max_value=self._group_updates_id_gen.get_current_token(),
            limit=1000,
        )
        self._group_updates_stream_cache = StreamChangeCache(
            "_group_updates_stream_cache",
            min_group_updates_id,
            prefilled_cache=_group_updates_prefill,
        )

        self._stream_order_on_start = self.get_room_max_stream_ordering()
        self._min_stream_order_on_start = self.get_room_min_stream_ordering()

        # Used in _generate_user_daily_visits to keep track of progress
        self._last_user_visit_update = self._get_start_of_day()
Ejemplo n.º 17
0
class DeviceInboxWorkerStore(SQLBaseStore):
    def __init__(
        self,
        database: DatabasePool,
        db_conn: LoggingDatabaseConnection,
        hs: "HomeServer",
    ):
        super().__init__(database, db_conn, hs)

        self._instance_name = hs.get_instance_name()

        # Map of (user_id, device_id) to the last stream_id that has been
        # deleted up to. This is so that we can no op deletions.
        self._last_device_delete_cache: ExpiringCache[Tuple[
            str, Optional[str]], int] = ExpiringCache(
                cache_name="last_device_delete_cache",
                clock=self._clock,
                max_len=10000,
                expiry_ms=30 * 60 * 1000,
            )

        if isinstance(database.engine, PostgresEngine):
            self._can_write_to_device = (self._instance_name
                                         in hs.config.worker.writers.to_device)

            self._device_inbox_id_gen: AbstractStreamIdGenerator = (
                MultiWriterIdGenerator(
                    db_conn=db_conn,
                    db=database,
                    stream_name="to_device",
                    instance_name=self._instance_name,
                    tables=[("device_inbox", "instance_name", "stream_id")],
                    sequence_name="device_inbox_sequence",
                    writers=hs.config.worker.writers.to_device,
                ))
        else:
            self._can_write_to_device = True
            self._device_inbox_id_gen = StreamIdGenerator(
                db_conn, "device_inbox", "stream_id")

        max_device_inbox_id = self._device_inbox_id_gen.get_current_token()
        device_inbox_prefill, min_device_inbox_id = self.db_pool.get_cache_dict(
            db_conn,
            "device_inbox",
            entity_column="user_id",
            stream_column="stream_id",
            max_value=max_device_inbox_id,
            limit=1000,
        )
        self._device_inbox_stream_cache = StreamChangeCache(
            "DeviceInboxStreamChangeCache",
            min_device_inbox_id,
            prefilled_cache=device_inbox_prefill,
        )

        # The federation outbox and the local device inbox uses the same
        # stream_id generator.
        device_outbox_prefill, min_device_outbox_id = self.db_pool.get_cache_dict(
            db_conn,
            "device_federation_outbox",
            entity_column="destination",
            stream_column="stream_id",
            max_value=max_device_inbox_id,
            limit=1000,
        )
        self._device_federation_outbox_stream_cache = StreamChangeCache(
            "DeviceFederationOutboxStreamChangeCache",
            min_device_outbox_id,
            prefilled_cache=device_outbox_prefill,
        )

    def process_replication_rows(
        self,
        stream_name: str,
        instance_name: str,
        token: int,
        rows: Iterable[ToDeviceStream.ToDeviceStreamRow],
    ) -> None:
        if stream_name == ToDeviceStream.NAME:
            # If replication is happening than postgres must be being used.
            assert isinstance(self._device_inbox_id_gen,
                              MultiWriterIdGenerator)
            self._device_inbox_id_gen.advance(instance_name, token)
            for row in rows:
                if row.entity.startswith("@"):
                    self._device_inbox_stream_cache.entity_has_changed(
                        row.entity, token)
                else:
                    self._device_federation_outbox_stream_cache.entity_has_changed(
                        row.entity, token)
        return super().process_replication_rows(stream_name, instance_name,
                                                token, rows)

    def get_to_device_stream_token(self) -> int:
        return self._device_inbox_id_gen.get_current_token()

    async def get_messages_for_user_devices(
        self,
        user_ids: Collection[str],
        from_stream_id: int,
        to_stream_id: int,
    ) -> Dict[Tuple[str, str], List[JsonDict]]:
        """
        Retrieve to-device messages for a given set of users.

        Only to-device messages with stream ids between the given boundaries
        (from < X <= to) are returned.

        Args:
            user_ids: The users to retrieve to-device messages for.
            from_stream_id: The lower boundary of stream id to filter with (exclusive).
            to_stream_id: The upper boundary of stream id to filter with (inclusive).

        Returns:
            A dictionary of (user id, device id) -> list of to-device messages.
        """
        # We expect the stream ID returned by _get_device_messages to always
        # be to_stream_id. So, no need to return it from this function.
        (
            user_id_device_id_to_messages,
            last_processed_stream_id,
        ) = await self._get_device_messages(
            user_ids=user_ids,
            from_stream_id=from_stream_id,
            to_stream_id=to_stream_id,
        )

        assert (
            last_processed_stream_id == to_stream_id
        ), "Expected _get_device_messages to process all to-device messages up to `to_stream_id`"

        return user_id_device_id_to_messages

    async def get_messages_for_device(
        self,
        user_id: str,
        device_id: str,
        from_stream_id: int,
        to_stream_id: int,
        limit: int = 100,
    ) -> Tuple[List[JsonDict], int]:
        """
        Retrieve to-device messages for a single user device.

        Only to-device messages with stream ids between the given boundaries
        (from < X <= to) are returned.

        Args:
            user_id: The ID of the user to retrieve messages for.
            device_id: The ID of the device to retrieve to-device messages for.
            from_stream_id: The lower boundary of stream id to filter with (exclusive).
            to_stream_id: The upper boundary of stream id to filter with (inclusive).
            limit: A limit on the number of to-device messages returned.

        Returns:
            A tuple containing:
                * A list of to-device messages within the given stream id range intended for
                  the given user / device combo.
                * The last-processed stream ID. Subsequent calls of this function with the
                  same device should pass this value as 'from_stream_id'.
        """
        (
            user_id_device_id_to_messages,
            last_processed_stream_id,
        ) = await self._get_device_messages(
            user_ids=[user_id],
            device_id=device_id,
            from_stream_id=from_stream_id,
            to_stream_id=to_stream_id,
            limit=limit,
        )

        if not user_id_device_id_to_messages:
            # There were no messages!
            return [], to_stream_id

        # Extract the messages, no need to return the user and device ID again
        to_device_messages = user_id_device_id_to_messages.get(
            (user_id, device_id), [])

        return to_device_messages, last_processed_stream_id

    async def _get_device_messages(
        self,
        user_ids: Collection[str],
        from_stream_id: int,
        to_stream_id: int,
        device_id: Optional[str] = None,
        limit: Optional[int] = None,
    ) -> Tuple[Dict[Tuple[str, str], List[JsonDict]], int]:
        """
        Retrieve pending to-device messages for a collection of user devices.

        Only to-device messages with stream ids between the given boundaries
        (from < X <= to) are returned.

        Note that a stream ID can be shared by multiple copies of the same message with
        different recipient devices. Stream IDs are only unique in the context of a single
        user ID / device ID pair. Thus, applying a limit (of messages to return) when working
        with a sliding window of stream IDs is only possible when querying messages of a
        single user device.

        Finally, note that device IDs are not unique across users.

        Args:
            user_ids: The user IDs to filter device messages by.
            from_stream_id: The lower boundary of stream id to filter with (exclusive).
            to_stream_id: The upper boundary of stream id to filter with (inclusive).
            device_id: A device ID to query to-device messages for. If not provided, to-device
                messages from all device IDs for the given user IDs will be queried. May not be
                provided if `user_ids` contains more than one entry.
            limit: The maximum number of to-device messages to return. Can only be used when
                passing a single user ID / device ID tuple.

        Returns:
            A tuple containing:
                * A dict of (user_id, device_id) -> list of to-device messages
                * The last-processed stream ID. If this is less than `to_stream_id`, then
                    there may be more messages to retrieve. If `limit` is not set, then this
                    is always equal to 'to_stream_id'.
        """
        if not user_ids:
            logger.warning("No users provided upon querying for device IDs")
            return {}, to_stream_id

        # Prevent a query for one user's device also retrieving another user's device with
        # the same device ID (device IDs are not unique across users).
        if len(user_ids) > 1 and device_id is not None:
            raise AssertionError(
                "Programming error: 'device_id' cannot be supplied to "
                "_get_device_messages when >1 user_id has been provided")

        # A limit can only be applied when querying for a single user ID / device ID tuple.
        # See the docstring of this function for more details.
        if limit is not None and device_id is None:
            raise AssertionError(
                "Programming error: _get_device_messages was passed 'limit' "
                "without a specific user_id/device_id")

        user_ids_to_query: Set[str] = set()
        device_ids_to_query: Set[str] = set()

        # Note that a device ID could be an empty str
        if device_id is not None:
            # If a device ID was passed, use it to filter results.
            # Otherwise, device IDs will be derived from the given collection of user IDs.
            device_ids_to_query.add(device_id)

        # Determine which users have devices with pending messages
        for user_id in user_ids:
            if self._device_inbox_stream_cache.has_entity_changed(
                    user_id, from_stream_id):
                # This user has new messages sent to them. Query messages for them
                user_ids_to_query.add(user_id)

        if not user_ids_to_query:
            return {}, to_stream_id

        def get_device_messages_txn(
            txn: LoggingTransaction,
        ) -> Tuple[Dict[Tuple[str, str], List[JsonDict]], int]:
            # Build a query to select messages from any of the given devices that
            # are between the given stream id bounds.

            # If a list of device IDs was not provided, retrieve all devices IDs
            # for the given users. We explicitly do not query hidden devices, as
            # hidden devices should not receive to-device messages.
            # Note that this is more efficient than just dropping `device_id` from the query,
            # since device_inbox has an index on `(user_id, device_id, stream_id)`
            if not device_ids_to_query:
                user_device_dicts = self.db_pool.simple_select_many_txn(
                    txn,
                    table="devices",
                    column="user_id",
                    iterable=user_ids_to_query,
                    keyvalues={
                        "user_id": user_id,
                        "hidden": False
                    },
                    retcols=("device_id", ),
                )

                device_ids_to_query.update(
                    {row["device_id"]
                     for row in user_device_dicts})

            if not device_ids_to_query:
                # We've ended up with no devices to query.
                return {}, to_stream_id

            # We include both user IDs and device IDs in this query, as we have an index
            # (device_inbox_user_stream_id) for them.
            user_id_many_clause_sql, user_id_many_clause_args = make_in_list_sql_clause(
                self.database_engine, "user_id", user_ids_to_query)
            (
                device_id_many_clause_sql,
                device_id_many_clause_args,
            ) = make_in_list_sql_clause(self.database_engine, "device_id",
                                        device_ids_to_query)

            sql = f"""
                SELECT stream_id, user_id, device_id, message_json FROM device_inbox
                WHERE {user_id_many_clause_sql}
                AND {device_id_many_clause_sql}
                AND ? < stream_id AND stream_id <= ?
                ORDER BY stream_id ASC
            """
            sql_args = (
                *user_id_many_clause_args,
                *device_id_many_clause_args,
                from_stream_id,
                to_stream_id,
            )

            # If a limit was provided, limit the data retrieved from the database
            if limit is not None:
                sql += "LIMIT ?"
                sql_args += (limit, )

            txn.execute(sql, sql_args)

            # Create and fill a dictionary of (user ID, device ID) -> list of messages
            # intended for each device.
            last_processed_stream_pos = to_stream_id
            recipient_device_to_messages: Dict[Tuple[str, str],
                                               List[JsonDict]] = {}
            rowcount = 0
            for row in txn:
                rowcount += 1

                last_processed_stream_pos = row[0]
                recipient_user_id = row[1]
                recipient_device_id = row[2]
                message_dict = db_to_json(row[3])

                # Store the device details
                recipient_device_to_messages.setdefault(
                    (recipient_user_id, recipient_device_id),
                    []).append(message_dict)

            if limit is not None and rowcount == limit:
                # We ended up bumping up against the message limit. There may be more messages
                # to retrieve. Return what we have, as well as the last stream position that
                # was processed.
                #
                # The caller is expected to set this as the lower (exclusive) bound
                # for the next query of this device.
                return recipient_device_to_messages, last_processed_stream_pos

            # The limit was not reached, thus we know that recipient_device_to_messages
            # contains all to-device messages for the given device and stream id range.
            #
            # We return to_stream_id, which the caller should then provide as the lower
            # (exclusive) bound on the next query of this device.
            return recipient_device_to_messages, to_stream_id

        return await self.db_pool.runInteraction("get_device_messages",
                                                 get_device_messages_txn)

    @trace
    async def delete_messages_for_device(self, user_id: str,
                                         device_id: Optional[str],
                                         up_to_stream_id: int) -> int:
        """
        Args:
            user_id: The recipient user_id.
            device_id: The recipient device_id.
            up_to_stream_id: Where to delete messages up to.

        Returns:
            The number of messages deleted.
        """
        # If we have cached the last stream id we've deleted up to, we can
        # check if there is likely to be anything that needs deleting
        last_deleted_stream_id = self._last_device_delete_cache.get(
            (user_id, device_id), None)

        set_tag("last_deleted_stream_id", last_deleted_stream_id)

        if last_deleted_stream_id:
            has_changed = self._device_inbox_stream_cache.has_entity_changed(
                user_id, last_deleted_stream_id)
            if not has_changed:
                log_kv({"message": "No changes in cache since last check"})
                return 0

        def delete_messages_for_device_txn(txn: LoggingTransaction) -> int:
            sql = ("DELETE FROM device_inbox"
                   " WHERE user_id = ? AND device_id = ?"
                   " AND stream_id <= ?")
            txn.execute(sql, (user_id, device_id, up_to_stream_id))
            return txn.rowcount

        count = await self.db_pool.runInteraction(
            "delete_messages_for_device", delete_messages_for_device_txn)

        log_kv({
            "message": f"deleted {count} messages for device",
            "count": count
        })

        # Update the cache, ensuring that we only ever increase the value
        updated_last_deleted_stream_id = self._last_device_delete_cache.get(
            (user_id, device_id), 0)
        self._last_device_delete_cache[(user_id, device_id)] = max(
            updated_last_deleted_stream_id, up_to_stream_id)

        return count

    @trace
    async def get_new_device_msgs_for_remote(
            self, destination: str, last_stream_id: int,
            current_stream_id: int, limit: int) -> Tuple[List[JsonDict], int]:
        """
        Args:
            destination: The name of the remote server.
            last_stream_id: The last position of the device message stream
                that the server sent up to.
            current_stream_id: The current position of the device message stream.
        Returns:
            A list of messages for the device and where in the stream the messages got to.
        """

        set_tag("destination", destination)
        set_tag("last_stream_id", last_stream_id)
        set_tag("current_stream_id", current_stream_id)
        set_tag("limit", limit)

        has_changed = self._device_federation_outbox_stream_cache.has_entity_changed(
            destination, last_stream_id)
        if not has_changed or last_stream_id == current_stream_id:
            log_kv({"message": "No new messages in stream"})
            return [], current_stream_id

        if limit <= 0:
            # This can happen if we run out of room for EDUs in the transaction.
            return [], last_stream_id

        @trace
        def get_new_messages_for_remote_destination_txn(
            txn: LoggingTransaction, ) -> Tuple[List[JsonDict], int]:
            sql = (
                "SELECT stream_id, messages_json FROM device_federation_outbox"
                " WHERE destination = ?"
                " AND ? < stream_id AND stream_id <= ?"
                " ORDER BY stream_id ASC"
                " LIMIT ?")
            txn.execute(
                sql, (destination, last_stream_id, current_stream_id, limit))

            messages = []
            stream_pos = current_stream_id

            for row in txn:
                stream_pos = row[0]
                messages.append(db_to_json(row[1]))

            # If the limit was not reached we know that there's no more data for this
            # user/device pair up to current_stream_id.
            if len(messages) < limit:
                log_kv({"message": "Set stream position to current position"})
                stream_pos = current_stream_id

            return messages, stream_pos

        return await self.db_pool.runInteraction(
            "get_new_device_msgs_for_remote",
            get_new_messages_for_remote_destination_txn,
        )

    @trace
    async def delete_device_msgs_for_remote(self, destination: str,
                                            up_to_stream_id: int) -> None:
        """Used to delete messages when the remote destination acknowledges
        their receipt.

        Args:
            destination: The destination server_name
            up_to_stream_id: Where to delete messages up to.
        """
        def delete_messages_for_remote_destination_txn(
                txn: LoggingTransaction) -> None:
            sql = ("DELETE FROM device_federation_outbox"
                   " WHERE destination = ?"
                   " AND stream_id <= ?")
            txn.execute(sql, (destination, up_to_stream_id))

        await self.db_pool.runInteraction(
            "delete_device_msgs_for_remote",
            delete_messages_for_remote_destination_txn)

    async def get_all_new_device_messages(
            self, instance_name: str, last_id: int, current_id: int,
            limit: int) -> Tuple[List[Tuple[int, tuple]], int, bool]:
        """Get updates for to device replication stream.

        Args:
            instance_name: The writer we want to fetch updates from. Unused
                here since there is only ever one writer.
            last_id: The token to fetch updates from. Exclusive.
            current_id: The token to fetch updates up to. Inclusive.
            limit: The requested limit for the number of rows to return. The
                function may return more or fewer rows.

        Returns:
            A tuple consisting of: the updates, a token to use to fetch
            subsequent updates, and whether we returned fewer rows than exists
            between the requested tokens due to the limit.

            The token returned can be used in a subsequent call to this
            function to get further updatees.

            The updates are a list of 2-tuples of stream ID and the row data
        """

        if last_id == current_id:
            return [], current_id, False

        def get_all_new_device_messages_txn(
            txn: LoggingTransaction,
        ) -> Tuple[List[Tuple[int, tuple]], int, bool]:
            # We limit like this as we might have multiple rows per stream_id, and
            # we want to make sure we always get all entries for any stream_id
            # we return.
            upper_pos = min(current_id, last_id + limit)
            sql = ("SELECT max(stream_id), user_id"
                   " FROM device_inbox"
                   " WHERE ? < stream_id AND stream_id <= ?"
                   " GROUP BY user_id")
            txn.execute(sql, (last_id, upper_pos))
            updates = [(row[0], row[1:]) for row in txn]

            sql = ("SELECT max(stream_id), destination"
                   " FROM device_federation_outbox"
                   " WHERE ? < stream_id AND stream_id <= ?"
                   " GROUP BY destination")
            txn.execute(sql, (last_id, upper_pos))
            updates.extend((row[0], row[1:]) for row in txn)

            # Order by ascending stream ordering
            updates.sort()

            limited = False
            upto_token = current_id
            if len(updates) >= limit:
                upto_token = updates[-1][0]
                limited = True

            return updates, upto_token, limited

        return await self.db_pool.runInteraction(
            "get_all_new_device_messages", get_all_new_device_messages_txn)

    @trace
    async def add_messages_to_device_inbox(
        self,
        local_messages_by_user_then_device: Dict[str, Dict[str, JsonDict]],
        remote_messages_by_destination: Dict[str, JsonDict],
    ) -> int:
        """Used to send messages from this server.

        Args:
            local_messages_by_user_then_device:
                Dictionary of recipient user_id to recipient device_id to message.
            remote_messages_by_destination:
                Dictionary of destination server_name to the EDU JSON to send.

        Returns:
            The new stream_id.
        """

        assert self._can_write_to_device

        def add_messages_txn(txn: LoggingTransaction, now_ms: int,
                             stream_id: int) -> None:
            # Add the local messages directly to the local inbox.
            self._add_messages_to_local_device_inbox_txn(
                txn, stream_id, local_messages_by_user_then_device)

            # Add the remote messages to the federation outbox.
            # We'll send them to a remote server when we next send a
            # federation transaction to that destination.
            self.db_pool.simple_insert_many_txn(
                txn,
                table="device_federation_outbox",
                keys=(
                    "destination",
                    "stream_id",
                    "queued_ts",
                    "messages_json",
                    "instance_name",
                ),
                values=[(
                    destination,
                    stream_id,
                    now_ms,
                    json_encoder.encode(edu),
                    self._instance_name,
                ) for destination, edu in
                        remote_messages_by_destination.items()],
            )

            if remote_messages_by_destination:
                issue9533_logger.debug(
                    "Queued outgoing to-device messages with stream_id %i for %s",
                    stream_id,
                    list(remote_messages_by_destination.keys()),
                )

        async with self._device_inbox_id_gen.get_next() as stream_id:
            now_ms = self._clock.time_msec()
            await self.db_pool.runInteraction("add_messages_to_device_inbox",
                                              add_messages_txn, now_ms,
                                              stream_id)
            for user_id in local_messages_by_user_then_device.keys():
                self._device_inbox_stream_cache.entity_has_changed(
                    user_id, stream_id)
            for destination in remote_messages_by_destination.keys():
                self._device_federation_outbox_stream_cache.entity_has_changed(
                    destination, stream_id)

        return self._device_inbox_id_gen.get_current_token()

    async def add_messages_from_remote_to_device_inbox(
        self,
        origin: str,
        message_id: str,
        local_messages_by_user_then_device: Dict[str, Dict[str, JsonDict]],
    ) -> int:
        assert self._can_write_to_device

        def add_messages_txn(txn: LoggingTransaction, now_ms: int,
                             stream_id: int) -> None:
            # Check if we've already inserted a matching message_id for that
            # origin. This can happen if the origin doesn't receive our
            # acknowledgement from the first time we received the message.
            already_inserted = self.db_pool.simple_select_one_txn(
                txn,
                table="device_federation_inbox",
                keyvalues={
                    "origin": origin,
                    "message_id": message_id
                },
                retcols=("message_id", ),
                allow_none=True,
            )
            if already_inserted is not None:
                return

            # Add an entry for this message_id so that we know we've processed
            # it.
            self.db_pool.simple_insert_txn(
                txn,
                table="device_federation_inbox",
                values={
                    "origin": origin,
                    "message_id": message_id,
                    "received_ts": now_ms,
                },
            )

            # Add the messages to the appropriate local device inboxes so that
            # they'll be sent to the devices when they next sync.
            self._add_messages_to_local_device_inbox_txn(
                txn, stream_id, local_messages_by_user_then_device)

        async with self._device_inbox_id_gen.get_next() as stream_id:
            now_ms = self._clock.time_msec()
            await self.db_pool.runInteraction(
                "add_messages_from_remote_to_device_inbox",
                add_messages_txn,
                now_ms,
                stream_id,
            )
            for user_id in local_messages_by_user_then_device.keys():
                self._device_inbox_stream_cache.entity_has_changed(
                    user_id, stream_id)

        return stream_id

    def _add_messages_to_local_device_inbox_txn(
        self,
        txn: LoggingTransaction,
        stream_id: int,
        messages_by_user_then_device: Dict[str, Dict[str, JsonDict]],
    ) -> None:
        assert self._can_write_to_device

        local_by_user_then_device = {}
        for user_id, messages_by_device in messages_by_user_then_device.items(
        ):
            messages_json_for_user = {}
            devices = list(messages_by_device.keys())
            if len(devices) == 1 and devices[0] == "*":
                # Handle wildcard device_ids.
                # We exclude hidden devices (such as cross-signing keys) here as they are
                # not expected to receive to-device messages.
                devices = self.db_pool.simple_select_onecol_txn(
                    txn,
                    table="devices",
                    keyvalues={
                        "user_id": user_id,
                        "hidden": False
                    },
                    retcol="device_id",
                )

                message_json = json_encoder.encode(messages_by_device["*"])
                for device_id in devices:
                    # Add the message for all devices for this user on this
                    # server.
                    messages_json_for_user[device_id] = message_json
            else:
                if not devices:
                    continue

                # We exclude hidden devices (such as cross-signing keys) here as they are
                # not expected to receive to-device messages.
                rows = self.db_pool.simple_select_many_txn(
                    txn,
                    table="devices",
                    keyvalues={
                        "user_id": user_id,
                        "hidden": False
                    },
                    column="device_id",
                    iterable=devices,
                    retcols=("device_id", ),
                )

                for row in rows:
                    # Only insert into the local inbox if the device exists on
                    # this server
                    device_id = row["device_id"]
                    message_json = json_encoder.encode(
                        messages_by_device[device_id])
                    messages_json_for_user[device_id] = message_json

            if messages_json_for_user:
                local_by_user_then_device[user_id] = messages_json_for_user

        if not local_by_user_then_device:
            return

        self.db_pool.simple_insert_many_txn(
            txn,
            table="device_inbox",
            keys=("user_id", "device_id", "stream_id", "message_json",
                  "instance_name"),
            values=[(user_id, device_id, stream_id, message_json,
                     self._instance_name) for user_id, messages_by_device in
                    local_by_user_then_device.items()
                    for device_id, message_json in messages_by_device.items()],
        )

        issue9533_logger.debug(
            "Stored to-device messages with stream_id %i for %s",
            stream_id,
            [(user_id, device_id)
             for (user_id,
                  messages_by_device) in local_by_user_then_device.items()
             for device_id in messages_by_device.keys()],
        )
Ejemplo n.º 18
0
class AccountDataStore(AccountDataWorkerStore):
    def __init__(self, db_conn, hs):
        self._account_data_id_gen = StreamIdGenerator(
            db_conn, "account_data_max_stream_id", "stream_id"
        )

        super(AccountDataStore, self).__init__(db_conn, hs)

    def get_max_account_data_stream_id(self):
        """Get the current max stream id for the private user data stream

        Returns:
            A deferred int.
        """
        return self._account_data_id_gen.get_current_token()

    @defer.inlineCallbacks
    def add_account_data_to_room(self, user_id, room_id, account_data_type, content):
        """Add some account_data to a room for a user.
        Args:
            user_id(str): The user to add a tag for.
            room_id(str): The room to add a tag for.
            account_data_type(str): The type of account_data to add.
            content(dict): A json object to associate with the tag.
        Returns:
            A deferred that completes once the account_data has been added.
        """
        content_json = json.dumps(content)

        with self._account_data_id_gen.get_next() as next_id:
            # no need to lock here as room_account_data has a unique constraint
            # on (user_id, room_id, account_data_type) so _simple_upsert will
            # retry if there is a conflict.
            yield self._simple_upsert(
                desc="add_room_account_data",
                table="room_account_data",
                keyvalues={
                    "user_id": user_id,
                    "room_id": room_id,
                    "account_data_type": account_data_type,
                },
                values={
                    "stream_id": next_id,
                    "content": content_json,
                },
                lock=False,
            )

            # it's theoretically possible for the above to succeed and the
            # below to fail - in which case we might reuse a stream id on
            # restart, and the above update might not get propagated. That
            # doesn't sound any worse than the whole update getting lost,
            # which is what would happen if we combined the two into one
            # transaction.
            yield self._update_max_stream_id(next_id)

            self._account_data_stream_cache.entity_has_changed(user_id, next_id)
            self.get_account_data_for_user.invalidate((user_id,))
            self.get_account_data_for_room.invalidate((user_id, room_id,))
            self.get_account_data_for_room_and_type.prefill(
                (user_id, room_id, account_data_type,), content,
            )

        result = self._account_data_id_gen.get_current_token()
        defer.returnValue(result)

    @defer.inlineCallbacks
    def add_account_data_for_user(self, user_id, account_data_type, content):
        """Add some account_data to a room for a user.
        Args:
            user_id(str): The user to add a tag for.
            account_data_type(str): The type of account_data to add.
            content(dict): A json object to associate with the tag.
        Returns:
            A deferred that completes once the account_data has been added.
        """
        content_json = json.dumps(content)

        with self._account_data_id_gen.get_next() as next_id:
            # no need to lock here as account_data has a unique constraint on
            # (user_id, account_data_type) so _simple_upsert will retry if
            # there is a conflict.
            yield self._simple_upsert(
                desc="add_user_account_data",
                table="account_data",
                keyvalues={
                    "user_id": user_id,
                    "account_data_type": account_data_type,
                },
                values={
                    "stream_id": next_id,
                    "content": content_json,
                },
                lock=False,
            )

            # it's theoretically possible for the above to succeed and the
            # below to fail - in which case we might reuse a stream id on
            # restart, and the above update might not get propagated. That
            # doesn't sound any worse than the whole update getting lost,
            # which is what would happen if we combined the two into one
            # transaction.
            yield self._update_max_stream_id(next_id)

            self._account_data_stream_cache.entity_has_changed(
                user_id, next_id,
            )
            self.get_account_data_for_user.invalidate((user_id,))
            self.get_global_account_data_by_type_for_user.invalidate(
                (account_data_type, user_id,)
            )

        result = self._account_data_id_gen.get_current_token()
        defer.returnValue(result)

    def _update_max_stream_id(self, next_id):
        """Update the max stream_id

        Args:
            next_id(int): The the revision to advance to.
        """
        def _update(txn):
            update_max_id_sql = (
                "UPDATE account_data_max_stream_id"
                " SET stream_id = ?"
                " WHERE stream_id < ?"
            )
            txn.execute(update_max_id_sql, (next_id, next_id))
        return self.runInteraction(
            "update_account_data_max_stream_id",
            _update,
        )
Ejemplo n.º 19
0
class AccountDataStore(AccountDataWorkerStore):
    def __init__(self, database: DatabasePool, db_conn, hs):
        self._account_data_id_gen = StreamIdGenerator(
            db_conn,
            "account_data_max_stream_id",
            "stream_id",
            extra_tables=[
                ("room_account_data", "stream_id"),
                ("room_tags_revisions", "stream_id"),
            ],
        )

        super(AccountDataStore, self).__init__(database, db_conn, hs)

    def get_max_account_data_stream_id(self) -> int:
        """Get the current max stream id for the private user data stream

        Returns:
            The maximum stream ID.
        """
        return self._account_data_id_gen.get_current_token()

    async def add_account_data_to_room(self, user_id: str, room_id: str,
                                       account_data_type: str,
                                       content: JsonDict) -> int:
        """Add some account_data to a room for a user.

        Args:
            user_id: The user to add a tag for.
            room_id: The room to add a tag for.
            account_data_type: The type of account_data to add.
            content: A json object to associate with the tag.

        Returns:
            The maximum stream ID.
        """
        content_json = json_encoder.encode(content)

        with self._account_data_id_gen.get_next() as next_id:
            # no need to lock here as room_account_data has a unique constraint
            # on (user_id, room_id, account_data_type) so simple_upsert will
            # retry if there is a conflict.
            await self.db_pool.simple_upsert(
                desc="add_room_account_data",
                table="room_account_data",
                keyvalues={
                    "user_id": user_id,
                    "room_id": room_id,
                    "account_data_type": account_data_type,
                },
                values={
                    "stream_id": next_id,
                    "content": content_json
                },
                lock=False,
            )

            # it's theoretically possible for the above to succeed and the
            # below to fail - in which case we might reuse a stream id on
            # restart, and the above update might not get propagated. That
            # doesn't sound any worse than the whole update getting lost,
            # which is what would happen if we combined the two into one
            # transaction.
            await self._update_max_stream_id(next_id)

            self._account_data_stream_cache.entity_has_changed(
                user_id, next_id)
            self.get_account_data_for_user.invalidate((user_id, ))
            self.get_account_data_for_room.invalidate((user_id, room_id))
            self.get_account_data_for_room_and_type.prefill(
                (user_id, room_id, account_data_type), content)

        return self._account_data_id_gen.get_current_token()

    async def add_account_data_for_user(self, user_id: str,
                                        account_data_type: str,
                                        content: JsonDict) -> int:
        """Add some account_data to a room for a user.

        Args:
            user_id: The user to add a tag for.
            account_data_type: The type of account_data to add.
            content: A json object to associate with the tag.

        Returns:
            The maximum stream ID.
        """
        content_json = json_encoder.encode(content)

        with self._account_data_id_gen.get_next() as next_id:
            # no need to lock here as account_data has a unique constraint on
            # (user_id, account_data_type) so simple_upsert will retry if
            # there is a conflict.
            await self.db_pool.simple_upsert(
                desc="add_user_account_data",
                table="account_data",
                keyvalues={
                    "user_id": user_id,
                    "account_data_type": account_data_type
                },
                values={
                    "stream_id": next_id,
                    "content": content_json
                },
                lock=False,
            )

            # it's theoretically possible for the above to succeed and the
            # below to fail - in which case we might reuse a stream id on
            # restart, and the above update might not get propagated. That
            # doesn't sound any worse than the whole update getting lost,
            # which is what would happen if we combined the two into one
            # transaction.
            #
            # Note: This is only here for backwards compat to allow admins to
            # roll back to a previous Synapse version. Next time we update the
            # database version we can remove this table.
            await self._update_max_stream_id(next_id)

            self._account_data_stream_cache.entity_has_changed(
                user_id, next_id)
            self.get_account_data_for_user.invalidate((user_id, ))
            self.get_global_account_data_by_type_for_user.invalidate(
                (account_data_type, user_id))

        return self._account_data_id_gen.get_current_token()

    def _update_max_stream_id(self, next_id: int):
        """Update the max stream_id

        Args:
            next_id: The the revision to advance to.
        """

        # Note: This is only here for backwards compat to allow admins to
        # roll back to a previous Synapse version. Next time we update the
        # database version we can remove this table.

        def _update(txn):
            update_max_id_sql = ("UPDATE account_data_max_stream_id"
                                 " SET stream_id = ?"
                                 " WHERE stream_id < ?")
            txn.execute(update_max_id_sql, (next_id, next_id))

        return self.db_pool.runInteraction("update_account_data_max_stream_id",
                                           _update)
Ejemplo n.º 20
0
    def __init__(self, db_conn, hs):
        self._account_data_id_gen = StreamIdGenerator(
            db_conn, "account_data_max_stream_id", "stream_id"
        )

        super(AccountDataStore, self).__init__(db_conn, hs)
Ejemplo n.º 21
0
class AccountDataStore(AccountDataWorkerStore):
    def __init__(self, db_conn, hs):
        self._account_data_id_gen = StreamIdGenerator(
            db_conn, "account_data_max_stream_id", "stream_id"
        )

        super(AccountDataStore, self).__init__(db_conn, hs)

    def get_max_account_data_stream_id(self):
        """Get the current max stream id for the private user data stream

        Returns:
            A deferred int.
        """
        return self._account_data_id_gen.get_current_token()

    @defer.inlineCallbacks
    def add_account_data_to_room(self, user_id, room_id, account_data_type, content):
        """Add some account_data to a room for a user.
        Args:
            user_id(str): The user to add a tag for.
            room_id(str): The room to add a tag for.
            account_data_type(str): The type of account_data to add.
            content(dict): A json object to associate with the tag.
        Returns:
            A deferred that completes once the account_data has been added.
        """
        content_json = json.dumps(content)

        with self._account_data_id_gen.get_next() as next_id:
            # no need to lock here as room_account_data has a unique constraint
            # on (user_id, room_id, account_data_type) so _simple_upsert will
            # retry if there is a conflict.
            yield self._simple_upsert(
                desc="add_room_account_data",
                table="room_account_data",
                keyvalues={
                    "user_id": user_id,
                    "room_id": room_id,
                    "account_data_type": account_data_type,
                },
                values={"stream_id": next_id, "content": content_json},
                lock=False,
            )

            # it's theoretically possible for the above to succeed and the
            # below to fail - in which case we might reuse a stream id on
            # restart, and the above update might not get propagated. That
            # doesn't sound any worse than the whole update getting lost,
            # which is what would happen if we combined the two into one
            # transaction.
            yield self._update_max_stream_id(next_id)

            self._account_data_stream_cache.entity_has_changed(user_id, next_id)
            self.get_account_data_for_user.invalidate((user_id,))
            self.get_account_data_for_room.invalidate((user_id, room_id))
            self.get_account_data_for_room_and_type.prefill(
                (user_id, room_id, account_data_type), content
            )

        result = self._account_data_id_gen.get_current_token()
        return result

    @defer.inlineCallbacks
    def add_account_data_for_user(self, user_id, account_data_type, content):
        """Add some account_data to a room for a user.
        Args:
            user_id(str): The user to add a tag for.
            account_data_type(str): The type of account_data to add.
            content(dict): A json object to associate with the tag.
        Returns:
            A deferred that completes once the account_data has been added.
        """
        content_json = json.dumps(content)

        with self._account_data_id_gen.get_next() as next_id:
            # no need to lock here as account_data has a unique constraint on
            # (user_id, account_data_type) so _simple_upsert will retry if
            # there is a conflict.
            yield self._simple_upsert(
                desc="add_user_account_data",
                table="account_data",
                keyvalues={"user_id": user_id, "account_data_type": account_data_type},
                values={"stream_id": next_id, "content": content_json},
                lock=False,
            )

            # it's theoretically possible for the above to succeed and the
            # below to fail - in which case we might reuse a stream id on
            # restart, and the above update might not get propagated. That
            # doesn't sound any worse than the whole update getting lost,
            # which is what would happen if we combined the two into one
            # transaction.
            yield self._update_max_stream_id(next_id)

            self._account_data_stream_cache.entity_has_changed(user_id, next_id)
            self.get_account_data_for_user.invalidate((user_id,))
            self.get_global_account_data_by_type_for_user.invalidate(
                (account_data_type, user_id)
            )

        result = self._account_data_id_gen.get_current_token()
        return result

    def _update_max_stream_id(self, next_id):
        """Update the max stream_id

        Args:
            next_id(int): The the revision to advance to.
        """

        def _update(txn):
            update_max_id_sql = (
                "UPDATE account_data_max_stream_id"
                " SET stream_id = ?"
                " WHERE stream_id < ?"
            )
            txn.execute(update_max_id_sql, (next_id, next_id))

        return self.runInteraction("update_account_data_max_stream_id", _update)
Ejemplo n.º 22
0
    def __init__(self, database: DatabasePool, db_conn, hs):
        super().__init__(database, db_conn, hs)

        self._instance_name = hs.get_instance_name()

        # Map of (user_id, device_id) to the last stream_id that has been
        # deleted up to. This is so that we can no op deletions.
        self._last_device_delete_cache = ExpiringCache(
            cache_name="last_device_delete_cache",
            clock=self._clock,
            max_len=10000,
            expiry_ms=30 * 60 * 1000,
        )

        if isinstance(database.engine, PostgresEngine):
            self._can_write_to_device = (self._instance_name
                                         in hs.config.worker.writers.to_device)

            self._device_inbox_id_gen = MultiWriterIdGenerator(
                db_conn=db_conn,
                db=database,
                stream_name="to_device",
                instance_name=self._instance_name,
                tables=[("device_inbox", "instance_name", "stream_id")],
                sequence_name="device_inbox_sequence",
                writers=hs.config.worker.writers.to_device,
            )
        else:
            self._can_write_to_device = True
            self._device_inbox_id_gen = StreamIdGenerator(
                db_conn, "device_inbox", "stream_id")

        max_device_inbox_id = self._device_inbox_id_gen.get_current_token()
        device_inbox_prefill, min_device_inbox_id = self.db_pool.get_cache_dict(
            db_conn,
            "device_inbox",
            entity_column="user_id",
            stream_column="stream_id",
            max_value=max_device_inbox_id,
            limit=1000,
        )
        self._device_inbox_stream_cache = StreamChangeCache(
            "DeviceInboxStreamChangeCache",
            min_device_inbox_id,
            prefilled_cache=device_inbox_prefill,
        )

        # The federation outbox and the local device inbox uses the same
        # stream_id generator.
        device_outbox_prefill, min_device_outbox_id = self.db_pool.get_cache_dict(
            db_conn,
            "device_federation_outbox",
            entity_column="destination",
            stream_column="stream_id",
            max_value=max_device_inbox_id,
            limit=1000,
        )
        self._device_federation_outbox_stream_cache = StreamChangeCache(
            "DeviceFederationOutboxStreamChangeCache",
            min_device_outbox_id,
            prefilled_cache=device_outbox_prefill,
        )
Ejemplo n.º 23
0
class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
    def __init__(
        self,
        database: DatabasePool,
        db_conn: LoggingDatabaseConnection,
        hs: "HomeServer",
    ):
        super().__init__(database, db_conn, hs)

        self._cross_signing_id_gen = StreamIdGenerator(
            db_conn, "e2e_cross_signing_keys", "stream_id")

    async def set_e2e_device_keys(self, user_id: str, device_id: str,
                                  time_now: int,
                                  device_keys: JsonDict) -> bool:
        """Stores device keys for a device. Returns whether there was a change
        or the keys were already in the database.
        """
        def _set_e2e_device_keys_txn(txn: LoggingTransaction) -> bool:
            set_tag("user_id", user_id)
            set_tag("device_id", device_id)
            set_tag("time_now", time_now)
            set_tag("device_keys", device_keys)

            old_key_json = self.db_pool.simple_select_one_onecol_txn(
                txn,
                table="e2e_device_keys_json",
                keyvalues={
                    "user_id": user_id,
                    "device_id": device_id
                },
                retcol="key_json",
                allow_none=True,
            )

            # In py3 we need old_key_json to match new_key_json type. The DB
            # returns unicode while encode_canonical_json returns bytes.
            new_key_json = encode_canonical_json(device_keys).decode("utf-8")

            if old_key_json == new_key_json:
                log_kv({"Message": "Device key already stored."})
                return False

            self.db_pool.simple_upsert_txn(
                txn,
                table="e2e_device_keys_json",
                keyvalues={
                    "user_id": user_id,
                    "device_id": device_id
                },
                values={
                    "ts_added_ms": time_now,
                    "key_json": new_key_json
                },
            )
            log_kv({"message": "Device keys stored."})
            return True

        return await self.db_pool.runInteraction("set_e2e_device_keys",
                                                 _set_e2e_device_keys_txn)

    async def delete_e2e_keys_by_device(self, user_id: str,
                                        device_id: str) -> None:
        def delete_e2e_keys_by_device_txn(txn: LoggingTransaction) -> None:
            log_kv({
                "message": "Deleting keys for device",
                "device_id": device_id,
                "user_id": user_id,
            })
            self.db_pool.simple_delete_txn(
                txn,
                table="e2e_device_keys_json",
                keyvalues={
                    "user_id": user_id,
                    "device_id": device_id
                },
            )
            self.db_pool.simple_delete_txn(
                txn,
                table="e2e_one_time_keys_json",
                keyvalues={
                    "user_id": user_id,
                    "device_id": device_id
                },
            )
            self._invalidate_cache_and_stream(txn,
                                              self.count_e2e_one_time_keys,
                                              (user_id, device_id))
            self.db_pool.simple_delete_txn(
                txn,
                table="dehydrated_devices",
                keyvalues={
                    "user_id": user_id,
                    "device_id": device_id
                },
            )
            self.db_pool.simple_delete_txn(
                txn,
                table="e2e_fallback_keys_json",
                keyvalues={
                    "user_id": user_id,
                    "device_id": device_id
                },
            )
            self._invalidate_cache_and_stream(
                txn, self.get_e2e_unused_fallback_key_types,
                (user_id, device_id))

        await self.db_pool.runInteraction("delete_e2e_keys_by_device",
                                          delete_e2e_keys_by_device_txn)

    def _set_e2e_cross_signing_key_txn(
        self,
        txn: LoggingTransaction,
        user_id: str,
        key_type: str,
        key: JsonDict,
        stream_id: int,
    ) -> None:
        """Set a user's cross-signing key.

        Args:
            txn: db connection
            user_id: the user to set the signing key for
            key_type: the type of key that is being set: either 'master'
                for a master key, 'self_signing' for a self-signing key, or
                'user_signing' for a user-signing key
            key: the key data
            stream_id
        """
        # the 'key' dict will look something like:
        # {
        #   "user_id": "@alice:example.com",
        #   "usage": ["self_signing"],
        #   "keys": {
        #     "ed25519:base64+self+signing+public+key": "base64+self+signing+public+key",
        #   },
        #   "signatures": {
        #     "@alice:example.com": {
        #       "ed25519:base64+master+public+key": "base64+signature"
        #     }
        #   }
        # }
        # The "keys" property must only have one entry, which will be the public
        # key, so we just grab the first value in there
        pubkey = next(iter(key["keys"].values()))

        # The cross-signing keys need to occupy the same namespace as devices,
        # since signatures are identified by device ID.  So add an entry to the
        # device table to make sure that we don't have a collision with device
        # IDs.
        # We only need to do this for local users, since remote servers should be
        # responsible for checking this for their own users.
        if self.hs.is_mine_id(user_id):
            self.db_pool.simple_insert_txn(
                txn,
                "devices",
                values={
                    "user_id": user_id,
                    "device_id": pubkey,
                    "display_name": key_type + " signing key",
                    "hidden": True,
                },
            )

        # and finally, store the key itself
        self.db_pool.simple_insert_txn(
            txn,
            "e2e_cross_signing_keys",
            values={
                "user_id": user_id,
                "keytype": key_type,
                "keydata": json_encoder.encode(key),
                "stream_id": stream_id,
            },
        )

        self._invalidate_cache_and_stream(
            txn, self._get_bare_e2e_cross_signing_keys, (user_id, ))

    async def set_e2e_cross_signing_key(self, user_id: str, key_type: str,
                                        key: JsonDict) -> None:
        """Set a user's cross-signing key.

        Args:
            user_id: the user to set the user-signing key for
            key_type: the type of cross-signing key to set
            key: the key data
        """

        async with self._cross_signing_id_gen.get_next() as stream_id:
            return await self.db_pool.runInteraction(
                "add_e2e_cross_signing_key",
                self._set_e2e_cross_signing_key_txn,
                user_id,
                key_type,
                key,
                stream_id,
            )

    async def store_e2e_cross_signing_signatures(
            self, user_id: str,
            signatures: "Iterable[SignatureListItem]") -> None:
        """Stores cross-signing signatures.

        Args:
            user_id: the user who made the signatures
            signatures: signatures to add
        """
        await self.db_pool.simple_insert_many(
            "e2e_cross_signing_signatures",
            keys=(
                "user_id",
                "key_id",
                "target_user_id",
                "target_device_id",
                "signature",
            ),
            values=[(
                user_id,
                item.signing_key_id,
                item.target_user_id,
                item.target_device_id,
                item.signature,
            ) for item in signatures],
            desc="add_e2e_signing_key",
        )
Ejemplo n.º 24
0
class DeviceInboxWorkerStore(SQLBaseStore):
    def __init__(self, database: DatabasePool, db_conn, hs):
        super().__init__(database, db_conn, hs)

        self._instance_name = hs.get_instance_name()

        # Map of (user_id, device_id) to the last stream_id that has been
        # deleted up to. This is so that we can no op deletions.
        self._last_device_delete_cache = ExpiringCache(
            cache_name="last_device_delete_cache",
            clock=self._clock,
            max_len=10000,
            expiry_ms=30 * 60 * 1000,
        )

        if isinstance(database.engine, PostgresEngine):
            self._can_write_to_device = (self._instance_name
                                         in hs.config.worker.writers.to_device)

            self._device_inbox_id_gen = MultiWriterIdGenerator(
                db_conn=db_conn,
                db=database,
                stream_name="to_device",
                instance_name=self._instance_name,
                tables=[("device_inbox", "instance_name", "stream_id")],
                sequence_name="device_inbox_sequence",
                writers=hs.config.worker.writers.to_device,
            )
        else:
            self._can_write_to_device = True
            self._device_inbox_id_gen = StreamIdGenerator(
                db_conn, "device_inbox", "stream_id")

        max_device_inbox_id = self._device_inbox_id_gen.get_current_token()
        device_inbox_prefill, min_device_inbox_id = self.db_pool.get_cache_dict(
            db_conn,
            "device_inbox",
            entity_column="user_id",
            stream_column="stream_id",
            max_value=max_device_inbox_id,
            limit=1000,
        )
        self._device_inbox_stream_cache = StreamChangeCache(
            "DeviceInboxStreamChangeCache",
            min_device_inbox_id,
            prefilled_cache=device_inbox_prefill,
        )

        # The federation outbox and the local device inbox uses the same
        # stream_id generator.
        device_outbox_prefill, min_device_outbox_id = self.db_pool.get_cache_dict(
            db_conn,
            "device_federation_outbox",
            entity_column="destination",
            stream_column="stream_id",
            max_value=max_device_inbox_id,
            limit=1000,
        )
        self._device_federation_outbox_stream_cache = StreamChangeCache(
            "DeviceFederationOutboxStreamChangeCache",
            min_device_outbox_id,
            prefilled_cache=device_outbox_prefill,
        )

    def process_replication_rows(self, stream_name, instance_name, token,
                                 rows):
        if stream_name == ToDeviceStream.NAME:
            self._device_inbox_id_gen.advance(instance_name, token)
            for row in rows:
                if row.entity.startswith("@"):
                    self._device_inbox_stream_cache.entity_has_changed(
                        row.entity, token)
                else:
                    self._device_federation_outbox_stream_cache.entity_has_changed(
                        row.entity, token)
        return super().process_replication_rows(stream_name, instance_name,
                                                token, rows)

    def get_to_device_stream_token(self):
        return self._device_inbox_id_gen.get_current_token()

    async def get_new_messages_for_device(
        self,
        user_id: str,
        device_id: str,
        last_stream_id: int,
        current_stream_id: int,
        limit: int = 100,
    ) -> Tuple[List[dict], int]:
        """
        Args:
            user_id: The recipient user_id.
            device_id: The recipient device_id.
            last_stream_id: The last stream ID checked.
            current_stream_id: The current position of the to device
                message stream.
            limit: The maximum number of messages to retrieve.

        Returns:
            A list of messages for the device and where in the stream the messages got to.
        """
        has_changed = self._device_inbox_stream_cache.has_entity_changed(
            user_id, last_stream_id)
        if not has_changed:
            return ([], current_stream_id)

        def get_new_messages_for_device_txn(txn):
            sql = ("SELECT stream_id, message_json FROM device_inbox"
                   " WHERE user_id = ? AND device_id = ?"
                   " AND ? < stream_id AND stream_id <= ?"
                   " ORDER BY stream_id ASC"
                   " LIMIT ?")
            txn.execute(
                sql,
                (user_id, device_id, last_stream_id, current_stream_id, limit))
            messages = []
            for row in txn:
                stream_pos = row[0]
                messages.append(db_to_json(row[1]))
            if len(messages) < limit:
                stream_pos = current_stream_id
            return messages, stream_pos

        return await self.db_pool.runInteraction(
            "get_new_messages_for_device", get_new_messages_for_device_txn)

    @trace
    async def delete_messages_for_device(self, user_id: str, device_id: str,
                                         up_to_stream_id: int) -> int:
        """
        Args:
            user_id: The recipient user_id.
            device_id: The recipient device_id.
            up_to_stream_id: Where to delete messages up to.

        Returns:
            The number of messages deleted.
        """
        # If we have cached the last stream id we've deleted up to, we can
        # check if there is likely to be anything that needs deleting
        last_deleted_stream_id = self._last_device_delete_cache.get(
            (user_id, device_id), None)

        set_tag("last_deleted_stream_id", last_deleted_stream_id)

        if last_deleted_stream_id:
            has_changed = self._device_inbox_stream_cache.has_entity_changed(
                user_id, last_deleted_stream_id)
            if not has_changed:
                log_kv({"message": "No changes in cache since last check"})
                return 0

        def delete_messages_for_device_txn(txn):
            sql = ("DELETE FROM device_inbox"
                   " WHERE user_id = ? AND device_id = ?"
                   " AND stream_id <= ?")
            txn.execute(sql, (user_id, device_id, up_to_stream_id))
            return txn.rowcount

        count = await self.db_pool.runInteraction(
            "delete_messages_for_device", delete_messages_for_device_txn)

        log_kv({
            "message": "deleted {} messages for device".format(count),
            "count": count
        })

        # Update the cache, ensuring that we only ever increase the value
        last_deleted_stream_id = self._last_device_delete_cache.get(
            (user_id, device_id), 0)
        self._last_device_delete_cache[(user_id, device_id)] = max(
            last_deleted_stream_id, up_to_stream_id)

        return count

    @trace
    async def get_new_device_msgs_for_remote(self, destination, last_stream_id,
                                             current_stream_id,
                                             limit) -> Tuple[List[dict], int]:
        """
        Args:
            destination(str): The name of the remote server.
            last_stream_id(int|long): The last position of the device message stream
                that the server sent up to.
            current_stream_id(int|long): The current position of the device
                message stream.
        Returns:
            A list of messages for the device and where in the stream the messages got to.
        """

        set_tag("destination", destination)
        set_tag("last_stream_id", last_stream_id)
        set_tag("current_stream_id", current_stream_id)
        set_tag("limit", limit)

        has_changed = self._device_federation_outbox_stream_cache.has_entity_changed(
            destination, last_stream_id)
        if not has_changed or last_stream_id == current_stream_id:
            log_kv({"message": "No new messages in stream"})
            return ([], current_stream_id)

        if limit <= 0:
            # This can happen if we run out of room for EDUs in the transaction.
            return ([], last_stream_id)

        @trace
        def get_new_messages_for_remote_destination_txn(txn):
            sql = (
                "SELECT stream_id, messages_json FROM device_federation_outbox"
                " WHERE destination = ?"
                " AND ? < stream_id AND stream_id <= ?"
                " ORDER BY stream_id ASC"
                " LIMIT ?")
            txn.execute(
                sql, (destination, last_stream_id, current_stream_id, limit))
            messages = []
            for row in txn:
                stream_pos = row[0]
                messages.append(db_to_json(row[1]))
            if len(messages) < limit:
                log_kv({"message": "Set stream position to current position"})
                stream_pos = current_stream_id
            return messages, stream_pos

        return await self.db_pool.runInteraction(
            "get_new_device_msgs_for_remote",
            get_new_messages_for_remote_destination_txn,
        )

    @trace
    async def delete_device_msgs_for_remote(self, destination: str,
                                            up_to_stream_id: int) -> None:
        """Used to delete messages when the remote destination acknowledges
        their receipt.

        Args:
            destination: The destination server_name
            up_to_stream_id: Where to delete messages up to.
        """
        def delete_messages_for_remote_destination_txn(txn):
            sql = ("DELETE FROM device_federation_outbox"
                   " WHERE destination = ?"
                   " AND stream_id <= ?")
            txn.execute(sql, (destination, up_to_stream_id))

        await self.db_pool.runInteraction(
            "delete_device_msgs_for_remote",
            delete_messages_for_remote_destination_txn)

    async def get_all_new_device_messages(
            self, instance_name: str, last_id: int, current_id: int,
            limit: int) -> Tuple[List[Tuple[int, tuple]], int, bool]:
        """Get updates for to device replication stream.

        Args:
            instance_name: The writer we want to fetch updates from. Unused
                here since there is only ever one writer.
            last_id: The token to fetch updates from. Exclusive.
            current_id: The token to fetch updates up to. Inclusive.
            limit: The requested limit for the number of rows to return. The
                function may return more or fewer rows.

        Returns:
            A tuple consisting of: the updates, a token to use to fetch
            subsequent updates, and whether we returned fewer rows than exists
            between the requested tokens due to the limit.

            The token returned can be used in a subsequent call to this
            function to get further updatees.

            The updates are a list of 2-tuples of stream ID and the row data
        """

        if last_id == current_id:
            return [], current_id, False

        def get_all_new_device_messages_txn(txn):
            # We limit like this as we might have multiple rows per stream_id, and
            # we want to make sure we always get all entries for any stream_id
            # we return.
            upper_pos = min(current_id, last_id + limit)
            sql = ("SELECT max(stream_id), user_id"
                   " FROM device_inbox"
                   " WHERE ? < stream_id AND stream_id <= ?"
                   " GROUP BY user_id")
            txn.execute(sql, (last_id, upper_pos))
            updates = [(row[0], row[1:]) for row in txn]

            sql = ("SELECT max(stream_id), destination"
                   " FROM device_federation_outbox"
                   " WHERE ? < stream_id AND stream_id <= ?"
                   " GROUP BY destination")
            txn.execute(sql, (last_id, upper_pos))
            updates.extend((row[0], row[1:]) for row in txn)

            # Order by ascending stream ordering
            updates.sort()

            limited = False
            upto_token = current_id
            if len(updates) >= limit:
                upto_token = updates[-1][0]
                limited = True

            return updates, upto_token, limited

        return await self.db_pool.runInteraction(
            "get_all_new_device_messages", get_all_new_device_messages_txn)

    @trace
    async def add_messages_to_device_inbox(
        self,
        local_messages_by_user_then_device: dict,
        remote_messages_by_destination: dict,
    ) -> int:
        """Used to send messages from this server.

        Args:
            local_messages_by_user_and_device:
                Dictionary of user_id to device_id to message.
            remote_messages_by_destination:
                Dictionary of destination server_name to the EDU JSON to send.

        Returns:
            The new stream_id.
        """

        assert self._can_write_to_device

        def add_messages_txn(txn, now_ms, stream_id):
            # Add the local messages directly to the local inbox.
            self._add_messages_to_local_device_inbox_txn(
                txn, stream_id, local_messages_by_user_then_device)

            # Add the remote messages to the federation outbox.
            # We'll send them to a remote server when we next send a
            # federation transaction to that destination.
            self.db_pool.simple_insert_many_txn(
                txn,
                table="device_federation_outbox",
                values=[{
                    "destination": destination,
                    "stream_id": stream_id,
                    "queued_ts": now_ms,
                    "messages_json": json_encoder.encode(edu),
                    "instance_name": self._instance_name,
                } for destination, edu in
                        remote_messages_by_destination.items()],
            )

        async with self._device_inbox_id_gen.get_next() as stream_id:
            now_ms = self.clock.time_msec()
            await self.db_pool.runInteraction("add_messages_to_device_inbox",
                                              add_messages_txn, now_ms,
                                              stream_id)
            for user_id in local_messages_by_user_then_device.keys():
                self._device_inbox_stream_cache.entity_has_changed(
                    user_id, stream_id)
            for destination in remote_messages_by_destination.keys():
                self._device_federation_outbox_stream_cache.entity_has_changed(
                    destination, stream_id)

        return self._device_inbox_id_gen.get_current_token()

    async def add_messages_from_remote_to_device_inbox(
            self, origin: str, message_id: str,
            local_messages_by_user_then_device: dict) -> int:
        assert self._can_write_to_device

        def add_messages_txn(txn, now_ms, stream_id):
            # Check if we've already inserted a matching message_id for that
            # origin. This can happen if the origin doesn't receive our
            # acknowledgement from the first time we received the message.
            already_inserted = self.db_pool.simple_select_one_txn(
                txn,
                table="device_federation_inbox",
                keyvalues={
                    "origin": origin,
                    "message_id": message_id
                },
                retcols=("message_id", ),
                allow_none=True,
            )
            if already_inserted is not None:
                return

            # Add an entry for this message_id so that we know we've processed
            # it.
            self.db_pool.simple_insert_txn(
                txn,
                table="device_federation_inbox",
                values={
                    "origin": origin,
                    "message_id": message_id,
                    "received_ts": now_ms,
                },
            )

            # Add the messages to the approriate local device inboxes so that
            # they'll be sent to the devices when they next sync.
            self._add_messages_to_local_device_inbox_txn(
                txn, stream_id, local_messages_by_user_then_device)

        async with self._device_inbox_id_gen.get_next() as stream_id:
            now_ms = self.clock.time_msec()
            await self.db_pool.runInteraction(
                "add_messages_from_remote_to_device_inbox",
                add_messages_txn,
                now_ms,
                stream_id,
            )
            for user_id in local_messages_by_user_then_device.keys():
                self._device_inbox_stream_cache.entity_has_changed(
                    user_id, stream_id)

        return stream_id

    def _add_messages_to_local_device_inbox_txn(self, txn, stream_id,
                                                messages_by_user_then_device):
        assert self._can_write_to_device

        local_by_user_then_device = {}
        for user_id, messages_by_device in messages_by_user_then_device.items(
        ):
            messages_json_for_user = {}
            devices = list(messages_by_device.keys())
            if len(devices) == 1 and devices[0] == "*":
                # Handle wildcard device_ids.
                devices = self.db_pool.simple_select_onecol_txn(
                    txn,
                    table="devices",
                    keyvalues={"user_id": user_id},
                    retcol="device_id",
                )

                message_json = json_encoder.encode(messages_by_device["*"])
                for device_id in devices:
                    # Add the message for all devices for this user on this
                    # server.
                    messages_json_for_user[device_id] = message_json
            else:
                if not devices:
                    continue

                rows = self.db_pool.simple_select_many_txn(
                    txn,
                    table="devices",
                    keyvalues={"user_id": user_id},
                    column="device_id",
                    iterable=devices,
                    retcols=("device_id", ),
                )

                for row in rows:
                    # Only insert into the local inbox if the device exists on
                    # this server
                    device_id = row["device_id"]
                    message_json = json_encoder.encode(
                        messages_by_device[device_id])
                    messages_json_for_user[device_id] = message_json

            if messages_json_for_user:
                local_by_user_then_device[user_id] = messages_json_for_user

        if not local_by_user_then_device:
            return

        self.db_pool.simple_insert_many_txn(
            txn,
            table="device_inbox",
            values=[{
                "user_id": user_id,
                "device_id": device_id,
                "stream_id": stream_id,
                "message_json": message_json,
                "instance_name": self._instance_name,
            } for user_id, messages_by_device in
                    local_by_user_then_device.items()
                    for device_id, message_json in messages_by_device.items()],
        )
Ejemplo n.º 25
0
    def __init__(
        self,
        database: DatabasePool,
        db_conn: LoggingDatabaseConnection,
        hs: "HomeServer",
    ):
        self.hs = hs
        self._clock = hs.get_clock()
        self.database_engine = database.engine

        self._device_list_id_gen = StreamIdGenerator(
            db_conn,
            "device_lists_stream",
            "stream_id",
            extra_tables=[
                ("user_signature_stream", "stream_id"),
                ("device_lists_outbound_pokes", "stream_id"),
                ("device_lists_changes_in_room", "stream_id"),
            ],
        )

        self._cache_id_gen: Optional[MultiWriterIdGenerator]
        if isinstance(self.database_engine, PostgresEngine):
            # We set the `writers` to an empty list here as we don't care about
            # missing updates over restarts, as we'll not have anything in our
            # caches to invalidate. (This reduces the amount of writes to the DB
            # that happen).
            self._cache_id_gen = MultiWriterIdGenerator(
                db_conn,
                database,
                stream_name="caches",
                instance_name=hs.get_instance_name(),
                tables=[(
                    "cache_invalidation_stream_by_instance",
                    "instance_name",
                    "stream_id",
                )],
                sequence_name="cache_invalidation_stream_seq",
                writers=[],
            )

        else:
            self._cache_id_gen = None

        super().__init__(database, db_conn, hs)

        events_max = self._stream_id_gen.get_current_token()
        curr_state_delta_prefill, min_curr_state_delta_id = self.db_pool.get_cache_dict(
            db_conn,
            "current_state_delta_stream",
            entity_column="room_id",
            stream_column="stream_id",
            max_value=events_max,  # As we share the stream id with events token
            limit=1000,
        )
        self._curr_state_delta_stream_cache = StreamChangeCache(
            "_curr_state_delta_stream_cache",
            min_curr_state_delta_id,
            prefilled_cache=curr_state_delta_prefill,
        )

        self._stream_order_on_start = self.get_room_max_stream_ordering()
        self._min_stream_order_on_start = self.get_room_min_stream_ordering()
Ejemplo n.º 26
0
class PresenceStore(PresenceBackgroundUpdateStore):
    def __init__(
        self,
        database: DatabasePool,
        db_conn: LoggingDatabaseConnection,
        hs: "HomeServer",
    ):
        super().__init__(database, db_conn, hs)

        self._can_persist_presence = (hs.get_instance_name()
                                      in hs.config.worker.writers.presence)

        if isinstance(database.engine, PostgresEngine):
            self._presence_id_gen = MultiWriterIdGenerator(
                db_conn=db_conn,
                db=database,
                stream_name="presence_stream",
                instance_name=self._instance_name,
                tables=[("presence_stream", "instance_name", "stream_id")],
                sequence_name="presence_stream_sequence",
                writers=hs.config.worker.writers.presence,
            )
        else:
            self._presence_id_gen = StreamIdGenerator(db_conn,
                                                      "presence_stream",
                                                      "stream_id")

        self.hs = hs
        self._presence_on_startup = self._get_active_presence(db_conn)

        presence_cache_prefill, min_presence_val = self.db_pool.get_cache_dict(
            db_conn,
            "presence_stream",
            entity_column="user_id",
            stream_column="stream_id",
            max_value=self._presence_id_gen.get_current_token(),
        )
        self.presence_stream_cache = StreamChangeCache(
            "PresenceStreamChangeCache",
            min_presence_val,
            prefilled_cache=presence_cache_prefill,
        )

    async def update_presence(self, presence_states) -> Tuple[int, int]:
        assert self._can_persist_presence

        stream_ordering_manager = self._presence_id_gen.get_next_mult(
            len(presence_states))

        async with stream_ordering_manager as stream_orderings:
            await self.db_pool.runInteraction(
                "update_presence",
                self._update_presence_txn,
                stream_orderings,
                presence_states,
            )

        return stream_orderings[-1], self._presence_id_gen.get_current_token()

    def _update_presence_txn(self, txn, stream_orderings, presence_states):
        for stream_id, state in zip(stream_orderings, presence_states):
            txn.call_after(self.presence_stream_cache.entity_has_changed,
                           state.user_id, stream_id)
            txn.call_after(self._get_presence_for_user.invalidate,
                           (state.user_id, ))

        # Delete old rows to stop database from getting really big
        sql = "DELETE FROM presence_stream WHERE stream_id < ? AND "

        for states in batch_iter(presence_states, 50):
            clause, args = make_in_list_sql_clause(self.database_engine,
                                                   "user_id",
                                                   [s.user_id for s in states])
            txn.execute(sql + clause, [stream_id] + list(args))

        # Actually insert new rows
        self.db_pool.simple_insert_many_txn(
            txn,
            table="presence_stream",
            keys=(
                "stream_id",
                "user_id",
                "state",
                "last_active_ts",
                "last_federation_update_ts",
                "last_user_sync_ts",
                "status_msg",
                "currently_active",
                "instance_name",
            ),
            values=[(
                stream_id,
                state.user_id,
                state.state,
                state.last_active_ts,
                state.last_federation_update_ts,
                state.last_user_sync_ts,
                state.status_msg,
                state.currently_active,
                self._instance_name,
            ) for stream_id, state in zip(stream_orderings, presence_states)],
        )

    async def get_all_presence_updates(
            self, instance_name: str, last_id: int, current_id: int,
            limit: int) -> Tuple[List[Tuple[int, list]], int, bool]:
        """Get updates for presence replication stream.

        Args:
            instance_name: The writer we want to fetch updates from. Unused
                here since there is only ever one writer.
            last_id: The token to fetch updates from. Exclusive.
            current_id: The token to fetch updates up to. Inclusive.
            limit: The requested limit for the number of rows to return. The
                function may return more or fewer rows.

        Returns:
            A tuple consisting of: the updates, a token to use to fetch
            subsequent updates, and whether we returned fewer rows than exists
            between the requested tokens due to the limit.

            The token returned can be used in a subsequent call to this
            function to get further updatees.

            The updates are a list of 2-tuples of stream ID and the row data
        """

        if last_id == current_id:
            return [], current_id, False

        def get_all_presence_updates_txn(txn):
            sql = """
                SELECT stream_id, user_id, state, last_active_ts,
                    last_federation_update_ts, last_user_sync_ts,
                    status_msg,
                currently_active
                FROM presence_stream
                WHERE ? < stream_id AND stream_id <= ?
                ORDER BY stream_id ASC
                LIMIT ?
            """
            txn.execute(sql, (last_id, current_id, limit))
            updates = [(row[0], row[1:]) for row in txn]

            upper_bound = current_id
            limited = False
            if len(updates) >= limit:
                upper_bound = updates[-1][0]
                limited = True

            return updates, upper_bound, limited

        return await self.db_pool.runInteraction("get_all_presence_updates",
                                                 get_all_presence_updates_txn)

    @cached()
    def _get_presence_for_user(self, user_id):
        raise NotImplementedError()

    @cachedList(
        cached_method_name="_get_presence_for_user",
        list_name="user_ids",
        num_args=1,
    )
    async def get_presence_for_users(self, user_ids):
        rows = await self.db_pool.simple_select_many_batch(
            table="presence_stream",
            column="user_id",
            iterable=user_ids,
            keyvalues={},
            retcols=(
                "user_id",
                "state",
                "last_active_ts",
                "last_federation_update_ts",
                "last_user_sync_ts",
                "status_msg",
                "currently_active",
            ),
            desc="get_presence_for_users",
        )

        for row in rows:
            row["currently_active"] = bool(row["currently_active"])

        return {row["user_id"]: UserPresenceState(**row) for row in rows}

    async def should_user_receive_full_presence_with_token(
        self,
        user_id: str,
        from_token: int,
    ) -> bool:
        """Check whether the given user should receive full presence using the stream token
        they're updating from.

        Args:
            user_id: The ID of the user to check.
            from_token: The stream token included in their /sync token.

        Returns:
            True if the user should have full presence sent to them, False otherwise.
        """
        def _should_user_receive_full_presence_with_token_txn(txn):
            sql = """
                SELECT 1 FROM users_to_send_full_presence_to
                WHERE user_id = ?
                AND presence_stream_id >= ?
            """
            txn.execute(sql, (user_id, from_token))
            return bool(txn.fetchone())

        return await self.db_pool.runInteraction(
            "should_user_receive_full_presence_with_token",
            _should_user_receive_full_presence_with_token_txn,
        )

    async def add_users_to_send_full_presence_to(self,
                                                 user_ids: Iterable[str]):
        """Adds to the list of users who should receive a full snapshot of presence
        upon their next sync.

        Args:
            user_ids: An iterable of user IDs.
        """
        # Add user entries to the table, updating the presence_stream_id column if the user already
        # exists in the table.
        presence_stream_id = self._presence_id_gen.get_current_token()
        await self.db_pool.simple_upsert_many(
            table="users_to_send_full_presence_to",
            key_names=("user_id", ),
            key_values=[(user_id, ) for user_id in user_ids],
            value_names=("presence_stream_id", ),
            # We save the current presence stream ID token along with the user ID entry so
            # that when a user /sync's, even if they syncing multiple times across separate
            # devices at different times, each device will receive full presence once - when
            # the presence stream ID in their sync token is less than the one in the table
            # for their user ID.
            value_values=[(presence_stream_id, ) for _ in user_ids],
            desc="add_users_to_send_full_presence_to",
        )

    async def get_presence_for_all_users(
        self,
        include_offline: bool = True,
    ) -> Dict[str, UserPresenceState]:
        """Retrieve the current presence state for all users.

        Note that the presence_stream table is culled frequently, so it should only
        contain the latest presence state for each user.

        Args:
            include_offline: Whether to include offline presence states

        Returns:
            A dict of user IDs to their current UserPresenceState.
        """
        users_to_state = {}

        exclude_keyvalues = None
        if not include_offline:
            # Exclude offline presence state
            exclude_keyvalues = {"state": "offline"}

        # This may be a very heavy database query.
        # We paginate in order to not block a database connection.
        limit = 100
        offset = 0
        while True:
            rows = await self.db_pool.runInteraction(
                "get_presence_for_all_users",
                self.db_pool.simple_select_list_paginate_txn,
                "presence_stream",
                orderby="stream_id",
                start=offset,
                limit=limit,
                exclude_keyvalues=exclude_keyvalues,
                retcols=(
                    "user_id",
                    "state",
                    "last_active_ts",
                    "last_federation_update_ts",
                    "last_user_sync_ts",
                    "status_msg",
                    "currently_active",
                ),
                order_direction="ASC",
            )

            for row in rows:
                users_to_state[row["user_id"]] = UserPresenceState(**row)

            # We've run out of updates to query
            if len(rows) < limit:
                break

            offset += limit

        return users_to_state

    def get_current_presence_token(self):
        return self._presence_id_gen.get_current_token()

    def _get_active_presence(self, db_conn: Connection):
        """Fetch non-offline presence from the database so that we can register
        the appropriate time outs.
        """

        # The `presence_stream_state_not_offline_idx` index should be used for this
        # query.
        sql = (
            "SELECT user_id, state, last_active_ts, last_federation_update_ts,"
            " last_user_sync_ts, status_msg, currently_active FROM presence_stream"
            " WHERE state != ?")

        txn = db_conn.cursor()
        txn.execute(sql, (PresenceState.OFFLINE, ))
        rows = self.db_pool.cursor_to_dict(txn)
        txn.close()

        for row in rows:
            row["currently_active"] = bool(row["currently_active"])

        return [UserPresenceState(**row) for row in rows]

    def take_presence_startup_info(self):
        active_on_startup = self._presence_on_startup
        self._presence_on_startup = None
        return active_on_startup

    def process_replication_rows(self, stream_name, instance_name, token,
                                 rows):
        if stream_name == PresenceStream.NAME:
            self._presence_id_gen.advance(instance_name, token)
            for row in rows:
                self.presence_stream_cache.entity_has_changed(
                    row.user_id, token)
                self._get_presence_for_user.invalidate((row.user_id, ))
        return super().process_replication_rows(stream_name, instance_name,
                                                token, rows)
Ejemplo n.º 27
0
class DataStore(
        EventsBackgroundUpdatesStore,
        RoomMemberStore,
        RoomStore,
        RoomBatchStore,
        RegistrationStore,
        StreamWorkerStore,
        ProfileStore,
        PresenceStore,
        TransactionWorkerStore,
        DirectoryStore,
        KeyStore,
        StateStore,
        SignatureStore,
        ApplicationServiceStore,
        PurgeEventsStore,
        EventFederationStore,
        MediaRepositoryStore,
        RejectionsStore,
        FilteringStore,
        PusherStore,
        PushRuleStore,
        ApplicationServiceTransactionStore,
        ReceiptsStore,
        EndToEndKeyStore,
        EndToEndRoomKeyStore,
        SearchStore,
        TagsStore,
        AccountDataStore,
        EventPushActionsStore,
        OpenIdStore,
        ClientIpWorkerStore,
        DeviceStore,
        DeviceInboxStore,
        UserDirectoryStore,
        GroupServerStore,
        UserErasureStore,
        MonthlyActiveUsersWorkerStore,
        StatsStore,
        RelationsStore,
        CensorEventsStore,
        UIAuthStore,
        EventForwardExtremitiesStore,
        CacheInvalidationWorkerStore,
        ServerMetricsStore,
        LockStore,
        SessionStore,
):
    def __init__(
        self,
        database: DatabasePool,
        db_conn: LoggingDatabaseConnection,
        hs: "HomeServer",
    ):
        self.hs = hs
        self._clock = hs.get_clock()
        self.database_engine = database.engine

        self._device_list_id_gen = StreamIdGenerator(
            db_conn,
            "device_lists_stream",
            "stream_id",
            extra_tables=[
                ("user_signature_stream", "stream_id"),
                ("device_lists_outbound_pokes", "stream_id"),
                ("device_lists_changes_in_room", "stream_id"),
            ],
        )

        self._cache_id_gen: Optional[MultiWriterIdGenerator]
        if isinstance(self.database_engine, PostgresEngine):
            # We set the `writers` to an empty list here as we don't care about
            # missing updates over restarts, as we'll not have anything in our
            # caches to invalidate. (This reduces the amount of writes to the DB
            # that happen).
            self._cache_id_gen = MultiWriterIdGenerator(
                db_conn,
                database,
                stream_name="caches",
                instance_name=hs.get_instance_name(),
                tables=[(
                    "cache_invalidation_stream_by_instance",
                    "instance_name",
                    "stream_id",
                )],
                sequence_name="cache_invalidation_stream_seq",
                writers=[],
            )

        else:
            self._cache_id_gen = None

        super().__init__(database, db_conn, hs)

        events_max = self._stream_id_gen.get_current_token()
        curr_state_delta_prefill, min_curr_state_delta_id = self.db_pool.get_cache_dict(
            db_conn,
            "current_state_delta_stream",
            entity_column="room_id",
            stream_column="stream_id",
            max_value=events_max,  # As we share the stream id with events token
            limit=1000,
        )
        self._curr_state_delta_stream_cache = StreamChangeCache(
            "_curr_state_delta_stream_cache",
            min_curr_state_delta_id,
            prefilled_cache=curr_state_delta_prefill,
        )

        self._stream_order_on_start = self.get_room_max_stream_ordering()
        self._min_stream_order_on_start = self.get_room_min_stream_ordering()

    def get_device_stream_token(self) -> int:
        return self._device_list_id_gen.get_current_token()

    async def get_users(self) -> List[JsonDict]:
        """Function to retrieve a list of users in users table.

        Returns:
            A list of dictionaries representing users.
        """
        return await self.db_pool.simple_select_list(
            table="users",
            keyvalues={},
            retcols=[
                "name",
                "password_hash",
                "is_guest",
                "admin",
                "user_type",
                "deactivated",
            ],
            desc="get_users",
        )

    async def get_users_paginate(
        self,
        start: int,
        limit: int,
        user_id: Optional[str] = None,
        name: Optional[str] = None,
        guests: bool = True,
        deactivated: bool = False,
        order_by: str = UserSortOrder.USER_ID.value,
        direction: str = "f",
    ) -> Tuple[List[JsonDict], int]:
        """Function to retrieve a paginated list of users from
        users list. This will return a json list of users and the
        total number of users matching the filter criteria.

        Args:
            start: start number to begin the query from
            limit: number of rows to retrieve
            user_id: search for user_id. ignored if name is not None
            name: search for local part of user_id or display name
            guests: whether to in include guest users
            deactivated: whether to include deactivated users
            order_by: the sort order of the returned list
            direction: sort ascending or descending
        Returns:
            A tuple of a list of mappings from user to information and a count of total users.
        """
        def get_users_paginate_txn(
            txn: LoggingTransaction, ) -> Tuple[List[JsonDict], int]:
            filters = []
            args = [self.hs.config.server.server_name]

            # Set ordering
            order_by_column = UserSortOrder(order_by).value

            if direction == "b":
                order = "DESC"
            else:
                order = "ASC"

            # `name` is in database already in lower case
            if name:
                filters.append("(name LIKE ? OR LOWER(displayname) LIKE ?)")
                args.extend(
                    ["@%" + name.lower() + "%:%", "%" + name.lower() + "%"])
            elif user_id:
                filters.append("name LIKE ?")
                args.extend(["%" + user_id.lower() + "%"])

            if not guests:
                filters.append("is_guest = 0")

            if not deactivated:
                filters.append("deactivated = 0")

            where_clause = "WHERE " + " AND ".join(filters) if len(
                filters) > 0 else ""

            sql_base = f"""
                FROM users as u
                LEFT JOIN profiles AS p ON u.name = '@' || p.user_id || ':' || ?
                {where_clause}
                """
            sql = "SELECT COUNT(*) as total_users " + sql_base
            txn.execute(sql, args)
            count = cast(Tuple[int], txn.fetchone())[0]

            sql = f"""
                SELECT name, user_type, is_guest, admin, deactivated, shadow_banned,
                displayname, avatar_url, creation_ts * 1000 as creation_ts
                {sql_base}
                ORDER BY {order_by_column} {order}, u.name ASC
                LIMIT ? OFFSET ?
            """
            args += [limit, start]
            txn.execute(sql, args)
            users = self.db_pool.cursor_to_dict(txn)
            return users, count

        return await self.db_pool.runInteraction("get_users_paginate_txn",
                                                 get_users_paginate_txn)

    async def search_users(self, term: str) -> Optional[List[JsonDict]]:
        """Function to search users list for one or more users with
        the matched term.

        Args:
            term: search term

        Returns:
            A list of dictionaries or None.
        """
        return await self.db_pool.simple_search_list(
            table="users",
            term=term,
            col="name",
            retcols=[
                "name", "password_hash", "is_guest", "admin", "user_type"
            ],
            desc="search_users",
        )
Ejemplo n.º 28
0
    def __init__(self, db_conn, hs):
        self._account_data_id_gen = StreamIdGenerator(
            db_conn, "account_data_max_stream_id", "stream_id"
        )

        super(AccountDataStore, self).__init__(db_conn, hs)
Ejemplo n.º 29
0
class ReceiptsStore(ReceiptsWorkerStore):
    def __init__(self, database: Database, db_conn, hs):
        # We instantiate this first as the ReceiptsWorkerStore constructor
        # needs to be able to call get_max_receipt_stream_id
        self._receipts_id_gen = StreamIdGenerator(db_conn,
                                                  "receipts_linearized",
                                                  "stream_id")

        super(ReceiptsStore, self).__init__(database, db_conn, hs)

    def get_max_receipt_stream_id(self):
        return self._receipts_id_gen.get_current_token()

    def insert_linearized_receipt_txn(self, txn, room_id, receipt_type,
                                      user_id, event_id, data, stream_id):
        """Inserts a read-receipt into the database if it's newer than the current RR

        Returns: int|None
            None if the RR is older than the current RR
            otherwise, the rx timestamp of the event that the RR corresponds to
                (or 0 if the event is unknown)
        """
        res = self.db.simple_select_one_txn(
            txn,
            table="events",
            retcols=["stream_ordering", "received_ts"],
            keyvalues={"event_id": event_id},
            allow_none=True,
        )

        stream_ordering = int(res["stream_ordering"]) if res else None
        rx_ts = res["received_ts"] if res else 0

        # We don't want to clobber receipts for more recent events, so we
        # have to compare orderings of existing receipts
        if stream_ordering is not None:
            sql = (
                "SELECT stream_ordering, event_id FROM events"
                " INNER JOIN receipts_linearized as r USING (event_id, room_id)"
                " WHERE r.room_id = ? AND r.receipt_type = ? AND r.user_id = ?"
            )
            txn.execute(sql, (room_id, receipt_type, user_id))

            for so, eid in txn:
                if int(so) >= stream_ordering:
                    logger.debug(
                        "Ignoring new receipt for %s in favour of existing "
                        "one for later event %s",
                        event_id,
                        eid,
                    )
                    return None

        txn.call_after(self.get_receipts_for_room.invalidate,
                       (room_id, receipt_type))
        txn.call_after(
            self._invalidate_get_users_with_receipts_in_room,
            room_id,
            receipt_type,
            user_id,
        )
        txn.call_after(self.get_receipts_for_user.invalidate,
                       (user_id, receipt_type))
        # FIXME: This shouldn't invalidate the whole cache
        txn.call_after(self._get_linearized_receipts_for_room.invalidate_many,
                       (room_id, ))

        txn.call_after(self._receipts_stream_cache.entity_has_changed, room_id,
                       stream_id)

        txn.call_after(
            self.get_last_receipt_event_id_for_user.invalidate,
            (user_id, room_id, receipt_type),
        )

        self.db.simple_upsert_txn(
            txn,
            table="receipts_linearized",
            keyvalues={
                "room_id": room_id,
                "receipt_type": receipt_type,
                "user_id": user_id,
            },
            values={
                "stream_id": stream_id,
                "event_id": event_id,
                "data": json.dumps(data),
            },
            # receipts_linearized has a unique constraint on
            # (user_id, room_id, receipt_type), so no need to lock
            lock=False,
        )

        if receipt_type == "m.read" and stream_ordering is not None:
            self._remove_old_push_actions_before_txn(
                txn,
                room_id=room_id,
                user_id=user_id,
                stream_ordering=stream_ordering)

        return rx_ts

    @defer.inlineCallbacks
    def insert_receipt(self, room_id, receipt_type, user_id, event_ids, data):
        """Insert a receipt, either from local client or remote server.

        Automatically does conversion between linearized and graph
        representations.
        """
        if not event_ids:
            return

        if len(event_ids) == 1:
            linearized_event_id = event_ids[0]
        else:
            # we need to points in graph -> linearized form.
            # TODO: Make this better.
            def graph_to_linear(txn):
                clause, args = make_in_list_sql_clause(self.database_engine,
                                                       "event_id", event_ids)

                sql = """
                    SELECT event_id WHERE room_id = ? AND stream_ordering IN (
                        SELECT max(stream_ordering) WHERE %s
                    )
                """ % (clause, )

                txn.execute(sql, [room_id] + list(args))
                rows = txn.fetchall()
                if rows:
                    return rows[0][0]
                else:
                    raise RuntimeError("Unrecognized event_ids: %r" %
                                       (event_ids, ))

            linearized_event_id = yield self.db.runInteraction(
                "insert_receipt_conv", graph_to_linear)

        stream_id_manager = self._receipts_id_gen.get_next()
        with stream_id_manager as stream_id:
            event_ts = yield self.db.runInteraction(
                "insert_linearized_receipt",
                self.insert_linearized_receipt_txn,
                room_id,
                receipt_type,
                user_id,
                linearized_event_id,
                data,
                stream_id=stream_id,
            )

        if event_ts is None:
            return None

        now = self._clock.time_msec()
        logger.debug(
            "RR for event %s in %s (%i ms old)",
            linearized_event_id,
            room_id,
            now - event_ts,
        )

        yield self.insert_graph_receipt(room_id, receipt_type, user_id,
                                        event_ids, data)

        max_persisted_id = self._receipts_id_gen.get_current_token()

        return stream_id, max_persisted_id

    def insert_graph_receipt(self, room_id, receipt_type, user_id, event_ids,
                             data):
        return self.db.runInteraction(
            "insert_graph_receipt",
            self.insert_graph_receipt_txn,
            room_id,
            receipt_type,
            user_id,
            event_ids,
            data,
        )

    def insert_graph_receipt_txn(self, txn, room_id, receipt_type, user_id,
                                 event_ids, data):
        txn.call_after(self.get_receipts_for_room.invalidate,
                       (room_id, receipt_type))
        txn.call_after(
            self._invalidate_get_users_with_receipts_in_room,
            room_id,
            receipt_type,
            user_id,
        )
        txn.call_after(self.get_receipts_for_user.invalidate,
                       (user_id, receipt_type))
        # FIXME: This shouldn't invalidate the whole cache
        txn.call_after(self._get_linearized_receipts_for_room.invalidate_many,
                       (room_id, ))

        self.db.simple_delete_txn(
            txn,
            table="receipts_graph",
            keyvalues={
                "room_id": room_id,
                "receipt_type": receipt_type,
                "user_id": user_id,
            },
        )
        self.db.simple_insert_txn(
            txn,
            table="receipts_graph",
            values={
                "room_id": room_id,
                "receipt_type": receipt_type,
                "user_id": user_id,
                "event_ids": json.dumps(event_ids),
                "data": json.dumps(data),
            },
        )