Example #1
0
    def setUp(self):
        self.dummy_store = _DummyStore()
        storage_controllers = Mock(main=self.dummy_store, state=self.dummy_store)
        hs = Mock(
            spec_set=[
                "config",
                "get_datastores",
                "get_storage_controllers",
                "get_auth",
                "get_state_handler",
                "get_clock",
                "get_state_resolution_handler",
                "get_account_validity_handler",
                "get_macaroon_generator",
                "hostname",
            ]
        )
        clock = cast(Clock, MockClock())
        hs.config = default_config("tesths", True)
        hs.get_datastores.return_value = Mock(main=self.dummy_store)
        hs.get_state_handler.return_value = None
        hs.get_clock.return_value = clock
        hs.get_macaroon_generator.return_value = MacaroonGenerator(
            clock, "tesths", b"verysecret"
        )
        hs.get_auth.return_value = Auth(hs)
        hs.get_state_resolution_handler = lambda: StateResolutionHandler(hs)
        hs.get_storage_controllers.return_value = storage_controllers

        self.state = StateHandler(hs)
        self.event_id = 0
Example #2
0
    def setUp(self):
        self.store = StateGroupStore()
        storage = Mock(main=self.store, state=self.store)
        hs = Mock(
            spec_set=[
                "config",
                "get_datastore",
                "get_storage",
                "get_auth",
                "get_state_handler",
                "get_clock",
                "get_state_resolution_handler",
                "hostname",
            ]
        )
        hs.config = default_config("tesths", True)
        hs.get_datastore.return_value = self.store
        hs.get_state_handler.return_value = None
        hs.get_clock.return_value = MockClock()
        hs.get_auth.return_value = Auth(hs)
        hs.get_state_resolution_handler = lambda: StateResolutionHandler(hs)
        hs.get_storage.return_value = storage

        self.state = StateHandler(hs)
        self.event_id = 0
Example #3
0
    def setUp(self):
        self.store = Mock(
            spec_set=[
                "get_state_groups",
                "add_event_hashes",
            ]
        )
        hs = Mock(spec=["get_datastore"])
        hs.get_datastore.return_value = self.store

        self.state = StateHandler(hs)
        self.event_id = 0
Example #4
0
    def setUp(self):
        self.store = StateGroupStore()
        hs = Mock(spec_set=[
            "get_datastore", "get_auth", "get_state_handler", "get_clock",
            "get_state_resolution_handler",
        ])
        hs.get_datastore.return_value = self.store
        hs.get_state_handler.return_value = None
        hs.get_clock.return_value = MockClock()
        hs.get_auth.return_value = Auth(hs)
        hs.get_state_resolution_handler = lambda: StateResolutionHandler(hs)

        self.state = StateHandler(hs)
        self.event_id = 0
Example #5
0
    def setUp(self):
        self.store = Mock(
            spec_set=[
                "get_state_groups",
                "add_event_hashes",
            ]
        )
        hs = Mock(spec=[
            "get_datastore", "get_auth", "get_state_handler", "get_clock",
        ])
        hs.get_datastore.return_value = self.store
        hs.get_state_handler.return_value = None
        hs.get_auth.return_value = Auth(hs)
        hs.get_clock.return_value = MockClock()

        self.state = StateHandler(hs)
        self.event_id = 0
Example #6
0
    def setUp(self):
        self.persistence = Mock(spec=[
            "get_unresolved_state_tree",
            "update_current_state",
            "get_latest_pdus_in_context",
            "get_current_state_pdu",
            "get_pdu",
            "get_power_level",
        ])
        self.replication = Mock(spec=["get_pdu"])

        hs = Mock(spec=["get_datastore", "get_replication_layer"])
        hs.get_datastore.return_value = self.persistence
        hs.get_replication_layer.return_value = self.replication
        hs.hostname = "bob.com"

        self.state = StateHandler(hs)
Example #7
0
    def setUp(self):
        self.store = Mock(spec_set=["get_state_groups", "add_event_hashes"])
        hs = Mock(spec_set=["get_datastore", "get_auth", "get_state_handler", "get_clock"])
        hs.get_datastore.return_value = self.store
        hs.get_state_handler.return_value = None
        hs.get_clock.return_value = MockClock()
        hs.get_auth.return_value = Auth(hs)

        self.state = StateHandler(hs)
        self.event_id = 0
Example #8
0
    def setUp(self):
        self.store = Mock(
            spec_set=[
                "get_state_groups",
            ]
        )
        hs = Mock(spec=["get_datastore"])
        hs.get_datastore.return_value = self.store

        self.state = StateHandler(hs)
        self.event_id = 0
Example #9
0
    def setUp(self):
        self.store = StateGroupStore()
        hs = Mock(spec_set=[
            "get_datastore", "get_auth", "get_state_handler", "get_clock",
            "get_state_resolution_handler",
        ])
        hs.get_datastore.return_value = self.store
        hs.get_state_handler.return_value = None
        hs.get_clock.return_value = MockClock()
        hs.get_auth.return_value = Auth(hs)
        hs.get_state_resolution_handler = lambda: StateResolutionHandler(hs)

        self.state = StateHandler(hs)
        self.event_id = 0
Example #10
0
    def setUp(self):
        self.store = Mock(
            spec_set=[
                "get_state_groups_ids",
                "add_event_hashes",
                "get_events",
                "get_next_state_group",
                "get_state_group_delta",
            ]
        )
        hs = Mock(spec_set=[
            "get_datastore", "get_auth", "get_state_handler", "get_clock",
        ])
        hs.get_datastore.return_value = self.store
        hs.get_state_handler.return_value = None
        hs.get_clock.return_value = MockClock()
        hs.get_auth.return_value = Auth(hs)

        self.store.get_next_state_group.side_effect = Mock
        self.store.get_state_group_delta.return_value = (None, None)

        self.state = StateHandler(hs)
        self.event_id = 0
Example #11
0
    def setUp(self):
        self.persistence = Mock(spec=[
            "get_unresolved_state_tree",
            "update_current_state",
            "get_latest_pdus_in_context",
            "get_current_state_pdu",
            "get_pdu",
        ])
        self.replication = Mock(spec=["get_pdu"])

        hs = Mock(spec=["get_datastore", "get_replication_layer"])
        hs.get_datastore.return_value = self.persistence
        hs.get_replication_layer.return_value = self.replication
        hs.hostname = "bob.com"

        self.state = StateHandler(hs)
Example #12
0
class StateTestCase(unittest.TestCase):
    def setUp(self):
        self.store = Mock(
            spec_set=[
                "get_state_groups",
            ]
        )
        hs = Mock(spec=["get_datastore"])
        hs.get_datastore.return_value = self.store

        self.state = StateHandler(hs)
        self.event_id = 0

    @defer.inlineCallbacks
    def test_annotate_with_old_message(self):
        event = self.create_event(type="test_message", name="event")

        old_state = [
            self.create_event(type="test1", state_key="1"),
            self.create_event(type="test1", state_key="2"),
            self.create_event(type="test2", state_key=""),
        ]

        yield self.state.annotate_event_with_state(event, old_state=old_state)

        for k, v in event.old_state_events.items():
            type, state_key = k
            self.assertEqual(type, v.type)
            self.assertEqual(state_key, v.state_key)

        self.assertEqual(set(old_state), set(event.old_state_events.values()))
        self.assertDictEqual(event.old_state_events, event.state_events)

        self.assertIsNone(event.state_group)

    @defer.inlineCallbacks
    def test_annotate_with_old_state(self):
        event = self.create_event(type="state", state_key="", name="event")

        old_state = [
            self.create_event(type="test1", state_key="1"),
            self.create_event(type="test1", state_key="2"),
            self.create_event(type="test2", state_key=""),
        ]

        yield self.state.annotate_event_with_state(event, old_state=old_state)

        for k, v in event.old_state_events.items():
            type, state_key = k
            self.assertEqual(type, v.type)
            self.assertEqual(state_key, v.state_key)

        self.assertEqual(
            set(old_state + [event]),
            set(event.old_state_events.values())
        )

        self.assertDictEqual(event.old_state_events, event.state_events)

        self.assertIsNone(event.state_group)

    @defer.inlineCallbacks
    def test_trivial_annotate_message(self):
        event = self.create_event(type="test_message", name="event")
        event.prev_events = []

        old_state = [
            self.create_event(type="test1", state_key="1"),
            self.create_event(type="test1", state_key="2"),
            self.create_event(type="test2", state_key=""),
        ]

        group_name = "group_name_1"

        self.store.get_state_groups.return_value = {
            group_name: old_state,
        }

        yield self.state.annotate_event_with_state(event)

        for k, v in event.old_state_events.items():
            type, state_key = k
            self.assertEqual(type, v.type)
            self.assertEqual(state_key, v.state_key)

        self.assertEqual(
            set([e.event_id for e in old_state]),
            set([e.event_id for e in event.old_state_events.values()])
        )

        self.assertDictEqual(
            {
                k: v.event_id
                for k, v in event.old_state_events.items()
            },
            {
                k: v.event_id
                for k, v in event.state_events.items()
            }
        )

        self.assertEqual(group_name, event.state_group)

    @defer.inlineCallbacks
    def test_trivial_annotate_state(self):
        event = self.create_event(type="state", state_key="", name="event")
        event.prev_events = []

        old_state = [
            self.create_event(type="test1", state_key="1"),
            self.create_event(type="test1", state_key="2"),
            self.create_event(type="test2", state_key=""),
        ]

        group_name = "group_name_1"

        self.store.get_state_groups.return_value = {
            group_name: old_state,
        }

        yield self.state.annotate_event_with_state(event)

        for k, v in event.old_state_events.items():
            type, state_key = k
            self.assertEqual(type, v.type)
            self.assertEqual(state_key, v.state_key)

        self.assertEqual(
            set([e.event_id for e in old_state]),
            set([e.event_id for e in event.old_state_events.values()])
        )

        self.assertEqual(
            set([e.event_id for e in old_state] + [event.event_id]),
            set([e.event_id for e in event.state_events.values()])
        )

        new_state = {
            k: v.event_id
            for k, v in event.state_events.items()
        }
        old_state = {
            k: v.event_id
            for k, v in event.old_state_events.items()
        }
        old_state[(event.type, event.state_key)] = event.event_id
        self.assertDictEqual(
            old_state,
            new_state
        )

        self.assertIsNone(event.state_group)

    @defer.inlineCallbacks
    def test_resolve_message_conflict(self):
        event = self.create_event(type="test_message", name="event")
        event.prev_events = []

        old_state_1 = [
            self.create_event(type="test1", state_key="1"),
            self.create_event(type="test1", state_key="2"),
            self.create_event(type="test2", state_key=""),
        ]

        old_state_2 = [
            self.create_event(type="test1", state_key="1"),
            self.create_event(type="test3", state_key="2"),
            self.create_event(type="test4", state_key=""),
        ]

        group_name_1 = "group_name_1"
        group_name_2 = "group_name_2"

        self.store.get_state_groups.return_value = {
            group_name_1: old_state_1,
            group_name_2: old_state_2,
        }

        yield self.state.annotate_event_with_state(event)

        self.assertEqual(len(event.old_state_events), 5)

        self.assertEqual(
            set([e.event_id for e in event.state_events.values()]),
            set([e.event_id for e in event.old_state_events.values()])
        )

        self.assertIsNone(event.state_group)

    @defer.inlineCallbacks
    def test_resolve_state_conflict(self):
        event = self.create_event(type="test4", state_key="", name="event")
        event.prev_events = []

        old_state_1 = [
            self.create_event(type="test1", state_key="1"),
            self.create_event(type="test1", state_key="2"),
            self.create_event(type="test2", state_key=""),
        ]

        old_state_2 = [
            self.create_event(type="test1", state_key="1"),
            self.create_event(type="test3", state_key="2"),
            self.create_event(type="test4", state_key=""),
        ]

        group_name_1 = "group_name_1"
        group_name_2 = "group_name_2"

        self.store.get_state_groups.return_value = {
            group_name_1: old_state_1,
            group_name_2: old_state_2,
        }

        yield self.state.annotate_event_with_state(event)

        self.assertEqual(len(event.old_state_events), 5)

        expected_new = event.old_state_events
        expected_new[(event.type, event.state_key)] = event

        self.assertEqual(
            set([e.event_id for e in expected_new.values()]),
            set([e.event_id for e in event.state_events.values()]),
        )

        self.assertIsNone(event.state_group)

    def create_event(self, name=None, type=None, state_key=None):
        self.event_id += 1
        event_id = str(self.event_id)

        if not name:
            if state_key is not None:
                name = "<%s-%s>" % (type, state_key)
            else:
                name = "<%s>" % (type, )

        event = Mock(name=name, spec=[])
        event.type = type

        if state_key is not None:
            event.state_key = state_key
        event.event_id = event_id

        event.user_id = "@user_id:example.com"
        event.room_id = "!room_id:example.com"

        return event
Example #13
0
class StateTestCase(unittest.TestCase):
    def setUp(self):
        self.store = Mock(
            spec_set=[
                "get_state_groups",
                "add_event_hashes",
            ]
        )
        hs = Mock(spec=[
            "get_datastore", "get_auth", "get_state_handler", "get_clock",
        ])
        hs.get_datastore.return_value = self.store
        hs.get_state_handler.return_value = None
        hs.get_auth.return_value = Auth(hs)
        hs.get_clock.return_value = MockClock()

        self.state = StateHandler(hs)
        self.event_id = 0

    @defer.inlineCallbacks
    def test_branch_no_conflict(self):
        graph = Graph(
            nodes={
                "START": DictObj(
                    type=EventTypes.Create,
                    state_key="",
                    depth=1,
                ),
                "A": DictObj(
                    type=EventTypes.Message,
                    depth=2,
                ),
                "B": DictObj(
                    type=EventTypes.Message,
                    depth=3,
                ),
                "C": DictObj(
                    type=EventTypes.Name,
                    state_key="",
                    depth=3,
                ),
                "D": DictObj(
                    type=EventTypes.Message,
                    depth=4,
                ),
            },
            edges={
                "A": ["START"],
                "B": ["A"],
                "C": ["A"],
                "D": ["B", "C"]
            }
        )

        store = StateGroupStore()
        self.store.get_state_groups.side_effect = store.get_state_groups

        context_store = {}

        for event in graph.walk():
            context = yield self.state.compute_event_context(event)
            store.store_state_groups(event, context)
            context_store[event.event_id] = context

        self.assertEqual(2, len(context_store["D"].current_state))

    @defer.inlineCallbacks
    def test_branch_basic_conflict(self):
        graph = Graph(
            nodes={
                "START": DictObj(
                    type=EventTypes.Create,
                    state_key="",
                    content={"creator": "@user_id:example.com"},
                    depth=1,
                ),
                "A": DictObj(
                    type=EventTypes.Member,
                    state_key="@user_id:example.com",
                    content={"membership": Membership.JOIN},
                    membership=Membership.JOIN,
                    depth=2,
                ),
                "B": DictObj(
                    type=EventTypes.Name,
                    state_key="",
                    depth=3,
                ),
                "C": DictObj(
                    type=EventTypes.Name,
                    state_key="",
                    depth=4,
                ),
                "D": DictObj(
                    type=EventTypes.Message,
                    depth=5,
                ),
            },
            edges={
                "A": ["START"],
                "B": ["A"],
                "C": ["A"],
                "D": ["B", "C"]
            }
        )

        store = StateGroupStore()
        self.store.get_state_groups.side_effect = store.get_state_groups

        context_store = {}

        for event in graph.walk():
            context = yield self.state.compute_event_context(event)
            store.store_state_groups(event, context)
            context_store[event.event_id] = context

        self.assertSetEqual(
            {"START", "A", "C"},
            {e.event_id for e in context_store["D"].current_state.values()}
        )

    @defer.inlineCallbacks
    def test_branch_have_banned_conflict(self):
        graph = Graph(
            nodes={
                "START": DictObj(
                    type=EventTypes.Create,
                    state_key="",
                    content={"creator": "@user_id:example.com"},
                    depth=1,
                ),
                "A": DictObj(
                    type=EventTypes.Member,
                    state_key="@user_id:example.com",
                    content={"membership": Membership.JOIN},
                    membership=Membership.JOIN,
                    depth=2,
                ),
                "B": DictObj(
                    type=EventTypes.Name,
                    state_key="",
                    depth=3,
                ),
                "C": DictObj(
                    type=EventTypes.Member,
                    state_key="@user_id_2:example.com",
                    content={"membership": Membership.BAN},
                    membership=Membership.BAN,
                    depth=4,
                ),
                "D": DictObj(
                    type=EventTypes.Name,
                    state_key="",
                    depth=4,
                    sender="@user_id_2:example.com",
                ),
                "E": DictObj(
                    type=EventTypes.Message,
                    depth=5,
                ),
            },
            edges={
                "A": ["START"],
                "B": ["A"],
                "C": ["B"],
                "D": ["B"],
                "E": ["C", "D"]
            }
        )

        store = StateGroupStore()
        self.store.get_state_groups.side_effect = store.get_state_groups

        context_store = {}

        for event in graph.walk():
            context = yield self.state.compute_event_context(event)
            store.store_state_groups(event, context)
            context_store[event.event_id] = context

        self.assertSetEqual(
            {"START", "A", "B", "C"},
            {e.event_id for e in context_store["E"].current_state.values()}
        )

    @defer.inlineCallbacks
    def test_branch_have_perms_conflict(self):
        userid1 = "@user_id:example.com"
        userid2 = "@user_id2:example.com"

        nodes = {
            "A1": DictObj(
                type=EventTypes.Create,
                state_key="",
                content={"creator": userid1},
                depth=1,
            ),
            "A2": DictObj(
                type=EventTypes.Member,
                state_key=userid1,
                content={"membership": Membership.JOIN},
                membership=Membership.JOIN,
            ),
            "A3": DictObj(
                type=EventTypes.Member,
                state_key=userid2,
                content={"membership": Membership.JOIN},
                membership=Membership.JOIN,
            ),
            "A4": DictObj(
                type=EventTypes.PowerLevels,
                state_key="",
                content={
                    "events": {"m.room.name": 50},
                    "users": {userid1: 100,
                              userid2: 60},
                },
            ),
            "A5": DictObj(
                type=EventTypes.Name,
                state_key="",
            ),
            "B": DictObj(
                type=EventTypes.PowerLevels,
                state_key="",
                content={
                    "events": {"m.room.name": 50},
                    "users": {userid2: 30},
                },
            ),
            "C": DictObj(
                type=EventTypes.Name,
                state_key="",
                sender=userid2,
            ),
            "D": DictObj(
                type=EventTypes.Message,
            ),
        }
        edges = {
            "A2": ["A1"],
            "A3": ["A2"],
            "A4": ["A3"],
            "A5": ["A4"],
            "B": ["A5"],
            "C": ["A5"],
            "D": ["B", "C"]
        }
        self._add_depths(nodes, edges)
        graph = Graph(nodes, edges)

        store = StateGroupStore()
        self.store.get_state_groups.side_effect = store.get_state_groups

        context_store = {}

        for event in graph.walk():
            context = yield self.state.compute_event_context(event)
            store.store_state_groups(event, context)
            context_store[event.event_id] = context

        self.assertSetEqual(
            {"A1", "A2", "A3", "A5", "B"},
            {e.event_id for e in context_store["D"].current_state.values()}
        )

    def _add_depths(self, nodes, edges):
        def _get_depth(ev):
            node = nodes[ev]
            if 'depth' not in node:
                prevs = edges[ev]
                depth = max(_get_depth(prev) for prev in prevs) + 1
                node['depth'] = depth
            return node['depth']

        for n in nodes:
            _get_depth(n)

    @defer.inlineCallbacks
    def test_annotate_with_old_message(self):
        event = create_event(type="test_message", name="event")

        old_state = [
            create_event(type="test1", state_key="1"),
            create_event(type="test1", state_key="2"),
            create_event(type="test2", state_key=""),
        ]

        context = yield self.state.compute_event_context(
            event, old_state=old_state
        )

        for k, v in context.current_state.items():
            type, state_key = k
            self.assertEqual(type, v.type)
            self.assertEqual(state_key, v.state_key)

        self.assertEqual(
            set(old_state), set(context.current_state.values())
        )

        self.assertIsNone(context.state_group)

    @defer.inlineCallbacks
    def test_annotate_with_old_state(self):
        event = create_event(type="state", state_key="", name="event")

        old_state = [
            create_event(type="test1", state_key="1"),
            create_event(type="test1", state_key="2"),
            create_event(type="test2", state_key=""),
        ]

        context = yield self.state.compute_event_context(
            event, old_state=old_state
        )

        for k, v in context.current_state.items():
            type, state_key = k
            self.assertEqual(type, v.type)
            self.assertEqual(state_key, v.state_key)

        self.assertEqual(
            set(old_state),
            set(context.current_state.values())
        )

        self.assertIsNone(context.state_group)

    @defer.inlineCallbacks
    def test_trivial_annotate_message(self):
        event = create_event(type="test_message", name="event")

        old_state = [
            create_event(type="test1", state_key="1"),
            create_event(type="test1", state_key="2"),
            create_event(type="test2", state_key=""),
        ]

        group_name = "group_name_1"

        self.store.get_state_groups.return_value = {
            group_name: old_state,
        }

        context = yield self.state.compute_event_context(event)

        for k, v in context.current_state.items():
            type, state_key = k
            self.assertEqual(type, v.type)
            self.assertEqual(state_key, v.state_key)

        self.assertEqual(
            set([e.event_id for e in old_state]),
            set([e.event_id for e in context.current_state.values()])
        )

        self.assertEqual(group_name, context.state_group)

    @defer.inlineCallbacks
    def test_trivial_annotate_state(self):
        event = create_event(type="state", state_key="", name="event")

        old_state = [
            create_event(type="test1", state_key="1"),
            create_event(type="test1", state_key="2"),
            create_event(type="test2", state_key=""),
        ]

        group_name = "group_name_1"

        self.store.get_state_groups.return_value = {
            group_name: old_state,
        }

        context = yield self.state.compute_event_context(event)

        for k, v in context.current_state.items():
            type, state_key = k
            self.assertEqual(type, v.type)
            self.assertEqual(state_key, v.state_key)

        self.assertEqual(
            set([e.event_id for e in old_state]),
            set([e.event_id for e in context.current_state.values()])
        )

        self.assertIsNone(context.state_group)

    @defer.inlineCallbacks
    def test_resolve_message_conflict(self):
        event = create_event(type="test_message", name="event")

        creation = create_event(
            type=EventTypes.Create, state_key=""
        )

        old_state_1 = [
            creation,
            create_event(type="test1", state_key="1"),
            create_event(type="test1", state_key="2"),
            create_event(type="test2", state_key=""),
        ]

        old_state_2 = [
            creation,
            create_event(type="test1", state_key="1"),
            create_event(type="test3", state_key="2"),
            create_event(type="test4", state_key=""),
        ]

        context = yield self._get_context(event, old_state_1, old_state_2)

        self.assertEqual(len(context.current_state), 6)

        self.assertIsNone(context.state_group)

    @defer.inlineCallbacks
    def test_resolve_state_conflict(self):
        event = create_event(type="test4", state_key="", name="event")

        creation = create_event(
            type=EventTypes.Create, state_key=""
        )

        old_state_1 = [
            creation,
            create_event(type="test1", state_key="1"),
            create_event(type="test1", state_key="2"),
            create_event(type="test2", state_key=""),
        ]

        old_state_2 = [
            creation,
            create_event(type="test1", state_key="1"),
            create_event(type="test3", state_key="2"),
            create_event(type="test4", state_key=""),
        ]

        context = yield self._get_context(event, old_state_1, old_state_2)

        self.assertEqual(len(context.current_state), 6)

        self.assertIsNone(context.state_group)

    @defer.inlineCallbacks
    def test_standard_depth_conflict(self):
        event = create_event(type="test4", name="event")

        member_event = create_event(
            type=EventTypes.Member,
            state_key="@user_id:example.com",
            content={
                "membership": Membership.JOIN,
            }
        )

        creation = create_event(
            type=EventTypes.Create, state_key="",
            content={"creator": "@foo:bar"}
        )

        old_state_1 = [
            creation,
            member_event,
            create_event(type="test1", state_key="1", depth=1),
        ]

        old_state_2 = [
            creation,
            member_event,
            create_event(type="test1", state_key="1", depth=2),
        ]

        context = yield self._get_context(event, old_state_1, old_state_2)

        self.assertEqual(old_state_2[2], context.current_state[("test1", "1")])

        # Reverse the depth to make sure we are actually using the depths
        # during state resolution.

        old_state_1 = [
            creation,
            member_event,
            create_event(type="test1", state_key="1", depth=2),
        ]

        old_state_2 = [
            creation,
            member_event,
            create_event(type="test1", state_key="1", depth=1),
        ]

        context = yield self._get_context(event, old_state_1, old_state_2)

        self.assertEqual(old_state_1[2], context.current_state[("test1", "1")])

    def _get_context(self, event, old_state_1, old_state_2):
        group_name_1 = "group_name_1"
        group_name_2 = "group_name_2"

        self.store.get_state_groups.return_value = {
            group_name_1: old_state_1,
            group_name_2: old_state_2,
        }

        return self.state.compute_event_context(event)
Example #14
0
class StateTestCase(unittest.TestCase):
    def setUp(self):
        self.store = StateGroupStore()
        storage = Mock(main=self.store, state=self.store)
        hs = Mock(spec_set=[
            "config",
            "get_datastore",
            "get_storage",
            "get_auth",
            "get_state_handler",
            "get_clock",
            "get_state_resolution_handler",
        ])
        hs.config = default_config("tesths", True)
        hs.get_datastore.return_value = self.store
        hs.get_state_handler.return_value = None
        hs.get_clock.return_value = MockClock()
        hs.get_auth.return_value = Auth(hs)
        hs.get_state_resolution_handler = lambda: StateResolutionHandler(hs)
        hs.get_storage.return_value = storage

        self.state = StateHandler(hs)
        self.event_id = 0

    @defer.inlineCallbacks
    def test_branch_no_conflict(self):
        graph = Graph(
            nodes={
                "START":
                DictObj(type=EventTypes.Create,
                        state_key="",
                        content={},
                        depth=1),
                "A":
                DictObj(type=EventTypes.Message, depth=2),
                "B":
                DictObj(type=EventTypes.Message, depth=3),
                "C":
                DictObj(type=EventTypes.Name, state_key="", depth=3),
                "D":
                DictObj(type=EventTypes.Message, depth=4),
            },
            edges={
                "A": ["START"],
                "B": ["A"],
                "C": ["A"],
                "D": ["B", "C"]
            },
        )

        self.store.register_events(graph.walk())

        context_store = {}  # type: dict[str, EventContext]

        for event in graph.walk():
            context = yield defer.ensureDeferred(
                self.state.compute_event_context(event))
            self.store.register_event_context(event, context)
            context_store[event.event_id] = context

        ctx_c = context_store["C"]
        ctx_d = context_store["D"]

        prev_state_ids = yield defer.ensureDeferred(ctx_d.get_prev_state_ids())
        self.assertEqual(2, len(prev_state_ids))

        self.assertEqual(ctx_c.state_group, ctx_d.state_group_before_event)
        self.assertEqual(ctx_d.state_group_before_event, ctx_d.state_group)

    @defer.inlineCallbacks
    def test_branch_basic_conflict(self):
        graph = Graph(
            nodes={
                "START":
                DictObj(
                    type=EventTypes.Create,
                    state_key="",
                    content={"creator": "@user_id:example.com"},
                    depth=1,
                ),
                "A":
                DictObj(
                    type=EventTypes.Member,
                    state_key="@user_id:example.com",
                    content={"membership": Membership.JOIN},
                    membership=Membership.JOIN,
                    depth=2,
                ),
                "B":
                DictObj(type=EventTypes.Name, state_key="", depth=3),
                "C":
                DictObj(type=EventTypes.Name, state_key="", depth=4),
                "D":
                DictObj(type=EventTypes.Message, depth=5),
            },
            edges={
                "A": ["START"],
                "B": ["A"],
                "C": ["A"],
                "D": ["B", "C"]
            },
        )

        self.store.register_events(graph.walk())

        context_store = {}

        for event in graph.walk():
            context = yield defer.ensureDeferred(
                self.state.compute_event_context(event))
            self.store.register_event_context(event, context)
            context_store[event.event_id] = context

        # C ends up winning the resolution between B and C

        ctx_c = context_store["C"]
        ctx_d = context_store["D"]

        prev_state_ids = yield defer.ensureDeferred(ctx_d.get_prev_state_ids())
        self.assertSetEqual({"START", "A", "C"}, set(prev_state_ids.values()))

        self.assertEqual(ctx_c.state_group, ctx_d.state_group_before_event)
        self.assertEqual(ctx_d.state_group_before_event, ctx_d.state_group)

    @defer.inlineCallbacks
    def test_branch_have_banned_conflict(self):
        graph = Graph(
            nodes={
                "START":
                DictObj(
                    type=EventTypes.Create,
                    state_key="",
                    content={"creator": "@user_id:example.com"},
                    depth=1,
                ),
                "A":
                DictObj(
                    type=EventTypes.Member,
                    state_key="@user_id:example.com",
                    content={"membership": Membership.JOIN},
                    membership=Membership.JOIN,
                    depth=2,
                ),
                "B":
                DictObj(type=EventTypes.Name, state_key="", depth=3),
                "C":
                DictObj(
                    type=EventTypes.Member,
                    state_key="@user_id_2:example.com",
                    content={"membership": Membership.BAN},
                    membership=Membership.BAN,
                    depth=4,
                ),
                "D":
                DictObj(
                    type=EventTypes.Name,
                    state_key="",
                    depth=4,
                    sender="@user_id_2:example.com",
                ),
                "E":
                DictObj(type=EventTypes.Message, depth=5),
            },
            edges={
                "A": ["START"],
                "B": ["A"],
                "C": ["B"],
                "D": ["B"],
                "E": ["C", "D"]
            },
        )

        self.store.register_events(graph.walk())

        context_store = {}

        for event in graph.walk():
            context = yield defer.ensureDeferred(
                self.state.compute_event_context(event))
            self.store.register_event_context(event, context)
            context_store[event.event_id] = context

        # C ends up winning the resolution between C and D because bans win over other
        # changes

        ctx_c = context_store["C"]
        ctx_e = context_store["E"]

        prev_state_ids = yield defer.ensureDeferred(ctx_e.get_prev_state_ids())
        self.assertSetEqual({"START", "A", "B", "C"},
                            set(prev_state_ids.values()))
        self.assertEqual(ctx_c.state_group, ctx_e.state_group_before_event)
        self.assertEqual(ctx_e.state_group_before_event, ctx_e.state_group)

    @defer.inlineCallbacks
    def test_branch_have_perms_conflict(self):
        userid1 = "@user_id:example.com"
        userid2 = "@user_id2:example.com"

        nodes = {
            "A1":
            DictObj(
                type=EventTypes.Create,
                state_key="",
                content={"creator": userid1},
                depth=1,
            ),
            "A2":
            DictObj(
                type=EventTypes.Member,
                state_key=userid1,
                content={"membership": Membership.JOIN},
                membership=Membership.JOIN,
            ),
            "A3":
            DictObj(
                type=EventTypes.Member,
                state_key=userid2,
                content={"membership": Membership.JOIN},
                membership=Membership.JOIN,
            ),
            "A4":
            DictObj(
                type=EventTypes.PowerLevels,
                state_key="",
                content={
                    "events": {
                        "m.room.name": 50
                    },
                    "users": {
                        userid1: 100,
                        userid2: 60
                    },
                },
            ),
            "A5":
            DictObj(type=EventTypes.Name, state_key=""),
            "B":
            DictObj(
                type=EventTypes.PowerLevels,
                state_key="",
                content={
                    "events": {
                        "m.room.name": 50
                    },
                    "users": {
                        userid2: 30
                    }
                },
            ),
            "C":
            DictObj(type=EventTypes.Name, state_key="", sender=userid2),
            "D":
            DictObj(type=EventTypes.Message),
        }
        edges = {
            "A2": ["A1"],
            "A3": ["A2"],
            "A4": ["A3"],
            "A5": ["A4"],
            "B": ["A5"],
            "C": ["A5"],
            "D": ["B", "C"],
        }
        self._add_depths(nodes, edges)
        graph = Graph(nodes, edges)

        self.store.register_events(graph.walk())

        context_store = {}

        for event in graph.walk():
            context = yield defer.ensureDeferred(
                self.state.compute_event_context(event))
            self.store.register_event_context(event, context)
            context_store[event.event_id] = context

        # B ends up winning the resolution between B and C because power levels
        # win over other changes.

        ctx_b = context_store["B"]
        ctx_d = context_store["D"]

        prev_state_ids = yield defer.ensureDeferred(ctx_d.get_prev_state_ids())
        self.assertSetEqual({"A1", "A2", "A3", "A5", "B"},
                            set(prev_state_ids.values()))

        self.assertEqual(ctx_b.state_group, ctx_d.state_group_before_event)
        self.assertEqual(ctx_d.state_group_before_event, ctx_d.state_group)

    def _add_depths(self, nodes, edges):
        def _get_depth(ev):
            node = nodes[ev]
            if "depth" not in node:
                prevs = edges[ev]
                depth = max(_get_depth(prev) for prev in prevs) + 1
                node["depth"] = depth
            return node["depth"]

        for n in nodes:
            _get_depth(n)

    @defer.inlineCallbacks
    def test_annotate_with_old_message(self):
        event = create_event(type="test_message", name="event")

        old_state = [
            create_event(type="test1", state_key="1"),
            create_event(type="test1", state_key="2"),
            create_event(type="test2", state_key=""),
        ]

        context = yield defer.ensureDeferred(
            self.state.compute_event_context(event, old_state=old_state))

        prev_state_ids = yield defer.ensureDeferred(
            context.get_prev_state_ids())
        self.assertCountEqual((e.event_id for e in old_state),
                              prev_state_ids.values())

        current_state_ids = yield defer.ensureDeferred(
            context.get_current_state_ids())
        self.assertCountEqual((e.event_id for e in old_state),
                              current_state_ids.values())

        self.assertIsNotNone(context.state_group_before_event)
        self.assertEqual(context.state_group_before_event, context.state_group)

    @defer.inlineCallbacks
    def test_annotate_with_old_state(self):
        event = create_event(type="state", state_key="", name="event")

        old_state = [
            create_event(type="test1", state_key="1"),
            create_event(type="test1", state_key="2"),
            create_event(type="test2", state_key=""),
        ]

        context = yield defer.ensureDeferred(
            self.state.compute_event_context(event, old_state=old_state))

        prev_state_ids = yield defer.ensureDeferred(
            context.get_prev_state_ids())
        self.assertCountEqual((e.event_id for e in old_state),
                              prev_state_ids.values())

        current_state_ids = yield defer.ensureDeferred(
            context.get_current_state_ids())
        self.assertCountEqual((e.event_id for e in old_state + [event]),
                              current_state_ids.values())

        self.assertIsNotNone(context.state_group_before_event)
        self.assertNotEqual(context.state_group_before_event,
                            context.state_group)
        self.assertEqual(context.state_group_before_event, context.prev_group)
        self.assertEqual({("state", ""): event.event_id}, context.delta_ids)

    @defer.inlineCallbacks
    def test_trivial_annotate_message(self):
        prev_event_id = "prev_event_id"
        event = create_event(type="test_message",
                             name="event2",
                             prev_events=[(prev_event_id, {})])

        old_state = [
            create_event(type="test1", state_key="1"),
            create_event(type="test1", state_key="2"),
            create_event(type="test2", state_key=""),
        ]

        group_name = yield defer.ensureDeferred(
            self.store.store_state_group(
                prev_event_id,
                event.room_id,
                None,
                None,
                {(e.type, e.state_key): e.event_id
                 for e in old_state},
            ))
        self.store.register_event_id_state_group(prev_event_id, group_name)

        context = yield defer.ensureDeferred(
            self.state.compute_event_context(event))

        current_state_ids = yield defer.ensureDeferred(
            context.get_current_state_ids())

        self.assertEqual({e.event_id
                          for e in old_state}, set(current_state_ids.values()))

        self.assertEqual(group_name, context.state_group)

    @defer.inlineCallbacks
    def test_trivial_annotate_state(self):
        prev_event_id = "prev_event_id"
        event = create_event(type="state",
                             state_key="",
                             name="event2",
                             prev_events=[(prev_event_id, {})])

        old_state = [
            create_event(type="test1", state_key="1"),
            create_event(type="test1", state_key="2"),
            create_event(type="test2", state_key=""),
        ]

        group_name = yield defer.ensureDeferred(
            self.store.store_state_group(
                prev_event_id,
                event.room_id,
                None,
                None,
                {(e.type, e.state_key): e.event_id
                 for e in old_state},
            ))
        self.store.register_event_id_state_group(prev_event_id, group_name)

        context = yield defer.ensureDeferred(
            self.state.compute_event_context(event))

        prev_state_ids = yield defer.ensureDeferred(
            context.get_prev_state_ids())

        self.assertEqual({e.event_id
                          for e in old_state}, set(prev_state_ids.values()))

        self.assertIsNotNone(context.state_group)

    @defer.inlineCallbacks
    def test_resolve_message_conflict(self):
        prev_event_id1 = "event_id1"
        prev_event_id2 = "event_id2"
        event = create_event(
            type="test_message",
            name="event3",
            prev_events=[(prev_event_id1, {}), (prev_event_id2, {})],
        )

        creation = create_event(type=EventTypes.Create, state_key="")

        old_state_1 = [
            creation,
            create_event(type="test1", state_key="1"),
            create_event(type="test1", state_key="2"),
            create_event(type="test2", state_key=""),
        ]

        old_state_2 = [
            creation,
            create_event(type="test1", state_key="1"),
            create_event(type="test3", state_key="2"),
            create_event(type="test4", state_key=""),
        ]

        self.store.register_events(old_state_1)
        self.store.register_events(old_state_2)

        context = yield self._get_context(event, prev_event_id1, old_state_1,
                                          prev_event_id2, old_state_2)

        current_state_ids = yield defer.ensureDeferred(
            context.get_current_state_ids())

        self.assertEqual(len(current_state_ids), 6)

        self.assertIsNotNone(context.state_group)

    @defer.inlineCallbacks
    def test_resolve_state_conflict(self):
        prev_event_id1 = "event_id1"
        prev_event_id2 = "event_id2"
        event = create_event(
            type="test4",
            state_key="",
            name="event",
            prev_events=[(prev_event_id1, {}), (prev_event_id2, {})],
        )

        creation = create_event(type=EventTypes.Create, state_key="")

        old_state_1 = [
            creation,
            create_event(type="test1", state_key="1"),
            create_event(type="test1", state_key="2"),
            create_event(type="test2", state_key=""),
        ]

        old_state_2 = [
            creation,
            create_event(type="test1", state_key="1"),
            create_event(type="test3", state_key="2"),
            create_event(type="test4", state_key=""),
        ]

        store = StateGroupStore()
        store.register_events(old_state_1)
        store.register_events(old_state_2)
        self.store.get_events = store.get_events

        context = yield self._get_context(event, prev_event_id1, old_state_1,
                                          prev_event_id2, old_state_2)

        current_state_ids = yield defer.ensureDeferred(
            context.get_current_state_ids())

        self.assertEqual(len(current_state_ids), 6)

        self.assertIsNotNone(context.state_group)

    @defer.inlineCallbacks
    def test_standard_depth_conflict(self):
        prev_event_id1 = "event_id1"
        prev_event_id2 = "event_id2"
        event = create_event(
            type="test4",
            name="event",
            prev_events=[(prev_event_id1, {}), (prev_event_id2, {})],
        )

        member_event = create_event(
            type=EventTypes.Member,
            state_key="@user_id:example.com",
            content={"membership": Membership.JOIN},
        )

        power_levels = create_event(
            type=EventTypes.PowerLevels,
            state_key="",
            content={
                "users": {
                    "@foo:bar": "100",
                    "@user_id:example.com": "100"
                }
            },
        )

        creation = create_event(type=EventTypes.Create,
                                state_key="",
                                content={"creator": "@foo:bar"})

        old_state_1 = [
            creation,
            power_levels,
            member_event,
            create_event(type="test1", state_key="1", depth=1),
        ]

        old_state_2 = [
            creation,
            power_levels,
            member_event,
            create_event(type="test1", state_key="1", depth=2),
        ]

        store = StateGroupStore()
        store.register_events(old_state_1)
        store.register_events(old_state_2)
        self.store.get_events = store.get_events

        context = yield self._get_context(event, prev_event_id1, old_state_1,
                                          prev_event_id2, old_state_2)

        current_state_ids = yield defer.ensureDeferred(
            context.get_current_state_ids())

        self.assertEqual(old_state_2[3].event_id,
                         current_state_ids[("test1", "1")])

        # Reverse the depth to make sure we are actually using the depths
        # during state resolution.

        old_state_1 = [
            creation,
            power_levels,
            member_event,
            create_event(type="test1", state_key="1", depth=2),
        ]

        old_state_2 = [
            creation,
            power_levels,
            member_event,
            create_event(type="test1", state_key="1", depth=1),
        ]

        store.register_events(old_state_1)
        store.register_events(old_state_2)

        context = yield self._get_context(event, prev_event_id1, old_state_1,
                                          prev_event_id2, old_state_2)

        current_state_ids = yield defer.ensureDeferred(
            context.get_current_state_ids())

        self.assertEqual(old_state_1[3].event_id,
                         current_state_ids[("test1", "1")])

    @defer.inlineCallbacks
    def _get_context(self, event, prev_event_id_1, old_state_1,
                     prev_event_id_2, old_state_2):
        sg1 = yield defer.ensureDeferred(
            self.store.store_state_group(
                prev_event_id_1,
                event.room_id,
                None,
                None,
                {(e.type, e.state_key): e.event_id
                 for e in old_state_1},
            ))
        self.store.register_event_id_state_group(prev_event_id_1, sg1)

        sg2 = yield defer.ensureDeferred(
            self.store.store_state_group(
                prev_event_id_2,
                event.room_id,
                None,
                None,
                {(e.type, e.state_key): e.event_id
                 for e in old_state_2},
            ))
        self.store.register_event_id_state_group(prev_event_id_2, sg2)

        result = yield defer.ensureDeferred(
            self.state.compute_event_context(event))
        return result
Example #15
0
class StateTestCase(unittest.TestCase):
    def setUp(self):
        self.persistence = Mock(spec=[
            "get_unresolved_state_tree",
            "update_current_state",
            "get_latest_pdus_in_context",
            "get_current_state_pdu",
            "get_pdu",
            "get_power_level",
        ])
        self.replication = Mock(spec=["get_pdu"])

        hs = Mock(spec=["get_datastore", "get_replication_layer"])
        hs.get_datastore.return_value = self.persistence
        hs.get_replication_layer.return_value = self.replication
        hs.hostname = "bob.com"

        self.state = StateHandler(hs)

    @defer.inlineCallbacks
    def test_new_state_key(self):
        # We've never seen anything for this state before
        new_pdu = new_fake_pdu("A", "test", "mem", "x", None, "u")

        self.persistence.get_power_level.side_effect = _gen_get_power_level({})

        self.persistence.get_unresolved_state_tree.return_value = ((ReturnType(
            [new_pdu], []), None))

        is_new = yield self.state.handle_new_state(new_pdu)

        self.assertTrue(is_new)

        self.persistence.get_unresolved_state_tree.assert_called_once_with(
            new_pdu)

        self.assertEqual(1, self.persistence.update_current_state.call_count)

        self.assertFalse(self.replication.get_pdu.called)

    @defer.inlineCallbacks
    def test_direct_overwrite(self):
        # We do a direct overwriting of the old state, i.e., the new state
        # points to the old state.

        old_pdu = new_fake_pdu("A", "test", "mem", "x", None, "u1")
        new_pdu = new_fake_pdu("B", "test", "mem", "x", "A", "u2")

        self.persistence.get_power_level.side_effect = _gen_get_power_level({
            "u1":
            10,
            "u2":
            5,
        })

        self.persistence.get_unresolved_state_tree.return_value = ((ReturnType(
            [new_pdu, old_pdu], [old_pdu]), None))

        is_new = yield self.state.handle_new_state(new_pdu)

        self.assertTrue(is_new)

        self.persistence.get_unresolved_state_tree.assert_called_once_with(
            new_pdu)

        self.assertEqual(1, self.persistence.update_current_state.call_count)

        self.assertFalse(self.replication.get_pdu.called)

    @defer.inlineCallbacks
    def test_overwrite(self):
        old_pdu_1 = new_fake_pdu("A", "test", "mem", "x", None, "u1")
        old_pdu_2 = new_fake_pdu("B", "test", "mem", "x", "A", "u2")
        new_pdu = new_fake_pdu("C", "test", "mem", "x", "B", "u3")

        self.persistence.get_power_level.side_effect = _gen_get_power_level({
            "u1":
            10,
            "u2":
            5,
            "u3":
            0,
        })

        self.persistence.get_unresolved_state_tree.return_value = ((ReturnType(
            [new_pdu, old_pdu_2, old_pdu_1], [old_pdu_1]), None))

        is_new = yield self.state.handle_new_state(new_pdu)

        self.assertTrue(is_new)

        self.persistence.get_unresolved_state_tree.assert_called_once_with(
            new_pdu)

        self.assertEqual(1, self.persistence.update_current_state.call_count)

        self.assertFalse(self.replication.get_pdu.called)

    @defer.inlineCallbacks
    def test_power_level_fail(self):
        # We try to update the state based on an outdated state, and have a
        # too low power level.

        old_pdu_1 = new_fake_pdu("A", "test", "mem", "x", None, "u1")
        old_pdu_2 = new_fake_pdu("B", "test", "mem", "x", None, "u2")
        new_pdu = new_fake_pdu("C", "test", "mem", "x", "A", "u3")

        self.persistence.get_power_level.side_effect = _gen_get_power_level({
            "u1":
            10,
            "u2":
            10,
            "u3":
            5,
        })

        self.persistence.get_unresolved_state_tree.return_value = ((ReturnType(
            [new_pdu, old_pdu_1], [old_pdu_2, old_pdu_1]), None))

        is_new = yield self.state.handle_new_state(new_pdu)

        self.assertFalse(is_new)

        self.persistence.get_unresolved_state_tree.assert_called_once_with(
            new_pdu)

        self.assertEqual(0, self.persistence.update_current_state.call_count)

        self.assertFalse(self.replication.get_pdu.called)

    @defer.inlineCallbacks
    def test_power_level_succeed(self):
        # We try to update the state based on an outdated state, but have
        # sufficient power level to force the update.

        old_pdu_1 = new_fake_pdu("A", "test", "mem", "x", None, "u1")
        old_pdu_2 = new_fake_pdu("B", "test", "mem", "x", None, "u2")
        new_pdu = new_fake_pdu("C", "test", "mem", "x", "A", "u3")

        self.persistence.get_power_level.side_effect = _gen_get_power_level({
            "u1":
            10,
            "u2":
            10,
            "u3":
            15,
        })

        self.persistence.get_unresolved_state_tree.return_value = ((ReturnType(
            [new_pdu, old_pdu_1], [old_pdu_2, old_pdu_1]), None))

        is_new = yield self.state.handle_new_state(new_pdu)

        self.assertTrue(is_new)

        self.persistence.get_unresolved_state_tree.assert_called_once_with(
            new_pdu)

        self.assertEqual(1, self.persistence.update_current_state.call_count)

        self.assertFalse(self.replication.get_pdu.called)

    @defer.inlineCallbacks
    def test_power_level_equal_same_len(self):
        # We try to update the state based on an outdated state, the power
        # levels are the same and so are the branch lengths

        old_pdu_1 = new_fake_pdu("A", "test", "mem", "x", None, "u1")
        old_pdu_2 = new_fake_pdu("B", "test", "mem", "x", None, "u2")
        new_pdu = new_fake_pdu("C", "test", "mem", "x", "A", "u3")

        self.persistence.get_power_level.side_effect = _gen_get_power_level({
            "u1":
            10,
            "u2":
            10,
            "u3":
            10,
        })

        self.persistence.get_unresolved_state_tree.return_value = ((ReturnType(
            [new_pdu, old_pdu_1], [old_pdu_2, old_pdu_1]), None))

        is_new = yield self.state.handle_new_state(new_pdu)

        self.assertTrue(is_new)

        self.persistence.get_unresolved_state_tree.assert_called_once_with(
            new_pdu)

        self.assertEqual(1, self.persistence.update_current_state.call_count)

        self.assertFalse(self.replication.get_pdu.called)

    @defer.inlineCallbacks
    def test_power_level_equal_diff_len(self):
        # We try to update the state based on an outdated state, the power
        # levels are the same but the branch length of the new one is longer.

        old_pdu_1 = new_fake_pdu("A", "test", "mem", "x", None, "u1")
        old_pdu_2 = new_fake_pdu("B", "test", "mem", "x", None, "u2")
        old_pdu_3 = new_fake_pdu("C", "test", "mem", "x", "A", "u3")
        new_pdu = new_fake_pdu("D", "test", "mem", "x", "C", "u4")

        self.persistence.get_power_level.side_effect = _gen_get_power_level({
            "u1":
            10,
            "u2":
            10,
            "u3":
            10,
            "u4":
            10,
        })

        self.persistence.get_unresolved_state_tree.return_value = ((ReturnType(
            [new_pdu, old_pdu_3, old_pdu_1], [old_pdu_2, old_pdu_1]), None))

        is_new = yield self.state.handle_new_state(new_pdu)

        self.assertTrue(is_new)

        self.persistence.get_unresolved_state_tree.assert_called_once_with(
            new_pdu)

        self.assertEqual(1, self.persistence.update_current_state.call_count)

        self.assertFalse(self.replication.get_pdu.called)

    @defer.inlineCallbacks
    def test_missing_pdu(self):
        # We try to update state against a PDU we haven't yet seen,
        # triggering a get_pdu request

        # The pdu we haven't seen
        old_pdu_1 = new_fake_pdu("A", "test", "mem", "x", None, "u1", depth=0)

        old_pdu_2 = new_fake_pdu("B", "test", "mem", "x", "A", "u2", depth=1)
        new_pdu = new_fake_pdu("C", "test", "mem", "x", "A", "u3", depth=2)

        self.persistence.get_power_level.side_effect = _gen_get_power_level({
            "u1":
            10,
            "u2":
            10,
            "u3":
            20,
        })

        # The return_value of `get_unresolved_state_tree`, which changes after
        # the call to get_pdu
        tree_to_return = [(ReturnType([new_pdu], [old_pdu_2]), 0)]

        def return_tree(p):
            return tree_to_return[0]

        def set_return_tree(destination, pdu_origin, pdu_id, outlier=False):
            tree_to_return[0] = (ReturnType([new_pdu, old_pdu_1],
                                            [old_pdu_2, old_pdu_1]), None)
            return defer.succeed(None)

        self.persistence.get_unresolved_state_tree.side_effect = return_tree

        self.replication.get_pdu.side_effect = set_return_tree

        self.persistence.get_pdu.return_value = None

        is_new = yield self.state.handle_new_state(new_pdu)

        self.assertTrue(is_new)

        self.replication.get_pdu.assert_called_with(
            destination=new_pdu.origin,
            pdu_origin=old_pdu_1.origin,
            pdu_id=old_pdu_1.pdu_id,
            outlier=True)

        self.persistence.get_unresolved_state_tree.assert_called_with(new_pdu)

        self.assertEquals(
            2, self.persistence.get_unresolved_state_tree.call_count)

        self.assertEqual(1, self.persistence.update_current_state.call_count)

    @defer.inlineCallbacks
    def test_missing_pdu_depth_1(self):
        # We try to update state against a PDU we haven't yet seen,
        # triggering a get_pdu request

        # The pdu we haven't seen
        old_pdu_1 = new_fake_pdu("A", "test", "mem", "x", None, "u1", depth=0)

        old_pdu_2 = new_fake_pdu("B", "test", "mem", "x", "A", "u2", depth=2)
        old_pdu_3 = new_fake_pdu("C", "test", "mem", "x", "B", "u3", depth=3)
        new_pdu = new_fake_pdu("D", "test", "mem", "x", "A", "u4", depth=4)

        self.persistence.get_power_level.side_effect = _gen_get_power_level({
            "u1":
            10,
            "u2":
            10,
            "u3":
            10,
            "u4":
            20,
        })

        # The return_value of `get_unresolved_state_tree`, which changes after
        # the call to get_pdu
        tree_to_return = [
            (ReturnType([new_pdu], [old_pdu_3]), 0),
            (ReturnType([new_pdu, old_pdu_1], [old_pdu_3]), 1),
            (ReturnType([new_pdu, old_pdu_1],
                        [old_pdu_3, old_pdu_2, old_pdu_1]), None),
        ]

        to_return = [0]

        def return_tree(p):
            return tree_to_return[to_return[0]]

        def set_return_tree(destination, pdu_origin, pdu_id, outlier=False):
            to_return[0] += 1
            return defer.succeed(None)

        self.persistence.get_unresolved_state_tree.side_effect = return_tree

        self.replication.get_pdu.side_effect = set_return_tree

        self.persistence.get_pdu.return_value = None

        is_new = yield self.state.handle_new_state(new_pdu)

        self.assertTrue(is_new)

        self.assertEqual(2, self.replication.get_pdu.call_count)

        self.replication.get_pdu.assert_has_calls([
            mock.call(destination=new_pdu.origin,
                      pdu_origin=old_pdu_1.origin,
                      pdu_id=old_pdu_1.pdu_id,
                      outlier=True),
            mock.call(destination=old_pdu_3.origin,
                      pdu_origin=old_pdu_2.origin,
                      pdu_id=old_pdu_2.pdu_id,
                      outlier=True),
        ])

        self.persistence.get_unresolved_state_tree.assert_called_with(new_pdu)

        self.assertEquals(
            3, self.persistence.get_unresolved_state_tree.call_count)

        self.assertEqual(1, self.persistence.update_current_state.call_count)

    @defer.inlineCallbacks
    def test_missing_pdu_depth_2(self):
        # We try to update state against a PDU we haven't yet seen,
        # triggering a get_pdu request

        # The pdu we haven't seen
        old_pdu_1 = new_fake_pdu("A", "test", "mem", "x", None, "u1", depth=0)

        old_pdu_2 = new_fake_pdu("B", "test", "mem", "x", "A", "u2", depth=2)
        old_pdu_3 = new_fake_pdu("C", "test", "mem", "x", "B", "u3", depth=3)
        new_pdu = new_fake_pdu("D", "test", "mem", "x", "A", "u4", depth=1)

        self.persistence.get_power_level.side_effect = _gen_get_power_level({
            "u1":
            10,
            "u2":
            10,
            "u3":
            10,
            "u4":
            20,
        })

        # The return_value of `get_unresolved_state_tree`, which changes after
        # the call to get_pdu
        tree_to_return = [
            (
                ReturnType([new_pdu], [old_pdu_3]),
                1,
            ),
            (
                ReturnType([new_pdu], [old_pdu_3, old_pdu_2]),
                0,
            ),
            (ReturnType([new_pdu, old_pdu_1],
                        [old_pdu_3, old_pdu_2, old_pdu_1]), None),
        ]

        to_return = [0]

        def return_tree(p):
            return tree_to_return[to_return[0]]

        def set_return_tree(destination, pdu_origin, pdu_id, outlier=False):
            to_return[0] += 1
            return defer.succeed(None)

        self.persistence.get_unresolved_state_tree.side_effect = return_tree

        self.replication.get_pdu.side_effect = set_return_tree

        self.persistence.get_pdu.return_value = None

        is_new = yield self.state.handle_new_state(new_pdu)

        self.assertTrue(is_new)

        self.assertEqual(2, self.replication.get_pdu.call_count)

        self.replication.get_pdu.assert_has_calls([
            mock.call(destination=old_pdu_3.origin,
                      pdu_origin=old_pdu_2.origin,
                      pdu_id=old_pdu_2.pdu_id,
                      outlier=True),
            mock.call(destination=new_pdu.origin,
                      pdu_origin=old_pdu_1.origin,
                      pdu_id=old_pdu_1.pdu_id,
                      outlier=True),
        ])

        self.persistence.get_unresolved_state_tree.assert_called_with(new_pdu)

        self.assertEquals(
            3, self.persistence.get_unresolved_state_tree.call_count)

        self.assertEqual(1, self.persistence.update_current_state.call_count)

    @defer.inlineCallbacks
    def test_no_common_ancestor(self):
        # We do a direct overwriting of the old state, i.e., the new state
        # points to the old state.

        old_pdu = new_fake_pdu("A", "test", "mem", "x", None, "u1")
        new_pdu = new_fake_pdu("B", "test", "mem", "x", None, "u2")

        self.persistence.get_power_level.side_effect = _gen_get_power_level({
            "u1":
            5,
            "u2":
            10,
        })

        self.persistence.get_unresolved_state_tree.return_value = ((ReturnType(
            [new_pdu], [old_pdu]), None))

        is_new = yield self.state.handle_new_state(new_pdu)

        self.assertTrue(is_new)

        self.persistence.get_unresolved_state_tree.assert_called_once_with(
            new_pdu)

        self.assertEqual(1, self.persistence.update_current_state.call_count)

        self.assertFalse(self.replication.get_pdu.called)

    @defer.inlineCallbacks
    def test_new_event(self):
        event = Mock()
        event.event_id = "12123123@test"

        state_pdu = new_fake_pdu("C", "test", "mem", "x", "A", 20)

        snapshot = Mock()
        snapshot.prev_state_pdu = state_pdu
        event_id = "*****@*****.**"

        def fill_out_prev_events(event):
            event.prev_events = [event_id]
            event.depth = 6

        snapshot.fill_out_prev_events = fill_out_prev_events

        yield self.state.handle_new_event(event, snapshot)

        self.assertLess(5, event.depth)

        self.assertEquals(1, len(event.prev_events))

        prev_id = event.prev_events[0]

        self.assertEqual(event_id, prev_id)

        self.assertEqual(encode_event_id(state_pdu.pdu_id, state_pdu.origin),
                         event.prev_state)
Example #16
0
class StateTestCase(unittest.TestCase):
    def setUp(self):
        self.store = StateGroupStore()
        hs = Mock(spec_set=[
            "get_datastore",
            "get_auth",
            "get_state_handler",
            "get_clock",
            "get_state_resolution_handler",
        ])
        hs.get_datastore.return_value = self.store
        hs.get_state_handler.return_value = None
        hs.get_clock.return_value = MockClock()
        hs.get_auth.return_value = Auth(hs)
        hs.get_state_resolution_handler = lambda: StateResolutionHandler(hs)

        self.state = StateHandler(hs)
        self.event_id = 0

    @defer.inlineCallbacks
    def test_branch_no_conflict(self):
        graph = Graph(nodes={
            "START":
            DictObj(
                type=EventTypes.Create,
                state_key="",
                depth=1,
            ),
            "A":
            DictObj(
                type=EventTypes.Message,
                depth=2,
            ),
            "B":
            DictObj(
                type=EventTypes.Message,
                depth=3,
            ),
            "C":
            DictObj(
                type=EventTypes.Name,
                state_key="",
                depth=3,
            ),
            "D":
            DictObj(
                type=EventTypes.Message,
                depth=4,
            ),
        },
                      edges={
                          "A": ["START"],
                          "B": ["A"],
                          "C": ["A"],
                          "D": ["B", "C"]
                      })

        self.store.register_events(graph.walk())

        context_store = {}

        for event in graph.walk():
            context = yield self.state.compute_event_context(event)
            self.store.register_event_context(event, context)
            context_store[event.event_id] = context

        prev_state_ids = yield context_store["D"].get_prev_state_ids(
            self.store)
        self.assertEqual(2, len(prev_state_ids))

    @defer.inlineCallbacks
    def test_branch_basic_conflict(self):
        graph = Graph(nodes={
            "START":
            DictObj(
                type=EventTypes.Create,
                state_key="",
                content={"creator": "@user_id:example.com"},
                depth=1,
            ),
            "A":
            DictObj(
                type=EventTypes.Member,
                state_key="@user_id:example.com",
                content={"membership": Membership.JOIN},
                membership=Membership.JOIN,
                depth=2,
            ),
            "B":
            DictObj(
                type=EventTypes.Name,
                state_key="",
                depth=3,
            ),
            "C":
            DictObj(
                type=EventTypes.Name,
                state_key="",
                depth=4,
            ),
            "D":
            DictObj(
                type=EventTypes.Message,
                depth=5,
            ),
        },
                      edges={
                          "A": ["START"],
                          "B": ["A"],
                          "C": ["A"],
                          "D": ["B", "C"]
                      })

        self.store.register_events(graph.walk())

        context_store = {}

        for event in graph.walk():
            context = yield self.state.compute_event_context(event)
            self.store.register_event_context(event, context)
            context_store[event.event_id] = context

        prev_state_ids = yield context_store["D"].get_prev_state_ids(
            self.store)

        self.assertSetEqual({"START", "A", "C"},
                            {e_id
                             for e_id in prev_state_ids.values()})

    @defer.inlineCallbacks
    def test_branch_have_banned_conflict(self):
        graph = Graph(nodes={
            "START":
            DictObj(
                type=EventTypes.Create,
                state_key="",
                content={"creator": "@user_id:example.com"},
                depth=1,
            ),
            "A":
            DictObj(
                type=EventTypes.Member,
                state_key="@user_id:example.com",
                content={"membership": Membership.JOIN},
                membership=Membership.JOIN,
                depth=2,
            ),
            "B":
            DictObj(
                type=EventTypes.Name,
                state_key="",
                depth=3,
            ),
            "C":
            DictObj(
                type=EventTypes.Member,
                state_key="@user_id_2:example.com",
                content={"membership": Membership.BAN},
                membership=Membership.BAN,
                depth=4,
            ),
            "D":
            DictObj(
                type=EventTypes.Name,
                state_key="",
                depth=4,
                sender="@user_id_2:example.com",
            ),
            "E":
            DictObj(
                type=EventTypes.Message,
                depth=5,
            ),
        },
                      edges={
                          "A": ["START"],
                          "B": ["A"],
                          "C": ["B"],
                          "D": ["B"],
                          "E": ["C", "D"]
                      })

        self.store.register_events(graph.walk())

        context_store = {}

        for event in graph.walk():
            context = yield self.state.compute_event_context(event)
            self.store.register_event_context(event, context)
            context_store[event.event_id] = context

        prev_state_ids = yield context_store["E"].get_prev_state_ids(
            self.store)

        self.assertSetEqual({"START", "A", "B", "C"},
                            {e
                             for e in prev_state_ids.values()})

    @defer.inlineCallbacks
    def test_branch_have_perms_conflict(self):
        userid1 = "@user_id:example.com"
        userid2 = "@user_id2:example.com"

        nodes = {
            "A1":
            DictObj(
                type=EventTypes.Create,
                state_key="",
                content={"creator": userid1},
                depth=1,
            ),
            "A2":
            DictObj(
                type=EventTypes.Member,
                state_key=userid1,
                content={"membership": Membership.JOIN},
                membership=Membership.JOIN,
            ),
            "A3":
            DictObj(
                type=EventTypes.Member,
                state_key=userid2,
                content={"membership": Membership.JOIN},
                membership=Membership.JOIN,
            ),
            "A4":
            DictObj(
                type=EventTypes.PowerLevels,
                state_key="",
                content={
                    "events": {
                        "m.room.name": 50
                    },
                    "users": {
                        userid1: 100,
                        userid2: 60
                    },
                },
            ),
            "A5":
            DictObj(
                type=EventTypes.Name,
                state_key="",
            ),
            "B":
            DictObj(
                type=EventTypes.PowerLevels,
                state_key="",
                content={
                    "events": {
                        "m.room.name": 50
                    },
                    "users": {
                        userid2: 30
                    },
                },
            ),
            "C":
            DictObj(
                type=EventTypes.Name,
                state_key="",
                sender=userid2,
            ),
            "D":
            DictObj(type=EventTypes.Message, ),
        }
        edges = {
            "A2": ["A1"],
            "A3": ["A2"],
            "A4": ["A3"],
            "A5": ["A4"],
            "B": ["A5"],
            "C": ["A5"],
            "D": ["B", "C"]
        }
        self._add_depths(nodes, edges)
        graph = Graph(nodes, edges)

        self.store.register_events(graph.walk())

        context_store = {}

        for event in graph.walk():
            context = yield self.state.compute_event_context(event)
            self.store.register_event_context(event, context)
            context_store[event.event_id] = context

        prev_state_ids = yield context_store["D"].get_prev_state_ids(
            self.store)

        self.assertSetEqual({"A1", "A2", "A3", "A5", "B"},
                            {e
                             for e in prev_state_ids.values()})

    def _add_depths(self, nodes, edges):
        def _get_depth(ev):
            node = nodes[ev]
            if 'depth' not in node:
                prevs = edges[ev]
                depth = max(_get_depth(prev) for prev in prevs) + 1
                node['depth'] = depth
            return node['depth']

        for n in nodes:
            _get_depth(n)

    @defer.inlineCallbacks
    def test_annotate_with_old_message(self):
        event = create_event(type="test_message", name="event")

        old_state = [
            create_event(type="test1", state_key="1"),
            create_event(type="test1", state_key="2"),
            create_event(type="test2", state_key=""),
        ]

        context = yield self.state.compute_event_context(event,
                                                         old_state=old_state)

        current_state_ids = yield context.get_current_state_ids(self.store)

        self.assertEqual(set(e.event_id for e in old_state),
                         set(current_state_ids.values()))

        self.assertIsNotNone(context.state_group)

    @defer.inlineCallbacks
    def test_annotate_with_old_state(self):
        event = create_event(type="state", state_key="", name="event")

        old_state = [
            create_event(type="test1", state_key="1"),
            create_event(type="test1", state_key="2"),
            create_event(type="test2", state_key=""),
        ]

        context = yield self.state.compute_event_context(event,
                                                         old_state=old_state)

        prev_state_ids = yield context.get_prev_state_ids(self.store)

        self.assertEqual(set(e.event_id for e in old_state),
                         set(prev_state_ids.values()))

    @defer.inlineCallbacks
    def test_trivial_annotate_message(self):
        prev_event_id = "prev_event_id"
        event = create_event(
            type="test_message",
            name="event2",
            prev_events=[(prev_event_id, {})],
        )

        old_state = [
            create_event(type="test1", state_key="1"),
            create_event(type="test1", state_key="2"),
            create_event(type="test2", state_key=""),
        ]

        group_name = self.store.store_state_group(
            prev_event_id,
            event.room_id,
            None,
            None,
            {(e.type, e.state_key): e.event_id
             for e in old_state},
        )
        self.store.register_event_id_state_group(prev_event_id, group_name)

        context = yield self.state.compute_event_context(event)

        current_state_ids = yield context.get_current_state_ids(self.store)

        self.assertEqual(set([e.event_id for e in old_state]),
                         set(current_state_ids.values()))

        self.assertEqual(group_name, context.state_group)

    @defer.inlineCallbacks
    def test_trivial_annotate_state(self):
        prev_event_id = "prev_event_id"
        event = create_event(
            type="state",
            state_key="",
            name="event2",
            prev_events=[(prev_event_id, {})],
        )

        old_state = [
            create_event(type="test1", state_key="1"),
            create_event(type="test1", state_key="2"),
            create_event(type="test2", state_key=""),
        ]

        group_name = self.store.store_state_group(
            prev_event_id,
            event.room_id,
            None,
            None,
            {(e.type, e.state_key): e.event_id
             for e in old_state},
        )
        self.store.register_event_id_state_group(prev_event_id, group_name)

        context = yield self.state.compute_event_context(event)

        prev_state_ids = yield context.get_prev_state_ids(self.store)

        self.assertEqual(set([e.event_id for e in old_state]),
                         set(prev_state_ids.values()))

        self.assertIsNotNone(context.state_group)

    @defer.inlineCallbacks
    def test_resolve_message_conflict(self):
        prev_event_id1 = "event_id1"
        prev_event_id2 = "event_id2"
        event = create_event(
            type="test_message",
            name="event3",
            prev_events=[(prev_event_id1, {}), (prev_event_id2, {})],
        )

        creation = create_event(type=EventTypes.Create, state_key="")

        old_state_1 = [
            creation,
            create_event(type="test1", state_key="1"),
            create_event(type="test1", state_key="2"),
            create_event(type="test2", state_key=""),
        ]

        old_state_2 = [
            creation,
            create_event(type="test1", state_key="1"),
            create_event(type="test3", state_key="2"),
            create_event(type="test4", state_key=""),
        ]

        self.store.register_events(old_state_1)
        self.store.register_events(old_state_2)

        context = yield self._get_context(
            event,
            prev_event_id1,
            old_state_1,
            prev_event_id2,
            old_state_2,
        )

        current_state_ids = yield context.get_current_state_ids(self.store)

        self.assertEqual(len(current_state_ids), 6)

        self.assertIsNotNone(context.state_group)

    @defer.inlineCallbacks
    def test_resolve_state_conflict(self):
        prev_event_id1 = "event_id1"
        prev_event_id2 = "event_id2"
        event = create_event(
            type="test4",
            state_key="",
            name="event",
            prev_events=[(prev_event_id1, {}), (prev_event_id2, {})],
        )

        creation = create_event(type=EventTypes.Create, state_key="")

        old_state_1 = [
            creation,
            create_event(type="test1", state_key="1"),
            create_event(type="test1", state_key="2"),
            create_event(type="test2", state_key=""),
        ]

        old_state_2 = [
            creation,
            create_event(type="test1", state_key="1"),
            create_event(type="test3", state_key="2"),
            create_event(type="test4", state_key=""),
        ]

        store = StateGroupStore()
        store.register_events(old_state_1)
        store.register_events(old_state_2)
        self.store.get_events = store.get_events

        context = yield self._get_context(
            event,
            prev_event_id1,
            old_state_1,
            prev_event_id2,
            old_state_2,
        )

        current_state_ids = yield context.get_current_state_ids(self.store)

        self.assertEqual(len(current_state_ids), 6)

        self.assertIsNotNone(context.state_group)

    @defer.inlineCallbacks
    def test_standard_depth_conflict(self):
        prev_event_id1 = "event_id1"
        prev_event_id2 = "event_id2"
        event = create_event(
            type="test4",
            name="event",
            prev_events=[(prev_event_id1, {}), (prev_event_id2, {})],
        )

        member_event = create_event(type=EventTypes.Member,
                                    state_key="@user_id:example.com",
                                    content={
                                        "membership": Membership.JOIN,
                                    })

        power_levels = create_event(type=EventTypes.PowerLevels,
                                    state_key="",
                                    content={
                                        "users": {
                                            "@foo:bar": "100",
                                            "@user_id:example.com": "100",
                                        }
                                    })

        creation = create_event(type=EventTypes.Create,
                                state_key="",
                                content={"creator": "@foo:bar"})

        old_state_1 = [
            creation,
            power_levels,
            member_event,
            create_event(type="test1", state_key="1", depth=1),
        ]

        old_state_2 = [
            creation,
            power_levels,
            member_event,
            create_event(type="test1", state_key="1", depth=2),
        ]

        store = StateGroupStore()
        store.register_events(old_state_1)
        store.register_events(old_state_2)
        self.store.get_events = store.get_events

        context = yield self._get_context(
            event,
            prev_event_id1,
            old_state_1,
            prev_event_id2,
            old_state_2,
        )

        current_state_ids = yield context.get_current_state_ids(self.store)

        self.assertEqual(old_state_2[3].event_id,
                         current_state_ids[("test1", "1")])

        # Reverse the depth to make sure we are actually using the depths
        # during state resolution.

        old_state_1 = [
            creation,
            power_levels,
            member_event,
            create_event(type="test1", state_key="1", depth=2),
        ]

        old_state_2 = [
            creation,
            power_levels,
            member_event,
            create_event(type="test1", state_key="1", depth=1),
        ]

        store.register_events(old_state_1)
        store.register_events(old_state_2)

        context = yield self._get_context(
            event,
            prev_event_id1,
            old_state_1,
            prev_event_id2,
            old_state_2,
        )

        current_state_ids = yield context.get_current_state_ids(self.store)

        self.assertEqual(old_state_1[3].event_id,
                         current_state_ids[("test1", "1")])

    def _get_context(self, event, prev_event_id_1, old_state_1,
                     prev_event_id_2, old_state_2):
        sg1 = self.store.store_state_group(
            prev_event_id_1,
            event.room_id,
            None,
            None,
            {(e.type, e.state_key): e.event_id
             for e in old_state_1},
        )
        self.store.register_event_id_state_group(prev_event_id_1, sg1)

        sg2 = self.store.store_state_group(
            prev_event_id_2,
            event.room_id,
            None,
            None,
            {(e.type, e.state_key): e.event_id
             for e in old_state_2},
        )
        self.store.register_event_id_state_group(prev_event_id_2, sg2)

        return self.state.compute_event_context(event)
Example #17
0
 def build_state_handler(self):
     return StateHandler(self)
Example #18
0
class StateTestCase(unittest.TestCase):
    def setUp(self):
        self.store = StateGroupStore()
        hs = Mock(
            spec_set=[
                "config",
                "get_datastore",
                "get_auth",
                "get_state_handler",
                "get_clock",
                "get_state_resolution_handler",
            ]
        )
        hs.config = default_config("tesths", True)
        hs.get_datastore.return_value = self.store
        hs.get_state_handler.return_value = None
        hs.get_clock.return_value = MockClock()
        hs.get_auth.return_value = Auth(hs)
        hs.get_state_resolution_handler = lambda: StateResolutionHandler(hs)

        self.state = StateHandler(hs)
        self.event_id = 0

    @defer.inlineCallbacks
    def test_branch_no_conflict(self):
        graph = Graph(
            nodes={
                "START": DictObj(
                    type=EventTypes.Create, state_key="", content={}, depth=1
                ),
                "A": DictObj(type=EventTypes.Message, depth=2),
                "B": DictObj(type=EventTypes.Message, depth=3),
                "C": DictObj(type=EventTypes.Name, state_key="", depth=3),
                "D": DictObj(type=EventTypes.Message, depth=4),
            },
            edges={"A": ["START"], "B": ["A"], "C": ["A"], "D": ["B", "C"]},
        )

        self.store.register_events(graph.walk())

        context_store = {}

        for event in graph.walk():
            context = yield self.state.compute_event_context(event)
            self.store.register_event_context(event, context)
            context_store[event.event_id] = context

        prev_state_ids = yield context_store["D"].get_prev_state_ids(self.store)
        self.assertEqual(2, len(prev_state_ids))

    @defer.inlineCallbacks
    def test_branch_basic_conflict(self):
        graph = Graph(
            nodes={
                "START": DictObj(
                    type=EventTypes.Create,
                    state_key="",
                    content={"creator": "@user_id:example.com"},
                    depth=1,
                ),
                "A": DictObj(
                    type=EventTypes.Member,
                    state_key="@user_id:example.com",
                    content={"membership": Membership.JOIN},
                    membership=Membership.JOIN,
                    depth=2,
                ),
                "B": DictObj(type=EventTypes.Name, state_key="", depth=3),
                "C": DictObj(type=EventTypes.Name, state_key="", depth=4),
                "D": DictObj(type=EventTypes.Message, depth=5),
            },
            edges={"A": ["START"], "B": ["A"], "C": ["A"], "D": ["B", "C"]},
        )

        self.store.register_events(graph.walk())

        context_store = {}

        for event in graph.walk():
            context = yield self.state.compute_event_context(event)
            self.store.register_event_context(event, context)
            context_store[event.event_id] = context

        prev_state_ids = yield context_store["D"].get_prev_state_ids(self.store)

        self.assertSetEqual(
            {"START", "A", "C"}, {e_id for e_id in prev_state_ids.values()}
        )

    @defer.inlineCallbacks
    def test_branch_have_banned_conflict(self):
        graph = Graph(
            nodes={
                "START": DictObj(
                    type=EventTypes.Create,
                    state_key="",
                    content={"creator": "@user_id:example.com"},
                    depth=1,
                ),
                "A": DictObj(
                    type=EventTypes.Member,
                    state_key="@user_id:example.com",
                    content={"membership": Membership.JOIN},
                    membership=Membership.JOIN,
                    depth=2,
                ),
                "B": DictObj(type=EventTypes.Name, state_key="", depth=3),
                "C": DictObj(
                    type=EventTypes.Member,
                    state_key="@user_id_2:example.com",
                    content={"membership": Membership.BAN},
                    membership=Membership.BAN,
                    depth=4,
                ),
                "D": DictObj(
                    type=EventTypes.Name,
                    state_key="",
                    depth=4,
                    sender="@user_id_2:example.com",
                ),
                "E": DictObj(type=EventTypes.Message, depth=5),
            },
            edges={"A": ["START"], "B": ["A"], "C": ["B"], "D": ["B"], "E": ["C", "D"]},
        )

        self.store.register_events(graph.walk())

        context_store = {}

        for event in graph.walk():
            context = yield self.state.compute_event_context(event)
            self.store.register_event_context(event, context)
            context_store[event.event_id] = context

        prev_state_ids = yield context_store["E"].get_prev_state_ids(self.store)

        self.assertSetEqual(
            {"START", "A", "B", "C"}, {e for e in prev_state_ids.values()}
        )

    @defer.inlineCallbacks
    def test_branch_have_perms_conflict(self):
        userid1 = "@user_id:example.com"
        userid2 = "@user_id2:example.com"

        nodes = {
            "A1": DictObj(
                type=EventTypes.Create,
                state_key="",
                content={"creator": userid1},
                depth=1,
            ),
            "A2": DictObj(
                type=EventTypes.Member,
                state_key=userid1,
                content={"membership": Membership.JOIN},
                membership=Membership.JOIN,
            ),
            "A3": DictObj(
                type=EventTypes.Member,
                state_key=userid2,
                content={"membership": Membership.JOIN},
                membership=Membership.JOIN,
            ),
            "A4": DictObj(
                type=EventTypes.PowerLevels,
                state_key="",
                content={
                    "events": {"m.room.name": 50},
                    "users": {userid1: 100, userid2: 60},
                },
            ),
            "A5": DictObj(type=EventTypes.Name, state_key=""),
            "B": DictObj(
                type=EventTypes.PowerLevels,
                state_key="",
                content={"events": {"m.room.name": 50}, "users": {userid2: 30}},
            ),
            "C": DictObj(type=EventTypes.Name, state_key="", sender=userid2),
            "D": DictObj(type=EventTypes.Message),
        }
        edges = {
            "A2": ["A1"],
            "A3": ["A2"],
            "A4": ["A3"],
            "A5": ["A4"],
            "B": ["A5"],
            "C": ["A5"],
            "D": ["B", "C"],
        }
        self._add_depths(nodes, edges)
        graph = Graph(nodes, edges)

        self.store.register_events(graph.walk())

        context_store = {}

        for event in graph.walk():
            context = yield self.state.compute_event_context(event)
            self.store.register_event_context(event, context)
            context_store[event.event_id] = context

        prev_state_ids = yield context_store["D"].get_prev_state_ids(self.store)

        self.assertSetEqual(
            {"A1", "A2", "A3", "A5", "B"}, {e for e in prev_state_ids.values()}
        )

    def _add_depths(self, nodes, edges):
        def _get_depth(ev):
            node = nodes[ev]
            if 'depth' not in node:
                prevs = edges[ev]
                depth = max(_get_depth(prev) for prev in prevs) + 1
                node['depth'] = depth
            return node['depth']

        for n in nodes:
            _get_depth(n)

    @defer.inlineCallbacks
    def test_annotate_with_old_message(self):
        event = create_event(type="test_message", name="event")

        old_state = [
            create_event(type="test1", state_key="1"),
            create_event(type="test1", state_key="2"),
            create_event(type="test2", state_key=""),
        ]

        context = yield self.state.compute_event_context(event, old_state=old_state)

        current_state_ids = yield context.get_current_state_ids(self.store)

        self.assertEqual(
            set(e.event_id for e in old_state), set(current_state_ids.values())
        )

        self.assertIsNotNone(context.state_group)

    @defer.inlineCallbacks
    def test_annotate_with_old_state(self):
        event = create_event(type="state", state_key="", name="event")

        old_state = [
            create_event(type="test1", state_key="1"),
            create_event(type="test1", state_key="2"),
            create_event(type="test2", state_key=""),
        ]

        context = yield self.state.compute_event_context(event, old_state=old_state)

        prev_state_ids = yield context.get_prev_state_ids(self.store)

        self.assertEqual(
            set(e.event_id for e in old_state), set(prev_state_ids.values())
        )

    @defer.inlineCallbacks
    def test_trivial_annotate_message(self):
        prev_event_id = "prev_event_id"
        event = create_event(
            type="test_message", name="event2", prev_events=[(prev_event_id, {})]
        )

        old_state = [
            create_event(type="test1", state_key="1"),
            create_event(type="test1", state_key="2"),
            create_event(type="test2", state_key=""),
        ]

        group_name = self.store.store_state_group(
            prev_event_id,
            event.room_id,
            None,
            None,
            {(e.type, e.state_key): e.event_id for e in old_state},
        )
        self.store.register_event_id_state_group(prev_event_id, group_name)

        context = yield self.state.compute_event_context(event)

        current_state_ids = yield context.get_current_state_ids(self.store)

        self.assertEqual(
            set([e.event_id for e in old_state]), set(current_state_ids.values())
        )

        self.assertEqual(group_name, context.state_group)

    @defer.inlineCallbacks
    def test_trivial_annotate_state(self):
        prev_event_id = "prev_event_id"
        event = create_event(
            type="state", state_key="", name="event2", prev_events=[(prev_event_id, {})]
        )

        old_state = [
            create_event(type="test1", state_key="1"),
            create_event(type="test1", state_key="2"),
            create_event(type="test2", state_key=""),
        ]

        group_name = self.store.store_state_group(
            prev_event_id,
            event.room_id,
            None,
            None,
            {(e.type, e.state_key): e.event_id for e in old_state},
        )
        self.store.register_event_id_state_group(prev_event_id, group_name)

        context = yield self.state.compute_event_context(event)

        prev_state_ids = yield context.get_prev_state_ids(self.store)

        self.assertEqual(
            set([e.event_id for e in old_state]), set(prev_state_ids.values())
        )

        self.assertIsNotNone(context.state_group)

    @defer.inlineCallbacks
    def test_resolve_message_conflict(self):
        prev_event_id1 = "event_id1"
        prev_event_id2 = "event_id2"
        event = create_event(
            type="test_message",
            name="event3",
            prev_events=[(prev_event_id1, {}), (prev_event_id2, {})],
        )

        creation = create_event(type=EventTypes.Create, state_key="")

        old_state_1 = [
            creation,
            create_event(type="test1", state_key="1"),
            create_event(type="test1", state_key="2"),
            create_event(type="test2", state_key=""),
        ]

        old_state_2 = [
            creation,
            create_event(type="test1", state_key="1"),
            create_event(type="test3", state_key="2"),
            create_event(type="test4", state_key=""),
        ]

        self.store.register_events(old_state_1)
        self.store.register_events(old_state_2)

        context = yield self._get_context(
            event, prev_event_id1, old_state_1, prev_event_id2, old_state_2
        )

        current_state_ids = yield context.get_current_state_ids(self.store)

        self.assertEqual(len(current_state_ids), 6)

        self.assertIsNotNone(context.state_group)

    @defer.inlineCallbacks
    def test_resolve_state_conflict(self):
        prev_event_id1 = "event_id1"
        prev_event_id2 = "event_id2"
        event = create_event(
            type="test4",
            state_key="",
            name="event",
            prev_events=[(prev_event_id1, {}), (prev_event_id2, {})],
        )

        creation = create_event(type=EventTypes.Create, state_key="")

        old_state_1 = [
            creation,
            create_event(type="test1", state_key="1"),
            create_event(type="test1", state_key="2"),
            create_event(type="test2", state_key=""),
        ]

        old_state_2 = [
            creation,
            create_event(type="test1", state_key="1"),
            create_event(type="test3", state_key="2"),
            create_event(type="test4", state_key=""),
        ]

        store = StateGroupStore()
        store.register_events(old_state_1)
        store.register_events(old_state_2)
        self.store.get_events = store.get_events

        context = yield self._get_context(
            event, prev_event_id1, old_state_1, prev_event_id2, old_state_2
        )

        current_state_ids = yield context.get_current_state_ids(self.store)

        self.assertEqual(len(current_state_ids), 6)

        self.assertIsNotNone(context.state_group)

    @defer.inlineCallbacks
    def test_standard_depth_conflict(self):
        prev_event_id1 = "event_id1"
        prev_event_id2 = "event_id2"
        event = create_event(
            type="test4",
            name="event",
            prev_events=[(prev_event_id1, {}), (prev_event_id2, {})],
        )

        member_event = create_event(
            type=EventTypes.Member,
            state_key="@user_id:example.com",
            content={"membership": Membership.JOIN},
        )

        power_levels = create_event(
            type=EventTypes.PowerLevels,
            state_key="",
            content={"users": {"@foo:bar": "100", "@user_id:example.com": "100"}},
        )

        creation = create_event(
            type=EventTypes.Create, state_key="", content={"creator": "@foo:bar"}
        )

        old_state_1 = [
            creation,
            power_levels,
            member_event,
            create_event(type="test1", state_key="1", depth=1),
        ]

        old_state_2 = [
            creation,
            power_levels,
            member_event,
            create_event(type="test1", state_key="1", depth=2),
        ]

        store = StateGroupStore()
        store.register_events(old_state_1)
        store.register_events(old_state_2)
        self.store.get_events = store.get_events

        context = yield self._get_context(
            event, prev_event_id1, old_state_1, prev_event_id2, old_state_2
        )

        current_state_ids = yield context.get_current_state_ids(self.store)

        self.assertEqual(old_state_2[3].event_id, current_state_ids[("test1", "1")])

        # Reverse the depth to make sure we are actually using the depths
        # during state resolution.

        old_state_1 = [
            creation,
            power_levels,
            member_event,
            create_event(type="test1", state_key="1", depth=2),
        ]

        old_state_2 = [
            creation,
            power_levels,
            member_event,
            create_event(type="test1", state_key="1", depth=1),
        ]

        store.register_events(old_state_1)
        store.register_events(old_state_2)

        context = yield self._get_context(
            event, prev_event_id1, old_state_1, prev_event_id2, old_state_2
        )

        current_state_ids = yield context.get_current_state_ids(self.store)

        self.assertEqual(old_state_1[3].event_id, current_state_ids[("test1", "1")])

    def _get_context(
        self, event, prev_event_id_1, old_state_1, prev_event_id_2, old_state_2
    ):
        sg1 = self.store.store_state_group(
            prev_event_id_1,
            event.room_id,
            None,
            None,
            {(e.type, e.state_key): e.event_id for e in old_state_1},
        )
        self.store.register_event_id_state_group(prev_event_id_1, sg1)

        sg2 = self.store.store_state_group(
            prev_event_id_2,
            event.room_id,
            None,
            None,
            {(e.type, e.state_key): e.event_id for e in old_state_2},
        )
        self.store.register_event_id_state_group(prev_event_id_2, sg2)

        return self.state.compute_event_context(event)
Example #19
0
 def get_state_handler(self) -> StateHandler:
     return StateHandler(self)
Example #20
0
class StateTestCase(unittest.TestCase):
    def setUp(self):
        self.persistence = Mock(spec=[
            "get_unresolved_state_tree",
            "update_current_state",
            "get_latest_pdus_in_context",
            "get_current_state_pdu",
            "get_pdu",
        ])
        self.replication = Mock(spec=["get_pdu"])

        hs = Mock(spec=["get_datastore", "get_replication_layer"])
        hs.get_datastore.return_value = self.persistence
        hs.get_replication_layer.return_value = self.replication
        hs.hostname = "bob.com"

        self.state = StateHandler(hs)

    @defer.inlineCallbacks
    def test_new_state_key(self):
        # We've never seen anything for this state before
        new_pdu = new_fake_pdu_entry("A", "test", "mem", "x", None, 10)

        self.persistence.get_unresolved_state_tree.return_value = (
            ReturnType([new_pdu], [])
        )

        is_new = yield self.state.handle_new_state(new_pdu)

        self.assertTrue(is_new)

        self.persistence.get_unresolved_state_tree.assert_called_once_with(
            new_pdu
        )

        self.assertEqual(1, self.persistence.update_current_state.call_count)

        self.assertFalse(self.replication.get_pdu.called)

    @defer.inlineCallbacks
    def test_direct_overwrite(self):
        # We do a direct overwriting of the old state, i.e., the new state
        # points to the old state.

        old_pdu = new_fake_pdu_entry("A", "test", "mem", "x", None, 10)
        new_pdu = new_fake_pdu_entry("B", "test", "mem", "x", "A", 5)

        self.persistence.get_unresolved_state_tree.return_value = (
            ReturnType([new_pdu, old_pdu], [old_pdu])
        )

        is_new = yield self.state.handle_new_state(new_pdu)

        self.assertTrue(is_new)

        self.persistence.get_unresolved_state_tree.assert_called_once_with(
            new_pdu
        )

        self.assertEqual(1, self.persistence.update_current_state.call_count)

        self.assertFalse(self.replication.get_pdu.called)

    @defer.inlineCallbacks
    def test_power_level_fail(self):
        # We try to update the state based on an outdated state, and have a
        # too low power level.

        old_pdu_1 = new_fake_pdu_entry("A", "test", "mem", "x", None, 10)
        old_pdu_2 = new_fake_pdu_entry("B", "test", "mem", "x", None, 10)
        new_pdu = new_fake_pdu_entry("C", "test", "mem", "x", "A", 5)

        self.persistence.get_unresolved_state_tree.return_value = (
            ReturnType([new_pdu, old_pdu_1], [old_pdu_2, old_pdu_1])
        )

        is_new = yield self.state.handle_new_state(new_pdu)

        self.assertFalse(is_new)

        self.persistence.get_unresolved_state_tree.assert_called_once_with(
            new_pdu
        )

        self.assertEqual(0, self.persistence.update_current_state.call_count)

        self.assertFalse(self.replication.get_pdu.called)

    @defer.inlineCallbacks
    def test_power_level_succeed(self):
        # We try to update the state based on an outdated state, but have
        # sufficient power level to force the update.

        old_pdu_1 = new_fake_pdu_entry("A", "test", "mem", "x", None, 10)
        old_pdu_2 = new_fake_pdu_entry("B", "test", "mem", "x", None, 10)
        new_pdu = new_fake_pdu_entry("C", "test", "mem", "x", "A", 15)

        self.persistence.get_unresolved_state_tree.return_value = (
            ReturnType([new_pdu, old_pdu_1], [old_pdu_2, old_pdu_1])
        )

        is_new = yield self.state.handle_new_state(new_pdu)

        self.assertTrue(is_new)

        self.persistence.get_unresolved_state_tree.assert_called_once_with(
            new_pdu
        )

        self.assertEqual(1, self.persistence.update_current_state.call_count)

        self.assertFalse(self.replication.get_pdu.called)

    @defer.inlineCallbacks
    def test_power_level_equal_same_len(self):
        # We try to update the state based on an outdated state, the power
        # levels are the same and so are the branch lengths

        old_pdu_1 = new_fake_pdu_entry("A", "test", "mem", "x", None, 10)
        old_pdu_2 = new_fake_pdu_entry("B", "test", "mem", "x", None, 10)
        new_pdu = new_fake_pdu_entry("C", "test", "mem", "x", "A", 10)

        self.persistence.get_unresolved_state_tree.return_value = (
            ReturnType([new_pdu, old_pdu_1], [old_pdu_2, old_pdu_1])
        )

        is_new = yield self.state.handle_new_state(new_pdu)

        self.assertTrue(is_new)

        self.persistence.get_unresolved_state_tree.assert_called_once_with(
            new_pdu
        )

        self.assertEqual(1, self.persistence.update_current_state.call_count)

        self.assertFalse(self.replication.get_pdu.called)

    @defer.inlineCallbacks
    def test_power_level_equal_diff_len(self):
        # We try to update the state based on an outdated state, the power
        # levels are the same but the branch length of the new one is longer.

        old_pdu_1 = new_fake_pdu_entry("A", "test", "mem", "x", None, 10)
        old_pdu_2 = new_fake_pdu_entry("B", "test", "mem", "x", None, 10)
        old_pdu_3 = new_fake_pdu_entry("C", "test", "mem", "x", "A", 10)
        new_pdu = new_fake_pdu_entry("D", "test", "mem", "x", "C", 10)

        self.persistence.get_unresolved_state_tree.return_value = (
            ReturnType([new_pdu, old_pdu_3, old_pdu_1], [old_pdu_2, old_pdu_1])
        )

        is_new = yield self.state.handle_new_state(new_pdu)

        self.assertTrue(is_new)

        self.persistence.get_unresolved_state_tree.assert_called_once_with(
            new_pdu
        )

        self.assertEqual(1, self.persistence.update_current_state.call_count)

        self.assertFalse(self.replication.get_pdu.called)

    @defer.inlineCallbacks
    def test_missing_pdu(self):
        # We try to update state against a PDU we haven't yet seen,
        # triggering a get_pdu request

        # The pdu we haven't seen
        old_pdu_1 = new_fake_pdu_entry("A", "test", "mem", "x", None, 10)

        old_pdu_2 = new_fake_pdu_entry("B", "test", "mem", "x", None, 10)
        new_pdu = new_fake_pdu_entry("C", "test", "mem", "x", "A", 20)

        # The return_value of `get_unresolved_state_tree`, which changes after
        # the call to get_pdu
        tree_to_return = [ReturnType([new_pdu], [old_pdu_2])]

        def return_tree(p):
            return tree_to_return[0]

        def set_return_tree(*args, **kwargs):
            tree_to_return[0] = ReturnType(
                [new_pdu, old_pdu_1], [old_pdu_2, old_pdu_1]
            )

        self.persistence.get_unresolved_state_tree.side_effect = return_tree

        self.replication.get_pdu.side_effect = set_return_tree

        self.persistence.get_pdu.return_value = None

        is_new = yield self.state.handle_new_state(new_pdu)

        self.assertTrue(is_new)

        self.persistence.get_unresolved_state_tree.assert_called_with(
            new_pdu
        )

        self.assertEquals(
            2, self.persistence.get_unresolved_state_tree.call_count
        )

        self.assertEqual(1, self.persistence.update_current_state.call_count)

    @defer.inlineCallbacks
    def test_new_event(self):
        event = Mock()
        event.event_id = "12123123@test"

        state_pdu = new_fake_pdu_entry("C", "test", "mem", "x", "A", 20)

        snapshot = Mock()
        snapshot.prev_state_pdu = state_pdu
        event_id = "*****@*****.**"

        def fill_out_prev_events(event):
            event.prev_events = [event_id]
            event.depth = 6
        snapshot.fill_out_prev_events = fill_out_prev_events

        yield self.state.handle_new_event(event, snapshot)

        self.assertLess(5, event.depth)

        self.assertEquals(1, len(event.prev_events))

        prev_id = event.prev_events[0]

        self.assertEqual(event_id, prev_id)

        self.assertEqual(
            encode_event_id(state_pdu.pdu_id, state_pdu.origin),
            event.prev_state
        )
Example #21
0
class StateTestCase(unittest.TestCase):
    def setUp(self):
        self.store = Mock(spec_set=[
            "get_state_groups",
            "add_event_hashes",
        ])
        hs = Mock(spec=[
            "get_datastore",
            "get_auth",
            "get_state_handler",
            "get_clock",
        ])
        hs.get_datastore.return_value = self.store
        hs.get_state_handler.return_value = None
        hs.get_auth.return_value = Auth(hs)
        hs.get_clock.return_value = MockClock()

        self.state = StateHandler(hs)
        self.event_id = 0

    @defer.inlineCallbacks
    def test_branch_no_conflict(self):
        graph = Graph(nodes={
            "START":
            DictObj(
                type=EventTypes.Create,
                state_key="",
                depth=1,
            ),
            "A":
            DictObj(
                type=EventTypes.Message,
                depth=2,
            ),
            "B":
            DictObj(
                type=EventTypes.Message,
                depth=3,
            ),
            "C":
            DictObj(
                type=EventTypes.Name,
                state_key="",
                depth=3,
            ),
            "D":
            DictObj(
                type=EventTypes.Message,
                depth=4,
            ),
        },
                      edges={
                          "A": ["START"],
                          "B": ["A"],
                          "C": ["A"],
                          "D": ["B", "C"]
                      })

        store = StateGroupStore()
        self.store.get_state_groups.side_effect = store.get_state_groups

        context_store = {}

        for event in graph.walk():
            context = yield self.state.compute_event_context(event)
            store.store_state_groups(event, context)
            context_store[event.event_id] = context

        self.assertEqual(2, len(context_store["D"].current_state))

    @defer.inlineCallbacks
    def test_branch_basic_conflict(self):
        graph = Graph(nodes={
            "START":
            DictObj(
                type=EventTypes.Create,
                state_key="",
                content={"creator": "@user_id:example.com"},
                depth=1,
            ),
            "A":
            DictObj(
                type=EventTypes.Member,
                state_key="@user_id:example.com",
                content={"membership": Membership.JOIN},
                membership=Membership.JOIN,
                depth=2,
            ),
            "B":
            DictObj(
                type=EventTypes.Name,
                state_key="",
                depth=3,
            ),
            "C":
            DictObj(
                type=EventTypes.Name,
                state_key="",
                depth=4,
            ),
            "D":
            DictObj(
                type=EventTypes.Message,
                depth=5,
            ),
        },
                      edges={
                          "A": ["START"],
                          "B": ["A"],
                          "C": ["A"],
                          "D": ["B", "C"]
                      })

        store = StateGroupStore()
        self.store.get_state_groups.side_effect = store.get_state_groups

        context_store = {}

        for event in graph.walk():
            context = yield self.state.compute_event_context(event)
            store.store_state_groups(event, context)
            context_store[event.event_id] = context

        self.assertSetEqual(
            {"START", "A", "C"},
            {e.event_id
             for e in context_store["D"].current_state.values()})

    @defer.inlineCallbacks
    def test_branch_have_banned_conflict(self):
        graph = Graph(nodes={
            "START":
            DictObj(
                type=EventTypes.Create,
                state_key="",
                content={"creator": "@user_id:example.com"},
                depth=1,
            ),
            "A":
            DictObj(
                type=EventTypes.Member,
                state_key="@user_id:example.com",
                content={"membership": Membership.JOIN},
                membership=Membership.JOIN,
                depth=2,
            ),
            "B":
            DictObj(
                type=EventTypes.Name,
                state_key="",
                depth=3,
            ),
            "C":
            DictObj(
                type=EventTypes.Member,
                state_key="@user_id_2:example.com",
                content={"membership": Membership.BAN},
                membership=Membership.BAN,
                depth=4,
            ),
            "D":
            DictObj(
                type=EventTypes.Name,
                state_key="",
                depth=4,
                sender="@user_id_2:example.com",
            ),
            "E":
            DictObj(
                type=EventTypes.Message,
                depth=5,
            ),
        },
                      edges={
                          "A": ["START"],
                          "B": ["A"],
                          "C": ["B"],
                          "D": ["B"],
                          "E": ["C", "D"]
                      })

        store = StateGroupStore()
        self.store.get_state_groups.side_effect = store.get_state_groups

        context_store = {}

        for event in graph.walk():
            context = yield self.state.compute_event_context(event)
            store.store_state_groups(event, context)
            context_store[event.event_id] = context

        self.assertSetEqual(
            {"START", "A", "B", "C"},
            {e.event_id
             for e in context_store["E"].current_state.values()})

    @defer.inlineCallbacks
    def test_branch_have_perms_conflict(self):
        userid1 = "@user_id:example.com"
        userid2 = "@user_id2:example.com"

        nodes = {
            "A1":
            DictObj(
                type=EventTypes.Create,
                state_key="",
                content={"creator": userid1},
                depth=1,
            ),
            "A2":
            DictObj(
                type=EventTypes.Member,
                state_key=userid1,
                content={"membership": Membership.JOIN},
                membership=Membership.JOIN,
            ),
            "A3":
            DictObj(
                type=EventTypes.Member,
                state_key=userid2,
                content={"membership": Membership.JOIN},
                membership=Membership.JOIN,
            ),
            "A4":
            DictObj(
                type=EventTypes.PowerLevels,
                state_key="",
                content={
                    "events": {
                        "m.room.name": 50
                    },
                    "users": {
                        userid1: 100,
                        userid2: 60
                    },
                },
            ),
            "A5":
            DictObj(
                type=EventTypes.Name,
                state_key="",
            ),
            "B":
            DictObj(
                type=EventTypes.PowerLevels,
                state_key="",
                content={
                    "events": {
                        "m.room.name": 50
                    },
                    "users": {
                        userid2: 30
                    },
                },
            ),
            "C":
            DictObj(
                type=EventTypes.Name,
                state_key="",
                sender=userid2,
            ),
            "D":
            DictObj(type=EventTypes.Message, ),
        }
        edges = {
            "A2": ["A1"],
            "A3": ["A2"],
            "A4": ["A3"],
            "A5": ["A4"],
            "B": ["A5"],
            "C": ["A5"],
            "D": ["B", "C"]
        }
        self._add_depths(nodes, edges)
        graph = Graph(nodes, edges)

        store = StateGroupStore()
        self.store.get_state_groups.side_effect = store.get_state_groups

        context_store = {}

        for event in graph.walk():
            context = yield self.state.compute_event_context(event)
            store.store_state_groups(event, context)
            context_store[event.event_id] = context

        self.assertSetEqual(
            {"A1", "A2", "A3", "A5", "B"},
            {e.event_id
             for e in context_store["D"].current_state.values()})

    def _add_depths(self, nodes, edges):
        def _get_depth(ev):
            node = nodes[ev]
            if 'depth' not in node:
                prevs = edges[ev]
                depth = max(_get_depth(prev) for prev in prevs) + 1
                node['depth'] = depth
            return node['depth']

        for n in nodes:
            _get_depth(n)

    @defer.inlineCallbacks
    def test_annotate_with_old_message(self):
        event = create_event(type="test_message", name="event")

        old_state = [
            create_event(type="test1", state_key="1"),
            create_event(type="test1", state_key="2"),
            create_event(type="test2", state_key=""),
        ]

        context = yield self.state.compute_event_context(event,
                                                         old_state=old_state)

        for k, v in context.current_state.items():
            type, state_key = k
            self.assertEqual(type, v.type)
            self.assertEqual(state_key, v.state_key)

        self.assertEqual(set(old_state), set(context.current_state.values()))

        self.assertIsNone(context.state_group)

    @defer.inlineCallbacks
    def test_annotate_with_old_state(self):
        event = create_event(type="state", state_key="", name="event")

        old_state = [
            create_event(type="test1", state_key="1"),
            create_event(type="test1", state_key="2"),
            create_event(type="test2", state_key=""),
        ]

        context = yield self.state.compute_event_context(event,
                                                         old_state=old_state)

        for k, v in context.current_state.items():
            type, state_key = k
            self.assertEqual(type, v.type)
            self.assertEqual(state_key, v.state_key)

        self.assertEqual(set(old_state), set(context.current_state.values()))

        self.assertIsNone(context.state_group)

    @defer.inlineCallbacks
    def test_trivial_annotate_message(self):
        event = create_event(type="test_message", name="event")

        old_state = [
            create_event(type="test1", state_key="1"),
            create_event(type="test1", state_key="2"),
            create_event(type="test2", state_key=""),
        ]

        group_name = "group_name_1"

        self.store.get_state_groups.return_value = {
            group_name: old_state,
        }

        context = yield self.state.compute_event_context(event)

        for k, v in context.current_state.items():
            type, state_key = k
            self.assertEqual(type, v.type)
            self.assertEqual(state_key, v.state_key)

        self.assertEqual(
            set([e.event_id for e in old_state]),
            set([e.event_id for e in context.current_state.values()]))

        self.assertEqual(group_name, context.state_group)

    @defer.inlineCallbacks
    def test_trivial_annotate_state(self):
        event = create_event(type="state", state_key="", name="event")

        old_state = [
            create_event(type="test1", state_key="1"),
            create_event(type="test1", state_key="2"),
            create_event(type="test2", state_key=""),
        ]

        group_name = "group_name_1"

        self.store.get_state_groups.return_value = {
            group_name: old_state,
        }

        context = yield self.state.compute_event_context(event)

        for k, v in context.current_state.items():
            type, state_key = k
            self.assertEqual(type, v.type)
            self.assertEqual(state_key, v.state_key)

        self.assertEqual(
            set([e.event_id for e in old_state]),
            set([e.event_id for e in context.current_state.values()]))

        self.assertIsNone(context.state_group)

    @defer.inlineCallbacks
    def test_resolve_message_conflict(self):
        event = create_event(type="test_message", name="event")

        creation = create_event(type=EventTypes.Create, state_key="")

        old_state_1 = [
            creation,
            create_event(type="test1", state_key="1"),
            create_event(type="test1", state_key="2"),
            create_event(type="test2", state_key=""),
        ]

        old_state_2 = [
            creation,
            create_event(type="test1", state_key="1"),
            create_event(type="test3", state_key="2"),
            create_event(type="test4", state_key=""),
        ]

        context = yield self._get_context(event, old_state_1, old_state_2)

        self.assertEqual(len(context.current_state), 6)

        self.assertIsNone(context.state_group)

    @defer.inlineCallbacks
    def test_resolve_state_conflict(self):
        event = create_event(type="test4", state_key="", name="event")

        creation = create_event(type=EventTypes.Create, state_key="")

        old_state_1 = [
            creation,
            create_event(type="test1", state_key="1"),
            create_event(type="test1", state_key="2"),
            create_event(type="test2", state_key=""),
        ]

        old_state_2 = [
            creation,
            create_event(type="test1", state_key="1"),
            create_event(type="test3", state_key="2"),
            create_event(type="test4", state_key=""),
        ]

        context = yield self._get_context(event, old_state_1, old_state_2)

        self.assertEqual(len(context.current_state), 6)

        self.assertIsNone(context.state_group)

    @defer.inlineCallbacks
    def test_standard_depth_conflict(self):
        event = create_event(type="test4", name="event")

        member_event = create_event(type=EventTypes.Member,
                                    state_key="@user_id:example.com",
                                    content={
                                        "membership": Membership.JOIN,
                                    })

        creation = create_event(type=EventTypes.Create,
                                state_key="",
                                content={"creator": "@foo:bar"})

        old_state_1 = [
            creation,
            member_event,
            create_event(type="test1", state_key="1", depth=1),
        ]

        old_state_2 = [
            creation,
            member_event,
            create_event(type="test1", state_key="1", depth=2),
        ]

        context = yield self._get_context(event, old_state_1, old_state_2)

        self.assertEqual(old_state_2[2], context.current_state[("test1", "1")])

        # Reverse the depth to make sure we are actually using the depths
        # during state resolution.

        old_state_1 = [
            creation,
            member_event,
            create_event(type="test1", state_key="1", depth=2),
        ]

        old_state_2 = [
            creation,
            member_event,
            create_event(type="test1", state_key="1", depth=1),
        ]

        context = yield self._get_context(event, old_state_1, old_state_2)

        self.assertEqual(old_state_1[2], context.current_state[("test1", "1")])

    def _get_context(self, event, old_state_1, old_state_2):
        group_name_1 = "group_name_1"
        group_name_2 = "group_name_2"

        self.store.get_state_groups.return_value = {
            group_name_1: old_state_1,
            group_name_2: old_state_2,
        }

        return self.state.compute_event_context(event)
Example #22
0
class StateTestCase(unittest.TestCase):
    def setUp(self):
        self.store = Mock(
            spec_set=[
                "get_state_groups",
                "add_event_hashes",
            ]
        )
        hs = Mock(spec=["get_datastore"])
        hs.get_datastore.return_value = self.store

        self.state = StateHandler(hs)
        self.event_id = 0

    @defer.inlineCallbacks
    def test_annotate_with_old_message(self):
        event = self.create_event(type="test_message", name="event")

        old_state = [
            self.create_event(type="test1", state_key="1"),
            self.create_event(type="test1", state_key="2"),
            self.create_event(type="test2", state_key=""),
        ]

        context = yield self.state.compute_event_context(
            event, old_state=old_state
        )

        for k, v in context.current_state.items():
            type, state_key = k
            self.assertEqual(type, v.type)
            self.assertEqual(state_key, v.state_key)

        self.assertEqual(
            set(old_state), set(context.current_state.values())
        )

        self.assertIsNone(context.state_group)

    @defer.inlineCallbacks
    def test_annotate_with_old_state(self):
        event = self.create_event(type="state", state_key="", name="event")

        old_state = [
            self.create_event(type="test1", state_key="1"),
            self.create_event(type="test1", state_key="2"),
            self.create_event(type="test2", state_key=""),
        ]

        context = yield self.state.compute_event_context(
            event, old_state=old_state
        )

        for k, v in context.current_state.items():
            type, state_key = k
            self.assertEqual(type, v.type)
            self.assertEqual(state_key, v.state_key)

        self.assertEqual(
            set(old_state),
            set(context.current_state.values())
        )

        self.assertIsNone(context.state_group)

    @defer.inlineCallbacks
    def test_trivial_annotate_message(self):
        event = self.create_event(type="test_message", name="event")
        event.prev_events = []

        old_state = [
            self.create_event(type="test1", state_key="1"),
            self.create_event(type="test1", state_key="2"),
            self.create_event(type="test2", state_key=""),
        ]

        group_name = "group_name_1"

        self.store.get_state_groups.return_value = {
            group_name: old_state,
        }

        context = yield self.state.compute_event_context(event)

        for k, v in context.current_state.items():
            type, state_key = k
            self.assertEqual(type, v.type)
            self.assertEqual(state_key, v.state_key)

        self.assertEqual(
            set([e.event_id for e in old_state]),
            set([e.event_id for e in context.current_state.values()])
        )

        self.assertEqual(group_name, context.state_group)

    @defer.inlineCallbacks
    def test_trivial_annotate_state(self):
        event = self.create_event(type="state", state_key="", name="event")
        event.prev_events = []

        old_state = [
            self.create_event(type="test1", state_key="1"),
            self.create_event(type="test1", state_key="2"),
            self.create_event(type="test2", state_key=""),
        ]

        group_name = "group_name_1"

        self.store.get_state_groups.return_value = {
            group_name: old_state,
        }

        context = yield self.state.compute_event_context(event)

        for k, v in context.current_state.items():
            type, state_key = k
            self.assertEqual(type, v.type)
            self.assertEqual(state_key, v.state_key)

        self.assertEqual(
            set([e.event_id for e in old_state]),
            set([e.event_id for e in context.current_state.values()])
        )

        self.assertIsNone(context.state_group)

    @defer.inlineCallbacks
    def test_resolve_message_conflict(self):
        event = self.create_event(type="test_message", name="event")
        event.prev_events = []

        old_state_1 = [
            self.create_event(type="test1", state_key="1"),
            self.create_event(type="test1", state_key="2"),
            self.create_event(type="test2", state_key=""),
        ]

        old_state_2 = [
            self.create_event(type="test1", state_key="1"),
            self.create_event(type="test3", state_key="2"),
            self.create_event(type="test4", state_key=""),
        ]

        group_name_1 = "group_name_1"
        group_name_2 = "group_name_2"

        self.store.get_state_groups.return_value = {
            group_name_1: old_state_1,
            group_name_2: old_state_2,
        }

        context = yield self.state.compute_event_context(event)

        self.assertEqual(len(context.current_state), 5)

        self.assertIsNone(context.state_group)

    @defer.inlineCallbacks
    def test_resolve_state_conflict(self):
        event = self.create_event(type="test4", state_key="", name="event")
        event.prev_events = []

        old_state_1 = [
            self.create_event(type="test1", state_key="1"),
            self.create_event(type="test1", state_key="2"),
            self.create_event(type="test2", state_key=""),
        ]

        old_state_2 = [
            self.create_event(type="test1", state_key="1"),
            self.create_event(type="test3", state_key="2"),
            self.create_event(type="test4", state_key=""),
        ]

        group_name_1 = "group_name_1"
        group_name_2 = "group_name_2"

        self.store.get_state_groups.return_value = {
            group_name_1: old_state_1,
            group_name_2: old_state_2,
        }

        context = yield self.state.compute_event_context(event)

        self.assertEqual(len(context.current_state), 5)

        self.assertIsNone(context.state_group)

    def create_event(self, name=None, type=None, state_key=None):
        self.event_id += 1
        event_id = str(self.event_id)

        if not name:
            if state_key is not None:
                name = "<%s-%s>" % (type, state_key)
            else:
                name = "<%s>" % (type, )

        event = Mock(name=name, spec=[])
        event.type = type

        if state_key is not None:
            event.state_key = state_key
        event.event_id = event_id

        event.is_state = lambda: (state_key is not None)
        event.unsigned = {}

        event.user_id = "@user_id:example.com"
        event.room_id = "!room_id:example.com"

        return event