Ejemplo n.º 1
0
    async def persist_event(
        self,
        event: EventBase,
        context: EventContext,
        backfilled: bool = False
    ) -> Tuple[EventBase, PersistedEventPosition, RoomStreamToken]:
        """
        Returns:
            The event, stream ordering of `event`, and the stream ordering of the
            latest persisted event. The returned event may not match the given
            event if it was deduplicated due to an existing event matching the
            transaction ID.
        """
        # add_to_queue returns a map from event ID to existing event ID if the
        # event was deduplicated. (The dict may also include other entries if
        # the event was persisted in a batch with other events.)
        replaced_events = await self._event_persist_queue.add_to_queue(
            event.room_id, [(event, context)], backfilled=backfilled)
        replaced_event = replaced_events.get(event.event_id)
        if replaced_event:
            event = await self.main_store.get_event(replaced_event)

        event_stream_id = event.internal_metadata.stream_ordering
        # stream ordering should have been assigned by now
        assert event_stream_id

        pos = PersistedEventPosition(self._instance_name, event_stream_id)
        return event, pos, self.main_store.get_room_max_token()
Ejemplo n.º 2
0
    async def get_position_for_event(self,
                                     event_id: str) -> PersistedEventPosition:
        """Get the persisted position for an event"""
        row = await self.db_pool.simple_select_one(
            table="events",
            keyvalues={"event_id": event_id},
            retcols=("stream_ordering", "instance_name"),
            desc="get_position_for_event",
        )

        return PersistedEventPosition(row["instance_name"] or "master",
                                      row["stream_ordering"])
Ejemplo n.º 3
0
    def test_get_rooms_for_user_with_stream_ordering(self):
        """Check that the cache on get_rooms_for_user_with_stream_ordering is invalidated
        by rows in the events stream
        """
        self.persist(type="m.room.create", key="", creator=USER_ID)
        self.persist(type="m.room.member", key=USER_ID, membership="join")
        self.replicate()
        self.check("get_rooms_for_user_with_stream_ordering", (USER_ID_2,), set())

        j2 = self.persist(
            type="m.room.member", sender=USER_ID_2, key=USER_ID_2, membership="join"
        )
        self.replicate()

        expected_pos = PersistedEventPosition(
            "master", j2.internal_metadata.stream_ordering
        )
        self.check(
            "get_rooms_for_user_with_stream_ordering",
            (USER_ID_2,),
            {GetRoomsForUserWithStreamOrdering(ROOM_ID, expected_pos)},
        )
Ejemplo n.º 4
0
    async def persist_event(
        self,
        event: EventBase,
        context: EventContext,
        backfilled: bool = False
    ) -> Tuple[PersistedEventPosition, RoomStreamToken]:
        """
        Returns:
            The stream ordering of `event`, and the stream ordering of the
            latest persisted event
        """
        deferred = self._event_persist_queue.add_to_queue(
            event.room_id, [(event, context)], backfilled=backfilled)

        self._maybe_start_persisting(event.room_id)

        await make_deferred_yieldable(deferred)

        event_stream_id = event.internal_metadata.stream_ordering

        pos = PersistedEventPosition(self._instance_name, event_stream_id)
        return pos, self.main_store.get_room_max_token()
Ejemplo n.º 5
0
    async def on_rdata(self, stream_name: str, instance_name: str, token: int,
                       rows: list):
        """Called to handle a batch of replication data with a given stream token.

        By default this just pokes the slave store. Can be overridden in subclasses to
        handle more.

        Args:
            stream_name: name of the replication stream for this batch of rows
            instance_name: the instance that wrote the rows.
            token: stream token for this batch of rows
            rows: a list of Stream.ROW_TYPE objects as returned by Stream.parse_row.
        """
        self.store.process_replication_rows(stream_name, instance_name, token,
                                            rows)

        if self.send_handler:
            await self.send_handler.process_replication_rows(
                stream_name, token, rows)

        if stream_name == TypingStream.NAME:
            self._typing_handler.process_replication_rows(token, rows)
            self.notifier.on_new_event("typing_key",
                                       token,
                                       rooms=[row.room_id for row in rows])
        elif stream_name == PushRulesStream.NAME:
            self.notifier.on_new_event("push_rules_key",
                                       token,
                                       users=[row.user_id for row in rows])
        elif stream_name in (AccountDataStream.NAME,
                             TagAccountDataStream.NAME):
            self.notifier.on_new_event("account_data_key",
                                       token,
                                       users=[row.user_id for row in rows])
        elif stream_name == ReceiptsStream.NAME:
            self.notifier.on_new_event("receipt_key",
                                       token,
                                       rooms=[row.room_id for row in rows])
            await self._pusher_pool.on_new_receipts(
                token, token, {row.room_id
                               for row in rows})
        elif stream_name == ToDeviceStream.NAME:
            entities = [
                row.entity for row in rows if row.entity.startswith("@")
            ]
            if entities:
                self.notifier.on_new_event("to_device_key",
                                           token,
                                           users=entities)
        elif stream_name == DeviceListsStream.NAME:
            all_room_ids: Set[str] = set()
            for row in rows:
                if row.entity.startswith("@"):
                    room_ids = await self.store.get_rooms_for_user(row.entity)
                    all_room_ids.update(room_ids)
            self.notifier.on_new_event("device_list_key",
                                       token,
                                       rooms=all_room_ids)
        elif stream_name == GroupServerStream.NAME:
            self.notifier.on_new_event("groups_key",
                                       token,
                                       users=[row.user_id for row in rows])
        elif stream_name == PushersStream.NAME:
            for row in rows:
                if row.deleted:
                    self.stop_pusher(row.user_id, row.app_id, row.pushkey)
                else:
                    await self.start_pusher(row.user_id, row.app_id,
                                            row.pushkey)
        elif stream_name == EventsStream.NAME:
            # We shouldn't get multiple rows per token for events stream, so
            # we don't need to optimise this for multiple rows.
            for row in rows:
                if row.type != EventsStreamEventRow.TypeId:
                    continue
                assert isinstance(row, EventsStreamRow)
                assert isinstance(row.data, EventsStreamEventRow)

                if row.data.rejected:
                    continue

                extra_users: Tuple[UserID, ...] = ()
                if row.data.type == EventTypes.Member and row.data.state_key:
                    extra_users = (UserID.from_string(row.data.state_key), )

                max_token = self.store.get_room_max_token()
                event_pos = PersistedEventPosition(instance_name, token)
                await self.notifier.on_new_room_event_args(
                    event_pos=event_pos,
                    max_room_stream_token=max_token,
                    extra_users=extra_users,
                    room_id=row.data.room_id,
                    event_id=row.data.event_id,
                    event_type=row.data.type,
                    state_key=row.data.state_key,
                    membership=row.data.membership,
                )

        await self._presence_handler.process_replication_rows(
            stream_name, instance_name, token, rows)

        # Notify any waiting deferreds. The list is ordered by position so we
        # just iterate through the list until we reach a position that is
        # greater than the received row position.
        waiting_list = self._streams_to_waiters.get(stream_name, [])

        # Index of first item with a position after the current token, i.e we
        # have called all deferreds before this index. If not overwritten by
        # loop below means either a) no items in list so no-op or b) all items
        # in list were called and so the list should be cleared. Setting it to
        # `len(list)` works for both cases.
        index_of_first_deferred_not_called = len(waiting_list)

        for idx, (position, deferred) in enumerate(waiting_list):
            if position <= token:
                try:
                    with PreserveLoggingContext():
                        deferred.callback(None)
                except Exception:
                    # The deferred has been cancelled or timed out.
                    pass
            else:
                # The list is sorted by position so we don't need to continue
                # checking any further entries in the list.
                index_of_first_deferred_not_called = idx
                break

        # Drop all entries in the waiting list that were called in the above
        # loop. (This maintains the order so no need to resort)
        waiting_list[:] = waiting_list[index_of_first_deferred_not_called:]
Ejemplo n.º 6
0
    async def on_rdata(self, stream_name: str, instance_name: str, token: int,
                       rows: list):
        """Called to handle a batch of replication data with a given stream token.

        By default this just pokes the slave store. Can be overridden in subclasses to
        handle more.

        Args:
            stream_name: name of the replication stream for this batch of rows
            instance_name: the instance that wrote the rows.
            token: stream token for this batch of rows
            rows: a list of Stream.ROW_TYPE objects as returned by Stream.parse_row.
        """
        self.store.process_replication_rows(stream_name, instance_name, token,
                                            rows)

        if stream_name == TypingStream.NAME:
            self._typing_handler.process_replication_rows(token, rows)
            self.notifier.on_new_event("typing_key",
                                       token,
                                       rooms=[row.room_id for row in rows])

        if stream_name == EventsStream.NAME:
            # We shouldn't get multiple rows per token for events stream, so
            # we don't need to optimise this for multiple rows.
            for row in rows:
                if row.type != EventsStreamEventRow.TypeId:
                    continue
                assert isinstance(row, EventsStreamRow)

                event = await self.store.get_event(row.data.event_id,
                                                   allow_rejected=True)
                if event.rejected_reason:
                    continue

                extra_users = ()  # type: Tuple[UserID, ...]
                if event.type == EventTypes.Member:
                    extra_users = (UserID.from_string(event.state_key), )

                max_token = self.store.get_room_max_token()
                event_pos = PersistedEventPosition(instance_name, token)
                self.notifier.on_new_room_event(event, event_pos, max_token,
                                                extra_users)

        # Notify any waiting deferreds. The list is ordered by position so we
        # just iterate through the list until we reach a position that is
        # greater than the received row position.
        waiting_list = self._streams_to_waiters.get(stream_name, [])

        # Index of first item with a position after the current token, i.e we
        # have called all deferreds before this index. If not overwritten by
        # loop below means either a) no items in list so no-op or b) all items
        # in list were called and so the list should be cleared. Setting it to
        # `len(list)` works for both cases.
        index_of_first_deferred_not_called = len(waiting_list)

        for idx, (position, deferred) in enumerate(waiting_list):
            if position <= token:
                try:
                    with PreserveLoggingContext():
                        deferred.callback(None)
                except Exception:
                    # The deferred has been cancelled or timed out.
                    pass
            else:
                # The list is sorted by position so we don't need to continue
                # checking any further entries in the list.
                index_of_first_deferred_not_called = idx
                break

        # Drop all entries in the waiting list that were called in the above
        # loop. (This maintains the order so no need to resort)
        waiting_list[:] = waiting_list[index_of_first_deferred_not_called:]
Ejemplo n.º 7
0
    def test_get_rooms_for_user_with_stream_ordering_with_multi_event_persist(self):
        """Check that current_state invalidation happens correctly with multiple events
        in the persistence batch.

        This test attempts to reproduce a race condition between the event persistence
        loop and a worker-based Sync handler.

        The problem occurred when the master persisted several events in one batch. It
        only updates the current_state at the end of each batch, so the obvious thing
        to do is then to issue a current_state_delta stream update corresponding to the
        last stream_id in the batch.

        However, that raises the possibility that a worker will see the replication
        notification for a join event before the current_state caches are invalidated.

        The test involves:
         * creating a join and a message event for a user, and persisting them in the
           same batch

         * controlling the replication stream so that updates are sent gradually

         * between each bunch of replication updates, check that we see a consistent
           snapshot of the state.
        """
        self.persist(type="m.room.create", key="", creator=USER_ID)
        self.persist(type="m.room.member", key=USER_ID, membership="join")
        self.replicate()
        self.check("get_rooms_for_user_with_stream_ordering", (USER_ID_2,), set())

        # limit the replication rate
        repl_transport = self._server_transport
        assert isinstance(repl_transport, FakeTransport)
        repl_transport.autoflush = False

        # build the join and message events and persist them in the same batch.
        logger.info("----- build test events ------")
        j2, j2ctx = self.build_event(
            type="m.room.member", sender=USER_ID_2, key=USER_ID_2, membership="join"
        )
        msg, msgctx = self.build_event()
        self.get_success(
            self._storage_controllers.persistence.persist_events(
                [(j2, j2ctx), (msg, msgctx)]
            )
        )
        self.replicate()

        event_source = RoomEventSource(self.hs)
        event_source.store = self.slaved_store
        current_token = event_source.get_current_key()

        # gradually stream out the replication
        while repl_transport.buffer:
            logger.info("------ flush ------")
            repl_transport.flush(30)
            self.pump(0)

            prev_token = current_token
            current_token = event_source.get_current_key()

            # attempt to replicate the behaviour of the sync handler.
            #
            # First, we get a list of the rooms we are joined to
            joined_rooms = self.get_success(
                self.slaved_store.get_rooms_for_user_with_stream_ordering(USER_ID_2)
            )

            # Then, we get a list of the events since the last sync
            membership_changes = self.get_success(
                self.slaved_store.get_membership_changes_for_user(
                    USER_ID_2, prev_token, current_token
                )
            )

            logger.info(
                "%s->%s: joined_rooms=%r membership_changes=%r",
                prev_token,
                current_token,
                joined_rooms,
                membership_changes,
            )

            # the membership change is only any use to us if the room is in the
            # joined_rooms list.
            if membership_changes:
                expected_pos = PersistedEventPosition(
                    "master", j2.internal_metadata.stream_ordering
                )
                self.assertEqual(
                    joined_rooms,
                    {GetRoomsForUserWithStreamOrdering(ROOM_ID, expected_pos)},
                )