Example #1
0
        def _get_users_server_still_shares_room_with_txn(txn):
            sql = """
                SELECT state_key FROM current_state_events
                WHERE
                    type = 'm.room.member'
                    AND membership = 'join'
                    AND %s
                GROUP BY state_key
            """

            clause, args = make_in_list_sql_clause(self.database_engine,
                                                   "state_key", user_ids)

            txn.execute(sql % (clause, ), args)

            return {row[0] for row in txn}
Example #2
0
        def _get_users_whose_devices_changed_txn(txn):
            changes = set()

            sql = """
                SELECT DISTINCT user_id FROM device_lists_stream
                WHERE stream_id > ?
                AND
            """

            for chunk in batch_iter(to_check, 100):
                clause, args = make_in_list_sql_clause(txn.database_engine,
                                                       "user_id", chunk)
                txn.execute(sql + clause, (from_key, ) + tuple(args))
                changes.update(user_id for user_id, in txn)

            return changes
Example #3
0
            def graph_to_linear(txn):
                clause, args = make_in_list_sql_clause(self.database_engine,
                                                       "event_id", event_ids)

                sql = """
                    SELECT event_id WHERE room_id = ? AND stream_ordering IN (
                        SELECT max(stream_ordering) WHERE %s
                    )
                """ % (clause, )

                txn.execute(sql, [room_id] + list(args))
                rows = txn.fetchall()
                if rows:
                    return rows[0][0]
                else:
                    raise RuntimeError("Unrecognized event_ids: %r" %
                                       (event_ids, ))
Example #4
0
    def _update_presence_txn(self, txn, stream_orderings, presence_states):
        for stream_id, state in zip(stream_orderings, presence_states):
            txn.call_after(self.presence_stream_cache.entity_has_changed,
                           state.user_id, stream_id)
            txn.call_after(self._get_presence_for_user.invalidate,
                           (state.user_id, ))

        # Delete old rows to stop database from getting really big
        sql = "DELETE FROM presence_stream WHERE stream_id < ? AND "

        for states in batch_iter(presence_states, 50):
            clause, args = make_in_list_sql_clause(self.database_engine,
                                                   "user_id",
                                                   [s.user_id for s in states])
            txn.execute(sql + clause, [stream_id] + list(args))

        # Actually insert new rows
        self.db_pool.simple_insert_many_txn(
            txn,
            table="presence_stream",
            keys=(
                "stream_id",
                "user_id",
                "state",
                "last_active_ts",
                "last_federation_update_ts",
                "last_user_sync_ts",
                "status_msg",
                "currently_active",
                "instance_name",
            ),
            values=[(
                stream_id,
                state.user_id,
                state.state,
                state.last_active_ts,
                state.last_federation_update_ts,
                state.last_user_sync_ts,
                state.status_msg,
                state.currently_active,
                self._instance_name,
            ) for stream_id, state in zip(stream_orderings, presence_states)],
        )
Example #5
0
    def get_last_receipt_for_user_txn(
        self,
        txn: LoggingTransaction,
        user_id: str,
        room_id: str,
        receipt_types: Collection[str],
    ) -> Optional[Tuple[str, int]]:
        """
        Fetch the event ID and stream_ordering for the latest receipt in a room
        with one of the given receipt types.

        Args:
            user_id: The user to fetch receipts for.
            room_id: The room ID to fetch the receipt for.
            receipt_type: The receipt types to fetch.

        Returns:
            The latest receipt, if one exists.
        """

        clause, args = make_in_list_sql_clause(
            self.database_engine, "receipt_type", receipt_types
        )

        sql = f"""
            SELECT event_id, stream_ordering
            FROM receipts_linearized
            INNER JOIN events USING (room_id, event_id)
            WHERE {clause}
            AND user_id = ?
            AND room_id = ?
            ORDER BY stream_ordering DESC
            LIMIT 1
        """

        args.extend((user_id, room_id))
        txn.execute(sql, args)

        return cast(Optional[Tuple[str, int]], txn.fetchone())
Example #6
0
    def _graph_to_linear(
        self, txn: LoggingTransaction, room_id: str, event_ids: List[str]
    ) -> str:
        """
        Generate a linearized event from a list of events (i.e. a list of forward
        extremities in the room).

        This should allow for calculation of the correct read receipt even if
        servers have different event ordering.

        Args:
            txn: The transaction
            room_id: The room ID the events are in.
            event_ids: The list of event IDs to linearize.

        Returns:
            The linearized event ID.
        """
        # TODO: Make this better.
        clause, args = make_in_list_sql_clause(
            self.database_engine, "event_id", event_ids
        )

        sql = """
            SELECT event_id WHERE room_id = ? AND stream_ordering IN (
                SELECT max(stream_ordering) WHERE %s
            )
        """ % (
            clause,
        )

        txn.execute(sql, [room_id] + list(args))
        rows = txn.fetchall()
        if rows:
            return rows[0][0]
        else:
            raise RuntimeError("Unrecognized event_ids: %r" % (event_ids,))
Example #7
0
    def _get_rooms_for_local_user_where_membership_is_txn(
            self, txn, user_id, membership_list):
        # Paranoia check.
        if not self.hs.is_mine_id(user_id):
            raise Exception(
                "Cannot call 'get_rooms_for_local_user_where_membership_is' on non-local user %r"
                % (user_id, ), )

        clause, args = make_in_list_sql_clause(self.database_engine,
                                               "c.membership", membership_list)

        sql = """
            SELECT room_id, e.sender, c.membership, event_id, e.stream_ordering
            FROM local_current_membership AS c
            INNER JOIN events AS e USING (room_id, event_id)
            WHERE
                user_id = ?
                AND %s
        """ % (clause, )

        txn.execute(sql, (user_id, *args))
        results = [RoomsForUser(**r) for r in self.db.cursor_to_dict(txn)]

        return results
Example #8
0
    def _get_auth_chain_difference_txn(self, txn,
                                       state_sets: List[Set[str]]) -> Set[str]:

        # Algorithm Description
        # ~~~~~~~~~~~~~~~~~~~~~
        #
        # The idea here is to basically walk the auth graph of each state set in
        # tandem, keeping track of which auth events are reachable by each state
        # set. If we reach an auth event we've already visited (via a different
        # state set) then we mark that auth event and all ancestors as reachable
        # by the state set. This requires that we keep track of the auth chains
        # in memory.
        #
        # Doing it in a such a way means that we can stop early if all auth
        # events we're currently walking are reachable by all state sets.
        #
        # *Note*: We can't stop walking an event's auth chain if it is reachable
        # by all state sets. This is because other auth chains we're walking
        # might be reachable only via the original auth chain. For example,
        # given the following auth chain:
        #
        #       A -> C -> D -> E
        #           /         /
        #       B -´---------´
        #
        # and state sets {A} and {B} then walking the auth chains of A and B
        # would immediately show that C is reachable by both. However, if we
        # stopped at C then we'd only reach E via the auth chain of B and so E
        # would errornously get included in the returned difference.
        #
        # The other thing that we do is limit the number of auth chains we walk
        # at once, due to practical limits (i.e. we can only query the database
        # with a limited set of parameters). We pick the auth chains we walk
        # each iteration based on their depth, in the hope that events with a
        # lower depth are likely reachable by those with higher depths.
        #
        # We could use any ordering that we believe would give a rough
        # topological ordering, e.g. origin server timestamp. If the ordering
        # chosen is not topological then the algorithm still produces the right
        # result, but perhaps a bit more inefficiently. This is why it is safe
        # to use "depth" here.

        initial_events = set(state_sets[0]).union(*state_sets[1:])

        # Dict from events in auth chains to which sets *cannot* reach them.
        # I.e. if the set is empty then all sets can reach the event.
        event_to_missing_sets = {
            event_id:
            {i
             for i, a in enumerate(state_sets) if event_id not in a}
            for event_id in initial_events
        }

        # The sorted list of events whose auth chains we should walk.
        search = []  # type: List[Tuple[int, str]]

        # We need to get the depth of the initial events for sorting purposes.
        sql = """
            SELECT depth, event_id FROM events
            WHERE %s
        """
        # the list can be huge, so let's avoid looking them all up in one massive
        # query.
        for batch in batch_iter(initial_events, 1000):
            clause, args = make_in_list_sql_clause(txn.database_engine,
                                                   "event_id", batch)
            txn.execute(sql % (clause, ), args)

            # I think building a temporary list with fetchall is more efficient than
            # just `search.extend(txn)`, but this is unconfirmed
            search.extend(txn.fetchall())

        # sort by depth
        search.sort()

        # Map from event to its auth events
        event_to_auth_events = {}  # type: Dict[str, Set[str]]

        base_sql = """
            SELECT a.event_id, auth_id, depth
            FROM event_auth AS a
            INNER JOIN events AS e ON (e.event_id = a.auth_id)
            WHERE
        """

        while search:
            # Check whether all our current walks are reachable by all state
            # sets. If so we can bail.
            if all(not event_to_missing_sets[eid] for _, eid in search):
                break

            # Fetch the auth events and their depths of the N last events we're
            # currently walking, either from cache or DB.
            search, chunk = search[:-100], search[-100:]

            found = []  # Results found  # type: List[Tuple[str, str, int]]
            to_fetch = []  # Event IDs to fetch from DB  # type: List[str]
            for _, event_id in chunk:
                res = self._event_auth_cache.get(event_id)
                if res is None:
                    to_fetch.append(event_id)
                else:
                    found.extend(
                        (event_id, auth_id, depth) for auth_id, depth in res)

            if to_fetch:
                clause, args = make_in_list_sql_clause(txn.database_engine,
                                                       "a.event_id", to_fetch)
                txn.execute(base_sql + clause, args)

                # We parse the results and add the to the `found` set and the
                # cache (note we need to batch up the results by event ID before
                # adding to the cache).
                to_cache = {}
                for event_id, auth_event_id, auth_event_depth in txn:
                    to_cache.setdefault(event_id, []).append(
                        (auth_event_id, auth_event_depth))
                    found.append((event_id, auth_event_id, auth_event_depth))

                for event_id, auth_events in to_cache.items():
                    self._event_auth_cache.set(event_id, auth_events)

            for event_id, auth_event_id, auth_event_depth in found:
                event_to_auth_events.setdefault(event_id,
                                                set()).add(auth_event_id)

                sets = event_to_missing_sets.get(auth_event_id)
                if sets is None:
                    # First time we're seeing this event, so we add it to the
                    # queue of things to fetch.
                    search.append((auth_event_depth, auth_event_id))

                    # Assume that this event is unreachable from any of the
                    # state sets until proven otherwise
                    sets = event_to_missing_sets[auth_event_id] = set(
                        range(len(state_sets)))
                else:
                    # We've previously seen this event, so look up its auth
                    # events and recursively mark all ancestors as reachable
                    # by the current event's state set.
                    a_ids = event_to_auth_events.get(auth_event_id)
                    while a_ids:
                        new_aids = set()
                        for a_id in a_ids:
                            event_to_missing_sets[a_id].intersection_update(
                                event_to_missing_sets[event_id])

                            b = event_to_auth_events.get(a_id)
                            if b:
                                new_aids.update(b)

                        a_ids = new_aids

                # Mark that the auth event is reachable by the approriate sets.
                sets.intersection_update(event_to_missing_sets[event_id])

            search.sort()

        # Return all events where not all sets can reach them.
        return {eid for eid, n in event_to_missing_sets.items() if n}
        def _cleanup_extremities_bg_update_txn(txn):
            # The set of extremity event IDs that we're checking this round
            original_set = set()

            # A dict[str, set[str]] of event ID to their prev events.
            graph = {}

            # The set of descendants of the original set that are not rejected
            # nor soft-failed. Ancestors of these events should be removed
            # from the forward extremities table.
            non_rejected_leaves = set()

            # Set of event IDs that have been soft failed, and for which we
            # should check if they have descendants which haven't been soft
            # failed.
            soft_failed_events_to_lookup = set()

            # First, we get `batch_size` events from the table, pulling out
            # their successor events, if any, and the successor events'
            # rejection status.
            txn.execute(
                """SELECT prev_event_id, event_id, internal_metadata,
                    rejections.event_id IS NOT NULL, events.outlier
                FROM (
                    SELECT event_id AS prev_event_id
                    FROM _extremities_to_check
                    LIMIT ?
                ) AS f
                LEFT JOIN event_edges USING (prev_event_id)
                LEFT JOIN events USING (event_id)
                LEFT JOIN event_json USING (event_id)
                LEFT JOIN rejections USING (event_id)
                """,
                (batch_size, ),
            )

            for prev_event_id, event_id, metadata, rejected, outlier in txn:
                original_set.add(prev_event_id)

                if not event_id or outlier:
                    # Common case where the forward extremity doesn't have any
                    # descendants.
                    continue

                graph.setdefault(event_id, set()).add(prev_event_id)

                soft_failed = False
                if metadata:
                    soft_failed = db_to_json(metadata).get("soft_failed")

                if soft_failed or rejected:
                    soft_failed_events_to_lookup.add(event_id)
                else:
                    non_rejected_leaves.add(event_id)

            # Now we recursively check all the soft-failed descendants we
            # found above in the same way, until we have nothing left to
            # check.
            while soft_failed_events_to_lookup:
                # We only want to do 100 at a time, so we split given list
                # into two.
                batch = list(soft_failed_events_to_lookup)
                to_check, to_defer = batch[:100], batch[100:]
                soft_failed_events_to_lookup = set(to_defer)

                sql = """SELECT prev_event_id, event_id, internal_metadata,
                    rejections.event_id IS NOT NULL
                    FROM event_edges
                    INNER JOIN events USING (event_id)
                    INNER JOIN event_json USING (event_id)
                    LEFT JOIN rejections USING (event_id)
                    WHERE
                        NOT events.outlier
                        AND
                """
                clause, args = make_in_list_sql_clause(self.database_engine,
                                                       "prev_event_id",
                                                       to_check)
                txn.execute(sql + clause, list(args))

                for prev_event_id, event_id, metadata, rejected in txn:
                    if event_id in graph:
                        # Already handled this event previously, but we still
                        # want to record the edge.
                        graph[event_id].add(prev_event_id)
                        continue

                    graph[event_id] = {prev_event_id}

                    soft_failed = db_to_json(metadata).get("soft_failed")
                    if soft_failed or rejected:
                        soft_failed_events_to_lookup.add(event_id)
                    else:
                        non_rejected_leaves.add(event_id)

            # We have a set of non-soft-failed descendants, so we recurse up
            # the graph to find all ancestors and add them to the set of event
            # IDs that we can delete from forward extremities table.
            to_delete = set()
            while non_rejected_leaves:
                event_id = non_rejected_leaves.pop()
                prev_event_ids = graph.get(event_id, set())
                non_rejected_leaves.update(prev_event_ids)
                to_delete.update(prev_event_ids)

            to_delete.intersection_update(original_set)

            deleted = self.db_pool.simple_delete_many_txn(
                txn=txn,
                table="event_forward_extremities",
                column="event_id",
                iterable=to_delete,
                keyvalues={},
            )

            logger.info(
                "Deleted %d forward extremities of %d checked, to clean up #5269",
                deleted,
                len(original_set),
            )

            if deleted:
                # We now need to invalidate the caches of these rooms
                rows = self.db_pool.simple_select_many_txn(
                    txn,
                    table="events",
                    column="event_id",
                    iterable=to_delete,
                    keyvalues={},
                    retcols=("room_id", ),
                )
                room_ids = {row["room_id"] for row in rows}
                for room_id in room_ids:
                    txn.call_after(
                        self.get_latest_event_ids_in_room.invalidate,
                        (room_id, ))

            self.db_pool.simple_delete_many_txn(
                txn=txn,
                table="_extremities_to_check",
                column="event_id",
                iterable=original_set,
                keyvalues={},
            )

            return len(original_set)
Example #10
0
    def search_rooms(self,
                     room_ids,
                     search_term,
                     keys,
                     limit,
                     pagination_token=None):
        """Performs a full text search over events with given keys.

        Args:
            room_id (list): The room_ids to search in
            search_term (str): Search term to search for
            keys (list): List of keys to search in, currently supports
                "content.body", "content.name", "content.topic"
            pagination_token (str): A pagination token previously returned

        Returns:
            list of dicts
        """
        clauses = []

        search_query = _parse_query(self.database_engine, search_term)

        args = []

        # Make sure we don't explode because the person is in too many rooms.
        # We filter the results below regardless.
        if len(room_ids) < 500:
            clause, args = make_in_list_sql_clause(self.database_engine,
                                                   "room_id", room_ids)
            clauses = [clause]

        local_clauses = []
        for key in keys:
            local_clauses.append("key = ?")
            args.append(key)

        clauses.append("(%s)" % (" OR ".join(local_clauses), ))

        # take copies of the current args and clauses lists, before adding
        # pagination clauses to main query.
        count_args = list(args)
        count_clauses = list(clauses)

        if pagination_token:
            try:
                origin_server_ts, stream = pagination_token.split(",")
                origin_server_ts = int(origin_server_ts)
                stream = int(stream)
            except Exception:
                raise SynapseError(400, "Invalid pagination token")

            clauses.append(
                "(origin_server_ts < ?"
                " OR (origin_server_ts = ? AND stream_ordering < ?))")
            args.extend([origin_server_ts, origin_server_ts, stream])

        if isinstance(self.database_engine, PostgresEngine):
            sql = (
                "SELECT ts_rank_cd(vector, to_tsquery('english', ?)) as rank,"
                " origin_server_ts, stream_ordering, room_id, event_id"
                " FROM event_search"
                " WHERE vector @@ to_tsquery('english', ?) AND ")
            args = [search_query, search_query] + args

            count_sql = ("SELECT room_id, count(*) as count FROM event_search"
                         " WHERE vector @@ to_tsquery('english', ?) AND ")
            count_args = [search_query] + count_args
        elif isinstance(self.database_engine, Sqlite3Engine):
            # We use CROSS JOIN here to ensure we use the right indexes.
            # https://sqlite.org/optoverview.html#crossjoin
            #
            # We want to use the full text search index on event_search to
            # extract all possible matches first, then lookup those matches
            # in the events table to get the topological ordering. We need
            # to use the indexes in this order because sqlite refuses to
            # MATCH unless it uses the full text search index
            sql = (
                "SELECT rank(matchinfo) as rank, room_id, event_id,"
                " origin_server_ts, stream_ordering"
                " FROM (SELECT key, event_id, matchinfo(event_search) as matchinfo"
                " FROM event_search"
                " WHERE value MATCH ?"
                " )"
                " CROSS JOIN events USING (event_id)"
                " WHERE ")
            args = [search_query] + args

            count_sql = ("SELECT room_id, count(*) as count FROM event_search"
                         " WHERE value MATCH ? AND ")
            count_args = [search_term] + count_args
        else:
            # This should be unreachable.
            raise Exception("Unrecognized database engine")

        sql += " AND ".join(clauses)
        count_sql += " AND ".join(count_clauses)

        # We add an arbitrary limit here to ensure we don't try to pull the
        # entire table from the database.
        if isinstance(self.database_engine, PostgresEngine):
            sql += (" ORDER BY origin_server_ts DESC NULLS LAST,"
                    " stream_ordering DESC NULLS LAST LIMIT ?")
        elif isinstance(self.database_engine, Sqlite3Engine):
            sql += " ORDER BY origin_server_ts DESC, stream_ordering DESC LIMIT ?"
        else:
            raise Exception("Unrecognized database engine")

        args.append(limit)

        results = yield self.db.execute("search_rooms", self.db.cursor_to_dict,
                                        sql, *args)

        results = list(filter(lambda row: row["room_id"] in room_ids, results))

        # We set redact_behaviour to BLOCK here to prevent redacted events being returned in
        # search results (which is a data leak)
        events = yield self.get_events_as_list(
            [r["event_id"] for r in results],
            redact_behaviour=EventRedactBehaviour.BLOCK,
        )

        event_map = {ev.event_id: ev for ev in events}

        highlights = None
        if isinstance(self.database_engine, PostgresEngine):
            highlights = yield self._find_highlights_in_postgres(
                search_query, events)

        count_sql += " GROUP BY room_id"

        count_results = yield self.db.execute("search_rooms_count",
                                              self.db.cursor_to_dict,
                                              count_sql, *count_args)

        count = sum(row["count"] for row in count_results
                    if row["room_id"] in room_ids)

        return {
            "results": [{
                "event":
                event_map[r["event_id"]],
                "rank":
                r["rank"],
                "pagination_token":
                "%s,%s" % (r["origin_server_ts"], r["stream_ordering"]),
            } for r in results if r["event_id"] in event_map],
            "highlights":
            highlights,
            "count":
            count,
        }
Example #11
0
    def search_msgs(self, room_ids, search_term, keys):
        """Performs a full text search over events with given keys.

        Args:
            room_ids (list): List of room ids to search in
            search_term (str): Search term to search for
            keys (list): List of keys to search in, currently supports
                "content.body", "content.name", "content.topic"

        Returns:
            list of dicts
        """
        clauses = []

        search_query = _parse_query(self.database_engine, search_term)

        args = []

        # Make sure we don't explode because the person is in too many rooms.
        # We filter the results below regardless.
        if len(room_ids) < 500:
            clause, args = make_in_list_sql_clause(self.database_engine,
                                                   "room_id", room_ids)
            clauses = [clause]

        local_clauses = []
        for key in keys:
            local_clauses.append("key = ?")
            args.append(key)

        clauses.append("(%s)" % (" OR ".join(local_clauses), ))

        count_args = args
        count_clauses = clauses

        if isinstance(self.database_engine, PostgresEngine):
            sql = (
                "SELECT ts_rank_cd(vector, to_tsquery('english', ?)) AS rank,"
                " room_id, event_id"
                " FROM event_search"
                " WHERE vector @@ to_tsquery('english', ?)")
            args = [search_query, search_query] + args

            count_sql = ("SELECT room_id, count(*) as count FROM event_search"
                         " WHERE vector @@ to_tsquery('english', ?)")
            count_args = [search_query] + count_args
        elif isinstance(self.database_engine, Sqlite3Engine):
            sql = (
                "SELECT rank(matchinfo(event_search)) as rank, room_id, event_id"
                " FROM event_search"
                " WHERE value MATCH ?")
            args = [search_query] + args

            count_sql = ("SELECT room_id, count(*) as count FROM event_search"
                         " WHERE value MATCH ?")
            count_args = [search_term] + count_args
        else:
            # This should be unreachable.
            raise Exception("Unrecognized database engine")

        for clause in clauses:
            sql += " AND " + clause

        for clause in count_clauses:
            count_sql += " AND " + clause

        # We add an arbitrary limit here to ensure we don't try to pull the
        # entire table from the database.
        sql += " ORDER BY rank DESC LIMIT 500"

        results = yield self.db.execute("search_msgs", self.db.cursor_to_dict,
                                        sql, *args)

        results = list(filter(lambda row: row["room_id"] in room_ids, results))

        # We set redact_behaviour to BLOCK here to prevent redacted events being returned in
        # search results (which is a data leak)
        events = yield self.get_events_as_list(
            [r["event_id"] for r in results],
            redact_behaviour=EventRedactBehaviour.BLOCK,
        )

        event_map = {ev.event_id: ev for ev in events}

        highlights = None
        if isinstance(self.database_engine, PostgresEngine):
            highlights = yield self._find_highlights_in_postgres(
                search_query, events)

        count_sql += " GROUP BY room_id"

        count_results = yield self.db.execute("search_rooms_count",
                                              self.db.cursor_to_dict,
                                              count_sql, *count_args)

        count = sum(row["count"] for row in count_results
                    if row["room_id"] in room_ids)

        return {
            "results": [{
                "event": event_map[r["event_id"]],
                "rank": r["rank"]
            } for r in results if r["event_id"] in event_map],
            "highlights":
            highlights,
            "count":
            count,
        }
Example #12
0
    def _add_messages_to_local_device_inbox_txn(self, txn, stream_id,
                                                messages_by_user_then_device):
        # Compatible method of performing an upsert
        sql = "SELECT stream_id FROM device_max_stream_id"

        txn.execute(sql)
        rows = txn.fetchone()
        if rows:
            db_stream_id = rows[0]
            if db_stream_id < stream_id:
                # Insert the new stream_id
                sql = "UPDATE device_max_stream_id SET stream_id = ?"
        else:
            # No rows, perform an insert
            sql = "INSERT INTO device_max_stream_id (stream_id) VALUES (?)"

        txn.execute(sql, (stream_id, ))

        local_by_user_then_device = {}
        for user_id, messages_by_device in messages_by_user_then_device.items(
        ):
            messages_json_for_user = {}
            devices = list(messages_by_device.keys())
            if len(devices) == 1 and devices[0] == "*":
                # Handle wildcard device_ids.
                sql = "SELECT device_id FROM devices" " WHERE user_id = ?"
                txn.execute(sql, (user_id, ))
                message_json = json.dumps(messages_by_device["*"])
                for row in txn:
                    # Add the message for all devices for this user on this
                    # server.
                    device = row[0]
                    messages_json_for_user[device] = message_json
            else:
                if not devices:
                    continue

                clause, args = make_in_list_sql_clause(txn.database_engine,
                                                       "device_id", devices)
                sql = "SELECT device_id FROM devices WHERE user_id = ? AND " + clause

                # TODO: Maybe this needs to be done in batches if there are
                # too many local devices for a given user.
                txn.execute(sql, [user_id] + list(args))
                for row in txn:
                    # Only insert into the local inbox if the device exists on
                    # this server
                    device = row[0]
                    message_json = json.dumps(messages_by_device[device])
                    messages_json_for_user[device] = message_json

            if messages_json_for_user:
                local_by_user_then_device[user_id] = messages_json_for_user

        if not local_by_user_then_device:
            return

        sql = ("INSERT INTO device_inbox"
               " (user_id, device_id, stream_id, message_json)"
               " VALUES (?,?,?,?)")
        rows = []
        for user_id, messages_by_device in local_by_user_then_device.items():
            for device_id, message_json in messages_by_device.items():
                rows.append((user_id, device_id, stream_id, message_json))

        txn.executemany(sql, rows)
Example #13
0
    def _get_auth_chain_difference_using_cover_index_txn(
            self, txn: Cursor, room_id: str,
            state_sets: List[Set[str]]) -> Set[str]:
        """Calculates the auth chain difference using the chain index.

        See docs/auth_chain_difference_algorithm.md for details
        """

        # First we look up the chain ID/sequence numbers for all the events, and
        # work out the chain/sequence numbers reachable from each state set.

        initial_events = set(state_sets[0]).union(*state_sets[1:])

        # Map from event_id -> (chain ID, seq no)
        chain_info = {}  # type: Dict[str, Tuple[int, int]]

        # Map from chain ID -> seq no -> event Id
        chain_to_event = {}  # type: Dict[int, Dict[int, str]]

        # All the chains that we've found that are reachable from the state
        # sets.
        seen_chains = set()  # type: Set[int]

        sql = """
            SELECT event_id, chain_id, sequence_number
            FROM event_auth_chains
            WHERE %s
        """
        for batch in batch_iter(initial_events, 1000):
            clause, args = make_in_list_sql_clause(txn.database_engine,
                                                   "event_id", batch)
            txn.execute(sql % (clause, ), args)

            for event_id, chain_id, sequence_number in txn:
                chain_info[event_id] = (chain_id, sequence_number)
                seen_chains.add(chain_id)
                chain_to_event.setdefault(chain_id,
                                          {})[sequence_number] = event_id

        # Check that we actually have a chain ID for all the events.
        events_missing_chain_info = initial_events.difference(chain_info)
        if events_missing_chain_info:
            # This can happen due to e.g. downgrade/upgrade of the server. We
            # raise an exception and fall back to the previous algorithm.
            logger.info(
                "Unexpectedly found that events don't have chain IDs in room %s: %s",
                room_id,
                events_missing_chain_info,
            )
            raise _NoChainCoverIndex(room_id)

        # Corresponds to `state_sets`, except as a map from chain ID to max
        # sequence number reachable from the state set.
        set_to_chain = []  # type: List[Dict[int, int]]
        for state_set in state_sets:
            chains = {}  # type: Dict[int, int]
            set_to_chain.append(chains)

            for event_id in state_set:
                chain_id, seq_no = chain_info[event_id]

                chains[chain_id] = max(seq_no, chains.get(chain_id, 0))

        # Now we look up all links for the chains we have, adding chains to
        # set_to_chain that are reachable from each set.
        sql = """
            SELECT
                origin_chain_id, origin_sequence_number,
                target_chain_id, target_sequence_number
            FROM event_auth_chain_links
            WHERE %s
        """

        # (We need to take a copy of `seen_chains` as we want to mutate it in
        # the loop)
        for batch in batch_iter(set(seen_chains), 1000):
            clause, args = make_in_list_sql_clause(txn.database_engine,
                                                   "origin_chain_id", batch)
            txn.execute(sql % (clause, ), args)

            for (
                    origin_chain_id,
                    origin_sequence_number,
                    target_chain_id,
                    target_sequence_number,
            ) in txn:
                for chains in set_to_chain:
                    # chains are only reachable if the origin sequence number of
                    # the link is less than the max sequence number in the
                    # origin chain.
                    if origin_sequence_number <= chains.get(
                            origin_chain_id, 0):
                        chains[target_chain_id] = max(
                            target_sequence_number,
                            chains.get(target_chain_id, 0),
                        )

                seen_chains.add(target_chain_id)

        # Now for each chain we figure out the maximum sequence number reachable
        # from *any* state set and the minimum sequence number reachable from
        # *all* state sets. Events in that range are in the auth chain
        # difference.
        result = set()

        # Mapping from chain ID to the range of sequence numbers that should be
        # pulled from the database.
        chain_to_gap = {}  # type: Dict[int, Tuple[int, int]]

        for chain_id in seen_chains:
            min_seq_no = min(
                chains.get(chain_id, 0) for chains in set_to_chain)
            max_seq_no = max(
                chains.get(chain_id, 0) for chains in set_to_chain)

            if min_seq_no < max_seq_no:
                # We have a non empty gap, try and fill it from the events that
                # we have, otherwise add them to the list of gaps to pull out
                # from the DB.
                for seq_no in range(min_seq_no + 1, max_seq_no + 1):
                    event_id = chain_to_event.get(chain_id, {}).get(seq_no)
                    if event_id:
                        result.add(event_id)
                    else:
                        chain_to_gap[chain_id] = (min_seq_no, max_seq_no)
                        break

        if not chain_to_gap:
            # If there are no gaps to fetch, we're done!
            return result

        if isinstance(self.database_engine, PostgresEngine):
            # We can use `execute_values` to efficiently fetch the gaps when
            # using postgres.
            sql = """
                SELECT event_id
                FROM event_auth_chains AS c, (VALUES ?) AS l(chain_id, min_seq, max_seq)
                WHERE
                    c.chain_id = l.chain_id
                    AND min_seq < sequence_number AND sequence_number <= max_seq
            """

            args = [(chain_id, min_no, max_no)
                    for chain_id, (min_no, max_no) in chain_to_gap.items()]

            rows = txn.execute_values(sql, args)
            result.update(r for r, in rows)
        else:
            # For SQLite we just fall back to doing a noddy for loop.
            sql = """
                SELECT event_id FROM event_auth_chains
                WHERE chain_id = ? AND ? < sequence_number AND sequence_number <= ?
            """
            for chain_id, (min_no, max_no) in chain_to_gap.items():
                txn.execute(sql, (chain_id, min_no, max_no))
                result.update(r for r, in txn)

        return result
Example #14
0
    def _fetch_event_rows(self, txn, event_ids):
        """Fetch event rows from the database

        Events which are not found are omitted from the result.

        The returned per-event dicts contain the following keys:

         * event_id (str)

         * json (str): json-encoded event structure

         * internal_metadata (str): json-encoded internal metadata dict

         * format_version (int|None): The format of the event. Hopefully one
           of EventFormatVersions. 'None' means the event predates
           EventFormatVersions (so the event is format V1).

         * rejected_reason (str|None): if the event was rejected, the reason
           why.

         * redactions (List[str]): a list of event-ids which (claim to) redact
           this event.

        Args:
            txn (twisted.enterprise.adbapi.Connection):
            event_ids (Iterable[str]): event IDs to fetch

        Returns:
            Dict[str, Dict]: a map from event id to event info.
        """
        event_dict = {}
        for evs in batch_iter(event_ids, 200):
            sql = ("SELECT "
                   " e.event_id, "
                   " e.internal_metadata,"
                   " e.json,"
                   " e.format_version, "
                   " rej.reason "
                   " FROM event_json as e"
                   " LEFT JOIN rejections as rej USING (event_id)"
                   " WHERE ")

            clause, args = make_in_list_sql_clause(txn.database_engine,
                                                   "e.event_id", evs)

            txn.execute(sql + clause, args)

            for row in txn:
                event_id = row[0]
                event_dict[event_id] = {
                    "event_id": event_id,
                    "internal_metadata": row[1],
                    "json": row[2],
                    "format_version": row[3],
                    "rejected_reason": row[4],
                    "redactions": [],
                }

            # check for redactions
            redactions_sql = "SELECT event_id, redacts FROM redactions WHERE "

            clause, args = make_in_list_sql_clause(txn.database_engine,
                                                   "redacts", evs)

            txn.execute(redactions_sql + clause, args)

            for (redacter, redacted) in txn:
                d = event_dict.get(redacted)
                if d:
                    d["redactions"].append(redacter)

        return event_dict
Example #15
0
    def _get_rooms_for_user_where_membership_is_txn(
        self, txn, user_id, membership_list
    ):

        do_invite = Membership.INVITE in membership_list
        membership_list = [m for m in membership_list if m != Membership.INVITE]

        results = []
        if membership_list:
            if self._current_state_events_membership_up_to_date:
                clause, args = make_in_list_sql_clause(
                    self.database_engine, "c.membership", membership_list
                )
                sql = """
                    SELECT room_id, e.sender, c.membership, event_id, e.stream_ordering
                    FROM current_state_events AS c
                    INNER JOIN events AS e USING (room_id, event_id)
                    WHERE
                        c.type = 'm.room.member'
                        AND state_key = ?
                        AND %s
                """ % (
                    clause,
                )
            else:
                clause, args = make_in_list_sql_clause(
                    self.database_engine, "m.membership", membership_list
                )
                sql = """
                    SELECT room_id, e.sender, m.membership, event_id, e.stream_ordering
                    FROM current_state_events AS c
                    INNER JOIN room_memberships AS m USING (room_id, event_id)
                    INNER JOIN events AS e USING (room_id, event_id)
                    WHERE
                        c.type = 'm.room.member'
                        AND state_key = ?
                        AND %s
                """ % (
                    clause,
                )

            txn.execute(sql, (user_id, *args))
            results = [RoomsForUser(**r) for r in self.db.cursor_to_dict(txn)]

        if do_invite:
            sql = (
                "SELECT i.room_id, inviter, i.event_id, e.stream_ordering"
                " FROM local_invites as i"
                " INNER JOIN events as e USING (event_id)"
                " WHERE invitee = ? AND locally_rejected is NULL"
                " AND replaced_by is NULL"
            )

            txn.execute(sql, (user_id,))
            results.extend(
                RoomsForUser(
                    room_id=r["room_id"],
                    sender=r["inviter"],
                    event_id=r["event_id"],
                    stream_ordering=r["stream_ordering"],
                    membership=Membership.INVITE,
                )
                for r in self.db.cursor_to_dict(txn)
            )

        return results
Example #16
0
    def _get_auth_chain_ids_using_cover_index_txn(
            self, txn: Cursor, room_id: str, event_ids: Collection[str],
            include_given: bool) -> List[str]:
        """Calculates the auth chain IDs using the chain index."""

        # First we look up the chain ID/sequence numbers for the given events.

        initial_events = set(event_ids)

        # All the events that we've found that are reachable from the events.
        seen_events = set()  # type: Set[str]

        # A map from chain ID to max sequence number of the given events.
        event_chains = {}  # type: Dict[int, int]

        sql = """
            SELECT event_id, chain_id, sequence_number
            FROM event_auth_chains
            WHERE %s
        """
        for batch in batch_iter(initial_events, 1000):
            clause, args = make_in_list_sql_clause(txn.database_engine,
                                                   "event_id", batch)
            txn.execute(sql % (clause, ), args)

            for event_id, chain_id, sequence_number in txn:
                seen_events.add(event_id)
                event_chains[chain_id] = max(sequence_number,
                                             event_chains.get(chain_id, 0))

        # Check that we actually have a chain ID for all the events.
        events_missing_chain_info = initial_events.difference(seen_events)
        if events_missing_chain_info:
            # This can happen due to e.g. downgrade/upgrade of the server. We
            # raise an exception and fall back to the previous algorithm.
            logger.info(
                "Unexpectedly found that events don't have chain IDs in room %s: %s",
                room_id,
                events_missing_chain_info,
            )
            raise _NoChainCoverIndex(room_id)

        # Now we look up all links for the chains we have, adding chains that
        # are reachable from any event.
        sql = """
            SELECT
                origin_chain_id, origin_sequence_number,
                target_chain_id, target_sequence_number
            FROM event_auth_chain_links
            WHERE %s
        """

        # A map from chain ID to max sequence number *reachable* from any event ID.
        chains = {}  # type: Dict[int, int]

        # Add all linked chains reachable from initial set of chains.
        for batch in batch_iter(event_chains, 1000):
            clause, args = make_in_list_sql_clause(txn.database_engine,
                                                   "origin_chain_id", batch)
            txn.execute(sql % (clause, ), args)

            for (
                    origin_chain_id,
                    origin_sequence_number,
                    target_chain_id,
                    target_sequence_number,
            ) in txn:
                # chains are only reachable if the origin sequence number of
                # the link is less than the max sequence number in the
                # origin chain.
                if origin_sequence_number <= event_chains.get(
                        origin_chain_id, 0):
                    chains[target_chain_id] = max(
                        target_sequence_number,
                        chains.get(target_chain_id, 0),
                    )

        # Add the initial set of chains, excluding the sequence corresponding to
        # initial event.
        for chain_id, seq_no in event_chains.items():
            chains[chain_id] = max(seq_no - 1, chains.get(chain_id, 0))

        # Now for each chain we figure out the maximum sequence number reachable
        # from *any* event ID. Events with a sequence less than that are in the
        # auth chain.
        if include_given:
            results = initial_events
        else:
            results = set()

        if isinstance(self.database_engine, PostgresEngine):
            # We can use `execute_values` to efficiently fetch the gaps when
            # using postgres.
            sql = """
                SELECT event_id
                FROM event_auth_chains AS c, (VALUES ?) AS l(chain_id, max_seq)
                WHERE
                    c.chain_id = l.chain_id
                    AND sequence_number <= max_seq
            """

            rows = txn.execute_values(sql, chains.items())
            results.update(r for r, in rows)
        else:
            # For SQLite we just fall back to doing a noddy for loop.
            sql = """
                SELECT event_id FROM event_auth_chains
                WHERE chain_id = ? AND sequence_number <= ?
            """
            for chain_id, max_no in chains.items():
                txn.execute(sql, (chain_id, max_no))
                results.update(r for r, in txn)

        return list(results)