def test_state_default_level(self): """ Check that users above the state_default level can send state and those below cannot """ creator = "@creator:example.com" pleb = "@joiner:example.com" king = "@joiner2:example.com" auth_events = { ("m.room.create", ""): _create_event(creator), ("m.room.member", creator): _join_event(creator), ("m.room.power_levels", ""): _power_levels_event( creator, {"state_default": "30", "users": {pleb: "29", king: "30"}} ), ("m.room.member", pleb): _join_event(pleb), ("m.room.member", king): _join_event(king), } # pleb should not be able to send state self.assertRaises( AuthError, event_auth.check, RoomVersions.V1, _random_state_event(pleb), auth_events, do_sig_check=False, ), # king should be able to send state event_auth.check( RoomVersions.V1, _random_state_event(king), auth_events, do_sig_check=False, )
def _iterative_auth_checks(room_id, room_version, event_ids, base_state, event_map, state_res_store): """Sequentially apply auth checks to each event in given list, updating the state as it goes along. Args: room_id (str) room_version (str) event_ids (list[str]): Ordered list of events to apply auth checks to base_state (StateMap[str]): The set of state to start with event_map (dict[str,FrozenEvent]) state_res_store (StateResolutionStore) Returns: Deferred[StateMap[str]]: Returns the final updated state """ resolved_state = base_state.copy() room_version_obj = KNOWN_ROOM_VERSIONS[room_version] for event_id in event_ids: event = event_map[event_id] auth_events = {} for aid in event.auth_event_ids(): ev = yield _get_event(room_id, aid, event_map, state_res_store, allow_none=True) if not ev: logger.warning("auth_event id %s for event %s is missing", aid, event_id) else: if ev.rejected_reason is None: auth_events[(ev.type, ev.state_key)] = ev for key in event_auth.auth_types_for_event(event): if key in resolved_state: ev_id = resolved_state[key] ev = yield _get_event(room_id, ev_id, event_map, state_res_store) if ev.rejected_reason is None: auth_events[key] = event_map[ev_id] try: event_auth.check( room_version_obj, event, auth_events, do_sig_check=False, do_size_check=False, ) resolved_state[(event.type, event.state_key)] = event_id except AuthError: pass return resolved_state
def _resolve_auth_events(room_version: RoomVersion, events: List[EventBase], auth_events: StateMap[EventBase]) -> EventBase: reverse = list(reversed(_ordered_events(events))) auth_keys = { key for event in events for key in event_auth.auth_types_for_event(room_version, event) } new_auth_events = {} for key in auth_keys: auth_event = auth_events.get(key, None) if auth_event: new_auth_events[key] = auth_event auth_events = new_auth_events prev_event = reverse[0] for event in reverse[1:]: auth_events[(prev_event.type, prev_event.state_key)] = prev_event try: # The signatures have already been checked at this point event_auth.check( RoomVersions.V1, event, auth_events, do_sig_check=False, do_size_check=False, ) prev_event = event except AuthError: return prev_event return event
def _iterative_auth_checks(event_ids, base_state, event_map): """Sequentially apply auth checks to each event in given list, updating the state as it goes along. """ resolved_state = base_state.copy() for event_id in event_ids: event = event_map[event_id] auth_events = {(event_map[aid].type, event_map[aid].state_key): event_map[aid] for aid, _ in event.auth_events} for key in event_auth.auth_types_for_event(event): if key in resolved_state: auth_events[key] = event_map[resolved_state[key]] try: event_auth.check(event, auth_events, do_sig_check=False, do_size_check=False) resolved_state[(event.type, event.state_key)] = event_id except AuthError: pass return resolved_state
def test_random_users_cannot_send_state_before_first_pl(self): """ Check that, before the first PL lands, the creator is the only user that can send a state event. """ creator = "@creator:example.com" joiner = "@joiner:example.com" auth_events = { ("m.room.create", ""): _create_event(creator), ("m.room.member", creator): _join_event(creator), ("m.room.member", joiner): _join_event(joiner), } # creator should be able to send state event_auth.check( RoomVersions.V1, _random_state_event(creator), auth_events, do_sig_check=False, ) # joiner should not be able to send state self.assertRaises( AuthError, event_auth.check, RoomVersions.V1, _random_state_event(joiner), auth_events, do_sig_check=False, )
def test_random_users_cannot_send_state_before_first_pl(self): """ Check that, before the first PL lands, the creator is the only user that can send a state event. """ creator = "@creator:example.com" joiner = "@joiner:example.com" auth_events = { ("m.room.create", ""): _create_event(creator), ("m.room.member", creator): _join_event(creator), ("m.room.member", joiner): _join_event(joiner), } # creator should be able to send state event_auth.check( RoomVersions.V1.identifier, _random_state_event(creator), auth_events, do_sig_check=False, ) # joiner should not be able to send state self.assertRaises( AuthError, event_auth.check, RoomVersions.V1.identifier, _random_state_event(joiner), auth_events, do_sig_check=False, ),
def _resolve_auth_events(events, auth_events): reverse = [i for i in reversed(_ordered_events(events))] auth_keys = set(key for event in events for key in event_auth.auth_types_for_event(event)) new_auth_events = {} for key in auth_keys: auth_event = auth_events.get(key, None) if auth_event: new_auth_events[key] = auth_event auth_events = new_auth_events prev_event = reverse[0] for event in reverse[1:]: auth_events[(prev_event.type, prev_event.state_key)] = prev_event try: # The signatures have already been checked at this point event_auth.check( RoomVersions.V1.identifier, event, auth_events, do_sig_check=False, do_size_check=False, ) prev_event = event except AuthError: return prev_event return event
def _resolve_auth_events(events, auth_events): reverse = [i for i in reversed(_ordered_events(events))] auth_keys = set( key for event in events for key in event_auth.auth_types_for_event(event) ) new_auth_events = {} for key in auth_keys: auth_event = auth_events.get(key, None) if auth_event: new_auth_events[key] = auth_event auth_events = new_auth_events prev_event = reverse[0] for event in reverse[1:]: auth_events[(prev_event.type, prev_event.state_key)] = prev_event try: # The signatures have already been checked at this point event_auth.check(event, auth_events, do_sig_check=False, do_size_check=False) prev_event = event except AuthError: return prev_event return event
def check_from_context(self, room_version, event, context, do_sig_check=True): prev_state_ids = yield context.get_prev_state_ids(self.store) auth_events_ids = yield self.compute_auth_events( event, prev_state_ids, for_verification=True ) auth_events = yield self.store.get_events(auth_events_ids) auth_events = {(e.type, e.state_key): e for e in itervalues(auth_events)} event_auth.check( room_version, event, auth_events=auth_events, do_sig_check=do_sig_check )
def _resolve_normal_events(events, auth_events): for event in _ordered_events(events): try: # The signatures have already been checked at this point event_auth.check(event, auth_events, do_sig_check=False, do_size_check=False) return event except AuthError: pass # Use the last event (the one with the least depth) if they all fail # the auth check. return event
def _resolve_normal_events(events, auth_events): for event in _ordered_events(events): try: # The signatures have already been checked at this point event_auth.check(event, auth_events, do_sig_check=False, do_size_check=False) return event except AuthError: pass # Use the last event (the one with the least depth) if they all fail # the auth check. return event
def check(self, event, auth_events, do_sig_check=True): """ Checks if this event is correctly authed. Args: event: the event being checked. auth_events (dict: event-key -> event): the existing room state. Returns: True if the auth checks pass. """ with Measure(self.clock, "auth.check"): event_auth.check(event, auth_events, do_sig_check=do_sig_check)
def check(self, event, auth_events, do_sig_check=True): """ Checks if this event is correctly authed. Args: event: the event being checked. auth_events (dict: event-key -> event): the existing room state. Returns: True if the auth checks pass. """ with Measure(self.clock, "auth.check"): event_auth.check(event, auth_events, do_sig_check=do_sig_check)
async def check_from_context( self, room_version: str, event, context, do_sig_check=True ): prev_state_ids = await context.get_prev_state_ids() auth_events_ids = self.compute_auth_events( event, prev_state_ids, for_verification=True ) auth_events = await self.store.get_events(auth_events_ids) auth_events = {(e.type, e.state_key): e for e in auth_events.values()} room_version_obj = KNOWN_ROOM_VERSIONS[room_version] event_auth.check( room_version_obj, event, auth_events=auth_events, do_sig_check=do_sig_check )
def _iterative_auth_checks(room_version, event_ids, base_state, event_map, state_res_store): """Sequentially apply auth checks to each event in given list, updating the state as it goes along. Args: room_version (str) event_ids (list[str]): Ordered list of events to apply auth checks to base_state (dict[tuple[str, str], str]): The set of state to start with event_map (dict[str,FrozenEvent]) state_res_store (StateResolutionStore) Returns: Deferred[dict[tuple[str, str], str]]: Returns the final updated state """ resolved_state = base_state.copy() for event_id in event_ids: event = event_map[event_id] auth_events = {} for aid in event.auth_event_ids(): ev = yield _get_event(aid, event_map, state_res_store) if ev.rejected_reason is None: auth_events[(ev.type, ev.state_key)] = ev for key in event_auth.auth_types_for_event(event): if key in resolved_state: ev_id = resolved_state[key] ev = yield _get_event(ev_id, event_map, state_res_store) if ev.rejected_reason is None: auth_events[key] = event_map[ev_id] try: event_auth.check( room_version, event, auth_events, do_sig_check=False, do_size_check=False, ) resolved_state[(event.type, event.state_key)] = event_id except AuthError: pass defer.returnValue(resolved_state)
async def check_from_context(self, room_version: str, event, context, do_sig_check=True) -> None: auth_event_ids = event.auth_event_ids() auth_events_by_id = await self._store.get_events(auth_event_ids) auth_events = {(e.type, e.state_key): e for e in auth_events_by_id.values()} room_version_obj = KNOWN_ROOM_VERSIONS[room_version] event_auth.check(room_version_obj, event, auth_events=auth_events, do_sig_check=do_sig_check)
def _iterative_auth_checks(room_version, event_ids, base_state, event_map, state_res_store): """Sequentially apply auth checks to each event in given list, updating the state as it goes along. Args: room_version (str) event_ids (list[str]): Ordered list of events to apply auth checks to base_state (dict[tuple[str, str], str]): The set of state to start with event_map (dict[str,FrozenEvent]) state_res_store (StateResolutionStore) Returns: Deferred[dict[tuple[str, str], str]]: Returns the final updated state """ resolved_state = base_state.copy() for event_id in event_ids: event = event_map[event_id] auth_events = {} for aid in event.auth_event_ids(): ev = yield _get_event(aid, event_map, state_res_store) if ev.rejected_reason is None: auth_events[(ev.type, ev.state_key)] = ev for key in event_auth.auth_types_for_event(event): if key in resolved_state: ev_id = resolved_state[key] ev = yield _get_event(ev_id, event_map, state_res_store) if ev.rejected_reason is None: auth_events[key] = event_map[ev_id] try: event_auth.check( room_version, event, auth_events, do_sig_check=False, do_size_check=False ) resolved_state[(event.type, event.state_key)] = event_id except AuthError: pass defer.returnValue(resolved_state)
def test_msc2209(self): """ Notifications power levels get checked due to MSC2209. """ creator = "@creator:example.com" pleb = "@joiner:example.com" auth_events = { ("m.room.create", ""): _create_event(creator), ("m.room.member", creator): _join_event(creator), ("m.room.power_levels", ""): _power_levels_event(creator, { "state_default": "30", "users": { pleb: "30" } }), ("m.room.member", pleb): _join_event(pleb), } # pleb should be able to modify the notifications power level. event_auth.check( RoomVersions.V1, _power_levels_event(pleb, {"notifications": { "room": 100 }}), auth_events, do_sig_check=False, ) # But an MSC2209 room rejects this change. with self.assertRaises(AuthError): event_auth.check( RoomVersions.V6, _power_levels_event(pleb, {"notifications": { "room": 100 }}), auth_events, do_sig_check=False, )
def _resolve_normal_events(events: List[EventBase], auth_events: StateMap[EventBase]) -> EventBase: for event in _ordered_events(events): try: # The signatures have already been checked at this point event_auth.check( RoomVersions.V1, event, auth_events, do_sig_check=False, do_size_check=False, ) return event except AuthError: pass # Use the last event (the one with the least depth) if they all fail # the auth check. return event
def test_state_default_level(self): """ Check that users above the state_default level can send state and those below cannot """ creator = "@creator:example.com" pleb = "@joiner:example.com" king = "@joiner2:example.com" auth_events = { ("m.room.create", ""): _create_event(creator), ("m.room.member", creator): _join_event(creator), ("m.room.power_levels", ""): _power_levels_event( creator, {"state_default": "30", "users": {pleb: "29", king: "30"}} ), ("m.room.member", pleb): _join_event(pleb), ("m.room.member", king): _join_event(king), } # pleb should not be able to send state self.assertRaises( AuthError, event_auth.check, RoomVersions.V1.identifier, _random_state_event(pleb), auth_events, do_sig_check=False, ), # king should be able to send state event_auth.check( RoomVersions.V1.identifier, _random_state_event(king), auth_events, do_sig_check=False, )
def test_alias_event(self): """Alias events have special behavior up through room version 6.""" creator = "@creator:example.com" other = "@other:example.com" auth_events = { ("m.room.create", ""): _create_event(creator), ("m.room.member", creator): _join_event(creator), } # creator should be able to send aliases event_auth.check( RoomVersions.V1, _alias_event(creator), auth_events, do_sig_check=False, ) # Reject an event with no state key. with self.assertRaises(AuthError): event_auth.check( RoomVersions.V1, _alias_event(creator, state_key=""), auth_events, do_sig_check=False, ) # If the domain of the sender does not match the state key, reject. with self.assertRaises(AuthError): event_auth.check( RoomVersions.V1, _alias_event(creator, state_key="test.com"), auth_events, do_sig_check=False, ) # Note that the member does *not* need to be in the room. event_auth.check( RoomVersions.V1, _alias_event(other), auth_events, do_sig_check=False, )
def test_msc2432_alias_event(self): """After MSC2432, alias events have no special behavior.""" creator = "@creator:example.com" other = "@other:example.com" auth_events = { ("m.room.create", ""): _create_event(creator), ("m.room.member", creator): _join_event(creator), } # creator should be able to send aliases event_auth.check( RoomVersions.V6, _alias_event(creator), auth_events, do_sig_check=False, ) # No particular checks are done on the state key. event_auth.check( RoomVersions.V6, _alias_event(creator, state_key=""), auth_events, do_sig_check=False, ) event_auth.check( RoomVersions.V6, _alias_event(creator, state_key="test.com"), auth_events, do_sig_check=False, ) # Per standard auth rules, the member must be in the room. with self.assertRaises(AuthError): event_auth.check( RoomVersions.V6, _alias_event(other), auth_events, do_sig_check=False, )
def test_join_rules_msc3083_restricted(self): """ Test joining a restricted room from MSC3083. This is similar to the public test, but has some additional checks on signatures. The checks which care about signatures fake them by simply adding an object of the proper form, not generating valid signatures. """ creator = "@creator:example.com" pleb = "@joiner:example.com" auth_events = { ("m.room.create", ""): _create_event(creator), ("m.room.member", creator): _join_event(creator), ("m.room.power_levels", ""): _power_levels_event(creator, {"invite": 0}), ("m.room.join_rules", ""): _join_rules_event(creator, "restricted"), } # Older room versions don't understand this join rule with self.assertRaises(AuthError): event_auth.check( RoomVersions.V6, _join_event(pleb), auth_events, do_sig_check=False, ) # A properly formatted join event should work. authorised_join_event = _join_event( pleb, additional_content={ "join_authorised_via_users_server": "@creator:example.com" }, ) event_auth.check( RoomVersions.V8, authorised_join_event, auth_events, do_sig_check=False, ) # A join issued by a specific user works (i.e. the power level checks # are done properly). pl_auth_events = auth_events.copy() pl_auth_events[("m.room.power_levels", "")] = _power_levels_event(creator, { "invite": 100, "users": { "@inviter:foo.test": 150 } }) pl_auth_events[( "m.room.member", "@inviter:foo.test")] = _join_event("@inviter:foo.test") event_auth.check( RoomVersions.V8, _join_event( pleb, additional_content={ "join_authorised_via_users_server": "@inviter:foo.test" }, ), pl_auth_events, do_sig_check=False, ) # A join which is missing an authorised server is rejected. with self.assertRaises(AuthError): event_auth.check( RoomVersions.V8, _join_event(pleb), auth_events, do_sig_check=False, ) # An join authorised by a user who is not in the room is rejected. pl_auth_events = auth_events.copy() pl_auth_events[("m.room.power_levels", "")] = _power_levels_event(creator, { "invite": 100, "users": { "@other:example.com": 150 } }) with self.assertRaises(AuthError): event_auth.check( RoomVersions.V8, _join_event( pleb, additional_content={ "join_authorised_via_users_server": "@other:example.com" }, ), auth_events, do_sig_check=False, ) # A user cannot be force-joined to a room. (This uses an event which # *would* be valid, but is sent be a different user.) with self.assertRaises(AuthError): event_auth.check( RoomVersions.V8, _member_event( pleb, "join", sender=creator, additional_content={ "join_authorised_via_users_server": "@inviter:foo.test" }, ), auth_events, do_sig_check=False, ) # Banned should be rejected. auth_events[("m.room.member", pleb)] = _member_event(pleb, "ban") with self.assertRaises(AuthError): event_auth.check( RoomVersions.V8, authorised_join_event, auth_events, do_sig_check=False, ) # A user who left can re-join. auth_events[("m.room.member", pleb)] = _member_event(pleb, "leave") event_auth.check( RoomVersions.V8, authorised_join_event, auth_events, do_sig_check=False, ) # A user can send a join if they're in the room. (This doesn't need to # be authorised since the user is already joined.) auth_events[("m.room.member", pleb)] = _member_event(pleb, "join") event_auth.check( RoomVersions.V8, _join_event(pleb), auth_events, do_sig_check=False, ) # A user can accept an invite. (This doesn't need to be authorised since # the user was invited.) auth_events[("m.room.member", pleb)] = _member_event(pleb, "invite", sender=creator) event_auth.check( RoomVersions.V8, _join_event(pleb), auth_events, do_sig_check=False, )
def resolver(state_sets, event_map): """Given a set of state return the resolved state. Args: state_sets(list[dict[tuple[str, str], str]]): A list of dicts from type/state_key tuples to event_id event_map(dict[str, FrozenEvent]): Map from event_id to event Returns: dict[tuple[str, str], str]: The resolved state map. """ # First split up the un/conflicted state unconflicted_state, conflicted_state = _seperate(state_sets) # Also fetch all auth events that appear in only some of the state sets' # auth chains. auth_diff = _get_auth_chain_difference(state_sets, event_map) # Now order the conflicted state and auth_diff by power level (falling # back to event_id to tie break consistently). event_id_to_level = [ (_get_power_level_for_sender(event_id, event_map), event_id) for event_id in set( itertools.chain( itertools.chain.from_iterable(conflicted_state.values()), auth_diff, )) ] event_id_to_level.sort() events_sorted_by_power = [eid for _, eid in event_id_to_level] # Now we reorder the list to ensure that auth dependencies of an event # appear before the event in the list sorted_events = [] def add_to_list(event_id): event = event_map[event_id] for aid, _ in event.auth_events: if aid in events_sorted_by_power: events_sorted_by_power.remove(aid) add_to_list(aid) sorted_events.append(event_id) # First, lets pick out all the events that (probably) require power leftover_events = [] while events_sorted_by_power: event_id = events_sorted_by_power.pop() if _is_power_event(event_map[event_id]): add_to_list(event_id) else: leftover_events.append(event_id) # Now we go through the sorted events and auth each one in turn, using any # previously successfully auth'ed events (falling back to their auth events # if they don't exist) overridden_state = {} event_id_to_auth = {} for event_id in sorted_events: event = event_map[event_id] auth_events = {} for aid, _ in event.auth_events: aev = event_map[aid] auth_events[(aev.type, aev.state_key)] = aev for key, eid in overridden_state.items(): auth_events[key] = event_map[eid] try: event_auth.check(event, auth_events, do_sig_check=False, do_size_check=False) allowed = True overridden_state[(event.type, event.state_key)] = event_id except AuthError: allowed = False event_id_to_auth[event_id] = allowed resolved_state = {} # Now for each conflicted state type/state_key, pick the latest event that # has passed auth above, falling back to the first one if none passed auth. for key, conflicted_ids in conflicted_state.items(): sorted_conflicts = [] for eid in sorted_events: if eid in conflicted_ids: sorted_conflicts.append(eid) sorted_conflicts.reverse() for eid in sorted_conflicts: if event_id_to_auth[eid]: resolved_eid = eid resolved_state[key] = resolved_eid break resolved_state.update(unconflicted_state) # OK, so we've now resolved the power events. Now mainline them. sorted_power_resolved = sorted(resolved_state.values()) mainline = [] def add_to_list_two(event_id): ev = event_map[event_id] for aid, _ in ev.auth_events: if aid not in mainline and event_id_to_auth.get(aid, True): add_to_list_two(aid) if event_id not in mainline: mainline.append(event_id) while sorted_power_resolved: ev_id = sorted_power_resolved.pop() ev = event_map[ev_id] if _is_power_event(ev): add_to_list_two(ev_id) mainline_map = {ev_id: i + 1 for i, ev_id in enumerate(mainline)} def get_mainline_depth(event_id): if event_id in mainline_map: return mainline_map[event_id] ev = event_map[event_id] if not ev.auth_events: return 0 depth = max(get_mainline_depth(aid) for aid, _ in ev.auth_events) return depth leftover_events_map = { ev_id: get_mainline_depth(ev_id) for ev_id in leftover_events } leftover_events.sort(key=lambda ev_id: (leftover_events_map[ev_id], ev_id)) for event_id in leftover_events: event = event_map[event_id] auth_events = {} for aid, _ in event.auth_events: aev = event_map[aid] auth_events[(aev.type, aev.state_key)] = aev for key, eid in overridden_state.items(): auth_events[key] = event_map[eid] try: event_auth.check(event, auth_events, do_sig_check=False, do_size_check=False) allowed = True overridden_state[(event.type, event.state_key)] = event_id except AuthError: allowed = False event_id_to_auth[event_id] = allowed for key, conflicted_ids in conflicted_state.items(): sorted_conflicts = [] for eid in leftover_events: if eid in conflicted_ids: sorted_conflicts.append(eid) sorted_conflicts.reverse() for eid in sorted_conflicts: if event_id_to_auth[eid]: resolved_eid = eid resolved_state[key] = resolved_eid break resolved_state.update(unconflicted_state) return resolved_state
def resolver(state_sets, event_map): """Given a set of state return the resolved state. Args: state_sets(list[dict[tuple[str, str], str]]): A list of dicts from type/state_key tuples to event_id event_map(dict[str, FrozenEvent]): Map from event_id to event Returns: dict[tuple[str, str], str]: The resolved state map. """ # First split up the un/conflicted state unconflicted_state, conflicted_state = _seperate(state_sets) # Also fetch all auth events that appear in only some of the state sets' # auth chains. auth_diff = _get_auth_chain_difference(state_sets, event_map) # Now order the conflicted state and auth_diff by power level (falling # back to event_id to tie break consistently). event_id_to_level = [ (_get_power_level_for_sender(event_id, event_map), event_id) for event_id in set( itertools.chain( itertools.chain.from_iterable(conflicted_state.values()), auth_diff, )) ] event_id_to_level.sort() events_sorted_by_power = [eid for _, eid in event_id_to_level] # Now we reorder the list to ensure that auth dependencies of an event # appear before the event in the list sorted_events = [] def add_to_list(event_id): event = event_map[event_id] for aid, _ in event.auth_events: if aid in events_sorted_by_power: events_sorted_by_power.remove(aid) add_to_list(aid) sorted_events.append(event_id) while events_sorted_by_power: ev = events_sorted_by_power.pop() add_to_list(ev) # Now we go through the sorted events and auth each one in turn, using any # previously successfully auth'ed events (falling back to their auth events # if they don't exist) overridden_state = {} event_id_to_auth = {} for event_id in sorted_events: event = event_map[event_id] auth_events = {} for aid, _ in event.auth_events: aev = event_map[aid] auth_events[(aev.type, aev.state_key)] = aev for key, eid in overridden_state.items(): auth_events[key] = event_map[eid] try: event_auth.check(event, auth_events, do_sig_check=False, do_size_check=False) allowed = True overridden_state[(event.type, event.state_key)] = event_id except AuthError: allowed = False event_id_to_auth[event_id] = allowed resolved_state = unconflicted_state # Now for each conflicted state type/state_key, pick the latest event tat # has passed auth above, falling back to the first one if none passed auth. for key, conflicted_ids in conflicted_state.items(): sorted_conflicts = [] for eid in sorted_events: if eid in conflicted_ids: sorted_conflicts.append(eid) sorted_conflicts.reverse() for eid in sorted_conflicts: if event_id_to_auth[eid]: resolved_eid = eid resolved_state[key] = resolved_eid break return resolved_state
def _iterative_auth_checks(clock, room_id, room_version, event_ids, base_state, event_map, state_res_store): """Sequentially apply auth checks to each event in given list, updating the state as it goes along. Args: clock (Clock) room_id (str) room_version (str) event_ids (list[str]): Ordered list of events to apply auth checks to base_state (StateMap[str]): The set of state to start with event_map (dict[str,FrozenEvent]) state_res_store (StateResolutionStore) Returns: Deferred[StateMap[str]]: Returns the final updated state """ resolved_state = base_state.copy() room_version_obj = KNOWN_ROOM_VERSIONS[room_version] for idx, event_id in enumerate(event_ids, start=1): event = event_map[event_id] auth_events = {} for aid in event.auth_event_ids(): ev = yield _get_event(room_id, aid, event_map, state_res_store, allow_none=True) if not ev: logger.warning("auth_event id %s for event %s is missing", aid, event_id) else: if ev.rejected_reason is None: auth_events[(ev.type, ev.state_key)] = ev for key in event_auth.auth_types_for_event(event): if key in resolved_state: ev_id = resolved_state[key] ev = yield _get_event(room_id, ev_id, event_map, state_res_store) if ev.rejected_reason is None: auth_events[key] = event_map[ev_id] try: event_auth.check( room_version_obj, event, auth_events, do_sig_check=False, do_size_check=False, ) resolved_state[(event.type, event.state_key)] = event_id except AuthError: pass # We yield occasionally when we're working with large data sets to # ensure that we don't block the reactor loop for too long. if idx % _YIELD_AFTER_ITERATIONS == 0: yield clock.sleep(0) return resolved_state
def resolve(graph_desc, resolution_func): """Given graph description and state resolution algorithm, compute the end state of the graph and compare against the expected state defined in the graph description """ graph, _, event_map = create_dag(graph_desc) state_past_event = {} for eid in reversed(list(topological_sort(graph))): event = event_map[eid] prev_states = [] for pid, _ in event.prev_events: prev_states.append(state_past_event[pid]) state_ids = {} if len(prev_states) == 1: state_ids = prev_states[0] elif len(prev_states) > 1: state_ids = resolution_func( prev_states, event_map, ) auth_events = { key: event_map[state_ids[key]] for key in event_auth.auth_types_for_event(event) if key in state_ids } try: event_auth.check( event, auth_events, do_sig_check=False, do_size_check=False, ) except AuthError as e: print("Failed to auth event", eid, " because:", e) return if event.is_state(): state_ids = dict(state_ids) state_ids[(event.type, event.state_key)] = eid state_past_event[eid] = state_ids start_state = state_past_event[to_event_id("START")] end_state = state_past_event[to_event_id("END")] expected_state = {} for eid in graph_desc["expected_state"]: ev = event_map[to_event_id(eid)] expected_state[(ev.type, ev.state_key)] = to_event_id(eid) mismatches = [] for key in set(itertools.chain(end_state, expected_state)): expected_id = expected_state.get(key) actual_id = end_state.get(key) if actual_id == start_state.get(key) and not expected_id: continue if expected_id != actual_id: mismatches.append((key[0], key[1], expected_id, actual_id)) if mismatches: print("Unexpected end state\n") print( tabulate( mismatches, headers=["Type", "State Key", "Expected", "Got"], )) else: print("Everything matched!")
def test_join_rules_invite(self): """ Test joining an invite only room. """ creator = "@creator:example.com" pleb = "@joiner:example.com" auth_events = { ("m.room.create", ""): _create_event(creator), ("m.room.member", creator): _join_event(creator), ("m.room.join_rules", ""): _join_rules_event(creator, "invite"), } # A join without an invite is rejected. with self.assertRaises(AuthError): event_auth.check( RoomVersions.V6, _join_event(pleb), auth_events, do_sig_check=False, ) # A user cannot be force-joined to a room. with self.assertRaises(AuthError): event_auth.check( RoomVersions.V6, _member_event(pleb, "join", sender=creator), auth_events, do_sig_check=False, ) # Banned should be rejected. auth_events[("m.room.member", pleb)] = _member_event(pleb, "ban") with self.assertRaises(AuthError): event_auth.check( RoomVersions.V6, _join_event(pleb), auth_events, do_sig_check=False, ) # A user who left cannot re-join. auth_events[("m.room.member", pleb)] = _member_event(pleb, "leave") with self.assertRaises(AuthError): event_auth.check( RoomVersions.V6, _join_event(pleb), auth_events, do_sig_check=False, ) # A user can send a join if they're in the room. auth_events[("m.room.member", pleb)] = _member_event(pleb, "join") event_auth.check( RoomVersions.V6, _join_event(pleb), auth_events, do_sig_check=False, ) # A user can accept an invite. auth_events[("m.room.member", pleb)] = _member_event(pleb, "invite", sender=creator) event_auth.check( RoomVersions.V6, _join_event(pleb), auth_events, do_sig_check=False, )
def test_join_rules_msc3083_restricted(self): """ Test joining a restricted room from MSC3083. This is pretty much the same test as public. """ creator = "@creator:example.com" pleb = "@joiner:example.com" auth_events = { ("m.room.create", ""): _create_event(creator), ("m.room.member", creator): _join_event(creator), ("m.room.join_rules", ""): _join_rules_event(creator, "restricted"), } # Older room versions don't understand this join rule with self.assertRaises(AuthError): event_auth.check( RoomVersions.V6, _join_event(pleb), auth_events, do_sig_check=False, ) # Check join. event_auth.check( RoomVersions.MSC3083, _join_event(pleb), auth_events, do_sig_check=False, ) # A user cannot be force-joined to a room. with self.assertRaises(AuthError): event_auth.check( RoomVersions.MSC3083, _member_event(pleb, "join", sender=creator), auth_events, do_sig_check=False, ) # Banned should be rejected. auth_events[("m.room.member", pleb)] = _member_event(pleb, "ban") with self.assertRaises(AuthError): event_auth.check( RoomVersions.MSC3083, _join_event(pleb), auth_events, do_sig_check=False, ) # A user who left can re-join. auth_events[("m.room.member", pleb)] = _member_event(pleb, "leave") event_auth.check( RoomVersions.MSC3083, _join_event(pleb), auth_events, do_sig_check=False, ) # A user can send a join if they're in the room. auth_events[("m.room.member", pleb)] = _member_event(pleb, "join") event_auth.check( RoomVersions.MSC3083, _join_event(pleb), auth_events, do_sig_check=False, ) # A user can accept an invite. auth_events[("m.room.member", pleb)] = _member_event(pleb, "invite", sender=creator) event_auth.check( RoomVersions.MSC3083, _join_event(pleb), auth_events, do_sig_check=False, )
async def _iterative_auth_checks( clock: Clock, room_id: str, room_version: RoomVersion, event_ids: List[str], base_state: StateMap[str], event_map: Dict[str, EventBase], state_res_store: "synapse.state.StateResolutionStore", ) -> MutableStateMap[str]: """Sequentially apply auth checks to each event in given list, updating the state as it goes along. Args: clock room_id room_version event_ids: Ordered list of events to apply auth checks to base_state: The set of state to start with event_map state_res_store Returns: Returns the final updated state """ resolved_state = dict(base_state) for idx, event_id in enumerate(event_ids, start=1): event = event_map[event_id] auth_events = {} for aid in event.auth_event_ids(): ev = await _get_event( room_id, aid, event_map, state_res_store, allow_none=True ) if not ev: logger.warning( "auth_event id %s for event %s is missing", aid, event_id ) else: if ev.rejected_reason is None: auth_events[(ev.type, ev.state_key)] = ev for key in event_auth.auth_types_for_event(room_version, event): if key in resolved_state: ev_id = resolved_state[key] ev = await _get_event(room_id, ev_id, event_map, state_res_store) if ev.rejected_reason is None: auth_events[key] = event_map[ev_id] try: event_auth.check( room_version, event, auth_events, do_sig_check=False, do_size_check=False, ) resolved_state[(event.type, event.state_key)] = event_id except AuthError: pass # We await occasionally when we're working with large data sets to # ensure that we don't block the reactor loop for too long. if idx % _AWAIT_AFTER_ITERATIONS == 0: await clock.sleep(0) return resolved_state