Esempio n. 1
0
class DeviceInboxWorkerStore(SQLBaseStore):
    def __init__(self, database: DatabasePool, db_conn, hs):
        super().__init__(database, db_conn, hs)

        self._instance_name = hs.get_instance_name()

        # Map of (user_id, device_id) to the last stream_id that has been
        # deleted up to. This is so that we can no op deletions.
        self._last_device_delete_cache = ExpiringCache(
            cache_name="last_device_delete_cache",
            clock=self._clock,
            max_len=10000,
            expiry_ms=30 * 60 * 1000,
        )

        if isinstance(database.engine, PostgresEngine):
            self._can_write_to_device = (self._instance_name
                                         in hs.config.worker.writers.to_device)

            self._device_inbox_id_gen = MultiWriterIdGenerator(
                db_conn=db_conn,
                db=database,
                stream_name="to_device",
                instance_name=self._instance_name,
                tables=[("device_inbox", "instance_name", "stream_id")],
                sequence_name="device_inbox_sequence",
                writers=hs.config.worker.writers.to_device,
            )
        else:
            self._can_write_to_device = True
            self._device_inbox_id_gen = StreamIdGenerator(
                db_conn, "device_inbox", "stream_id")

        max_device_inbox_id = self._device_inbox_id_gen.get_current_token()
        device_inbox_prefill, min_device_inbox_id = self.db_pool.get_cache_dict(
            db_conn,
            "device_inbox",
            entity_column="user_id",
            stream_column="stream_id",
            max_value=max_device_inbox_id,
            limit=1000,
        )
        self._device_inbox_stream_cache = StreamChangeCache(
            "DeviceInboxStreamChangeCache",
            min_device_inbox_id,
            prefilled_cache=device_inbox_prefill,
        )

        # The federation outbox and the local device inbox uses the same
        # stream_id generator.
        device_outbox_prefill, min_device_outbox_id = self.db_pool.get_cache_dict(
            db_conn,
            "device_federation_outbox",
            entity_column="destination",
            stream_column="stream_id",
            max_value=max_device_inbox_id,
            limit=1000,
        )
        self._device_federation_outbox_stream_cache = StreamChangeCache(
            "DeviceFederationOutboxStreamChangeCache",
            min_device_outbox_id,
            prefilled_cache=device_outbox_prefill,
        )

    def process_replication_rows(self, stream_name, instance_name, token,
                                 rows):
        if stream_name == ToDeviceStream.NAME:
            self._device_inbox_id_gen.advance(instance_name, token)
            for row in rows:
                if row.entity.startswith("@"):
                    self._device_inbox_stream_cache.entity_has_changed(
                        row.entity, token)
                else:
                    self._device_federation_outbox_stream_cache.entity_has_changed(
                        row.entity, token)
        return super().process_replication_rows(stream_name, instance_name,
                                                token, rows)

    def get_to_device_stream_token(self):
        return self._device_inbox_id_gen.get_current_token()

    async def get_new_messages_for_device(
        self,
        user_id: str,
        device_id: str,
        last_stream_id: int,
        current_stream_id: int,
        limit: int = 100,
    ) -> Tuple[List[dict], int]:
        """
        Args:
            user_id: The recipient user_id.
            device_id: The recipient device_id.
            last_stream_id: The last stream ID checked.
            current_stream_id: The current position of the to device
                message stream.
            limit: The maximum number of messages to retrieve.

        Returns:
            A list of messages for the device and where in the stream the messages got to.
        """
        has_changed = self._device_inbox_stream_cache.has_entity_changed(
            user_id, last_stream_id)
        if not has_changed:
            return ([], current_stream_id)

        def get_new_messages_for_device_txn(txn):
            sql = ("SELECT stream_id, message_json FROM device_inbox"
                   " WHERE user_id = ? AND device_id = ?"
                   " AND ? < stream_id AND stream_id <= ?"
                   " ORDER BY stream_id ASC"
                   " LIMIT ?")
            txn.execute(
                sql,
                (user_id, device_id, last_stream_id, current_stream_id, limit))
            messages = []
            for row in txn:
                stream_pos = row[0]
                messages.append(db_to_json(row[1]))
            if len(messages) < limit:
                stream_pos = current_stream_id
            return messages, stream_pos

        return await self.db_pool.runInteraction(
            "get_new_messages_for_device", get_new_messages_for_device_txn)

    @trace
    async def delete_messages_for_device(self, user_id: str, device_id: str,
                                         up_to_stream_id: int) -> int:
        """
        Args:
            user_id: The recipient user_id.
            device_id: The recipient device_id.
            up_to_stream_id: Where to delete messages up to.

        Returns:
            The number of messages deleted.
        """
        # If we have cached the last stream id we've deleted up to, we can
        # check if there is likely to be anything that needs deleting
        last_deleted_stream_id = self._last_device_delete_cache.get(
            (user_id, device_id), None)

        set_tag("last_deleted_stream_id", last_deleted_stream_id)

        if last_deleted_stream_id:
            has_changed = self._device_inbox_stream_cache.has_entity_changed(
                user_id, last_deleted_stream_id)
            if not has_changed:
                log_kv({"message": "No changes in cache since last check"})
                return 0

        def delete_messages_for_device_txn(txn):
            sql = ("DELETE FROM device_inbox"
                   " WHERE user_id = ? AND device_id = ?"
                   " AND stream_id <= ?")
            txn.execute(sql, (user_id, device_id, up_to_stream_id))
            return txn.rowcount

        count = await self.db_pool.runInteraction(
            "delete_messages_for_device", delete_messages_for_device_txn)

        log_kv({
            "message": "deleted {} messages for device".format(count),
            "count": count
        })

        # Update the cache, ensuring that we only ever increase the value
        last_deleted_stream_id = self._last_device_delete_cache.get(
            (user_id, device_id), 0)
        self._last_device_delete_cache[(user_id, device_id)] = max(
            last_deleted_stream_id, up_to_stream_id)

        return count

    @trace
    async def get_new_device_msgs_for_remote(self, destination, last_stream_id,
                                             current_stream_id,
                                             limit) -> Tuple[List[dict], int]:
        """
        Args:
            destination(str): The name of the remote server.
            last_stream_id(int|long): The last position of the device message stream
                that the server sent up to.
            current_stream_id(int|long): The current position of the device
                message stream.
        Returns:
            A list of messages for the device and where in the stream the messages got to.
        """

        set_tag("destination", destination)
        set_tag("last_stream_id", last_stream_id)
        set_tag("current_stream_id", current_stream_id)
        set_tag("limit", limit)

        has_changed = self._device_federation_outbox_stream_cache.has_entity_changed(
            destination, last_stream_id)
        if not has_changed or last_stream_id == current_stream_id:
            log_kv({"message": "No new messages in stream"})
            return ([], current_stream_id)

        if limit <= 0:
            # This can happen if we run out of room for EDUs in the transaction.
            return ([], last_stream_id)

        @trace
        def get_new_messages_for_remote_destination_txn(txn):
            sql = (
                "SELECT stream_id, messages_json FROM device_federation_outbox"
                " WHERE destination = ?"
                " AND ? < stream_id AND stream_id <= ?"
                " ORDER BY stream_id ASC"
                " LIMIT ?")
            txn.execute(
                sql, (destination, last_stream_id, current_stream_id, limit))
            messages = []
            for row in txn:
                stream_pos = row[0]
                messages.append(db_to_json(row[1]))
            if len(messages) < limit:
                log_kv({"message": "Set stream position to current position"})
                stream_pos = current_stream_id
            return messages, stream_pos

        return await self.db_pool.runInteraction(
            "get_new_device_msgs_for_remote",
            get_new_messages_for_remote_destination_txn,
        )

    @trace
    async def delete_device_msgs_for_remote(self, destination: str,
                                            up_to_stream_id: int) -> None:
        """Used to delete messages when the remote destination acknowledges
        their receipt.

        Args:
            destination: The destination server_name
            up_to_stream_id: Where to delete messages up to.
        """
        def delete_messages_for_remote_destination_txn(txn):
            sql = ("DELETE FROM device_federation_outbox"
                   " WHERE destination = ?"
                   " AND stream_id <= ?")
            txn.execute(sql, (destination, up_to_stream_id))

        await self.db_pool.runInteraction(
            "delete_device_msgs_for_remote",
            delete_messages_for_remote_destination_txn)

    async def get_all_new_device_messages(
            self, instance_name: str, last_id: int, current_id: int,
            limit: int) -> Tuple[List[Tuple[int, tuple]], int, bool]:
        """Get updates for to device replication stream.

        Args:
            instance_name: The writer we want to fetch updates from. Unused
                here since there is only ever one writer.
            last_id: The token to fetch updates from. Exclusive.
            current_id: The token to fetch updates up to. Inclusive.
            limit: The requested limit for the number of rows to return. The
                function may return more or fewer rows.

        Returns:
            A tuple consisting of: the updates, a token to use to fetch
            subsequent updates, and whether we returned fewer rows than exists
            between the requested tokens due to the limit.

            The token returned can be used in a subsequent call to this
            function to get further updatees.

            The updates are a list of 2-tuples of stream ID and the row data
        """

        if last_id == current_id:
            return [], current_id, False

        def get_all_new_device_messages_txn(txn):
            # We limit like this as we might have multiple rows per stream_id, and
            # we want to make sure we always get all entries for any stream_id
            # we return.
            upper_pos = min(current_id, last_id + limit)
            sql = ("SELECT max(stream_id), user_id"
                   " FROM device_inbox"
                   " WHERE ? < stream_id AND stream_id <= ?"
                   " GROUP BY user_id")
            txn.execute(sql, (last_id, upper_pos))
            updates = [(row[0], row[1:]) for row in txn]

            sql = ("SELECT max(stream_id), destination"
                   " FROM device_federation_outbox"
                   " WHERE ? < stream_id AND stream_id <= ?"
                   " GROUP BY destination")
            txn.execute(sql, (last_id, upper_pos))
            updates.extend((row[0], row[1:]) for row in txn)

            # Order by ascending stream ordering
            updates.sort()

            limited = False
            upto_token = current_id
            if len(updates) >= limit:
                upto_token = updates[-1][0]
                limited = True

            return updates, upto_token, limited

        return await self.db_pool.runInteraction(
            "get_all_new_device_messages", get_all_new_device_messages_txn)

    @trace
    async def add_messages_to_device_inbox(
        self,
        local_messages_by_user_then_device: dict,
        remote_messages_by_destination: dict,
    ) -> int:
        """Used to send messages from this server.

        Args:
            local_messages_by_user_and_device:
                Dictionary of user_id to device_id to message.
            remote_messages_by_destination:
                Dictionary of destination server_name to the EDU JSON to send.

        Returns:
            The new stream_id.
        """

        assert self._can_write_to_device

        def add_messages_txn(txn, now_ms, stream_id):
            # Add the local messages directly to the local inbox.
            self._add_messages_to_local_device_inbox_txn(
                txn, stream_id, local_messages_by_user_then_device)

            # Add the remote messages to the federation outbox.
            # We'll send them to a remote server when we next send a
            # federation transaction to that destination.
            self.db_pool.simple_insert_many_txn(
                txn,
                table="device_federation_outbox",
                values=[{
                    "destination": destination,
                    "stream_id": stream_id,
                    "queued_ts": now_ms,
                    "messages_json": json_encoder.encode(edu),
                    "instance_name": self._instance_name,
                } for destination, edu in
                        remote_messages_by_destination.items()],
            )

        async with self._device_inbox_id_gen.get_next() as stream_id:
            now_ms = self.clock.time_msec()
            await self.db_pool.runInteraction("add_messages_to_device_inbox",
                                              add_messages_txn, now_ms,
                                              stream_id)
            for user_id in local_messages_by_user_then_device.keys():
                self._device_inbox_stream_cache.entity_has_changed(
                    user_id, stream_id)
            for destination in remote_messages_by_destination.keys():
                self._device_federation_outbox_stream_cache.entity_has_changed(
                    destination, stream_id)

        return self._device_inbox_id_gen.get_current_token()

    async def add_messages_from_remote_to_device_inbox(
            self, origin: str, message_id: str,
            local_messages_by_user_then_device: dict) -> int:
        assert self._can_write_to_device

        def add_messages_txn(txn, now_ms, stream_id):
            # Check if we've already inserted a matching message_id for that
            # origin. This can happen if the origin doesn't receive our
            # acknowledgement from the first time we received the message.
            already_inserted = self.db_pool.simple_select_one_txn(
                txn,
                table="device_federation_inbox",
                keyvalues={
                    "origin": origin,
                    "message_id": message_id
                },
                retcols=("message_id", ),
                allow_none=True,
            )
            if already_inserted is not None:
                return

            # Add an entry for this message_id so that we know we've processed
            # it.
            self.db_pool.simple_insert_txn(
                txn,
                table="device_federation_inbox",
                values={
                    "origin": origin,
                    "message_id": message_id,
                    "received_ts": now_ms,
                },
            )

            # Add the messages to the approriate local device inboxes so that
            # they'll be sent to the devices when they next sync.
            self._add_messages_to_local_device_inbox_txn(
                txn, stream_id, local_messages_by_user_then_device)

        async with self._device_inbox_id_gen.get_next() as stream_id:
            now_ms = self.clock.time_msec()
            await self.db_pool.runInteraction(
                "add_messages_from_remote_to_device_inbox",
                add_messages_txn,
                now_ms,
                stream_id,
            )
            for user_id in local_messages_by_user_then_device.keys():
                self._device_inbox_stream_cache.entity_has_changed(
                    user_id, stream_id)

        return stream_id

    def _add_messages_to_local_device_inbox_txn(self, txn, stream_id,
                                                messages_by_user_then_device):
        assert self._can_write_to_device

        local_by_user_then_device = {}
        for user_id, messages_by_device in messages_by_user_then_device.items(
        ):
            messages_json_for_user = {}
            devices = list(messages_by_device.keys())
            if len(devices) == 1 and devices[0] == "*":
                # Handle wildcard device_ids.
                devices = self.db_pool.simple_select_onecol_txn(
                    txn,
                    table="devices",
                    keyvalues={"user_id": user_id},
                    retcol="device_id",
                )

                message_json = json_encoder.encode(messages_by_device["*"])
                for device_id in devices:
                    # Add the message for all devices for this user on this
                    # server.
                    messages_json_for_user[device_id] = message_json
            else:
                if not devices:
                    continue

                rows = self.db_pool.simple_select_many_txn(
                    txn,
                    table="devices",
                    keyvalues={"user_id": user_id},
                    column="device_id",
                    iterable=devices,
                    retcols=("device_id", ),
                )

                for row in rows:
                    # Only insert into the local inbox if the device exists on
                    # this server
                    device_id = row["device_id"]
                    message_json = json_encoder.encode(
                        messages_by_device[device_id])
                    messages_json_for_user[device_id] = message_json

            if messages_json_for_user:
                local_by_user_then_device[user_id] = messages_json_for_user

        if not local_by_user_then_device:
            return

        self.db_pool.simple_insert_many_txn(
            txn,
            table="device_inbox",
            values=[{
                "user_id": user_id,
                "device_id": device_id,
                "stream_id": stream_id,
                "message_json": message_json,
                "instance_name": self._instance_name,
            } for user_id, messages_by_device in
                    local_by_user_then_device.items()
                    for device_id, message_json in messages_by_device.items()],
        )
Esempio n. 2
0
class AccountDataStore(AccountDataWorkerStore):
    def __init__(self, db_conn, hs):
        self._account_data_id_gen = StreamIdGenerator(
            db_conn, "account_data_max_stream_id", "stream_id"
        )

        super(AccountDataStore, self).__init__(db_conn, hs)

    def get_max_account_data_stream_id(self):
        """Get the current max stream id for the private user data stream

        Returns:
            A deferred int.
        """
        return self._account_data_id_gen.get_current_token()

    @defer.inlineCallbacks
    def add_account_data_to_room(self, user_id, room_id, account_data_type, content):
        """Add some account_data to a room for a user.
        Args:
            user_id(str): The user to add a tag for.
            room_id(str): The room to add a tag for.
            account_data_type(str): The type of account_data to add.
            content(dict): A json object to associate with the tag.
        Returns:
            A deferred that completes once the account_data has been added.
        """
        content_json = json.dumps(content)

        with self._account_data_id_gen.get_next() as next_id:
            # no need to lock here as room_account_data has a unique constraint
            # on (user_id, room_id, account_data_type) so _simple_upsert will
            # retry if there is a conflict.
            yield self._simple_upsert(
                desc="add_room_account_data",
                table="room_account_data",
                keyvalues={
                    "user_id": user_id,
                    "room_id": room_id,
                    "account_data_type": account_data_type,
                },
                values={"stream_id": next_id, "content": content_json},
                lock=False,
            )

            # it's theoretically possible for the above to succeed and the
            # below to fail - in which case we might reuse a stream id on
            # restart, and the above update might not get propagated. That
            # doesn't sound any worse than the whole update getting lost,
            # which is what would happen if we combined the two into one
            # transaction.
            yield self._update_max_stream_id(next_id)

            self._account_data_stream_cache.entity_has_changed(user_id, next_id)
            self.get_account_data_for_user.invalidate((user_id,))
            self.get_account_data_for_room.invalidate((user_id, room_id))
            self.get_account_data_for_room_and_type.prefill(
                (user_id, room_id, account_data_type), content
            )

        result = self._account_data_id_gen.get_current_token()
        return result

    @defer.inlineCallbacks
    def add_account_data_for_user(self, user_id, account_data_type, content):
        """Add some account_data to a room for a user.
        Args:
            user_id(str): The user to add a tag for.
            account_data_type(str): The type of account_data to add.
            content(dict): A json object to associate with the tag.
        Returns:
            A deferred that completes once the account_data has been added.
        """
        content_json = json.dumps(content)

        with self._account_data_id_gen.get_next() as next_id:
            # no need to lock here as account_data has a unique constraint on
            # (user_id, account_data_type) so _simple_upsert will retry if
            # there is a conflict.
            yield self._simple_upsert(
                desc="add_user_account_data",
                table="account_data",
                keyvalues={"user_id": user_id, "account_data_type": account_data_type},
                values={"stream_id": next_id, "content": content_json},
                lock=False,
            )

            # it's theoretically possible for the above to succeed and the
            # below to fail - in which case we might reuse a stream id on
            # restart, and the above update might not get propagated. That
            # doesn't sound any worse than the whole update getting lost,
            # which is what would happen if we combined the two into one
            # transaction.
            yield self._update_max_stream_id(next_id)

            self._account_data_stream_cache.entity_has_changed(user_id, next_id)
            self.get_account_data_for_user.invalidate((user_id,))
            self.get_global_account_data_by_type_for_user.invalidate(
                (account_data_type, user_id)
            )

        result = self._account_data_id_gen.get_current_token()
        return result

    def _update_max_stream_id(self, next_id):
        """Update the max stream_id

        Args:
            next_id(int): The the revision to advance to.
        """

        def _update(txn):
            update_max_id_sql = (
                "UPDATE account_data_max_stream_id"
                " SET stream_id = ?"
                " WHERE stream_id < ?"
            )
            txn.execute(update_max_id_sql, (next_id, next_id))

        return self.runInteraction("update_account_data_max_stream_id", _update)
Esempio n. 3
0
class AccountDataStore(AccountDataWorkerStore):
    def __init__(self, database: DatabasePool, db_conn, hs):
        self._account_data_id_gen = StreamIdGenerator(
            db_conn,
            "account_data_max_stream_id",
            "stream_id",
            extra_tables=[
                ("room_account_data", "stream_id"),
                ("room_tags_revisions", "stream_id"),
            ],
        )

        super(AccountDataStore, self).__init__(database, db_conn, hs)

    def get_max_account_data_stream_id(self) -> int:
        """Get the current max stream id for the private user data stream

        Returns:
            The maximum stream ID.
        """
        return self._account_data_id_gen.get_current_token()

    async def add_account_data_to_room(self, user_id: str, room_id: str,
                                       account_data_type: str,
                                       content: JsonDict) -> int:
        """Add some account_data to a room for a user.

        Args:
            user_id: The user to add a tag for.
            room_id: The room to add a tag for.
            account_data_type: The type of account_data to add.
            content: A json object to associate with the tag.

        Returns:
            The maximum stream ID.
        """
        content_json = json_encoder.encode(content)

        with self._account_data_id_gen.get_next() as next_id:
            # no need to lock here as room_account_data has a unique constraint
            # on (user_id, room_id, account_data_type) so simple_upsert will
            # retry if there is a conflict.
            await self.db_pool.simple_upsert(
                desc="add_room_account_data",
                table="room_account_data",
                keyvalues={
                    "user_id": user_id,
                    "room_id": room_id,
                    "account_data_type": account_data_type,
                },
                values={
                    "stream_id": next_id,
                    "content": content_json
                },
                lock=False,
            )

            # it's theoretically possible for the above to succeed and the
            # below to fail - in which case we might reuse a stream id on
            # restart, and the above update might not get propagated. That
            # doesn't sound any worse than the whole update getting lost,
            # which is what would happen if we combined the two into one
            # transaction.
            await self._update_max_stream_id(next_id)

            self._account_data_stream_cache.entity_has_changed(
                user_id, next_id)
            self.get_account_data_for_user.invalidate((user_id, ))
            self.get_account_data_for_room.invalidate((user_id, room_id))
            self.get_account_data_for_room_and_type.prefill(
                (user_id, room_id, account_data_type), content)

        return self._account_data_id_gen.get_current_token()

    async def add_account_data_for_user(self, user_id: str,
                                        account_data_type: str,
                                        content: JsonDict) -> int:
        """Add some account_data to a room for a user.

        Args:
            user_id: The user to add a tag for.
            account_data_type: The type of account_data to add.
            content: A json object to associate with the tag.

        Returns:
            The maximum stream ID.
        """
        content_json = json_encoder.encode(content)

        with self._account_data_id_gen.get_next() as next_id:
            # no need to lock here as account_data has a unique constraint on
            # (user_id, account_data_type) so simple_upsert will retry if
            # there is a conflict.
            await self.db_pool.simple_upsert(
                desc="add_user_account_data",
                table="account_data",
                keyvalues={
                    "user_id": user_id,
                    "account_data_type": account_data_type
                },
                values={
                    "stream_id": next_id,
                    "content": content_json
                },
                lock=False,
            )

            # it's theoretically possible for the above to succeed and the
            # below to fail - in which case we might reuse a stream id on
            # restart, and the above update might not get propagated. That
            # doesn't sound any worse than the whole update getting lost,
            # which is what would happen if we combined the two into one
            # transaction.
            #
            # Note: This is only here for backwards compat to allow admins to
            # roll back to a previous Synapse version. Next time we update the
            # database version we can remove this table.
            await self._update_max_stream_id(next_id)

            self._account_data_stream_cache.entity_has_changed(
                user_id, next_id)
            self.get_account_data_for_user.invalidate((user_id, ))
            self.get_global_account_data_by_type_for_user.invalidate(
                (account_data_type, user_id))

        return self._account_data_id_gen.get_current_token()

    def _update_max_stream_id(self, next_id: int):
        """Update the max stream_id

        Args:
            next_id: The the revision to advance to.
        """

        # Note: This is only here for backwards compat to allow admins to
        # roll back to a previous Synapse version. Next time we update the
        # database version we can remove this table.

        def _update(txn):
            update_max_id_sql = ("UPDATE account_data_max_stream_id"
                                 " SET stream_id = ?"
                                 " WHERE stream_id < ?")
            txn.execute(update_max_id_sql, (next_id, next_id))

        return self.db_pool.runInteraction("update_account_data_max_stream_id",
                                           _update)
Esempio n. 4
0
class ReceiptsStore(ReceiptsWorkerStore):
    def __init__(self, database: Database, db_conn, hs):
        # We instantiate this first as the ReceiptsWorkerStore constructor
        # needs to be able to call get_max_receipt_stream_id
        self._receipts_id_gen = StreamIdGenerator(db_conn,
                                                  "receipts_linearized",
                                                  "stream_id")

        super(ReceiptsStore, self).__init__(database, db_conn, hs)

    def get_max_receipt_stream_id(self):
        return self._receipts_id_gen.get_current_token()

    def insert_linearized_receipt_txn(self, txn, room_id, receipt_type,
                                      user_id, event_id, data, stream_id):
        """Inserts a read-receipt into the database if it's newer than the current RR

        Returns: int|None
            None if the RR is older than the current RR
            otherwise, the rx timestamp of the event that the RR corresponds to
                (or 0 if the event is unknown)
        """
        res = self.db.simple_select_one_txn(
            txn,
            table="events",
            retcols=["stream_ordering", "received_ts"],
            keyvalues={"event_id": event_id},
            allow_none=True,
        )

        stream_ordering = int(res["stream_ordering"]) if res else None
        rx_ts = res["received_ts"] if res else 0

        # We don't want to clobber receipts for more recent events, so we
        # have to compare orderings of existing receipts
        if stream_ordering is not None:
            sql = (
                "SELECT stream_ordering, event_id FROM events"
                " INNER JOIN receipts_linearized as r USING (event_id, room_id)"
                " WHERE r.room_id = ? AND r.receipt_type = ? AND r.user_id = ?"
            )
            txn.execute(sql, (room_id, receipt_type, user_id))

            for so, eid in txn:
                if int(so) >= stream_ordering:
                    logger.debug(
                        "Ignoring new receipt for %s in favour of existing "
                        "one for later event %s",
                        event_id,
                        eid,
                    )
                    return None

        txn.call_after(self.get_receipts_for_room.invalidate,
                       (room_id, receipt_type))
        txn.call_after(
            self._invalidate_get_users_with_receipts_in_room,
            room_id,
            receipt_type,
            user_id,
        )
        txn.call_after(self.get_receipts_for_user.invalidate,
                       (user_id, receipt_type))
        # FIXME: This shouldn't invalidate the whole cache
        txn.call_after(self._get_linearized_receipts_for_room.invalidate_many,
                       (room_id, ))

        txn.call_after(self._receipts_stream_cache.entity_has_changed, room_id,
                       stream_id)

        txn.call_after(
            self.get_last_receipt_event_id_for_user.invalidate,
            (user_id, room_id, receipt_type),
        )

        self.db.simple_upsert_txn(
            txn,
            table="receipts_linearized",
            keyvalues={
                "room_id": room_id,
                "receipt_type": receipt_type,
                "user_id": user_id,
            },
            values={
                "stream_id": stream_id,
                "event_id": event_id,
                "data": json.dumps(data),
            },
            # receipts_linearized has a unique constraint on
            # (user_id, room_id, receipt_type), so no need to lock
            lock=False,
        )

        if receipt_type == "m.read" and stream_ordering is not None:
            self._remove_old_push_actions_before_txn(
                txn,
                room_id=room_id,
                user_id=user_id,
                stream_ordering=stream_ordering)

        return rx_ts

    @defer.inlineCallbacks
    def insert_receipt(self, room_id, receipt_type, user_id, event_ids, data):
        """Insert a receipt, either from local client or remote server.

        Automatically does conversion between linearized and graph
        representations.
        """
        if not event_ids:
            return

        if len(event_ids) == 1:
            linearized_event_id = event_ids[0]
        else:
            # we need to points in graph -> linearized form.
            # TODO: Make this better.
            def graph_to_linear(txn):
                clause, args = make_in_list_sql_clause(self.database_engine,
                                                       "event_id", event_ids)

                sql = """
                    SELECT event_id WHERE room_id = ? AND stream_ordering IN (
                        SELECT max(stream_ordering) WHERE %s
                    )
                """ % (clause, )

                txn.execute(sql, [room_id] + list(args))
                rows = txn.fetchall()
                if rows:
                    return rows[0][0]
                else:
                    raise RuntimeError("Unrecognized event_ids: %r" %
                                       (event_ids, ))

            linearized_event_id = yield self.db.runInteraction(
                "insert_receipt_conv", graph_to_linear)

        stream_id_manager = self._receipts_id_gen.get_next()
        with stream_id_manager as stream_id:
            event_ts = yield self.db.runInteraction(
                "insert_linearized_receipt",
                self.insert_linearized_receipt_txn,
                room_id,
                receipt_type,
                user_id,
                linearized_event_id,
                data,
                stream_id=stream_id,
            )

        if event_ts is None:
            return None

        now = self._clock.time_msec()
        logger.debug(
            "RR for event %s in %s (%i ms old)",
            linearized_event_id,
            room_id,
            now - event_ts,
        )

        yield self.insert_graph_receipt(room_id, receipt_type, user_id,
                                        event_ids, data)

        max_persisted_id = self._receipts_id_gen.get_current_token()

        return stream_id, max_persisted_id

    def insert_graph_receipt(self, room_id, receipt_type, user_id, event_ids,
                             data):
        return self.db.runInteraction(
            "insert_graph_receipt",
            self.insert_graph_receipt_txn,
            room_id,
            receipt_type,
            user_id,
            event_ids,
            data,
        )

    def insert_graph_receipt_txn(self, txn, room_id, receipt_type, user_id,
                                 event_ids, data):
        txn.call_after(self.get_receipts_for_room.invalidate,
                       (room_id, receipt_type))
        txn.call_after(
            self._invalidate_get_users_with_receipts_in_room,
            room_id,
            receipt_type,
            user_id,
        )
        txn.call_after(self.get_receipts_for_user.invalidate,
                       (user_id, receipt_type))
        # FIXME: This shouldn't invalidate the whole cache
        txn.call_after(self._get_linearized_receipts_for_room.invalidate_many,
                       (room_id, ))

        self.db.simple_delete_txn(
            txn,
            table="receipts_graph",
            keyvalues={
                "room_id": room_id,
                "receipt_type": receipt_type,
                "user_id": user_id,
            },
        )
        self.db.simple_insert_txn(
            txn,
            table="receipts_graph",
            values={
                "room_id": room_id,
                "receipt_type": receipt_type,
                "user_id": user_id,
                "event_ids": json.dumps(event_ids),
                "data": json.dumps(data),
            },
        )
class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
    def __init__(
        self,
        database: DatabasePool,
        db_conn: LoggingDatabaseConnection,
        hs: "HomeServer",
    ):
        super().__init__(database, db_conn, hs)

        self._cross_signing_id_gen = StreamIdGenerator(
            db_conn, "e2e_cross_signing_keys", "stream_id")

    async def set_e2e_device_keys(self, user_id: str, device_id: str,
                                  time_now: int,
                                  device_keys: JsonDict) -> bool:
        """Stores device keys for a device. Returns whether there was a change
        or the keys were already in the database.
        """
        def _set_e2e_device_keys_txn(txn: LoggingTransaction) -> bool:
            set_tag("user_id", user_id)
            set_tag("device_id", device_id)
            set_tag("time_now", time_now)
            set_tag("device_keys", device_keys)

            old_key_json = self.db_pool.simple_select_one_onecol_txn(
                txn,
                table="e2e_device_keys_json",
                keyvalues={
                    "user_id": user_id,
                    "device_id": device_id
                },
                retcol="key_json",
                allow_none=True,
            )

            # In py3 we need old_key_json to match new_key_json type. The DB
            # returns unicode while encode_canonical_json returns bytes.
            new_key_json = encode_canonical_json(device_keys).decode("utf-8")

            if old_key_json == new_key_json:
                log_kv({"Message": "Device key already stored."})
                return False

            self.db_pool.simple_upsert_txn(
                txn,
                table="e2e_device_keys_json",
                keyvalues={
                    "user_id": user_id,
                    "device_id": device_id
                },
                values={
                    "ts_added_ms": time_now,
                    "key_json": new_key_json
                },
            )
            log_kv({"message": "Device keys stored."})
            return True

        return await self.db_pool.runInteraction("set_e2e_device_keys",
                                                 _set_e2e_device_keys_txn)

    async def delete_e2e_keys_by_device(self, user_id: str,
                                        device_id: str) -> None:
        def delete_e2e_keys_by_device_txn(txn: LoggingTransaction) -> None:
            log_kv({
                "message": "Deleting keys for device",
                "device_id": device_id,
                "user_id": user_id,
            })
            self.db_pool.simple_delete_txn(
                txn,
                table="e2e_device_keys_json",
                keyvalues={
                    "user_id": user_id,
                    "device_id": device_id
                },
            )
            self.db_pool.simple_delete_txn(
                txn,
                table="e2e_one_time_keys_json",
                keyvalues={
                    "user_id": user_id,
                    "device_id": device_id
                },
            )
            self._invalidate_cache_and_stream(txn,
                                              self.count_e2e_one_time_keys,
                                              (user_id, device_id))
            self.db_pool.simple_delete_txn(
                txn,
                table="dehydrated_devices",
                keyvalues={
                    "user_id": user_id,
                    "device_id": device_id
                },
            )
            self.db_pool.simple_delete_txn(
                txn,
                table="e2e_fallback_keys_json",
                keyvalues={
                    "user_id": user_id,
                    "device_id": device_id
                },
            )
            self._invalidate_cache_and_stream(
                txn, self.get_e2e_unused_fallback_key_types,
                (user_id, device_id))

        await self.db_pool.runInteraction("delete_e2e_keys_by_device",
                                          delete_e2e_keys_by_device_txn)

    def _set_e2e_cross_signing_key_txn(
        self,
        txn: LoggingTransaction,
        user_id: str,
        key_type: str,
        key: JsonDict,
        stream_id: int,
    ) -> None:
        """Set a user's cross-signing key.

        Args:
            txn: db connection
            user_id: the user to set the signing key for
            key_type: the type of key that is being set: either 'master'
                for a master key, 'self_signing' for a self-signing key, or
                'user_signing' for a user-signing key
            key: the key data
            stream_id
        """
        # the 'key' dict will look something like:
        # {
        #   "user_id": "@alice:example.com",
        #   "usage": ["self_signing"],
        #   "keys": {
        #     "ed25519:base64+self+signing+public+key": "base64+self+signing+public+key",
        #   },
        #   "signatures": {
        #     "@alice:example.com": {
        #       "ed25519:base64+master+public+key": "base64+signature"
        #     }
        #   }
        # }
        # The "keys" property must only have one entry, which will be the public
        # key, so we just grab the first value in there
        pubkey = next(iter(key["keys"].values()))

        # The cross-signing keys need to occupy the same namespace as devices,
        # since signatures are identified by device ID.  So add an entry to the
        # device table to make sure that we don't have a collision with device
        # IDs.
        # We only need to do this for local users, since remote servers should be
        # responsible for checking this for their own users.
        if self.hs.is_mine_id(user_id):
            self.db_pool.simple_insert_txn(
                txn,
                "devices",
                values={
                    "user_id": user_id,
                    "device_id": pubkey,
                    "display_name": key_type + " signing key",
                    "hidden": True,
                },
            )

        # and finally, store the key itself
        self.db_pool.simple_insert_txn(
            txn,
            "e2e_cross_signing_keys",
            values={
                "user_id": user_id,
                "keytype": key_type,
                "keydata": json_encoder.encode(key),
                "stream_id": stream_id,
            },
        )

        self._invalidate_cache_and_stream(
            txn, self._get_bare_e2e_cross_signing_keys, (user_id, ))

    async def set_e2e_cross_signing_key(self, user_id: str, key_type: str,
                                        key: JsonDict) -> None:
        """Set a user's cross-signing key.

        Args:
            user_id: the user to set the user-signing key for
            key_type: the type of cross-signing key to set
            key: the key data
        """

        async with self._cross_signing_id_gen.get_next() as stream_id:
            return await self.db_pool.runInteraction(
                "add_e2e_cross_signing_key",
                self._set_e2e_cross_signing_key_txn,
                user_id,
                key_type,
                key,
                stream_id,
            )

    async def store_e2e_cross_signing_signatures(
            self, user_id: str,
            signatures: "Iterable[SignatureListItem]") -> None:
        """Stores cross-signing signatures.

        Args:
            user_id: the user who made the signatures
            signatures: signatures to add
        """
        await self.db_pool.simple_insert_many(
            "e2e_cross_signing_signatures",
            keys=(
                "user_id",
                "key_id",
                "target_user_id",
                "target_device_id",
                "signature",
            ),
            values=[(
                user_id,
                item.signing_key_id,
                item.target_user_id,
                item.target_device_id,
                item.signature,
            ) for item in signatures],
            desc="add_e2e_signing_key",
        )
Esempio n. 6
0
class AccountDataStore(AccountDataWorkerStore):
    def __init__(self, db_conn, hs):
        self._account_data_id_gen = StreamIdGenerator(
            db_conn, "account_data_max_stream_id", "stream_id"
        )

        super(AccountDataStore, self).__init__(db_conn, hs)

    def get_max_account_data_stream_id(self):
        """Get the current max stream id for the private user data stream

        Returns:
            A deferred int.
        """
        return self._account_data_id_gen.get_current_token()

    @defer.inlineCallbacks
    def add_account_data_to_room(self, user_id, room_id, account_data_type, content):
        """Add some account_data to a room for a user.
        Args:
            user_id(str): The user to add a tag for.
            room_id(str): The room to add a tag for.
            account_data_type(str): The type of account_data to add.
            content(dict): A json object to associate with the tag.
        Returns:
            A deferred that completes once the account_data has been added.
        """
        content_json = json.dumps(content)

        with self._account_data_id_gen.get_next() as next_id:
            # no need to lock here as room_account_data has a unique constraint
            # on (user_id, room_id, account_data_type) so _simple_upsert will
            # retry if there is a conflict.
            yield self._simple_upsert(
                desc="add_room_account_data",
                table="room_account_data",
                keyvalues={
                    "user_id": user_id,
                    "room_id": room_id,
                    "account_data_type": account_data_type,
                },
                values={
                    "stream_id": next_id,
                    "content": content_json,
                },
                lock=False,
            )

            # it's theoretically possible for the above to succeed and the
            # below to fail - in which case we might reuse a stream id on
            # restart, and the above update might not get propagated. That
            # doesn't sound any worse than the whole update getting lost,
            # which is what would happen if we combined the two into one
            # transaction.
            yield self._update_max_stream_id(next_id)

            self._account_data_stream_cache.entity_has_changed(user_id, next_id)
            self.get_account_data_for_user.invalidate((user_id,))
            self.get_account_data_for_room.invalidate((user_id, room_id,))
            self.get_account_data_for_room_and_type.prefill(
                (user_id, room_id, account_data_type,), content,
            )

        result = self._account_data_id_gen.get_current_token()
        defer.returnValue(result)

    @defer.inlineCallbacks
    def add_account_data_for_user(self, user_id, account_data_type, content):
        """Add some account_data to a room for a user.
        Args:
            user_id(str): The user to add a tag for.
            account_data_type(str): The type of account_data to add.
            content(dict): A json object to associate with the tag.
        Returns:
            A deferred that completes once the account_data has been added.
        """
        content_json = json.dumps(content)

        with self._account_data_id_gen.get_next() as next_id:
            # no need to lock here as account_data has a unique constraint on
            # (user_id, account_data_type) so _simple_upsert will retry if
            # there is a conflict.
            yield self._simple_upsert(
                desc="add_user_account_data",
                table="account_data",
                keyvalues={
                    "user_id": user_id,
                    "account_data_type": account_data_type,
                },
                values={
                    "stream_id": next_id,
                    "content": content_json,
                },
                lock=False,
            )

            # it's theoretically possible for the above to succeed and the
            # below to fail - in which case we might reuse a stream id on
            # restart, and the above update might not get propagated. That
            # doesn't sound any worse than the whole update getting lost,
            # which is what would happen if we combined the two into one
            # transaction.
            yield self._update_max_stream_id(next_id)

            self._account_data_stream_cache.entity_has_changed(
                user_id, next_id,
            )
            self.get_account_data_for_user.invalidate((user_id,))
            self.get_global_account_data_by_type_for_user.invalidate(
                (account_data_type, user_id,)
            )

        result = self._account_data_id_gen.get_current_token()
        defer.returnValue(result)

    def _update_max_stream_id(self, next_id):
        """Update the max stream_id

        Args:
            next_id(int): The the revision to advance to.
        """
        def _update(txn):
            update_max_id_sql = (
                "UPDATE account_data_max_stream_id"
                " SET stream_id = ?"
                " WHERE stream_id < ?"
            )
            txn.execute(update_max_id_sql, (next_id, next_id))
        return self.runInteraction(
            "update_account_data_max_stream_id",
            _update,
        )
Esempio n. 7
0
class DeviceInboxWorkerStore(SQLBaseStore):
    def __init__(
        self,
        database: DatabasePool,
        db_conn: LoggingDatabaseConnection,
        hs: "HomeServer",
    ):
        super().__init__(database, db_conn, hs)

        self._instance_name = hs.get_instance_name()

        # Map of (user_id, device_id) to the last stream_id that has been
        # deleted up to. This is so that we can no op deletions.
        self._last_device_delete_cache: ExpiringCache[Tuple[
            str, Optional[str]], int] = ExpiringCache(
                cache_name="last_device_delete_cache",
                clock=self._clock,
                max_len=10000,
                expiry_ms=30 * 60 * 1000,
            )

        if isinstance(database.engine, PostgresEngine):
            self._can_write_to_device = (self._instance_name
                                         in hs.config.worker.writers.to_device)

            self._device_inbox_id_gen: AbstractStreamIdGenerator = (
                MultiWriterIdGenerator(
                    db_conn=db_conn,
                    db=database,
                    stream_name="to_device",
                    instance_name=self._instance_name,
                    tables=[("device_inbox", "instance_name", "stream_id")],
                    sequence_name="device_inbox_sequence",
                    writers=hs.config.worker.writers.to_device,
                ))
        else:
            self._can_write_to_device = True
            self._device_inbox_id_gen = StreamIdGenerator(
                db_conn, "device_inbox", "stream_id")

        max_device_inbox_id = self._device_inbox_id_gen.get_current_token()
        device_inbox_prefill, min_device_inbox_id = self.db_pool.get_cache_dict(
            db_conn,
            "device_inbox",
            entity_column="user_id",
            stream_column="stream_id",
            max_value=max_device_inbox_id,
            limit=1000,
        )
        self._device_inbox_stream_cache = StreamChangeCache(
            "DeviceInboxStreamChangeCache",
            min_device_inbox_id,
            prefilled_cache=device_inbox_prefill,
        )

        # The federation outbox and the local device inbox uses the same
        # stream_id generator.
        device_outbox_prefill, min_device_outbox_id = self.db_pool.get_cache_dict(
            db_conn,
            "device_federation_outbox",
            entity_column="destination",
            stream_column="stream_id",
            max_value=max_device_inbox_id,
            limit=1000,
        )
        self._device_federation_outbox_stream_cache = StreamChangeCache(
            "DeviceFederationOutboxStreamChangeCache",
            min_device_outbox_id,
            prefilled_cache=device_outbox_prefill,
        )

    def process_replication_rows(
        self,
        stream_name: str,
        instance_name: str,
        token: int,
        rows: Iterable[ToDeviceStream.ToDeviceStreamRow],
    ) -> None:
        if stream_name == ToDeviceStream.NAME:
            # If replication is happening than postgres must be being used.
            assert isinstance(self._device_inbox_id_gen,
                              MultiWriterIdGenerator)
            self._device_inbox_id_gen.advance(instance_name, token)
            for row in rows:
                if row.entity.startswith("@"):
                    self._device_inbox_stream_cache.entity_has_changed(
                        row.entity, token)
                else:
                    self._device_federation_outbox_stream_cache.entity_has_changed(
                        row.entity, token)
        return super().process_replication_rows(stream_name, instance_name,
                                                token, rows)

    def get_to_device_stream_token(self) -> int:
        return self._device_inbox_id_gen.get_current_token()

    async def get_messages_for_user_devices(
        self,
        user_ids: Collection[str],
        from_stream_id: int,
        to_stream_id: int,
    ) -> Dict[Tuple[str, str], List[JsonDict]]:
        """
        Retrieve to-device messages for a given set of users.

        Only to-device messages with stream ids between the given boundaries
        (from < X <= to) are returned.

        Args:
            user_ids: The users to retrieve to-device messages for.
            from_stream_id: The lower boundary of stream id to filter with (exclusive).
            to_stream_id: The upper boundary of stream id to filter with (inclusive).

        Returns:
            A dictionary of (user id, device id) -> list of to-device messages.
        """
        # We expect the stream ID returned by _get_device_messages to always
        # be to_stream_id. So, no need to return it from this function.
        (
            user_id_device_id_to_messages,
            last_processed_stream_id,
        ) = await self._get_device_messages(
            user_ids=user_ids,
            from_stream_id=from_stream_id,
            to_stream_id=to_stream_id,
        )

        assert (
            last_processed_stream_id == to_stream_id
        ), "Expected _get_device_messages to process all to-device messages up to `to_stream_id`"

        return user_id_device_id_to_messages

    async def get_messages_for_device(
        self,
        user_id: str,
        device_id: str,
        from_stream_id: int,
        to_stream_id: int,
        limit: int = 100,
    ) -> Tuple[List[JsonDict], int]:
        """
        Retrieve to-device messages for a single user device.

        Only to-device messages with stream ids between the given boundaries
        (from < X <= to) are returned.

        Args:
            user_id: The ID of the user to retrieve messages for.
            device_id: The ID of the device to retrieve to-device messages for.
            from_stream_id: The lower boundary of stream id to filter with (exclusive).
            to_stream_id: The upper boundary of stream id to filter with (inclusive).
            limit: A limit on the number of to-device messages returned.

        Returns:
            A tuple containing:
                * A list of to-device messages within the given stream id range intended for
                  the given user / device combo.
                * The last-processed stream ID. Subsequent calls of this function with the
                  same device should pass this value as 'from_stream_id'.
        """
        (
            user_id_device_id_to_messages,
            last_processed_stream_id,
        ) = await self._get_device_messages(
            user_ids=[user_id],
            device_id=device_id,
            from_stream_id=from_stream_id,
            to_stream_id=to_stream_id,
            limit=limit,
        )

        if not user_id_device_id_to_messages:
            # There were no messages!
            return [], to_stream_id

        # Extract the messages, no need to return the user and device ID again
        to_device_messages = user_id_device_id_to_messages.get(
            (user_id, device_id), [])

        return to_device_messages, last_processed_stream_id

    async def _get_device_messages(
        self,
        user_ids: Collection[str],
        from_stream_id: int,
        to_stream_id: int,
        device_id: Optional[str] = None,
        limit: Optional[int] = None,
    ) -> Tuple[Dict[Tuple[str, str], List[JsonDict]], int]:
        """
        Retrieve pending to-device messages for a collection of user devices.

        Only to-device messages with stream ids between the given boundaries
        (from < X <= to) are returned.

        Note that a stream ID can be shared by multiple copies of the same message with
        different recipient devices. Stream IDs are only unique in the context of a single
        user ID / device ID pair. Thus, applying a limit (of messages to return) when working
        with a sliding window of stream IDs is only possible when querying messages of a
        single user device.

        Finally, note that device IDs are not unique across users.

        Args:
            user_ids: The user IDs to filter device messages by.
            from_stream_id: The lower boundary of stream id to filter with (exclusive).
            to_stream_id: The upper boundary of stream id to filter with (inclusive).
            device_id: A device ID to query to-device messages for. If not provided, to-device
                messages from all device IDs for the given user IDs will be queried. May not be
                provided if `user_ids` contains more than one entry.
            limit: The maximum number of to-device messages to return. Can only be used when
                passing a single user ID / device ID tuple.

        Returns:
            A tuple containing:
                * A dict of (user_id, device_id) -> list of to-device messages
                * The last-processed stream ID. If this is less than `to_stream_id`, then
                    there may be more messages to retrieve. If `limit` is not set, then this
                    is always equal to 'to_stream_id'.
        """
        if not user_ids:
            logger.warning("No users provided upon querying for device IDs")
            return {}, to_stream_id

        # Prevent a query for one user's device also retrieving another user's device with
        # the same device ID (device IDs are not unique across users).
        if len(user_ids) > 1 and device_id is not None:
            raise AssertionError(
                "Programming error: 'device_id' cannot be supplied to "
                "_get_device_messages when >1 user_id has been provided")

        # A limit can only be applied when querying for a single user ID / device ID tuple.
        # See the docstring of this function for more details.
        if limit is not None and device_id is None:
            raise AssertionError(
                "Programming error: _get_device_messages was passed 'limit' "
                "without a specific user_id/device_id")

        user_ids_to_query: Set[str] = set()
        device_ids_to_query: Set[str] = set()

        # Note that a device ID could be an empty str
        if device_id is not None:
            # If a device ID was passed, use it to filter results.
            # Otherwise, device IDs will be derived from the given collection of user IDs.
            device_ids_to_query.add(device_id)

        # Determine which users have devices with pending messages
        for user_id in user_ids:
            if self._device_inbox_stream_cache.has_entity_changed(
                    user_id, from_stream_id):
                # This user has new messages sent to them. Query messages for them
                user_ids_to_query.add(user_id)

        if not user_ids_to_query:
            return {}, to_stream_id

        def get_device_messages_txn(
            txn: LoggingTransaction,
        ) -> Tuple[Dict[Tuple[str, str], List[JsonDict]], int]:
            # Build a query to select messages from any of the given devices that
            # are between the given stream id bounds.

            # If a list of device IDs was not provided, retrieve all devices IDs
            # for the given users. We explicitly do not query hidden devices, as
            # hidden devices should not receive to-device messages.
            # Note that this is more efficient than just dropping `device_id` from the query,
            # since device_inbox has an index on `(user_id, device_id, stream_id)`
            if not device_ids_to_query:
                user_device_dicts = self.db_pool.simple_select_many_txn(
                    txn,
                    table="devices",
                    column="user_id",
                    iterable=user_ids_to_query,
                    keyvalues={
                        "user_id": user_id,
                        "hidden": False
                    },
                    retcols=("device_id", ),
                )

                device_ids_to_query.update(
                    {row["device_id"]
                     for row in user_device_dicts})

            if not device_ids_to_query:
                # We've ended up with no devices to query.
                return {}, to_stream_id

            # We include both user IDs and device IDs in this query, as we have an index
            # (device_inbox_user_stream_id) for them.
            user_id_many_clause_sql, user_id_many_clause_args = make_in_list_sql_clause(
                self.database_engine, "user_id", user_ids_to_query)
            (
                device_id_many_clause_sql,
                device_id_many_clause_args,
            ) = make_in_list_sql_clause(self.database_engine, "device_id",
                                        device_ids_to_query)

            sql = f"""
                SELECT stream_id, user_id, device_id, message_json FROM device_inbox
                WHERE {user_id_many_clause_sql}
                AND {device_id_many_clause_sql}
                AND ? < stream_id AND stream_id <= ?
                ORDER BY stream_id ASC
            """
            sql_args = (
                *user_id_many_clause_args,
                *device_id_many_clause_args,
                from_stream_id,
                to_stream_id,
            )

            # If a limit was provided, limit the data retrieved from the database
            if limit is not None:
                sql += "LIMIT ?"
                sql_args += (limit, )

            txn.execute(sql, sql_args)

            # Create and fill a dictionary of (user ID, device ID) -> list of messages
            # intended for each device.
            last_processed_stream_pos = to_stream_id
            recipient_device_to_messages: Dict[Tuple[str, str],
                                               List[JsonDict]] = {}
            rowcount = 0
            for row in txn:
                rowcount += 1

                last_processed_stream_pos = row[0]
                recipient_user_id = row[1]
                recipient_device_id = row[2]
                message_dict = db_to_json(row[3])

                # Store the device details
                recipient_device_to_messages.setdefault(
                    (recipient_user_id, recipient_device_id),
                    []).append(message_dict)

            if limit is not None and rowcount == limit:
                # We ended up bumping up against the message limit. There may be more messages
                # to retrieve. Return what we have, as well as the last stream position that
                # was processed.
                #
                # The caller is expected to set this as the lower (exclusive) bound
                # for the next query of this device.
                return recipient_device_to_messages, last_processed_stream_pos

            # The limit was not reached, thus we know that recipient_device_to_messages
            # contains all to-device messages for the given device and stream id range.
            #
            # We return to_stream_id, which the caller should then provide as the lower
            # (exclusive) bound on the next query of this device.
            return recipient_device_to_messages, to_stream_id

        return await self.db_pool.runInteraction("get_device_messages",
                                                 get_device_messages_txn)

    @trace
    async def delete_messages_for_device(self, user_id: str,
                                         device_id: Optional[str],
                                         up_to_stream_id: int) -> int:
        """
        Args:
            user_id: The recipient user_id.
            device_id: The recipient device_id.
            up_to_stream_id: Where to delete messages up to.

        Returns:
            The number of messages deleted.
        """
        # If we have cached the last stream id we've deleted up to, we can
        # check if there is likely to be anything that needs deleting
        last_deleted_stream_id = self._last_device_delete_cache.get(
            (user_id, device_id), None)

        set_tag("last_deleted_stream_id", last_deleted_stream_id)

        if last_deleted_stream_id:
            has_changed = self._device_inbox_stream_cache.has_entity_changed(
                user_id, last_deleted_stream_id)
            if not has_changed:
                log_kv({"message": "No changes in cache since last check"})
                return 0

        def delete_messages_for_device_txn(txn: LoggingTransaction) -> int:
            sql = ("DELETE FROM device_inbox"
                   " WHERE user_id = ? AND device_id = ?"
                   " AND stream_id <= ?")
            txn.execute(sql, (user_id, device_id, up_to_stream_id))
            return txn.rowcount

        count = await self.db_pool.runInteraction(
            "delete_messages_for_device", delete_messages_for_device_txn)

        log_kv({
            "message": f"deleted {count} messages for device",
            "count": count
        })

        # Update the cache, ensuring that we only ever increase the value
        updated_last_deleted_stream_id = self._last_device_delete_cache.get(
            (user_id, device_id), 0)
        self._last_device_delete_cache[(user_id, device_id)] = max(
            updated_last_deleted_stream_id, up_to_stream_id)

        return count

    @trace
    async def get_new_device_msgs_for_remote(
            self, destination: str, last_stream_id: int,
            current_stream_id: int, limit: int) -> Tuple[List[JsonDict], int]:
        """
        Args:
            destination: The name of the remote server.
            last_stream_id: The last position of the device message stream
                that the server sent up to.
            current_stream_id: The current position of the device message stream.
        Returns:
            A list of messages for the device and where in the stream the messages got to.
        """

        set_tag("destination", destination)
        set_tag("last_stream_id", last_stream_id)
        set_tag("current_stream_id", current_stream_id)
        set_tag("limit", limit)

        has_changed = self._device_federation_outbox_stream_cache.has_entity_changed(
            destination, last_stream_id)
        if not has_changed or last_stream_id == current_stream_id:
            log_kv({"message": "No new messages in stream"})
            return [], current_stream_id

        if limit <= 0:
            # This can happen if we run out of room for EDUs in the transaction.
            return [], last_stream_id

        @trace
        def get_new_messages_for_remote_destination_txn(
            txn: LoggingTransaction, ) -> Tuple[List[JsonDict], int]:
            sql = (
                "SELECT stream_id, messages_json FROM device_federation_outbox"
                " WHERE destination = ?"
                " AND ? < stream_id AND stream_id <= ?"
                " ORDER BY stream_id ASC"
                " LIMIT ?")
            txn.execute(
                sql, (destination, last_stream_id, current_stream_id, limit))

            messages = []
            stream_pos = current_stream_id

            for row in txn:
                stream_pos = row[0]
                messages.append(db_to_json(row[1]))

            # If the limit was not reached we know that there's no more data for this
            # user/device pair up to current_stream_id.
            if len(messages) < limit:
                log_kv({"message": "Set stream position to current position"})
                stream_pos = current_stream_id

            return messages, stream_pos

        return await self.db_pool.runInteraction(
            "get_new_device_msgs_for_remote",
            get_new_messages_for_remote_destination_txn,
        )

    @trace
    async def delete_device_msgs_for_remote(self, destination: str,
                                            up_to_stream_id: int) -> None:
        """Used to delete messages when the remote destination acknowledges
        their receipt.

        Args:
            destination: The destination server_name
            up_to_stream_id: Where to delete messages up to.
        """
        def delete_messages_for_remote_destination_txn(
                txn: LoggingTransaction) -> None:
            sql = ("DELETE FROM device_federation_outbox"
                   " WHERE destination = ?"
                   " AND stream_id <= ?")
            txn.execute(sql, (destination, up_to_stream_id))

        await self.db_pool.runInteraction(
            "delete_device_msgs_for_remote",
            delete_messages_for_remote_destination_txn)

    async def get_all_new_device_messages(
            self, instance_name: str, last_id: int, current_id: int,
            limit: int) -> Tuple[List[Tuple[int, tuple]], int, bool]:
        """Get updates for to device replication stream.

        Args:
            instance_name: The writer we want to fetch updates from. Unused
                here since there is only ever one writer.
            last_id: The token to fetch updates from. Exclusive.
            current_id: The token to fetch updates up to. Inclusive.
            limit: The requested limit for the number of rows to return. The
                function may return more or fewer rows.

        Returns:
            A tuple consisting of: the updates, a token to use to fetch
            subsequent updates, and whether we returned fewer rows than exists
            between the requested tokens due to the limit.

            The token returned can be used in a subsequent call to this
            function to get further updatees.

            The updates are a list of 2-tuples of stream ID and the row data
        """

        if last_id == current_id:
            return [], current_id, False

        def get_all_new_device_messages_txn(
            txn: LoggingTransaction,
        ) -> Tuple[List[Tuple[int, tuple]], int, bool]:
            # We limit like this as we might have multiple rows per stream_id, and
            # we want to make sure we always get all entries for any stream_id
            # we return.
            upper_pos = min(current_id, last_id + limit)
            sql = ("SELECT max(stream_id), user_id"
                   " FROM device_inbox"
                   " WHERE ? < stream_id AND stream_id <= ?"
                   " GROUP BY user_id")
            txn.execute(sql, (last_id, upper_pos))
            updates = [(row[0], row[1:]) for row in txn]

            sql = ("SELECT max(stream_id), destination"
                   " FROM device_federation_outbox"
                   " WHERE ? < stream_id AND stream_id <= ?"
                   " GROUP BY destination")
            txn.execute(sql, (last_id, upper_pos))
            updates.extend((row[0], row[1:]) for row in txn)

            # Order by ascending stream ordering
            updates.sort()

            limited = False
            upto_token = current_id
            if len(updates) >= limit:
                upto_token = updates[-1][0]
                limited = True

            return updates, upto_token, limited

        return await self.db_pool.runInteraction(
            "get_all_new_device_messages", get_all_new_device_messages_txn)

    @trace
    async def add_messages_to_device_inbox(
        self,
        local_messages_by_user_then_device: Dict[str, Dict[str, JsonDict]],
        remote_messages_by_destination: Dict[str, JsonDict],
    ) -> int:
        """Used to send messages from this server.

        Args:
            local_messages_by_user_then_device:
                Dictionary of recipient user_id to recipient device_id to message.
            remote_messages_by_destination:
                Dictionary of destination server_name to the EDU JSON to send.

        Returns:
            The new stream_id.
        """

        assert self._can_write_to_device

        def add_messages_txn(txn: LoggingTransaction, now_ms: int,
                             stream_id: int) -> None:
            # Add the local messages directly to the local inbox.
            self._add_messages_to_local_device_inbox_txn(
                txn, stream_id, local_messages_by_user_then_device)

            # Add the remote messages to the federation outbox.
            # We'll send them to a remote server when we next send a
            # federation transaction to that destination.
            self.db_pool.simple_insert_many_txn(
                txn,
                table="device_federation_outbox",
                keys=(
                    "destination",
                    "stream_id",
                    "queued_ts",
                    "messages_json",
                    "instance_name",
                ),
                values=[(
                    destination,
                    stream_id,
                    now_ms,
                    json_encoder.encode(edu),
                    self._instance_name,
                ) for destination, edu in
                        remote_messages_by_destination.items()],
            )

            if remote_messages_by_destination:
                issue9533_logger.debug(
                    "Queued outgoing to-device messages with stream_id %i for %s",
                    stream_id,
                    list(remote_messages_by_destination.keys()),
                )

        async with self._device_inbox_id_gen.get_next() as stream_id:
            now_ms = self._clock.time_msec()
            await self.db_pool.runInteraction("add_messages_to_device_inbox",
                                              add_messages_txn, now_ms,
                                              stream_id)
            for user_id in local_messages_by_user_then_device.keys():
                self._device_inbox_stream_cache.entity_has_changed(
                    user_id, stream_id)
            for destination in remote_messages_by_destination.keys():
                self._device_federation_outbox_stream_cache.entity_has_changed(
                    destination, stream_id)

        return self._device_inbox_id_gen.get_current_token()

    async def add_messages_from_remote_to_device_inbox(
        self,
        origin: str,
        message_id: str,
        local_messages_by_user_then_device: Dict[str, Dict[str, JsonDict]],
    ) -> int:
        assert self._can_write_to_device

        def add_messages_txn(txn: LoggingTransaction, now_ms: int,
                             stream_id: int) -> None:
            # Check if we've already inserted a matching message_id for that
            # origin. This can happen if the origin doesn't receive our
            # acknowledgement from the first time we received the message.
            already_inserted = self.db_pool.simple_select_one_txn(
                txn,
                table="device_federation_inbox",
                keyvalues={
                    "origin": origin,
                    "message_id": message_id
                },
                retcols=("message_id", ),
                allow_none=True,
            )
            if already_inserted is not None:
                return

            # Add an entry for this message_id so that we know we've processed
            # it.
            self.db_pool.simple_insert_txn(
                txn,
                table="device_federation_inbox",
                values={
                    "origin": origin,
                    "message_id": message_id,
                    "received_ts": now_ms,
                },
            )

            # Add the messages to the appropriate local device inboxes so that
            # they'll be sent to the devices when they next sync.
            self._add_messages_to_local_device_inbox_txn(
                txn, stream_id, local_messages_by_user_then_device)

        async with self._device_inbox_id_gen.get_next() as stream_id:
            now_ms = self._clock.time_msec()
            await self.db_pool.runInteraction(
                "add_messages_from_remote_to_device_inbox",
                add_messages_txn,
                now_ms,
                stream_id,
            )
            for user_id in local_messages_by_user_then_device.keys():
                self._device_inbox_stream_cache.entity_has_changed(
                    user_id, stream_id)

        return stream_id

    def _add_messages_to_local_device_inbox_txn(
        self,
        txn: LoggingTransaction,
        stream_id: int,
        messages_by_user_then_device: Dict[str, Dict[str, JsonDict]],
    ) -> None:
        assert self._can_write_to_device

        local_by_user_then_device = {}
        for user_id, messages_by_device in messages_by_user_then_device.items(
        ):
            messages_json_for_user = {}
            devices = list(messages_by_device.keys())
            if len(devices) == 1 and devices[0] == "*":
                # Handle wildcard device_ids.
                # We exclude hidden devices (such as cross-signing keys) here as they are
                # not expected to receive to-device messages.
                devices = self.db_pool.simple_select_onecol_txn(
                    txn,
                    table="devices",
                    keyvalues={
                        "user_id": user_id,
                        "hidden": False
                    },
                    retcol="device_id",
                )

                message_json = json_encoder.encode(messages_by_device["*"])
                for device_id in devices:
                    # Add the message for all devices for this user on this
                    # server.
                    messages_json_for_user[device_id] = message_json
            else:
                if not devices:
                    continue

                # We exclude hidden devices (such as cross-signing keys) here as they are
                # not expected to receive to-device messages.
                rows = self.db_pool.simple_select_many_txn(
                    txn,
                    table="devices",
                    keyvalues={
                        "user_id": user_id,
                        "hidden": False
                    },
                    column="device_id",
                    iterable=devices,
                    retcols=("device_id", ),
                )

                for row in rows:
                    # Only insert into the local inbox if the device exists on
                    # this server
                    device_id = row["device_id"]
                    message_json = json_encoder.encode(
                        messages_by_device[device_id])
                    messages_json_for_user[device_id] = message_json

            if messages_json_for_user:
                local_by_user_then_device[user_id] = messages_json_for_user

        if not local_by_user_then_device:
            return

        self.db_pool.simple_insert_many_txn(
            txn,
            table="device_inbox",
            keys=("user_id", "device_id", "stream_id", "message_json",
                  "instance_name"),
            values=[(user_id, device_id, stream_id, message_json,
                     self._instance_name) for user_id, messages_by_device in
                    local_by_user_then_device.items()
                    for device_id, message_json in messages_by_device.items()],
        )

        issue9533_logger.debug(
            "Stored to-device messages with stream_id %i for %s",
            stream_id,
            [(user_id, device_id)
             for (user_id,
                  messages_by_device) in local_by_user_then_device.items()
             for device_id in messages_by_device.keys()],
        )