Ejemplo n.º 1
0
    def __init__(self, db_conn, hs):
        super(SlavedEventStore, self).__init__(db_conn, hs)
        self._stream_id_gen = SlavedIdTracker(
            db_conn, "events", "stream_ordering",
        )
        self._backfill_id_gen = SlavedIdTracker(
            db_conn, "events", "stream_ordering", step=-1
        )
        events_max = self._stream_id_gen.get_current_token()
        event_cache_prefill, min_event_val = self._get_cache_dict(
            db_conn, "events",
            entity_column="room_id",
            stream_column="stream_ordering",
            max_value=events_max,
        )
        self._events_stream_cache = StreamChangeCache(
            "EventsRoomStreamChangeCache", min_event_val,
            prefilled_cache=event_cache_prefill,
        )
        self._membership_stream_cache = StreamChangeCache(
            "MembershipStreamChangeCache", events_max,
        )

        self.stream_ordering_month_ago = 0
        self._stream_order_on_start = self.get_room_max_stream_ordering()
Ejemplo n.º 2
0
 def test_prefilled_cache(self):
     """
     Providing a prefilled cache to StreamChangeCache will result in a cache
     with the prefilled-cache entered in.
     """
     cache = StreamChangeCache("#test", 1, prefilled_cache={"*****@*****.**": 2})
     self.assertTrue(cache.has_entity_changed("*****@*****.**", 1))
Ejemplo n.º 3
0
class SlavedGroupServerStore(BaseSlavedStore):
    def __init__(self, db_conn, hs):
        super(SlavedGroupServerStore, self).__init__(db_conn, hs)

        self.hs = hs

        self._group_updates_id_gen = SlavedIdTracker(
            db_conn, "local_group_updates", "stream_id",
        )
        self._group_updates_stream_cache = StreamChangeCache(
            "_group_updates_stream_cache", self._group_updates_id_gen.get_current_token(),
        )

    get_groups_changes_for_user = DataStore.get_groups_changes_for_user.__func__
    get_group_stream_token = DataStore.get_group_stream_token.__func__
    get_all_groups_for_user = DataStore.get_all_groups_for_user.__func__

    def stream_positions(self):
        result = super(SlavedGroupServerStore, self).stream_positions()
        result["groups"] = self._group_updates_id_gen.get_current_token()
        return result

    def process_replication_rows(self, stream_name, token, rows):
        if stream_name == "groups":
            self._group_updates_id_gen.advance(token)
            for row in rows:
                self._group_updates_stream_cache.entity_has_changed(
                    row.user_id, token
                )

        return super(SlavedGroupServerStore, self).process_replication_rows(
            stream_name, token, rows
        )
Ejemplo n.º 4
0
class SlavedReceiptsStore(BaseSlavedStore):

    def __init__(self, db_conn, hs):
        super(SlavedReceiptsStore, self).__init__(db_conn, hs)

        self._receipts_id_gen = SlavedIdTracker(
            db_conn, "receipts_linearized", "stream_id"
        )

        self._receipts_stream_cache = StreamChangeCache(
            "ReceiptsRoomChangeCache", self._receipts_id_gen.get_current_token()
        )

    get_receipts_for_user = ReceiptsStore.__dict__["get_receipts_for_user"]
    get_linearized_receipts_for_room = (
        ReceiptsStore.__dict__["get_linearized_receipts_for_room"]
    )
    _get_linearized_receipts_for_rooms = (
        ReceiptsStore.__dict__["_get_linearized_receipts_for_rooms"]
    )
    get_last_receipt_event_id_for_user = (
        ReceiptsStore.__dict__["get_last_receipt_event_id_for_user"]
    )

    get_max_receipt_stream_id = DataStore.get_max_receipt_stream_id.__func__
    get_all_updated_receipts = DataStore.get_all_updated_receipts.__func__

    get_linearized_receipts_for_rooms = (
        DataStore.get_linearized_receipts_for_rooms.__func__
    )

    def stream_positions(self):
        result = super(SlavedReceiptsStore, self).stream_positions()
        result["receipts"] = self._receipts_id_gen.get_current_token()
        return result

    def process_replication(self, result):
        stream = result.get("receipts")
        if stream:
            self._receipts_id_gen.advance(int(stream["position"]))
            for row in stream["rows"]:
                position, room_id, receipt_type, user_id = row[:4]
                self.invalidate_caches_for_receipt(room_id, receipt_type, user_id)
                self._receipts_stream_cache.entity_has_changed(room_id, position)

        return super(SlavedReceiptsStore, self).process_replication(result)

    def invalidate_caches_for_receipt(self, room_id, receipt_type, user_id):
        self.get_receipts_for_user.invalidate((user_id, receipt_type))
        self.get_linearized_receipts_for_room.invalidate_many((room_id,))
        self.get_last_receipt_event_id_for_user.invalidate(
            (user_id, room_id, receipt_type)
        )
Ejemplo n.º 5
0
    def __init__(self, db_conn, hs):
        super(SlavedDeviceStore, self).__init__(db_conn, hs)

        self.hs = hs

        self._device_list_id_gen = SlavedIdTracker(
            db_conn, "device_lists_stream", "stream_id",
        )
        device_list_max = self._device_list_id_gen.get_current_token()
        self._device_list_stream_cache = StreamChangeCache(
            "DeviceListStreamChangeCache", device_list_max,
        )
        self._device_list_federation_stream_cache = StreamChangeCache(
            "DeviceListFederationStreamChangeCache", device_list_max,
        )
Ejemplo n.º 6
0
class SlavedDeviceInboxStore(BaseSlavedStore):
    def __init__(self, db_conn, hs):
        super(SlavedDeviceInboxStore, self).__init__(db_conn, hs)
        self._device_inbox_id_gen = SlavedIdTracker(
            db_conn, "device_max_stream_id", "stream_id",
        )
        self._device_inbox_stream_cache = StreamChangeCache(
            "DeviceInboxStreamChangeCache",
            self._device_inbox_id_gen.get_current_token()
        )
        self._device_federation_outbox_stream_cache = StreamChangeCache(
            "DeviceFederationOutboxStreamChangeCache",
            self._device_inbox_id_gen.get_current_token()
        )

        self._last_device_delete_cache = ExpiringCache(
            cache_name="last_device_delete_cache",
            clock=self._clock,
            max_len=10000,
            expiry_ms=30 * 60 * 1000,
        )

    get_to_device_stream_token = DataStore.get_to_device_stream_token.__func__
    get_new_messages_for_device = DataStore.get_new_messages_for_device.__func__
    get_new_device_msgs_for_remote = DataStore.get_new_device_msgs_for_remote.__func__
    delete_messages_for_device = DataStore.delete_messages_for_device.__func__
    delete_device_msgs_for_remote = DataStore.delete_device_msgs_for_remote.__func__

    def stream_positions(self):
        result = super(SlavedDeviceInboxStore, self).stream_positions()
        result["to_device"] = self._device_inbox_id_gen.get_current_token()
        return result

    def process_replication_rows(self, stream_name, token, rows):
        if stream_name == "to_device":
            self._device_inbox_id_gen.advance(token)
            for row in rows:
                if row.entity.startswith("@"):
                    self._device_inbox_stream_cache.entity_has_changed(
                        row.entity, token
                    )
                else:
                    self._device_federation_outbox_stream_cache.entity_has_changed(
                        row.entity, token
                    )
        return super(SlavedDeviceInboxStore, self).process_replication_rows(
            stream_name, token, rows
        )
Ejemplo n.º 7
0
class SlavedDeviceStore(BaseSlavedStore):
    def __init__(self, db_conn, hs):
        super(SlavedDeviceStore, self).__init__(db_conn, hs)

        self.hs = hs

        self._device_list_id_gen = SlavedIdTracker(
            db_conn, "device_lists_stream", "stream_id",
        )
        device_list_max = self._device_list_id_gen.get_current_token()
        self._device_list_stream_cache = StreamChangeCache(
            "DeviceListStreamChangeCache", device_list_max,
        )
        self._device_list_federation_stream_cache = StreamChangeCache(
            "DeviceListFederationStreamChangeCache", device_list_max,
        )

    get_device_stream_token = __func__(DataStore.get_device_stream_token)
    get_user_whose_devices_changed = __func__(DataStore.get_user_whose_devices_changed)
    get_devices_by_remote = __func__(DataStore.get_devices_by_remote)
    _get_devices_by_remote_txn = __func__(DataStore._get_devices_by_remote_txn)
    _get_e2e_device_keys_txn = __func__(DataStore._get_e2e_device_keys_txn)
    mark_as_sent_devices_by_remote = __func__(DataStore.mark_as_sent_devices_by_remote)
    _mark_as_sent_devices_by_remote_txn = (
        __func__(DataStore._mark_as_sent_devices_by_remote_txn)
    )
    count_e2e_one_time_keys = EndToEndKeyStore.__dict__["count_e2e_one_time_keys"]

    def stream_positions(self):
        result = super(SlavedDeviceStore, self).stream_positions()
        result["device_lists"] = self._device_list_id_gen.get_current_token()
        return result

    def process_replication_rows(self, stream_name, token, rows):
        if stream_name == "device_lists":
            self._device_list_id_gen.advance(token)
            for row in rows:
                self._device_list_stream_cache.entity_has_changed(
                    row.user_id, token
                )

                if row.destination:
                    self._device_list_federation_stream_cache.entity_has_changed(
                        row.destination, token
                    )
        return super(SlavedDeviceStore, self).process_replication_rows(
            stream_name, token, rows
        )
Ejemplo n.º 8
0
    def __init__(self, db_conn, hs):
        account_max = self.get_max_account_data_stream_id()
        self._account_data_stream_cache = StreamChangeCache(
            "AccountDataAndTagsChangeCache", account_max,
        )

        super(AccountDataWorkerStore, self).__init__(db_conn, hs)
Ejemplo n.º 9
0
    def __init__(self, hs):
        self.store = hs.get_datastore()
        self.server_name = hs.config.server_name
        self.auth = hs.get_auth()
        self.is_mine_id = hs.is_mine_id
        self.notifier = hs.get_notifier()
        self.state = hs.get_state_handler()

        self.hs = hs

        self.clock = hs.get_clock()
        self.wheel_timer = WheelTimer(bucket_size=5000)

        self.federation = hs.get_federation_sender()

        hs.get_federation_registry().register_edu_handler("m.typing", self._recv_edu)

        hs.get_distributor().observe("user_left_room", self.user_left_room)

        self._member_typing_until = {}  # clock time we expect to stop
        self._member_last_federation_poke = {}

        self._latest_room_serial = 0
        self._reset()

        # caches which room_ids changed at which serials
        self._typing_stream_change_cache = StreamChangeCache(
            "TypingStreamChangeCache", self._latest_room_serial,
        )

        self.clock.looping_call(
            self._handle_timeouts,
            5000,
        )
Ejemplo n.º 10
0
class SlavedDeviceStore(EndToEndKeyWorkerStore, DeviceWorkerStore, BaseSlavedStore):
    def __init__(self, db_conn, hs):
        super(SlavedDeviceStore, self).__init__(db_conn, hs)

        self.hs = hs

        self._device_list_id_gen = SlavedIdTracker(
            db_conn, "device_lists_stream", "stream_id",
        )
        device_list_max = self._device_list_id_gen.get_current_token()
        self._device_list_stream_cache = StreamChangeCache(
            "DeviceListStreamChangeCache", device_list_max,
        )
        self._device_list_federation_stream_cache = StreamChangeCache(
            "DeviceListFederationStreamChangeCache", device_list_max,
        )

    def stream_positions(self):
        result = super(SlavedDeviceStore, self).stream_positions()
        result["device_lists"] = self._device_list_id_gen.get_current_token()
        return result

    def process_replication_rows(self, stream_name, token, rows):
        if stream_name == "device_lists":
            self._device_list_id_gen.advance(token)
            for row in rows:
                self._invalidate_caches_for_devices(
                    token, row.user_id, row.destination,
                )
        return super(SlavedDeviceStore, self).process_replication_rows(
            stream_name, token, rows
        )

    def _invalidate_caches_for_devices(self, token, user_id, destination):
        self._device_list_stream_cache.entity_has_changed(
            user_id, token
        )

        if destination:
            self._device_list_federation_stream_cache.entity_has_changed(
                destination, token
            )

        self._get_cached_devices_for_user.invalidate((user_id,))
        self._get_cached_user_device.invalidate_many((user_id,))
        self.get_device_list_last_stream_id_for_remote.invalidate((user_id,))
Ejemplo n.º 11
0
class SlavedPushRuleStore(SlavedEventStore):
    def __init__(self, db_conn, hs):
        super(SlavedPushRuleStore, self).__init__(db_conn, hs)
        self._push_rules_stream_id_gen = SlavedIdTracker(
            db_conn, "push_rules_stream", "stream_id",
        )
        self.push_rules_stream_cache = StreamChangeCache(
            "PushRulesStreamChangeCache",
            self._push_rules_stream_id_gen.get_current_token(),
        )

    get_push_rules_for_user = PushRuleStore.__dict__["get_push_rules_for_user"]
    get_push_rules_enabled_for_user = (
        PushRuleStore.__dict__["get_push_rules_enabled_for_user"]
    )
    have_push_rules_changed_for_user = (
        DataStore.have_push_rules_changed_for_user.__func__
    )

    def get_push_rules_stream_token(self):
        return (
            self._push_rules_stream_id_gen.get_current_token(),
            self._stream_id_gen.get_current_token(),
        )

    def stream_positions(self):
        result = super(SlavedPushRuleStore, self).stream_positions()
        result["push_rules"] = self._push_rules_stream_id_gen.get_current_token()
        return result

    def process_replication(self, result):
        stream = result.get("push_rules")
        if stream:
            for row in stream["rows"]:
                position = row[0]
                user_id = row[2]
                self.get_push_rules_for_user.invalidate((user_id,))
                self.get_push_rules_enabled_for_user.invalidate((user_id,))
                self.push_rules_stream_cache.entity_has_changed(
                    user_id, position
                )

            self._push_rules_stream_id_gen.advance(int(stream["position"]))

        return super(SlavedPushRuleStore, self).process_replication(result)
Ejemplo n.º 12
0
    def test_has_any_entity_changed(self):
        """
        StreamChangeCache.has_any_entity_changed will return True if any
        entities have been changed since the provided stream position, and
        False if they have not.  If the cache has entries and the provided
        stream position is before it, it will return True, otherwise False if
        the cache has no entries.
        """
        cache = StreamChangeCache("#test", 1)

        # With no entities, it returns False for the past, present, and future.
        self.assertFalse(cache.has_any_entity_changed(0))
        self.assertFalse(cache.has_any_entity_changed(1))
        self.assertFalse(cache.has_any_entity_changed(2))

        # We add an entity
        cache.entity_has_changed("*****@*****.**", 2)

        # With an entity, it returns True for the past, the stream start
        # position, and False for the stream position the entity was changed
        # on and ones after it.
        self.assertTrue(cache.has_any_entity_changed(0))
        self.assertTrue(cache.has_any_entity_changed(1))
        self.assertFalse(cache.has_any_entity_changed(2))
        self.assertFalse(cache.has_any_entity_changed(3))
Ejemplo n.º 13
0
 def __init__(self, db_conn, hs):
     super(SlavedPushRuleStore, self).__init__(db_conn, hs)
     self._push_rules_stream_id_gen = SlavedIdTracker(
         db_conn, "push_rules_stream", "stream_id",
     )
     self.push_rules_stream_cache = StreamChangeCache(
         "PushRulesStreamChangeCache",
         self._push_rules_stream_id_gen.get_current_token(),
     )
Ejemplo n.º 14
0
 def __init__(self, db_conn, hs):
     super(SlavedAccountDataStore, self).__init__(db_conn, hs)
     self._account_data_id_gen = SlavedIdTracker(
         db_conn, "account_data_max_stream_id", "stream_id",
     )
     self._account_data_stream_cache = StreamChangeCache(
         "AccountDataAndTagsChangeCache",
         self._account_data_id_gen.get_current_token(),
     )
Ejemplo n.º 15
0
    def __init__(self, db_conn, hs):
        super(SlavedPresenceStore, self).__init__(db_conn, hs)
        self._presence_id_gen = SlavedIdTracker(db_conn, "presence_stream",
                                                "stream_id")

        self._presence_on_startup = self._get_active_presence(db_conn)

        self.presence_stream_cache = self.presence_stream_cache = StreamChangeCache(
            "PresenceStreamChangeCache",
            self._presence_id_gen.get_current_token())
Ejemplo n.º 16
0
    def __init__(self, db_conn, hs):
        super(SlavedReceiptsStore, self).__init__(db_conn, hs)

        self._receipts_id_gen = SlavedIdTracker(
            db_conn, "receipts_linearized", "stream_id"
        )

        self._receipts_stream_cache = StreamChangeCache(
            "ReceiptsRoomChangeCache", self._receipts_id_gen.get_current_token()
        )
Ejemplo n.º 17
0
    def __init__(self, db_conn, hs):
        super(SlavedDeviceStore, self).__init__(db_conn, hs)

        self.hs = hs

        self._device_list_id_gen = SlavedIdTracker(
            db_conn,
            "device_lists_stream",
            "stream_id",
        )
        device_list_max = self._device_list_id_gen.get_current_token()
        self._device_list_stream_cache = StreamChangeCache(
            "DeviceListStreamChangeCache",
            device_list_max,
        )
        self._device_list_federation_stream_cache = StreamChangeCache(
            "DeviceListFederationStreamChangeCache",
            device_list_max,
        )
Ejemplo n.º 18
0
    def __init__(self, db_conn, hs):
        super(UserDirectorySlaveStore, self).__init__(db_conn, hs)

        events_max = self._stream_id_gen.get_current_token()
        curr_state_delta_prefill, min_curr_state_delta_id = self._get_cache_dict(
            db_conn,
            "current_state_delta_stream",
            entity_column="room_id",
            stream_column="stream_id",
            max_value=events_max,  # As we share the stream id with events token
            limit=1000,
        )
        self._curr_state_delta_stream_cache = StreamChangeCache(
            "_curr_state_delta_stream_cache",
            min_curr_state_delta_id,
            prefilled_cache=curr_state_delta_prefill,
        )

        self._current_state_delta_pos = events_max
Ejemplo n.º 19
0
    def __init__(self, database: DatabasePool, db_conn, hs):
        super(SlavedDeviceInboxStore, self).__init__(database, db_conn, hs)
        self._device_inbox_id_gen = SlavedIdTracker(db_conn, "device_inbox",
                                                    "stream_id")
        self._device_inbox_stream_cache = StreamChangeCache(
            "DeviceInboxStreamChangeCache",
            self._device_inbox_id_gen.get_current_token(),
        )
        self._device_federation_outbox_stream_cache = StreamChangeCache(
            "DeviceFederationOutboxStreamChangeCache",
            self._device_inbox_id_gen.get_current_token(),
        )

        self._last_device_delete_cache = ExpiringCache(
            cache_name="last_device_delete_cache",
            clock=self._clock,
            max_len=10000,
            expiry_ms=30 * 60 * 1000,
        )
Ejemplo n.º 20
0
    def __init__(self, db_conn, hs):
        super(StreamWorkerStore, self).__init__(db_conn, hs)

        events_max = self.get_room_max_stream_ordering()
        event_cache_prefill, min_event_val = self._get_cache_dict(
            db_conn, "events",
            entity_column="room_id",
            stream_column="stream_ordering",
            max_value=events_max,
        )
        self._events_stream_cache = StreamChangeCache(
            "EventsRoomStreamChangeCache", min_event_val,
            prefilled_cache=event_cache_prefill,
        )
        self._membership_stream_cache = StreamChangeCache(
            "MembershipStreamChangeCache", events_max,
        )

        self._stream_order_on_start = self.get_room_max_stream_ordering()
Ejemplo n.º 21
0
    def __init__(self, db_conn, hs):
        super(StreamWorkerStore, self).__init__(db_conn, hs)

        events_max = self.get_room_max_stream_ordering()
        event_cache_prefill, min_event_val = self._get_cache_dict(
            db_conn, "events",
            entity_column="room_id",
            stream_column="stream_ordering",
            max_value=events_max,
        )
        self._events_stream_cache = StreamChangeCache(
            "EventsRoomStreamChangeCache", min_event_val,
            prefilled_cache=event_cache_prefill,
        )
        self._membership_stream_cache = StreamChangeCache(
            "MembershipStreamChangeCache", events_max,
        )

        self._stream_order_on_start = self.get_room_max_stream_ordering()
Ejemplo n.º 22
0
class SlavedDeviceInboxStore(DeviceInboxWorkerStore, BaseSlavedStore):
    def __init__(self, database: Database, db_conn, hs):
        super(SlavedDeviceInboxStore, self).__init__(database, db_conn, hs)
        self._device_inbox_id_gen = SlavedIdTracker(
            db_conn, "device_max_stream_id", "stream_id"
        )
        self._device_inbox_stream_cache = StreamChangeCache(
            "DeviceInboxStreamChangeCache",
            self._device_inbox_id_gen.get_current_token(),
        )
        self._device_federation_outbox_stream_cache = StreamChangeCache(
            "DeviceFederationOutboxStreamChangeCache",
            self._device_inbox_id_gen.get_current_token(),
        )

        self._last_device_delete_cache = ExpiringCache(
            cache_name="last_device_delete_cache",
            clock=self._clock,
            max_len=10000,
            expiry_ms=30 * 60 * 1000,
        )

    def stream_positions(self):
        result = super(SlavedDeviceInboxStore, self).stream_positions()
        result["to_device"] = self._device_inbox_id_gen.get_current_token()
        return result

    def process_replication_rows(self, stream_name, token, rows):
        if stream_name == "to_device":
            self._device_inbox_id_gen.advance(token)
            for row in rows:
                if row.entity.startswith("@"):
                    self._device_inbox_stream_cache.entity_has_changed(
                        row.entity, token
                    )
                else:
                    self._device_federation_outbox_stream_cache.entity_has_changed(
                        row.entity, token
                    )
        return super(SlavedDeviceInboxStore, self).process_replication_rows(
            stream_name, token, rows
        )
Ejemplo n.º 23
0
    def __init__(
        self,
        database: DatabasePool,
        db_conn: LoggingDatabaseConnection,
        hs: "HomeServer",
    ):
        self._instance_name = hs.get_instance_name()

        if isinstance(database.engine, PostgresEngine):
            self._can_write_to_receipts = (
                self._instance_name in hs.config.worker.writers.receipts)

            self._receipts_id_gen = MultiWriterIdGenerator(
                db_conn=db_conn,
                db=database,
                stream_name="receipts",
                instance_name=self._instance_name,
                tables=[("receipts_linearized", "instance_name", "stream_id")],
                sequence_name="receipts_sequence",
                writers=hs.config.worker.writers.receipts,
            )
        else:
            self._can_write_to_receipts = True

            # We shouldn't be running in worker mode with SQLite, but its useful
            # to support it for unit tests.
            #
            # If this process is the writer than we need to use
            # `StreamIdGenerator`, otherwise we use `SlavedIdTracker` which gets
            # updated over replication. (Multiple writers are not supported for
            # SQLite).
            if hs.get_instance_name() in hs.config.worker.writers.receipts:
                self._receipts_id_gen = StreamIdGenerator(
                    db_conn, "receipts_linearized", "stream_id")
            else:
                self._receipts_id_gen = SlavedIdTracker(
                    db_conn, "receipts_linearized", "stream_id")

        super().__init__(database, db_conn, hs)

        self._receipts_stream_cache = StreamChangeCache(
            "ReceiptsRoomChangeCache", self.get_max_receipt_stream_id())
Ejemplo n.º 24
0
    def __init__(
        self,
        database: DatabasePool,
        db_conn: Connection,
        hs: "HomeServer",
    ):
        super().__init__(database, db_conn, hs)

        self._can_persist_presence = (
            hs.get_instance_name() in hs.config.worker.writers.presence
        )

        if isinstance(database.engine, PostgresEngine):
            self._presence_id_gen = MultiWriterIdGenerator(
                db_conn=db_conn,
                db=database,
                stream_name="presence_stream",
                instance_name=self._instance_name,
                tables=[("presence_stream", "instance_name", "stream_id")],
                sequence_name="presence_stream_sequence",
                writers=hs.config.worker.writers.to_device,
            )
        else:
            self._presence_id_gen = StreamIdGenerator(
                db_conn, "presence_stream", "stream_id"
            )

        self.hs = hs
        self._presence_on_startup = self._get_active_presence(db_conn)

        presence_cache_prefill, min_presence_val = self.db_pool.get_cache_dict(
            db_conn,
            "presence_stream",
            entity_column="user_id",
            stream_column="stream_id",
            max_value=self._presence_id_gen.get_current_token(),
        )
        self.presence_stream_cache = StreamChangeCache(
            "PresenceStreamChangeCache",
            min_presence_val,
            prefilled_cache=presence_cache_prefill,
        )
Ejemplo n.º 25
0
class SlavedPushRuleStore(SlavedEventStore):
    def __init__(self, db_conn, hs):
        super(SlavedPushRuleStore, self).__init__(db_conn, hs)
        self._push_rules_stream_id_gen = SlavedIdTracker(
            db_conn,
            "push_rules_stream",
            "stream_id",
        )
        self.push_rules_stream_cache = StreamChangeCache(
            "PushRulesStreamChangeCache",
            self._push_rules_stream_id_gen.get_current_token(),
        )

    get_push_rules_for_user = PushRuleStore.__dict__["get_push_rules_for_user"]
    get_push_rules_enabled_for_user = (
        PushRuleStore.__dict__["get_push_rules_enabled_for_user"])
    have_push_rules_changed_for_user = (
        DataStore.have_push_rules_changed_for_user.__func__)

    def get_push_rules_stream_token(self):
        return (
            self._push_rules_stream_id_gen.get_current_token(),
            self._stream_id_gen.get_current_token(),
        )

    def stream_positions(self):
        result = super(SlavedPushRuleStore, self).stream_positions()
        result[
            "push_rules"] = self._push_rules_stream_id_gen.get_current_token()
        return result

    def process_replication_rows(self, stream_name, token, rows):
        if stream_name == "push_rules":
            self._push_rules_stream_id_gen.advance(token)
            for row in rows:
                self.get_push_rules_for_user.invalidate((row.user_id, ))
                self.get_push_rules_enabled_for_user.invalidate(
                    (row.user_id, ))
                self.push_rules_stream_cache.entity_has_changed(
                    row.user_id, token)
        return super(SlavedPushRuleStore,
                     self).process_replication_rows(stream_name, token, rows)
Ejemplo n.º 26
0
class SlavedDeviceStore(EndToEndKeyWorkerStore, DeviceWorkerStore,
                        BaseSlavedStore):
    def __init__(self, db_conn, hs):
        super(SlavedDeviceStore, self).__init__(db_conn, hs)

        self.hs = hs

        self._device_list_id_gen = SlavedIdTracker(db_conn,
                                                   "device_lists_stream",
                                                   "stream_id")
        device_list_max = self._device_list_id_gen.get_current_token()
        self._device_list_stream_cache = StreamChangeCache(
            "DeviceListStreamChangeCache", device_list_max)
        self._user_signature_stream_cache = StreamChangeCache(
            "UserSignatureStreamChangeCache", device_list_max)
        self._device_list_federation_stream_cache = StreamChangeCache(
            "DeviceListFederationStreamChangeCache", device_list_max)

    def stream_positions(self):
        result = super(SlavedDeviceStore, self).stream_positions()
        result["device_lists"] = self._device_list_id_gen.get_current_token()
        return result

    def process_replication_rows(self, stream_name, token, rows):
        if stream_name == "device_lists":
            self._device_list_id_gen.advance(token)
            for row in rows:
                self._invalidate_caches_for_devices(token, row.user_id,
                                                    row.destination)
        return super(SlavedDeviceStore,
                     self).process_replication_rows(stream_name, token, rows)

    def _invalidate_caches_for_devices(self, token, user_id, destination):
        self._device_list_stream_cache.entity_has_changed(user_id, token)

        if destination:
            self._device_list_federation_stream_cache.entity_has_changed(
                destination, token)

        self._get_cached_devices_for_user.invalidate((user_id, ))
        self._get_cached_user_device.invalidate_many((user_id, ))
        self.get_device_list_last_stream_id_for_remote.invalidate((user_id, ))
Ejemplo n.º 27
0
class UserDirectorySlaveStore(
    SlavedEventStore,
    SlavedApplicationServiceStore,
    SlavedRegistrationStore,
    SlavedClientIpStore,
    UserDirectoryStore,
    BaseSlavedStore,
):
    def __init__(self, db_conn, hs):
        super(UserDirectorySlaveStore, self).__init__(db_conn, hs)

        events_max = self._stream_id_gen.get_current_token()
        curr_state_delta_prefill, min_curr_state_delta_id = self._get_cache_dict(
            db_conn,
            "current_state_delta_stream",
            entity_column="room_id",
            stream_column="stream_id",
            max_value=events_max,  # As we share the stream id with events token
            limit=1000,
        )
        self._curr_state_delta_stream_cache = StreamChangeCache(
            "_curr_state_delta_stream_cache",
            min_curr_state_delta_id,
            prefilled_cache=curr_state_delta_prefill,
        )

    def stream_positions(self):
        result = super(UserDirectorySlaveStore, self).stream_positions()
        return result

    def process_replication_rows(self, stream_name, token, rows):
        if stream_name == EventsStream.NAME:
            self._stream_id_gen.advance(token)
            for row in rows:
                if row.type != EventsStreamCurrentStateRow.TypeId:
                    continue
                self._curr_state_delta_stream_cache.entity_has_changed(
                    row.data.room_id, token
                )
        return super(UserDirectorySlaveStore, self).process_replication_rows(
            stream_name, token, rows
        )
Ejemplo n.º 28
0
class SlavedDeviceInboxStore(BaseSlavedStore):
    def __init__(self, db_conn, hs):
        super(SlavedDeviceInboxStore, self).__init__(db_conn, hs)
        self._device_inbox_id_gen = SlavedIdTracker(
            db_conn,
            "device_max_stream_id",
            "stream_id",
        )
        self._device_inbox_stream_cache = StreamChangeCache(
            "DeviceInboxStreamChangeCache",
            self._device_inbox_id_gen.get_current_token())
        self._device_federation_outbox_stream_cache = StreamChangeCache(
            "DeviceFederationOutboxStreamChangeCache",
            self._device_inbox_id_gen.get_current_token())

    get_to_device_stream_token = DataStore.get_to_device_stream_token.__func__
    get_new_messages_for_device = DataStore.get_new_messages_for_device.__func__
    get_new_device_msgs_for_remote = DataStore.get_new_device_msgs_for_remote.__func__
    delete_messages_for_device = DataStore.delete_messages_for_device.__func__
    delete_device_msgs_for_remote = DataStore.delete_device_msgs_for_remote.__func__

    def stream_positions(self):
        result = super(SlavedDeviceInboxStore, self).stream_positions()
        result["to_device"] = self._device_inbox_id_gen.get_current_token()
        return result

    def process_replication(self, result):
        stream = result.get("to_device")
        if stream:
            self._device_inbox_id_gen.advance(int(stream["position"]))
            for row in stream["rows"]:
                stream_id = row[0]
                entity = row[1]

                if entity.startswith("@"):
                    self._device_inbox_stream_cache.entity_has_changed(
                        entity, stream_id)
                else:
                    self._device_federation_outbox_stream_cache.entity_has_changed(
                        entity, stream_id)

        return super(SlavedDeviceInboxStore, self).process_replication(result)
Ejemplo n.º 29
0
    def __init__(self, db_conn, hs):
        super(SlavedGroupServerStore, self).__init__(db_conn, hs)

        self.hs = hs

        self._group_updates_id_gen = SlavedIdTracker(
            db_conn, "local_group_updates", "stream_id",
        )
        self._group_updates_stream_cache = StreamChangeCache(
            "_group_updates_stream_cache", self._group_updates_id_gen.get_current_token(),
        )
Ejemplo n.º 30
0
    def __init__(self, hs):
        super().__init__(hs)

        assert hs.config.worker.writers.typing == hs.get_instance_name()

        self.auth = hs.get_auth()
        self.notifier = hs.get_notifier()

        self.hs = hs

        hs.get_federation_registry().register_edu_handler(
            "m.typing", self._recv_edu)

        hs.get_distributor().observe("user_left_room", self.user_left_room)

        self._member_typing_until = {}  # clock time we expect to stop

        # caches which room_ids changed at which serials
        self._typing_stream_change_cache = StreamChangeCache(
            "TypingStreamChangeCache", self._latest_room_serial)
Ejemplo n.º 31
0
    def __init__(self, db_conn, hs):
        super(SlavedDeviceInboxStore, self).__init__(db_conn, hs)
        self._device_inbox_id_gen = SlavedIdTracker(
            db_conn, "device_max_stream_id", "stream_id",
        )
        self._device_inbox_stream_cache = StreamChangeCache(
            "DeviceInboxStreamChangeCache",
            self._device_inbox_id_gen.get_current_token()
        )
        self._device_federation_outbox_stream_cache = StreamChangeCache(
            "DeviceFederationOutboxStreamChangeCache",
            self._device_inbox_id_gen.get_current_token()
        )

        self._last_device_delete_cache = ExpiringCache(
            cache_name="last_device_delete_cache",
            clock=self._clock,
            max_len=10000,
            expiry_ms=30 * 60 * 1000,
        )
Ejemplo n.º 32
0
class UserDirectorySlaveStore(
    SlavedEventStore,
    SlavedApplicationServiceStore,
    SlavedRegistrationStore,
    SlavedClientIpStore,
    UserDirectoryStore,
    BaseSlavedStore,
):
    def __init__(self, db_conn, hs):
        super(UserDirectorySlaveStore, self).__init__(db_conn, hs)

        events_max = self._stream_id_gen.get_current_token()
        curr_state_delta_prefill, min_curr_state_delta_id = self._get_cache_dict(
            db_conn, "current_state_delta_stream",
            entity_column="room_id",
            stream_column="stream_id",
            max_value=events_max,  # As we share the stream id with events token
            limit=1000,
        )
        self._curr_state_delta_stream_cache = StreamChangeCache(
            "_curr_state_delta_stream_cache", min_curr_state_delta_id,
            prefilled_cache=curr_state_delta_prefill,
        )

        self._current_state_delta_pos = events_max

    def stream_positions(self):
        result = super(UserDirectorySlaveStore, self).stream_positions()
        result["current_state_deltas"] = self._current_state_delta_pos
        return result

    def process_replication_rows(self, stream_name, token, rows):
        if stream_name == "current_state_deltas":
            self._current_state_delta_pos = token
            for row in rows:
                self._curr_state_delta_stream_cache.entity_has_changed(
                    row.room_id, token
                )
        return super(UserDirectorySlaveStore, self).process_replication_rows(
            stream_name, token, rows
        )
    def test_has_entity_changed_pops_off_start(self):
        """
        StreamChangeCache.entity_has_changed will respect the max size and
        purge the oldest items upon reaching that max size.
        """
        cache = StreamChangeCache("#test", 1, max_size=2)

        cache.entity_has_changed("*****@*****.**", 2)
        cache.entity_has_changed("*****@*****.**", 3)
        cache.entity_has_changed("*****@*****.**", 4)

        # The cache is at the max size, 2
        self.assertEqual(len(cache._cache), 2)

        # The oldest item has been popped off
        self.assertTrue("*****@*****.**" not in cache._entity_to_key)

        # If we update an existing entity, it keeps the two existing entities
        cache.entity_has_changed("*****@*****.**", 5)
        self.assertEqual({"*****@*****.**", "*****@*****.**"},
                         set(cache._entity_to_key))
Ejemplo n.º 34
0
    def test_get_all_entities_changed(self):
        """
        StreamChangeCache.get_all_entities_changed will return all changed
        entities since the given position.  If the position is before the start
        of the known stream, it returns None instead.
        """
        cache = StreamChangeCache("#test", 1)

        cache.entity_has_changed("*****@*****.**", 2)
        cache.entity_has_changed("*****@*****.**", 3)
        cache.entity_has_changed("*****@*****.**", 4)

        self.assertEqual(
            cache.get_all_entities_changed(1),
            ["*****@*****.**", "*****@*****.**", "*****@*****.**"],
        )
        self.assertEqual(
            cache.get_all_entities_changed(2), ["*****@*****.**", "*****@*****.**"]
        )
        self.assertEqual(cache.get_all_entities_changed(3), ["*****@*****.**"])
        self.assertEqual(cache.get_all_entities_changed(0), None)
Ejemplo n.º 35
0
    def __init__(
        self,
        database: DatabasePool,
        db_conn: LoggingDatabaseConnection,
        hs: "HomeServer",
    ):
        super().__init__(database, db_conn, hs)

        self.hs = hs

        self._device_list_id_gen = SlavedIdTracker(
            db_conn,
            "device_lists_stream",
            "stream_id",
            extra_tables=[
                ("user_signature_stream", "stream_id"),
                ("device_lists_outbound_pokes", "stream_id"),
            ],
        )
        device_list_max = self._device_list_id_gen.get_current_token()
        self._device_list_stream_cache = StreamChangeCache(
            "DeviceListStreamChangeCache", device_list_max)
        self._user_signature_stream_cache = StreamChangeCache(
            "UserSignatureStreamChangeCache", device_list_max)
        self._device_list_federation_stream_cache = StreamChangeCache(
            "DeviceListFederationStreamChangeCache", device_list_max)
Ejemplo n.º 36
0
    def __init__(self, database: Database, db_conn, hs):
        self._stream_id_gen = SlavedIdTracker(db_conn, "events", "stream_ordering")
        self._backfill_id_gen = SlavedIdTracker(
            db_conn, "events", "stream_ordering", step=-1
        )

        super(SlavedEventStore, self).__init__(database, db_conn, hs)

        events_max = self._stream_id_gen.get_current_token()
        curr_state_delta_prefill, min_curr_state_delta_id = self.db.get_cache_dict(
            db_conn,
            "current_state_delta_stream",
            entity_column="room_id",
            stream_column="stream_id",
            max_value=events_max,  # As we share the stream id with events token
            limit=1000,
        )
        self._curr_state_delta_stream_cache = StreamChangeCache(
            "_curr_state_delta_stream_cache",
            min_curr_state_delta_id,
            prefilled_cache=curr_state_delta_prefill,
        )
Ejemplo n.º 37
0
    def __init__(self, hs):
        self.store = hs.get_datastore()
        self.server_name = hs.config.server_name
        self.auth = hs.get_auth()
        self.is_mine_id = hs.is_mine_id
        self.notifier = hs.get_notifier()
        self.state = hs.get_state_handler()

        self.hs = hs

        self.clock = hs.get_clock()
        self.wheel_timer = WheelTimer(bucket_size=5000)

        self.federation = hs.get_federation_sender()

        hs.get_federation_registry().register_edu_handler(
            "m.typing", self._recv_edu)

        hs.get_distributor().observe("user_left_room", self.user_left_room)

        self._member_typing_until = {}  # clock time we expect to stop
        self._member_last_federation_poke = {}

        # map room IDs to serial numbers
        self._room_serials = {}
        self._latest_room_serial = 0
        # map room IDs to sets of users currently typing
        self._room_typing = {}

        # caches which room_ids changed at which serials
        self._typing_stream_change_cache = StreamChangeCache(
            "TypingStreamChangeCache",
            self._latest_room_serial,
        )

        self.clock.looping_call(
            self._handle_timeouts,
            5000,
        )
Ejemplo n.º 38
0
class SlavedPresenceStore(BaseSlavedStore):
    def __init__(self, database: Database, db_conn, hs):
        super(SlavedPresenceStore, self).__init__(database, db_conn, hs)
        self._presence_id_gen = SlavedIdTracker(db_conn, "presence_stream",
                                                "stream_id")

        self._presence_on_startup = self._get_active_presence(db_conn)

        self.presence_stream_cache = StreamChangeCache(
            "PresenceStreamChangeCache",
            self._presence_id_gen.get_current_token())

    _get_active_presence = __func__(DataStore._get_active_presence)
    take_presence_startup_info = __func__(DataStore.take_presence_startup_info)
    _get_presence_for_user = PresenceStore.__dict__["_get_presence_for_user"]
    get_presence_for_users = PresenceStore.__dict__["get_presence_for_users"]

    def get_current_presence_token(self):
        return self._presence_id_gen.get_current_token()

    def stream_positions(self):
        result = super(SlavedPresenceStore, self).stream_positions()

        if self.hs.config.use_presence:
            position = self._presence_id_gen.get_current_token()
            result["presence"] = position

        return result

    def process_replication_rows(self, stream_name, token, rows):
        if stream_name == "presence":
            self._presence_id_gen.advance(token)
            for row in rows:
                self.presence_stream_cache.entity_has_changed(
                    row.user_id, token)
                self._get_presence_for_user.invalidate((row.user_id, ))
        return super(SlavedPresenceStore,
                     self).process_replication_rows(stream_name, token, rows)
Ejemplo n.º 39
0
    def __init__(self, db_conn, hs):
        super(PushRulesWorkerStore, self).__init__(db_conn, hs)

        push_rules_prefill, push_rules_id = self._get_cache_dict(
            db_conn, "push_rules_stream",
            entity_column="user_id",
            stream_column="stream_id",
            max_value=self.get_max_push_rules_stream_id(),
        )

        self.push_rules_stream_cache = StreamChangeCache(
            "PushRulesStreamChangeCache", push_rules_id,
            prefilled_cache=push_rules_prefill,
        )
Ejemplo n.º 40
0
class SlavedGroupServerStore(GroupServerWorkerStore, BaseSlavedStore):
    def __init__(self, database: Database, db_conn, hs):
        super(SlavedGroupServerStore, self).__init__(database, db_conn, hs)

        self.hs = hs

        self._group_updates_id_gen = SlavedIdTracker(
            db_conn, "local_group_updates", "stream_id"
        )
        self._group_updates_stream_cache = StreamChangeCache(
            "_group_updates_stream_cache",
            self._group_updates_id_gen.get_current_token(),
        )

    def get_group_stream_token(self):
        return self._group_updates_id_gen.get_current_token()

    def process_replication_rows(self, stream_name, instance_name, token, rows):
        if stream_name == GroupServerStream.NAME:
            self._group_updates_id_gen.advance(token)
            for row in rows:
                self._group_updates_stream_cache.entity_has_changed(row.user_id, token)

        return super().process_replication_rows(stream_name, instance_name, token, rows)
Ejemplo n.º 41
0
    def __init__(self, hs: "HomeServer"):
        super().__init__(hs)

        assert hs.get_instance_name() in hs.config.worker.writers.typing

        self.auth = hs.get_auth()
        self.notifier = hs.get_notifier()
        self.event_auth_handler = hs.get_event_auth_handler()

        self.hs = hs

        hs.get_federation_registry().register_edu_handler(
            EduTypes.TYPING, self._recv_edu
        )

        hs.get_distributor().observe("user_left_room", self.user_left_room)

        # clock time we expect to stop
        self._member_typing_until: Dict[RoomMember, int] = {}

        # caches which room_ids changed at which serials
        self._typing_stream_change_cache = StreamChangeCache(
            "TypingStreamChangeCache", self._latest_room_serial
        )
Ejemplo n.º 42
0
    def __init__(self, database: Database, db_conn, hs):
        super(PushRulesWorkerStore, self).__init__(database, db_conn, hs)

        if hs.config.worker.worker_app is None:
            self._push_rules_stream_id_gen = ChainedIdGenerator(
                self._stream_id_gen, db_conn, "push_rules_stream", "stream_id"
            )  # type: Union[ChainedIdGenerator, SlavedIdTracker]
        else:
            self._push_rules_stream_id_gen = SlavedIdTracker(
                db_conn, "push_rules_stream", "stream_id")

        push_rules_prefill, push_rules_id = self.db.get_cache_dict(
            db_conn,
            "push_rules_stream",
            entity_column="user_id",
            stream_column="stream_id",
            max_value=self.get_max_push_rules_stream_id(),
        )

        self.push_rules_stream_cache = StreamChangeCache(
            "PushRulesStreamChangeCache",
            push_rules_id,
            prefilled_cache=push_rules_prefill,
        )
Ejemplo n.º 43
0
    def __init__(self, db_conn, hs):
        super(UserDirectorySlaveStore, self).__init__(db_conn, hs)

        events_max = self._stream_id_gen.get_current_token()
        curr_state_delta_prefill, min_curr_state_delta_id = self._get_cache_dict(
            db_conn, "current_state_delta_stream",
            entity_column="room_id",
            stream_column="stream_id",
            max_value=events_max,  # As we share the stream id with events token
            limit=1000,
        )
        self._curr_state_delta_stream_cache = StreamChangeCache(
            "_curr_state_delta_stream_cache", min_curr_state_delta_id,
            prefilled_cache=curr_state_delta_prefill,
        )
Ejemplo n.º 44
0
class SlavedDeviceStore(EndToEndKeyWorkerStore, DeviceWorkerStore,
                        BaseSlavedStore):
    def __init__(self, database: DatabasePool, db_conn, hs):
        super().__init__(database, db_conn, hs)

        self.hs = hs

        self._device_list_id_gen = SlavedIdTracker(
            db_conn,
            "device_lists_stream",
            "stream_id",
            extra_tables=[
                ("user_signature_stream", "stream_id"),
                ("device_lists_outbound_pokes", "stream_id"),
            ],
        )
        device_list_max = self._device_list_id_gen.get_current_token()
        self._device_list_stream_cache = StreamChangeCache(
            "DeviceListStreamChangeCache", device_list_max)
        self._user_signature_stream_cache = StreamChangeCache(
            "UserSignatureStreamChangeCache", device_list_max)
        self._device_list_federation_stream_cache = StreamChangeCache(
            "DeviceListFederationStreamChangeCache", device_list_max)

    def get_device_stream_token(self) -> int:
        return self._device_list_id_gen.get_current_token()

    def process_replication_rows(self, stream_name, instance_name, token,
                                 rows):
        if stream_name == DeviceListsStream.NAME:
            self._device_list_id_gen.advance(instance_name, token)
            self._invalidate_caches_for_devices(token, rows)
        elif stream_name == UserSignatureStream.NAME:
            self._device_list_id_gen.advance(instance_name, token)
            for row in rows:
                self._user_signature_stream_cache.entity_has_changed(
                    row.user_id, token)
        return super().process_replication_rows(stream_name, instance_name,
                                                token, rows)

    def _invalidate_caches_for_devices(self, token, rows):
        for row in rows:
            # The entities are either user IDs (starting with '@') whose devices
            # have changed, or remote servers that we need to tell about
            # changes.
            if row.entity.startswith("@"):
                self._device_list_stream_cache.entity_has_changed(
                    row.entity, token)
                self.get_cached_devices_for_user.invalidate((row.entity, ))
                self._get_cached_user_device.invalidate_many((row.entity, ))
                self.get_device_list_last_stream_id_for_remote.invalidate(
                    (row.entity, ))

            else:
                self._device_list_federation_stream_cache.entity_has_changed(
                    row.entity, token)
Ejemplo n.º 45
0
    def test_has_entity_changed_pops_off_start(self):
        """
        StreamChangeCache.entity_has_changed will respect the max size and
        purge the oldest items upon reaching that max size.
        """
        cache = StreamChangeCache("#test", 1, max_size=2)

        cache.entity_has_changed("*****@*****.**", 2)
        cache.entity_has_changed("*****@*****.**", 3)
        cache.entity_has_changed("*****@*****.**", 4)

        # The cache is at the max size, 2
        self.assertEqual(len(cache._cache), 2)

        # The oldest item has been popped off
        self.assertTrue("*****@*****.**" not in cache._entity_to_key)

        # If we update an existing entity, it keeps the two existing entities
        cache.entity_has_changed("*****@*****.**", 5)
        self.assertEqual(
            set(["*****@*****.**", "*****@*****.**"]), set(cache._entity_to_key)
        )
Ejemplo n.º 46
0
    def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
        super().__init__(database, db_conn, hs)

        events_max = self._stream_id_gen.get_current_token()
        curr_state_delta_prefill, min_curr_state_delta_id = self.db_pool.get_cache_dict(
            db_conn,
            "current_state_delta_stream",
            entity_column="room_id",
            stream_column="stream_id",
            max_value=events_max,  # As we share the stream id with events token
            limit=1000,
        )
        self._curr_state_delta_stream_cache = StreamChangeCache(
            "_curr_state_delta_stream_cache",
            min_curr_state_delta_id,
            prefilled_cache=curr_state_delta_prefill,
        )
Ejemplo n.º 47
0
class AccountDataWorkerStore(SQLBaseStore):
    """This is an abstract base class where subclasses must implement
    `get_max_account_data_stream_id` which can be called in the initializer.
    """

    # This ABCMeta metaclass ensures that we cannot be instantiated without
    # the abstract methods being implemented.
    __metaclass__ = abc.ABCMeta

    def __init__(self, db_conn, hs):
        account_max = self.get_max_account_data_stream_id()
        self._account_data_stream_cache = StreamChangeCache(
            "AccountDataAndTagsChangeCache", account_max
        )

        super(AccountDataWorkerStore, self).__init__(db_conn, hs)

    @abc.abstractmethod
    def get_max_account_data_stream_id(self):
        """Get the current max stream ID for account data stream

        Returns:
            int
        """
        raise NotImplementedError()

    @cached()
    def get_account_data_for_user(self, user_id):
        """Get all the client account_data for a user.

        Args:
            user_id(str): The user to get the account_data for.
        Returns:
            A deferred pair of a dict of global account_data and a dict
            mapping from room_id string to per room account_data dicts.
        """

        def get_account_data_for_user_txn(txn):
            rows = self._simple_select_list_txn(
                txn,
                "account_data",
                {"user_id": user_id},
                ["account_data_type", "content"],
            )

            global_account_data = {
                row["account_data_type"]: json.loads(row["content"]) for row in rows
            }

            rows = self._simple_select_list_txn(
                txn,
                "room_account_data",
                {"user_id": user_id},
                ["room_id", "account_data_type", "content"],
            )

            by_room = {}
            for row in rows:
                room_data = by_room.setdefault(row["room_id"], {})
                room_data[row["account_data_type"]] = json.loads(row["content"])

            return global_account_data, by_room

        return self.runInteraction(
            "get_account_data_for_user", get_account_data_for_user_txn
        )

    @cachedInlineCallbacks(num_args=2, max_entries=5000)
    def get_global_account_data_by_type_for_user(self, data_type, user_id):
        """
        Returns:
            Deferred: A dict
        """
        result = yield self._simple_select_one_onecol(
            table="account_data",
            keyvalues={"user_id": user_id, "account_data_type": data_type},
            retcol="content",
            desc="get_global_account_data_by_type_for_user",
            allow_none=True,
        )

        if result:
            return json.loads(result)
        else:
            return None

    @cached(num_args=2)
    def get_account_data_for_room(self, user_id, room_id):
        """Get all the client account_data for a user for a room.

        Args:
            user_id(str): The user to get the account_data for.
            room_id(str): The room to get the account_data for.
        Returns:
            A deferred dict of the room account_data
        """

        def get_account_data_for_room_txn(txn):
            rows = self._simple_select_list_txn(
                txn,
                "room_account_data",
                {"user_id": user_id, "room_id": room_id},
                ["account_data_type", "content"],
            )

            return {
                row["account_data_type"]: json.loads(row["content"]) for row in rows
            }

        return self.runInteraction(
            "get_account_data_for_room", get_account_data_for_room_txn
        )

    @cached(num_args=3, max_entries=5000)
    def get_account_data_for_room_and_type(self, user_id, room_id, account_data_type):
        """Get the client account_data of given type for a user for a room.

        Args:
            user_id(str): The user to get the account_data for.
            room_id(str): The room to get the account_data for.
            account_data_type (str): The account data type to get.
        Returns:
            A deferred of the room account_data for that type, or None if
            there isn't any set.
        """

        def get_account_data_for_room_and_type_txn(txn):
            content_json = self._simple_select_one_onecol_txn(
                txn,
                table="room_account_data",
                keyvalues={
                    "user_id": user_id,
                    "room_id": room_id,
                    "account_data_type": account_data_type,
                },
                retcol="content",
                allow_none=True,
            )

            return json.loads(content_json) if content_json else None

        return self.runInteraction(
            "get_account_data_for_room_and_type", get_account_data_for_room_and_type_txn
        )

    def get_all_updated_account_data(
        self, last_global_id, last_room_id, current_id, limit
    ):
        """Get all the client account_data that has changed on the server
        Args:
            last_global_id(int): The position to fetch from for top level data
            last_room_id(int): The position to fetch from for per room data
            current_id(int): The position to fetch up to.
        Returns:
            A deferred pair of lists of tuples of stream_id int, user_id string,
            room_id string, and type string.
        """
        if last_room_id == current_id and last_global_id == current_id:
            return defer.succeed(([], []))

        def get_updated_account_data_txn(txn):
            sql = (
                "SELECT stream_id, user_id, account_data_type"
                " FROM account_data WHERE ? < stream_id AND stream_id <= ?"
                " ORDER BY stream_id ASC LIMIT ?"
            )
            txn.execute(sql, (last_global_id, current_id, limit))
            global_results = txn.fetchall()

            sql = (
                "SELECT stream_id, user_id, room_id, account_data_type"
                " FROM room_account_data WHERE ? < stream_id AND stream_id <= ?"
                " ORDER BY stream_id ASC LIMIT ?"
            )
            txn.execute(sql, (last_room_id, current_id, limit))
            room_results = txn.fetchall()
            return global_results, room_results

        return self.runInteraction(
            "get_all_updated_account_data_txn", get_updated_account_data_txn
        )

    def get_updated_account_data_for_user(self, user_id, stream_id):
        """Get all the client account_data for a that's changed for a user

        Args:
            user_id(str): The user to get the account_data for.
            stream_id(int): The point in the stream since which to get updates
        Returns:
            A deferred pair of a dict of global account_data and a dict
            mapping from room_id string to per room account_data dicts.
        """

        def get_updated_account_data_for_user_txn(txn):
            sql = (
                "SELECT account_data_type, content FROM account_data"
                " WHERE user_id = ? AND stream_id > ?"
            )

            txn.execute(sql, (user_id, stream_id))

            global_account_data = {row[0]: json.loads(row[1]) for row in txn}

            sql = (
                "SELECT room_id, account_data_type, content FROM room_account_data"
                " WHERE user_id = ? AND stream_id > ?"
            )

            txn.execute(sql, (user_id, stream_id))

            account_data_by_room = {}
            for row in txn:
                room_account_data = account_data_by_room.setdefault(row[0], {})
                room_account_data[row[1]] = json.loads(row[2])

            return global_account_data, account_data_by_room

        changed = self._account_data_stream_cache.has_entity_changed(
            user_id, int(stream_id)
        )
        if not changed:
            return {}, {}

        return self.runInteraction(
            "get_updated_account_data_for_user", get_updated_account_data_for_user_txn
        )

    @cachedInlineCallbacks(num_args=2, cache_context=True, max_entries=5000)
    def is_ignored_by(self, ignored_user_id, ignorer_user_id, cache_context):
        ignored_account_data = yield self.get_global_account_data_by_type_for_user(
            "m.ignored_user_list",
            ignorer_user_id,
            on_invalidate=cache_context.invalidate,
        )
        if not ignored_account_data:
            return False

        return ignored_user_id in ignored_account_data.get("ignored_users", {})
Ejemplo n.º 48
0
class PresenceStore(PresenceBackgroundUpdateStore):
    def __init__(
        self,
        database: DatabasePool,
        db_conn: LoggingDatabaseConnection,
        hs: "HomeServer",
    ):
        super().__init__(database, db_conn, hs)

        self._can_persist_presence = (hs.get_instance_name()
                                      in hs.config.worker.writers.presence)

        if isinstance(database.engine, PostgresEngine):
            self._presence_id_gen = MultiWriterIdGenerator(
                db_conn=db_conn,
                db=database,
                stream_name="presence_stream",
                instance_name=self._instance_name,
                tables=[("presence_stream", "instance_name", "stream_id")],
                sequence_name="presence_stream_sequence",
                writers=hs.config.worker.writers.presence,
            )
        else:
            self._presence_id_gen = StreamIdGenerator(db_conn,
                                                      "presence_stream",
                                                      "stream_id")

        self.hs = hs
        self._presence_on_startup = self._get_active_presence(db_conn)

        presence_cache_prefill, min_presence_val = self.db_pool.get_cache_dict(
            db_conn,
            "presence_stream",
            entity_column="user_id",
            stream_column="stream_id",
            max_value=self._presence_id_gen.get_current_token(),
        )
        self.presence_stream_cache = StreamChangeCache(
            "PresenceStreamChangeCache",
            min_presence_val,
            prefilled_cache=presence_cache_prefill,
        )

    async def update_presence(self, presence_states) -> Tuple[int, int]:
        assert self._can_persist_presence

        stream_ordering_manager = self._presence_id_gen.get_next_mult(
            len(presence_states))

        async with stream_ordering_manager as stream_orderings:
            await self.db_pool.runInteraction(
                "update_presence",
                self._update_presence_txn,
                stream_orderings,
                presence_states,
            )

        return stream_orderings[-1], self._presence_id_gen.get_current_token()

    def _update_presence_txn(self, txn, stream_orderings, presence_states):
        for stream_id, state in zip(stream_orderings, presence_states):
            txn.call_after(self.presence_stream_cache.entity_has_changed,
                           state.user_id, stream_id)
            txn.call_after(self._get_presence_for_user.invalidate,
                           (state.user_id, ))

        # Delete old rows to stop database from getting really big
        sql = "DELETE FROM presence_stream WHERE stream_id < ? AND "

        for states in batch_iter(presence_states, 50):
            clause, args = make_in_list_sql_clause(self.database_engine,
                                                   "user_id",
                                                   [s.user_id for s in states])
            txn.execute(sql + clause, [stream_id] + list(args))

        # Actually insert new rows
        self.db_pool.simple_insert_many_txn(
            txn,
            table="presence_stream",
            keys=(
                "stream_id",
                "user_id",
                "state",
                "last_active_ts",
                "last_federation_update_ts",
                "last_user_sync_ts",
                "status_msg",
                "currently_active",
                "instance_name",
            ),
            values=[(
                stream_id,
                state.user_id,
                state.state,
                state.last_active_ts,
                state.last_federation_update_ts,
                state.last_user_sync_ts,
                state.status_msg,
                state.currently_active,
                self._instance_name,
            ) for stream_id, state in zip(stream_orderings, presence_states)],
        )

    async def get_all_presence_updates(
            self, instance_name: str, last_id: int, current_id: int,
            limit: int) -> Tuple[List[Tuple[int, list]], int, bool]:
        """Get updates for presence replication stream.

        Args:
            instance_name: The writer we want to fetch updates from. Unused
                here since there is only ever one writer.
            last_id: The token to fetch updates from. Exclusive.
            current_id: The token to fetch updates up to. Inclusive.
            limit: The requested limit for the number of rows to return. The
                function may return more or fewer rows.

        Returns:
            A tuple consisting of: the updates, a token to use to fetch
            subsequent updates, and whether we returned fewer rows than exists
            between the requested tokens due to the limit.

            The token returned can be used in a subsequent call to this
            function to get further updatees.

            The updates are a list of 2-tuples of stream ID and the row data
        """

        if last_id == current_id:
            return [], current_id, False

        def get_all_presence_updates_txn(txn):
            sql = """
                SELECT stream_id, user_id, state, last_active_ts,
                    last_federation_update_ts, last_user_sync_ts,
                    status_msg,
                currently_active
                FROM presence_stream
                WHERE ? < stream_id AND stream_id <= ?
                ORDER BY stream_id ASC
                LIMIT ?
            """
            txn.execute(sql, (last_id, current_id, limit))
            updates = [(row[0], row[1:]) for row in txn]

            upper_bound = current_id
            limited = False
            if len(updates) >= limit:
                upper_bound = updates[-1][0]
                limited = True

            return updates, upper_bound, limited

        return await self.db_pool.runInteraction("get_all_presence_updates",
                                                 get_all_presence_updates_txn)

    @cached()
    def _get_presence_for_user(self, user_id):
        raise NotImplementedError()

    @cachedList(
        cached_method_name="_get_presence_for_user",
        list_name="user_ids",
        num_args=1,
    )
    async def get_presence_for_users(self, user_ids):
        rows = await self.db_pool.simple_select_many_batch(
            table="presence_stream",
            column="user_id",
            iterable=user_ids,
            keyvalues={},
            retcols=(
                "user_id",
                "state",
                "last_active_ts",
                "last_federation_update_ts",
                "last_user_sync_ts",
                "status_msg",
                "currently_active",
            ),
            desc="get_presence_for_users",
        )

        for row in rows:
            row["currently_active"] = bool(row["currently_active"])

        return {row["user_id"]: UserPresenceState(**row) for row in rows}

    async def should_user_receive_full_presence_with_token(
        self,
        user_id: str,
        from_token: int,
    ) -> bool:
        """Check whether the given user should receive full presence using the stream token
        they're updating from.

        Args:
            user_id: The ID of the user to check.
            from_token: The stream token included in their /sync token.

        Returns:
            True if the user should have full presence sent to them, False otherwise.
        """
        def _should_user_receive_full_presence_with_token_txn(txn):
            sql = """
                SELECT 1 FROM users_to_send_full_presence_to
                WHERE user_id = ?
                AND presence_stream_id >= ?
            """
            txn.execute(sql, (user_id, from_token))
            return bool(txn.fetchone())

        return await self.db_pool.runInteraction(
            "should_user_receive_full_presence_with_token",
            _should_user_receive_full_presence_with_token_txn,
        )

    async def add_users_to_send_full_presence_to(self,
                                                 user_ids: Iterable[str]):
        """Adds to the list of users who should receive a full snapshot of presence
        upon their next sync.

        Args:
            user_ids: An iterable of user IDs.
        """
        # Add user entries to the table, updating the presence_stream_id column if the user already
        # exists in the table.
        presence_stream_id = self._presence_id_gen.get_current_token()
        await self.db_pool.simple_upsert_many(
            table="users_to_send_full_presence_to",
            key_names=("user_id", ),
            key_values=[(user_id, ) for user_id in user_ids],
            value_names=("presence_stream_id", ),
            # We save the current presence stream ID token along with the user ID entry so
            # that when a user /sync's, even if they syncing multiple times across separate
            # devices at different times, each device will receive full presence once - when
            # the presence stream ID in their sync token is less than the one in the table
            # for their user ID.
            value_values=[(presence_stream_id, ) for _ in user_ids],
            desc="add_users_to_send_full_presence_to",
        )

    async def get_presence_for_all_users(
        self,
        include_offline: bool = True,
    ) -> Dict[str, UserPresenceState]:
        """Retrieve the current presence state for all users.

        Note that the presence_stream table is culled frequently, so it should only
        contain the latest presence state for each user.

        Args:
            include_offline: Whether to include offline presence states

        Returns:
            A dict of user IDs to their current UserPresenceState.
        """
        users_to_state = {}

        exclude_keyvalues = None
        if not include_offline:
            # Exclude offline presence state
            exclude_keyvalues = {"state": "offline"}

        # This may be a very heavy database query.
        # We paginate in order to not block a database connection.
        limit = 100
        offset = 0
        while True:
            rows = await self.db_pool.runInteraction(
                "get_presence_for_all_users",
                self.db_pool.simple_select_list_paginate_txn,
                "presence_stream",
                orderby="stream_id",
                start=offset,
                limit=limit,
                exclude_keyvalues=exclude_keyvalues,
                retcols=(
                    "user_id",
                    "state",
                    "last_active_ts",
                    "last_federation_update_ts",
                    "last_user_sync_ts",
                    "status_msg",
                    "currently_active",
                ),
                order_direction="ASC",
            )

            for row in rows:
                users_to_state[row["user_id"]] = UserPresenceState(**row)

            # We've run out of updates to query
            if len(rows) < limit:
                break

            offset += limit

        return users_to_state

    def get_current_presence_token(self):
        return self._presence_id_gen.get_current_token()

    def _get_active_presence(self, db_conn: Connection):
        """Fetch non-offline presence from the database so that we can register
        the appropriate time outs.
        """

        # The `presence_stream_state_not_offline_idx` index should be used for this
        # query.
        sql = (
            "SELECT user_id, state, last_active_ts, last_federation_update_ts,"
            " last_user_sync_ts, status_msg, currently_active FROM presence_stream"
            " WHERE state != ?")

        txn = db_conn.cursor()
        txn.execute(sql, (PresenceState.OFFLINE, ))
        rows = self.db_pool.cursor_to_dict(txn)
        txn.close()

        for row in rows:
            row["currently_active"] = bool(row["currently_active"])

        return [UserPresenceState(**row) for row in rows]

    def take_presence_startup_info(self):
        active_on_startup = self._presence_on_startup
        self._presence_on_startup = None
        return active_on_startup

    def process_replication_rows(self, stream_name, instance_name, token,
                                 rows):
        if stream_name == PresenceStream.NAME:
            self._presence_id_gen.advance(instance_name, token)
            for row in rows:
                self.presence_stream_cache.entity_has_changed(
                    row.user_id, token)
                self._get_presence_for_user.invalidate((row.user_id, ))
        return super().process_replication_rows(stream_name, instance_name,
                                                token, rows)
Ejemplo n.º 49
0
class ReceiptsStore(SQLBaseStore):
    def __init__(self, hs):
        super(ReceiptsStore, self).__init__(hs)

        self._receipts_stream_cache = StreamChangeCache(
            "ReceiptsRoomChangeCache", self._receipts_id_gen.get_current_token()
        )

    @cachedInlineCallbacks()
    def get_users_with_read_receipts_in_room(self, room_id):
        receipts = yield self.get_receipts_for_room(room_id, "m.read")
        defer.returnValue(set(r['user_id'] for r in receipts))

    def _invalidate_get_users_with_receipts_in_room(self, room_id, receipt_type,
                                                    user_id):
        if receipt_type != "m.read":
            return

        # Returns an ObservableDeferred
        res = self.get_users_with_read_receipts_in_room.cache.get((room_id,), None)

        if res and res.called and user_id in res.result:
            # We'd only be adding to the set, so no point invalidating if the
            # user is already there
            return

        self.get_users_with_read_receipts_in_room.invalidate((room_id,))

    @cached(num_args=2)
    def get_receipts_for_room(self, room_id, receipt_type):
        return self._simple_select_list(
            table="receipts_linearized",
            keyvalues={
                "room_id": room_id,
                "receipt_type": receipt_type,
            },
            retcols=("user_id", "event_id"),
            desc="get_receipts_for_room",
        )

    @cached(num_args=3)
    def get_last_receipt_event_id_for_user(self, user_id, room_id, receipt_type):
        return self._simple_select_one_onecol(
            table="receipts_linearized",
            keyvalues={
                "room_id": room_id,
                "receipt_type": receipt_type,
                "user_id": user_id
            },
            retcol="event_id",
            desc="get_own_receipt_for_user",
            allow_none=True,
        )

    @cachedInlineCallbacks(num_args=2)
    def get_receipts_for_user(self, user_id, receipt_type):
        rows = yield self._simple_select_list(
            table="receipts_linearized",
            keyvalues={
                "user_id": user_id,
                "receipt_type": receipt_type,
            },
            retcols=("room_id", "event_id"),
            desc="get_receipts_for_user",
        )

        defer.returnValue({row["room_id"]: row["event_id"] for row in rows})

    @defer.inlineCallbacks
    def get_linearized_receipts_for_rooms(self, room_ids, to_key, from_key=None):
        """Get receipts for multiple rooms for sending to clients.

        Args:
            room_ids (list): List of room_ids.
            to_key (int): Max stream id to fetch receipts upto.
            from_key (int): Min stream id to fetch receipts from. None fetches
                from the start.

        Returns:
            list: A list of receipts.
        """
        room_ids = set(room_ids)

        if from_key:
            room_ids = yield self._receipts_stream_cache.get_entities_changed(
                room_ids, from_key
            )

        results = yield self._get_linearized_receipts_for_rooms(
            room_ids, to_key, from_key=from_key
        )

        defer.returnValue([ev for res in results.values() for ev in res])

    @cachedInlineCallbacks(num_args=3, max_entries=5000, lru=True, tree=True)
    def get_linearized_receipts_for_room(self, room_id, to_key, from_key=None):
        """Get receipts for a single room for sending to clients.

        Args:
            room_ids (str): The room id.
            to_key (int): Max stream id to fetch receipts upto.
            from_key (int): Min stream id to fetch receipts from. None fetches
                from the start.

        Returns:
            list: A list of receipts.
        """
        def f(txn):
            if from_key:
                sql = (
                    "SELECT * FROM receipts_linearized WHERE"
                    " room_id = ? AND stream_id > ? AND stream_id <= ?"
                )

                txn.execute(
                    sql,
                    (room_id, from_key, to_key)
                )
            else:
                sql = (
                    "SELECT * FROM receipts_linearized WHERE"
                    " room_id = ? AND stream_id <= ?"
                )

                txn.execute(
                    sql,
                    (room_id, to_key)
                )

            rows = self.cursor_to_dict(txn)

            return rows

        rows = yield self.runInteraction(
            "get_linearized_receipts_for_room", f
        )

        if not rows:
            defer.returnValue([])

        content = {}
        for row in rows:
            content.setdefault(
                row["event_id"], {}
            ).setdefault(
                row["receipt_type"], {}
            )[row["user_id"]] = json.loads(row["data"])

        defer.returnValue([{
            "type": "m.receipt",
            "room_id": room_id,
            "content": content,
        }])

    @cachedList(cached_method_name="get_linearized_receipts_for_room",
                list_name="room_ids", num_args=3, inlineCallbacks=True)
    def _get_linearized_receipts_for_rooms(self, room_ids, to_key, from_key=None):
        if not room_ids:
            defer.returnValue({})

        def f(txn):
            if from_key:
                sql = (
                    "SELECT * FROM receipts_linearized WHERE"
                    " room_id IN (%s) AND stream_id > ? AND stream_id <= ?"
                ) % (
                    ",".join(["?"] * len(room_ids))
                )
                args = list(room_ids)
                args.extend([from_key, to_key])

                txn.execute(sql, args)
            else:
                sql = (
                    "SELECT * FROM receipts_linearized WHERE"
                    " room_id IN (%s) AND stream_id <= ?"
                ) % (
                    ",".join(["?"] * len(room_ids))
                )

                args = list(room_ids)
                args.append(to_key)

                txn.execute(sql, args)

            return self.cursor_to_dict(txn)

        txn_results = yield self.runInteraction(
            "_get_linearized_receipts_for_rooms", f
        )

        results = {}
        for row in txn_results:
            # We want a single event per room, since we want to batch the
            # receipts by room, event and type.
            room_event = results.setdefault(row["room_id"], {
                "type": "m.receipt",
                "room_id": row["room_id"],
                "content": {},
            })

            # The content is of the form:
            # {"$foo:bar": { "read": { "@user:host": <receipt> }, .. }, .. }
            event_entry = room_event["content"].setdefault(row["event_id"], {})
            receipt_type = event_entry.setdefault(row["receipt_type"], {})

            receipt_type[row["user_id"]] = json.loads(row["data"])

        results = {
            room_id: [results[room_id]] if room_id in results else []
            for room_id in room_ids
        }
        defer.returnValue(results)

    def get_max_receipt_stream_id(self):
        return self._receipts_id_gen.get_current_token()

    def insert_linearized_receipt_txn(self, txn, room_id, receipt_type,
                                      user_id, event_id, data, stream_id):
        txn.call_after(
            self.get_receipts_for_room.invalidate, (room_id, receipt_type)
        )
        txn.call_after(
            self._invalidate_get_users_with_receipts_in_room,
            room_id, receipt_type, user_id,
        )
        txn.call_after(
            self.get_receipts_for_user.invalidate, (user_id, receipt_type)
        )
        # FIXME: This shouldn't invalidate the whole cache
        txn.call_after(self.get_linearized_receipts_for_room.invalidate_many, (room_id,))

        txn.call_after(
            self._receipts_stream_cache.entity_has_changed,
            room_id, stream_id
        )

        txn.call_after(
            self.get_last_receipt_event_id_for_user.invalidate,
            (user_id, room_id, receipt_type)
        )

        res = self._simple_select_one_txn(
            txn,
            table="events",
            retcols=["topological_ordering", "stream_ordering"],
            keyvalues={"event_id": event_id},
            allow_none=True
        )

        topological_ordering = int(res["topological_ordering"]) if res else None
        stream_ordering = int(res["stream_ordering"]) if res else None

        # We don't want to clobber receipts for more recent events, so we
        # have to compare orderings of existing receipts
        sql = (
            "SELECT topological_ordering, stream_ordering, event_id FROM events"
            " INNER JOIN receipts_linearized as r USING (event_id, room_id)"
            " WHERE r.room_id = ? AND r.receipt_type = ? AND r.user_id = ?"
        )

        txn.execute(sql, (room_id, receipt_type, user_id))
        results = txn.fetchall()

        if results and topological_ordering:
            for to, so, _ in results:
                if int(to) > topological_ordering:
                    return False
                elif int(to) == topological_ordering and int(so) >= stream_ordering:
                    return False

        self._simple_delete_txn(
            txn,
            table="receipts_linearized",
            keyvalues={
                "room_id": room_id,
                "receipt_type": receipt_type,
                "user_id": user_id,
            }
        )

        self._simple_insert_txn(
            txn,
            table="receipts_linearized",
            values={
                "stream_id": stream_id,
                "room_id": room_id,
                "receipt_type": receipt_type,
                "user_id": user_id,
                "event_id": event_id,
                "data": json.dumps(data),
            }
        )

        if receipt_type == "m.read" and topological_ordering:
            self._remove_old_push_actions_before_txn(
                txn,
                room_id=room_id,
                user_id=user_id,
                topological_ordering=topological_ordering,
            )

        return True

    @defer.inlineCallbacks
    def insert_receipt(self, room_id, receipt_type, user_id, event_ids, data):
        """Insert a receipt, either from local client or remote server.

        Automatically does conversion between linearized and graph
        representations.
        """
        if not event_ids:
            return

        if len(event_ids) == 1:
            linearized_event_id = event_ids[0]
        else:
            # we need to points in graph -> linearized form.
            # TODO: Make this better.
            def graph_to_linear(txn):
                query = (
                    "SELECT event_id WHERE room_id = ? AND stream_ordering IN ("
                    " SELECT max(stream_ordering) WHERE event_id IN (%s)"
                    ")"
                ) % (",".join(["?"] * len(event_ids)))

                txn.execute(query, [room_id] + event_ids)
                rows = txn.fetchall()
                if rows:
                    return rows[0][0]
                else:
                    raise RuntimeError("Unrecognized event_ids: %r" % (event_ids,))

            linearized_event_id = yield self.runInteraction(
                "insert_receipt_conv", graph_to_linear
            )

        stream_id_manager = self._receipts_id_gen.get_next()
        with stream_id_manager as stream_id:
            have_persisted = yield self.runInteraction(
                "insert_linearized_receipt",
                self.insert_linearized_receipt_txn,
                room_id, receipt_type, user_id, linearized_event_id,
                data,
                stream_id=stream_id,
            )

            if not have_persisted:
                defer.returnValue(None)

        yield self.insert_graph_receipt(
            room_id, receipt_type, user_id, event_ids, data
        )

        max_persisted_id = self._stream_id_gen.get_current_token()

        defer.returnValue((stream_id, max_persisted_id))

    def insert_graph_receipt(self, room_id, receipt_type, user_id, event_ids,
                             data):
        return self.runInteraction(
            "insert_graph_receipt",
            self.insert_graph_receipt_txn,
            room_id, receipt_type, user_id, event_ids, data
        )

    def insert_graph_receipt_txn(self, txn, room_id, receipt_type,
                                 user_id, event_ids, data):
        txn.call_after(
            self.get_receipts_for_room.invalidate, (room_id, receipt_type)
        )
        txn.call_after(
            self._invalidate_get_users_with_receipts_in_room,
            room_id, receipt_type, user_id,
        )
        txn.call_after(
            self.get_receipts_for_user.invalidate, (user_id, receipt_type)
        )
        # FIXME: This shouldn't invalidate the whole cache
        txn.call_after(self.get_linearized_receipts_for_room.invalidate_many, (room_id,))

        self._simple_delete_txn(
            txn,
            table="receipts_graph",
            keyvalues={
                "room_id": room_id,
                "receipt_type": receipt_type,
                "user_id": user_id,
            }
        )
        self._simple_insert_txn(
            txn,
            table="receipts_graph",
            values={
                "room_id": room_id,
                "receipt_type": receipt_type,
                "user_id": user_id,
                "event_ids": json.dumps(event_ids),
                "data": json.dumps(data),
            }
        )

    def get_all_updated_receipts(self, last_id, current_id, limit=None):
        if last_id == current_id:
            return defer.succeed([])

        def get_all_updated_receipts_txn(txn):
            sql = (
                "SELECT stream_id, room_id, receipt_type, user_id, event_id, data"
                " FROM receipts_linearized"
                " WHERE ? < stream_id AND stream_id <= ?"
                " ORDER BY stream_id ASC"
            )
            args = [last_id, current_id]
            if limit is not None:
                sql += " LIMIT ?"
                args.append(limit)
            txn.execute(sql, args)

            return txn.fetchall()
        return self.runInteraction(
            "get_all_updated_receipts", get_all_updated_receipts_txn
        )
Ejemplo n.º 50
0
    def __init__(self, hs):
        super(ReceiptsStore, self).__init__(hs)

        self._receipts_stream_cache = StreamChangeCache(
            "ReceiptsRoomChangeCache", self._receipts_id_gen.get_current_token()
        )
Ejemplo n.º 51
0
class SlavedEventStore(BaseSlavedStore):

    def __init__(self, db_conn, hs):
        super(SlavedEventStore, self).__init__(db_conn, hs)
        self._stream_id_gen = SlavedIdTracker(
            db_conn, "events", "stream_ordering",
        )
        self._backfill_id_gen = SlavedIdTracker(
            db_conn, "events", "stream_ordering", step=-1
        )
        events_max = self._stream_id_gen.get_current_token()
        event_cache_prefill, min_event_val = self._get_cache_dict(
            db_conn, "events",
            entity_column="room_id",
            stream_column="stream_ordering",
            max_value=events_max,
        )
        self._events_stream_cache = StreamChangeCache(
            "EventsRoomStreamChangeCache", min_event_val,
            prefilled_cache=event_cache_prefill,
        )
        self._membership_stream_cache = StreamChangeCache(
            "MembershipStreamChangeCache", events_max,
        )

        self.stream_ordering_month_ago = 0
        self._stream_order_on_start = self.get_room_max_stream_ordering()

    # Cached functions can't be accessed through a class instance so we need
    # to reach inside the __dict__ to extract them.
    get_rooms_for_user = RoomMemberStore.__dict__["get_rooms_for_user"]
    get_users_in_room = RoomMemberStore.__dict__["get_users_in_room"]
    get_latest_event_ids_in_room = EventFederationStore.__dict__[
        "get_latest_event_ids_in_room"
    ]
    _get_current_state_for_key = StateStore.__dict__[
        "_get_current_state_for_key"
    ]
    get_invited_rooms_for_user = RoomMemberStore.__dict__[
        "get_invited_rooms_for_user"
    ]
    get_unread_event_push_actions_by_room_for_user = (
        EventPushActionsStore.__dict__["get_unread_event_push_actions_by_room_for_user"]
    )
    _get_state_group_for_events = (
        StateStore.__dict__["_get_state_group_for_events"]
    )
    _get_state_group_for_event = (
        StateStore.__dict__["_get_state_group_for_event"]
    )
    _get_state_groups_from_groups = (
        StateStore.__dict__["_get_state_groups_from_groups"]
    )
    _get_state_groups_from_groups_txn = (
        DataStore._get_state_groups_from_groups_txn.__func__
    )
    _get_state_group_from_group = (
        StateStore.__dict__["_get_state_group_from_group"]
    )
    get_recent_event_ids_for_room = (
        StreamStore.__dict__["get_recent_event_ids_for_room"]
    )

    get_unread_push_actions_for_user_in_range_for_http = (
        DataStore.get_unread_push_actions_for_user_in_range_for_http.__func__
    )
    get_unread_push_actions_for_user_in_range_for_email = (
        DataStore.get_unread_push_actions_for_user_in_range_for_email.__func__
    )
    get_push_action_users_in_range = (
        DataStore.get_push_action_users_in_range.__func__
    )
    get_event = DataStore.get_event.__func__
    get_events = DataStore.get_events.__func__
    get_current_state = DataStore.get_current_state.__func__
    get_current_state_for_key = DataStore.get_current_state_for_key.__func__
    get_rooms_for_user_where_membership_is = (
        DataStore.get_rooms_for_user_where_membership_is.__func__
    )
    get_membership_changes_for_user = (
        DataStore.get_membership_changes_for_user.__func__
    )
    get_room_events_max_id = DataStore.get_room_events_max_id.__func__
    get_room_events_stream_for_room = (
        DataStore.get_room_events_stream_for_room.__func__
    )
    get_events_around = DataStore.get_events_around.__func__
    get_state_for_event = DataStore.get_state_for_event.__func__
    get_state_for_events = DataStore.get_state_for_events.__func__
    get_state_groups = DataStore.get_state_groups.__func__
    get_state_groups_ids = DataStore.get_state_groups_ids.__func__
    get_state_ids_for_event = DataStore.get_state_ids_for_event.__func__
    get_state_ids_for_events = DataStore.get_state_ids_for_events.__func__
    get_joined_users_from_state = DataStore.get_joined_users_from_state.__func__
    get_joined_users_from_context = DataStore.get_joined_users_from_context.__func__
    _get_joined_users_from_context = (
        RoomMemberStore.__dict__["_get_joined_users_from_context"]
    )

    get_recent_events_for_room = DataStore.get_recent_events_for_room.__func__
    get_room_events_stream_for_rooms = (
        DataStore.get_room_events_stream_for_rooms.__func__
    )
    is_host_joined = DataStore.is_host_joined.__func__
    _is_host_joined = RoomMemberStore.__dict__["_is_host_joined"]
    get_stream_token_for_event = DataStore.get_stream_token_for_event.__func__

    _set_before_and_after = staticmethod(DataStore._set_before_and_after)

    _get_events = DataStore._get_events.__func__
    _get_events_from_cache = DataStore._get_events_from_cache.__func__

    _invalidate_get_event_cache = DataStore._invalidate_get_event_cache.__func__
    _enqueue_events = DataStore._enqueue_events.__func__
    _do_fetch = DataStore._do_fetch.__func__
    _fetch_event_rows = DataStore._fetch_event_rows.__func__
    _get_event_from_row = DataStore._get_event_from_row.__func__
    _get_rooms_for_user_where_membership_is_txn = (
        DataStore._get_rooms_for_user_where_membership_is_txn.__func__
    )
    _get_members_rows_txn = DataStore._get_members_rows_txn.__func__
    _get_state_for_groups = DataStore._get_state_for_groups.__func__
    _get_all_state_from_cache = DataStore._get_all_state_from_cache.__func__
    _get_events_around_txn = DataStore._get_events_around_txn.__func__
    _get_some_state_from_cache = DataStore._get_some_state_from_cache.__func__

    get_backfill_events = DataStore.get_backfill_events.__func__
    _get_backfill_events = DataStore._get_backfill_events.__func__
    get_missing_events = DataStore.get_missing_events.__func__
    _get_missing_events = DataStore._get_missing_events.__func__

    get_auth_chain = DataStore.get_auth_chain.__func__
    get_auth_chain_ids = DataStore.get_auth_chain_ids.__func__
    _get_auth_chain_ids_txn = DataStore._get_auth_chain_ids_txn.__func__

    get_room_max_stream_ordering = DataStore.get_room_max_stream_ordering.__func__

    get_forward_extremeties_for_room = (
        DataStore.get_forward_extremeties_for_room.__func__
    )
    _get_forward_extremeties_for_room = (
        EventFederationStore.__dict__["_get_forward_extremeties_for_room"]
    )

    def stream_positions(self):
        result = super(SlavedEventStore, self).stream_positions()
        result["events"] = self._stream_id_gen.get_current_token()
        result["backfill"] = -self._backfill_id_gen.get_current_token()
        return result

    def process_replication(self, result):
        state_resets = set(
            r[0] for r in result.get("state_resets", {"rows": []})["rows"]
        )

        stream = result.get("events")
        if stream:
            self._stream_id_gen.advance(int(stream["position"]))
            for row in stream["rows"]:
                self._process_replication_row(
                    row, backfilled=False, state_resets=state_resets
                )

        stream = result.get("backfill")
        if stream:
            self._backfill_id_gen.advance(-int(stream["position"]))
            for row in stream["rows"]:
                self._process_replication_row(
                    row, backfilled=True, state_resets=state_resets
                )

        stream = result.get("forward_ex_outliers")
        if stream:
            self._stream_id_gen.advance(int(stream["position"]))
            for row in stream["rows"]:
                event_id = row[1]
                self._invalidate_get_event_cache(event_id)

        stream = result.get("backward_ex_outliers")
        if stream:
            self._backfill_id_gen.advance(-int(stream["position"]))
            for row in stream["rows"]:
                event_id = row[1]
                self._invalidate_get_event_cache(event_id)

        return super(SlavedEventStore, self).process_replication(result)

    def _process_replication_row(self, row, backfilled, state_resets):
        position = row[0]
        internal = json.loads(row[1])
        event_json = json.loads(row[2])
        event = FrozenEvent(event_json, internal_metadata_dict=internal)
        self.invalidate_caches_for_event(
            event, backfilled, reset_state=position in state_resets
        )

    def invalidate_caches_for_event(self, event, backfilled, reset_state):
        if reset_state:
            self._get_current_state_for_key.invalidate_all()
            self.get_rooms_for_user.invalidate_all()
            self.get_users_in_room.invalidate((event.room_id,))

        self._invalidate_get_event_cache(event.event_id)

        self.get_latest_event_ids_in_room.invalidate((event.room_id,))

        self.get_unread_event_push_actions_by_room_for_user.invalidate_many(
            (event.room_id,)
        )

        if not backfilled:
            self._events_stream_cache.entity_has_changed(
                event.room_id, event.internal_metadata.stream_ordering
            )

        # self.get_unread_event_push_actions_by_room_for_user.invalidate_many(
        #     (event.room_id,)
        # )

        if event.type == EventTypes.Redaction:
            self._invalidate_get_event_cache(event.redacts)

        if event.type == EventTypes.Member:
            self.get_rooms_for_user.invalidate((event.state_key,))
            self.get_users_in_room.invalidate((event.room_id,))
            self._membership_stream_cache.entity_has_changed(
                event.state_key, event.internal_metadata.stream_ordering
            )
            self.get_invited_rooms_for_user.invalidate((event.state_key,))

        if not event.is_state():
            return

        if backfilled:
            return

        if (not event.internal_metadata.is_invite_from_remote()
                and event.internal_metadata.is_outlier()):
            return

        self._get_current_state_for_key.invalidate((
            event.room_id, event.type, event.state_key
        ))
Ejemplo n.º 52
0
    def __init__(self, database: DatabasePool, db_conn, hs):
        super().__init__(database, db_conn, hs)

        self._instance_name = hs.get_instance_name()

        # Map of (user_id, device_id) to the last stream_id that has been
        # deleted up to. This is so that we can no op deletions.
        self._last_device_delete_cache = ExpiringCache(
            cache_name="last_device_delete_cache",
            clock=self._clock,
            max_len=10000,
            expiry_ms=30 * 60 * 1000,
        )

        if isinstance(database.engine, PostgresEngine):
            self._can_write_to_device = (self._instance_name
                                         in hs.config.worker.writers.to_device)

            self._device_inbox_id_gen = MultiWriterIdGenerator(
                db_conn=db_conn,
                db=database,
                stream_name="to_device",
                instance_name=self._instance_name,
                tables=[("device_inbox", "instance_name", "stream_id")],
                sequence_name="device_inbox_sequence",
                writers=hs.config.worker.writers.to_device,
            )
        else:
            self._can_write_to_device = True
            self._device_inbox_id_gen = StreamIdGenerator(
                db_conn, "device_inbox", "stream_id")

        max_device_inbox_id = self._device_inbox_id_gen.get_current_token()
        device_inbox_prefill, min_device_inbox_id = self.db_pool.get_cache_dict(
            db_conn,
            "device_inbox",
            entity_column="user_id",
            stream_column="stream_id",
            max_value=max_device_inbox_id,
            limit=1000,
        )
        self._device_inbox_stream_cache = StreamChangeCache(
            "DeviceInboxStreamChangeCache",
            min_device_inbox_id,
            prefilled_cache=device_inbox_prefill,
        )

        # The federation outbox and the local device inbox uses the same
        # stream_id generator.
        device_outbox_prefill, min_device_outbox_id = self.db_pool.get_cache_dict(
            db_conn,
            "device_federation_outbox",
            entity_column="destination",
            stream_column="stream_id",
            max_value=max_device_inbox_id,
            limit=1000,
        )
        self._device_federation_outbox_stream_cache = StreamChangeCache(
            "DeviceFederationOutboxStreamChangeCache",
            min_device_outbox_id,
            prefilled_cache=device_outbox_prefill,
        )
Ejemplo n.º 53
0
class DeviceInboxWorkerStore(SQLBaseStore):
    def __init__(self, database: DatabasePool, db_conn, hs):
        super().__init__(database, db_conn, hs)

        self._instance_name = hs.get_instance_name()

        # Map of (user_id, device_id) to the last stream_id that has been
        # deleted up to. This is so that we can no op deletions.
        self._last_device_delete_cache = ExpiringCache(
            cache_name="last_device_delete_cache",
            clock=self._clock,
            max_len=10000,
            expiry_ms=30 * 60 * 1000,
        )

        if isinstance(database.engine, PostgresEngine):
            self._can_write_to_device = (self._instance_name
                                         in hs.config.worker.writers.to_device)

            self._device_inbox_id_gen = MultiWriterIdGenerator(
                db_conn=db_conn,
                db=database,
                stream_name="to_device",
                instance_name=self._instance_name,
                tables=[("device_inbox", "instance_name", "stream_id")],
                sequence_name="device_inbox_sequence",
                writers=hs.config.worker.writers.to_device,
            )
        else:
            self._can_write_to_device = True
            self._device_inbox_id_gen = StreamIdGenerator(
                db_conn, "device_inbox", "stream_id")

        max_device_inbox_id = self._device_inbox_id_gen.get_current_token()
        device_inbox_prefill, min_device_inbox_id = self.db_pool.get_cache_dict(
            db_conn,
            "device_inbox",
            entity_column="user_id",
            stream_column="stream_id",
            max_value=max_device_inbox_id,
            limit=1000,
        )
        self._device_inbox_stream_cache = StreamChangeCache(
            "DeviceInboxStreamChangeCache",
            min_device_inbox_id,
            prefilled_cache=device_inbox_prefill,
        )

        # The federation outbox and the local device inbox uses the same
        # stream_id generator.
        device_outbox_prefill, min_device_outbox_id = self.db_pool.get_cache_dict(
            db_conn,
            "device_federation_outbox",
            entity_column="destination",
            stream_column="stream_id",
            max_value=max_device_inbox_id,
            limit=1000,
        )
        self._device_federation_outbox_stream_cache = StreamChangeCache(
            "DeviceFederationOutboxStreamChangeCache",
            min_device_outbox_id,
            prefilled_cache=device_outbox_prefill,
        )

    def process_replication_rows(self, stream_name, instance_name, token,
                                 rows):
        if stream_name == ToDeviceStream.NAME:
            self._device_inbox_id_gen.advance(instance_name, token)
            for row in rows:
                if row.entity.startswith("@"):
                    self._device_inbox_stream_cache.entity_has_changed(
                        row.entity, token)
                else:
                    self._device_federation_outbox_stream_cache.entity_has_changed(
                        row.entity, token)
        return super().process_replication_rows(stream_name, instance_name,
                                                token, rows)

    def get_to_device_stream_token(self):
        return self._device_inbox_id_gen.get_current_token()

    async def get_new_messages_for_device(
        self,
        user_id: str,
        device_id: str,
        last_stream_id: int,
        current_stream_id: int,
        limit: int = 100,
    ) -> Tuple[List[dict], int]:
        """
        Args:
            user_id: The recipient user_id.
            device_id: The recipient device_id.
            last_stream_id: The last stream ID checked.
            current_stream_id: The current position of the to device
                message stream.
            limit: The maximum number of messages to retrieve.

        Returns:
            A list of messages for the device and where in the stream the messages got to.
        """
        has_changed = self._device_inbox_stream_cache.has_entity_changed(
            user_id, last_stream_id)
        if not has_changed:
            return ([], current_stream_id)

        def get_new_messages_for_device_txn(txn):
            sql = ("SELECT stream_id, message_json FROM device_inbox"
                   " WHERE user_id = ? AND device_id = ?"
                   " AND ? < stream_id AND stream_id <= ?"
                   " ORDER BY stream_id ASC"
                   " LIMIT ?")
            txn.execute(
                sql,
                (user_id, device_id, last_stream_id, current_stream_id, limit))
            messages = []
            for row in txn:
                stream_pos = row[0]
                messages.append(db_to_json(row[1]))
            if len(messages) < limit:
                stream_pos = current_stream_id
            return messages, stream_pos

        return await self.db_pool.runInteraction(
            "get_new_messages_for_device", get_new_messages_for_device_txn)

    @trace
    async def delete_messages_for_device(self, user_id: str, device_id: str,
                                         up_to_stream_id: int) -> int:
        """
        Args:
            user_id: The recipient user_id.
            device_id: The recipient device_id.
            up_to_stream_id: Where to delete messages up to.

        Returns:
            The number of messages deleted.
        """
        # If we have cached the last stream id we've deleted up to, we can
        # check if there is likely to be anything that needs deleting
        last_deleted_stream_id = self._last_device_delete_cache.get(
            (user_id, device_id), None)

        set_tag("last_deleted_stream_id", last_deleted_stream_id)

        if last_deleted_stream_id:
            has_changed = self._device_inbox_stream_cache.has_entity_changed(
                user_id, last_deleted_stream_id)
            if not has_changed:
                log_kv({"message": "No changes in cache since last check"})
                return 0

        def delete_messages_for_device_txn(txn):
            sql = ("DELETE FROM device_inbox"
                   " WHERE user_id = ? AND device_id = ?"
                   " AND stream_id <= ?")
            txn.execute(sql, (user_id, device_id, up_to_stream_id))
            return txn.rowcount

        count = await self.db_pool.runInteraction(
            "delete_messages_for_device", delete_messages_for_device_txn)

        log_kv({
            "message": "deleted {} messages for device".format(count),
            "count": count
        })

        # Update the cache, ensuring that we only ever increase the value
        last_deleted_stream_id = self._last_device_delete_cache.get(
            (user_id, device_id), 0)
        self._last_device_delete_cache[(user_id, device_id)] = max(
            last_deleted_stream_id, up_to_stream_id)

        return count

    @trace
    async def get_new_device_msgs_for_remote(self, destination, last_stream_id,
                                             current_stream_id,
                                             limit) -> Tuple[List[dict], int]:
        """
        Args:
            destination(str): The name of the remote server.
            last_stream_id(int|long): The last position of the device message stream
                that the server sent up to.
            current_stream_id(int|long): The current position of the device
                message stream.
        Returns:
            A list of messages for the device and where in the stream the messages got to.
        """

        set_tag("destination", destination)
        set_tag("last_stream_id", last_stream_id)
        set_tag("current_stream_id", current_stream_id)
        set_tag("limit", limit)

        has_changed = self._device_federation_outbox_stream_cache.has_entity_changed(
            destination, last_stream_id)
        if not has_changed or last_stream_id == current_stream_id:
            log_kv({"message": "No new messages in stream"})
            return ([], current_stream_id)

        if limit <= 0:
            # This can happen if we run out of room for EDUs in the transaction.
            return ([], last_stream_id)

        @trace
        def get_new_messages_for_remote_destination_txn(txn):
            sql = (
                "SELECT stream_id, messages_json FROM device_federation_outbox"
                " WHERE destination = ?"
                " AND ? < stream_id AND stream_id <= ?"
                " ORDER BY stream_id ASC"
                " LIMIT ?")
            txn.execute(
                sql, (destination, last_stream_id, current_stream_id, limit))
            messages = []
            for row in txn:
                stream_pos = row[0]
                messages.append(db_to_json(row[1]))
            if len(messages) < limit:
                log_kv({"message": "Set stream position to current position"})
                stream_pos = current_stream_id
            return messages, stream_pos

        return await self.db_pool.runInteraction(
            "get_new_device_msgs_for_remote",
            get_new_messages_for_remote_destination_txn,
        )

    @trace
    async def delete_device_msgs_for_remote(self, destination: str,
                                            up_to_stream_id: int) -> None:
        """Used to delete messages when the remote destination acknowledges
        their receipt.

        Args:
            destination: The destination server_name
            up_to_stream_id: Where to delete messages up to.
        """
        def delete_messages_for_remote_destination_txn(txn):
            sql = ("DELETE FROM device_federation_outbox"
                   " WHERE destination = ?"
                   " AND stream_id <= ?")
            txn.execute(sql, (destination, up_to_stream_id))

        await self.db_pool.runInteraction(
            "delete_device_msgs_for_remote",
            delete_messages_for_remote_destination_txn)

    async def get_all_new_device_messages(
            self, instance_name: str, last_id: int, current_id: int,
            limit: int) -> Tuple[List[Tuple[int, tuple]], int, bool]:
        """Get updates for to device replication stream.

        Args:
            instance_name: The writer we want to fetch updates from. Unused
                here since there is only ever one writer.
            last_id: The token to fetch updates from. Exclusive.
            current_id: The token to fetch updates up to. Inclusive.
            limit: The requested limit for the number of rows to return. The
                function may return more or fewer rows.

        Returns:
            A tuple consisting of: the updates, a token to use to fetch
            subsequent updates, and whether we returned fewer rows than exists
            between the requested tokens due to the limit.

            The token returned can be used in a subsequent call to this
            function to get further updatees.

            The updates are a list of 2-tuples of stream ID and the row data
        """

        if last_id == current_id:
            return [], current_id, False

        def get_all_new_device_messages_txn(txn):
            # We limit like this as we might have multiple rows per stream_id, and
            # we want to make sure we always get all entries for any stream_id
            # we return.
            upper_pos = min(current_id, last_id + limit)
            sql = ("SELECT max(stream_id), user_id"
                   " FROM device_inbox"
                   " WHERE ? < stream_id AND stream_id <= ?"
                   " GROUP BY user_id")
            txn.execute(sql, (last_id, upper_pos))
            updates = [(row[0], row[1:]) for row in txn]

            sql = ("SELECT max(stream_id), destination"
                   " FROM device_federation_outbox"
                   " WHERE ? < stream_id AND stream_id <= ?"
                   " GROUP BY destination")
            txn.execute(sql, (last_id, upper_pos))
            updates.extend((row[0], row[1:]) for row in txn)

            # Order by ascending stream ordering
            updates.sort()

            limited = False
            upto_token = current_id
            if len(updates) >= limit:
                upto_token = updates[-1][0]
                limited = True

            return updates, upto_token, limited

        return await self.db_pool.runInteraction(
            "get_all_new_device_messages", get_all_new_device_messages_txn)

    @trace
    async def add_messages_to_device_inbox(
        self,
        local_messages_by_user_then_device: dict,
        remote_messages_by_destination: dict,
    ) -> int:
        """Used to send messages from this server.

        Args:
            local_messages_by_user_and_device:
                Dictionary of user_id to device_id to message.
            remote_messages_by_destination:
                Dictionary of destination server_name to the EDU JSON to send.

        Returns:
            The new stream_id.
        """

        assert self._can_write_to_device

        def add_messages_txn(txn, now_ms, stream_id):
            # Add the local messages directly to the local inbox.
            self._add_messages_to_local_device_inbox_txn(
                txn, stream_id, local_messages_by_user_then_device)

            # Add the remote messages to the federation outbox.
            # We'll send them to a remote server when we next send a
            # federation transaction to that destination.
            self.db_pool.simple_insert_many_txn(
                txn,
                table="device_federation_outbox",
                values=[{
                    "destination": destination,
                    "stream_id": stream_id,
                    "queued_ts": now_ms,
                    "messages_json": json_encoder.encode(edu),
                    "instance_name": self._instance_name,
                } for destination, edu in
                        remote_messages_by_destination.items()],
            )

        async with self._device_inbox_id_gen.get_next() as stream_id:
            now_ms = self.clock.time_msec()
            await self.db_pool.runInteraction("add_messages_to_device_inbox",
                                              add_messages_txn, now_ms,
                                              stream_id)
            for user_id in local_messages_by_user_then_device.keys():
                self._device_inbox_stream_cache.entity_has_changed(
                    user_id, stream_id)
            for destination in remote_messages_by_destination.keys():
                self._device_federation_outbox_stream_cache.entity_has_changed(
                    destination, stream_id)

        return self._device_inbox_id_gen.get_current_token()

    async def add_messages_from_remote_to_device_inbox(
            self, origin: str, message_id: str,
            local_messages_by_user_then_device: dict) -> int:
        assert self._can_write_to_device

        def add_messages_txn(txn, now_ms, stream_id):
            # Check if we've already inserted a matching message_id for that
            # origin. This can happen if the origin doesn't receive our
            # acknowledgement from the first time we received the message.
            already_inserted = self.db_pool.simple_select_one_txn(
                txn,
                table="device_federation_inbox",
                keyvalues={
                    "origin": origin,
                    "message_id": message_id
                },
                retcols=("message_id", ),
                allow_none=True,
            )
            if already_inserted is not None:
                return

            # Add an entry for this message_id so that we know we've processed
            # it.
            self.db_pool.simple_insert_txn(
                txn,
                table="device_federation_inbox",
                values={
                    "origin": origin,
                    "message_id": message_id,
                    "received_ts": now_ms,
                },
            )

            # Add the messages to the approriate local device inboxes so that
            # they'll be sent to the devices when they next sync.
            self._add_messages_to_local_device_inbox_txn(
                txn, stream_id, local_messages_by_user_then_device)

        async with self._device_inbox_id_gen.get_next() as stream_id:
            now_ms = self.clock.time_msec()
            await self.db_pool.runInteraction(
                "add_messages_from_remote_to_device_inbox",
                add_messages_txn,
                now_ms,
                stream_id,
            )
            for user_id in local_messages_by_user_then_device.keys():
                self._device_inbox_stream_cache.entity_has_changed(
                    user_id, stream_id)

        return stream_id

    def _add_messages_to_local_device_inbox_txn(self, txn, stream_id,
                                                messages_by_user_then_device):
        assert self._can_write_to_device

        local_by_user_then_device = {}
        for user_id, messages_by_device in messages_by_user_then_device.items(
        ):
            messages_json_for_user = {}
            devices = list(messages_by_device.keys())
            if len(devices) == 1 and devices[0] == "*":
                # Handle wildcard device_ids.
                devices = self.db_pool.simple_select_onecol_txn(
                    txn,
                    table="devices",
                    keyvalues={"user_id": user_id},
                    retcol="device_id",
                )

                message_json = json_encoder.encode(messages_by_device["*"])
                for device_id in devices:
                    # Add the message for all devices for this user on this
                    # server.
                    messages_json_for_user[device_id] = message_json
            else:
                if not devices:
                    continue

                rows = self.db_pool.simple_select_many_txn(
                    txn,
                    table="devices",
                    keyvalues={"user_id": user_id},
                    column="device_id",
                    iterable=devices,
                    retcols=("device_id", ),
                )

                for row in rows:
                    # Only insert into the local inbox if the device exists on
                    # this server
                    device_id = row["device_id"]
                    message_json = json_encoder.encode(
                        messages_by_device[device_id])
                    messages_json_for_user[device_id] = message_json

            if messages_json_for_user:
                local_by_user_then_device[user_id] = messages_json_for_user

        if not local_by_user_then_device:
            return

        self.db_pool.simple_insert_many_txn(
            txn,
            table="device_inbox",
            values=[{
                "user_id": user_id,
                "device_id": device_id,
                "stream_id": stream_id,
                "message_json": message_json,
                "instance_name": self._instance_name,
            } for user_id, messages_by_device in
                    local_by_user_then_device.items()
                    for device_id, message_json in messages_by_device.items()],
        )
Ejemplo n.º 54
0
class AccountDataWorkerStore(SQLBaseStore):
    """This is an abstract base class where subclasses must implement
    `get_max_account_data_stream_id` which can be called in the initializer.
    """

    # This ABCMeta metaclass ensures that we cannot be instantiated without
    # the abstract methods being implemented.
    __metaclass__ = abc.ABCMeta

    def __init__(self, db_conn, hs):
        account_max = self.get_max_account_data_stream_id()
        self._account_data_stream_cache = StreamChangeCache(
            "AccountDataAndTagsChangeCache", account_max,
        )

        super(AccountDataWorkerStore, self).__init__(db_conn, hs)

    @abc.abstractmethod
    def get_max_account_data_stream_id(self):
        """Get the current max stream ID for account data stream

        Returns:
            int
        """
        raise NotImplementedError()

    @cached()
    def get_account_data_for_user(self, user_id):
        """Get all the client account_data for a user.

        Args:
            user_id(str): The user to get the account_data for.
        Returns:
            A deferred pair of a dict of global account_data and a dict
            mapping from room_id string to per room account_data dicts.
        """

        def get_account_data_for_user_txn(txn):
            rows = self._simple_select_list_txn(
                txn, "account_data", {"user_id": user_id},
                ["account_data_type", "content"]
            )

            global_account_data = {
                row["account_data_type"]: json.loads(row["content"]) for row in rows
            }

            rows = self._simple_select_list_txn(
                txn, "room_account_data", {"user_id": user_id},
                ["room_id", "account_data_type", "content"]
            )

            by_room = {}
            for row in rows:
                room_data = by_room.setdefault(row["room_id"], {})
                room_data[row["account_data_type"]] = json.loads(row["content"])

            return (global_account_data, by_room)

        return self.runInteraction(
            "get_account_data_for_user", get_account_data_for_user_txn
        )

    @cachedInlineCallbacks(num_args=2, max_entries=5000)
    def get_global_account_data_by_type_for_user(self, data_type, user_id):
        """
        Returns:
            Deferred: A dict
        """
        result = yield self._simple_select_one_onecol(
            table="account_data",
            keyvalues={
                "user_id": user_id,
                "account_data_type": data_type,
            },
            retcol="content",
            desc="get_global_account_data_by_type_for_user",
            allow_none=True,
        )

        if result:
            defer.returnValue(json.loads(result))
        else:
            defer.returnValue(None)

    @cached(num_args=2)
    def get_account_data_for_room(self, user_id, room_id):
        """Get all the client account_data for a user for a room.

        Args:
            user_id(str): The user to get the account_data for.
            room_id(str): The room to get the account_data for.
        Returns:
            A deferred dict of the room account_data
        """
        def get_account_data_for_room_txn(txn):
            rows = self._simple_select_list_txn(
                txn, "room_account_data", {"user_id": user_id, "room_id": room_id},
                ["account_data_type", "content"]
            )

            return {
                row["account_data_type"]: json.loads(row["content"]) for row in rows
            }

        return self.runInteraction(
            "get_account_data_for_room", get_account_data_for_room_txn
        )

    @cached(num_args=3, max_entries=5000)
    def get_account_data_for_room_and_type(self, user_id, room_id, account_data_type):
        """Get the client account_data of given type for a user for a room.

        Args:
            user_id(str): The user to get the account_data for.
            room_id(str): The room to get the account_data for.
            account_data_type (str): The account data type to get.
        Returns:
            A deferred of the room account_data for that type, or None if
            there isn't any set.
        """
        def get_account_data_for_room_and_type_txn(txn):
            content_json = self._simple_select_one_onecol_txn(
                txn,
                table="room_account_data",
                keyvalues={
                    "user_id": user_id,
                    "room_id": room_id,
                    "account_data_type": account_data_type,
                },
                retcol="content",
                allow_none=True
            )

            return json.loads(content_json) if content_json else None

        return self.runInteraction(
            "get_account_data_for_room_and_type",
            get_account_data_for_room_and_type_txn,
        )

    def get_all_updated_account_data(self, last_global_id, last_room_id,
                                     current_id, limit):
        """Get all the client account_data that has changed on the server
        Args:
            last_global_id(int): The position to fetch from for top level data
            last_room_id(int): The position to fetch from for per room data
            current_id(int): The position to fetch up to.
        Returns:
            A deferred pair of lists of tuples of stream_id int, user_id string,
            room_id string, type string, and content string.
        """
        if last_room_id == current_id and last_global_id == current_id:
            return defer.succeed(([], []))

        def get_updated_account_data_txn(txn):
            sql = (
                "SELECT stream_id, user_id, account_data_type, content"
                " FROM account_data WHERE ? < stream_id AND stream_id <= ?"
                " ORDER BY stream_id ASC LIMIT ?"
            )
            txn.execute(sql, (last_global_id, current_id, limit))
            global_results = txn.fetchall()

            sql = (
                "SELECT stream_id, user_id, room_id, account_data_type, content"
                " FROM room_account_data WHERE ? < stream_id AND stream_id <= ?"
                " ORDER BY stream_id ASC LIMIT ?"
            )
            txn.execute(sql, (last_room_id, current_id, limit))
            room_results = txn.fetchall()
            return (global_results, room_results)
        return self.runInteraction(
            "get_all_updated_account_data_txn", get_updated_account_data_txn
        )

    def get_updated_account_data_for_user(self, user_id, stream_id):
        """Get all the client account_data for a that's changed for a user

        Args:
            user_id(str): The user to get the account_data for.
            stream_id(int): The point in the stream since which to get updates
        Returns:
            A deferred pair of a dict of global account_data and a dict
            mapping from room_id string to per room account_data dicts.
        """

        def get_updated_account_data_for_user_txn(txn):
            sql = (
                "SELECT account_data_type, content FROM account_data"
                " WHERE user_id = ? AND stream_id > ?"
            )

            txn.execute(sql, (user_id, stream_id))

            global_account_data = {
                row[0]: json.loads(row[1]) for row in txn
            }

            sql = (
                "SELECT room_id, account_data_type, content FROM room_account_data"
                " WHERE user_id = ? AND stream_id > ?"
            )

            txn.execute(sql, (user_id, stream_id))

            account_data_by_room = {}
            for row in txn:
                room_account_data = account_data_by_room.setdefault(row[0], {})
                room_account_data[row[1]] = json.loads(row[2])

            return (global_account_data, account_data_by_room)

        changed = self._account_data_stream_cache.has_entity_changed(
            user_id, int(stream_id)
        )
        if not changed:
            return ({}, {})

        return self.runInteraction(
            "get_updated_account_data_for_user", get_updated_account_data_for_user_txn
        )

    @cachedInlineCallbacks(num_args=2, cache_context=True, max_entries=5000)
    def is_ignored_by(self, ignored_user_id, ignorer_user_id, cache_context):
        ignored_account_data = yield self.get_global_account_data_by_type_for_user(
            "m.ignored_user_list", ignorer_user_id,
            on_invalidate=cache_context.invalidate,
        )
        if not ignored_account_data:
            defer.returnValue(False)

        defer.returnValue(
            ignored_user_id in ignored_account_data.get("ignored_users", {})
        )
Ejemplo n.º 55
0
class PushRulesWorkerStore(
        ApplicationServiceWorkerStore,
        ReceiptsWorkerStore,
        PusherWorkerStore,
        RoomMemberWorkerStore,
        EventsWorkerStore,
        SQLBaseStore,
        metaclass=abc.ABCMeta,
):
    """This is an abstract base class where subclasses must implement
    `get_max_push_rules_stream_id` which can be called in the initializer.
    """
    def __init__(
        self,
        database: DatabasePool,
        db_conn: LoggingDatabaseConnection,
        hs: "HomeServer",
    ):
        super().__init__(database, db_conn, hs)

        if hs.config.worker.worker_app is None:
            self._push_rules_stream_id_gen: AbstractStreamIdTracker = StreamIdGenerator(
                db_conn, "push_rules_stream", "stream_id")
        else:
            self._push_rules_stream_id_gen = SlavedIdTracker(
                db_conn, "push_rules_stream", "stream_id")

        push_rules_prefill, push_rules_id = self.db_pool.get_cache_dict(
            db_conn,
            "push_rules_stream",
            entity_column="user_id",
            stream_column="stream_id",
            max_value=self.get_max_push_rules_stream_id(),
        )

        self.push_rules_stream_cache = StreamChangeCache(
            "PushRulesStreamChangeCache",
            push_rules_id,
            prefilled_cache=push_rules_prefill,
        )

    @abc.abstractmethod
    def get_max_push_rules_stream_id(self) -> int:
        """Get the position of the push rules stream.

        Returns:
            int
        """
        raise NotImplementedError()

    @cached(max_entries=5000)
    async def get_push_rules_for_user(self, user_id: str) -> List[JsonDict]:
        rows = await self.db_pool.simple_select_list(
            table="push_rules",
            keyvalues={"user_name": user_id},
            retcols=(
                "user_name",
                "rule_id",
                "priority_class",
                "priority",
                "conditions",
                "actions",
            ),
            desc="get_push_rules_for_user",
        )

        rows.sort(key=lambda row:
                  (-int(row["priority_class"]), -int(row["priority"])))

        enabled_map = await self.get_push_rules_enabled_for_user(user_id)

        return _load_rules(rows, enabled_map, self.hs.config.experimental)

    @cached(max_entries=5000)
    async def get_push_rules_enabled_for_user(self,
                                              user_id: str) -> Dict[str, bool]:
        results = await self.db_pool.simple_select_list(
            table="push_rules_enable",
            keyvalues={"user_name": user_id},
            retcols=("rule_id", "enabled"),
            desc="get_push_rules_enabled_for_user",
        )
        return {r["rule_id"]: bool(r["enabled"]) for r in results}

    async def have_push_rules_changed_for_user(self, user_id: str,
                                               last_id: int) -> bool:
        if not self.push_rules_stream_cache.has_entity_changed(
                user_id, last_id):
            return False
        else:

            def have_push_rules_changed_txn(txn: LoggingTransaction) -> bool:
                sql = ("SELECT COUNT(stream_id) FROM push_rules_stream"
                       " WHERE user_id = ? AND ? < stream_id")
                txn.execute(sql, (user_id, last_id))
                (count, ) = cast(Tuple[int], txn.fetchone())
                return bool(count)

            return await self.db_pool.runInteraction(
                "have_push_rules_changed", have_push_rules_changed_txn)

    @cachedList(cached_method_name="get_push_rules_for_user",
                list_name="user_ids")
    async def bulk_get_push_rules(
            self, user_ids: Collection[str]) -> Dict[str, List[JsonDict]]:
        if not user_ids:
            return {}

        results: Dict[str,
                      List[JsonDict]] = {user_id: []
                                         for user_id in user_ids}

        rows = await self.db_pool.simple_select_many_batch(
            table="push_rules",
            column="user_name",
            iterable=user_ids,
            retcols=("*", ),
            desc="bulk_get_push_rules",
        )

        rows.sort(key=lambda row:
                  (-int(row["priority_class"]), -int(row["priority"])))

        for row in rows:
            results.setdefault(row["user_name"], []).append(row)

        enabled_map_by_user = await self.bulk_get_push_rules_enabled(user_ids)

        for user_id, rules in results.items():
            results[user_id] = _load_rules(
                rules, enabled_map_by_user.get(user_id, {}),
                self.hs.config.experimental)

        return results

    @cachedList(cached_method_name="get_push_rules_enabled_for_user",
                list_name="user_ids")
    async def bulk_get_push_rules_enabled(
            self, user_ids: Collection[str]) -> Dict[str, Dict[str, bool]]:
        if not user_ids:
            return {}

        results: Dict[str, Dict[str,
                                bool]] = {user_id: {}
                                          for user_id in user_ids}

        rows = await self.db_pool.simple_select_many_batch(
            table="push_rules_enable",
            column="user_name",
            iterable=user_ids,
            retcols=("user_name", "rule_id", "enabled"),
            desc="bulk_get_push_rules_enabled",
        )
        for row in rows:
            enabled = bool(row["enabled"])
            results.setdefault(row["user_name"], {})[row["rule_id"]] = enabled
        return results

    async def get_all_push_rule_updates(
            self, instance_name: str, last_id: int, current_id: int,
            limit: int) -> Tuple[List[Tuple[int, Tuple[str]]], int, bool]:
        """Get updates for push_rules replication stream.

        Args:
            instance_name: The writer we want to fetch updates from. Unused
                here since there is only ever one writer.
            last_id: The token to fetch updates from. Exclusive.
            current_id: The token to fetch updates up to. Inclusive.
            limit: The requested limit for the number of rows to return. The
                function may return more or fewer rows.

        Returns:
            A tuple consisting of: the updates, a token to use to fetch
            subsequent updates, and whether we returned fewer rows than exists
            between the requested tokens due to the limit.

            The token returned can be used in a subsequent call to this
            function to get further updatees.

            The updates are a list of 2-tuples of stream ID and the row data
        """

        if last_id == current_id:
            return [], current_id, False

        def get_all_push_rule_updates_txn(
            txn: LoggingTransaction,
        ) -> Tuple[List[Tuple[int, Tuple[str]]], int, bool]:
            sql = """
                SELECT stream_id, user_id
                FROM push_rules_stream
                WHERE ? < stream_id AND stream_id <= ?
                ORDER BY stream_id ASC
                LIMIT ?
            """
            txn.execute(sql, (last_id, current_id, limit))
            updates = cast(
                List[Tuple[int, Tuple[str]]],
                [(stream_id, (user_id, )) for stream_id, user_id in txn],
            )

            limited = False
            upper_bound = current_id
            if len(updates) == limit:
                limited = True
                upper_bound = updates[-1][0]

            return updates, upper_bound, limited

        return await self.db_pool.runInteraction(
            "get_all_push_rule_updates", get_all_push_rule_updates_txn)
Ejemplo n.º 56
0
class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
    """This is an abstract base class where subclasses must implement
    `get_room_max_stream_ordering` and `get_room_min_stream_ordering`
    which can be called in the initializer.
    """

    __metaclass__ = abc.ABCMeta

    def __init__(self, db_conn, hs):
        super(StreamWorkerStore, self).__init__(db_conn, hs)

        events_max = self.get_room_max_stream_ordering()
        event_cache_prefill, min_event_val = self._get_cache_dict(
            db_conn,
            "events",
            entity_column="room_id",
            stream_column="stream_ordering",
            max_value=events_max,
        )
        self._events_stream_cache = StreamChangeCache(
            "EventsRoomStreamChangeCache",
            min_event_val,
            prefilled_cache=event_cache_prefill,
        )
        self._membership_stream_cache = StreamChangeCache(
            "MembershipStreamChangeCache", events_max
        )

        self._stream_order_on_start = self.get_room_max_stream_ordering()

    @abc.abstractmethod
    def get_room_max_stream_ordering(self):
        raise NotImplementedError()

    @abc.abstractmethod
    def get_room_min_stream_ordering(self):
        raise NotImplementedError()

    @defer.inlineCallbacks
    def get_room_events_stream_for_rooms(
        self, room_ids, from_key, to_key, limit=0, order='DESC'
    ):
        """Get new room events in stream ordering since `from_key`.

        Args:
            room_id (str)
            from_key (str): Token from which no events are returned before
            to_key (str): Token from which no events are returned after. (This
                is typically the current stream token)
            limit (int): Maximum number of events to return
            order (str): Either "DESC" or "ASC". Determines which events are
                returned when the result is limited. If "DESC" then the most
                recent `limit` events are returned, otherwise returns the
                oldest `limit` events.

        Returns:
            Deferred[dict[str,tuple[list[FrozenEvent], str]]]
                A map from room id to a tuple containing:
                    - list of recent events in the room
                    - stream ordering key for the start of the chunk of events returned.
        """
        from_id = RoomStreamToken.parse_stream_token(from_key).stream

        room_ids = yield self._events_stream_cache.get_entities_changed(
            room_ids, from_id
        )

        if not room_ids:
            defer.returnValue({})

        results = {}
        room_ids = list(room_ids)
        for rm_ids in (room_ids[i : i + 20] for i in range(0, len(room_ids), 20)):
            res = yield make_deferred_yieldable(
                defer.gatherResults(
                    [
                        run_in_background(
                            self.get_room_events_stream_for_room,
                            room_id,
                            from_key,
                            to_key,
                            limit,
                            order=order,
                        )
                        for room_id in rm_ids
                    ],
                    consumeErrors=True,
                )
            )
            results.update(dict(zip(rm_ids, res)))

        defer.returnValue(results)

    def get_rooms_that_changed(self, room_ids, from_key):
        """Given a list of rooms and a token, return rooms where there may have
        been changes.

        Args:
            room_ids (list)
            from_key (str): The room_key portion of a StreamToken
        """
        from_key = RoomStreamToken.parse_stream_token(from_key).stream
        return set(
            room_id
            for room_id in room_ids
            if self._events_stream_cache.has_entity_changed(room_id, from_key)
        )

    @defer.inlineCallbacks
    def get_room_events_stream_for_room(
        self, room_id, from_key, to_key, limit=0, order='DESC'
    ):

        """Get new room events in stream ordering since `from_key`.

        Args:
            room_id (str)
            from_key (str): Token from which no events are returned before
            to_key (str): Token from which no events are returned after. (This
                is typically the current stream token)
            limit (int): Maximum number of events to return
            order (str): Either "DESC" or "ASC". Determines which events are
                returned when the result is limited. If "DESC" then the most
                recent `limit` events are returned, otherwise returns the
                oldest `limit` events.

        Returns:
            Deferred[tuple[list[FrozenEvent], str]]: Returns the list of
            events (in ascending order) and the token from the start of
            the chunk of events returned.
        """
        if from_key == to_key:
            defer.returnValue(([], from_key))

        from_id = RoomStreamToken.parse_stream_token(from_key).stream
        to_id = RoomStreamToken.parse_stream_token(to_key).stream

        has_changed = yield self._events_stream_cache.has_entity_changed(
            room_id, from_id
        )

        if not has_changed:
            defer.returnValue(([], from_key))

        def f(txn):
            sql = (
                "SELECT event_id, stream_ordering FROM events WHERE"
                " room_id = ?"
                " AND not outlier"
                " AND stream_ordering > ? AND stream_ordering <= ?"
                " ORDER BY stream_ordering %s LIMIT ?"
            ) % (order,)
            txn.execute(sql, (room_id, from_id, to_id, limit))

            rows = [_EventDictReturn(row[0], None, row[1]) for row in txn]
            return rows

        rows = yield self.runInteraction("get_room_events_stream_for_room", f)

        ret = yield self.get_events_as_list([
            r.event_id for r in rows], get_prev_content=True,
        )

        self._set_before_and_after(ret, rows, topo_order=from_id is None)

        if order.lower() == "desc":
            ret.reverse()

        if rows:
            key = "s%d" % min(r.stream_ordering for r in rows)
        else:
            # Assume we didn't get anything because there was nothing to
            # get.
            key = from_key

        defer.returnValue((ret, key))

    @defer.inlineCallbacks
    def get_membership_changes_for_user(self, user_id, from_key, to_key):
        from_id = RoomStreamToken.parse_stream_token(from_key).stream
        to_id = RoomStreamToken.parse_stream_token(to_key).stream

        if from_key == to_key:
            defer.returnValue([])

        if from_id:
            has_changed = self._membership_stream_cache.has_entity_changed(
                user_id, int(from_id)
            )
            if not has_changed:
                defer.returnValue([])

        def f(txn):
            sql = (
                "SELECT m.event_id, stream_ordering FROM events AS e,"
                " room_memberships AS m"
                " WHERE e.event_id = m.event_id"
                " AND m.user_id = ?"
                " AND e.stream_ordering > ? AND e.stream_ordering <= ?"
                " ORDER BY e.stream_ordering ASC"
            )
            txn.execute(sql, (user_id, from_id, to_id))

            rows = [_EventDictReturn(row[0], None, row[1]) for row in txn]

            return rows

        rows = yield self.runInteraction("get_membership_changes_for_user", f)

        ret = yield self.get_events_as_list(
            [r.event_id for r in rows], get_prev_content=True,
        )

        self._set_before_and_after(ret, rows, topo_order=False)

        defer.returnValue(ret)

    @defer.inlineCallbacks
    def get_recent_events_for_room(self, room_id, limit, end_token):
        """Get the most recent events in the room in topological ordering.

        Args:
            room_id (str)
            limit (int)
            end_token (str): The stream token representing now.

        Returns:
            Deferred[tuple[list[FrozenEvent], str]]: Returns a list of
            events and a token pointing to the start of the returned
            events.
            The events returned are in ascending order.
        """

        rows, token = yield self.get_recent_event_ids_for_room(
            room_id, limit, end_token
        )

        logger.debug("stream before")
        events = yield self.get_events_as_list(
            [r.event_id for r in rows], get_prev_content=True
        )
        logger.debug("stream after")

        self._set_before_and_after(events, rows)

        defer.returnValue((events, token))

    @defer.inlineCallbacks
    def get_recent_event_ids_for_room(self, room_id, limit, end_token):
        """Get the most recent events in the room in topological ordering.

        Args:
            room_id (str)
            limit (int)
            end_token (str): The stream token representing now.

        Returns:
            Deferred[tuple[list[_EventDictReturn], str]]: Returns a list of
            _EventDictReturn and a token pointing to the start of the returned
            events.
            The events returned are in ascending order.
        """
        # Allow a zero limit here, and no-op.
        if limit == 0:
            defer.returnValue(([], end_token))

        end_token = RoomStreamToken.parse(end_token)

        rows, token = yield self.runInteraction(
            "get_recent_event_ids_for_room",
            self._paginate_room_events_txn,
            room_id,
            from_token=end_token,
            limit=limit,
        )

        # We want to return the results in ascending order.
        rows.reverse()

        defer.returnValue((rows, token))

    def get_room_event_after_stream_ordering(self, room_id, stream_ordering):
        """Gets details of the first event in a room at or after a stream ordering

        Args:
            room_id (str):
            stream_ordering (int):

        Returns:
            Deferred[(int, int, str)]:
                (stream ordering, topological ordering, event_id)
        """

        def _f(txn):
            sql = (
                "SELECT stream_ordering, topological_ordering, event_id"
                " FROM events"
                " WHERE room_id = ? AND stream_ordering >= ?"
                " AND NOT outlier"
                " ORDER BY stream_ordering"
                " LIMIT 1"
            )
            txn.execute(sql, (room_id, stream_ordering))
            return txn.fetchone()

        return self.runInteraction("get_room_event_after_stream_ordering", _f)

    @defer.inlineCallbacks
    def get_room_events_max_id(self, room_id=None):
        """Returns the current token for rooms stream.

        By default, it returns the current global stream token. Specifying a
        `room_id` causes it to return the current room specific topological
        token.
        """
        token = yield self.get_room_max_stream_ordering()
        if room_id is None:
            defer.returnValue("s%d" % (token,))
        else:
            topo = yield self.runInteraction(
                "_get_max_topological_txn", self._get_max_topological_txn, room_id
            )
            defer.returnValue("t%d-%d" % (topo, token))

    def get_stream_token_for_event(self, event_id):
        """The stream token for an event
        Args:
            event_id(str): The id of the event to look up a stream token for.
        Raises:
            StoreError if the event wasn't in the database.
        Returns:
            A deferred "s%d" stream token.
        """
        return self._simple_select_one_onecol(
            table="events", keyvalues={"event_id": event_id}, retcol="stream_ordering"
        ).addCallback(lambda row: "s%d" % (row,))

    def get_topological_token_for_event(self, event_id):
        """The stream token for an event
        Args:
            event_id(str): The id of the event to look up a stream token for.
        Raises:
            StoreError if the event wasn't in the database.
        Returns:
            A deferred "t%d-%d" topological token.
        """
        return self._simple_select_one(
            table="events",
            keyvalues={"event_id": event_id},
            retcols=("stream_ordering", "topological_ordering"),
            desc="get_topological_token_for_event",
        ).addCallback(
            lambda row: "t%d-%d" % (row["topological_ordering"], row["stream_ordering"])
        )

    def get_max_topological_token(self, room_id, stream_key):
        """Get the max topological token in a room before the given stream
        ordering.

        Args:
            room_id (str)
            stream_key (int)

        Returns:
            Deferred[int]
        """
        sql = (
            "SELECT coalesce(max(topological_ordering), 0) FROM events"
            " WHERE room_id = ? AND stream_ordering < ?"
        )
        return self._execute(
            "get_max_topological_token", None, sql, room_id, stream_key
        ).addCallback(lambda r: r[0][0] if r else 0)

    def _get_max_topological_txn(self, txn, room_id):
        txn.execute(
            "SELECT MAX(topological_ordering) FROM events" " WHERE room_id = ?",
            (room_id,),
        )

        rows = txn.fetchall()
        return rows[0][0] if rows else 0

    @staticmethod
    def _set_before_and_after(events, rows, topo_order=True):
        """Inserts ordering information to events' internal metadata from
        the DB rows.

        Args:
            events (list[FrozenEvent])
            rows (list[_EventDictReturn])
            topo_order (bool): Whether the events were ordered topologically
                or by stream ordering. If true then all rows should have a non
                null topological_ordering.
        """
        for event, row in zip(events, rows):
            stream = row.stream_ordering
            if topo_order and row.topological_ordering:
                topo = row.topological_ordering
            else:
                topo = None
            internal = event.internal_metadata
            internal.before = str(RoomStreamToken(topo, stream - 1))
            internal.after = str(RoomStreamToken(topo, stream))
            internal.order = (int(topo) if topo else 0, int(stream))

    @defer.inlineCallbacks
    def get_events_around(
        self, room_id, event_id, before_limit, after_limit, event_filter=None
    ):
        """Retrieve events and pagination tokens around a given event in a
        room.

        Args:
            room_id (str)
            event_id (str)
            before_limit (int)
            after_limit (int)
            event_filter (Filter|None)

        Returns:
            dict
        """

        results = yield self.runInteraction(
            "get_events_around",
            self._get_events_around_txn,
            room_id,
            event_id,
            before_limit,
            after_limit,
            event_filter,
        )

        events_before = yield self.get_events_as_list(
            [e for e in results["before"]["event_ids"]], get_prev_content=True
        )

        events_after = yield self.get_events_as_list(
            [e for e in results["after"]["event_ids"]], get_prev_content=True
        )

        defer.returnValue(
            {
                "events_before": events_before,
                "events_after": events_after,
                "start": results["before"]["token"],
                "end": results["after"]["token"],
            }
        )

    def _get_events_around_txn(
        self, txn, room_id, event_id, before_limit, after_limit, event_filter
    ):
        """Retrieves event_ids and pagination tokens around a given event in a
        room.

        Args:
            room_id (str)
            event_id (str)
            before_limit (int)
            after_limit (int)
            event_filter (Filter|None)

        Returns:
            dict
        """

        results = self._simple_select_one_txn(
            txn,
            "events",
            keyvalues={"event_id": event_id, "room_id": room_id},
            retcols=["stream_ordering", "topological_ordering"],
        )

        # Paginating backwards includes the event at the token, but paginating
        # forward doesn't.
        before_token = RoomStreamToken(
            results["topological_ordering"] - 1, results["stream_ordering"]
        )

        after_token = RoomStreamToken(
            results["topological_ordering"], results["stream_ordering"]
        )

        rows, start_token = self._paginate_room_events_txn(
            txn,
            room_id,
            before_token,
            direction='b',
            limit=before_limit,
            event_filter=event_filter,
        )
        events_before = [r.event_id for r in rows]

        rows, end_token = self._paginate_room_events_txn(
            txn,
            room_id,
            after_token,
            direction='f',
            limit=after_limit,
            event_filter=event_filter,
        )
        events_after = [r.event_id for r in rows]

        return {
            "before": {"event_ids": events_before, "token": start_token},
            "after": {"event_ids": events_after, "token": end_token},
        }

    @defer.inlineCallbacks
    def get_all_new_events_stream(self, from_id, current_id, limit):
        """Get all new events

         Returns all events with from_id < stream_ordering <= current_id.

         Args:
             from_id (int):  the stream_ordering of the last event we processed
             current_id (int):  the stream_ordering of the most recently processed event
             limit (int): the maximum number of events to return

         Returns:
             Deferred[Tuple[int, list[FrozenEvent]]]: A tuple of (next_id, events), where
             `next_id` is the next value to pass as `from_id` (it will either be the
             stream_ordering of the last returned event, or, if fewer than `limit` events
             were found, `current_id`.
         """

        def get_all_new_events_stream_txn(txn):
            sql = (
                "SELECT e.stream_ordering, e.event_id"
                " FROM events AS e"
                " WHERE"
                " ? < e.stream_ordering AND e.stream_ordering <= ?"
                " ORDER BY e.stream_ordering ASC"
                " LIMIT ?"
            )

            txn.execute(sql, (from_id, current_id, limit))
            rows = txn.fetchall()

            upper_bound = current_id
            if len(rows) == limit:
                upper_bound = rows[-1][0]

            return upper_bound, [row[1] for row in rows]

        upper_bound, event_ids = yield self.runInteraction(
            "get_all_new_events_stream", get_all_new_events_stream_txn
        )

        events = yield self.get_events_as_list(event_ids)

        defer.returnValue((upper_bound, events))

    def get_federation_out_pos(self, typ):
        return self._simple_select_one_onecol(
            table="federation_stream_position",
            retcol="stream_id",
            keyvalues={"type": typ},
            desc="get_federation_out_pos",
        )

    def update_federation_out_pos(self, typ, stream_id):
        return self._simple_update_one(
            table="federation_stream_position",
            keyvalues={"type": typ},
            updatevalues={"stream_id": stream_id},
            desc="update_federation_out_pos",
        )

    def has_room_changed_since(self, room_id, stream_id):
        return self._events_stream_cache.has_entity_changed(room_id, stream_id)

    def _paginate_room_events_txn(
        self,
        txn,
        room_id,
        from_token,
        to_token=None,
        direction='b',
        limit=-1,
        event_filter=None,
    ):
        """Returns list of events before or after a given token.

        Args:
            txn
            room_id (str)
            from_token (RoomStreamToken): The token used to stream from
            to_token (RoomStreamToken|None): A token which if given limits the
                results to only those before
            direction(char): Either 'b' or 'f' to indicate whether we are
                paginating forwards or backwards from `from_key`.
            limit (int): The maximum number of events to return.
            event_filter (Filter|None): If provided filters the events to
                those that match the filter.

        Returns:
            Deferred[tuple[list[_EventDictReturn], str]]: Returns the results
            as a list of _EventDictReturn and a token that points to the end
            of the result set.
        """

        assert int(limit) >= 0

        # Tokens really represent positions between elements, but we use
        # the convention of pointing to the event before the gap. Hence
        # we have a bit of asymmetry when it comes to equalities.
        args = [False, room_id]
        if direction == 'b':
            order = "DESC"
        else:
            order = "ASC"

        bounds = generate_pagination_where_clause(
            direction=direction,
            column_names=("topological_ordering", "stream_ordering"),
            from_token=from_token,
            to_token=to_token,
            engine=self.database_engine,
        )

        filter_clause, filter_args = filter_to_clause(event_filter)

        if filter_clause:
            bounds += " AND " + filter_clause
            args.extend(filter_args)

        args.append(int(limit))

        sql = (
            "SELECT event_id, topological_ordering, stream_ordering"
            " FROM events"
            " WHERE outlier = ? AND room_id = ? AND %(bounds)s"
            " ORDER BY topological_ordering %(order)s,"
            " stream_ordering %(order)s LIMIT ?"
        ) % {"bounds": bounds, "order": order}

        txn.execute(sql, args)

        rows = [_EventDictReturn(row[0], row[1], row[2]) for row in txn]

        if rows:
            topo = rows[-1].topological_ordering
            toke = rows[-1].stream_ordering
            if direction == 'b':
                # Tokens are positions between events.
                # This token points *after* the last event in the chunk.
                # We need it to point to the event before it in the chunk
                # when we are going backwards so we subtract one from the
                # stream part.
                toke -= 1
            next_token = RoomStreamToken(topo, toke)
        else:
            # TODO (erikj): We should work out what to do here instead.
            next_token = to_token if to_token else from_token

        return rows, str(next_token)

    @defer.inlineCallbacks
    def paginate_room_events(
        self, room_id, from_key, to_key=None, direction='b', limit=-1, event_filter=None
    ):
        """Returns list of events before or after a given token.

        Args:
            room_id (str)
            from_key (str): The token used to stream from
            to_key (str|None): A token which if given limits the results to
                only those before
            direction(char): Either 'b' or 'f' to indicate whether we are
                paginating forwards or backwards from `from_key`.
            limit (int): The maximum number of events to return. Zero or less
                means no limit.
            event_filter (Filter|None): If provided filters the events to
                those that match the filter.

        Returns:
            tuple[list[dict], str]: Returns the results as a list of dicts and
            a token that points to the end of the result set. The dicts have
            the keys "event_id", "topological_ordering" and "stream_orderign".
        """

        from_key = RoomStreamToken.parse(from_key)
        if to_key:
            to_key = RoomStreamToken.parse(to_key)

        rows, token = yield self.runInteraction(
            "paginate_room_events",
            self._paginate_room_events_txn,
            room_id,
            from_key,
            to_key,
            direction,
            limit,
            event_filter,
        )

        events = yield self.get_events_as_list(
            [r.event_id for r in rows], get_prev_content=True
        )

        self._set_before_and_after(events, rows)

        defer.returnValue((events, token))
Ejemplo n.º 57
0
class TypingWriterHandler(FollowerTypingHandler):
    def __init__(self, hs: "HomeServer"):
        super().__init__(hs)

        assert hs.config.worker.writers.typing == hs.get_instance_name()

        self.auth = hs.get_auth()
        self.notifier = hs.get_notifier()

        self.hs = hs

        hs.get_federation_registry().register_edu_handler(
            "m.typing", self._recv_edu)

        hs.get_distributor().observe("user_left_room", self.user_left_room)

        # clock time we expect to stop
        self._member_typing_until = {}  # type: Dict[RoomMember, int]

        # caches which room_ids changed at which serials
        self._typing_stream_change_cache = StreamChangeCache(
            "TypingStreamChangeCache", self._latest_room_serial)

    def _handle_timeout_for_member(self, now: int, member: RoomMember) -> None:
        super()._handle_timeout_for_member(now, member)

        if not self.is_typing(member):
            # Nothing to do if they're no longer typing
            return

        until = self._member_typing_until.get(member, None)
        if not until or until <= now:
            logger.info("Timing out typing for: %s", member.user_id)
            self._stopped_typing(member)
            return

    async def started_typing(self, target_user: UserID, requester: Requester,
                             room_id: str, timeout: int) -> None:
        target_user_id = target_user.to_string()
        auth_user_id = requester.user.to_string()

        if not self.is_mine_id(target_user_id):
            raise SynapseError(400, "User is not hosted on this homeserver")

        if target_user_id != auth_user_id:
            raise AuthError(400, "Cannot set another user's typing state")

        if requester.shadow_banned:
            # We randomly sleep a bit just to annoy the requester.
            await self.clock.sleep(random.randint(1, 10))
            raise ShadowBanError()

        await self.auth.check_user_in_room(room_id, target_user_id)

        logger.debug("%s has started typing in %s", target_user_id, room_id)

        member = RoomMember(room_id=room_id, user_id=target_user_id)

        was_present = member.user_id in self._room_typing.get(room_id, set())

        now = self.clock.time_msec()
        self._member_typing_until[member] = now + timeout

        self.wheel_timer.insert(now=now, obj=member, then=now + timeout)

        if was_present:
            # No point sending another notification
            return

        self._push_update(member=member, typing=True)

    async def stopped_typing(self, target_user: UserID, requester: Requester,
                             room_id: str) -> None:
        target_user_id = target_user.to_string()
        auth_user_id = requester.user.to_string()

        if not self.is_mine_id(target_user_id):
            raise SynapseError(400, "User is not hosted on this homeserver")

        if target_user_id != auth_user_id:
            raise AuthError(400, "Cannot set another user's typing state")

        if requester.shadow_banned:
            # We randomly sleep a bit just to annoy the requester.
            await self.clock.sleep(random.randint(1, 10))
            raise ShadowBanError()

        await self.auth.check_user_in_room(room_id, target_user_id)

        logger.debug("%s has stopped typing in %s", target_user_id, room_id)

        member = RoomMember(room_id=room_id, user_id=target_user_id)

        self._stopped_typing(member)

    def user_left_room(self, user: UserID, room_id: str) -> None:
        user_id = user.to_string()
        if self.is_mine_id(user_id):
            member = RoomMember(room_id=room_id, user_id=user_id)
            self._stopped_typing(member)

    def _stopped_typing(self, member: RoomMember) -> None:
        if member.user_id not in self._room_typing.get(member.room_id, set()):
            # No point
            return

        self._member_typing_until.pop(member, None)
        self._member_last_federation_poke.pop(member, None)

        self._push_update(member=member, typing=False)

    def _push_update(self, member: RoomMember, typing: bool) -> None:
        if self.hs.is_mine_id(member.user_id):
            # Only send updates for changes to our own users.
            run_as_background_process("typing._push_remote", self._push_remote,
                                      member, typing)

        self._push_update_local(member=member, typing=typing)

    async def _recv_edu(self, origin: str, content: JsonDict) -> None:
        room_id = content["room_id"]
        user_id = content["user_id"]

        member = RoomMember(user_id=user_id, room_id=room_id)

        # Check that the string is a valid user id
        user = UserID.from_string(user_id)

        if user.domain != origin:
            logger.info("Got typing update from %r with bad 'user_id': %r",
                        origin, user_id)
            return

        users = await self.store.get_users_in_room(room_id)
        domains = {get_domain_from_id(u) for u in users}

        if self.server_name in domains:
            logger.info("Got typing update from %s: %r", user_id, content)
            now = self.clock.time_msec()
            self._member_typing_until[member] = now + FEDERATION_TIMEOUT
            self.wheel_timer.insert(now=now,
                                    obj=member,
                                    then=now + FEDERATION_TIMEOUT)
            self._push_update_local(member=member, typing=content["typing"])

    def _push_update_local(self, member: RoomMember, typing: bool) -> None:
        room_set = self._room_typing.setdefault(member.room_id, set())
        if typing:
            room_set.add(member.user_id)
        else:
            room_set.discard(member.user_id)

        self._latest_room_serial += 1
        self._room_serials[member.room_id] = self._latest_room_serial
        self._typing_stream_change_cache.entity_has_changed(
            member.room_id, self._latest_room_serial)

        self.notifier.on_new_event("typing_key",
                                   self._latest_room_serial,
                                   rooms=[member.room_id])

    async def get_all_typing_updates(
            self, instance_name: str, last_id: int, current_id: int,
            limit: int) -> Tuple[List[Tuple[int, list]], int, bool]:
        """Get updates for typing replication stream.

        Args:
            instance_name: The writer we want to fetch updates from. Unused
                here since there is only ever one writer.
            last_id: The token to fetch updates from. Exclusive.
            current_id: The token to fetch updates up to. Inclusive.
            limit: The requested limit for the number of rows to return. The
                function may return more or fewer rows.

        Returns:
            A tuple consisting of: the updates, a token to use to fetch
            subsequent updates, and whether we returned fewer rows than exists
            between the requested tokens due to the limit.

            The token returned can be used in a subsequent call to this
            function to get further updates.

            The updates are a list of 2-tuples of stream ID and the row data
        """

        if last_id == current_id:
            return [], current_id, False

        changed_rooms = self._typing_stream_change_cache.get_all_entities_changed(
            last_id)  # type: Optional[Iterable[str]]

        if changed_rooms is None:
            changed_rooms = self._room_serials

        rows = []
        for room_id in changed_rooms:
            serial = self._room_serials[room_id]
            if last_id < serial <= current_id:
                typing = self._room_typing[room_id]
                rows.append((serial, [room_id, list(typing)]))
        rows.sort()

        limited = False
        # We, unusually, use a strict limit here as we have all the rows in
        # memory rather than pulling them out of the database with a `LIMIT ?`
        # clause.
        if len(rows) > limit:
            rows = rows[:limit]
            current_id = rows[-1][0]
            limited = True

        return rows, current_id, limited

    def process_replication_rows(
            self, token: int,
            rows: List[TypingStream.TypingStreamRow]) -> None:
        # The writing process should never get updates from replication.
        raise Exception(
            "Typing writer instance got typing info over replication")
Ejemplo n.º 58
0
class PushRulesWorkerStore(
    ApplicationServiceWorkerStore,
    ReceiptsWorkerStore,
    PusherWorkerStore,
    RoomMemberWorkerStore,
    SQLBaseStore,
):
    """This is an abstract base class where subclasses must implement
    `get_max_push_rules_stream_id` which can be called in the initializer.
    """

    # This ABCMeta metaclass ensures that we cannot be instantiated without
    # the abstract methods being implemented.
    __metaclass__ = abc.ABCMeta

    def __init__(self, db_conn, hs):
        super(PushRulesWorkerStore, self).__init__(db_conn, hs)

        push_rules_prefill, push_rules_id = self._get_cache_dict(
            db_conn,
            "push_rules_stream",
            entity_column="user_id",
            stream_column="stream_id",
            max_value=self.get_max_push_rules_stream_id(),
        )

        self.push_rules_stream_cache = StreamChangeCache(
            "PushRulesStreamChangeCache",
            push_rules_id,
            prefilled_cache=push_rules_prefill,
        )

    @abc.abstractmethod
    def get_max_push_rules_stream_id(self):
        """Get the position of the push rules stream.

        Returns:
            int
        """
        raise NotImplementedError()

    @cachedInlineCallbacks(max_entries=5000)
    def get_push_rules_for_user(self, user_id):
        rows = yield self._simple_select_list(
            table="push_rules",
            keyvalues={"user_name": user_id},
            retcols=(
                "user_name",
                "rule_id",
                "priority_class",
                "priority",
                "conditions",
                "actions",
            ),
            desc="get_push_rules_enabled_for_user",
        )

        rows.sort(key=lambda row: (-int(row["priority_class"]), -int(row["priority"])))

        enabled_map = yield self.get_push_rules_enabled_for_user(user_id)

        rules = _load_rules(rows, enabled_map)

        defer.returnValue(rules)

    @cachedInlineCallbacks(max_entries=5000)
    def get_push_rules_enabled_for_user(self, user_id):
        results = yield self._simple_select_list(
            table="push_rules_enable",
            keyvalues={'user_name': user_id},
            retcols=("user_name", "rule_id", "enabled"),
            desc="get_push_rules_enabled_for_user",
        )
        defer.returnValue(
            {r['rule_id']: False if r['enabled'] == 0 else True for r in results}
        )

    def have_push_rules_changed_for_user(self, user_id, last_id):
        if not self.push_rules_stream_cache.has_entity_changed(user_id, last_id):
            return defer.succeed(False)
        else:

            def have_push_rules_changed_txn(txn):
                sql = (
                    "SELECT COUNT(stream_id) FROM push_rules_stream"
                    " WHERE user_id = ? AND ? < stream_id"
                )
                txn.execute(sql, (user_id, last_id))
                count, = txn.fetchone()
                return bool(count)

            return self.runInteraction(
                "have_push_rules_changed", have_push_rules_changed_txn
            )

    @cachedList(
        cached_method_name="get_push_rules_for_user",
        list_name="user_ids",
        num_args=1,
        inlineCallbacks=True,
    )
    def bulk_get_push_rules(self, user_ids):
        if not user_ids:
            defer.returnValue({})

        results = {user_id: [] for user_id in user_ids}

        rows = yield self._simple_select_many_batch(
            table="push_rules",
            column="user_name",
            iterable=user_ids,
            retcols=("*",),
            desc="bulk_get_push_rules",
        )

        rows.sort(key=lambda row: (-int(row["priority_class"]), -int(row["priority"])))

        for row in rows:
            results.setdefault(row['user_name'], []).append(row)

        enabled_map_by_user = yield self.bulk_get_push_rules_enabled(user_ids)

        for user_id, rules in results.items():
            results[user_id] = _load_rules(rules, enabled_map_by_user.get(user_id, {}))

        defer.returnValue(results)

    @defer.inlineCallbacks
    def move_push_rule_from_room_to_room(self, new_room_id, user_id, rule):
        """Move a single push rule from one room to another for a specific user.

        Args:
            new_room_id (str): ID of the new room.
            user_id (str): ID of user the push rule belongs to.
            rule (Dict): A push rule.
        """
        # Create new rule id
        rule_id_scope = '/'.join(rule["rule_id"].split('/')[:-1])
        new_rule_id = rule_id_scope + "/" + new_room_id

        # Change room id in each condition
        for condition in rule.get("conditions", []):
            if condition.get("key") == "room_id":
                condition["pattern"] = new_room_id

        # Add the rule for the new room
        yield self.add_push_rule(
            user_id=user_id,
            rule_id=new_rule_id,
            priority_class=rule["priority_class"],
            conditions=rule["conditions"],
            actions=rule["actions"],
        )

        # Delete push rule for the old room
        yield self.delete_push_rule(user_id, rule["rule_id"])

    @defer.inlineCallbacks
    def move_push_rules_from_room_to_room_for_user(
        self, old_room_id, new_room_id, user_id
    ):
        """Move all of the push rules from one room to another for a specific
        user.

        Args:
            old_room_id (str): ID of the old room.
            new_room_id (str): ID of the new room.
            user_id (str): ID of user to copy push rules for.
        """
        # Retrieve push rules for this user
        user_push_rules = yield self.get_push_rules_for_user(user_id)

        # Get rules relating to the old room, move them to the new room, then
        # delete them from the old room
        for rule in user_push_rules:
            conditions = rule.get("conditions", [])
            if any(
                (c.get("key") == "room_id" and c.get("pattern") == old_room_id)
                for c in conditions
            ):
                self.move_push_rule_from_room_to_room(new_room_id, user_id, rule)

    @defer.inlineCallbacks
    def bulk_get_push_rules_for_room(self, event, context):
        state_group = context.state_group
        if not state_group:
            # If state_group is None it means it has yet to be assigned a
            # state group, i.e. we need to make sure that calls with a state_group
            # of None don't hit previous cached calls with a None state_group.
            # To do this we set the state_group to a new object as object() != object()
            state_group = object()

        current_state_ids = yield context.get_current_state_ids(self)
        result = yield self._bulk_get_push_rules_for_room(
            event.room_id, state_group, current_state_ids, event=event
        )
        defer.returnValue(result)

    @cachedInlineCallbacks(num_args=2, cache_context=True)
    def _bulk_get_push_rules_for_room(
        self, room_id, state_group, current_state_ids, cache_context, event=None
    ):
        # We don't use `state_group`, its there so that we can cache based
        # on it. However, its important that its never None, since two current_state's
        # with a state_group of None are likely to be different.
        # See bulk_get_push_rules_for_room for how we work around this.
        assert state_group is not None

        # We also will want to generate notifs for other people in the room so
        # their unread countss are correct in the event stream, but to avoid
        # generating them for bot / AS users etc, we only do so for people who've
        # sent a read receipt into the room.

        users_in_room = yield self._get_joined_users_from_context(
            room_id,
            state_group,
            current_state_ids,
            on_invalidate=cache_context.invalidate,
            event=event,
        )

        # We ignore app service users for now. This is so that we don't fill
        # up the `get_if_users_have_pushers` cache with AS entries that we
        # know don't have pushers, nor even read receipts.
        local_users_in_room = set(
            u
            for u in users_in_room
            if self.hs.is_mine_id(u)
            and not self.get_if_app_services_interested_in_user(u)
        )

        # users in the room who have pushers need to get push rules run because
        # that's how their pushers work
        if_users_with_pushers = yield self.get_if_users_have_pushers(
            local_users_in_room, on_invalidate=cache_context.invalidate
        )
        user_ids = set(
            uid for uid, have_pusher in if_users_with_pushers.items() if have_pusher
        )

        users_with_receipts = yield self.get_users_with_read_receipts_in_room(
            room_id, on_invalidate=cache_context.invalidate
        )

        # any users with pushers must be ours: they have pushers
        for uid in users_with_receipts:
            if uid in local_users_in_room:
                user_ids.add(uid)

        rules_by_user = yield self.bulk_get_push_rules(
            user_ids, on_invalidate=cache_context.invalidate
        )

        rules_by_user = {k: v for k, v in rules_by_user.items() if v is not None}

        defer.returnValue(rules_by_user)

    @cachedList(
        cached_method_name="get_push_rules_enabled_for_user",
        list_name="user_ids",
        num_args=1,
        inlineCallbacks=True,
    )
    def bulk_get_push_rules_enabled(self, user_ids):
        if not user_ids:
            defer.returnValue({})

        results = {user_id: {} for user_id in user_ids}

        rows = yield self._simple_select_many_batch(
            table="push_rules_enable",
            column="user_name",
            iterable=user_ids,
            retcols=("user_name", "rule_id", "enabled"),
            desc="bulk_get_push_rules_enabled",
        )
        for row in rows:
            enabled = bool(row['enabled'])
            results.setdefault(row['user_name'], {})[row['rule_id']] = enabled
        defer.returnValue(results)
Ejemplo n.º 59
0
    def __init__(self, db_conn, hs):
        super(ReceiptsWorkerStore, self).__init__(db_conn, hs)

        self._receipts_stream_cache = StreamChangeCache(
            "ReceiptsRoomChangeCache", self.get_max_receipt_stream_id()
        )
Ejemplo n.º 60
0
class SlavedAccountDataStore(BaseSlavedStore):

    def __init__(self, db_conn, hs):
        super(SlavedAccountDataStore, self).__init__(db_conn, hs)
        self._account_data_id_gen = SlavedIdTracker(
            db_conn, "account_data_max_stream_id", "stream_id",
        )
        self._account_data_stream_cache = StreamChangeCache(
            "AccountDataAndTagsChangeCache",
            self._account_data_id_gen.get_current_token(),
        )

    get_account_data_for_user = (
        AccountDataStore.__dict__["get_account_data_for_user"]
    )

    get_global_account_data_by_type_for_users = (
        AccountDataStore.__dict__["get_global_account_data_by_type_for_users"]
    )

    get_global_account_data_by_type_for_user = (
        AccountDataStore.__dict__["get_global_account_data_by_type_for_user"]
    )

    get_tags_for_user = TagsStore.__dict__["get_tags_for_user"]

    get_updated_tags = DataStore.get_updated_tags.__func__
    get_updated_account_data_for_user = (
        DataStore.get_updated_account_data_for_user.__func__
    )

    def get_max_account_data_stream_id(self):
        return self._account_data_id_gen.get_current_token()

    def stream_positions(self):
        result = super(SlavedAccountDataStore, self).stream_positions()
        position = self._account_data_id_gen.get_current_token()
        result["user_account_data"] = position
        result["room_account_data"] = position
        result["tag_account_data"] = position
        return result

    def process_replication(self, result):
        stream = result.get("user_account_data")
        if stream:
            self._account_data_id_gen.advance(int(stream["position"]))
            for row in stream["rows"]:
                position, user_id, data_type = row[:3]
                self.get_global_account_data_by_type_for_user.invalidate(
                    (data_type, user_id,)
                )
                self.get_account_data_for_user.invalidate((user_id,))
                self._account_data_stream_cache.entity_has_changed(
                    user_id, position
                )

        stream = result.get("room_account_data")
        if stream:
            self._account_data_id_gen.advance(int(stream["position"]))
            for row in stream["rows"]:
                position, user_id = row[:2]
                self.get_account_data_for_user.invalidate((user_id,))
                self._account_data_stream_cache.entity_has_changed(
                    user_id, position
                )

        stream = result.get("tag_account_data")
        if stream:
            self._account_data_id_gen.advance(int(stream["position"]))
            for row in stream["rows"]:
                position, user_id = row[:2]
                self.get_tags_for_user.invalidate((user_id,))
                self._account_data_stream_cache.entity_has_changed(
                    user_id, position
                )

        return super(SlavedAccountDataStore, self).process_replication(result)