Exemple #1
0
    def _invalidate_state_caches_and_stream(self, txn, room_id,
                                            members_changed):
        """Special case invalidation of caches based on current state.

        We special case this so that we can batch the cache invalidations into a
        single replication poke.

        Args:
            txn
            room_id (str): Room where state changed
            members_changed (iterable[str]): The user_ids of members that have changed
        """
        txn.call_after(self._invalidate_state_caches, room_id, members_changed)

        if members_changed:
            # We need to be careful that the size of the `members_changed` list
            # isn't so large that it causes problems sending over replication, so we
            # send them in chunks.
            # Max line length is 16K, and max user ID length is 255, so 50 should
            # be safe.
            for chunk in batch_iter(members_changed, 50):
                keys = itertools.chain([room_id], chunk)
                self._send_invalidation_to_replication(
                    txn, CURRENT_STATE_CACHE_NAME, keys)
        else:
            # if no members changed, we still need to invalidate the other caches.
            self._send_invalidation_to_replication(txn,
                                                   CURRENT_STATE_CACHE_NAME,
                                                   [room_id])
Exemple #2
0
    def _get_auth_chain_ids_txn(self, txn: LoggingTransaction,
                                event_ids: Collection[str],
                                include_given: bool) -> List[str]:
        if include_given:
            results = set(event_ids)
        else:
            results = set()

        base_sql = "SELECT DISTINCT auth_id FROM event_auth WHERE "

        front = set(event_ids)
        while front:
            new_front = set()
            for chunk in batch_iter(front, 100):
                clause, args = make_in_list_sql_clause(txn.database_engine,
                                                       "event_id", chunk)
                txn.execute(base_sql + clause, args)
                new_front.update(r[0] for r in txn)

            new_front -= results

            front = new_front
            results.update(front)

        return list(results)
    def _get_auth_chain_ids_txn(self, txn, event_ids, include_given,
                                ignore_events):
        if ignore_events is None:
            ignore_events = set()

        if include_given:
            results = set(event_ids)
        else:
            results = set()

        base_sql = "SELECT auth_id FROM event_auth WHERE "

        front = set(event_ids)
        while front:
            new_front = set()
            for chunk in batch_iter(front, 100):
                clause, args = make_in_list_sql_clause(txn.database_engine,
                                                       "event_id", chunk)
                txn.execute(base_sql + clause, args)
                new_front.update(r[0] for r in txn)

            new_front -= ignore_events
            new_front -= results

            front = new_front
            results.update(front)

        return list(results)
Exemple #4
0
    def _update_presence_txn(self, txn, stream_orderings, presence_states):
        for stream_id, state in zip(stream_orderings, presence_states):
            txn.call_after(
                self.presence_stream_cache.entity_has_changed, state.user_id, stream_id
            )
            txn.call_after(self._get_presence_for_user.invalidate, (state.user_id,))

        # Actually insert new rows
        self.db_pool.simple_insert_many_txn(
            txn,
            table="presence_stream",
            values=[
                {
                    "stream_id": stream_id,
                    "user_id": state.user_id,
                    "state": state.state,
                    "last_active_ts": state.last_active_ts,
                    "last_federation_update_ts": state.last_federation_update_ts,
                    "last_user_sync_ts": state.last_user_sync_ts,
                    "status_msg": state.status_msg,
                    "currently_active": state.currently_active,
                }
                for stream_id, state in zip(stream_orderings, presence_states)
            ],
        )

        # Delete old rows to stop database from getting really big
        sql = "DELETE FROM presence_stream WHERE stream_id < ? AND "

        for states in batch_iter(presence_states, 50):
            clause, args = make_in_list_sql_clause(
                self.database_engine, "user_id", [s.user_id for s in states]
            )
            txn.execute(sql + clause, [stream_id] + list(args))
Exemple #5
0
    def _get_auth_chain_ids_txn(self, txn: LoggingTransaction,
                                event_ids: Collection[str],
                                include_given: bool) -> List[str]:
        """Calculates the auth chain IDs.

        This is used when we don't have a cover index for the room.
        """
        if include_given:
            results = set(event_ids)
        else:
            results = set()

        # We pull out the depth simply so that we can populate the
        # `_event_auth_cache` cache.
        base_sql = """
            SELECT a.event_id, auth_id, depth
            FROM event_auth AS a
            INNER JOIN events AS e ON (e.event_id = a.auth_id)
            WHERE
        """

        front = set(event_ids)
        while front:
            new_front = set()
            for chunk in batch_iter(front, 100):
                # Pull the auth events either from the cache or DB.
                to_fetch = []  # Event IDs to fetch from DB  # type: List[str]
                for event_id in chunk:
                    res = self._event_auth_cache.get(event_id)
                    if res is None:
                        to_fetch.append(event_id)
                    else:
                        new_front.update(auth_id for auth_id, depth in res)

                if to_fetch:
                    clause, args = make_in_list_sql_clause(
                        txn.database_engine, "a.event_id", to_fetch)
                    txn.execute(base_sql + clause, args)

                    # Note we need to batch up the results by event ID before
                    # adding to the cache.
                    to_cache = {}
                    for event_id, auth_event_id, auth_event_depth in txn:
                        to_cache.setdefault(event_id, []).append(
                            (auth_event_id, auth_event_depth))
                        new_front.add(auth_event_id)

                    for event_id, auth_events in to_cache.items():
                        self._event_auth_cache.set(event_id, auth_events)

            new_front -= results

            front = new_front
            results.update(front)

        return list(results)
Exemple #6
0
        def _get_users_whose_devices_changed_txn(txn):
            changes = set()

            sql = """
                SELECT DISTINCT user_id FROM device_lists_stream
                WHERE stream_id > ?
                AND
            """

            for chunk in batch_iter(to_check, 100):
                clause, args = make_in_list_sql_clause(txn.database_engine,
                                                       "user_id", chunk)
                txn.execute(sql + clause, (from_key, ) + tuple(args))
                changes.update(user_id for user_id, in txn)

            return changes
Exemple #7
0
    def _get_bare_e2e_cross_signing_keys_bulk_txn(
        self, txn: Connection, user_ids: List[str],
    ) -> Dict[str, Dict[str, dict]]:
        """Returns the cross-signing keys for a set of users.  The output of this
        function should be passed to _get_e2e_cross_signing_signatures_txn if
        the signatures for the calling user need to be fetched.

        Args:
            txn (twisted.enterprise.adbapi.Connection): db connection
            user_ids (list[str]): the users whose keys are being requested

        Returns:
            dict[str, dict[str, dict]]: mapping from user ID to key type to key
                data.  If a user's cross-signing keys were not found, their user
                ID will not be in the dict.

        """
        result = {}

        for user_chunk in batch_iter(user_ids, 100):
            clause, params = make_in_list_sql_clause(
                txn.database_engine, "k.user_id", user_chunk
            )
            sql = (
                """
                SELECT k.user_id, k.keytype, k.keydata, k.stream_id
                  FROM e2e_cross_signing_keys k
                  INNER JOIN (SELECT user_id, keytype, MAX(stream_id) AS stream_id
                                FROM e2e_cross_signing_keys
                               GROUP BY user_id, keytype) s
                 USING (user_id, stream_id, keytype)
                 WHERE
            """
                + clause
            )

            txn.execute(sql, params)
            rows = self.db.cursor_to_dict(txn)

            for row in rows:
                user_id = row["user_id"]
                key_type = row["keytype"]
                key = json.loads(row["keydata"])
                user_info = result.setdefault(user_id, {})
                user_info[key_type] = key

        return result
Exemple #8
0
    async def get_metadata_for_events(
            self, event_ids: Collection[str]) -> Dict[str, EventMetadata]:
        """Get some metadata (room_id, type, state_key) for the given events.

        This method is a faster alternative than fetching the full events from
        the DB, and should be used when the full event is not needed.

        Returns metadata for rejected and redacted events. Events that have not
        been persisted are omitted from the returned dict.
        """
        def get_metadata_for_events_txn(
            txn: LoggingTransaction,
            batch_ids: Collection[str],
        ) -> Dict[str, EventMetadata]:
            clause, args = make_in_list_sql_clause(self.database_engine,
                                                   "e.event_id", batch_ids)

            sql = f"""
                SELECT e.event_id, e.room_id, e.type, se.state_key, r.reason
                FROM events AS e
                LEFT JOIN state_events se USING (event_id)
                LEFT JOIN rejections r USING (event_id)
                WHERE {clause}
            """

            txn.execute(sql, args)
            return {
                event_id: EventMetadata(
                    room_id=room_id,
                    event_type=event_type,
                    state_key=state_key,
                    rejection_reason=rejection_reason,
                )
                for event_id, room_id, event_type, state_key, rejection_reason
                in txn
            }

        result_map: Dict[str, EventMetadata] = {}
        for batch_ids in batch_iter(event_ids, 1000):
            result_map.update(await self.db_pool.runInteraction(
                "get_metadata_for_events",
                get_metadata_for_events_txn,
                batch_ids=batch_ids,
            ))

        return result_map
Exemple #9
0
    def _get_e2e_cross_signing_signatures_txn(
        self,
        txn: Connection,
        keys: Dict[str, Dict[str, dict]],
        from_user_id: str,
    ) -> Dict[str, Dict[str, dict]]:
        """Returns the cross-signing signatures made by a user on a set of keys.

        Args:
            txn (twisted.enterprise.adbapi.Connection): db connection
            keys (dict[str, dict[str, dict]]): a map of user ID to key type to
                key data.  This dict will be modified to add signatures.
            from_user_id (str): fetch the signatures made by this user

        Returns:
            dict[str, dict[str, dict]]: mapping from user ID to key type to key
                data.  The return value will be the same as the keys argument,
                with the modifications included.
        """

        # find out what cross-signing keys (a.k.a. devices) we need to get
        # signatures for.  This is a map of (user_id, device_id) to key type
        # (device_id is the key's public part).
        devices = {}

        for user_id, user_info in keys.items():
            if user_info is None:
                continue
            for key_type, key in user_info.items():
                device_id = None
                for k in key["keys"].values():
                    device_id = k
                devices[(user_id, device_id)] = key_type

        for batch in batch_iter(devices.keys(), size=100):
            sql = """
                SELECT target_user_id, target_device_id, key_id, signature
                  FROM e2e_cross_signing_signatures
                 WHERE user_id = ?
                   AND (%s)
            """ % (" OR ".join("(target_user_id = ? AND target_device_id = ?)"
                               for _ in batch))
            query_params = [from_user_id]
            for item in batch:
                # item is a (user_id, device_id) tuple
                query_params.extend(item)

            txn.execute(sql, query_params)
            rows = self.db_pool.cursor_to_dict(txn)

            # and add the signatures to the appropriate keys
            for row in rows:
                key_id = row["key_id"]
                target_user_id = row["target_user_id"]
                target_device_id = row["target_device_id"]
                key_type = devices[(target_user_id, target_device_id)]
                # We need to copy everything, because the result may have come
                # from the cache.  dict.copy only does a shallow copy, so we
                # need to recursively copy the dicts that will be modified.
                user_info = keys[target_user_id] = keys[target_user_id].copy()
                target_user_key = user_info[key_type] = user_info[
                    key_type].copy()
                if "signatures" in target_user_key:
                    signatures = target_user_key[
                        "signatures"] = target_user_key["signatures"].copy()
                    if from_user_id in signatures:
                        user_sigs = signatures[from_user_id] = signatures[
                            from_user_id]
                        user_sigs[key_id] = row["signature"]
                    else:
                        signatures[from_user_id] = {key_id: row["signature"]}
                else:
                    target_user_key["signatures"] = {
                        from_user_id: {
                            key_id: row["signature"]
                        }
                    }

        return keys
Exemple #10
0
 def _txn(txn):
     for batch in batch_iter(server_name_and_key_ids, 50):
         _get_keys(txn, batch)
     return keys
Exemple #11
0
 def _txn(txn: Cursor) -> Dict[Tuple[str, str], FetchKeyResult]:
     for batch in batch_iter(server_name_and_key_ids, 50):
         _get_keys(txn, batch)
     return keys
    async def get_e2e_device_keys_and_signatures(
        self,
        query_list: List[Tuple[str, Optional[str]]],
        include_all_devices: bool = False,
        include_deleted_devices: bool = False,
    ) -> Dict[str, Dict[str, Optional[DeviceKeyLookupResult]]]:
        """Fetch a list of device keys

        Any cross-signatures made on the keys by the owner of the device are also
        included.

        The cross-signatures are added to the `signatures` field within the `keys`
        object in the response.

        Args:
            query_list: List of pairs of user_ids and device_ids. Device id can be None
                to indicate "all devices for this user"

            include_all_devices: whether to return devices without device keys

            include_deleted_devices: whether to include null entries for
                devices which no longer exist (but were in the query_list).
                This option only takes effect if include_all_devices is true.

        Returns:
            Dict mapping from user-id to dict mapping from device_id to
            key data.
        """
        set_tag("include_all_devices", include_all_devices)
        set_tag("include_deleted_devices", include_deleted_devices)

        result = await self.db_pool.runInteraction(
            "get_e2e_device_keys",
            self._get_e2e_device_keys_txn,
            query_list,
            include_all_devices,
            include_deleted_devices,
        )

        # get the (user_id, device_id) tuples to look up cross-signatures for
        signature_query = ((user_id, device_id)
                           for user_id, dev in result.items()
                           for device_id, d in dev.items()
                           if d is not None and d.keys is not None)

        for batch in batch_iter(signature_query, 50):
            cross_sigs_result = await self.db_pool.runInteraction(
                "get_e2e_cross_signing_signatures",
                self._get_e2e_cross_signing_signatures_for_devices_txn,
                batch,
            )

            # add each cross-signing signature to the correct device in the result dict.
            for (user_id, key_id, device_id, signature) in cross_sigs_result:
                target_device_result = result[user_id][device_id]
                target_device_signatures = target_device_result.keys.setdefault(
                    "signatures", {})
                signing_user_signatures = target_device_signatures.setdefault(
                    user_id, {})
                signing_user_signatures[key_id] = signature

        log_kv(result)
        return result
    def _get_e2e_cross_signing_signatures_txn(
        self,
        txn: LoggingTransaction,
        keys: Dict[str, Optional[Dict[str, JsonDict]]],
        from_user_id: str,
    ) -> Dict[str, Optional[Dict[str, JsonDict]]]:
        """Returns the cross-signing signatures made by a user on a set of keys.

        Args:
            txn: db connection
            keys: a map of user ID to key type to key data.
                This dict will be modified to add signatures.
            from_user_id: fetch the signatures made by this user

        Returns:
            Mapping from user ID to key type to key data.
            The return value will be the same as the keys argument, with the
            modifications included.
        """

        # find out what cross-signing keys (a.k.a. devices) we need to get
        # signatures for.  This is a map of (user_id, device_id) to key type
        # (device_id is the key's public part).
        devices: Dict[Tuple[str, str], str] = {}

        for user_id, user_keys in keys.items():
            if user_keys is None:
                continue
            for key_type, key in user_keys.items():
                device_id = None
                for k in key["keys"].values():
                    device_id = k
                # `key` ought to be a `CrossSigningKey`, whose .keys property is a
                # dictionary with a single entry:
                #     "algorithm:base64_public_key": "base64_public_key"
                # See https://spec.matrix.org/v1.1/client-server-api/#cross-signing
                assert isinstance(device_id, str)
                devices[(user_id, device_id)] = key_type

        for batch in batch_iter(devices.keys(), size=100):
            sql = """
                SELECT target_user_id, target_device_id, key_id, signature
                  FROM e2e_cross_signing_signatures
                 WHERE user_id = ?
                   AND (%s)
            """ % (" OR ".join("(target_user_id = ? AND target_device_id = ?)"
                               for _ in batch))
            query_params = [from_user_id]
            for item in batch:
                # item is a (user_id, device_id) tuple
                query_params.extend(item)

            txn.execute(sql, query_params)
            rows = self.db_pool.cursor_to_dict(txn)

            # and add the signatures to the appropriate keys
            for row in rows:
                key_id: str = row["key_id"]
                target_user_id: str = row["target_user_id"]
                target_device_id: str = row["target_device_id"]
                key_type = devices[(target_user_id, target_device_id)]
                # We need to copy everything, because the result may have come
                # from the cache.  dict.copy only does a shallow copy, so we
                # need to recursively copy the dicts that will be modified.
                user_keys = keys[target_user_id]
                # `user_keys` cannot be `None` because we only fetched signatures for
                # users with keys
                assert user_keys is not None
                user_keys = keys[target_user_id] = user_keys.copy()

                target_user_key = user_keys[key_type] = user_keys[
                    key_type].copy()
                if "signatures" in target_user_key:
                    signatures = target_user_key[
                        "signatures"] = target_user_key["signatures"].copy()
                    if from_user_id in signatures:
                        user_sigs = signatures[from_user_id] = signatures[
                            from_user_id]
                        user_sigs[key_id] = row["signature"]
                    else:
                        signatures[from_user_id] = {key_id: row["signature"]}
                else:
                    target_user_key["signatures"] = {
                        from_user_id: {
                            key_id: row["signature"]
                        }
                    }

        return keys
Exemple #14
0
    def _get_auth_chain_difference_using_cover_index_txn(
            self, txn: Cursor, room_id: str,
            state_sets: List[Set[str]]) -> Set[str]:
        """Calculates the auth chain difference using the chain index.

        See docs/auth_chain_difference_algorithm.md for details
        """

        # First we look up the chain ID/sequence numbers for all the events, and
        # work out the chain/sequence numbers reachable from each state set.

        initial_events = set(state_sets[0]).union(*state_sets[1:])

        # Map from event_id -> (chain ID, seq no)
        chain_info = {}  # type: Dict[str, Tuple[int, int]]

        # Map from chain ID -> seq no -> event Id
        chain_to_event = {}  # type: Dict[int, Dict[int, str]]

        # All the chains that we've found that are reachable from the state
        # sets.
        seen_chains = set()  # type: Set[int]

        sql = """
            SELECT event_id, chain_id, sequence_number
            FROM event_auth_chains
            WHERE %s
        """
        for batch in batch_iter(initial_events, 1000):
            clause, args = make_in_list_sql_clause(txn.database_engine,
                                                   "event_id", batch)
            txn.execute(sql % (clause, ), args)

            for event_id, chain_id, sequence_number in txn:
                chain_info[event_id] = (chain_id, sequence_number)
                seen_chains.add(chain_id)
                chain_to_event.setdefault(chain_id,
                                          {})[sequence_number] = event_id

        # Check that we actually have a chain ID for all the events.
        events_missing_chain_info = initial_events.difference(chain_info)
        if events_missing_chain_info:
            # This can happen due to e.g. downgrade/upgrade of the server. We
            # raise an exception and fall back to the previous algorithm.
            logger.info(
                "Unexpectedly found that events don't have chain IDs in room %s: %s",
                room_id,
                events_missing_chain_info,
            )
            raise _NoChainCoverIndex(room_id)

        # Corresponds to `state_sets`, except as a map from chain ID to max
        # sequence number reachable from the state set.
        set_to_chain = []  # type: List[Dict[int, int]]
        for state_set in state_sets:
            chains = {}  # type: Dict[int, int]
            set_to_chain.append(chains)

            for event_id in state_set:
                chain_id, seq_no = chain_info[event_id]

                chains[chain_id] = max(seq_no, chains.get(chain_id, 0))

        # Now we look up all links for the chains we have, adding chains to
        # set_to_chain that are reachable from each set.
        sql = """
            SELECT
                origin_chain_id, origin_sequence_number,
                target_chain_id, target_sequence_number
            FROM event_auth_chain_links
            WHERE %s
        """

        # (We need to take a copy of `seen_chains` as we want to mutate it in
        # the loop)
        for batch in batch_iter(set(seen_chains), 1000):
            clause, args = make_in_list_sql_clause(txn.database_engine,
                                                   "origin_chain_id", batch)
            txn.execute(sql % (clause, ), args)

            for (
                    origin_chain_id,
                    origin_sequence_number,
                    target_chain_id,
                    target_sequence_number,
            ) in txn:
                for chains in set_to_chain:
                    # chains are only reachable if the origin sequence number of
                    # the link is less than the max sequence number in the
                    # origin chain.
                    if origin_sequence_number <= chains.get(
                            origin_chain_id, 0):
                        chains[target_chain_id] = max(
                            target_sequence_number,
                            chains.get(target_chain_id, 0),
                        )

                seen_chains.add(target_chain_id)

        # Now for each chain we figure out the maximum sequence number reachable
        # from *any* state set and the minimum sequence number reachable from
        # *all* state sets. Events in that range are in the auth chain
        # difference.
        result = set()

        # Mapping from chain ID to the range of sequence numbers that should be
        # pulled from the database.
        chain_to_gap = {}  # type: Dict[int, Tuple[int, int]]

        for chain_id in seen_chains:
            min_seq_no = min(
                chains.get(chain_id, 0) for chains in set_to_chain)
            max_seq_no = max(
                chains.get(chain_id, 0) for chains in set_to_chain)

            if min_seq_no < max_seq_no:
                # We have a non empty gap, try and fill it from the events that
                # we have, otherwise add them to the list of gaps to pull out
                # from the DB.
                for seq_no in range(min_seq_no + 1, max_seq_no + 1):
                    event_id = chain_to_event.get(chain_id, {}).get(seq_no)
                    if event_id:
                        result.add(event_id)
                    else:
                        chain_to_gap[chain_id] = (min_seq_no, max_seq_no)
                        break

        if not chain_to_gap:
            # If there are no gaps to fetch, we're done!
            return result

        if isinstance(self.database_engine, PostgresEngine):
            # We can use `execute_values` to efficiently fetch the gaps when
            # using postgres.
            sql = """
                SELECT event_id
                FROM event_auth_chains AS c, (VALUES ?) AS l(chain_id, min_seq, max_seq)
                WHERE
                    c.chain_id = l.chain_id
                    AND min_seq < sequence_number AND sequence_number <= max_seq
            """

            args = [(chain_id, min_no, max_no)
                    for chain_id, (min_no, max_no) in chain_to_gap.items()]

            rows = txn.execute_values(sql, args)
            result.update(r for r, in rows)
        else:
            # For SQLite we just fall back to doing a noddy for loop.
            sql = """
                SELECT event_id FROM event_auth_chains
                WHERE chain_id = ? AND ? < sequence_number AND sequence_number <= ?
            """
            for chain_id, (min_no, max_no) in chain_to_gap.items():
                txn.execute(sql, (chain_id, min_no, max_no))
                result.update(r for r, in txn)

        return result
    def _fetch_event_rows(self, txn, event_ids):
        """Fetch event rows from the database

        Events which are not found are omitted from the result.

        The returned per-event dicts contain the following keys:

         * event_id (str)

         * json (str): json-encoded event structure

         * internal_metadata (str): json-encoded internal metadata dict

         * format_version (int|None): The format of the event. Hopefully one
           of EventFormatVersions. 'None' means the event predates
           EventFormatVersions (so the event is format V1).

         * rejected_reason (str|None): if the event was rejected, the reason
           why.

         * redactions (List[str]): a list of event-ids which (claim to) redact
           this event.

        Args:
            txn (twisted.enterprise.adbapi.Connection):
            event_ids (Iterable[str]): event IDs to fetch

        Returns:
            Dict[str, Dict]: a map from event id to event info.
        """
        event_dict = {}
        for evs in batch_iter(event_ids, 200):
            sql = ("SELECT "
                   " e.event_id, "
                   " e.internal_metadata,"
                   " e.json,"
                   " e.format_version, "
                   " rej.reason "
                   " FROM event_json as e"
                   " LEFT JOIN rejections as rej USING (event_id)"
                   " WHERE ")

            clause, args = make_in_list_sql_clause(txn.database_engine,
                                                   "e.event_id", evs)

            txn.execute(sql + clause, args)

            for row in txn:
                event_id = row[0]
                event_dict[event_id] = {
                    "event_id": event_id,
                    "internal_metadata": row[1],
                    "json": row[2],
                    "format_version": row[3],
                    "rejected_reason": row[4],
                    "redactions": [],
                }

            # check for redactions
            redactions_sql = "SELECT event_id, redacts FROM redactions WHERE "

            clause, args = make_in_list_sql_clause(txn.database_engine,
                                                   "redacts", evs)

            txn.execute(redactions_sql + clause, args)

            for (redacter, redacted) in txn:
                d = event_dict.get(redacted)
                if d:
                    d["redactions"].append(redacter)

        return event_dict
Exemple #16
0
    def _get_auth_chain_ids_using_cover_index_txn(
            self, txn: Cursor, room_id: str, event_ids: Collection[str],
            include_given: bool) -> List[str]:
        """Calculates the auth chain IDs using the chain index."""

        # First we look up the chain ID/sequence numbers for the given events.

        initial_events = set(event_ids)

        # All the events that we've found that are reachable from the events.
        seen_events = set()  # type: Set[str]

        # A map from chain ID to max sequence number of the given events.
        event_chains = {}  # type: Dict[int, int]

        sql = """
            SELECT event_id, chain_id, sequence_number
            FROM event_auth_chains
            WHERE %s
        """
        for batch in batch_iter(initial_events, 1000):
            clause, args = make_in_list_sql_clause(txn.database_engine,
                                                   "event_id", batch)
            txn.execute(sql % (clause, ), args)

            for event_id, chain_id, sequence_number in txn:
                seen_events.add(event_id)
                event_chains[chain_id] = max(sequence_number,
                                             event_chains.get(chain_id, 0))

        # Check that we actually have a chain ID for all the events.
        events_missing_chain_info = initial_events.difference(seen_events)
        if events_missing_chain_info:
            # This can happen due to e.g. downgrade/upgrade of the server. We
            # raise an exception and fall back to the previous algorithm.
            logger.info(
                "Unexpectedly found that events don't have chain IDs in room %s: %s",
                room_id,
                events_missing_chain_info,
            )
            raise _NoChainCoverIndex(room_id)

        # Now we look up all links for the chains we have, adding chains that
        # are reachable from any event.
        sql = """
            SELECT
                origin_chain_id, origin_sequence_number,
                target_chain_id, target_sequence_number
            FROM event_auth_chain_links
            WHERE %s
        """

        # A map from chain ID to max sequence number *reachable* from any event ID.
        chains = {}  # type: Dict[int, int]

        # Add all linked chains reachable from initial set of chains.
        for batch in batch_iter(event_chains, 1000):
            clause, args = make_in_list_sql_clause(txn.database_engine,
                                                   "origin_chain_id", batch)
            txn.execute(sql % (clause, ), args)

            for (
                    origin_chain_id,
                    origin_sequence_number,
                    target_chain_id,
                    target_sequence_number,
            ) in txn:
                # chains are only reachable if the origin sequence number of
                # the link is less than the max sequence number in the
                # origin chain.
                if origin_sequence_number <= event_chains.get(
                        origin_chain_id, 0):
                    chains[target_chain_id] = max(
                        target_sequence_number,
                        chains.get(target_chain_id, 0),
                    )

        # Add the initial set of chains, excluding the sequence corresponding to
        # initial event.
        for chain_id, seq_no in event_chains.items():
            chains[chain_id] = max(seq_no - 1, chains.get(chain_id, 0))

        # Now for each chain we figure out the maximum sequence number reachable
        # from *any* event ID. Events with a sequence less than that are in the
        # auth chain.
        if include_given:
            results = initial_events
        else:
            results = set()

        if isinstance(self.database_engine, PostgresEngine):
            # We can use `execute_values` to efficiently fetch the gaps when
            # using postgres.
            sql = """
                SELECT event_id
                FROM event_auth_chains AS c, (VALUES ?) AS l(chain_id, max_seq)
                WHERE
                    c.chain_id = l.chain_id
                    AND sequence_number <= max_seq
            """

            rows = txn.execute_values(sql, chains.items())
            results.update(r for r, in rows)
        else:
            # For SQLite we just fall back to doing a noddy for loop.
            sql = """
                SELECT event_id FROM event_auth_chains
                WHERE chain_id = ? AND sequence_number <= ?
            """
            for chain_id, max_no in chains.items():
                txn.execute(sql, (chain_id, max_no))
                results.update(r for r, in txn)

        return list(results)
Exemple #17
0
    def _get_auth_chain_difference_txn(self, txn,
                                       state_sets: List[Set[str]]) -> Set[str]:

        # Algorithm Description
        # ~~~~~~~~~~~~~~~~~~~~~
        #
        # The idea here is to basically walk the auth graph of each state set in
        # tandem, keeping track of which auth events are reachable by each state
        # set. If we reach an auth event we've already visited (via a different
        # state set) then we mark that auth event and all ancestors as reachable
        # by the state set. This requires that we keep track of the auth chains
        # in memory.
        #
        # Doing it in a such a way means that we can stop early if all auth
        # events we're currently walking are reachable by all state sets.
        #
        # *Note*: We can't stop walking an event's auth chain if it is reachable
        # by all state sets. This is because other auth chains we're walking
        # might be reachable only via the original auth chain. For example,
        # given the following auth chain:
        #
        #       A -> C -> D -> E
        #           /         /
        #       B -´---------´
        #
        # and state sets {A} and {B} then walking the auth chains of A and B
        # would immediately show that C is reachable by both. However, if we
        # stopped at C then we'd only reach E via the auth chain of B and so E
        # would errornously get included in the returned difference.
        #
        # The other thing that we do is limit the number of auth chains we walk
        # at once, due to practical limits (i.e. we can only query the database
        # with a limited set of parameters). We pick the auth chains we walk
        # each iteration based on their depth, in the hope that events with a
        # lower depth are likely reachable by those with higher depths.
        #
        # We could use any ordering that we believe would give a rough
        # topological ordering, e.g. origin server timestamp. If the ordering
        # chosen is not topological then the algorithm still produces the right
        # result, but perhaps a bit more inefficiently. This is why it is safe
        # to use "depth" here.

        initial_events = set(state_sets[0]).union(*state_sets[1:])

        # Dict from events in auth chains to which sets *cannot* reach them.
        # I.e. if the set is empty then all sets can reach the event.
        event_to_missing_sets = {
            event_id:
            {i
             for i, a in enumerate(state_sets) if event_id not in a}
            for event_id in initial_events
        }

        # The sorted list of events whose auth chains we should walk.
        search = []  # type: List[Tuple[int, str]]

        # We need to get the depth of the initial events for sorting purposes.
        sql = """
            SELECT depth, event_id FROM events
            WHERE %s
        """
        # the list can be huge, so let's avoid looking them all up in one massive
        # query.
        for batch in batch_iter(initial_events, 1000):
            clause, args = make_in_list_sql_clause(txn.database_engine,
                                                   "event_id", batch)
            txn.execute(sql % (clause, ), args)

            # I think building a temporary list with fetchall is more efficient than
            # just `search.extend(txn)`, but this is unconfirmed
            search.extend(txn.fetchall())

        # sort by depth
        search.sort()

        # Map from event to its auth events
        event_to_auth_events = {}  # type: Dict[str, Set[str]]

        base_sql = """
            SELECT a.event_id, auth_id, depth
            FROM event_auth AS a
            INNER JOIN events AS e ON (e.event_id = a.auth_id)
            WHERE
        """

        while search:
            # Check whether all our current walks are reachable by all state
            # sets. If so we can bail.
            if all(not event_to_missing_sets[eid] for _, eid in search):
                break

            # Fetch the auth events and their depths of the N last events we're
            # currently walking, either from cache or DB.
            search, chunk = search[:-100], search[-100:]

            found = []  # Results found  # type: List[Tuple[str, str, int]]
            to_fetch = []  # Event IDs to fetch from DB  # type: List[str]
            for _, event_id in chunk:
                res = self._event_auth_cache.get(event_id)
                if res is None:
                    to_fetch.append(event_id)
                else:
                    found.extend(
                        (event_id, auth_id, depth) for auth_id, depth in res)

            if to_fetch:
                clause, args = make_in_list_sql_clause(txn.database_engine,
                                                       "a.event_id", to_fetch)
                txn.execute(base_sql + clause, args)

                # We parse the results and add the to the `found` set and the
                # cache (note we need to batch up the results by event ID before
                # adding to the cache).
                to_cache = {}
                for event_id, auth_event_id, auth_event_depth in txn:
                    to_cache.setdefault(event_id, []).append(
                        (auth_event_id, auth_event_depth))
                    found.append((event_id, auth_event_id, auth_event_depth))

                for event_id, auth_events in to_cache.items():
                    self._event_auth_cache.set(event_id, auth_events)

            for event_id, auth_event_id, auth_event_depth in found:
                event_to_auth_events.setdefault(event_id,
                                                set()).add(auth_event_id)

                sets = event_to_missing_sets.get(auth_event_id)
                if sets is None:
                    # First time we're seeing this event, so we add it to the
                    # queue of things to fetch.
                    search.append((auth_event_depth, auth_event_id))

                    # Assume that this event is unreachable from any of the
                    # state sets until proven otherwise
                    sets = event_to_missing_sets[auth_event_id] = set(
                        range(len(state_sets)))
                else:
                    # We've previously seen this event, so look up its auth
                    # events and recursively mark all ancestors as reachable
                    # by the current event's state set.
                    a_ids = event_to_auth_events.get(auth_event_id)
                    while a_ids:
                        new_aids = set()
                        for a_id in a_ids:
                            event_to_missing_sets[a_id].intersection_update(
                                event_to_missing_sets[event_id])

                            b = event_to_auth_events.get(a_id)
                            if b:
                                new_aids.update(b)

                        a_ids = new_aids

                # Mark that the auth event is reachable by the approriate sets.
                sets.intersection_update(event_to_missing_sets[event_id])

            search.sort()

        # Return all events where not all sets can reach them.
        return {eid for eid, n in event_to_missing_sets.items() if n}
    def _get_bare_e2e_cross_signing_keys_bulk_txn(
        self,
        txn: Connection,
        user_ids: List[str],
    ) -> Dict[str, Dict[str, dict]]:
        """Returns the cross-signing keys for a set of users.  The output of this
        function should be passed to _get_e2e_cross_signing_signatures_txn if
        the signatures for the calling user need to be fetched.

        Args:
            txn (twisted.enterprise.adbapi.Connection): db connection
            user_ids (list[str]): the users whose keys are being requested

        Returns:
            dict[str, dict[str, dict]]: mapping from user ID to key type to key
                data.  If a user's cross-signing keys were not found, their user
                ID will not be in the dict.

        """
        result = {}

        for user_chunk in batch_iter(user_ids, 100):
            clause, params = make_in_list_sql_clause(txn.database_engine,
                                                     "user_id", user_chunk)

            # Fetch the latest key for each type per user.
            if isinstance(self.database_engine, PostgresEngine):
                # The `DISTINCT ON` clause will pick the *first* row it
                # encounters, so ordering by stream ID desc will ensure we get
                # the latest key.
                sql = """
                    SELECT DISTINCT ON (user_id, keytype) user_id, keytype, keydata, stream_id
                        FROM e2e_cross_signing_keys
                        WHERE %(clause)s
                        ORDER BY user_id, keytype, stream_id DESC
                """ % {
                    "clause": clause
                }
            else:
                # SQLite has special handling for bare columns when using
                # MIN/MAX with a `GROUP BY` clause where it picks the value from
                # a row that matches the MIN/MAX.
                sql = """
                    SELECT user_id, keytype, keydata, MAX(stream_id)
                        FROM e2e_cross_signing_keys
                        WHERE %(clause)s
                        GROUP BY user_id, keytype
                """ % {
                    "clause": clause
                }

            txn.execute(sql, params)
            rows = self.db_pool.cursor_to_dict(txn)

            for row in rows:
                user_id = row["user_id"]
                key_type = row["keytype"]
                key = db_to_json(row["keydata"])
                user_info = result.setdefault(user_id, {})
                user_info[key_type] = key

        return result