Esempio n. 1
0
def _create_auth_events_from_maps(
    room_version: RoomVersion,
    unconflicted_state: StateMap[str],
    conflicted_state: StateMap[Set[str]],
    state_map: Dict[str, EventBase],
) -> StateMap[str]:
    """

    Args:
        room_version: The room version.
        unconflicted_state: The unconflicted state map.
        conflicted_state: The conflicted state map.
        state_map:

    Returns:
        A map from state key to event id.
    """
    auth_events = {}
    for event_ids in conflicted_state.values():
        for event_id in event_ids:
            if event_id in state_map:
                keys = event_auth.auth_types_for_event(room_version,
                                                       state_map[event_id])
                for key in keys:
                    if key not in auth_events:
                        auth_event_id = unconflicted_state.get(key, None)
                        if auth_event_id:
                            auth_events[key] = auth_event_id
    return auth_events
Esempio n. 2
0
def _iterative_auth_checks(room_id, room_version, event_ids, base_state,
                           event_map, state_res_store):
    """Sequentially apply auth checks to each event in given list, updating the
    state as it goes along.

    Args:
        room_id (str)
        room_version (str)
        event_ids (list[str]): Ordered list of events to apply auth checks to
        base_state (StateMap[str]): The set of state to start with
        event_map (dict[str,FrozenEvent])
        state_res_store (StateResolutionStore)

    Returns:
        Deferred[StateMap[str]]: Returns the final updated state
    """
    resolved_state = base_state.copy()
    room_version_obj = KNOWN_ROOM_VERSIONS[room_version]

    for event_id in event_ids:
        event = event_map[event_id]

        auth_events = {}
        for aid in event.auth_event_ids():
            ev = yield _get_event(room_id,
                                  aid,
                                  event_map,
                                  state_res_store,
                                  allow_none=True)

            if not ev:
                logger.warning("auth_event id %s for event %s is missing", aid,
                               event_id)
            else:
                if ev.rejected_reason is None:
                    auth_events[(ev.type, ev.state_key)] = ev

        for key in event_auth.auth_types_for_event(event):
            if key in resolved_state:
                ev_id = resolved_state[key]
                ev = yield _get_event(room_id, ev_id, event_map,
                                      state_res_store)

                if ev.rejected_reason is None:
                    auth_events[key] = event_map[ev_id]

        try:
            event_auth.check(
                room_version_obj,
                event,
                auth_events,
                do_sig_check=False,
                do_size_check=False,
            )

            resolved_state[(event.type, event.state_key)] = event_id
        except AuthError:
            pass

    return resolved_state
Esempio n. 3
0
def _iterative_auth_checks(event_ids, base_state, event_map):
    """Sequentially apply auth checks to each event in given list, updating the
    state as it goes along.
    """
    resolved_state = base_state.copy()

    for event_id in event_ids:
        event = event_map[event_id]

        auth_events = {(event_map[aid].type, event_map[aid].state_key):
                       event_map[aid]
                       for aid, _ in event.auth_events}
        for key in event_auth.auth_types_for_event(event):
            if key in resolved_state:
                auth_events[key] = event_map[resolved_state[key]]

        try:
            event_auth.check(event,
                             auth_events,
                             do_sig_check=False,
                             do_size_check=False)

            resolved_state[(event.type, event.state_key)] = event_id
        except AuthError:
            pass

    return resolved_state
Esempio n. 4
0
def _resolve_auth_events(room_version: RoomVersion, events: List[EventBase],
                         auth_events: StateMap[EventBase]) -> EventBase:
    reverse = list(reversed(_ordered_events(events)))

    auth_keys = {
        key
        for event in events
        for key in event_auth.auth_types_for_event(room_version, event)
    }

    new_auth_events = {}
    for key in auth_keys:
        auth_event = auth_events.get(key, None)
        if auth_event:
            new_auth_events[key] = auth_event

    auth_events = new_auth_events

    prev_event = reverse[0]
    for event in reverse[1:]:
        auth_events[(prev_event.type, prev_event.state_key)] = prev_event
        try:
            # The signatures have already been checked at this point
            event_auth.check_auth_rules_for_event(
                RoomVersions.V1,
                event,
                auth_events.values(),
            )
            prev_event = event
        except AuthError:
            return prev_event

    return event
Esempio n. 5
0
    def compute_auth_events(
        self, event, current_state_ids: StateMap[str], for_verification: bool = False,
    ) -> List[str]:
        """Given an event and current state return the list of event IDs used
        to auth an event.

        If `for_verification` is False then only return auth events that
        should be added to the event's `auth_events`.

        Returns:
            List of event IDs.
        """

        if event.type == EventTypes.Create:
            return []

        # Currently we ignore the `for_verification` flag even though there are
        # some situations where we can drop particular auth events when adding
        # to the event's `auth_events` (e.g. joins pointing to previous joins
        # when room is publicly joinable). Dropping event IDs has the
        # advantage that the auth chain for the room grows slower, but we use
        # the auth chain in state resolution v2 to order events, which means
        # care must be taken if dropping events to ensure that it doesn't
        # introduce undesirable "state reset" behaviour.
        #
        # All of which sounds a bit tricky so we don't bother for now.

        auth_ids = []
        for etype, state_key in event_auth.auth_types_for_event(event):
            auth_ev_id = current_state_ids.get((etype, state_key))
            if auth_ev_id:
                auth_ids.append(auth_ev_id)

        return auth_ids
Esempio n. 6
0
def _resolve_auth_events(events, auth_events):
    reverse = [i for i in reversed(_ordered_events(events))]

    auth_keys = set(key for event in events
                    for key in event_auth.auth_types_for_event(event))

    new_auth_events = {}
    for key in auth_keys:
        auth_event = auth_events.get(key, None)
        if auth_event:
            new_auth_events[key] = auth_event

    auth_events = new_auth_events

    prev_event = reverse[0]
    for event in reverse[1:]:
        auth_events[(prev_event.type, prev_event.state_key)] = prev_event
        try:
            # The signatures have already been checked at this point
            event_auth.check(
                RoomVersions.V1.identifier,
                event,
                auth_events,
                do_sig_check=False,
                do_size_check=False,
            )
            prev_event = event
        except AuthError:
            return prev_event

    return event
Esempio n. 7
0
def _resolve_auth_events(events, auth_events):
    reverse = [i for i in reversed(_ordered_events(events))]

    auth_keys = set(
        key
        for event in events
        for key in event_auth.auth_types_for_event(event)
    )

    new_auth_events = {}
    for key in auth_keys:
        auth_event = auth_events.get(key, None)
        if auth_event:
            new_auth_events[key] = auth_event

    auth_events = new_auth_events

    prev_event = reverse[0]
    for event in reverse[1:]:
        auth_events[(prev_event.type, prev_event.state_key)] = prev_event
        try:
            # The signatures have already been checked at this point
            event_auth.check(event, auth_events, do_sig_check=False, do_size_check=False)
            prev_event = event
        except AuthError:
            return prev_event

    return event
Esempio n. 8
0
def _create_auth_events_from_maps(unconflicted_state, conflicted_state, state_map):
    auth_events = {}
    for event_ids in itervalues(conflicted_state):
        for event_id in event_ids:
            if event_id in state_map:
                keys = event_auth.auth_types_for_event(state_map[event_id])
                for key in keys:
                    if key not in auth_events:
                        event_id = unconflicted_state.get(key, None)
                        if event_id:
                            auth_events[key] = event_id
    return auth_events
Esempio n. 9
0
def _create_auth_events_from_maps(unconflicted_state, conflicted_state, state_map):
    auth_events = {}
    for event_ids in itervalues(conflicted_state):
        for event_id in event_ids:
            if event_id in state_map:
                keys = event_auth.auth_types_for_event(state_map[event_id])
                for key in keys:
                    if key not in auth_events:
                        event_id = unconflicted_state.get(key, None)
                        if event_id:
                            auth_events[key] = event_id
    return auth_events
Esempio n. 10
0
def _iterative_auth_checks(room_version, event_ids, base_state, event_map,
                           state_res_store):
    """Sequentially apply auth checks to each event in given list, updating the
    state as it goes along.

    Args:
        room_version (str)
        event_ids (list[str]): Ordered list of events to apply auth checks to
        base_state (dict[tuple[str, str], str]): The set of state to start with
        event_map (dict[str,FrozenEvent])
        state_res_store (StateResolutionStore)

    Returns:
        Deferred[dict[tuple[str, str], str]]: Returns the final updated state
    """
    resolved_state = base_state.copy()

    for event_id in event_ids:
        event = event_map[event_id]

        auth_events = {}
        for aid in event.auth_event_ids():
            ev = yield _get_event(aid, event_map, state_res_store)

            if ev.rejected_reason is None:
                auth_events[(ev.type, ev.state_key)] = ev

        for key in event_auth.auth_types_for_event(event):
            if key in resolved_state:
                ev_id = resolved_state[key]
                ev = yield _get_event(ev_id, event_map, state_res_store)

                if ev.rejected_reason is None:
                    auth_events[key] = event_map[ev_id]

        try:
            event_auth.check(
                room_version,
                event,
                auth_events,
                do_sig_check=False,
                do_size_check=False,
            )

            resolved_state[(event.type, event.state_key)] = event_id
        except AuthError:
            pass

    defer.returnValue(resolved_state)
Esempio n. 11
0
def _iterative_auth_checks(room_version, event_ids, base_state, event_map,
                           state_res_store):
    """Sequentially apply auth checks to each event in given list, updating the
    state as it goes along.

    Args:
        room_version (str)
        event_ids (list[str]): Ordered list of events to apply auth checks to
        base_state (dict[tuple[str, str], str]): The set of state to start with
        event_map (dict[str,FrozenEvent])
        state_res_store (StateResolutionStore)

    Returns:
        Deferred[dict[tuple[str, str], str]]: Returns the final updated state
    """
    resolved_state = base_state.copy()

    for event_id in event_ids:
        event = event_map[event_id]

        auth_events = {}
        for aid in event.auth_event_ids():
            ev = yield _get_event(aid, event_map, state_res_store)

            if ev.rejected_reason is None:
                auth_events[(ev.type, ev.state_key)] = ev

        for key in event_auth.auth_types_for_event(event):
            if key in resolved_state:
                ev_id = resolved_state[key]
                ev = yield _get_event(ev_id, event_map, state_res_store)

                if ev.rejected_reason is None:
                    auth_events[key] = event_map[ev_id]

        try:
            event_auth.check(
                room_version, event, auth_events,
                do_sig_check=False,
                do_size_check=False
            )

            resolved_state[(event.type, event.state_key)] = event_id
        except AuthError:
            pass

    defer.returnValue(resolved_state)
Esempio n. 12
0
    async def _get_power_levels_and_sender_level(
            self, event: EventBase, context: EventContext) -> Tuple[dict, int]:
        event_types = auth_types_for_event(event.room_version, event)
        prev_state_ids = await context.get_prev_state_ids(
            StateFilter.from_types(event_types))
        pl_event_id = prev_state_ids.get(POWER_KEY)

        if pl_event_id:
            # fastpath: if there's a power level event, that's all we need, and
            # not having a power level event is an extreme edge case
            auth_events = {POWER_KEY: await self.store.get_event(pl_event_id)}
        else:
            auth_events_ids = self._event_auth_handler.compute_auth_events(
                event, prev_state_ids, for_verification=False)
            auth_events_dict = await self.store.get_events(auth_events_ids)
            auth_events = {(e.type, e.state_key): e
                           for e in auth_events_dict.values()}

        sender_level = get_user_power_level(event.sender, auth_events)

        pl_event = auth_events.get(POWER_KEY)

        return pl_event.content if pl_event else {}, sender_level
Esempio n. 13
0
def _iterative_auth_checks(clock, room_id, room_version, event_ids, base_state,
                           event_map, state_res_store):
    """Sequentially apply auth checks to each event in given list, updating the
    state as it goes along.

    Args:
        clock (Clock)
        room_id (str)
        room_version (str)
        event_ids (list[str]): Ordered list of events to apply auth checks to
        base_state (StateMap[str]): The set of state to start with
        event_map (dict[str,FrozenEvent])
        state_res_store (StateResolutionStore)

    Returns:
        Deferred[StateMap[str]]: Returns the final updated state
    """
    resolved_state = base_state.copy()
    room_version_obj = KNOWN_ROOM_VERSIONS[room_version]

    for idx, event_id in enumerate(event_ids, start=1):
        event = event_map[event_id]

        auth_events = {}
        for aid in event.auth_event_ids():
            ev = yield _get_event(room_id,
                                  aid,
                                  event_map,
                                  state_res_store,
                                  allow_none=True)

            if not ev:
                logger.warning("auth_event id %s for event %s is missing", aid,
                               event_id)
            else:
                if ev.rejected_reason is None:
                    auth_events[(ev.type, ev.state_key)] = ev

        for key in event_auth.auth_types_for_event(event):
            if key in resolved_state:
                ev_id = resolved_state[key]
                ev = yield _get_event(room_id, ev_id, event_map,
                                      state_res_store)

                if ev.rejected_reason is None:
                    auth_events[key] = event_map[ev_id]

        try:
            event_auth.check(
                room_version_obj,
                event,
                auth_events,
                do_sig_check=False,
                do_size_check=False,
            )

            resolved_state[(event.type, event.state_key)] = event_id
        except AuthError:
            pass

        # We yield occasionally when we're working with large data sets to
        # ensure that we don't block the reactor loop for too long.
        if idx % _YIELD_AFTER_ITERATIONS == 0:
            yield clock.sleep(0)

    return resolved_state
Esempio n. 14
0
    def do_check(
        self,
        events: List[FakeEvent],
        edges: List[List[str]],
        expected_state_ids: List[str],
    ) -> None:
        """Take a list of events and edges and calculate the state of the
        graph at END, and asserts it matches `expected_state_ids`

        Args:
            events
            edges: A list of chains of event edges, e.g.
                `[[A, B, C]]` are edges A->B and B->C.
            expected_state_ids: The expected state at END, (excluding
                the keys that haven't changed since START).
        """
        # We want to sort the events into topological order for processing.
        graph: Dict[str, Set[str]] = {}

        fake_event_map: Dict[str, FakeEvent] = {}

        for ev in itertools.chain(INITIAL_EVENTS, events):
            graph[ev.node_id] = set()
            fake_event_map[ev.node_id] = ev

        for a, b in pairwise(INITIAL_EDGES):
            graph[a].add(b)

        for edge_list in edges:
            for a, b in pairwise(edge_list):
                graph[a].add(b)

        event_map: Dict[str, EventBase] = {}
        state_at_event: Dict[str, StateMap[str]] = {}

        # We copy the map as the sort consumes the graph
        graph_copy = {k: set(v) for k, v in graph.items()}

        for node_id in lexicographical_topological_sort(graph_copy,
                                                        key=lambda e: e):
            fake_event = fake_event_map[node_id]
            event_id = fake_event.event_id

            prev_events = list(graph[node_id])

            state_before: StateMap[str]
            if len(prev_events) == 0:
                state_before = {}
            elif len(prev_events) == 1:
                state_before = dict(state_at_event[prev_events[0]])
            else:
                state_d = resolve_events_with_store(
                    FakeClock(),
                    ROOM_ID,
                    RoomVersions.V2,
                    [state_at_event[n] for n in prev_events],
                    event_map=event_map,
                    state_res_store=TestStateResolutionStore(event_map),
                )

                state_before = self.successResultOf(
                    defer.ensureDeferred(state_d))

            state_after = dict(state_before)
            if fake_event.state_key is not None:
                state_after[(fake_event.type, fake_event.state_key)] = event_id

            # This type ignore is a bit sad. Things we have tried:
            # 1. Define a `GenericEvent` Protocol satisfied by FakeEvent, EventBase and
            #    EventBuilder. But this is Hard because the relevant attributes are
            #    DictProperty[T] descriptors on EventBase but normal Ts on FakeEvent.
            # 2. Define a `GenericEvent` Protocol describing `FakeEvent` only, and
            #    change this function to accept Union[Event, EventBase, EventBuilder].
            #    This seems reasonable to me, but mypy isn't happy. I think that's
            #    a mypy bug, see https://github.com/python/mypy/issues/5570
            # Instead, resort to a type-ignore.
            auth_types = set(auth_types_for_event(
                RoomVersions.V6, fake_event))  # type: ignore[arg-type]

            auth_events = []
            for key in auth_types:
                if key in state_before:
                    auth_events.append(state_before[key])

            event = fake_event.to_event(auth_events, prev_events)

            state_at_event[node_id] = state_after
            event_map[event_id] = event

        expected_state = {}
        for node_id in expected_state_ids:
            # expected_state_ids are node IDs rather than event IDs,
            # so we have to convert
            event_id = EventID(node_id, "example.com").to_string()
            event = event_map[event_id]

            key = (event.type, event.state_key)

            expected_state[key] = event_id

        start_state = state_at_event["START"]
        end_state = {
            key: value
            for key, value in state_at_event["END"].items()
            if key in expected_state or start_state.get(key) != value
        }

        self.assertEqual(expected_state, end_state)
Esempio n. 15
0
def resolve(graph_desc, resolution_func):
    """Given graph description and state resolution algorithm, compute the end
    state of the graph and compare against the expected state defined in the
    graph description
    """

    graph, _, event_map = create_dag(graph_desc)

    state_past_event = {}
    for eid in reversed(list(topological_sort(graph))):
        event = event_map[eid]

        prev_states = []
        for pid, _ in event.prev_events:
            prev_states.append(state_past_event[pid])

        state_ids = {}
        if len(prev_states) == 1:
            state_ids = prev_states[0]
        elif len(prev_states) > 1:
            state_ids = resolution_func(
                prev_states,
                event_map,
            )

        auth_events = {
            key: event_map[state_ids[key]]
            for key in event_auth.auth_types_for_event(event)
            if key in state_ids
        }

        try:
            event_auth.check(
                event,
                auth_events,
                do_sig_check=False,
                do_size_check=False,
            )
        except AuthError as e:
            print("Failed to auth event", eid, " because:", e)
            return

        if event.is_state():
            state_ids = dict(state_ids)
            state_ids[(event.type, event.state_key)] = eid

        state_past_event[eid] = state_ids

    start_state = state_past_event[to_event_id("START")]
    end_state = state_past_event[to_event_id("END")]

    expected_state = {}
    for eid in graph_desc["expected_state"]:
        ev = event_map[to_event_id(eid)]
        expected_state[(ev.type, ev.state_key)] = to_event_id(eid)

    mismatches = []
    for key in set(itertools.chain(end_state, expected_state)):
        expected_id = expected_state.get(key)
        actual_id = end_state.get(key)
        if actual_id == start_state.get(key) and not expected_id:
            continue

        if expected_id != actual_id:
            mismatches.append((key[0], key[1], expected_id, actual_id))

    if mismatches:
        print("Unexpected end state\n")
        print(
            tabulate(
                mismatches,
                headers=["Type", "State Key", "Expected", "Got"],
            ))
    else:
        print("Everything matched!")
Esempio n. 16
0
async def _iterative_auth_checks(
    clock: Clock,
    room_id: str,
    room_version: RoomVersion,
    event_ids: List[str],
    base_state: StateMap[str],
    event_map: Dict[str, EventBase],
    state_res_store: StateResolutionStore,
) -> MutableStateMap[str]:
    """Sequentially apply auth checks to each event in given list, updating the
    state as it goes along.

    Args:
        clock
        room_id
        room_version
        event_ids: Ordered list of events to apply auth checks to
        base_state: The set of state to start with
        event_map
        state_res_store

    Returns:
        Returns the final updated state
    """
    resolved_state = dict(base_state)

    for idx, event_id in enumerate(event_ids, start=1):
        event = event_map[event_id]

        auth_events = {}
        for aid in event.auth_event_ids():
            ev = await _get_event(room_id,
                                  aid,
                                  event_map,
                                  state_res_store,
                                  allow_none=True)

            if not ev:
                logger.warning("auth_event id %s for event %s is missing", aid,
                               event_id)
            else:
                if ev.rejected_reason is None:
                    auth_events[(ev.type, ev.state_key)] = ev

        for key in event_auth.auth_types_for_event(room_version, event):
            if key in resolved_state:
                ev_id = resolved_state[key]
                ev = await _get_event(room_id, ev_id, event_map,
                                      state_res_store)

                if ev.rejected_reason is None:
                    auth_events[key] = event_map[ev_id]

        try:
            event_auth.check_state_dependent_auth_rules(
                event,
                auth_events.values(),
            )

            resolved_state[(event.type, event.state_key)] = event_id
        except AuthError:
            pass

        # We await occasionally when we're working with large data sets to
        # ensure that we don't block the reactor loop for too long.
        if idx % _AWAIT_AFTER_ITERATIONS == 0:
            await clock.sleep(0)

    return resolved_state
Esempio n. 17
0
    def do_check(self, events, edges, expected_state_ids):
        """Take a list of events and edges and calculate the state of the
        graph at END, and asserts it matches `expected_state_ids`

        Args:
            events (list[FakeEvent])
            edges (list[list[str]]): A list of chains of event edges, e.g.
                `[[A, B, C]]` are edges A->B and B->C.
            expected_state_ids (list[str]): The expected state at END, (excluding
                the keys that haven't changed since START).
        """
        # We want to sort the events into topological order for processing.
        graph = {}

        # node_id -> FakeEvent
        fake_event_map = {}

        for ev in itertools.chain(INITIAL_EVENTS, events):
            graph[ev.node_id] = set()
            fake_event_map[ev.node_id] = ev

        for a, b in pairwise(INITIAL_EDGES):
            graph[a].add(b)

        for edge_list in edges:
            for a, b in pairwise(edge_list):
                graph[a].add(b)

        # event_id -> FrozenEvent
        event_map = {}
        # node_id -> state
        state_at_event = {}

        # We copy the map as the sort consumes the graph
        graph_copy = {k: set(v) for k, v in graph.items()}

        for node_id in lexicographical_topological_sort(graph_copy,
                                                        key=lambda e: e):
            fake_event = fake_event_map[node_id]
            event_id = fake_event.event_id

            prev_events = list(graph[node_id])

            if len(prev_events) == 0:
                state_before = {}
            elif len(prev_events) == 1:
                state_before = dict(state_at_event[prev_events[0]])
            else:
                state_d = resolve_events_with_store(
                    FakeClock(),
                    ROOM_ID,
                    RoomVersions.V2.identifier,
                    [state_at_event[n] for n in prev_events],
                    event_map=event_map,
                    state_res_store=TestStateResolutionStore(event_map),
                )

                state_before = self.successResultOf(
                    defer.ensureDeferred(state_d))

            state_after = dict(state_before)
            if fake_event.state_key is not None:
                state_after[(fake_event.type, fake_event.state_key)] = event_id

            auth_types = set(auth_types_for_event(fake_event))

            auth_events = []
            for key in auth_types:
                if key in state_before:
                    auth_events.append(state_before[key])

            event = fake_event.to_event(auth_events, prev_events)

            state_at_event[node_id] = state_after
            event_map[event_id] = event

        expected_state = {}
        for node_id in expected_state_ids:
            # expected_state_ids are node IDs rather than event IDs,
            # so we have to convert
            event_id = EventID(node_id, "example.com").to_string()
            event = event_map[event_id]

            key = (event.type, event.state_key)

            expected_state[key] = event_id

        start_state = state_at_event["START"]
        end_state = {
            key: value
            for key, value in state_at_event["END"].items()
            if key in expected_state or start_state.get(key) != value
        }

        self.assertEqual(expected_state, end_state)
Esempio n. 18
0
    def do_check(self, events, edges, expected_state_ids):
        """Take a list of events and edges and calculate the state of the
        graph at END, and asserts it matches `expected_state_ids`

        Args:
            events (list[FakeEvent])
            edges (list[list[str]]): A list of chains of event edges, e.g.
                `[[A, B, C]]` are edges A->B and B->C.
            expected_state_ids (list[str]): The expected state at END, (excluding
                the keys that haven't changed since START).
        """
        # We want to sort the events into topological order for processing.
        graph = {}

        # node_id -> FakeEvent
        fake_event_map = {}

        for ev in itertools.chain(INITIAL_EVENTS, events):
            graph[ev.node_id] = set()
            fake_event_map[ev.node_id] = ev

        for a, b in pairwise(INITIAL_EDGES):
            graph[a].add(b)

        for edge_list in edges:
            for a, b in pairwise(edge_list):
                graph[a].add(b)

        # event_id -> FrozenEvent
        event_map = {}
        # node_id -> state
        state_at_event = {}

        # We copy the map as the sort consumes the graph
        graph_copy = {k: set(v) for k, v in graph.items()}

        for node_id in lexicographical_topological_sort(graph_copy, key=lambda e: e):
            fake_event = fake_event_map[node_id]
            event_id = fake_event.event_id

            prev_events = list(graph[node_id])

            if len(prev_events) == 0:
                state_before = {}
            elif len(prev_events) == 1:
                state_before = dict(state_at_event[prev_events[0]])
            else:
                state_d = resolve_events_with_store(
                    [state_at_event[n] for n in prev_events],
                    event_map=event_map,
                    state_res_store=TestStateResolutionStore(event_map),
                )

                state_before = self.successResultOf(state_d)

            state_after = dict(state_before)
            if fake_event.state_key is not None:
                state_after[(fake_event.type, fake_event.state_key)] = event_id

            auth_types = set(auth_types_for_event(fake_event))

            auth_events = []
            for key in auth_types:
                if key in state_before:
                    auth_events.append(state_before[key])

            event = fake_event.to_event(auth_events, prev_events)

            state_at_event[node_id] = state_after
            event_map[event_id] = event

        expected_state = {}
        for node_id in expected_state_ids:
            # expected_state_ids are node IDs rather than event IDs,
            # so we have to convert
            event_id = EventID(node_id, "example.com").to_string()
            event = event_map[event_id]

            key = (event.type, event.state_key)

            expected_state[key] = event_id

        start_state = state_at_event["START"]
        end_state = {
            key: value
            for key, value in state_at_event["END"].items()
            if key in expected_state or start_state.get(key) != value
        }

        self.assertEqual(expected_state, end_state)