Example #1
0
    def __init__(self, database: DatabasePool, db_conn, hs):
        super().__init__(database, db_conn, hs)

        # Originally the state store used a single DictionaryCache to cache the
        # event IDs for the state types in a given state group to avoid hammering
        # on the state_group* tables.
        #
        # The point of using a DictionaryCache is that it can cache a subset
        # of the state events for a given state group (i.e. a subset of the keys for a
        # given dict which is an entry in the cache for a given state group ID).
        #
        # However, this poses problems when performing complicated queries
        # on the store - for instance: "give me all the state for this group, but
        # limit members to this subset of users", as DictionaryCache's API isn't
        # rich enough to say "please cache any of these fields, apart from this subset".
        # This is problematic when lazy loading members, which requires this behaviour,
        # as without it the cache has no choice but to speculatively load all
        # state events for the group, which negates the efficiency being sought.
        #
        # Rather than overcomplicating DictionaryCache's API, we instead split the
        # state_group_cache into two halves - one for tracking non-member events,
        # and the other for tracking member_events.  This means that lazy loading
        # queries can be made in a cache-friendly manner by querying both caches
        # separately and then merging the result.  So for the example above, you
        # would query the members cache for a specific subset of state keys
        # (which DictionaryCache will handle efficiently and fine) and the non-members
        # cache for all state (which DictionaryCache will similarly handle fine)
        # and then just merge the results together.
        #
        # We size the non-members cache to be smaller than the members cache as the
        # vast majority of state in Matrix (today) is member events.

        self._state_group_cache = DictionaryCache(
            "*stateGroupCache*",
            # TODO: this hasn't been tuned yet
            50000,
        )
        self._state_group_members_cache = DictionaryCache(
            "*stateGroupMembersCache*",
            500000,
        )

        def get_max_state_group_txn(txn: Cursor):
            txn.execute("SELECT COALESCE(max(id), 0) FROM state_groups")
            return txn.fetchone()[0]

        self._state_group_seq_gen = build_sequence_generator(
            db_conn,
            self.database_engine,
            get_max_state_group_txn,
            "state_group_id_seq",
            table="state_groups",
            id_column="id",
        )
Example #2
0
    def __init__(self, hs):
        self.hs = hs
        self._db_pool = hs.get_db_pool()
        self._clock = hs.get_clock()

        self._previous_txn_total_time = 0
        self._current_txn_total_time = 0
        self._previous_loop_ts = 0

        # TODO(paul): These can eventually be removed once the metrics code
        #   is running in mainline, and we have some nice monitoring frontends
        #   to watch it
        self._txn_perf_counters = PerformanceCounters()
        self._get_event_counters = PerformanceCounters()

        self._get_event_cache = Cache("*getEvent*", keylen=3, lru=True,
                                      max_entries=hs.config.event_cache_size)

        self._state_group_cache = DictionaryCache(
            "*stateGroupCache*", 2000 * CACHE_SIZE_FACTOR
        )

        self._event_fetch_lock = threading.Condition()
        self._event_fetch_list = []
        self._event_fetch_ongoing = 0

        self._pending_ds = []

        self.database_engine = hs.database_engine
Example #3
0
    def __init__(self, hs):
        self.hs = hs
        self._db_pool = hs.get_db_pool()
        self._clock = hs.get_clock()

        self._previous_txn_total_time = 0
        self._current_txn_total_time = 0
        self._previous_loop_ts = 0

        # TODO(paul): These can eventually be removed once the metrics code
        #   is running in mainline, and we have some nice monitoring frontends
        #   to watch it
        self._txn_perf_counters = PerformanceCounters()
        self._get_event_counters = PerformanceCounters()

        self._get_event_cache = Cache("*getEvent*",
                                      keylen=3,
                                      lru=True,
                                      max_entries=hs.config.event_cache_size)

        self._state_group_cache = DictionaryCache("*stateGroupCache*", 2000)

        self._event_fetch_lock = threading.Condition()
        self._event_fetch_list = []
        self._event_fetch_ongoing = 0

        self._pending_ds = []

        self.database_engine = hs.database_engine

        self._stream_id_gen = StreamIdGenerator("events", "stream_ordering")
        self._transaction_id_gen = IdGenerator("sent_transactions", "id", self)
        self._state_groups_id_gen = IdGenerator("state_groups", "id", self)
        self._access_tokens_id_gen = IdGenerator("access_tokens", "id", self)
        self._pushers_id_gen = IdGenerator("pushers", "id", self)
        self._push_rule_id_gen = IdGenerator("push_rules", "id", self)
        self._push_rules_enable_id_gen = IdGenerator("push_rules_enable", "id",
                                                     self)
        self._receipts_id_gen = StreamIdGenerator("receipts_linearized",
                                                  "stream_id")
Example #4
0
 def setUp(self):
     self.cache = DictionaryCache("foobar")
Example #5
0
class DictCacheTestCase(unittest.TestCase):

    def setUp(self):
        self.cache = DictionaryCache("foobar")

    def test_simple_cache_hit_full(self):
        key = "test_simple_cache_hit_full"

        v = self.cache.get(key)
        self.assertEqual((False, {}), v)

        seq = self.cache.sequence
        test_value = {"test": "test_simple_cache_hit_full"}
        self.cache.update(seq, key, test_value, full=True)

        c = self.cache.get(key)
        self.assertEqual(test_value, c.value)

    def test_simple_cache_hit_partial(self):
        key = "test_simple_cache_hit_partial"

        seq = self.cache.sequence
        test_value = {
            "test": "test_simple_cache_hit_partial"
        }
        self.cache.update(seq, key, test_value, full=True)

        c = self.cache.get(key, ["test"])
        self.assertEqual(test_value, c.value)

    def test_simple_cache_miss_partial(self):
        key = "test_simple_cache_miss_partial"

        seq = self.cache.sequence
        test_value = {
            "test": "test_simple_cache_miss_partial"
        }
        self.cache.update(seq, key, test_value, full=True)

        c = self.cache.get(key, ["test2"])
        self.assertEqual({}, c.value)

    def test_simple_cache_hit_miss_partial(self):
        key = "test_simple_cache_hit_miss_partial"

        seq = self.cache.sequence
        test_value = {
            "test": "test_simple_cache_hit_miss_partial",
            "test2": "test_simple_cache_hit_miss_partial2",
            "test3": "test_simple_cache_hit_miss_partial3",
        }
        self.cache.update(seq, key, test_value, full=True)

        c = self.cache.get(key, ["test2"])
        self.assertEqual({"test2": "test_simple_cache_hit_miss_partial2"}, c.value)

    def test_multi_insert(self):
        key = "test_simple_cache_hit_miss_partial"

        seq = self.cache.sequence
        test_value_1 = {
            "test": "test_simple_cache_hit_miss_partial",
        }
        self.cache.update(seq, key, test_value_1, full=False)

        seq = self.cache.sequence
        test_value_2 = {
            "test2": "test_simple_cache_hit_miss_partial2",
        }
        self.cache.update(seq, key, test_value_2, full=False)

        c = self.cache.get(key)
        self.assertEqual(
            {
                "test": "test_simple_cache_hit_miss_partial",
                "test2": "test_simple_cache_hit_miss_partial2",
            },
            c.value
        )
Example #6
0
class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
    """A data store for fetching/storing state groups.
    """
    def __init__(self, database: Database, db_conn, hs):
        super(StateGroupDataStore, self).__init__(database, db_conn, hs)

        # Originally the state store used a single DictionaryCache to cache the
        # event IDs for the state types in a given state group to avoid hammering
        # on the state_group* tables.
        #
        # The point of using a DictionaryCache is that it can cache a subset
        # of the state events for a given state group (i.e. a subset of the keys for a
        # given dict which is an entry in the cache for a given state group ID).
        #
        # However, this poses problems when performing complicated queries
        # on the store - for instance: "give me all the state for this group, but
        # limit members to this subset of users", as DictionaryCache's API isn't
        # rich enough to say "please cache any of these fields, apart from this subset".
        # This is problematic when lazy loading members, which requires this behaviour,
        # as without it the cache has no choice but to speculatively load all
        # state events for the group, which negates the efficiency being sought.
        #
        # Rather than overcomplicating DictionaryCache's API, we instead split the
        # state_group_cache into two halves - one for tracking non-member events,
        # and the other for tracking member_events.  This means that lazy loading
        # queries can be made in a cache-friendly manner by querying both caches
        # separately and then merging the result.  So for the example above, you
        # would query the members cache for a specific subset of state keys
        # (which DictionaryCache will handle efficiently and fine) and the non-members
        # cache for all state (which DictionaryCache will similarly handle fine)
        # and then just merge the results together.
        #
        # We size the non-members cache to be smaller than the members cache as the
        # vast majority of state in Matrix (today) is member events.

        self._state_group_cache = DictionaryCache(
            "*stateGroupCache*",
            # TODO: this hasn't been tuned yet
            50000,
        )
        self._state_group_members_cache = DictionaryCache(
            "*stateGroupMembersCache*",
            500000,
        )

    @cached(max_entries=10000, iterable=True)
    def get_state_group_delta(self, state_group):
        """Given a state group try to return a previous group and a delta between
        the old and the new.

        Returns:
            (prev_group, delta_ids), where both may be None.
        """
        def _get_state_group_delta_txn(txn):
            prev_group = self.db.simple_select_one_onecol_txn(
                txn,
                table="state_group_edges",
                keyvalues={"state_group": state_group},
                retcol="prev_state_group",
                allow_none=True,
            )

            if not prev_group:
                return _GetStateGroupDelta(None, None)

            delta_ids = self.db.simple_select_list_txn(
                txn,
                table="state_groups_state",
                keyvalues={"state_group": state_group},
                retcols=("type", "state_key", "event_id"),
            )

            return _GetStateGroupDelta(
                prev_group,
                {(row["type"], row["state_key"]): row["event_id"]
                 for row in delta_ids},
            )

        return self.db.runInteraction("get_state_group_delta",
                                      _get_state_group_delta_txn)

    @defer.inlineCallbacks
    def _get_state_groups_from_groups(self, groups: List[int],
                                      state_filter: StateFilter):
        """Returns the state groups for a given set of groups from the
        database, filtering on types of state events.

        Args:
            groups: list of state group IDs to query
            state_filter: The state filter used to fetch state
                from the database.
        Returns:
            Deferred[Dict[int, StateMap[str]]]: Dict of state group to state map.
        """
        results = {}

        chunks = [groups[i:i + 100] for i in range(0, len(groups), 100)]
        for chunk in chunks:
            res = yield self.db.runInteraction(
                "_get_state_groups_from_groups",
                self._get_state_groups_from_groups_txn,
                chunk,
                state_filter,
            )
            results.update(res)

        return results

    def _get_state_for_group_using_cache(self, cache, group, state_filter):
        """Checks if group is in cache. See `_get_state_for_groups`

        Args:
            cache(DictionaryCache): the state group cache to use
            group(int): The state group to lookup
            state_filter (StateFilter): The state filter used to fetch state
                from the database.

        Returns 2-tuple (`state_dict`, `got_all`).
        `got_all` is a bool indicating if we successfully retrieved all
        requests state from the cache, if False we need to query the DB for the
        missing state.
        """
        is_all, known_absent, state_dict_ids = cache.get(group)

        if is_all or state_filter.is_full():
            # Either we have everything or want everything, either way
            # `is_all` tells us whether we've gotten everything.
            return state_filter.filter_state(state_dict_ids), is_all

        # tracks whether any of our requested types are missing from the cache
        missing_types = False

        if state_filter.has_wildcards():
            # We don't know if we fetched all the state keys for the types in
            # the filter that are wildcards, so we have to assume that we may
            # have missed some.
            missing_types = True
        else:
            # There aren't any wild cards, so `concrete_types()` returns the
            # complete list of event types we're wanting.
            for key in state_filter.concrete_types():
                if key not in state_dict_ids and key not in known_absent:
                    missing_types = True
                    break

        return state_filter.filter_state(state_dict_ids), not missing_types

    @defer.inlineCallbacks
    def _get_state_for_groups(self,
                              groups: Iterable[int],
                              state_filter: StateFilter = StateFilter.all()):
        """Gets the state at each of a list of state groups, optionally
        filtering by type/state_key

        Args:
            groups: list of state groups for which we want
                to get the state.
            state_filter: The state filter used to fetch state
                from the database.
        Returns:
            Deferred[Dict[int, StateMap[str]]]: Dict of state group to state map.
        """

        member_filter, non_member_filter = state_filter.get_member_split()

        # Now we look them up in the member and non-member caches
        (
            non_member_state,
            incomplete_groups_nm,
        ) = yield self._get_state_for_groups_using_cache(
            groups, self._state_group_cache, state_filter=non_member_filter)

        (
            member_state,
            incomplete_groups_m,
        ) = yield self._get_state_for_groups_using_cache(
            groups,
            self._state_group_members_cache,
            state_filter=member_filter)

        state = dict(non_member_state)
        for group in groups:
            state[group].update(member_state[group])

        # Now fetch any missing groups from the database

        incomplete_groups = incomplete_groups_m | incomplete_groups_nm

        if not incomplete_groups:
            return state

        cache_sequence_nm = self._state_group_cache.sequence
        cache_sequence_m = self._state_group_members_cache.sequence

        # Help the cache hit ratio by expanding the filter a bit
        db_state_filter = state_filter.return_expanded()

        group_to_state_dict = yield self._get_state_groups_from_groups(
            list(incomplete_groups), state_filter=db_state_filter)

        # Now lets update the caches
        self._insert_into_cache(
            group_to_state_dict,
            db_state_filter,
            cache_seq_num_members=cache_sequence_m,
            cache_seq_num_non_members=cache_sequence_nm,
        )

        # And finally update the result dict, by filtering out any extra
        # stuff we pulled out of the database.
        for group, group_state_dict in iteritems(group_to_state_dict):
            # We just replace any existing entries, as we will have loaded
            # everything we need from the database anyway.
            state[group] = state_filter.filter_state(group_state_dict)

        return state

    def _get_state_for_groups_using_cache(
        self, groups: Iterable[int], cache: DictionaryCache,
        state_filter: StateFilter
    ) -> Tuple[Dict[int, StateMap[str]], Set[int]]:
        """Gets the state at each of a list of state groups, optionally
        filtering by type/state_key, querying from a specific cache.

        Args:
            groups: list of state groups for which we want to get the state.
            cache: the cache of group ids to state dicts which
                we will pass through - either the normal state cache or the
                specific members state cache.
            state_filter: The state filter used to fetch state from the
                database.

        Returns:
            Tuple of dict of state_group_id to state map of entries in the
            cache, and the state group ids either missing from the cache or
            incomplete.
        """
        results = {}
        incomplete_groups = set()
        for group in set(groups):
            state_dict_ids, got_all = self._get_state_for_group_using_cache(
                cache, group, state_filter)
            results[group] = state_dict_ids

            if not got_all:
                incomplete_groups.add(group)

        return results, incomplete_groups

    def _insert_into_cache(
        self,
        group_to_state_dict,
        state_filter,
        cache_seq_num_members,
        cache_seq_num_non_members,
    ):
        """Inserts results from querying the database into the relevant cache.

        Args:
            group_to_state_dict (dict): The new entries pulled from database.
                Map from state group to state dict
            state_filter (StateFilter): The state filter used to fetch state
                from the database.
            cache_seq_num_members (int): Sequence number of member cache since
                last lookup in cache
            cache_seq_num_non_members (int): Sequence number of member cache since
                last lookup in cache
        """

        # We need to work out which types we've fetched from the DB for the
        # member vs non-member caches. This should be as accurate as possible,
        # but can be an underestimate (e.g. when we have wild cards)

        member_filter, non_member_filter = state_filter.get_member_split()
        if member_filter.is_full():
            # We fetched all member events
            member_types = None
        else:
            # `concrete_types()` will only return a subset when there are wild
            # cards in the filter, but that's fine.
            member_types = member_filter.concrete_types()

        if non_member_filter.is_full():
            # We fetched all non member events
            non_member_types = None
        else:
            non_member_types = non_member_filter.concrete_types()

        for group, group_state_dict in iteritems(group_to_state_dict):
            state_dict_members = {}
            state_dict_non_members = {}

            for k, v in iteritems(group_state_dict):
                if k[0] == EventTypes.Member:
                    state_dict_members[k] = v
                else:
                    state_dict_non_members[k] = v

            self._state_group_members_cache.update(
                cache_seq_num_members,
                key=group,
                value=state_dict_members,
                fetched_keys=member_types,
            )

            self._state_group_cache.update(
                cache_seq_num_non_members,
                key=group,
                value=state_dict_non_members,
                fetched_keys=non_member_types,
            )

    def store_state_group(self, event_id, room_id, prev_group, delta_ids,
                          current_state_ids):
        """Store a new set of state, returning a newly assigned state group.

        Args:
            event_id (str): The event ID for which the state was calculated
            room_id (str)
            prev_group (int|None): A previous state group for the room, optional.
            delta_ids (dict|None): The delta between state at `prev_group` and
                `current_state_ids`, if `prev_group` was given. Same format as
                `current_state_ids`.
            current_state_ids (dict): The state to store. Map of (type, state_key)
                to event_id.

        Returns:
            Deferred[int]: The state group ID
        """
        def _store_state_group_txn(txn):
            if current_state_ids is None:
                # AFAIK, this can never happen
                raise Exception("current_state_ids cannot be None")

            state_group = self.database_engine.get_next_state_group_id(txn)

            self.db.simple_insert_txn(
                txn,
                table="state_groups",
                values={
                    "id": state_group,
                    "room_id": room_id,
                    "event_id": event_id
                },
            )

            # We persist as a delta if we can, while also ensuring the chain
            # of deltas isn't tooo long, as otherwise read performance degrades.
            if prev_group:
                is_in_db = self.db.simple_select_one_onecol_txn(
                    txn,
                    table="state_groups",
                    keyvalues={"id": prev_group},
                    retcol="id",
                    allow_none=True,
                )
                if not is_in_db:
                    raise Exception(
                        "Trying to persist state with unpersisted prev_group: %r"
                        % (prev_group, ))

                potential_hops = self._count_state_group_hops_txn(
                    txn, prev_group)
            if prev_group and potential_hops < MAX_STATE_DELTA_HOPS:
                self.db.simple_insert_txn(
                    txn,
                    table="state_group_edges",
                    values={
                        "state_group": state_group,
                        "prev_state_group": prev_group
                    },
                )

                self.db.simple_insert_many_txn(
                    txn,
                    table="state_groups_state",
                    values=[{
                        "state_group": state_group,
                        "room_id": room_id,
                        "type": key[0],
                        "state_key": key[1],
                        "event_id": state_id,
                    } for key, state_id in iteritems(delta_ids)],
                )
            else:
                self.db.simple_insert_many_txn(
                    txn,
                    table="state_groups_state",
                    values=[{
                        "state_group": state_group,
                        "room_id": room_id,
                        "type": key[0],
                        "state_key": key[1],
                        "event_id": state_id,
                    } for key, state_id in iteritems(current_state_ids)],
                )

            # Prefill the state group caches with this group.
            # It's fine to use the sequence like this as the state group map
            # is immutable. (If the map wasn't immutable then this prefill could
            # race with another update)

            current_member_state_ids = {
                s: ev
                for (s, ev) in iteritems(current_state_ids)
                if s[0] == EventTypes.Member
            }
            txn.call_after(
                self._state_group_members_cache.update,
                self._state_group_members_cache.sequence,
                key=state_group,
                value=dict(current_member_state_ids),
            )

            current_non_member_state_ids = {
                s: ev
                for (s, ev) in iteritems(current_state_ids)
                if s[0] != EventTypes.Member
            }
            txn.call_after(
                self._state_group_cache.update,
                self._state_group_cache.sequence,
                key=state_group,
                value=dict(current_non_member_state_ids),
            )

            return state_group

        return self.db.runInteraction("store_state_group",
                                      _store_state_group_txn)

    def purge_unreferenced_state_groups(
            self, room_id: str, state_groups_to_delete) -> defer.Deferred:
        """Deletes no longer referenced state groups and de-deltas any state
        groups that reference them.

        Args:
            room_id: The room the state groups belong to (must all be in the
                same room).
            state_groups_to_delete (Collection[int]): Set of all state groups
                to delete.
        """

        return self.db.runInteraction(
            "purge_unreferenced_state_groups",
            self._purge_unreferenced_state_groups,
            room_id,
            state_groups_to_delete,
        )

    def _purge_unreferenced_state_groups(self, txn, room_id,
                                         state_groups_to_delete):
        logger.info("[purge] found %i state groups to delete",
                    len(state_groups_to_delete))

        rows = self.db.simple_select_many_txn(
            txn,
            table="state_group_edges",
            column="prev_state_group",
            iterable=state_groups_to_delete,
            keyvalues={},
            retcols=("state_group", ),
        )

        remaining_state_groups = {
            row["state_group"]
            for row in rows if row["state_group"] not in state_groups_to_delete
        }

        logger.info(
            "[purge] de-delta-ing %i remaining state groups",
            len(remaining_state_groups),
        )

        # Now we turn the state groups that reference to-be-deleted state
        # groups to non delta versions.
        for sg in remaining_state_groups:
            logger.info("[purge] de-delta-ing remaining state group %s", sg)
            curr_state = self._get_state_groups_from_groups_txn(txn, [sg])
            curr_state = curr_state[sg]

            self.db.simple_delete_txn(txn,
                                      table="state_groups_state",
                                      keyvalues={"state_group": sg})

            self.db.simple_delete_txn(txn,
                                      table="state_group_edges",
                                      keyvalues={"state_group": sg})

            self.db.simple_insert_many_txn(
                txn,
                table="state_groups_state",
                values=[{
                    "state_group": sg,
                    "room_id": room_id,
                    "type": key[0],
                    "state_key": key[1],
                    "event_id": state_id,
                } for key, state_id in iteritems(curr_state)],
            )

        logger.info("[purge] removing redundant state groups")
        txn.executemany(
            "DELETE FROM state_groups_state WHERE state_group = ?",
            ((sg, ) for sg in state_groups_to_delete),
        )
        txn.executemany(
            "DELETE FROM state_groups WHERE id = ?",
            ((sg, ) for sg in state_groups_to_delete),
        )

    @defer.inlineCallbacks
    def get_previous_state_groups(self, state_groups):
        """Fetch the previous groups of the given state groups.

        Args:
            state_groups (Iterable[int])

        Returns:
            Deferred[dict[int, int]]: mapping from state group to previous
            state group.
        """

        rows = yield self.db.simple_select_many_batch(
            table="state_group_edges",
            column="prev_state_group",
            iterable=state_groups,
            keyvalues={},
            retcols=("prev_state_group", "state_group"),
            desc="get_previous_state_groups",
        )

        return {row["state_group"]: row["prev_state_group"] for row in rows}

    def purge_room_state(self, room_id, state_groups_to_delete):
        """Deletes all record of a room from state tables

        Args:
            room_id (str):
            state_groups_to_delete (list[int]): State groups to delete
        """

        return self.db.runInteraction(
            "purge_room_state",
            self._purge_room_state_txn,
            room_id,
            state_groups_to_delete,
        )

    def _purge_room_state_txn(self, txn, room_id, state_groups_to_delete):
        # first we have to delete the state groups states
        logger.info("[purge] removing %s from state_groups_state", room_id)

        self.db.simple_delete_many_txn(
            txn,
            table="state_groups_state",
            column="state_group",
            iterable=state_groups_to_delete,
            keyvalues={},
        )

        # ... and the state group edges
        logger.info("[purge] removing %s from state_group_edges", room_id)

        self.db.simple_delete_many_txn(
            txn,
            table="state_group_edges",
            column="state_group",
            iterable=state_groups_to_delete,
            keyvalues={},
        )

        # ... and the state groups
        logger.info("[purge] removing %s from state_groups", room_id)

        self.db.simple_delete_many_txn(
            txn,
            table="state_groups",
            column="id",
            iterable=state_groups_to_delete,
            keyvalues={},
        )
Example #7
0
class StateGroupWorkerStore(EventsWorkerStore, StateGroupBackgroundUpdateStore,
                            SQLBaseStore):
    """The parts of StateGroupStore that can be called from workers.
    """

    STATE_GROUP_DEDUPLICATION_UPDATE_NAME = "state_group_state_deduplication"
    STATE_GROUP_INDEX_UPDATE_NAME = "state_group_state_type_index"
    CURRENT_STATE_INDEX_UPDATE_NAME = "current_state_members_idx"

    def __init__(self, db_conn, hs):
        super(StateGroupWorkerStore, self).__init__(db_conn, hs)

        # Originally the state store used a single DictionaryCache to cache the
        # event IDs for the state types in a given state group to avoid hammering
        # on the state_group* tables.
        #
        # The point of using a DictionaryCache is that it can cache a subset
        # of the state events for a given state group (i.e. a subset of the keys for a
        # given dict which is an entry in the cache for a given state group ID).
        #
        # However, this poses problems when performing complicated queries
        # on the store - for instance: "give me all the state for this group, but
        # limit members to this subset of users", as DictionaryCache's API isn't
        # rich enough to say "please cache any of these fields, apart from this subset".
        # This is problematic when lazy loading members, which requires this behaviour,
        # as without it the cache has no choice but to speculatively load all
        # state events for the group, which negates the efficiency being sought.
        #
        # Rather than overcomplicating DictionaryCache's API, we instead split the
        # state_group_cache into two halves - one for tracking non-member events,
        # and the other for tracking member_events.  This means that lazy loading
        # queries can be made in a cache-friendly manner by querying both caches
        # separately and then merging the result.  So for the example above, you
        # would query the members cache for a specific subset of state keys
        # (which DictionaryCache will handle efficiently and fine) and the non-members
        # cache for all state (which DictionaryCache will similarly handle fine)
        # and then just merge the results together.
        #
        # We size the non-members cache to be smaller than the members cache as the
        # vast majority of state in Matrix (today) is member events.

        self._state_group_cache = DictionaryCache(
            "*stateGroupCache*",
            # TODO: this hasn't been tuned yet
            50000 * get_cache_factor_for("stateGroupCache"),
        )
        self._state_group_members_cache = DictionaryCache(
            "*stateGroupMembersCache*",
            500000 * get_cache_factor_for("stateGroupMembersCache"),
        )

    @defer.inlineCallbacks
    def get_room_version(self, room_id):
        """Get the room_version of a given room

        Args:
            room_id (str)

        Returns:
            Deferred[str]

        Raises:
            NotFoundError if the room is unknown
        """
        # for now we do this by looking at the create event. We may want to cache this
        # more intelligently in future.

        # Retrieve the room's create event
        create_event = yield self.get_create_event_for_room(room_id)
        return create_event.content.get("room_version", "1")

    @defer.inlineCallbacks
    def get_room_predecessor(self, room_id):
        """Get the predecessor room of an upgraded room if one exists.
        Otherwise return None.

        Args:
            room_id (str)

        Returns:
            Deferred[dict|None]: A dictionary containing the structure of the predecessor
                field from the room's create event. The structure is subject to other servers,
                but it is expected to be:
                    * room_id (str): The room ID of the predecessor room
                    * event_id (str): The ID of the tombstone event in the predecessor room

        Raises:
            NotFoundError if the room is unknown
        """
        # Retrieve the room's create event
        create_event = yield self.get_create_event_for_room(room_id)

        # Return predecessor if present
        return create_event.content.get("predecessor", None)

    @defer.inlineCallbacks
    def get_create_event_for_room(self, room_id):
        """Get the create state event for a room.

        Args:
            room_id (str)

        Returns:
            Deferred[EventBase]: The room creation event.

        Raises:
            NotFoundError if the room is unknown
        """
        state_ids = yield self.get_current_state_ids(room_id)
        create_id = state_ids.get((EventTypes.Create, ""))

        # If we can't find the create event, assume we've hit a dead end
        if not create_id:
            raise NotFoundError("Unknown room %s" % (room_id))

        # Retrieve the room's create event and return
        create_event = yield self.get_event(create_id)
        return create_event

    @cached(max_entries=100000, iterable=True)
    def get_current_state_ids(self, room_id):
        """Get the current state event ids for a room based on the
        current_state_events table.

        Args:
            room_id (str)

        Returns:
            deferred: dict of (type, state_key) -> event_id
        """
        def _get_current_state_ids_txn(txn):
            txn.execute(
                """SELECT type, state_key, event_id FROM current_state_events
                WHERE room_id = ?
                """,
                (room_id, ),
            )

            return {(intern_string(r[0]), intern_string(r[1])): to_ascii(r[2])
                    for r in txn}

        return self.runInteraction("get_current_state_ids",
                                   _get_current_state_ids_txn)

    # FIXME: how should this be cached?
    def get_filtered_current_state_ids(self,
                                       room_id,
                                       state_filter=StateFilter.all()):
        """Get the current state event of a given type for a room based on the
        current_state_events table.  This may not be as up-to-date as the result
        of doing a fresh state resolution as per state_handler.get_current_state

        Args:
            room_id (str)
            state_filter (StateFilter): The state filter used to fetch state
                from the database.

        Returns:
            Deferred[dict[tuple[str, str], str]]: Map from type/state_key to
            event ID.
        """

        where_clause, where_args = state_filter.make_sql_filter_clause()

        if not where_clause:
            # We delegate to the cached version
            return self.get_current_state_ids(room_id)

        def _get_filtered_current_state_ids_txn(txn):
            results = {}
            sql = """
                SELECT type, state_key, event_id FROM current_state_events
                WHERE room_id = ?
            """

            if where_clause:
                sql += " AND (%s)" % (where_clause, )

            args = [room_id]
            args.extend(where_args)
            txn.execute(sql, args)
            for row in txn:
                typ, state_key, event_id = row
                key = (intern_string(typ), intern_string(state_key))
                results[key] = event_id

            return results

        return self.runInteraction("get_filtered_current_state_ids",
                                   _get_filtered_current_state_ids_txn)

    @defer.inlineCallbacks
    def get_canonical_alias_for_room(self, room_id):
        """Get canonical alias for room, if any

        Args:
            room_id (str)

        Returns:
            Deferred[str|None]: The canonical alias, if any
        """

        state = yield self.get_filtered_current_state_ids(
            room_id, StateFilter.from_types([(EventTypes.CanonicalAlias, "")]))

        event_id = state.get((EventTypes.CanonicalAlias, ""))
        if not event_id:
            return

        event = yield self.get_event(event_id, allow_none=True)
        if not event:
            return

        return event.content.get("canonical_alias")

    @cached(max_entries=10000, iterable=True)
    def get_state_group_delta(self, state_group):
        """Given a state group try to return a previous group and a delta between
        the old and the new.

        Returns:
            (prev_group, delta_ids), where both may be None.
        """
        def _get_state_group_delta_txn(txn):
            prev_group = self._simple_select_one_onecol_txn(
                txn,
                table="state_group_edges",
                keyvalues={"state_group": state_group},
                retcol="prev_state_group",
                allow_none=True,
            )

            if not prev_group:
                return _GetStateGroupDelta(None, None)

            delta_ids = self._simple_select_list_txn(
                txn,
                table="state_groups_state",
                keyvalues={"state_group": state_group},
                retcols=("type", "state_key", "event_id"),
            )

            return _GetStateGroupDelta(
                prev_group,
                {(row["type"], row["state_key"]): row["event_id"]
                 for row in delta_ids},
            )

        return self.runInteraction("get_state_group_delta",
                                   _get_state_group_delta_txn)

    @defer.inlineCallbacks
    def get_state_groups_ids(self, _room_id, event_ids):
        """Get the event IDs of all the state for the state groups for the given events

        Args:
            _room_id (str): id of the room for these events
            event_ids (iterable[str]): ids of the events

        Returns:
            Deferred[dict[int, dict[tuple[str, str], str]]]:
                dict of state_group_id -> (dict of (type, state_key) -> event id)
        """
        if not event_ids:
            return {}

        event_to_groups = yield self._get_state_group_for_events(event_ids)

        groups = set(itervalues(event_to_groups))
        group_to_state = yield self._get_state_for_groups(groups)

        return group_to_state

    @defer.inlineCallbacks
    def get_state_ids_for_group(self, state_group):
        """Get the event IDs of all the state in the given state group

        Args:
            state_group (int)

        Returns:
            Deferred[dict]: Resolves to a map of (type, state_key) -> event_id
        """
        group_to_state = yield self._get_state_for_groups((state_group, ))

        return group_to_state[state_group]

    @defer.inlineCallbacks
    def get_state_groups(self, room_id, event_ids):
        """ Get the state groups for the given list of event_ids

        Returns:
            Deferred[dict[int, list[EventBase]]]:
                dict of state_group_id -> list of state events.
        """
        if not event_ids:
            return {}

        group_to_ids = yield self.get_state_groups_ids(room_id, event_ids)

        state_event_map = yield self.get_events(
            [
                ev_id for group_ids in itervalues(group_to_ids)
                for ev_id in itervalues(group_ids)
            ],
            get_prev_content=False,
        )

        return {
            group: [
                state_event_map[v] for v in itervalues(event_id_map)
                if v in state_event_map
            ]
            for group, event_id_map in iteritems(group_to_ids)
        }

    @defer.inlineCallbacks
    def _get_state_groups_from_groups(self, groups, state_filter):
        """Returns the state groups for a given set of groups, filtering on
        types of state events.

        Args:
            groups(list[int]): list of state group IDs to query
            state_filter (StateFilter): The state filter used to fetch state
                from the database.
        Returns:
            Deferred[dict[int, dict[tuple[str, str], str]]]:
                dict of state_group_id -> (dict of (type, state_key) -> event id)
        """
        results = {}

        chunks = [groups[i:i + 100] for i in range(0, len(groups), 100)]
        for chunk in chunks:
            res = yield self.runInteraction(
                "_get_state_groups_from_groups",
                self._get_state_groups_from_groups_txn,
                chunk,
                state_filter,
            )
            results.update(res)

        return results

    @defer.inlineCallbacks
    def get_state_for_events(self, event_ids, state_filter=StateFilter.all()):
        """Given a list of event_ids and type tuples, return a list of state
        dicts for each event.

        Args:
            event_ids (list[string])
            state_filter (StateFilter): The state filter used to fetch state
                from the database.

        Returns:
            deferred: A dict of (event_id) -> (type, state_key) -> [state_events]
        """
        event_to_groups = yield self._get_state_group_for_events(event_ids)

        groups = set(itervalues(event_to_groups))
        group_to_state = yield self._get_state_for_groups(groups, state_filter)

        state_event_map = yield self.get_events(
            [
                ev_id for sd in itervalues(group_to_state)
                for ev_id in itervalues(sd)
            ],
            get_prev_content=False,
        )

        event_to_state = {
            event_id: {
                k: state_event_map[v]
                for k, v in iteritems(group_to_state[group])
                if v in state_event_map
            }
            for event_id, group in iteritems(event_to_groups)
        }

        return {event: event_to_state[event] for event in event_ids}

    @defer.inlineCallbacks
    def get_state_ids_for_events(self,
                                 event_ids,
                                 state_filter=StateFilter.all()):
        """
        Get the state dicts corresponding to a list of events, containing the event_ids
        of the state events (as opposed to the events themselves)

        Args:
            event_ids(list(str)): events whose state should be returned
            state_filter (StateFilter): The state filter used to fetch state
                from the database.

        Returns:
            A deferred dict from event_id -> (type, state_key) -> event_id
        """
        event_to_groups = yield self._get_state_group_for_events(event_ids)

        groups = set(itervalues(event_to_groups))
        group_to_state = yield self._get_state_for_groups(groups, state_filter)

        event_to_state = {
            event_id: group_to_state[group]
            for event_id, group in iteritems(event_to_groups)
        }

        return {event: event_to_state[event] for event in event_ids}

    @defer.inlineCallbacks
    def get_state_for_event(self, event_id, state_filter=StateFilter.all()):
        """
        Get the state dict corresponding to a particular event

        Args:
            event_id(str): event whose state should be returned
            state_filter (StateFilter): The state filter used to fetch state
                from the database.

        Returns:
            A deferred dict from (type, state_key) -> state_event
        """
        state_map = yield self.get_state_for_events([event_id], state_filter)
        return state_map[event_id]

    @defer.inlineCallbacks
    def get_state_ids_for_event(self, event_id,
                                state_filter=StateFilter.all()):
        """
        Get the state dict corresponding to a particular event

        Args:
            event_id(str): event whose state should be returned
            state_filter (StateFilter): The state filter used to fetch state
                from the database.

        Returns:
            A deferred dict from (type, state_key) -> state_event
        """
        state_map = yield self.get_state_ids_for_events([event_id],
                                                        state_filter)
        return state_map[event_id]

    @cached(max_entries=50000)
    def _get_state_group_for_event(self, event_id):
        return self._simple_select_one_onecol(
            table="event_to_state_groups",
            keyvalues={"event_id": event_id},
            retcol="state_group",
            allow_none=True,
            desc="_get_state_group_for_event",
        )

    @cachedList(
        cached_method_name="_get_state_group_for_event",
        list_name="event_ids",
        num_args=1,
        inlineCallbacks=True,
    )
    def _get_state_group_for_events(self, event_ids):
        """Returns mapping event_id -> state_group
        """
        rows = yield self._simple_select_many_batch(
            table="event_to_state_groups",
            column="event_id",
            iterable=event_ids,
            keyvalues={},
            retcols=("event_id", "state_group"),
            desc="_get_state_group_for_events",
        )

        return {row["event_id"]: row["state_group"] for row in rows}

    def _get_state_for_group_using_cache(self, cache, group, state_filter):
        """Checks if group is in cache. See `_get_state_for_groups`

        Args:
            cache(DictionaryCache): the state group cache to use
            group(int): The state group to lookup
            state_filter (StateFilter): The state filter used to fetch state
                from the database.

        Returns 2-tuple (`state_dict`, `got_all`).
        `got_all` is a bool indicating if we successfully retrieved all
        requests state from the cache, if False we need to query the DB for the
        missing state.
        """
        is_all, known_absent, state_dict_ids = cache.get(group)

        if is_all or state_filter.is_full():
            # Either we have everything or want everything, either way
            # `is_all` tells us whether we've gotten everything.
            return state_filter.filter_state(state_dict_ids), is_all

        # tracks whether any of our requested types are missing from the cache
        missing_types = False

        if state_filter.has_wildcards():
            # We don't know if we fetched all the state keys for the types in
            # the filter that are wildcards, so we have to assume that we may
            # have missed some.
            missing_types = True
        else:
            # There aren't any wild cards, so `concrete_types()` returns the
            # complete list of event types we're wanting.
            for key in state_filter.concrete_types():
                if key not in state_dict_ids and key not in known_absent:
                    missing_types = True
                    break

        return state_filter.filter_state(state_dict_ids), not missing_types

    @defer.inlineCallbacks
    def _get_state_for_groups(self, groups, state_filter=StateFilter.all()):
        """Gets the state at each of a list of state groups, optionally
        filtering by type/state_key

        Args:
            groups (iterable[int]): list of state groups for which we want
                to get the state.
            state_filter (StateFilter): The state filter used to fetch state
                from the database.
        Returns:
            Deferred[dict[int, dict[tuple[str, str], str]]]:
                dict of state_group_id -> (dict of (type, state_key) -> event id)
        """

        member_filter, non_member_filter = state_filter.get_member_split()

        # Now we look them up in the member and non-member caches
        (
            non_member_state,
            incomplete_groups_nm,
        ) = yield self._get_state_for_groups_using_cache(
            groups, self._state_group_cache, state_filter=non_member_filter)

        (
            member_state,
            incomplete_groups_m,
        ) = yield self._get_state_for_groups_using_cache(
            groups,
            self._state_group_members_cache,
            state_filter=member_filter)

        state = dict(non_member_state)
        for group in groups:
            state[group].update(member_state[group])

        # Now fetch any missing groups from the database

        incomplete_groups = incomplete_groups_m | incomplete_groups_nm

        if not incomplete_groups:
            return state

        cache_sequence_nm = self._state_group_cache.sequence
        cache_sequence_m = self._state_group_members_cache.sequence

        # Help the cache hit ratio by expanding the filter a bit
        db_state_filter = state_filter.return_expanded()

        group_to_state_dict = yield self._get_state_groups_from_groups(
            list(incomplete_groups), state_filter=db_state_filter)

        # Now lets update the caches
        self._insert_into_cache(
            group_to_state_dict,
            db_state_filter,
            cache_seq_num_members=cache_sequence_m,
            cache_seq_num_non_members=cache_sequence_nm,
        )

        # And finally update the result dict, by filtering out any extra
        # stuff we pulled out of the database.
        for group, group_state_dict in iteritems(group_to_state_dict):
            # We just replace any existing entries, as we will have loaded
            # everything we need from the database anyway.
            state[group] = state_filter.filter_state(group_state_dict)

        return state

    def _get_state_for_groups_using_cache(self, groups, cache, state_filter):
        """Gets the state at each of a list of state groups, optionally
        filtering by type/state_key, querying from a specific cache.

        Args:
            groups (iterable[int]): list of state groups for which we want
                to get the state.
            cache (DictionaryCache): the cache of group ids to state dicts which
                we will pass through - either the normal state cache or the specific
                members state cache.
            state_filter (StateFilter): The state filter used to fetch state
                from the database.

        Returns:
            tuple[dict[int, dict[tuple[str, str], str]], set[int]]: Tuple of
            dict of state_group_id -> (dict of (type, state_key) -> event id)
            of entries in the cache, and the state group ids either missing
            from the cache or incomplete.
        """
        results = {}
        incomplete_groups = set()
        for group in set(groups):
            state_dict_ids, got_all = self._get_state_for_group_using_cache(
                cache, group, state_filter)
            results[group] = state_dict_ids

            if not got_all:
                incomplete_groups.add(group)

        return results, incomplete_groups

    def _insert_into_cache(
        self,
        group_to_state_dict,
        state_filter,
        cache_seq_num_members,
        cache_seq_num_non_members,
    ):
        """Inserts results from querying the database into the relevant cache.

        Args:
            group_to_state_dict (dict): The new entries pulled from database.
                Map from state group to state dict
            state_filter (StateFilter): The state filter used to fetch state
                from the database.
            cache_seq_num_members (int): Sequence number of member cache since
                last lookup in cache
            cache_seq_num_non_members (int): Sequence number of member cache since
                last lookup in cache
        """

        # We need to work out which types we've fetched from the DB for the
        # member vs non-member caches. This should be as accurate as possible,
        # but can be an underestimate (e.g. when we have wild cards)

        member_filter, non_member_filter = state_filter.get_member_split()
        if member_filter.is_full():
            # We fetched all member events
            member_types = None
        else:
            # `concrete_types()` will only return a subset when there are wild
            # cards in the filter, but that's fine.
            member_types = member_filter.concrete_types()

        if non_member_filter.is_full():
            # We fetched all non member events
            non_member_types = None
        else:
            non_member_types = non_member_filter.concrete_types()

        for group, group_state_dict in iteritems(group_to_state_dict):
            state_dict_members = {}
            state_dict_non_members = {}

            for k, v in iteritems(group_state_dict):
                if k[0] == EventTypes.Member:
                    state_dict_members[k] = v
                else:
                    state_dict_non_members[k] = v

            self._state_group_members_cache.update(
                cache_seq_num_members,
                key=group,
                value=state_dict_members,
                fetched_keys=member_types,
            )

            self._state_group_cache.update(
                cache_seq_num_non_members,
                key=group,
                value=state_dict_non_members,
                fetched_keys=non_member_types,
            )

    def store_state_group(self, event_id, room_id, prev_group, delta_ids,
                          current_state_ids):
        """Store a new set of state, returning a newly assigned state group.

        Args:
            event_id (str): The event ID for which the state was calculated
            room_id (str)
            prev_group (int|None): A previous state group for the room, optional.
            delta_ids (dict|None): The delta between state at `prev_group` and
                `current_state_ids`, if `prev_group` was given. Same format as
                `current_state_ids`.
            current_state_ids (dict): The state to store. Map of (type, state_key)
                to event_id.

        Returns:
            Deferred[int]: The state group ID
        """
        def _store_state_group_txn(txn):
            if current_state_ids is None:
                # AFAIK, this can never happen
                raise Exception("current_state_ids cannot be None")

            state_group = self.database_engine.get_next_state_group_id(txn)

            self._simple_insert_txn(
                txn,
                table="state_groups",
                values={
                    "id": state_group,
                    "room_id": room_id,
                    "event_id": event_id
                },
            )

            # We persist as a delta if we can, while also ensuring the chain
            # of deltas isn't tooo long, as otherwise read performance degrades.
            if prev_group:
                is_in_db = self._simple_select_one_onecol_txn(
                    txn,
                    table="state_groups",
                    keyvalues={"id": prev_group},
                    retcol="id",
                    allow_none=True,
                )
                if not is_in_db:
                    raise Exception(
                        "Trying to persist state with unpersisted prev_group: %r"
                        % (prev_group, ))

                potential_hops = self._count_state_group_hops_txn(
                    txn, prev_group)
            if prev_group and potential_hops < MAX_STATE_DELTA_HOPS:
                self._simple_insert_txn(
                    txn,
                    table="state_group_edges",
                    values={
                        "state_group": state_group,
                        "prev_state_group": prev_group
                    },
                )

                self._simple_insert_many_txn(
                    txn,
                    table="state_groups_state",
                    values=[{
                        "state_group": state_group,
                        "room_id": room_id,
                        "type": key[0],
                        "state_key": key[1],
                        "event_id": state_id,
                    } for key, state_id in iteritems(delta_ids)],
                )
            else:
                self._simple_insert_many_txn(
                    txn,
                    table="state_groups_state",
                    values=[{
                        "state_group": state_group,
                        "room_id": room_id,
                        "type": key[0],
                        "state_key": key[1],
                        "event_id": state_id,
                    } for key, state_id in iteritems(current_state_ids)],
                )

            # Prefill the state group caches with this group.
            # It's fine to use the sequence like this as the state group map
            # is immutable. (If the map wasn't immutable then this prefill could
            # race with another update)

            current_member_state_ids = {
                s: ev
                for (s, ev) in iteritems(current_state_ids)
                if s[0] == EventTypes.Member
            }
            txn.call_after(
                self._state_group_members_cache.update,
                self._state_group_members_cache.sequence,
                key=state_group,
                value=dict(current_member_state_ids),
            )

            current_non_member_state_ids = {
                s: ev
                for (s, ev) in iteritems(current_state_ids)
                if s[0] != EventTypes.Member
            }
            txn.call_after(
                self._state_group_cache.update,
                self._state_group_cache.sequence,
                key=state_group,
                value=dict(current_non_member_state_ids),
            )

            return state_group

        return self.runInteraction("store_state_group", _store_state_group_txn)

    @defer.inlineCallbacks
    def get_referenced_state_groups(self, state_groups):
        """Check if the state groups are referenced by events.

        Args:
            state_groups (Iterable[int])

        Returns:
            Deferred[set[int]]: The subset of state groups that are
            referenced.
        """

        rows = yield self._simple_select_many_batch(
            table="event_to_state_groups",
            column="state_group",
            iterable=state_groups,
            keyvalues={},
            retcols=("DISTINCT state_group", ),
            desc="get_referenced_state_groups",
        )

        return set(row["state_group"] for row in rows)
Example #8
0
    def __init__(self, db_conn, hs):
        super(StateGroupWorkerStore, self).__init__(db_conn, hs)

        self._state_group_cache = DictionaryCache(
            "*stateGroupCache*",
            500000 * get_cache_factor_for("stateGroupCache"))
Example #9
0
class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
    """The parts of StateGroupStore that can be called from workers.
    """

    STATE_GROUP_DEDUPLICATION_UPDATE_NAME = "state_group_state_deduplication"
    STATE_GROUP_INDEX_UPDATE_NAME = "state_group_state_type_index"
    CURRENT_STATE_INDEX_UPDATE_NAME = "current_state_members_idx"

    def __init__(self, db_conn, hs):
        super(StateGroupWorkerStore, self).__init__(db_conn, hs)

        self._state_group_cache = DictionaryCache(
            "*stateGroupCache*",
            500000 * get_cache_factor_for("stateGroupCache"))

    @defer.inlineCallbacks
    def get_room_version(self, room_id):
        """Get the room_version of a given room

        Args:
            room_id (str)

        Returns:
            Deferred[str]

        Raises:
            NotFoundError if the room is unknown
        """
        # for now we do this by looking at the create event. We may want to cache this
        # more intelligently in future.
        state_ids = yield self.get_current_state_ids(room_id)
        create_id = state_ids.get((EventTypes.Create, ""))

        if not create_id:
            raise NotFoundError("Unknown room")

        create_event = yield self.get_event(create_id)
        defer.returnValue(create_event.content.get("room_version", "1"))

    @cached(max_entries=100000, iterable=True)
    def get_current_state_ids(self, room_id):
        """Get the current state event ids for a room based on the
        current_state_events table.

        Args:
            room_id (str)

        Returns:
            deferred: dict of (type, state_key) -> event_id
        """
        def _get_current_state_ids_txn(txn):
            txn.execute(
                """SELECT type, state_key, event_id FROM current_state_events
                WHERE room_id = ?
                """, (room_id, ))

            return {(intern_string(r[0]), intern_string(r[1])): to_ascii(r[2])
                    for r in txn}

        return self.runInteraction(
            "get_current_state_ids",
            _get_current_state_ids_txn,
        )

    # FIXME: how should this be cached?
    def get_filtered_current_state_ids(self,
                                       room_id,
                                       types,
                                       filtered_types=None):
        """Get the current state event of a given type for a room based on the
        current_state_events table.  This may not be as up-to-date as the result
        of doing a fresh state resolution as per state_handler.get_current_state
        Args:
            room_id (str)
            types (list[(Str, (Str|None))]): List of (type, state_key) tuples
                which are used to filter the state fetched. `state_key` may be
                None, which matches any `state_key`
            filtered_types (list[Str]|None): List of types to apply the above filter to.
        Returns:
            deferred: dict of (type, state_key) -> event
        """

        include_other_types = False if filtered_types is None else True

        def _get_filtered_current_state_ids_txn(txn):
            results = {}
            sql = """SELECT type, state_key, event_id FROM current_state_events
                     WHERE room_id = ? %s"""
            # Turns out that postgres doesn't like doing a list of OR's and
            # is about 1000x slower, so we just issue a query for each specific
            # type seperately.
            if types:
                clause_to_args = [("AND type = ? AND state_key = ?",
                                   (etype,
                                    state_key)) if state_key is not None else
                                  ("AND type = ?", (etype, ))
                                  for etype, state_key in types]

                if include_other_types:
                    unique_types = set(filtered_types)
                    clause_to_args.append(
                        ("AND type <> ? " * len(unique_types),
                         list(unique_types)))
            else:
                # If types is None we fetch all the state, and so just use an
                # empty where clause with no extra args.
                clause_to_args = [("", [])]
            for where_clause, where_args in clause_to_args:
                args = [room_id]
                args.extend(where_args)
                txn.execute(sql % (where_clause, ), args)
                for row in txn:
                    typ, state_key, event_id = row
                    key = (intern_string(typ), intern_string(state_key))
                    results[key] = event_id
            return results

        return self.runInteraction(
            "get_filtered_current_state_ids",
            _get_filtered_current_state_ids_txn,
        )

    @cached(max_entries=10000, iterable=True)
    def get_state_group_delta(self, state_group):
        """Given a state group try to return a previous group and a delta between
        the old and the new.

        Returns:
            (prev_group, delta_ids), where both may be None.
        """
        def _get_state_group_delta_txn(txn):
            prev_group = self._simple_select_one_onecol_txn(
                txn,
                table="state_group_edges",
                keyvalues={
                    "state_group": state_group,
                },
                retcol="prev_state_group",
                allow_none=True,
            )

            if not prev_group:
                return _GetStateGroupDelta(None, None)

            delta_ids = self._simple_select_list_txn(
                txn,
                table="state_groups_state",
                keyvalues={
                    "state_group": state_group,
                },
                retcols=(
                    "type",
                    "state_key",
                    "event_id",
                ))

            return _GetStateGroupDelta(
                prev_group, {(row["type"], row["state_key"]): row["event_id"]
                             for row in delta_ids})

        return self.runInteraction(
            "get_state_group_delta",
            _get_state_group_delta_txn,
        )

    @defer.inlineCallbacks
    def get_state_groups_ids(self, room_id, event_ids):
        if not event_ids:
            defer.returnValue({})

        event_to_groups = yield self._get_state_group_for_events(event_ids, )

        groups = set(itervalues(event_to_groups))
        group_to_state = yield self._get_state_for_groups(groups)

        defer.returnValue(group_to_state)

    @defer.inlineCallbacks
    def get_state_ids_for_group(self, state_group):
        """Get the state IDs for the given state group

        Args:
            state_group (int)

        Returns:
            Deferred[dict]: Resolves to a map of (type, state_key) -> event_id
        """
        group_to_state = yield self._get_state_for_groups((state_group, ))

        defer.returnValue(group_to_state[state_group])

    @defer.inlineCallbacks
    def get_state_groups(self, room_id, event_ids):
        """ Get the state groups for the given list of event_ids

        The return value is a dict mapping group names to lists of events.
        """
        if not event_ids:
            defer.returnValue({})

        group_to_ids = yield self.get_state_groups_ids(room_id, event_ids)

        state_event_map = yield self.get_events([
            ev_id for group_ids in itervalues(group_to_ids)
            for ev_id in itervalues(group_ids)
        ],
                                                get_prev_content=False)

        defer.returnValue({
            group: [
                state_event_map[v] for v in itervalues(event_id_map)
                if v in state_event_map
            ]
            for group, event_id_map in iteritems(group_to_ids)
        })

    @defer.inlineCallbacks
    def _get_state_groups_from_groups(self, groups, types):
        """Returns the state groups for a given set of groups, filtering on
        types of state events.

        Args:
            groups(list[int]): list of state group IDs to query
            types (Iterable[str, str|None]|None): list of 2-tuples of the form
                (`type`, `state_key`), where a `state_key` of `None` matches all
                state_keys for the `type`. If None, all types are returned.

        Returns:
            dictionary state_group -> (dict of (type, state_key) -> event id)
        """
        results = {}

        chunks = [groups[i:i + 100] for i in range(0, len(groups), 100)]
        for chunk in chunks:
            res = yield self.runInteraction(
                "_get_state_groups_from_groups",
                self._get_state_groups_from_groups_txn,
                chunk,
                types,
            )
            results.update(res)

        defer.returnValue(results)

    def _get_state_groups_from_groups_txn(
        self,
        txn,
        groups,
        types=None,
    ):
        results = {group: {} for group in groups}

        if types is not None:
            types = list(set(types))  # deduplicate types list

        if isinstance(self.database_engine, PostgresEngine):
            # Temporarily disable sequential scans in this transaction. This is
            # a temporary hack until we can add the right indices in
            txn.execute("SET LOCAL enable_seqscan=off")

            # The below query walks the state_group tree so that the "state"
            # table includes all state_groups in the tree. It then joins
            # against `state_groups_state` to fetch the latest state.
            # It assumes that previous state groups are always numerically
            # lesser.
            # The PARTITION is used to get the event_id in the greatest state
            # group for the given type, state_key.
            # This may return multiple rows per (type, state_key), but last_value
            # should be the same.
            sql = ("""
                WITH RECURSIVE state(state_group) AS (
                    VALUES(?::bigint)
                    UNION ALL
                    SELECT prev_state_group FROM state_group_edges e, state s
                    WHERE s.state_group = e.state_group
                )
                SELECT type, state_key, last_value(event_id) OVER (
                    PARTITION BY type, state_key ORDER BY state_group ASC
                    ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING
                ) AS event_id FROM state_groups_state
                WHERE state_group IN (
                    SELECT state_group FROM state
                )
                %s
            """)

            # Turns out that postgres doesn't like doing a list of OR's and
            # is about 1000x slower, so we just issue a query for each specific
            # type seperately.
            if types is not None:
                clause_to_args = [("AND type = ? AND state_key = ?",
                                   (etype,
                                    state_key)) if state_key is not None else
                                  ("AND type = ?", (etype, ))
                                  for etype, state_key in types]
            else:
                # If types is None we fetch all the state, and so just use an
                # empty where clause with no extra args.
                clause_to_args = [("", [])]

            for where_clause, where_args in clause_to_args:
                for group in groups:
                    args = [group]
                    args.extend(where_args)

                    txn.execute(sql % (where_clause, ), args)
                    for row in txn:
                        typ, state_key, event_id = row
                        key = (typ, state_key)
                        results[group][key] = event_id
        else:
            where_args = []
            where_clauses = []
            wildcard_types = False
            if types is not None:
                for typ in types:
                    if typ[1] is None:
                        where_clauses.append("(type = ?)")
                        where_args.append(typ[0])
                        wildcard_types = True
                    else:
                        where_clauses.append("(type = ? AND state_key = ?)")
                        where_args.extend([typ[0], typ[1]])

                where_clause = "AND (%s)" % (" OR ".join(where_clauses))
            else:
                where_clause = ""

            # We don't use WITH RECURSIVE on sqlite3 as there are distributions
            # that ship with an sqlite3 version that doesn't support it (e.g. wheezy)
            for group in groups:
                next_group = group

                while next_group:
                    # We did this before by getting the list of group ids, and
                    # then passing that list to sqlite to get latest event for
                    # each (type, state_key). However, that was terribly slow
                    # without the right indices (which we can't add until
                    # after we finish deduping state, which requires this func)
                    args = [next_group]
                    if types:
                        args.extend(where_args)

                    txn.execute(
                        "SELECT type, state_key, event_id FROM state_groups_state"
                        " WHERE state_group = ? %s" % (where_clause, ), args)
                    results[group].update(
                        ((typ, state_key), event_id)
                        for typ, state_key, event_id in txn
                        if (typ, state_key) not in results[group])

                    # If the number of entries in the (type,state_key)->event_id dict
                    # matches the number of (type,state_keys) types we were searching
                    # for, then we must have found them all, so no need to go walk
                    # further down the tree... UNLESS our types filter contained
                    # wildcards (i.e. Nones) in which case we have to do an exhaustive
                    # search
                    if (types is not None and not wildcard_types
                            and len(results[group]) == len(types)):
                        break

                    next_group = self._simple_select_one_onecol_txn(
                        txn,
                        table="state_group_edges",
                        keyvalues={"state_group": next_group},
                        retcol="prev_state_group",
                        allow_none=True,
                    )

        return results

    @defer.inlineCallbacks
    def get_state_for_events(self, event_ids, types, filtered_types=None):
        """Given a list of event_ids and type tuples, return a list of state
        dicts for each event. The state dicts will only have the type/state_keys
        that are in the `types` list.

        Args:
            event_ids (list[string])
            types (list[(str, str|None)]|None): List of (type, state_key) tuples
                which are used to filter the state fetched. If `state_key` is None,
                all events are returned of the given type.
                May be None, which matches any key.
            filtered_types(list[str]|None): Only apply filtering via `types` to this
                list of event types.  Other types of events are returned unfiltered.
                If None, `types` filtering is applied to all events.

        Returns:
            deferred: A dict of (event_id) -> (type, state_key) -> [state_events]
        """
        event_to_groups = yield self._get_state_group_for_events(event_ids, )

        groups = set(itervalues(event_to_groups))
        group_to_state = yield self._get_state_for_groups(
            groups, types, filtered_types)

        state_event_map = yield self.get_events([
            ev_id for sd in itervalues(group_to_state)
            for ev_id in itervalues(sd)
        ],
                                                get_prev_content=False)

        event_to_state = {
            event_id: {
                k: state_event_map[v]
                for k, v in iteritems(group_to_state[group])
                if v in state_event_map
            }
            for event_id, group in iteritems(event_to_groups)
        }

        defer.returnValue(
            {event: event_to_state[event]
             for event in event_ids})

    @defer.inlineCallbacks
    def get_state_ids_for_events(self,
                                 event_ids,
                                 types=None,
                                 filtered_types=None):
        """
        Get the state dicts corresponding to a list of events, containing the event_ids
        of the state events (as opposed to the events themselves)

        Args:
            event_ids(list(str)): events whose state should be returned
            types(list[(str, str|None)]|None): List of (type, state_key) tuples
                which are used to filter the state fetched. If `state_key` is None,
                all events are returned of the given type.
                May be None, which matches any key.
            filtered_types(list[str]|None): Only apply filtering via `types` to this
                list of event types.  Other types of events are returned unfiltered.
                If None, `types` filtering is applied to all events.

        Returns:
            A deferred dict from event_id -> (type, state_key) -> event_id
        """
        event_to_groups = yield self._get_state_group_for_events(event_ids, )

        groups = set(itervalues(event_to_groups))
        group_to_state = yield self._get_state_for_groups(
            groups, types, filtered_types)

        event_to_state = {
            event_id: group_to_state[group]
            for event_id, group in iteritems(event_to_groups)
        }

        defer.returnValue(
            {event: event_to_state[event]
             for event in event_ids})

    @defer.inlineCallbacks
    def get_state_for_event(self, event_id, types=None, filtered_types=None):
        """
        Get the state dict corresponding to a particular event

        Args:
            event_id(str): event whose state should be returned
            types(list[(str, str|None)]|None): List of (type, state_key) tuples
                which are used to filter the state fetched. If `state_key` is None,
                all events are returned of the given type.
                May be None, which matches any key.
            filtered_types(list[str]|None): Only apply filtering via `types` to this
                list of event types.  Other types of events are returned unfiltered.
                If None, `types` filtering is applied to all events.

        Returns:
            A deferred dict from (type, state_key) -> state_event
        """
        state_map = yield self.get_state_for_events([event_id], types,
                                                    filtered_types)
        defer.returnValue(state_map[event_id])

    @defer.inlineCallbacks
    def get_state_ids_for_event(self,
                                event_id,
                                types=None,
                                filtered_types=None):
        """
        Get the state dict corresponding to a particular event

        Args:
            event_id(str): event whose state should be returned
            types(list[(str, str|None)]|None): List of (type, state_key) tuples
                which are used to filter the state fetched. If `state_key` is None,
                all events are returned of the given type.
                May be None, which matches any key.
            filtered_types(list[str]|None): Only apply filtering via `types` to this
                list of event types.  Other types of events are returned unfiltered.
                If None, `types` filtering is applied to all events.

        Returns:
            A deferred dict from (type, state_key) -> state_event
        """
        state_map = yield self.get_state_ids_for_events([event_id], types,
                                                        filtered_types)
        defer.returnValue(state_map[event_id])

    @cached(max_entries=50000)
    def _get_state_group_for_event(self, event_id):
        return self._simple_select_one_onecol(
            table="event_to_state_groups",
            keyvalues={
                "event_id": event_id,
            },
            retcol="state_group",
            allow_none=True,
            desc="_get_state_group_for_event",
        )

    @cachedList(cached_method_name="_get_state_group_for_event",
                list_name="event_ids",
                num_args=1,
                inlineCallbacks=True)
    def _get_state_group_for_events(self, event_ids):
        """Returns mapping event_id -> state_group
        """
        rows = yield self._simple_select_many_batch(
            table="event_to_state_groups",
            column="event_id",
            iterable=event_ids,
            keyvalues={},
            retcols=(
                "event_id",
                "state_group",
            ),
            desc="_get_state_group_for_events",
        )

        defer.returnValue(
            {row["event_id"]: row["state_group"]
             for row in rows})

    def _get_some_state_from_cache(self, group, types, filtered_types=None):
        """Checks if group is in cache. See `_get_state_for_groups`

        Args:
            group(int): The state group to lookup
            types(list[str, str|None]): List of 2-tuples of the form
                (`type`, `state_key`), where a `state_key` of `None` matches all
                state_keys for the `type`.
            filtered_types(list[str]|None): Only apply filtering via `types` to this
                list of event types.  Other types of events are returned unfiltered.
                If None, `types` filtering is applied to all events.

        Returns 2-tuple (`state_dict`, `got_all`).
        `got_all` is a bool indicating if we successfully retrieved all
        requests state from the cache, if False we need to query the DB for the
        missing state.
        """
        is_all, known_absent, state_dict_ids = self._state_group_cache.get(
            group)

        type_to_key = {}

        # tracks whether any of ourrequested types are missing from the cache
        missing_types = False

        for typ, state_key in types:
            key = (typ, state_key)

            if (state_key is None or
                (filtered_types is not None and typ not in filtered_types)):
                type_to_key[typ] = None
                # we mark the type as missing from the cache because
                # when the cache was populated it might have been done with a
                # restricted set of state_keys, so the wildcard will not work
                # and the cache may be incomplete.
                missing_types = True
            else:
                if type_to_key.get(typ, object()) is not None:
                    type_to_key.setdefault(typ, set()).add(state_key)

                if key not in state_dict_ids and key not in known_absent:
                    missing_types = True

        sentinel = object()

        def include(typ, state_key):
            valid_state_keys = type_to_key.get(typ, sentinel)
            if valid_state_keys is sentinel:
                return filtered_types is not None and typ not in filtered_types
            if valid_state_keys is None:
                return True
            if state_key in valid_state_keys:
                return True
            return False

        got_all = is_all
        if not got_all:
            # the cache is incomplete. We may still have got all the results we need, if
            # we don't have any wildcards in the match list.
            if not missing_types and filtered_types is None:
                got_all = True

        return {
            k: v
            for k, v in iteritems(state_dict_ids) if include(k[0], k[1])
        }, got_all

    def _get_all_state_from_cache(self, group):
        """Checks if group is in cache. See `_get_state_for_groups`

        Returns 2-tuple (`state_dict`, `got_all`). `got_all` is a bool
        indicating if we successfully retrieved all requests state from the
        cache, if False we need to query the DB for the missing state.

        Args:
            group: The state group to lookup
        """
        is_all, _, state_dict_ids = self._state_group_cache.get(group)

        return state_dict_ids, is_all

    @defer.inlineCallbacks
    def _get_state_for_groups(self, groups, types=None, filtered_types=None):
        """Gets the state at each of a list of state groups, optionally
        filtering by type/state_key

        Args:
            groups (iterable[int]): list of state groups for which we want
                to get the state.
            types (None|iterable[(str, None|str)]):
                indicates the state type/keys required. If None, the whole
                state is fetched and returned.

                Otherwise, each entry should be a `(type, state_key)` tuple to
                include in the response. A `state_key` of None is a wildcard
                meaning that we require all state with that type.
            filtered_types(list[str]|None): Only apply filtering via `types` to this
                list of event types.  Other types of events are returned unfiltered.
                If None, `types` filtering is applied to all events.

        Returns:
            Deferred[dict[int, dict[(type, state_key), EventBase]]]
                a dictionary mapping from state group to state dictionary.
        """
        if types:
            types = frozenset(types)
        results = {}
        missing_groups = []
        if types is not None:
            for group in set(groups):
                state_dict_ids, got_all = self._get_some_state_from_cache(
                    group, types, filtered_types)
                results[group] = state_dict_ids

                if not got_all:
                    missing_groups.append(group)
        else:
            for group in set(groups):
                state_dict_ids, got_all = self._get_all_state_from_cache(group)

                results[group] = state_dict_ids

                if not got_all:
                    missing_groups.append(group)

        if missing_groups:
            # Okay, so we have some missing_types, lets fetch them.
            cache_seq_num = self._state_group_cache.sequence

            # the DictionaryCache knows if it has *all* the state, but
            # does not know if it has all of the keys of a particular type,
            # which makes wildcard lookups expensive unless we have a complete
            # cache. Hence, if we are doing a wildcard lookup, populate the
            # cache fully so that we can do an efficient lookup next time.

            if filtered_types or (types and any(k is None
                                                for (t, k) in types)):
                types_to_fetch = None
            else:
                types_to_fetch = types

            group_to_state_dict = yield self._get_state_groups_from_groups(
                missing_groups, types_to_fetch)

            for group, group_state_dict in iteritems(group_to_state_dict):
                state_dict = results[group]

                # update the result, filtering by `types`.
                if types:
                    for k, v in iteritems(group_state_dict):
                        (typ, _) = k
                        if ((k in types or (typ, None) in types) or
                            (filtered_types and typ not in filtered_types)):
                            state_dict[k] = v
                else:
                    state_dict.update(group_state_dict)

                # update the cache with all the things we fetched from the
                # database.
                self._state_group_cache.update(
                    cache_seq_num,
                    key=group,
                    value=group_state_dict,
                    fetched_keys=types_to_fetch,
                )

        defer.returnValue(results)

    def store_state_group(self, event_id, room_id, prev_group, delta_ids,
                          current_state_ids):
        """Store a new set of state, returning a newly assigned state group.

        Args:
            event_id (str): The event ID for which the state was calculated
            room_id (str)
            prev_group (int|None): A previous state group for the room, optional.
            delta_ids (dict|None): The delta between state at `prev_group` and
                `current_state_ids`, if `prev_group` was given. Same format as
                `current_state_ids`.
            current_state_ids (dict): The state to store. Map of (type, state_key)
                to event_id.

        Returns:
            Deferred[int]: The state group ID
        """
        def _store_state_group_txn(txn):
            if current_state_ids is None:
                # AFAIK, this can never happen
                raise Exception("current_state_ids cannot be None")

            state_group = self.database_engine.get_next_state_group_id(txn)

            self._simple_insert_txn(
                txn,
                table="state_groups",
                values={
                    "id": state_group,
                    "room_id": room_id,
                    "event_id": event_id,
                },
            )

            # We persist as a delta if we can, while also ensuring the chain
            # of deltas isn't tooo long, as otherwise read performance degrades.
            if prev_group:
                is_in_db = self._simple_select_one_onecol_txn(
                    txn,
                    table="state_groups",
                    keyvalues={"id": prev_group},
                    retcol="id",
                    allow_none=True,
                )
                if not is_in_db:
                    raise Exception(
                        "Trying to persist state with unpersisted prev_group: %r"
                        % (prev_group, ))

                potential_hops = self._count_state_group_hops_txn(
                    txn, prev_group)
            if prev_group and potential_hops < MAX_STATE_DELTA_HOPS:
                self._simple_insert_txn(
                    txn,
                    table="state_group_edges",
                    values={
                        "state_group": state_group,
                        "prev_state_group": prev_group,
                    },
                )

                self._simple_insert_many_txn(
                    txn,
                    table="state_groups_state",
                    values=[{
                        "state_group": state_group,
                        "room_id": room_id,
                        "type": key[0],
                        "state_key": key[1],
                        "event_id": state_id,
                    } for key, state_id in iteritems(delta_ids)],
                )
            else:
                self._simple_insert_many_txn(
                    txn,
                    table="state_groups_state",
                    values=[{
                        "state_group": state_group,
                        "room_id": room_id,
                        "type": key[0],
                        "state_key": key[1],
                        "event_id": state_id,
                    } for key, state_id in iteritems(current_state_ids)],
                )

            # Prefill the state group cache with this group.
            # It's fine to use the sequence like this as the state group map
            # is immutable. (If the map wasn't immutable then this prefill could
            # race with another update)
            txn.call_after(
                self._state_group_cache.update,
                self._state_group_cache.sequence,
                key=state_group,
                value=dict(current_state_ids),
            )

            return state_group

        return self.runInteraction("store_state_group", _store_state_group_txn)

    def _count_state_group_hops_txn(self, txn, state_group):
        """Given a state group, count how many hops there are in the tree.

        This is used to ensure the delta chains don't get too long.
        """
        if isinstance(self.database_engine, PostgresEngine):
            sql = ("""
                WITH RECURSIVE state(state_group) AS (
                    VALUES(?::bigint)
                    UNION ALL
                    SELECT prev_state_group FROM state_group_edges e, state s
                    WHERE s.state_group = e.state_group
                )
                SELECT count(*) FROM state;
            """)

            txn.execute(sql, (state_group, ))
            row = txn.fetchone()
            if row and row[0]:
                return row[0]
            else:
                return 0
        else:
            # We don't use WITH RECURSIVE on sqlite3 as there are distributions
            # that ship with an sqlite3 version that doesn't support it (e.g. wheezy)
            next_group = state_group
            count = 0

            while next_group:
                next_group = self._simple_select_one_onecol_txn(
                    txn,
                    table="state_group_edges",
                    keyvalues={"state_group": next_group},
                    retcol="prev_state_group",
                    allow_none=True,
                )
                if next_group:
                    count += 1

            return count
 def setUp(self):
     self.cache = DictionaryCache("foobar")
class DictCacheTestCase(unittest.TestCase):
    def setUp(self):
        self.cache = DictionaryCache("foobar")

    def test_simple_cache_hit_full(self):
        key = "test_simple_cache_hit_full"

        v = self.cache.get(key)
        self.assertIs(v.full, False)
        self.assertEqual(v.known_absent, set())
        self.assertEqual({}, v.value)

        seq = self.cache.sequence
        test_value = {"test": "test_simple_cache_hit_full"}
        self.cache.update(seq, key, test_value)

        c = self.cache.get(key)
        self.assertEqual(test_value, c.value)

    def test_simple_cache_hit_partial(self):
        key = "test_simple_cache_hit_partial"

        seq = self.cache.sequence
        test_value = {"test": "test_simple_cache_hit_partial"}
        self.cache.update(seq, key, test_value)

        c = self.cache.get(key, ["test"])
        self.assertEqual(test_value, c.value)

    def test_simple_cache_miss_partial(self):
        key = "test_simple_cache_miss_partial"

        seq = self.cache.sequence
        test_value = {"test": "test_simple_cache_miss_partial"}
        self.cache.update(seq, key, test_value)

        c = self.cache.get(key, ["test2"])
        self.assertEqual({}, c.value)

    def test_simple_cache_hit_miss_partial(self):
        key = "test_simple_cache_hit_miss_partial"

        seq = self.cache.sequence
        test_value = {
            "test": "test_simple_cache_hit_miss_partial",
            "test2": "test_simple_cache_hit_miss_partial2",
            "test3": "test_simple_cache_hit_miss_partial3",
        }
        self.cache.update(seq, key, test_value)

        c = self.cache.get(key, ["test2"])
        self.assertEqual({"test2": "test_simple_cache_hit_miss_partial2"},
                         c.value)

    def test_multi_insert(self):
        key = "test_simple_cache_hit_miss_partial"

        seq = self.cache.sequence
        test_value_1 = {"test": "test_simple_cache_hit_miss_partial"}
        self.cache.update(seq, key, test_value_1, fetched_keys=set("test"))

        seq = self.cache.sequence
        test_value_2 = {"test2": "test_simple_cache_hit_miss_partial2"}
        self.cache.update(seq, key, test_value_2, fetched_keys=set("test2"))

        c = self.cache.get(key)
        self.assertEqual(
            {
                "test": "test_simple_cache_hit_miss_partial",
                "test2": "test_simple_cache_hit_miss_partial2",
            },
            c.value,
        )
Example #12
0
    def __init__(self, db_conn, hs):
        super(StateGroupWorkerStore, self).__init__(db_conn, hs)

        self._state_group_cache = DictionaryCache("*stateGroupCache*",
                                                  100000 * CACHE_SIZE_FACTOR)
class DictCacheTestCase(unittest.TestCase):
    def setUp(self):
        self.cache = DictionaryCache("foobar")

    def test_simple_cache_hit_full(self):
        key = "test_simple_cache_hit_full"

        v = self.cache.get(key)
        self.assertEqual((False, set(), {}), v)

        seq = self.cache.sequence
        test_value = {"test": "test_simple_cache_hit_full"}
        self.cache.update(seq, key, test_value, full=True)

        c = self.cache.get(key)
        self.assertEqual(test_value, c.value)

    def test_simple_cache_hit_partial(self):
        key = "test_simple_cache_hit_partial"

        seq = self.cache.sequence
        test_value = {"test": "test_simple_cache_hit_partial"}
        self.cache.update(seq, key, test_value, full=True)

        c = self.cache.get(key, ["test"])
        self.assertEqual(test_value, c.value)

    def test_simple_cache_miss_partial(self):
        key = "test_simple_cache_miss_partial"

        seq = self.cache.sequence
        test_value = {"test": "test_simple_cache_miss_partial"}
        self.cache.update(seq, key, test_value, full=True)

        c = self.cache.get(key, ["test2"])
        self.assertEqual({}, c.value)

    def test_simple_cache_hit_miss_partial(self):
        key = "test_simple_cache_hit_miss_partial"

        seq = self.cache.sequence
        test_value = {
            "test": "test_simple_cache_hit_miss_partial",
            "test2": "test_simple_cache_hit_miss_partial2",
            "test3": "test_simple_cache_hit_miss_partial3",
        }
        self.cache.update(seq, key, test_value, full=True)

        c = self.cache.get(key, ["test2"])
        self.assertEqual({"test2": "test_simple_cache_hit_miss_partial2"},
                         c.value)

    def test_multi_insert(self):
        key = "test_simple_cache_hit_miss_partial"

        seq = self.cache.sequence
        test_value_1 = {
            "test": "test_simple_cache_hit_miss_partial",
        }
        self.cache.update(seq, key, test_value_1, full=False)

        seq = self.cache.sequence
        test_value_2 = {
            "test2": "test_simple_cache_hit_miss_partial2",
        }
        self.cache.update(seq, key, test_value_2, full=False)

        c = self.cache.get(key)
        self.assertEqual(
            {
                "test": "test_simple_cache_hit_miss_partial",
                "test2": "test_simple_cache_hit_miss_partial2",
            }, c.value)
Example #14
0
    def __init__(self, db_conn, hs):
        super(StateGroupWorkerStore, self).__init__(db_conn, hs)

        self._state_group_cache = DictionaryCache(
            "*stateGroupCache*", 100000 * CACHE_SIZE_FACTOR
        )
Example #15
0
class StateGroupWorkerStore(SQLBaseStore):
    """The parts of StateGroupStore that can be called from workers.
    """

    STATE_GROUP_DEDUPLICATION_UPDATE_NAME = "state_group_state_deduplication"
    STATE_GROUP_INDEX_UPDATE_NAME = "state_group_state_type_index"
    CURRENT_STATE_INDEX_UPDATE_NAME = "current_state_members_idx"

    def __init__(self, db_conn, hs):
        super(StateGroupWorkerStore, self).__init__(db_conn, hs)

        self._state_group_cache = DictionaryCache(
            "*stateGroupCache*", 100000 * CACHE_SIZE_FACTOR
        )

    @cached(max_entries=100000, iterable=True)
    def get_current_state_ids(self, room_id):
        """Get the current state event ids for a room based on the
        current_state_events table.

        Args:
            room_id (str)

        Returns:
            deferred: dict of (type, state_key) -> event_id
        """
        def _get_current_state_ids_txn(txn):
            txn.execute(
                """SELECT type, state_key, event_id FROM current_state_events
                WHERE room_id = ?
                """,
                (room_id,)
            )

            return {
                (intern_string(r[0]), intern_string(r[1])): to_ascii(r[2]) for r in txn
            }

        return self.runInteraction(
            "get_current_state_ids",
            _get_current_state_ids_txn,
        )

    @cached(max_entries=10000, iterable=True)
    def get_state_group_delta(self, state_group):
        """Given a state group try to return a previous group and a delta between
        the old and the new.

        Returns:
            (prev_group, delta_ids), where both may be None.
        """
        def _get_state_group_delta_txn(txn):
            prev_group = self._simple_select_one_onecol_txn(
                txn,
                table="state_group_edges",
                keyvalues={
                    "state_group": state_group,
                },
                retcol="prev_state_group",
                allow_none=True,
            )

            if not prev_group:
                return _GetStateGroupDelta(None, None)

            delta_ids = self._simple_select_list_txn(
                txn,
                table="state_groups_state",
                keyvalues={
                    "state_group": state_group,
                },
                retcols=("type", "state_key", "event_id",)
            )

            return _GetStateGroupDelta(prev_group, {
                (row["type"], row["state_key"]): row["event_id"]
                for row in delta_ids
            })
        return self.runInteraction(
            "get_state_group_delta",
            _get_state_group_delta_txn,
        )

    @defer.inlineCallbacks
    def get_state_groups_ids(self, room_id, event_ids):
        if not event_ids:
            defer.returnValue({})

        event_to_groups = yield self._get_state_group_for_events(
            event_ids,
        )

        groups = set(event_to_groups.itervalues())
        group_to_state = yield self._get_state_for_groups(groups)

        defer.returnValue(group_to_state)

    @defer.inlineCallbacks
    def get_state_ids_for_group(self, state_group):
        """Get the state IDs for the given state group

        Args:
            state_group (int)

        Returns:
            Deferred[dict]: Resolves to a map of (type, state_key) -> event_id
        """
        group_to_state = yield self._get_state_for_groups((state_group,))

        defer.returnValue(group_to_state[state_group])

    @defer.inlineCallbacks
    def get_state_groups(self, room_id, event_ids):
        """ Get the state groups for the given list of event_ids

        The return value is a dict mapping group names to lists of events.
        """
        if not event_ids:
            defer.returnValue({})

        group_to_ids = yield self.get_state_groups_ids(room_id, event_ids)

        state_event_map = yield self.get_events(
            [
                ev_id for group_ids in group_to_ids.itervalues()
                for ev_id in group_ids.itervalues()
            ],
            get_prev_content=False
        )

        defer.returnValue({
            group: [
                state_event_map[v] for v in event_id_map.itervalues()
                if v in state_event_map
            ]
            for group, event_id_map in group_to_ids.iteritems()
        })

    @defer.inlineCallbacks
    def _get_state_groups_from_groups(self, groups, types):
        """Returns dictionary state_group -> (dict of (type, state_key) -> event id)
        """
        results = {}

        chunks = [groups[i:i + 100] for i in xrange(0, len(groups), 100)]
        for chunk in chunks:
            res = yield self.runInteraction(
                "_get_state_groups_from_groups",
                self._get_state_groups_from_groups_txn, chunk, types,
            )
            results.update(res)

        defer.returnValue(results)

    def _get_state_groups_from_groups_txn(self, txn, groups, types=None):
        results = {group: {} for group in groups}
        if types is not None:
            types = list(set(types))  # deduplicate types list

        if isinstance(self.database_engine, PostgresEngine):
            # Temporarily disable sequential scans in this transaction. This is
            # a temporary hack until we can add the right indices in
            txn.execute("SET LOCAL enable_seqscan=off")

            # The below query walks the state_group tree so that the "state"
            # table includes all state_groups in the tree. It then joins
            # against `state_groups_state` to fetch the latest state.
            # It assumes that previous state groups are always numerically
            # lesser.
            # The PARTITION is used to get the event_id in the greatest state
            # group for the given type, state_key.
            # This may return multiple rows per (type, state_key), but last_value
            # should be the same.
            sql = ("""
                WITH RECURSIVE state(state_group) AS (
                    VALUES(?::bigint)
                    UNION ALL
                    SELECT prev_state_group FROM state_group_edges e, state s
                    WHERE s.state_group = e.state_group
                )
                SELECT type, state_key, last_value(event_id) OVER (
                    PARTITION BY type, state_key ORDER BY state_group ASC
                    ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING
                ) AS event_id FROM state_groups_state
                WHERE state_group IN (
                    SELECT state_group FROM state
                )
                %s
            """)

            # Turns out that postgres doesn't like doing a list of OR's and
            # is about 1000x slower, so we just issue a query for each specific
            # type seperately.
            if types:
                clause_to_args = [
                    (
                        "AND type = ? AND state_key = ?",
                        (etype, state_key)
                    ) if state_key is not None else (
                        "AND type = ?",
                        (etype,)
                    )
                    for etype, state_key in types
                ]
            else:
                # If types is None we fetch all the state, and so just use an
                # empty where clause with no extra args.
                clause_to_args = [("", [])]

            for where_clause, where_args in clause_to_args:
                for group in groups:
                    args = [group]
                    args.extend(where_args)

                    txn.execute(sql % (where_clause,), args)
                    for row in txn:
                        typ, state_key, event_id = row
                        key = (typ, state_key)
                        results[group][key] = event_id
        else:
            where_args = []
            where_clauses = []
            wildcard_types = False
            if types is not None:
                for typ in types:
                    if typ[1] is None:
                        where_clauses.append("(type = ?)")
                        where_args.extend(typ[0])
                        wildcard_types = True
                    else:
                        where_clauses.append("(type = ? AND state_key = ?)")
                        where_args.extend([typ[0], typ[1]])
                where_clause = "AND (%s)" % (" OR ".join(where_clauses))
            else:
                where_clause = ""

            # We don't use WITH RECURSIVE on sqlite3 as there are distributions
            # that ship with an sqlite3 version that doesn't support it (e.g. wheezy)
            for group in groups:
                next_group = group

                while next_group:
                    # We did this before by getting the list of group ids, and
                    # then passing that list to sqlite to get latest event for
                    # each (type, state_key). However, that was terribly slow
                    # without the right indices (which we can't add until
                    # after we finish deduping state, which requires this func)
                    args = [next_group]
                    if types:
                        args.extend(where_args)

                    txn.execute(
                        "SELECT type, state_key, event_id FROM state_groups_state"
                        " WHERE state_group = ? %s" % (where_clause,),
                        args
                    )
                    results[group].update(
                        ((typ, state_key), event_id)
                        for typ, state_key, event_id in txn
                        if (typ, state_key) not in results[group]
                    )

                    # If the number of entries in the (type,state_key)->event_id dict
                    # matches the number of (type,state_keys) types we were searching
                    # for, then we must have found them all, so no need to go walk
                    # further down the tree... UNLESS our types filter contained
                    # wildcards (i.e. Nones) in which case we have to do an exhaustive
                    # search
                    if (
                        types is not None and
                        not wildcard_types and
                        len(results[group]) == len(types)
                    ):
                        break

                    next_group = self._simple_select_one_onecol_txn(
                        txn,
                        table="state_group_edges",
                        keyvalues={"state_group": next_group},
                        retcol="prev_state_group",
                        allow_none=True,
                    )

        return results

    @defer.inlineCallbacks
    def get_state_for_events(self, event_ids, types):
        """Given a list of event_ids and type tuples, return a list of state
        dicts for each event. The state dicts will only have the type/state_keys
        that are in the `types` list.

        Args:
            event_ids (list)
            types (list): List of (type, state_key) tuples which are used to
                filter the state fetched. `state_key` may be None, which matches
                any `state_key`

        Returns:
            deferred: A list of dicts corresponding to the event_ids given.
            The dicts are mappings from (type, state_key) -> state_events
        """
        event_to_groups = yield self._get_state_group_for_events(
            event_ids,
        )

        groups = set(event_to_groups.itervalues())
        group_to_state = yield self._get_state_for_groups(groups, types)

        state_event_map = yield self.get_events(
            [ev_id for sd in group_to_state.itervalues() for ev_id in sd.itervalues()],
            get_prev_content=False
        )

        event_to_state = {
            event_id: {
                k: state_event_map[v]
                for k, v in group_to_state[group].iteritems()
                if v in state_event_map
            }
            for event_id, group in event_to_groups.iteritems()
        }

        defer.returnValue({event: event_to_state[event] for event in event_ids})

    @defer.inlineCallbacks
    def get_state_ids_for_events(self, event_ids, types=None):
        """
        Get the state dicts corresponding to a list of events

        Args:
            event_ids(list(str)): events whose state should be returned
            types(list[(str, str)]|None): List of (type, state_key) tuples
                which are used to filter the state fetched. May be None, which
                matches any key

        Returns:
            A deferred dict from event_id -> (type, state_key) -> state_event
        """
        event_to_groups = yield self._get_state_group_for_events(
            event_ids,
        )

        groups = set(event_to_groups.itervalues())
        group_to_state = yield self._get_state_for_groups(groups, types)

        event_to_state = {
            event_id: group_to_state[group]
            for event_id, group in event_to_groups.iteritems()
        }

        defer.returnValue({event: event_to_state[event] for event in event_ids})

    @defer.inlineCallbacks
    def get_state_for_event(self, event_id, types=None):
        """
        Get the state dict corresponding to a particular event

        Args:
            event_id(str): event whose state should be returned
            types(list[(str, str)]|None): List of (type, state_key) tuples
                which are used to filter the state fetched. May be None, which
                matches any key

        Returns:
            A deferred dict from (type, state_key) -> state_event
        """
        state_map = yield self.get_state_for_events([event_id], types)
        defer.returnValue(state_map[event_id])

    @defer.inlineCallbacks
    def get_state_ids_for_event(self, event_id, types=None):
        """
        Get the state dict corresponding to a particular event

        Args:
            event_id(str): event whose state should be returned
            types(list[(str, str)]|None): List of (type, state_key) tuples
                which are used to filter the state fetched. May be None, which
                matches any key

        Returns:
            A deferred dict from (type, state_key) -> state_event
        """
        state_map = yield self.get_state_ids_for_events([event_id], types)
        defer.returnValue(state_map[event_id])

    @cached(max_entries=50000)
    def _get_state_group_for_event(self, event_id):
        return self._simple_select_one_onecol(
            table="event_to_state_groups",
            keyvalues={
                "event_id": event_id,
            },
            retcol="state_group",
            allow_none=True,
            desc="_get_state_group_for_event",
        )

    @cachedList(cached_method_name="_get_state_group_for_event",
                list_name="event_ids", num_args=1, inlineCallbacks=True)
    def _get_state_group_for_events(self, event_ids):
        """Returns mapping event_id -> state_group
        """
        rows = yield self._simple_select_many_batch(
            table="event_to_state_groups",
            column="event_id",
            iterable=event_ids,
            keyvalues={},
            retcols=("event_id", "state_group",),
            desc="_get_state_group_for_events",
        )

        defer.returnValue({row["event_id"]: row["state_group"] for row in rows})

    def _get_some_state_from_cache(self, group, types):
        """Checks if group is in cache. See `_get_state_for_groups`

        Returns 3-tuple (`state_dict`, `missing_types`, `got_all`).
        `missing_types` is the list of types that aren't in the cache for that
        group. `got_all` is a bool indicating if we successfully retrieved all
        requests state from the cache, if False we need to query the DB for the
        missing state.

        Args:
            group: The state group to lookup
            types (list): List of 2-tuples of the form (`type`, `state_key`),
                where a `state_key` of `None` matches all state_keys for the
                `type`.
        """
        is_all, known_absent, state_dict_ids = self._state_group_cache.get(group)

        type_to_key = {}
        missing_types = set()

        for typ, state_key in types:
            key = (typ, state_key)
            if state_key is None:
                type_to_key[typ] = None
                missing_types.add(key)
            else:
                if type_to_key.get(typ, object()) is not None:
                    type_to_key.setdefault(typ, set()).add(state_key)

                if key not in state_dict_ids and key not in known_absent:
                    missing_types.add(key)

        sentinel = object()

        def include(typ, state_key):
            valid_state_keys = type_to_key.get(typ, sentinel)
            if valid_state_keys is sentinel:
                return False
            if valid_state_keys is None:
                return True
            if state_key in valid_state_keys:
                return True
            return False

        got_all = is_all or not missing_types

        return {
            k: v for k, v in state_dict_ids.iteritems()
            if include(k[0], k[1])
        }, missing_types, got_all

    def _get_all_state_from_cache(self, group):
        """Checks if group is in cache. See `_get_state_for_groups`

        Returns 2-tuple (`state_dict`, `got_all`). `got_all` is a bool
        indicating if we successfully retrieved all requests state from the
        cache, if False we need to query the DB for the missing state.

        Args:
            group: The state group to lookup
        """
        is_all, _, state_dict_ids = self._state_group_cache.get(group)

        return state_dict_ids, is_all

    @defer.inlineCallbacks
    def _get_state_for_groups(self, groups, types=None):
        """Given list of groups returns dict of group -> list of state events
        with matching types. `types` is a list of `(type, state_key)`, where
        a `state_key` of None matches all state_keys. If `types` is None then
        all events are returned.
        """
        if types:
            types = frozenset(types)
        results = {}
        missing_groups = []
        if types is not None:
            for group in set(groups):
                state_dict_ids, _, got_all = self._get_some_state_from_cache(
                    group, types
                )
                results[group] = state_dict_ids

                if not got_all:
                    missing_groups.append(group)
        else:
            for group in set(groups):
                state_dict_ids, got_all = self._get_all_state_from_cache(
                    group
                )

                results[group] = state_dict_ids

                if not got_all:
                    missing_groups.append(group)

        if missing_groups:
            # Okay, so we have some missing_types, lets fetch them.
            cache_seq_num = self._state_group_cache.sequence

            group_to_state_dict = yield self._get_state_groups_from_groups(
                missing_groups, types
            )

            # Now we want to update the cache with all the things we fetched
            # from the database.
            for group, group_state_dict in group_to_state_dict.iteritems():
                state_dict = results[group]

                state_dict.update(
                    ((intern_string(k[0]), intern_string(k[1])), to_ascii(v))
                    for k, v in group_state_dict.iteritems()
                )

                self._state_group_cache.update(
                    cache_seq_num,
                    key=group,
                    value=state_dict,
                    full=(types is None),
                    known_absent=types,
                )

        defer.returnValue(results)

    def store_state_group(self, event_id, room_id, prev_group, delta_ids,
                          current_state_ids):
        """Store a new set of state, returning a newly assigned state group.

        Args:
            event_id (str): The event ID for which the state was calculated
            room_id (str)
            prev_group (int|None): A previous state group for the room, optional.
            delta_ids (dict|None): The delta between state at `prev_group` and
                `current_state_ids`, if `prev_group` was given. Same format as
                `current_state_ids`.
            current_state_ids (dict): The state to store. Map of (type, state_key)
                to event_id.

        Returns:
            Deferred[int]: The state group ID
        """
        def _store_state_group_txn(txn):
            if current_state_ids is None:
                # AFAIK, this can never happen
                raise Exception("current_state_ids cannot be None")

            state_group = self.database_engine.get_next_state_group_id(txn)

            self._simple_insert_txn(
                txn,
                table="state_groups",
                values={
                    "id": state_group,
                    "room_id": room_id,
                    "event_id": event_id,
                },
            )

            # We persist as a delta if we can, while also ensuring the chain
            # of deltas isn't tooo long, as otherwise read performance degrades.
            if prev_group:
                is_in_db = self._simple_select_one_onecol_txn(
                    txn,
                    table="state_groups",
                    keyvalues={"id": prev_group},
                    retcol="id",
                    allow_none=True,
                )
                if not is_in_db:
                    raise Exception(
                        "Trying to persist state with unpersisted prev_group: %r"
                        % (prev_group,)
                    )

                potential_hops = self._count_state_group_hops_txn(
                    txn, prev_group
                )
            if prev_group and potential_hops < MAX_STATE_DELTA_HOPS:
                self._simple_insert_txn(
                    txn,
                    table="state_group_edges",
                    values={
                        "state_group": state_group,
                        "prev_state_group": prev_group,
                    },
                )

                self._simple_insert_many_txn(
                    txn,
                    table="state_groups_state",
                    values=[
                        {
                            "state_group": state_group,
                            "room_id": room_id,
                            "type": key[0],
                            "state_key": key[1],
                            "event_id": state_id,
                        }
                        for key, state_id in delta_ids.iteritems()
                    ],
                )
            else:
                self._simple_insert_many_txn(
                    txn,
                    table="state_groups_state",
                    values=[
                        {
                            "state_group": state_group,
                            "room_id": room_id,
                            "type": key[0],
                            "state_key": key[1],
                            "event_id": state_id,
                        }
                        for key, state_id in current_state_ids.iteritems()
                    ],
                )

            # Prefill the state group cache with this group.
            # It's fine to use the sequence like this as the state group map
            # is immutable. (If the map wasn't immutable then this prefill could
            # race with another update)
            txn.call_after(
                self._state_group_cache.update,
                self._state_group_cache.sequence,
                key=state_group,
                value=dict(current_state_ids),
                full=True,
            )

            return state_group

        return self.runInteraction("store_state_group", _store_state_group_txn)

    def _count_state_group_hops_txn(self, txn, state_group):
        """Given a state group, count how many hops there are in the tree.

        This is used to ensure the delta chains don't get too long.
        """
        if isinstance(self.database_engine, PostgresEngine):
            sql = ("""
                WITH RECURSIVE state(state_group) AS (
                    VALUES(?::bigint)
                    UNION ALL
                    SELECT prev_state_group FROM state_group_edges e, state s
                    WHERE s.state_group = e.state_group
                )
                SELECT count(*) FROM state;
            """)

            txn.execute(sql, (state_group,))
            row = txn.fetchone()
            if row and row[0]:
                return row[0]
            else:
                return 0
        else:
            # We don't use WITH RECURSIVE on sqlite3 as there are distributions
            # that ship with an sqlite3 version that doesn't support it (e.g. wheezy)
            next_group = state_group
            count = 0

            while next_group:
                next_group = self._simple_select_one_onecol_txn(
                    txn,
                    table="state_group_edges",
                    keyvalues={"state_group": next_group},
                    retcol="prev_state_group",
                    allow_none=True,
                )
                if next_group:
                    count += 1

            return count
Example #16
0
class StateGroupReadStore(SQLBaseStore):
    """The read-only parts of StateGroupStore

    None of these functions write to the state tables, so are suitable for
    including in the SlavedStores.
    """

    STATE_GROUP_DEDUPLICATION_UPDATE_NAME = "state_group_state_deduplication"
    STATE_GROUP_INDEX_UPDATE_NAME = "state_group_state_type_index"
    CURRENT_STATE_INDEX_UPDATE_NAME = "current_state_members_idx"

    def __init__(self, db_conn, hs):
        super(StateGroupReadStore, self).__init__(db_conn, hs)

        self._state_group_cache = DictionaryCache("*stateGroupCache*",
                                                  100000 * CACHE_SIZE_FACTOR)

    @cached(max_entries=100000, iterable=True)
    def get_current_state_ids(self, room_id):
        """Get the current state event ids for a room based on the
        current_state_events table.

        Args:
            room_id (str)

        Returns:
            deferred: dict of (type, state_key) -> event_id
        """
        def _get_current_state_ids_txn(txn):
            txn.execute(
                """SELECT type, state_key, event_id FROM current_state_events
                WHERE room_id = ?
                """, (room_id, ))

            return {(intern_string(r[0]), intern_string(r[1])): to_ascii(r[2])
                    for r in txn}

        return self.runInteraction(
            "get_current_state_ids",
            _get_current_state_ids_txn,
        )

    @cached(max_entries=10000, iterable=True)
    def get_state_group_delta(self, state_group):
        """Given a state group try to return a previous group and a delta between
        the old and the new.

        Returns:
            (prev_group, delta_ids), where both may be None.
        """
        def _get_state_group_delta_txn(txn):
            prev_group = self._simple_select_one_onecol_txn(
                txn,
                table="state_group_edges",
                keyvalues={
                    "state_group": state_group,
                },
                retcol="prev_state_group",
                allow_none=True,
            )

            if not prev_group:
                return _GetStateGroupDelta(None, None)

            delta_ids = self._simple_select_list_txn(
                txn,
                table="state_groups_state",
                keyvalues={
                    "state_group": state_group,
                },
                retcols=(
                    "type",
                    "state_key",
                    "event_id",
                ))

            return _GetStateGroupDelta(
                prev_group, {(row["type"], row["state_key"]): row["event_id"]
                             for row in delta_ids})

        return self.runInteraction(
            "get_state_group_delta",
            _get_state_group_delta_txn,
        )

    @defer.inlineCallbacks
    def get_state_groups_ids(self, room_id, event_ids):
        if not event_ids:
            defer.returnValue({})

        event_to_groups = yield self._get_state_group_for_events(event_ids, )

        groups = set(event_to_groups.itervalues())
        group_to_state = yield self._get_state_for_groups(groups)

        defer.returnValue(group_to_state)

    @defer.inlineCallbacks
    def get_state_groups(self, room_id, event_ids):
        """ Get the state groups for the given list of event_ids

        The return value is a dict mapping group names to lists of events.
        """
        if not event_ids:
            defer.returnValue({})

        group_to_ids = yield self.get_state_groups_ids(room_id, event_ids)

        state_event_map = yield self.get_events([
            ev_id for group_ids in group_to_ids.itervalues()
            for ev_id in group_ids.itervalues()
        ],
                                                get_prev_content=False)

        defer.returnValue({
            group: [
                state_event_map[v] for v in event_id_map.itervalues()
                if v in state_event_map
            ]
            for group, event_id_map in group_to_ids.iteritems()
        })

    @defer.inlineCallbacks
    def _get_state_groups_from_groups(self, groups, types):
        """Returns dictionary state_group -> (dict of (type, state_key) -> event id)
        """
        results = {}

        chunks = [groups[i:i + 100] for i in xrange(0, len(groups), 100)]
        for chunk in chunks:
            res = yield self.runInteraction(
                "_get_state_groups_from_groups",
                self._get_state_groups_from_groups_txn,
                chunk,
                types,
            )
            results.update(res)

        defer.returnValue(results)

    def _get_state_groups_from_groups_txn(self, txn, groups, types=None):
        results = {group: {} for group in groups}
        if types is not None:
            types = list(set(types))  # deduplicate types list

        if isinstance(self.database_engine, PostgresEngine):
            # Temporarily disable sequential scans in this transaction. This is
            # a temporary hack until we can add the right indices in
            txn.execute("SET LOCAL enable_seqscan=off")

            # The below query walks the state_group tree so that the "state"
            # table includes all state_groups in the tree. It then joins
            # against `state_groups_state` to fetch the latest state.
            # It assumes that previous state groups are always numerically
            # lesser.
            # The PARTITION is used to get the event_id in the greatest state
            # group for the given type, state_key.
            # This may return multiple rows per (type, state_key), but last_value
            # should be the same.
            sql = ("""
                WITH RECURSIVE state(state_group) AS (
                    VALUES(?::bigint)
                    UNION ALL
                    SELECT prev_state_group FROM state_group_edges e, state s
                    WHERE s.state_group = e.state_group
                )
                SELECT type, state_key, last_value(event_id) OVER (
                    PARTITION BY type, state_key ORDER BY state_group ASC
                    ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING
                ) AS event_id FROM state_groups_state
                WHERE state_group IN (
                    SELECT state_group FROM state
                )
                %s
            """)

            # Turns out that postgres doesn't like doing a list of OR's and
            # is about 1000x slower, so we just issue a query for each specific
            # type seperately.
            if types:
                clause_to_args = [("AND type = ? AND state_key = ?",
                                   (etype, state_key))
                                  for etype, state_key in types]
            else:
                # If types is None we fetch all the state, and so just use an
                # empty where clause with no extra args.
                clause_to_args = [("", [])]

            for where_clause, where_args in clause_to_args:
                for group in groups:
                    args = [group]
                    args.extend(where_args)

                    txn.execute(sql % (where_clause, ), args)
                    for row in txn:
                        typ, state_key, event_id = row
                        key = (typ, state_key)
                        results[group][key] = event_id
        else:
            if types is not None:
                where_clause = "AND (%s)" % (" OR ".join(
                    ["(type = ? AND state_key = ?)"] * len(types)), )
            else:
                where_clause = ""

            # We don't use WITH RECURSIVE on sqlite3 as there are distributions
            # that ship with an sqlite3 version that doesn't support it (e.g. wheezy)
            for group in groups:
                next_group = group

                while next_group:
                    # We did this before by getting the list of group ids, and
                    # then passing that list to sqlite to get latest event for
                    # each (type, state_key). However, that was terribly slow
                    # without the right indices (which we can't add until
                    # after we finish deduping state, which requires this func)
                    args = [next_group]
                    if types:
                        args.extend(i for typ in types for i in typ)

                    txn.execute(
                        "SELECT type, state_key, event_id FROM state_groups_state"
                        " WHERE state_group = ? %s" % (where_clause, ), args)
                    results[group].update(
                        ((typ, state_key), event_id)
                        for typ, state_key, event_id in txn
                        if (typ, state_key) not in results[group])

                    # If the lengths match then we must have all the types,
                    # so no need to go walk further down the tree.
                    if types is not None and len(results[group]) == len(types):
                        break

                    next_group = self._simple_select_one_onecol_txn(
                        txn,
                        table="state_group_edges",
                        keyvalues={"state_group": next_group},
                        retcol="prev_state_group",
                        allow_none=True,
                    )

        return results

    @defer.inlineCallbacks
    def get_state_for_events(self, event_ids, types):
        """Given a list of event_ids and type tuples, return a list of state
        dicts for each event. The state dicts will only have the type/state_keys
        that are in the `types` list.

        Args:
            event_ids (list)
            types (list): List of (type, state_key) tuples which are used to
                filter the state fetched. `state_key` may be None, which matches
                any `state_key`

        Returns:
            deferred: A list of dicts corresponding to the event_ids given.
            The dicts are mappings from (type, state_key) -> state_events
        """
        event_to_groups = yield self._get_state_group_for_events(event_ids, )

        groups = set(event_to_groups.itervalues())
        group_to_state = yield self._get_state_for_groups(groups, types)

        state_event_map = yield self.get_events([
            ev_id for sd in group_to_state.itervalues()
            for ev_id in sd.itervalues()
        ],
                                                get_prev_content=False)

        event_to_state = {
            event_id: {
                k: state_event_map[v]
                for k, v in group_to_state[group].iteritems()
                if v in state_event_map
            }
            for event_id, group in event_to_groups.iteritems()
        }

        defer.returnValue(
            {event: event_to_state[event]
             for event in event_ids})

    @defer.inlineCallbacks
    def get_state_ids_for_events(self, event_ids, types=None):
        """
        Get the state dicts corresponding to a list of events

        Args:
            event_ids(list(str)): events whose state should be returned
            types(list[(str, str)]|None): List of (type, state_key) tuples
                which are used to filter the state fetched. May be None, which
                matches any key

        Returns:
            A deferred dict from event_id -> (type, state_key) -> state_event
        """
        event_to_groups = yield self._get_state_group_for_events(event_ids, )

        groups = set(event_to_groups.itervalues())
        group_to_state = yield self._get_state_for_groups(groups, types)

        event_to_state = {
            event_id: group_to_state[group]
            for event_id, group in event_to_groups.iteritems()
        }

        defer.returnValue(
            {event: event_to_state[event]
             for event in event_ids})

    @defer.inlineCallbacks
    def get_state_for_event(self, event_id, types=None):
        """
        Get the state dict corresponding to a particular event

        Args:
            event_id(str): event whose state should be returned
            types(list[(str, str)]|None): List of (type, state_key) tuples
                which are used to filter the state fetched. May be None, which
                matches any key

        Returns:
            A deferred dict from (type, state_key) -> state_event
        """
        state_map = yield self.get_state_for_events([event_id], types)
        defer.returnValue(state_map[event_id])

    @defer.inlineCallbacks
    def get_state_ids_for_event(self, event_id, types=None):
        """
        Get the state dict corresponding to a particular event

        Args:
            event_id(str): event whose state should be returned
            types(list[(str, str)]|None): List of (type, state_key) tuples
                which are used to filter the state fetched. May be None, which
                matches any key

        Returns:
            A deferred dict from (type, state_key) -> state_event
        """
        state_map = yield self.get_state_ids_for_events([event_id], types)
        defer.returnValue(state_map[event_id])

    @cached(max_entries=50000)
    def _get_state_group_for_event(self, event_id):
        return self._simple_select_one_onecol(
            table="event_to_state_groups",
            keyvalues={
                "event_id": event_id,
            },
            retcol="state_group",
            allow_none=True,
            desc="_get_state_group_for_event",
        )

    @cachedList(cached_method_name="_get_state_group_for_event",
                list_name="event_ids",
                num_args=1,
                inlineCallbacks=True)
    def _get_state_group_for_events(self, event_ids):
        """Returns mapping event_id -> state_group
        """
        rows = yield self._simple_select_many_batch(
            table="event_to_state_groups",
            column="event_id",
            iterable=event_ids,
            keyvalues={},
            retcols=(
                "event_id",
                "state_group",
            ),
            desc="_get_state_group_for_events",
        )

        defer.returnValue(
            {row["event_id"]: row["state_group"]
             for row in rows})

    def _get_some_state_from_cache(self, group, types):
        """Checks if group is in cache. See `_get_state_for_groups`

        Returns 3-tuple (`state_dict`, `missing_types`, `got_all`).
        `missing_types` is the list of types that aren't in the cache for that
        group. `got_all` is a bool indicating if we successfully retrieved all
        requests state from the cache, if False we need to query the DB for the
        missing state.

        Args:
            group: The state group to lookup
            types (list): List of 2-tuples of the form (`type`, `state_key`),
                where a `state_key` of `None` matches all state_keys for the
                `type`.
        """
        is_all, known_absent, state_dict_ids = self._state_group_cache.get(
            group)

        type_to_key = {}
        missing_types = set()

        for typ, state_key in types:
            key = (typ, state_key)
            if state_key is None:
                type_to_key[typ] = None
                missing_types.add(key)
            else:
                if type_to_key.get(typ, object()) is not None:
                    type_to_key.setdefault(typ, set()).add(state_key)

                if key not in state_dict_ids and key not in known_absent:
                    missing_types.add(key)

        sentinel = object()

        def include(typ, state_key):
            valid_state_keys = type_to_key.get(typ, sentinel)
            if valid_state_keys is sentinel:
                return False
            if valid_state_keys is None:
                return True
            if state_key in valid_state_keys:
                return True
            return False

        got_all = is_all or not missing_types

        return {
            k: v
            for k, v in state_dict_ids.iteritems() if include(k[0], k[1])
        }, missing_types, got_all

    def _get_all_state_from_cache(self, group):
        """Checks if group is in cache. See `_get_state_for_groups`

        Returns 2-tuple (`state_dict`, `got_all`). `got_all` is a bool
        indicating if we successfully retrieved all requests state from the
        cache, if False we need to query the DB for the missing state.

        Args:
            group: The state group to lookup
        """
        is_all, _, state_dict_ids = self._state_group_cache.get(group)

        return state_dict_ids, is_all

    @defer.inlineCallbacks
    def _get_state_for_groups(self, groups, types=None):
        """Given list of groups returns dict of group -> list of state events
        with matching types. `types` is a list of `(type, state_key)`, where
        a `state_key` of None matches all state_keys. If `types` is None then
        all events are returned.
        """
        if types:
            types = frozenset(types)
        results = {}
        missing_groups = []
        if types is not None:
            for group in set(groups):
                state_dict_ids, _, got_all = self._get_some_state_from_cache(
                    group, types)
                results[group] = state_dict_ids

                if not got_all:
                    missing_groups.append(group)
        else:
            for group in set(groups):
                state_dict_ids, got_all = self._get_all_state_from_cache(group)

                results[group] = state_dict_ids

                if not got_all:
                    missing_groups.append(group)

        if missing_groups:
            # Okay, so we have some missing_types, lets fetch them.
            cache_seq_num = self._state_group_cache.sequence

            group_to_state_dict = yield self._get_state_groups_from_groups(
                missing_groups, types)

            # Now we want to update the cache with all the things we fetched
            # from the database.
            for group, group_state_dict in group_to_state_dict.iteritems():
                state_dict = results[group]

                state_dict.update(
                    ((intern_string(k[0]), intern_string(k[1])), to_ascii(v))
                    for k, v in group_state_dict.iteritems())

                self._state_group_cache.update(
                    cache_seq_num,
                    key=group,
                    value=state_dict,
                    full=(types is None),
                    known_absent=types,
                )

        defer.returnValue(results)