def setUp(self): self.dummy_store = _DummyStore() storage_controllers = Mock(main=self.dummy_store, state=self.dummy_store) hs = Mock( spec_set=[ "config", "get_datastores", "get_storage_controllers", "get_auth", "get_state_handler", "get_clock", "get_state_resolution_handler", "get_account_validity_handler", "get_macaroon_generator", "hostname", ] ) clock = cast(Clock, MockClock()) hs.config = default_config("tesths", True) hs.get_datastores.return_value = Mock(main=self.dummy_store) hs.get_state_handler.return_value = None hs.get_clock.return_value = clock hs.get_macaroon_generator.return_value = MacaroonGenerator( clock, "tesths", b"verysecret" ) hs.get_auth.return_value = Auth(hs) hs.get_state_resolution_handler = lambda: StateResolutionHandler(hs) hs.get_storage_controllers.return_value = storage_controllers self.state = StateHandler(hs) self.event_id = 0
def setUp(self): self.store = StateGroupStore() storage = Mock(main=self.store, state=self.store) hs = Mock( spec_set=[ "config", "get_datastore", "get_storage", "get_auth", "get_state_handler", "get_clock", "get_state_resolution_handler", "hostname", ] ) hs.config = default_config("tesths", True) hs.get_datastore.return_value = self.store hs.get_state_handler.return_value = None hs.get_clock.return_value = MockClock() hs.get_auth.return_value = Auth(hs) hs.get_state_resolution_handler = lambda: StateResolutionHandler(hs) hs.get_storage.return_value = storage self.state = StateHandler(hs) self.event_id = 0
def setUp(self): self.store = Mock( spec_set=[ "get_state_groups", "add_event_hashes", ] ) hs = Mock(spec=["get_datastore"]) hs.get_datastore.return_value = self.store self.state = StateHandler(hs) self.event_id = 0
def setUp(self): self.store = StateGroupStore() hs = Mock(spec_set=[ "get_datastore", "get_auth", "get_state_handler", "get_clock", "get_state_resolution_handler", ]) hs.get_datastore.return_value = self.store hs.get_state_handler.return_value = None hs.get_clock.return_value = MockClock() hs.get_auth.return_value = Auth(hs) hs.get_state_resolution_handler = lambda: StateResolutionHandler(hs) self.state = StateHandler(hs) self.event_id = 0
def setUp(self): self.store = Mock( spec_set=[ "get_state_groups", "add_event_hashes", ] ) hs = Mock(spec=[ "get_datastore", "get_auth", "get_state_handler", "get_clock", ]) hs.get_datastore.return_value = self.store hs.get_state_handler.return_value = None hs.get_auth.return_value = Auth(hs) hs.get_clock.return_value = MockClock() self.state = StateHandler(hs) self.event_id = 0
def setUp(self): self.persistence = Mock(spec=[ "get_unresolved_state_tree", "update_current_state", "get_latest_pdus_in_context", "get_current_state_pdu", "get_pdu", "get_power_level", ]) self.replication = Mock(spec=["get_pdu"]) hs = Mock(spec=["get_datastore", "get_replication_layer"]) hs.get_datastore.return_value = self.persistence hs.get_replication_layer.return_value = self.replication hs.hostname = "bob.com" self.state = StateHandler(hs)
def setUp(self): self.store = Mock(spec_set=["get_state_groups", "add_event_hashes"]) hs = Mock(spec_set=["get_datastore", "get_auth", "get_state_handler", "get_clock"]) hs.get_datastore.return_value = self.store hs.get_state_handler.return_value = None hs.get_clock.return_value = MockClock() hs.get_auth.return_value = Auth(hs) self.state = StateHandler(hs) self.event_id = 0
def setUp(self): self.store = Mock( spec_set=[ "get_state_groups", ] ) hs = Mock(spec=["get_datastore"]) hs.get_datastore.return_value = self.store self.state = StateHandler(hs) self.event_id = 0
def setUp(self): self.store = Mock( spec_set=[ "get_state_groups_ids", "add_event_hashes", "get_events", "get_next_state_group", "get_state_group_delta", ] ) hs = Mock(spec_set=[ "get_datastore", "get_auth", "get_state_handler", "get_clock", ]) hs.get_datastore.return_value = self.store hs.get_state_handler.return_value = None hs.get_clock.return_value = MockClock() hs.get_auth.return_value = Auth(hs) self.store.get_next_state_group.side_effect = Mock self.store.get_state_group_delta.return_value = (None, None) self.state = StateHandler(hs) self.event_id = 0
def setUp(self): self.persistence = Mock(spec=[ "get_unresolved_state_tree", "update_current_state", "get_latest_pdus_in_context", "get_current_state_pdu", "get_pdu", ]) self.replication = Mock(spec=["get_pdu"]) hs = Mock(spec=["get_datastore", "get_replication_layer"]) hs.get_datastore.return_value = self.persistence hs.get_replication_layer.return_value = self.replication hs.hostname = "bob.com" self.state = StateHandler(hs)
class StateTestCase(unittest.TestCase): def setUp(self): self.store = Mock( spec_set=[ "get_state_groups", ] ) hs = Mock(spec=["get_datastore"]) hs.get_datastore.return_value = self.store self.state = StateHandler(hs) self.event_id = 0 @defer.inlineCallbacks def test_annotate_with_old_message(self): event = self.create_event(type="test_message", name="event") old_state = [ self.create_event(type="test1", state_key="1"), self.create_event(type="test1", state_key="2"), self.create_event(type="test2", state_key=""), ] yield self.state.annotate_event_with_state(event, old_state=old_state) for k, v in event.old_state_events.items(): type, state_key = k self.assertEqual(type, v.type) self.assertEqual(state_key, v.state_key) self.assertEqual(set(old_state), set(event.old_state_events.values())) self.assertDictEqual(event.old_state_events, event.state_events) self.assertIsNone(event.state_group) @defer.inlineCallbacks def test_annotate_with_old_state(self): event = self.create_event(type="state", state_key="", name="event") old_state = [ self.create_event(type="test1", state_key="1"), self.create_event(type="test1", state_key="2"), self.create_event(type="test2", state_key=""), ] yield self.state.annotate_event_with_state(event, old_state=old_state) for k, v in event.old_state_events.items(): type, state_key = k self.assertEqual(type, v.type) self.assertEqual(state_key, v.state_key) self.assertEqual( set(old_state + [event]), set(event.old_state_events.values()) ) self.assertDictEqual(event.old_state_events, event.state_events) self.assertIsNone(event.state_group) @defer.inlineCallbacks def test_trivial_annotate_message(self): event = self.create_event(type="test_message", name="event") event.prev_events = [] old_state = [ self.create_event(type="test1", state_key="1"), self.create_event(type="test1", state_key="2"), self.create_event(type="test2", state_key=""), ] group_name = "group_name_1" self.store.get_state_groups.return_value = { group_name: old_state, } yield self.state.annotate_event_with_state(event) for k, v in event.old_state_events.items(): type, state_key = k self.assertEqual(type, v.type) self.assertEqual(state_key, v.state_key) self.assertEqual( set([e.event_id for e in old_state]), set([e.event_id for e in event.old_state_events.values()]) ) self.assertDictEqual( { k: v.event_id for k, v in event.old_state_events.items() }, { k: v.event_id for k, v in event.state_events.items() } ) self.assertEqual(group_name, event.state_group) @defer.inlineCallbacks def test_trivial_annotate_state(self): event = self.create_event(type="state", state_key="", name="event") event.prev_events = [] old_state = [ self.create_event(type="test1", state_key="1"), self.create_event(type="test1", state_key="2"), self.create_event(type="test2", state_key=""), ] group_name = "group_name_1" self.store.get_state_groups.return_value = { group_name: old_state, } yield self.state.annotate_event_with_state(event) for k, v in event.old_state_events.items(): type, state_key = k self.assertEqual(type, v.type) self.assertEqual(state_key, v.state_key) self.assertEqual( set([e.event_id for e in old_state]), set([e.event_id for e in event.old_state_events.values()]) ) self.assertEqual( set([e.event_id for e in old_state] + [event.event_id]), set([e.event_id for e in event.state_events.values()]) ) new_state = { k: v.event_id for k, v in event.state_events.items() } old_state = { k: v.event_id for k, v in event.old_state_events.items() } old_state[(event.type, event.state_key)] = event.event_id self.assertDictEqual( old_state, new_state ) self.assertIsNone(event.state_group) @defer.inlineCallbacks def test_resolve_message_conflict(self): event = self.create_event(type="test_message", name="event") event.prev_events = [] old_state_1 = [ self.create_event(type="test1", state_key="1"), self.create_event(type="test1", state_key="2"), self.create_event(type="test2", state_key=""), ] old_state_2 = [ self.create_event(type="test1", state_key="1"), self.create_event(type="test3", state_key="2"), self.create_event(type="test4", state_key=""), ] group_name_1 = "group_name_1" group_name_2 = "group_name_2" self.store.get_state_groups.return_value = { group_name_1: old_state_1, group_name_2: old_state_2, } yield self.state.annotate_event_with_state(event) self.assertEqual(len(event.old_state_events), 5) self.assertEqual( set([e.event_id for e in event.state_events.values()]), set([e.event_id for e in event.old_state_events.values()]) ) self.assertIsNone(event.state_group) @defer.inlineCallbacks def test_resolve_state_conflict(self): event = self.create_event(type="test4", state_key="", name="event") event.prev_events = [] old_state_1 = [ self.create_event(type="test1", state_key="1"), self.create_event(type="test1", state_key="2"), self.create_event(type="test2", state_key=""), ] old_state_2 = [ self.create_event(type="test1", state_key="1"), self.create_event(type="test3", state_key="2"), self.create_event(type="test4", state_key=""), ] group_name_1 = "group_name_1" group_name_2 = "group_name_2" self.store.get_state_groups.return_value = { group_name_1: old_state_1, group_name_2: old_state_2, } yield self.state.annotate_event_with_state(event) self.assertEqual(len(event.old_state_events), 5) expected_new = event.old_state_events expected_new[(event.type, event.state_key)] = event self.assertEqual( set([e.event_id for e in expected_new.values()]), set([e.event_id for e in event.state_events.values()]), ) self.assertIsNone(event.state_group) def create_event(self, name=None, type=None, state_key=None): self.event_id += 1 event_id = str(self.event_id) if not name: if state_key is not None: name = "<%s-%s>" % (type, state_key) else: name = "<%s>" % (type, ) event = Mock(name=name, spec=[]) event.type = type if state_key is not None: event.state_key = state_key event.event_id = event_id event.user_id = "@user_id:example.com" event.room_id = "!room_id:example.com" return event
class StateTestCase(unittest.TestCase): def setUp(self): self.store = Mock( spec_set=[ "get_state_groups", "add_event_hashes", ] ) hs = Mock(spec=[ "get_datastore", "get_auth", "get_state_handler", "get_clock", ]) hs.get_datastore.return_value = self.store hs.get_state_handler.return_value = None hs.get_auth.return_value = Auth(hs) hs.get_clock.return_value = MockClock() self.state = StateHandler(hs) self.event_id = 0 @defer.inlineCallbacks def test_branch_no_conflict(self): graph = Graph( nodes={ "START": DictObj( type=EventTypes.Create, state_key="", depth=1, ), "A": DictObj( type=EventTypes.Message, depth=2, ), "B": DictObj( type=EventTypes.Message, depth=3, ), "C": DictObj( type=EventTypes.Name, state_key="", depth=3, ), "D": DictObj( type=EventTypes.Message, depth=4, ), }, edges={ "A": ["START"], "B": ["A"], "C": ["A"], "D": ["B", "C"] } ) store = StateGroupStore() self.store.get_state_groups.side_effect = store.get_state_groups context_store = {} for event in graph.walk(): context = yield self.state.compute_event_context(event) store.store_state_groups(event, context) context_store[event.event_id] = context self.assertEqual(2, len(context_store["D"].current_state)) @defer.inlineCallbacks def test_branch_basic_conflict(self): graph = Graph( nodes={ "START": DictObj( type=EventTypes.Create, state_key="", content={"creator": "@user_id:example.com"}, depth=1, ), "A": DictObj( type=EventTypes.Member, state_key="@user_id:example.com", content={"membership": Membership.JOIN}, membership=Membership.JOIN, depth=2, ), "B": DictObj( type=EventTypes.Name, state_key="", depth=3, ), "C": DictObj( type=EventTypes.Name, state_key="", depth=4, ), "D": DictObj( type=EventTypes.Message, depth=5, ), }, edges={ "A": ["START"], "B": ["A"], "C": ["A"], "D": ["B", "C"] } ) store = StateGroupStore() self.store.get_state_groups.side_effect = store.get_state_groups context_store = {} for event in graph.walk(): context = yield self.state.compute_event_context(event) store.store_state_groups(event, context) context_store[event.event_id] = context self.assertSetEqual( {"START", "A", "C"}, {e.event_id for e in context_store["D"].current_state.values()} ) @defer.inlineCallbacks def test_branch_have_banned_conflict(self): graph = Graph( nodes={ "START": DictObj( type=EventTypes.Create, state_key="", content={"creator": "@user_id:example.com"}, depth=1, ), "A": DictObj( type=EventTypes.Member, state_key="@user_id:example.com", content={"membership": Membership.JOIN}, membership=Membership.JOIN, depth=2, ), "B": DictObj( type=EventTypes.Name, state_key="", depth=3, ), "C": DictObj( type=EventTypes.Member, state_key="@user_id_2:example.com", content={"membership": Membership.BAN}, membership=Membership.BAN, depth=4, ), "D": DictObj( type=EventTypes.Name, state_key="", depth=4, sender="@user_id_2:example.com", ), "E": DictObj( type=EventTypes.Message, depth=5, ), }, edges={ "A": ["START"], "B": ["A"], "C": ["B"], "D": ["B"], "E": ["C", "D"] } ) store = StateGroupStore() self.store.get_state_groups.side_effect = store.get_state_groups context_store = {} for event in graph.walk(): context = yield self.state.compute_event_context(event) store.store_state_groups(event, context) context_store[event.event_id] = context self.assertSetEqual( {"START", "A", "B", "C"}, {e.event_id for e in context_store["E"].current_state.values()} ) @defer.inlineCallbacks def test_branch_have_perms_conflict(self): userid1 = "@user_id:example.com" userid2 = "@user_id2:example.com" nodes = { "A1": DictObj( type=EventTypes.Create, state_key="", content={"creator": userid1}, depth=1, ), "A2": DictObj( type=EventTypes.Member, state_key=userid1, content={"membership": Membership.JOIN}, membership=Membership.JOIN, ), "A3": DictObj( type=EventTypes.Member, state_key=userid2, content={"membership": Membership.JOIN}, membership=Membership.JOIN, ), "A4": DictObj( type=EventTypes.PowerLevels, state_key="", content={ "events": {"m.room.name": 50}, "users": {userid1: 100, userid2: 60}, }, ), "A5": DictObj( type=EventTypes.Name, state_key="", ), "B": DictObj( type=EventTypes.PowerLevels, state_key="", content={ "events": {"m.room.name": 50}, "users": {userid2: 30}, }, ), "C": DictObj( type=EventTypes.Name, state_key="", sender=userid2, ), "D": DictObj( type=EventTypes.Message, ), } edges = { "A2": ["A1"], "A3": ["A2"], "A4": ["A3"], "A5": ["A4"], "B": ["A5"], "C": ["A5"], "D": ["B", "C"] } self._add_depths(nodes, edges) graph = Graph(nodes, edges) store = StateGroupStore() self.store.get_state_groups.side_effect = store.get_state_groups context_store = {} for event in graph.walk(): context = yield self.state.compute_event_context(event) store.store_state_groups(event, context) context_store[event.event_id] = context self.assertSetEqual( {"A1", "A2", "A3", "A5", "B"}, {e.event_id for e in context_store["D"].current_state.values()} ) def _add_depths(self, nodes, edges): def _get_depth(ev): node = nodes[ev] if 'depth' not in node: prevs = edges[ev] depth = max(_get_depth(prev) for prev in prevs) + 1 node['depth'] = depth return node['depth'] for n in nodes: _get_depth(n) @defer.inlineCallbacks def test_annotate_with_old_message(self): event = create_event(type="test_message", name="event") old_state = [ create_event(type="test1", state_key="1"), create_event(type="test1", state_key="2"), create_event(type="test2", state_key=""), ] context = yield self.state.compute_event_context( event, old_state=old_state ) for k, v in context.current_state.items(): type, state_key = k self.assertEqual(type, v.type) self.assertEqual(state_key, v.state_key) self.assertEqual( set(old_state), set(context.current_state.values()) ) self.assertIsNone(context.state_group) @defer.inlineCallbacks def test_annotate_with_old_state(self): event = create_event(type="state", state_key="", name="event") old_state = [ create_event(type="test1", state_key="1"), create_event(type="test1", state_key="2"), create_event(type="test2", state_key=""), ] context = yield self.state.compute_event_context( event, old_state=old_state ) for k, v in context.current_state.items(): type, state_key = k self.assertEqual(type, v.type) self.assertEqual(state_key, v.state_key) self.assertEqual( set(old_state), set(context.current_state.values()) ) self.assertIsNone(context.state_group) @defer.inlineCallbacks def test_trivial_annotate_message(self): event = create_event(type="test_message", name="event") old_state = [ create_event(type="test1", state_key="1"), create_event(type="test1", state_key="2"), create_event(type="test2", state_key=""), ] group_name = "group_name_1" self.store.get_state_groups.return_value = { group_name: old_state, } context = yield self.state.compute_event_context(event) for k, v in context.current_state.items(): type, state_key = k self.assertEqual(type, v.type) self.assertEqual(state_key, v.state_key) self.assertEqual( set([e.event_id for e in old_state]), set([e.event_id for e in context.current_state.values()]) ) self.assertEqual(group_name, context.state_group) @defer.inlineCallbacks def test_trivial_annotate_state(self): event = create_event(type="state", state_key="", name="event") old_state = [ create_event(type="test1", state_key="1"), create_event(type="test1", state_key="2"), create_event(type="test2", state_key=""), ] group_name = "group_name_1" self.store.get_state_groups.return_value = { group_name: old_state, } context = yield self.state.compute_event_context(event) for k, v in context.current_state.items(): type, state_key = k self.assertEqual(type, v.type) self.assertEqual(state_key, v.state_key) self.assertEqual( set([e.event_id for e in old_state]), set([e.event_id for e in context.current_state.values()]) ) self.assertIsNone(context.state_group) @defer.inlineCallbacks def test_resolve_message_conflict(self): event = create_event(type="test_message", name="event") creation = create_event( type=EventTypes.Create, state_key="" ) old_state_1 = [ creation, create_event(type="test1", state_key="1"), create_event(type="test1", state_key="2"), create_event(type="test2", state_key=""), ] old_state_2 = [ creation, create_event(type="test1", state_key="1"), create_event(type="test3", state_key="2"), create_event(type="test4", state_key=""), ] context = yield self._get_context(event, old_state_1, old_state_2) self.assertEqual(len(context.current_state), 6) self.assertIsNone(context.state_group) @defer.inlineCallbacks def test_resolve_state_conflict(self): event = create_event(type="test4", state_key="", name="event") creation = create_event( type=EventTypes.Create, state_key="" ) old_state_1 = [ creation, create_event(type="test1", state_key="1"), create_event(type="test1", state_key="2"), create_event(type="test2", state_key=""), ] old_state_2 = [ creation, create_event(type="test1", state_key="1"), create_event(type="test3", state_key="2"), create_event(type="test4", state_key=""), ] context = yield self._get_context(event, old_state_1, old_state_2) self.assertEqual(len(context.current_state), 6) self.assertIsNone(context.state_group) @defer.inlineCallbacks def test_standard_depth_conflict(self): event = create_event(type="test4", name="event") member_event = create_event( type=EventTypes.Member, state_key="@user_id:example.com", content={ "membership": Membership.JOIN, } ) creation = create_event( type=EventTypes.Create, state_key="", content={"creator": "@foo:bar"} ) old_state_1 = [ creation, member_event, create_event(type="test1", state_key="1", depth=1), ] old_state_2 = [ creation, member_event, create_event(type="test1", state_key="1", depth=2), ] context = yield self._get_context(event, old_state_1, old_state_2) self.assertEqual(old_state_2[2], context.current_state[("test1", "1")]) # Reverse the depth to make sure we are actually using the depths # during state resolution. old_state_1 = [ creation, member_event, create_event(type="test1", state_key="1", depth=2), ] old_state_2 = [ creation, member_event, create_event(type="test1", state_key="1", depth=1), ] context = yield self._get_context(event, old_state_1, old_state_2) self.assertEqual(old_state_1[2], context.current_state[("test1", "1")]) def _get_context(self, event, old_state_1, old_state_2): group_name_1 = "group_name_1" group_name_2 = "group_name_2" self.store.get_state_groups.return_value = { group_name_1: old_state_1, group_name_2: old_state_2, } return self.state.compute_event_context(event)
class StateTestCase(unittest.TestCase): def setUp(self): self.store = StateGroupStore() storage = Mock(main=self.store, state=self.store) hs = Mock(spec_set=[ "config", "get_datastore", "get_storage", "get_auth", "get_state_handler", "get_clock", "get_state_resolution_handler", ]) hs.config = default_config("tesths", True) hs.get_datastore.return_value = self.store hs.get_state_handler.return_value = None hs.get_clock.return_value = MockClock() hs.get_auth.return_value = Auth(hs) hs.get_state_resolution_handler = lambda: StateResolutionHandler(hs) hs.get_storage.return_value = storage self.state = StateHandler(hs) self.event_id = 0 @defer.inlineCallbacks def test_branch_no_conflict(self): graph = Graph( nodes={ "START": DictObj(type=EventTypes.Create, state_key="", content={}, depth=1), "A": DictObj(type=EventTypes.Message, depth=2), "B": DictObj(type=EventTypes.Message, depth=3), "C": DictObj(type=EventTypes.Name, state_key="", depth=3), "D": DictObj(type=EventTypes.Message, depth=4), }, edges={ "A": ["START"], "B": ["A"], "C": ["A"], "D": ["B", "C"] }, ) self.store.register_events(graph.walk()) context_store = {} # type: dict[str, EventContext] for event in graph.walk(): context = yield defer.ensureDeferred( self.state.compute_event_context(event)) self.store.register_event_context(event, context) context_store[event.event_id] = context ctx_c = context_store["C"] ctx_d = context_store["D"] prev_state_ids = yield defer.ensureDeferred(ctx_d.get_prev_state_ids()) self.assertEqual(2, len(prev_state_ids)) self.assertEqual(ctx_c.state_group, ctx_d.state_group_before_event) self.assertEqual(ctx_d.state_group_before_event, ctx_d.state_group) @defer.inlineCallbacks def test_branch_basic_conflict(self): graph = Graph( nodes={ "START": DictObj( type=EventTypes.Create, state_key="", content={"creator": "@user_id:example.com"}, depth=1, ), "A": DictObj( type=EventTypes.Member, state_key="@user_id:example.com", content={"membership": Membership.JOIN}, membership=Membership.JOIN, depth=2, ), "B": DictObj(type=EventTypes.Name, state_key="", depth=3), "C": DictObj(type=EventTypes.Name, state_key="", depth=4), "D": DictObj(type=EventTypes.Message, depth=5), }, edges={ "A": ["START"], "B": ["A"], "C": ["A"], "D": ["B", "C"] }, ) self.store.register_events(graph.walk()) context_store = {} for event in graph.walk(): context = yield defer.ensureDeferred( self.state.compute_event_context(event)) self.store.register_event_context(event, context) context_store[event.event_id] = context # C ends up winning the resolution between B and C ctx_c = context_store["C"] ctx_d = context_store["D"] prev_state_ids = yield defer.ensureDeferred(ctx_d.get_prev_state_ids()) self.assertSetEqual({"START", "A", "C"}, set(prev_state_ids.values())) self.assertEqual(ctx_c.state_group, ctx_d.state_group_before_event) self.assertEqual(ctx_d.state_group_before_event, ctx_d.state_group) @defer.inlineCallbacks def test_branch_have_banned_conflict(self): graph = Graph( nodes={ "START": DictObj( type=EventTypes.Create, state_key="", content={"creator": "@user_id:example.com"}, depth=1, ), "A": DictObj( type=EventTypes.Member, state_key="@user_id:example.com", content={"membership": Membership.JOIN}, membership=Membership.JOIN, depth=2, ), "B": DictObj(type=EventTypes.Name, state_key="", depth=3), "C": DictObj( type=EventTypes.Member, state_key="@user_id_2:example.com", content={"membership": Membership.BAN}, membership=Membership.BAN, depth=4, ), "D": DictObj( type=EventTypes.Name, state_key="", depth=4, sender="@user_id_2:example.com", ), "E": DictObj(type=EventTypes.Message, depth=5), }, edges={ "A": ["START"], "B": ["A"], "C": ["B"], "D": ["B"], "E": ["C", "D"] }, ) self.store.register_events(graph.walk()) context_store = {} for event in graph.walk(): context = yield defer.ensureDeferred( self.state.compute_event_context(event)) self.store.register_event_context(event, context) context_store[event.event_id] = context # C ends up winning the resolution between C and D because bans win over other # changes ctx_c = context_store["C"] ctx_e = context_store["E"] prev_state_ids = yield defer.ensureDeferred(ctx_e.get_prev_state_ids()) self.assertSetEqual({"START", "A", "B", "C"}, set(prev_state_ids.values())) self.assertEqual(ctx_c.state_group, ctx_e.state_group_before_event) self.assertEqual(ctx_e.state_group_before_event, ctx_e.state_group) @defer.inlineCallbacks def test_branch_have_perms_conflict(self): userid1 = "@user_id:example.com" userid2 = "@user_id2:example.com" nodes = { "A1": DictObj( type=EventTypes.Create, state_key="", content={"creator": userid1}, depth=1, ), "A2": DictObj( type=EventTypes.Member, state_key=userid1, content={"membership": Membership.JOIN}, membership=Membership.JOIN, ), "A3": DictObj( type=EventTypes.Member, state_key=userid2, content={"membership": Membership.JOIN}, membership=Membership.JOIN, ), "A4": DictObj( type=EventTypes.PowerLevels, state_key="", content={ "events": { "m.room.name": 50 }, "users": { userid1: 100, userid2: 60 }, }, ), "A5": DictObj(type=EventTypes.Name, state_key=""), "B": DictObj( type=EventTypes.PowerLevels, state_key="", content={ "events": { "m.room.name": 50 }, "users": { userid2: 30 } }, ), "C": DictObj(type=EventTypes.Name, state_key="", sender=userid2), "D": DictObj(type=EventTypes.Message), } edges = { "A2": ["A1"], "A3": ["A2"], "A4": ["A3"], "A5": ["A4"], "B": ["A5"], "C": ["A5"], "D": ["B", "C"], } self._add_depths(nodes, edges) graph = Graph(nodes, edges) self.store.register_events(graph.walk()) context_store = {} for event in graph.walk(): context = yield defer.ensureDeferred( self.state.compute_event_context(event)) self.store.register_event_context(event, context) context_store[event.event_id] = context # B ends up winning the resolution between B and C because power levels # win over other changes. ctx_b = context_store["B"] ctx_d = context_store["D"] prev_state_ids = yield defer.ensureDeferred(ctx_d.get_prev_state_ids()) self.assertSetEqual({"A1", "A2", "A3", "A5", "B"}, set(prev_state_ids.values())) self.assertEqual(ctx_b.state_group, ctx_d.state_group_before_event) self.assertEqual(ctx_d.state_group_before_event, ctx_d.state_group) def _add_depths(self, nodes, edges): def _get_depth(ev): node = nodes[ev] if "depth" not in node: prevs = edges[ev] depth = max(_get_depth(prev) for prev in prevs) + 1 node["depth"] = depth return node["depth"] for n in nodes: _get_depth(n) @defer.inlineCallbacks def test_annotate_with_old_message(self): event = create_event(type="test_message", name="event") old_state = [ create_event(type="test1", state_key="1"), create_event(type="test1", state_key="2"), create_event(type="test2", state_key=""), ] context = yield defer.ensureDeferred( self.state.compute_event_context(event, old_state=old_state)) prev_state_ids = yield defer.ensureDeferred( context.get_prev_state_ids()) self.assertCountEqual((e.event_id for e in old_state), prev_state_ids.values()) current_state_ids = yield defer.ensureDeferred( context.get_current_state_ids()) self.assertCountEqual((e.event_id for e in old_state), current_state_ids.values()) self.assertIsNotNone(context.state_group_before_event) self.assertEqual(context.state_group_before_event, context.state_group) @defer.inlineCallbacks def test_annotate_with_old_state(self): event = create_event(type="state", state_key="", name="event") old_state = [ create_event(type="test1", state_key="1"), create_event(type="test1", state_key="2"), create_event(type="test2", state_key=""), ] context = yield defer.ensureDeferred( self.state.compute_event_context(event, old_state=old_state)) prev_state_ids = yield defer.ensureDeferred( context.get_prev_state_ids()) self.assertCountEqual((e.event_id for e in old_state), prev_state_ids.values()) current_state_ids = yield defer.ensureDeferred( context.get_current_state_ids()) self.assertCountEqual((e.event_id for e in old_state + [event]), current_state_ids.values()) self.assertIsNotNone(context.state_group_before_event) self.assertNotEqual(context.state_group_before_event, context.state_group) self.assertEqual(context.state_group_before_event, context.prev_group) self.assertEqual({("state", ""): event.event_id}, context.delta_ids) @defer.inlineCallbacks def test_trivial_annotate_message(self): prev_event_id = "prev_event_id" event = create_event(type="test_message", name="event2", prev_events=[(prev_event_id, {})]) old_state = [ create_event(type="test1", state_key="1"), create_event(type="test1", state_key="2"), create_event(type="test2", state_key=""), ] group_name = yield defer.ensureDeferred( self.store.store_state_group( prev_event_id, event.room_id, None, None, {(e.type, e.state_key): e.event_id for e in old_state}, )) self.store.register_event_id_state_group(prev_event_id, group_name) context = yield defer.ensureDeferred( self.state.compute_event_context(event)) current_state_ids = yield defer.ensureDeferred( context.get_current_state_ids()) self.assertEqual({e.event_id for e in old_state}, set(current_state_ids.values())) self.assertEqual(group_name, context.state_group) @defer.inlineCallbacks def test_trivial_annotate_state(self): prev_event_id = "prev_event_id" event = create_event(type="state", state_key="", name="event2", prev_events=[(prev_event_id, {})]) old_state = [ create_event(type="test1", state_key="1"), create_event(type="test1", state_key="2"), create_event(type="test2", state_key=""), ] group_name = yield defer.ensureDeferred( self.store.store_state_group( prev_event_id, event.room_id, None, None, {(e.type, e.state_key): e.event_id for e in old_state}, )) self.store.register_event_id_state_group(prev_event_id, group_name) context = yield defer.ensureDeferred( self.state.compute_event_context(event)) prev_state_ids = yield defer.ensureDeferred( context.get_prev_state_ids()) self.assertEqual({e.event_id for e in old_state}, set(prev_state_ids.values())) self.assertIsNotNone(context.state_group) @defer.inlineCallbacks def test_resolve_message_conflict(self): prev_event_id1 = "event_id1" prev_event_id2 = "event_id2" event = create_event( type="test_message", name="event3", prev_events=[(prev_event_id1, {}), (prev_event_id2, {})], ) creation = create_event(type=EventTypes.Create, state_key="") old_state_1 = [ creation, create_event(type="test1", state_key="1"), create_event(type="test1", state_key="2"), create_event(type="test2", state_key=""), ] old_state_2 = [ creation, create_event(type="test1", state_key="1"), create_event(type="test3", state_key="2"), create_event(type="test4", state_key=""), ] self.store.register_events(old_state_1) self.store.register_events(old_state_2) context = yield self._get_context(event, prev_event_id1, old_state_1, prev_event_id2, old_state_2) current_state_ids = yield defer.ensureDeferred( context.get_current_state_ids()) self.assertEqual(len(current_state_ids), 6) self.assertIsNotNone(context.state_group) @defer.inlineCallbacks def test_resolve_state_conflict(self): prev_event_id1 = "event_id1" prev_event_id2 = "event_id2" event = create_event( type="test4", state_key="", name="event", prev_events=[(prev_event_id1, {}), (prev_event_id2, {})], ) creation = create_event(type=EventTypes.Create, state_key="") old_state_1 = [ creation, create_event(type="test1", state_key="1"), create_event(type="test1", state_key="2"), create_event(type="test2", state_key=""), ] old_state_2 = [ creation, create_event(type="test1", state_key="1"), create_event(type="test3", state_key="2"), create_event(type="test4", state_key=""), ] store = StateGroupStore() store.register_events(old_state_1) store.register_events(old_state_2) self.store.get_events = store.get_events context = yield self._get_context(event, prev_event_id1, old_state_1, prev_event_id2, old_state_2) current_state_ids = yield defer.ensureDeferred( context.get_current_state_ids()) self.assertEqual(len(current_state_ids), 6) self.assertIsNotNone(context.state_group) @defer.inlineCallbacks def test_standard_depth_conflict(self): prev_event_id1 = "event_id1" prev_event_id2 = "event_id2" event = create_event( type="test4", name="event", prev_events=[(prev_event_id1, {}), (prev_event_id2, {})], ) member_event = create_event( type=EventTypes.Member, state_key="@user_id:example.com", content={"membership": Membership.JOIN}, ) power_levels = create_event( type=EventTypes.PowerLevels, state_key="", content={ "users": { "@foo:bar": "100", "@user_id:example.com": "100" } }, ) creation = create_event(type=EventTypes.Create, state_key="", content={"creator": "@foo:bar"}) old_state_1 = [ creation, power_levels, member_event, create_event(type="test1", state_key="1", depth=1), ] old_state_2 = [ creation, power_levels, member_event, create_event(type="test1", state_key="1", depth=2), ] store = StateGroupStore() store.register_events(old_state_1) store.register_events(old_state_2) self.store.get_events = store.get_events context = yield self._get_context(event, prev_event_id1, old_state_1, prev_event_id2, old_state_2) current_state_ids = yield defer.ensureDeferred( context.get_current_state_ids()) self.assertEqual(old_state_2[3].event_id, current_state_ids[("test1", "1")]) # Reverse the depth to make sure we are actually using the depths # during state resolution. old_state_1 = [ creation, power_levels, member_event, create_event(type="test1", state_key="1", depth=2), ] old_state_2 = [ creation, power_levels, member_event, create_event(type="test1", state_key="1", depth=1), ] store.register_events(old_state_1) store.register_events(old_state_2) context = yield self._get_context(event, prev_event_id1, old_state_1, prev_event_id2, old_state_2) current_state_ids = yield defer.ensureDeferred( context.get_current_state_ids()) self.assertEqual(old_state_1[3].event_id, current_state_ids[("test1", "1")]) @defer.inlineCallbacks def _get_context(self, event, prev_event_id_1, old_state_1, prev_event_id_2, old_state_2): sg1 = yield defer.ensureDeferred( self.store.store_state_group( prev_event_id_1, event.room_id, None, None, {(e.type, e.state_key): e.event_id for e in old_state_1}, )) self.store.register_event_id_state_group(prev_event_id_1, sg1) sg2 = yield defer.ensureDeferred( self.store.store_state_group( prev_event_id_2, event.room_id, None, None, {(e.type, e.state_key): e.event_id for e in old_state_2}, )) self.store.register_event_id_state_group(prev_event_id_2, sg2) result = yield defer.ensureDeferred( self.state.compute_event_context(event)) return result
class StateTestCase(unittest.TestCase): def setUp(self): self.persistence = Mock(spec=[ "get_unresolved_state_tree", "update_current_state", "get_latest_pdus_in_context", "get_current_state_pdu", "get_pdu", "get_power_level", ]) self.replication = Mock(spec=["get_pdu"]) hs = Mock(spec=["get_datastore", "get_replication_layer"]) hs.get_datastore.return_value = self.persistence hs.get_replication_layer.return_value = self.replication hs.hostname = "bob.com" self.state = StateHandler(hs) @defer.inlineCallbacks def test_new_state_key(self): # We've never seen anything for this state before new_pdu = new_fake_pdu("A", "test", "mem", "x", None, "u") self.persistence.get_power_level.side_effect = _gen_get_power_level({}) self.persistence.get_unresolved_state_tree.return_value = ((ReturnType( [new_pdu], []), None)) is_new = yield self.state.handle_new_state(new_pdu) self.assertTrue(is_new) self.persistence.get_unresolved_state_tree.assert_called_once_with( new_pdu) self.assertEqual(1, self.persistence.update_current_state.call_count) self.assertFalse(self.replication.get_pdu.called) @defer.inlineCallbacks def test_direct_overwrite(self): # We do a direct overwriting of the old state, i.e., the new state # points to the old state. old_pdu = new_fake_pdu("A", "test", "mem", "x", None, "u1") new_pdu = new_fake_pdu("B", "test", "mem", "x", "A", "u2") self.persistence.get_power_level.side_effect = _gen_get_power_level({ "u1": 10, "u2": 5, }) self.persistence.get_unresolved_state_tree.return_value = ((ReturnType( [new_pdu, old_pdu], [old_pdu]), None)) is_new = yield self.state.handle_new_state(new_pdu) self.assertTrue(is_new) self.persistence.get_unresolved_state_tree.assert_called_once_with( new_pdu) self.assertEqual(1, self.persistence.update_current_state.call_count) self.assertFalse(self.replication.get_pdu.called) @defer.inlineCallbacks def test_overwrite(self): old_pdu_1 = new_fake_pdu("A", "test", "mem", "x", None, "u1") old_pdu_2 = new_fake_pdu("B", "test", "mem", "x", "A", "u2") new_pdu = new_fake_pdu("C", "test", "mem", "x", "B", "u3") self.persistence.get_power_level.side_effect = _gen_get_power_level({ "u1": 10, "u2": 5, "u3": 0, }) self.persistence.get_unresolved_state_tree.return_value = ((ReturnType( [new_pdu, old_pdu_2, old_pdu_1], [old_pdu_1]), None)) is_new = yield self.state.handle_new_state(new_pdu) self.assertTrue(is_new) self.persistence.get_unresolved_state_tree.assert_called_once_with( new_pdu) self.assertEqual(1, self.persistence.update_current_state.call_count) self.assertFalse(self.replication.get_pdu.called) @defer.inlineCallbacks def test_power_level_fail(self): # We try to update the state based on an outdated state, and have a # too low power level. old_pdu_1 = new_fake_pdu("A", "test", "mem", "x", None, "u1") old_pdu_2 = new_fake_pdu("B", "test", "mem", "x", None, "u2") new_pdu = new_fake_pdu("C", "test", "mem", "x", "A", "u3") self.persistence.get_power_level.side_effect = _gen_get_power_level({ "u1": 10, "u2": 10, "u3": 5, }) self.persistence.get_unresolved_state_tree.return_value = ((ReturnType( [new_pdu, old_pdu_1], [old_pdu_2, old_pdu_1]), None)) is_new = yield self.state.handle_new_state(new_pdu) self.assertFalse(is_new) self.persistence.get_unresolved_state_tree.assert_called_once_with( new_pdu) self.assertEqual(0, self.persistence.update_current_state.call_count) self.assertFalse(self.replication.get_pdu.called) @defer.inlineCallbacks def test_power_level_succeed(self): # We try to update the state based on an outdated state, but have # sufficient power level to force the update. old_pdu_1 = new_fake_pdu("A", "test", "mem", "x", None, "u1") old_pdu_2 = new_fake_pdu("B", "test", "mem", "x", None, "u2") new_pdu = new_fake_pdu("C", "test", "mem", "x", "A", "u3") self.persistence.get_power_level.side_effect = _gen_get_power_level({ "u1": 10, "u2": 10, "u3": 15, }) self.persistence.get_unresolved_state_tree.return_value = ((ReturnType( [new_pdu, old_pdu_1], [old_pdu_2, old_pdu_1]), None)) is_new = yield self.state.handle_new_state(new_pdu) self.assertTrue(is_new) self.persistence.get_unresolved_state_tree.assert_called_once_with( new_pdu) self.assertEqual(1, self.persistence.update_current_state.call_count) self.assertFalse(self.replication.get_pdu.called) @defer.inlineCallbacks def test_power_level_equal_same_len(self): # We try to update the state based on an outdated state, the power # levels are the same and so are the branch lengths old_pdu_1 = new_fake_pdu("A", "test", "mem", "x", None, "u1") old_pdu_2 = new_fake_pdu("B", "test", "mem", "x", None, "u2") new_pdu = new_fake_pdu("C", "test", "mem", "x", "A", "u3") self.persistence.get_power_level.side_effect = _gen_get_power_level({ "u1": 10, "u2": 10, "u3": 10, }) self.persistence.get_unresolved_state_tree.return_value = ((ReturnType( [new_pdu, old_pdu_1], [old_pdu_2, old_pdu_1]), None)) is_new = yield self.state.handle_new_state(new_pdu) self.assertTrue(is_new) self.persistence.get_unresolved_state_tree.assert_called_once_with( new_pdu) self.assertEqual(1, self.persistence.update_current_state.call_count) self.assertFalse(self.replication.get_pdu.called) @defer.inlineCallbacks def test_power_level_equal_diff_len(self): # We try to update the state based on an outdated state, the power # levels are the same but the branch length of the new one is longer. old_pdu_1 = new_fake_pdu("A", "test", "mem", "x", None, "u1") old_pdu_2 = new_fake_pdu("B", "test", "mem", "x", None, "u2") old_pdu_3 = new_fake_pdu("C", "test", "mem", "x", "A", "u3") new_pdu = new_fake_pdu("D", "test", "mem", "x", "C", "u4") self.persistence.get_power_level.side_effect = _gen_get_power_level({ "u1": 10, "u2": 10, "u3": 10, "u4": 10, }) self.persistence.get_unresolved_state_tree.return_value = ((ReturnType( [new_pdu, old_pdu_3, old_pdu_1], [old_pdu_2, old_pdu_1]), None)) is_new = yield self.state.handle_new_state(new_pdu) self.assertTrue(is_new) self.persistence.get_unresolved_state_tree.assert_called_once_with( new_pdu) self.assertEqual(1, self.persistence.update_current_state.call_count) self.assertFalse(self.replication.get_pdu.called) @defer.inlineCallbacks def test_missing_pdu(self): # We try to update state against a PDU we haven't yet seen, # triggering a get_pdu request # The pdu we haven't seen old_pdu_1 = new_fake_pdu("A", "test", "mem", "x", None, "u1", depth=0) old_pdu_2 = new_fake_pdu("B", "test", "mem", "x", "A", "u2", depth=1) new_pdu = new_fake_pdu("C", "test", "mem", "x", "A", "u3", depth=2) self.persistence.get_power_level.side_effect = _gen_get_power_level({ "u1": 10, "u2": 10, "u3": 20, }) # The return_value of `get_unresolved_state_tree`, which changes after # the call to get_pdu tree_to_return = [(ReturnType([new_pdu], [old_pdu_2]), 0)] def return_tree(p): return tree_to_return[0] def set_return_tree(destination, pdu_origin, pdu_id, outlier=False): tree_to_return[0] = (ReturnType([new_pdu, old_pdu_1], [old_pdu_2, old_pdu_1]), None) return defer.succeed(None) self.persistence.get_unresolved_state_tree.side_effect = return_tree self.replication.get_pdu.side_effect = set_return_tree self.persistence.get_pdu.return_value = None is_new = yield self.state.handle_new_state(new_pdu) self.assertTrue(is_new) self.replication.get_pdu.assert_called_with( destination=new_pdu.origin, pdu_origin=old_pdu_1.origin, pdu_id=old_pdu_1.pdu_id, outlier=True) self.persistence.get_unresolved_state_tree.assert_called_with(new_pdu) self.assertEquals( 2, self.persistence.get_unresolved_state_tree.call_count) self.assertEqual(1, self.persistence.update_current_state.call_count) @defer.inlineCallbacks def test_missing_pdu_depth_1(self): # We try to update state against a PDU we haven't yet seen, # triggering a get_pdu request # The pdu we haven't seen old_pdu_1 = new_fake_pdu("A", "test", "mem", "x", None, "u1", depth=0) old_pdu_2 = new_fake_pdu("B", "test", "mem", "x", "A", "u2", depth=2) old_pdu_3 = new_fake_pdu("C", "test", "mem", "x", "B", "u3", depth=3) new_pdu = new_fake_pdu("D", "test", "mem", "x", "A", "u4", depth=4) self.persistence.get_power_level.side_effect = _gen_get_power_level({ "u1": 10, "u2": 10, "u3": 10, "u4": 20, }) # The return_value of `get_unresolved_state_tree`, which changes after # the call to get_pdu tree_to_return = [ (ReturnType([new_pdu], [old_pdu_3]), 0), (ReturnType([new_pdu, old_pdu_1], [old_pdu_3]), 1), (ReturnType([new_pdu, old_pdu_1], [old_pdu_3, old_pdu_2, old_pdu_1]), None), ] to_return = [0] def return_tree(p): return tree_to_return[to_return[0]] def set_return_tree(destination, pdu_origin, pdu_id, outlier=False): to_return[0] += 1 return defer.succeed(None) self.persistence.get_unresolved_state_tree.side_effect = return_tree self.replication.get_pdu.side_effect = set_return_tree self.persistence.get_pdu.return_value = None is_new = yield self.state.handle_new_state(new_pdu) self.assertTrue(is_new) self.assertEqual(2, self.replication.get_pdu.call_count) self.replication.get_pdu.assert_has_calls([ mock.call(destination=new_pdu.origin, pdu_origin=old_pdu_1.origin, pdu_id=old_pdu_1.pdu_id, outlier=True), mock.call(destination=old_pdu_3.origin, pdu_origin=old_pdu_2.origin, pdu_id=old_pdu_2.pdu_id, outlier=True), ]) self.persistence.get_unresolved_state_tree.assert_called_with(new_pdu) self.assertEquals( 3, self.persistence.get_unresolved_state_tree.call_count) self.assertEqual(1, self.persistence.update_current_state.call_count) @defer.inlineCallbacks def test_missing_pdu_depth_2(self): # We try to update state against a PDU we haven't yet seen, # triggering a get_pdu request # The pdu we haven't seen old_pdu_1 = new_fake_pdu("A", "test", "mem", "x", None, "u1", depth=0) old_pdu_2 = new_fake_pdu("B", "test", "mem", "x", "A", "u2", depth=2) old_pdu_3 = new_fake_pdu("C", "test", "mem", "x", "B", "u3", depth=3) new_pdu = new_fake_pdu("D", "test", "mem", "x", "A", "u4", depth=1) self.persistence.get_power_level.side_effect = _gen_get_power_level({ "u1": 10, "u2": 10, "u3": 10, "u4": 20, }) # The return_value of `get_unresolved_state_tree`, which changes after # the call to get_pdu tree_to_return = [ ( ReturnType([new_pdu], [old_pdu_3]), 1, ), ( ReturnType([new_pdu], [old_pdu_3, old_pdu_2]), 0, ), (ReturnType([new_pdu, old_pdu_1], [old_pdu_3, old_pdu_2, old_pdu_1]), None), ] to_return = [0] def return_tree(p): return tree_to_return[to_return[0]] def set_return_tree(destination, pdu_origin, pdu_id, outlier=False): to_return[0] += 1 return defer.succeed(None) self.persistence.get_unresolved_state_tree.side_effect = return_tree self.replication.get_pdu.side_effect = set_return_tree self.persistence.get_pdu.return_value = None is_new = yield self.state.handle_new_state(new_pdu) self.assertTrue(is_new) self.assertEqual(2, self.replication.get_pdu.call_count) self.replication.get_pdu.assert_has_calls([ mock.call(destination=old_pdu_3.origin, pdu_origin=old_pdu_2.origin, pdu_id=old_pdu_2.pdu_id, outlier=True), mock.call(destination=new_pdu.origin, pdu_origin=old_pdu_1.origin, pdu_id=old_pdu_1.pdu_id, outlier=True), ]) self.persistence.get_unresolved_state_tree.assert_called_with(new_pdu) self.assertEquals( 3, self.persistence.get_unresolved_state_tree.call_count) self.assertEqual(1, self.persistence.update_current_state.call_count) @defer.inlineCallbacks def test_no_common_ancestor(self): # We do a direct overwriting of the old state, i.e., the new state # points to the old state. old_pdu = new_fake_pdu("A", "test", "mem", "x", None, "u1") new_pdu = new_fake_pdu("B", "test", "mem", "x", None, "u2") self.persistence.get_power_level.side_effect = _gen_get_power_level({ "u1": 5, "u2": 10, }) self.persistence.get_unresolved_state_tree.return_value = ((ReturnType( [new_pdu], [old_pdu]), None)) is_new = yield self.state.handle_new_state(new_pdu) self.assertTrue(is_new) self.persistence.get_unresolved_state_tree.assert_called_once_with( new_pdu) self.assertEqual(1, self.persistence.update_current_state.call_count) self.assertFalse(self.replication.get_pdu.called) @defer.inlineCallbacks def test_new_event(self): event = Mock() event.event_id = "12123123@test" state_pdu = new_fake_pdu("C", "test", "mem", "x", "A", 20) snapshot = Mock() snapshot.prev_state_pdu = state_pdu event_id = "*****@*****.**" def fill_out_prev_events(event): event.prev_events = [event_id] event.depth = 6 snapshot.fill_out_prev_events = fill_out_prev_events yield self.state.handle_new_event(event, snapshot) self.assertLess(5, event.depth) self.assertEquals(1, len(event.prev_events)) prev_id = event.prev_events[0] self.assertEqual(event_id, prev_id) self.assertEqual(encode_event_id(state_pdu.pdu_id, state_pdu.origin), event.prev_state)
class StateTestCase(unittest.TestCase): def setUp(self): self.store = StateGroupStore() hs = Mock(spec_set=[ "get_datastore", "get_auth", "get_state_handler", "get_clock", "get_state_resolution_handler", ]) hs.get_datastore.return_value = self.store hs.get_state_handler.return_value = None hs.get_clock.return_value = MockClock() hs.get_auth.return_value = Auth(hs) hs.get_state_resolution_handler = lambda: StateResolutionHandler(hs) self.state = StateHandler(hs) self.event_id = 0 @defer.inlineCallbacks def test_branch_no_conflict(self): graph = Graph(nodes={ "START": DictObj( type=EventTypes.Create, state_key="", depth=1, ), "A": DictObj( type=EventTypes.Message, depth=2, ), "B": DictObj( type=EventTypes.Message, depth=3, ), "C": DictObj( type=EventTypes.Name, state_key="", depth=3, ), "D": DictObj( type=EventTypes.Message, depth=4, ), }, edges={ "A": ["START"], "B": ["A"], "C": ["A"], "D": ["B", "C"] }) self.store.register_events(graph.walk()) context_store = {} for event in graph.walk(): context = yield self.state.compute_event_context(event) self.store.register_event_context(event, context) context_store[event.event_id] = context prev_state_ids = yield context_store["D"].get_prev_state_ids( self.store) self.assertEqual(2, len(prev_state_ids)) @defer.inlineCallbacks def test_branch_basic_conflict(self): graph = Graph(nodes={ "START": DictObj( type=EventTypes.Create, state_key="", content={"creator": "@user_id:example.com"}, depth=1, ), "A": DictObj( type=EventTypes.Member, state_key="@user_id:example.com", content={"membership": Membership.JOIN}, membership=Membership.JOIN, depth=2, ), "B": DictObj( type=EventTypes.Name, state_key="", depth=3, ), "C": DictObj( type=EventTypes.Name, state_key="", depth=4, ), "D": DictObj( type=EventTypes.Message, depth=5, ), }, edges={ "A": ["START"], "B": ["A"], "C": ["A"], "D": ["B", "C"] }) self.store.register_events(graph.walk()) context_store = {} for event in graph.walk(): context = yield self.state.compute_event_context(event) self.store.register_event_context(event, context) context_store[event.event_id] = context prev_state_ids = yield context_store["D"].get_prev_state_ids( self.store) self.assertSetEqual({"START", "A", "C"}, {e_id for e_id in prev_state_ids.values()}) @defer.inlineCallbacks def test_branch_have_banned_conflict(self): graph = Graph(nodes={ "START": DictObj( type=EventTypes.Create, state_key="", content={"creator": "@user_id:example.com"}, depth=1, ), "A": DictObj( type=EventTypes.Member, state_key="@user_id:example.com", content={"membership": Membership.JOIN}, membership=Membership.JOIN, depth=2, ), "B": DictObj( type=EventTypes.Name, state_key="", depth=3, ), "C": DictObj( type=EventTypes.Member, state_key="@user_id_2:example.com", content={"membership": Membership.BAN}, membership=Membership.BAN, depth=4, ), "D": DictObj( type=EventTypes.Name, state_key="", depth=4, sender="@user_id_2:example.com", ), "E": DictObj( type=EventTypes.Message, depth=5, ), }, edges={ "A": ["START"], "B": ["A"], "C": ["B"], "D": ["B"], "E": ["C", "D"] }) self.store.register_events(graph.walk()) context_store = {} for event in graph.walk(): context = yield self.state.compute_event_context(event) self.store.register_event_context(event, context) context_store[event.event_id] = context prev_state_ids = yield context_store["E"].get_prev_state_ids( self.store) self.assertSetEqual({"START", "A", "B", "C"}, {e for e in prev_state_ids.values()}) @defer.inlineCallbacks def test_branch_have_perms_conflict(self): userid1 = "@user_id:example.com" userid2 = "@user_id2:example.com" nodes = { "A1": DictObj( type=EventTypes.Create, state_key="", content={"creator": userid1}, depth=1, ), "A2": DictObj( type=EventTypes.Member, state_key=userid1, content={"membership": Membership.JOIN}, membership=Membership.JOIN, ), "A3": DictObj( type=EventTypes.Member, state_key=userid2, content={"membership": Membership.JOIN}, membership=Membership.JOIN, ), "A4": DictObj( type=EventTypes.PowerLevels, state_key="", content={ "events": { "m.room.name": 50 }, "users": { userid1: 100, userid2: 60 }, }, ), "A5": DictObj( type=EventTypes.Name, state_key="", ), "B": DictObj( type=EventTypes.PowerLevels, state_key="", content={ "events": { "m.room.name": 50 }, "users": { userid2: 30 }, }, ), "C": DictObj( type=EventTypes.Name, state_key="", sender=userid2, ), "D": DictObj(type=EventTypes.Message, ), } edges = { "A2": ["A1"], "A3": ["A2"], "A4": ["A3"], "A5": ["A4"], "B": ["A5"], "C": ["A5"], "D": ["B", "C"] } self._add_depths(nodes, edges) graph = Graph(nodes, edges) self.store.register_events(graph.walk()) context_store = {} for event in graph.walk(): context = yield self.state.compute_event_context(event) self.store.register_event_context(event, context) context_store[event.event_id] = context prev_state_ids = yield context_store["D"].get_prev_state_ids( self.store) self.assertSetEqual({"A1", "A2", "A3", "A5", "B"}, {e for e in prev_state_ids.values()}) def _add_depths(self, nodes, edges): def _get_depth(ev): node = nodes[ev] if 'depth' not in node: prevs = edges[ev] depth = max(_get_depth(prev) for prev in prevs) + 1 node['depth'] = depth return node['depth'] for n in nodes: _get_depth(n) @defer.inlineCallbacks def test_annotate_with_old_message(self): event = create_event(type="test_message", name="event") old_state = [ create_event(type="test1", state_key="1"), create_event(type="test1", state_key="2"), create_event(type="test2", state_key=""), ] context = yield self.state.compute_event_context(event, old_state=old_state) current_state_ids = yield context.get_current_state_ids(self.store) self.assertEqual(set(e.event_id for e in old_state), set(current_state_ids.values())) self.assertIsNotNone(context.state_group) @defer.inlineCallbacks def test_annotate_with_old_state(self): event = create_event(type="state", state_key="", name="event") old_state = [ create_event(type="test1", state_key="1"), create_event(type="test1", state_key="2"), create_event(type="test2", state_key=""), ] context = yield self.state.compute_event_context(event, old_state=old_state) prev_state_ids = yield context.get_prev_state_ids(self.store) self.assertEqual(set(e.event_id for e in old_state), set(prev_state_ids.values())) @defer.inlineCallbacks def test_trivial_annotate_message(self): prev_event_id = "prev_event_id" event = create_event( type="test_message", name="event2", prev_events=[(prev_event_id, {})], ) old_state = [ create_event(type="test1", state_key="1"), create_event(type="test1", state_key="2"), create_event(type="test2", state_key=""), ] group_name = self.store.store_state_group( prev_event_id, event.room_id, None, None, {(e.type, e.state_key): e.event_id for e in old_state}, ) self.store.register_event_id_state_group(prev_event_id, group_name) context = yield self.state.compute_event_context(event) current_state_ids = yield context.get_current_state_ids(self.store) self.assertEqual(set([e.event_id for e in old_state]), set(current_state_ids.values())) self.assertEqual(group_name, context.state_group) @defer.inlineCallbacks def test_trivial_annotate_state(self): prev_event_id = "prev_event_id" event = create_event( type="state", state_key="", name="event2", prev_events=[(prev_event_id, {})], ) old_state = [ create_event(type="test1", state_key="1"), create_event(type="test1", state_key="2"), create_event(type="test2", state_key=""), ] group_name = self.store.store_state_group( prev_event_id, event.room_id, None, None, {(e.type, e.state_key): e.event_id for e in old_state}, ) self.store.register_event_id_state_group(prev_event_id, group_name) context = yield self.state.compute_event_context(event) prev_state_ids = yield context.get_prev_state_ids(self.store) self.assertEqual(set([e.event_id for e in old_state]), set(prev_state_ids.values())) self.assertIsNotNone(context.state_group) @defer.inlineCallbacks def test_resolve_message_conflict(self): prev_event_id1 = "event_id1" prev_event_id2 = "event_id2" event = create_event( type="test_message", name="event3", prev_events=[(prev_event_id1, {}), (prev_event_id2, {})], ) creation = create_event(type=EventTypes.Create, state_key="") old_state_1 = [ creation, create_event(type="test1", state_key="1"), create_event(type="test1", state_key="2"), create_event(type="test2", state_key=""), ] old_state_2 = [ creation, create_event(type="test1", state_key="1"), create_event(type="test3", state_key="2"), create_event(type="test4", state_key=""), ] self.store.register_events(old_state_1) self.store.register_events(old_state_2) context = yield self._get_context( event, prev_event_id1, old_state_1, prev_event_id2, old_state_2, ) current_state_ids = yield context.get_current_state_ids(self.store) self.assertEqual(len(current_state_ids), 6) self.assertIsNotNone(context.state_group) @defer.inlineCallbacks def test_resolve_state_conflict(self): prev_event_id1 = "event_id1" prev_event_id2 = "event_id2" event = create_event( type="test4", state_key="", name="event", prev_events=[(prev_event_id1, {}), (prev_event_id2, {})], ) creation = create_event(type=EventTypes.Create, state_key="") old_state_1 = [ creation, create_event(type="test1", state_key="1"), create_event(type="test1", state_key="2"), create_event(type="test2", state_key=""), ] old_state_2 = [ creation, create_event(type="test1", state_key="1"), create_event(type="test3", state_key="2"), create_event(type="test4", state_key=""), ] store = StateGroupStore() store.register_events(old_state_1) store.register_events(old_state_2) self.store.get_events = store.get_events context = yield self._get_context( event, prev_event_id1, old_state_1, prev_event_id2, old_state_2, ) current_state_ids = yield context.get_current_state_ids(self.store) self.assertEqual(len(current_state_ids), 6) self.assertIsNotNone(context.state_group) @defer.inlineCallbacks def test_standard_depth_conflict(self): prev_event_id1 = "event_id1" prev_event_id2 = "event_id2" event = create_event( type="test4", name="event", prev_events=[(prev_event_id1, {}), (prev_event_id2, {})], ) member_event = create_event(type=EventTypes.Member, state_key="@user_id:example.com", content={ "membership": Membership.JOIN, }) power_levels = create_event(type=EventTypes.PowerLevels, state_key="", content={ "users": { "@foo:bar": "100", "@user_id:example.com": "100", } }) creation = create_event(type=EventTypes.Create, state_key="", content={"creator": "@foo:bar"}) old_state_1 = [ creation, power_levels, member_event, create_event(type="test1", state_key="1", depth=1), ] old_state_2 = [ creation, power_levels, member_event, create_event(type="test1", state_key="1", depth=2), ] store = StateGroupStore() store.register_events(old_state_1) store.register_events(old_state_2) self.store.get_events = store.get_events context = yield self._get_context( event, prev_event_id1, old_state_1, prev_event_id2, old_state_2, ) current_state_ids = yield context.get_current_state_ids(self.store) self.assertEqual(old_state_2[3].event_id, current_state_ids[("test1", "1")]) # Reverse the depth to make sure we are actually using the depths # during state resolution. old_state_1 = [ creation, power_levels, member_event, create_event(type="test1", state_key="1", depth=2), ] old_state_2 = [ creation, power_levels, member_event, create_event(type="test1", state_key="1", depth=1), ] store.register_events(old_state_1) store.register_events(old_state_2) context = yield self._get_context( event, prev_event_id1, old_state_1, prev_event_id2, old_state_2, ) current_state_ids = yield context.get_current_state_ids(self.store) self.assertEqual(old_state_1[3].event_id, current_state_ids[("test1", "1")]) def _get_context(self, event, prev_event_id_1, old_state_1, prev_event_id_2, old_state_2): sg1 = self.store.store_state_group( prev_event_id_1, event.room_id, None, None, {(e.type, e.state_key): e.event_id for e in old_state_1}, ) self.store.register_event_id_state_group(prev_event_id_1, sg1) sg2 = self.store.store_state_group( prev_event_id_2, event.room_id, None, None, {(e.type, e.state_key): e.event_id for e in old_state_2}, ) self.store.register_event_id_state_group(prev_event_id_2, sg2) return self.state.compute_event_context(event)
def build_state_handler(self): return StateHandler(self)
class StateTestCase(unittest.TestCase): def setUp(self): self.store = StateGroupStore() hs = Mock( spec_set=[ "config", "get_datastore", "get_auth", "get_state_handler", "get_clock", "get_state_resolution_handler", ] ) hs.config = default_config("tesths", True) hs.get_datastore.return_value = self.store hs.get_state_handler.return_value = None hs.get_clock.return_value = MockClock() hs.get_auth.return_value = Auth(hs) hs.get_state_resolution_handler = lambda: StateResolutionHandler(hs) self.state = StateHandler(hs) self.event_id = 0 @defer.inlineCallbacks def test_branch_no_conflict(self): graph = Graph( nodes={ "START": DictObj( type=EventTypes.Create, state_key="", content={}, depth=1 ), "A": DictObj(type=EventTypes.Message, depth=2), "B": DictObj(type=EventTypes.Message, depth=3), "C": DictObj(type=EventTypes.Name, state_key="", depth=3), "D": DictObj(type=EventTypes.Message, depth=4), }, edges={"A": ["START"], "B": ["A"], "C": ["A"], "D": ["B", "C"]}, ) self.store.register_events(graph.walk()) context_store = {} for event in graph.walk(): context = yield self.state.compute_event_context(event) self.store.register_event_context(event, context) context_store[event.event_id] = context prev_state_ids = yield context_store["D"].get_prev_state_ids(self.store) self.assertEqual(2, len(prev_state_ids)) @defer.inlineCallbacks def test_branch_basic_conflict(self): graph = Graph( nodes={ "START": DictObj( type=EventTypes.Create, state_key="", content={"creator": "@user_id:example.com"}, depth=1, ), "A": DictObj( type=EventTypes.Member, state_key="@user_id:example.com", content={"membership": Membership.JOIN}, membership=Membership.JOIN, depth=2, ), "B": DictObj(type=EventTypes.Name, state_key="", depth=3), "C": DictObj(type=EventTypes.Name, state_key="", depth=4), "D": DictObj(type=EventTypes.Message, depth=5), }, edges={"A": ["START"], "B": ["A"], "C": ["A"], "D": ["B", "C"]}, ) self.store.register_events(graph.walk()) context_store = {} for event in graph.walk(): context = yield self.state.compute_event_context(event) self.store.register_event_context(event, context) context_store[event.event_id] = context prev_state_ids = yield context_store["D"].get_prev_state_ids(self.store) self.assertSetEqual( {"START", "A", "C"}, {e_id for e_id in prev_state_ids.values()} ) @defer.inlineCallbacks def test_branch_have_banned_conflict(self): graph = Graph( nodes={ "START": DictObj( type=EventTypes.Create, state_key="", content={"creator": "@user_id:example.com"}, depth=1, ), "A": DictObj( type=EventTypes.Member, state_key="@user_id:example.com", content={"membership": Membership.JOIN}, membership=Membership.JOIN, depth=2, ), "B": DictObj(type=EventTypes.Name, state_key="", depth=3), "C": DictObj( type=EventTypes.Member, state_key="@user_id_2:example.com", content={"membership": Membership.BAN}, membership=Membership.BAN, depth=4, ), "D": DictObj( type=EventTypes.Name, state_key="", depth=4, sender="@user_id_2:example.com", ), "E": DictObj(type=EventTypes.Message, depth=5), }, edges={"A": ["START"], "B": ["A"], "C": ["B"], "D": ["B"], "E": ["C", "D"]}, ) self.store.register_events(graph.walk()) context_store = {} for event in graph.walk(): context = yield self.state.compute_event_context(event) self.store.register_event_context(event, context) context_store[event.event_id] = context prev_state_ids = yield context_store["E"].get_prev_state_ids(self.store) self.assertSetEqual( {"START", "A", "B", "C"}, {e for e in prev_state_ids.values()} ) @defer.inlineCallbacks def test_branch_have_perms_conflict(self): userid1 = "@user_id:example.com" userid2 = "@user_id2:example.com" nodes = { "A1": DictObj( type=EventTypes.Create, state_key="", content={"creator": userid1}, depth=1, ), "A2": DictObj( type=EventTypes.Member, state_key=userid1, content={"membership": Membership.JOIN}, membership=Membership.JOIN, ), "A3": DictObj( type=EventTypes.Member, state_key=userid2, content={"membership": Membership.JOIN}, membership=Membership.JOIN, ), "A4": DictObj( type=EventTypes.PowerLevels, state_key="", content={ "events": {"m.room.name": 50}, "users": {userid1: 100, userid2: 60}, }, ), "A5": DictObj(type=EventTypes.Name, state_key=""), "B": DictObj( type=EventTypes.PowerLevels, state_key="", content={"events": {"m.room.name": 50}, "users": {userid2: 30}}, ), "C": DictObj(type=EventTypes.Name, state_key="", sender=userid2), "D": DictObj(type=EventTypes.Message), } edges = { "A2": ["A1"], "A3": ["A2"], "A4": ["A3"], "A5": ["A4"], "B": ["A5"], "C": ["A5"], "D": ["B", "C"], } self._add_depths(nodes, edges) graph = Graph(nodes, edges) self.store.register_events(graph.walk()) context_store = {} for event in graph.walk(): context = yield self.state.compute_event_context(event) self.store.register_event_context(event, context) context_store[event.event_id] = context prev_state_ids = yield context_store["D"].get_prev_state_ids(self.store) self.assertSetEqual( {"A1", "A2", "A3", "A5", "B"}, {e for e in prev_state_ids.values()} ) def _add_depths(self, nodes, edges): def _get_depth(ev): node = nodes[ev] if 'depth' not in node: prevs = edges[ev] depth = max(_get_depth(prev) for prev in prevs) + 1 node['depth'] = depth return node['depth'] for n in nodes: _get_depth(n) @defer.inlineCallbacks def test_annotate_with_old_message(self): event = create_event(type="test_message", name="event") old_state = [ create_event(type="test1", state_key="1"), create_event(type="test1", state_key="2"), create_event(type="test2", state_key=""), ] context = yield self.state.compute_event_context(event, old_state=old_state) current_state_ids = yield context.get_current_state_ids(self.store) self.assertEqual( set(e.event_id for e in old_state), set(current_state_ids.values()) ) self.assertIsNotNone(context.state_group) @defer.inlineCallbacks def test_annotate_with_old_state(self): event = create_event(type="state", state_key="", name="event") old_state = [ create_event(type="test1", state_key="1"), create_event(type="test1", state_key="2"), create_event(type="test2", state_key=""), ] context = yield self.state.compute_event_context(event, old_state=old_state) prev_state_ids = yield context.get_prev_state_ids(self.store) self.assertEqual( set(e.event_id for e in old_state), set(prev_state_ids.values()) ) @defer.inlineCallbacks def test_trivial_annotate_message(self): prev_event_id = "prev_event_id" event = create_event( type="test_message", name="event2", prev_events=[(prev_event_id, {})] ) old_state = [ create_event(type="test1", state_key="1"), create_event(type="test1", state_key="2"), create_event(type="test2", state_key=""), ] group_name = self.store.store_state_group( prev_event_id, event.room_id, None, None, {(e.type, e.state_key): e.event_id for e in old_state}, ) self.store.register_event_id_state_group(prev_event_id, group_name) context = yield self.state.compute_event_context(event) current_state_ids = yield context.get_current_state_ids(self.store) self.assertEqual( set([e.event_id for e in old_state]), set(current_state_ids.values()) ) self.assertEqual(group_name, context.state_group) @defer.inlineCallbacks def test_trivial_annotate_state(self): prev_event_id = "prev_event_id" event = create_event( type="state", state_key="", name="event2", prev_events=[(prev_event_id, {})] ) old_state = [ create_event(type="test1", state_key="1"), create_event(type="test1", state_key="2"), create_event(type="test2", state_key=""), ] group_name = self.store.store_state_group( prev_event_id, event.room_id, None, None, {(e.type, e.state_key): e.event_id for e in old_state}, ) self.store.register_event_id_state_group(prev_event_id, group_name) context = yield self.state.compute_event_context(event) prev_state_ids = yield context.get_prev_state_ids(self.store) self.assertEqual( set([e.event_id for e in old_state]), set(prev_state_ids.values()) ) self.assertIsNotNone(context.state_group) @defer.inlineCallbacks def test_resolve_message_conflict(self): prev_event_id1 = "event_id1" prev_event_id2 = "event_id2" event = create_event( type="test_message", name="event3", prev_events=[(prev_event_id1, {}), (prev_event_id2, {})], ) creation = create_event(type=EventTypes.Create, state_key="") old_state_1 = [ creation, create_event(type="test1", state_key="1"), create_event(type="test1", state_key="2"), create_event(type="test2", state_key=""), ] old_state_2 = [ creation, create_event(type="test1", state_key="1"), create_event(type="test3", state_key="2"), create_event(type="test4", state_key=""), ] self.store.register_events(old_state_1) self.store.register_events(old_state_2) context = yield self._get_context( event, prev_event_id1, old_state_1, prev_event_id2, old_state_2 ) current_state_ids = yield context.get_current_state_ids(self.store) self.assertEqual(len(current_state_ids), 6) self.assertIsNotNone(context.state_group) @defer.inlineCallbacks def test_resolve_state_conflict(self): prev_event_id1 = "event_id1" prev_event_id2 = "event_id2" event = create_event( type="test4", state_key="", name="event", prev_events=[(prev_event_id1, {}), (prev_event_id2, {})], ) creation = create_event(type=EventTypes.Create, state_key="") old_state_1 = [ creation, create_event(type="test1", state_key="1"), create_event(type="test1", state_key="2"), create_event(type="test2", state_key=""), ] old_state_2 = [ creation, create_event(type="test1", state_key="1"), create_event(type="test3", state_key="2"), create_event(type="test4", state_key=""), ] store = StateGroupStore() store.register_events(old_state_1) store.register_events(old_state_2) self.store.get_events = store.get_events context = yield self._get_context( event, prev_event_id1, old_state_1, prev_event_id2, old_state_2 ) current_state_ids = yield context.get_current_state_ids(self.store) self.assertEqual(len(current_state_ids), 6) self.assertIsNotNone(context.state_group) @defer.inlineCallbacks def test_standard_depth_conflict(self): prev_event_id1 = "event_id1" prev_event_id2 = "event_id2" event = create_event( type="test4", name="event", prev_events=[(prev_event_id1, {}), (prev_event_id2, {})], ) member_event = create_event( type=EventTypes.Member, state_key="@user_id:example.com", content={"membership": Membership.JOIN}, ) power_levels = create_event( type=EventTypes.PowerLevels, state_key="", content={"users": {"@foo:bar": "100", "@user_id:example.com": "100"}}, ) creation = create_event( type=EventTypes.Create, state_key="", content={"creator": "@foo:bar"} ) old_state_1 = [ creation, power_levels, member_event, create_event(type="test1", state_key="1", depth=1), ] old_state_2 = [ creation, power_levels, member_event, create_event(type="test1", state_key="1", depth=2), ] store = StateGroupStore() store.register_events(old_state_1) store.register_events(old_state_2) self.store.get_events = store.get_events context = yield self._get_context( event, prev_event_id1, old_state_1, prev_event_id2, old_state_2 ) current_state_ids = yield context.get_current_state_ids(self.store) self.assertEqual(old_state_2[3].event_id, current_state_ids[("test1", "1")]) # Reverse the depth to make sure we are actually using the depths # during state resolution. old_state_1 = [ creation, power_levels, member_event, create_event(type="test1", state_key="1", depth=2), ] old_state_2 = [ creation, power_levels, member_event, create_event(type="test1", state_key="1", depth=1), ] store.register_events(old_state_1) store.register_events(old_state_2) context = yield self._get_context( event, prev_event_id1, old_state_1, prev_event_id2, old_state_2 ) current_state_ids = yield context.get_current_state_ids(self.store) self.assertEqual(old_state_1[3].event_id, current_state_ids[("test1", "1")]) def _get_context( self, event, prev_event_id_1, old_state_1, prev_event_id_2, old_state_2 ): sg1 = self.store.store_state_group( prev_event_id_1, event.room_id, None, None, {(e.type, e.state_key): e.event_id for e in old_state_1}, ) self.store.register_event_id_state_group(prev_event_id_1, sg1) sg2 = self.store.store_state_group( prev_event_id_2, event.room_id, None, None, {(e.type, e.state_key): e.event_id for e in old_state_2}, ) self.store.register_event_id_state_group(prev_event_id_2, sg2) return self.state.compute_event_context(event)
def get_state_handler(self) -> StateHandler: return StateHandler(self)
class StateTestCase(unittest.TestCase): def setUp(self): self.persistence = Mock(spec=[ "get_unresolved_state_tree", "update_current_state", "get_latest_pdus_in_context", "get_current_state_pdu", "get_pdu", ]) self.replication = Mock(spec=["get_pdu"]) hs = Mock(spec=["get_datastore", "get_replication_layer"]) hs.get_datastore.return_value = self.persistence hs.get_replication_layer.return_value = self.replication hs.hostname = "bob.com" self.state = StateHandler(hs) @defer.inlineCallbacks def test_new_state_key(self): # We've never seen anything for this state before new_pdu = new_fake_pdu_entry("A", "test", "mem", "x", None, 10) self.persistence.get_unresolved_state_tree.return_value = ( ReturnType([new_pdu], []) ) is_new = yield self.state.handle_new_state(new_pdu) self.assertTrue(is_new) self.persistence.get_unresolved_state_tree.assert_called_once_with( new_pdu ) self.assertEqual(1, self.persistence.update_current_state.call_count) self.assertFalse(self.replication.get_pdu.called) @defer.inlineCallbacks def test_direct_overwrite(self): # We do a direct overwriting of the old state, i.e., the new state # points to the old state. old_pdu = new_fake_pdu_entry("A", "test", "mem", "x", None, 10) new_pdu = new_fake_pdu_entry("B", "test", "mem", "x", "A", 5) self.persistence.get_unresolved_state_tree.return_value = ( ReturnType([new_pdu, old_pdu], [old_pdu]) ) is_new = yield self.state.handle_new_state(new_pdu) self.assertTrue(is_new) self.persistence.get_unresolved_state_tree.assert_called_once_with( new_pdu ) self.assertEqual(1, self.persistence.update_current_state.call_count) self.assertFalse(self.replication.get_pdu.called) @defer.inlineCallbacks def test_power_level_fail(self): # We try to update the state based on an outdated state, and have a # too low power level. old_pdu_1 = new_fake_pdu_entry("A", "test", "mem", "x", None, 10) old_pdu_2 = new_fake_pdu_entry("B", "test", "mem", "x", None, 10) new_pdu = new_fake_pdu_entry("C", "test", "mem", "x", "A", 5) self.persistence.get_unresolved_state_tree.return_value = ( ReturnType([new_pdu, old_pdu_1], [old_pdu_2, old_pdu_1]) ) is_new = yield self.state.handle_new_state(new_pdu) self.assertFalse(is_new) self.persistence.get_unresolved_state_tree.assert_called_once_with( new_pdu ) self.assertEqual(0, self.persistence.update_current_state.call_count) self.assertFalse(self.replication.get_pdu.called) @defer.inlineCallbacks def test_power_level_succeed(self): # We try to update the state based on an outdated state, but have # sufficient power level to force the update. old_pdu_1 = new_fake_pdu_entry("A", "test", "mem", "x", None, 10) old_pdu_2 = new_fake_pdu_entry("B", "test", "mem", "x", None, 10) new_pdu = new_fake_pdu_entry("C", "test", "mem", "x", "A", 15) self.persistence.get_unresolved_state_tree.return_value = ( ReturnType([new_pdu, old_pdu_1], [old_pdu_2, old_pdu_1]) ) is_new = yield self.state.handle_new_state(new_pdu) self.assertTrue(is_new) self.persistence.get_unresolved_state_tree.assert_called_once_with( new_pdu ) self.assertEqual(1, self.persistence.update_current_state.call_count) self.assertFalse(self.replication.get_pdu.called) @defer.inlineCallbacks def test_power_level_equal_same_len(self): # We try to update the state based on an outdated state, the power # levels are the same and so are the branch lengths old_pdu_1 = new_fake_pdu_entry("A", "test", "mem", "x", None, 10) old_pdu_2 = new_fake_pdu_entry("B", "test", "mem", "x", None, 10) new_pdu = new_fake_pdu_entry("C", "test", "mem", "x", "A", 10) self.persistence.get_unresolved_state_tree.return_value = ( ReturnType([new_pdu, old_pdu_1], [old_pdu_2, old_pdu_1]) ) is_new = yield self.state.handle_new_state(new_pdu) self.assertTrue(is_new) self.persistence.get_unresolved_state_tree.assert_called_once_with( new_pdu ) self.assertEqual(1, self.persistence.update_current_state.call_count) self.assertFalse(self.replication.get_pdu.called) @defer.inlineCallbacks def test_power_level_equal_diff_len(self): # We try to update the state based on an outdated state, the power # levels are the same but the branch length of the new one is longer. old_pdu_1 = new_fake_pdu_entry("A", "test", "mem", "x", None, 10) old_pdu_2 = new_fake_pdu_entry("B", "test", "mem", "x", None, 10) old_pdu_3 = new_fake_pdu_entry("C", "test", "mem", "x", "A", 10) new_pdu = new_fake_pdu_entry("D", "test", "mem", "x", "C", 10) self.persistence.get_unresolved_state_tree.return_value = ( ReturnType([new_pdu, old_pdu_3, old_pdu_1], [old_pdu_2, old_pdu_1]) ) is_new = yield self.state.handle_new_state(new_pdu) self.assertTrue(is_new) self.persistence.get_unresolved_state_tree.assert_called_once_with( new_pdu ) self.assertEqual(1, self.persistence.update_current_state.call_count) self.assertFalse(self.replication.get_pdu.called) @defer.inlineCallbacks def test_missing_pdu(self): # We try to update state against a PDU we haven't yet seen, # triggering a get_pdu request # The pdu we haven't seen old_pdu_1 = new_fake_pdu_entry("A", "test", "mem", "x", None, 10) old_pdu_2 = new_fake_pdu_entry("B", "test", "mem", "x", None, 10) new_pdu = new_fake_pdu_entry("C", "test", "mem", "x", "A", 20) # The return_value of `get_unresolved_state_tree`, which changes after # the call to get_pdu tree_to_return = [ReturnType([new_pdu], [old_pdu_2])] def return_tree(p): return tree_to_return[0] def set_return_tree(*args, **kwargs): tree_to_return[0] = ReturnType( [new_pdu, old_pdu_1], [old_pdu_2, old_pdu_1] ) self.persistence.get_unresolved_state_tree.side_effect = return_tree self.replication.get_pdu.side_effect = set_return_tree self.persistence.get_pdu.return_value = None is_new = yield self.state.handle_new_state(new_pdu) self.assertTrue(is_new) self.persistence.get_unresolved_state_tree.assert_called_with( new_pdu ) self.assertEquals( 2, self.persistence.get_unresolved_state_tree.call_count ) self.assertEqual(1, self.persistence.update_current_state.call_count) @defer.inlineCallbacks def test_new_event(self): event = Mock() event.event_id = "12123123@test" state_pdu = new_fake_pdu_entry("C", "test", "mem", "x", "A", 20) snapshot = Mock() snapshot.prev_state_pdu = state_pdu event_id = "*****@*****.**" def fill_out_prev_events(event): event.prev_events = [event_id] event.depth = 6 snapshot.fill_out_prev_events = fill_out_prev_events yield self.state.handle_new_event(event, snapshot) self.assertLess(5, event.depth) self.assertEquals(1, len(event.prev_events)) prev_id = event.prev_events[0] self.assertEqual(event_id, prev_id) self.assertEqual( encode_event_id(state_pdu.pdu_id, state_pdu.origin), event.prev_state )
class StateTestCase(unittest.TestCase): def setUp(self): self.store = Mock(spec_set=[ "get_state_groups", "add_event_hashes", ]) hs = Mock(spec=[ "get_datastore", "get_auth", "get_state_handler", "get_clock", ]) hs.get_datastore.return_value = self.store hs.get_state_handler.return_value = None hs.get_auth.return_value = Auth(hs) hs.get_clock.return_value = MockClock() self.state = StateHandler(hs) self.event_id = 0 @defer.inlineCallbacks def test_branch_no_conflict(self): graph = Graph(nodes={ "START": DictObj( type=EventTypes.Create, state_key="", depth=1, ), "A": DictObj( type=EventTypes.Message, depth=2, ), "B": DictObj( type=EventTypes.Message, depth=3, ), "C": DictObj( type=EventTypes.Name, state_key="", depth=3, ), "D": DictObj( type=EventTypes.Message, depth=4, ), }, edges={ "A": ["START"], "B": ["A"], "C": ["A"], "D": ["B", "C"] }) store = StateGroupStore() self.store.get_state_groups.side_effect = store.get_state_groups context_store = {} for event in graph.walk(): context = yield self.state.compute_event_context(event) store.store_state_groups(event, context) context_store[event.event_id] = context self.assertEqual(2, len(context_store["D"].current_state)) @defer.inlineCallbacks def test_branch_basic_conflict(self): graph = Graph(nodes={ "START": DictObj( type=EventTypes.Create, state_key="", content={"creator": "@user_id:example.com"}, depth=1, ), "A": DictObj( type=EventTypes.Member, state_key="@user_id:example.com", content={"membership": Membership.JOIN}, membership=Membership.JOIN, depth=2, ), "B": DictObj( type=EventTypes.Name, state_key="", depth=3, ), "C": DictObj( type=EventTypes.Name, state_key="", depth=4, ), "D": DictObj( type=EventTypes.Message, depth=5, ), }, edges={ "A": ["START"], "B": ["A"], "C": ["A"], "D": ["B", "C"] }) store = StateGroupStore() self.store.get_state_groups.side_effect = store.get_state_groups context_store = {} for event in graph.walk(): context = yield self.state.compute_event_context(event) store.store_state_groups(event, context) context_store[event.event_id] = context self.assertSetEqual( {"START", "A", "C"}, {e.event_id for e in context_store["D"].current_state.values()}) @defer.inlineCallbacks def test_branch_have_banned_conflict(self): graph = Graph(nodes={ "START": DictObj( type=EventTypes.Create, state_key="", content={"creator": "@user_id:example.com"}, depth=1, ), "A": DictObj( type=EventTypes.Member, state_key="@user_id:example.com", content={"membership": Membership.JOIN}, membership=Membership.JOIN, depth=2, ), "B": DictObj( type=EventTypes.Name, state_key="", depth=3, ), "C": DictObj( type=EventTypes.Member, state_key="@user_id_2:example.com", content={"membership": Membership.BAN}, membership=Membership.BAN, depth=4, ), "D": DictObj( type=EventTypes.Name, state_key="", depth=4, sender="@user_id_2:example.com", ), "E": DictObj( type=EventTypes.Message, depth=5, ), }, edges={ "A": ["START"], "B": ["A"], "C": ["B"], "D": ["B"], "E": ["C", "D"] }) store = StateGroupStore() self.store.get_state_groups.side_effect = store.get_state_groups context_store = {} for event in graph.walk(): context = yield self.state.compute_event_context(event) store.store_state_groups(event, context) context_store[event.event_id] = context self.assertSetEqual( {"START", "A", "B", "C"}, {e.event_id for e in context_store["E"].current_state.values()}) @defer.inlineCallbacks def test_branch_have_perms_conflict(self): userid1 = "@user_id:example.com" userid2 = "@user_id2:example.com" nodes = { "A1": DictObj( type=EventTypes.Create, state_key="", content={"creator": userid1}, depth=1, ), "A2": DictObj( type=EventTypes.Member, state_key=userid1, content={"membership": Membership.JOIN}, membership=Membership.JOIN, ), "A3": DictObj( type=EventTypes.Member, state_key=userid2, content={"membership": Membership.JOIN}, membership=Membership.JOIN, ), "A4": DictObj( type=EventTypes.PowerLevels, state_key="", content={ "events": { "m.room.name": 50 }, "users": { userid1: 100, userid2: 60 }, }, ), "A5": DictObj( type=EventTypes.Name, state_key="", ), "B": DictObj( type=EventTypes.PowerLevels, state_key="", content={ "events": { "m.room.name": 50 }, "users": { userid2: 30 }, }, ), "C": DictObj( type=EventTypes.Name, state_key="", sender=userid2, ), "D": DictObj(type=EventTypes.Message, ), } edges = { "A2": ["A1"], "A3": ["A2"], "A4": ["A3"], "A5": ["A4"], "B": ["A5"], "C": ["A5"], "D": ["B", "C"] } self._add_depths(nodes, edges) graph = Graph(nodes, edges) store = StateGroupStore() self.store.get_state_groups.side_effect = store.get_state_groups context_store = {} for event in graph.walk(): context = yield self.state.compute_event_context(event) store.store_state_groups(event, context) context_store[event.event_id] = context self.assertSetEqual( {"A1", "A2", "A3", "A5", "B"}, {e.event_id for e in context_store["D"].current_state.values()}) def _add_depths(self, nodes, edges): def _get_depth(ev): node = nodes[ev] if 'depth' not in node: prevs = edges[ev] depth = max(_get_depth(prev) for prev in prevs) + 1 node['depth'] = depth return node['depth'] for n in nodes: _get_depth(n) @defer.inlineCallbacks def test_annotate_with_old_message(self): event = create_event(type="test_message", name="event") old_state = [ create_event(type="test1", state_key="1"), create_event(type="test1", state_key="2"), create_event(type="test2", state_key=""), ] context = yield self.state.compute_event_context(event, old_state=old_state) for k, v in context.current_state.items(): type, state_key = k self.assertEqual(type, v.type) self.assertEqual(state_key, v.state_key) self.assertEqual(set(old_state), set(context.current_state.values())) self.assertIsNone(context.state_group) @defer.inlineCallbacks def test_annotate_with_old_state(self): event = create_event(type="state", state_key="", name="event") old_state = [ create_event(type="test1", state_key="1"), create_event(type="test1", state_key="2"), create_event(type="test2", state_key=""), ] context = yield self.state.compute_event_context(event, old_state=old_state) for k, v in context.current_state.items(): type, state_key = k self.assertEqual(type, v.type) self.assertEqual(state_key, v.state_key) self.assertEqual(set(old_state), set(context.current_state.values())) self.assertIsNone(context.state_group) @defer.inlineCallbacks def test_trivial_annotate_message(self): event = create_event(type="test_message", name="event") old_state = [ create_event(type="test1", state_key="1"), create_event(type="test1", state_key="2"), create_event(type="test2", state_key=""), ] group_name = "group_name_1" self.store.get_state_groups.return_value = { group_name: old_state, } context = yield self.state.compute_event_context(event) for k, v in context.current_state.items(): type, state_key = k self.assertEqual(type, v.type) self.assertEqual(state_key, v.state_key) self.assertEqual( set([e.event_id for e in old_state]), set([e.event_id for e in context.current_state.values()])) self.assertEqual(group_name, context.state_group) @defer.inlineCallbacks def test_trivial_annotate_state(self): event = create_event(type="state", state_key="", name="event") old_state = [ create_event(type="test1", state_key="1"), create_event(type="test1", state_key="2"), create_event(type="test2", state_key=""), ] group_name = "group_name_1" self.store.get_state_groups.return_value = { group_name: old_state, } context = yield self.state.compute_event_context(event) for k, v in context.current_state.items(): type, state_key = k self.assertEqual(type, v.type) self.assertEqual(state_key, v.state_key) self.assertEqual( set([e.event_id for e in old_state]), set([e.event_id for e in context.current_state.values()])) self.assertIsNone(context.state_group) @defer.inlineCallbacks def test_resolve_message_conflict(self): event = create_event(type="test_message", name="event") creation = create_event(type=EventTypes.Create, state_key="") old_state_1 = [ creation, create_event(type="test1", state_key="1"), create_event(type="test1", state_key="2"), create_event(type="test2", state_key=""), ] old_state_2 = [ creation, create_event(type="test1", state_key="1"), create_event(type="test3", state_key="2"), create_event(type="test4", state_key=""), ] context = yield self._get_context(event, old_state_1, old_state_2) self.assertEqual(len(context.current_state), 6) self.assertIsNone(context.state_group) @defer.inlineCallbacks def test_resolve_state_conflict(self): event = create_event(type="test4", state_key="", name="event") creation = create_event(type=EventTypes.Create, state_key="") old_state_1 = [ creation, create_event(type="test1", state_key="1"), create_event(type="test1", state_key="2"), create_event(type="test2", state_key=""), ] old_state_2 = [ creation, create_event(type="test1", state_key="1"), create_event(type="test3", state_key="2"), create_event(type="test4", state_key=""), ] context = yield self._get_context(event, old_state_1, old_state_2) self.assertEqual(len(context.current_state), 6) self.assertIsNone(context.state_group) @defer.inlineCallbacks def test_standard_depth_conflict(self): event = create_event(type="test4", name="event") member_event = create_event(type=EventTypes.Member, state_key="@user_id:example.com", content={ "membership": Membership.JOIN, }) creation = create_event(type=EventTypes.Create, state_key="", content={"creator": "@foo:bar"}) old_state_1 = [ creation, member_event, create_event(type="test1", state_key="1", depth=1), ] old_state_2 = [ creation, member_event, create_event(type="test1", state_key="1", depth=2), ] context = yield self._get_context(event, old_state_1, old_state_2) self.assertEqual(old_state_2[2], context.current_state[("test1", "1")]) # Reverse the depth to make sure we are actually using the depths # during state resolution. old_state_1 = [ creation, member_event, create_event(type="test1", state_key="1", depth=2), ] old_state_2 = [ creation, member_event, create_event(type="test1", state_key="1", depth=1), ] context = yield self._get_context(event, old_state_1, old_state_2) self.assertEqual(old_state_1[2], context.current_state[("test1", "1")]) def _get_context(self, event, old_state_1, old_state_2): group_name_1 = "group_name_1" group_name_2 = "group_name_2" self.store.get_state_groups.return_value = { group_name_1: old_state_1, group_name_2: old_state_2, } return self.state.compute_event_context(event)
class StateTestCase(unittest.TestCase): def setUp(self): self.store = Mock( spec_set=[ "get_state_groups", "add_event_hashes", ] ) hs = Mock(spec=["get_datastore"]) hs.get_datastore.return_value = self.store self.state = StateHandler(hs) self.event_id = 0 @defer.inlineCallbacks def test_annotate_with_old_message(self): event = self.create_event(type="test_message", name="event") old_state = [ self.create_event(type="test1", state_key="1"), self.create_event(type="test1", state_key="2"), self.create_event(type="test2", state_key=""), ] context = yield self.state.compute_event_context( event, old_state=old_state ) for k, v in context.current_state.items(): type, state_key = k self.assertEqual(type, v.type) self.assertEqual(state_key, v.state_key) self.assertEqual( set(old_state), set(context.current_state.values()) ) self.assertIsNone(context.state_group) @defer.inlineCallbacks def test_annotate_with_old_state(self): event = self.create_event(type="state", state_key="", name="event") old_state = [ self.create_event(type="test1", state_key="1"), self.create_event(type="test1", state_key="2"), self.create_event(type="test2", state_key=""), ] context = yield self.state.compute_event_context( event, old_state=old_state ) for k, v in context.current_state.items(): type, state_key = k self.assertEqual(type, v.type) self.assertEqual(state_key, v.state_key) self.assertEqual( set(old_state), set(context.current_state.values()) ) self.assertIsNone(context.state_group) @defer.inlineCallbacks def test_trivial_annotate_message(self): event = self.create_event(type="test_message", name="event") event.prev_events = [] old_state = [ self.create_event(type="test1", state_key="1"), self.create_event(type="test1", state_key="2"), self.create_event(type="test2", state_key=""), ] group_name = "group_name_1" self.store.get_state_groups.return_value = { group_name: old_state, } context = yield self.state.compute_event_context(event) for k, v in context.current_state.items(): type, state_key = k self.assertEqual(type, v.type) self.assertEqual(state_key, v.state_key) self.assertEqual( set([e.event_id for e in old_state]), set([e.event_id for e in context.current_state.values()]) ) self.assertEqual(group_name, context.state_group) @defer.inlineCallbacks def test_trivial_annotate_state(self): event = self.create_event(type="state", state_key="", name="event") event.prev_events = [] old_state = [ self.create_event(type="test1", state_key="1"), self.create_event(type="test1", state_key="2"), self.create_event(type="test2", state_key=""), ] group_name = "group_name_1" self.store.get_state_groups.return_value = { group_name: old_state, } context = yield self.state.compute_event_context(event) for k, v in context.current_state.items(): type, state_key = k self.assertEqual(type, v.type) self.assertEqual(state_key, v.state_key) self.assertEqual( set([e.event_id for e in old_state]), set([e.event_id for e in context.current_state.values()]) ) self.assertIsNone(context.state_group) @defer.inlineCallbacks def test_resolve_message_conflict(self): event = self.create_event(type="test_message", name="event") event.prev_events = [] old_state_1 = [ self.create_event(type="test1", state_key="1"), self.create_event(type="test1", state_key="2"), self.create_event(type="test2", state_key=""), ] old_state_2 = [ self.create_event(type="test1", state_key="1"), self.create_event(type="test3", state_key="2"), self.create_event(type="test4", state_key=""), ] group_name_1 = "group_name_1" group_name_2 = "group_name_2" self.store.get_state_groups.return_value = { group_name_1: old_state_1, group_name_2: old_state_2, } context = yield self.state.compute_event_context(event) self.assertEqual(len(context.current_state), 5) self.assertIsNone(context.state_group) @defer.inlineCallbacks def test_resolve_state_conflict(self): event = self.create_event(type="test4", state_key="", name="event") event.prev_events = [] old_state_1 = [ self.create_event(type="test1", state_key="1"), self.create_event(type="test1", state_key="2"), self.create_event(type="test2", state_key=""), ] old_state_2 = [ self.create_event(type="test1", state_key="1"), self.create_event(type="test3", state_key="2"), self.create_event(type="test4", state_key=""), ] group_name_1 = "group_name_1" group_name_2 = "group_name_2" self.store.get_state_groups.return_value = { group_name_1: old_state_1, group_name_2: old_state_2, } context = yield self.state.compute_event_context(event) self.assertEqual(len(context.current_state), 5) self.assertIsNone(context.state_group) def create_event(self, name=None, type=None, state_key=None): self.event_id += 1 event_id = str(self.event_id) if not name: if state_key is not None: name = "<%s-%s>" % (type, state_key) else: name = "<%s>" % (type, ) event = Mock(name=name, spec=[]) event.type = type if state_key is not None: event.state_key = state_key event.event_id = event_id event.is_state = lambda: (state_key is not None) event.unsigned = {} event.user_id = "@user_id:example.com" event.room_id = "!room_id:example.com" return event