コード例 #1
0
    def _get_bare_e2e_cross_signing_keys_bulk_txn(
        self,
        txn: Connection,
        user_ids: List[str],
    ) -> Dict[str, Dict[str, dict]]:
        """Returns the cross-signing keys for a set of users.  The output of this
        function should be passed to _get_e2e_cross_signing_signatures_txn if
        the signatures for the calling user need to be fetched.

        Args:
            txn (twisted.enterprise.adbapi.Connection): db connection
            user_ids (list[str]): the users whose keys are being requested

        Returns:
            dict[str, dict[str, dict]]: mapping from user ID to key type to key
                data.  If a user's cross-signing keys were not found, their user
                ID will not be in the dict.

        """
        result = {}

        batch_size = 100
        chunks = [
            user_ids[i:i + batch_size]
            for i in range(0, len(user_ids), batch_size)
        ]
        for user_chunk in chunks:
            sql = """
                SELECT k.user_id, k.keytype, k.keydata, k.stream_id
                  FROM e2e_cross_signing_keys k
                  INNER JOIN (SELECT user_id, keytype, MAX(stream_id) AS stream_id
                                FROM e2e_cross_signing_keys
                               GROUP BY user_id, keytype) s
                 USING (user_id, stream_id, keytype)
                 WHERE k.user_id IN (%s)
            """ % (",".join("?" for u in user_chunk), )
            query_params = []
            query_params.extend(user_chunk)

            txn.execute(sql, query_params)
            rows = self.db.cursor_to_dict(txn)

            for row in rows:
                user_id = row["user_id"]
                key_type = row["keytype"]
                key = json.loads(row["keydata"])
                user_info = result.setdefault(user_id, {})
                user_info[key_type] = key

        return result
コード例 #2
0
    def _get_bare_e2e_cross_signing_keys_bulk_txn(
        self, txn: Connection, user_ids: List[str],
    ) -> Dict[str, Dict[str, dict]]:
        """Returns the cross-signing keys for a set of users.  The output of this
        function should be passed to _get_e2e_cross_signing_signatures_txn if
        the signatures for the calling user need to be fetched.

        Args:
            txn (twisted.enterprise.adbapi.Connection): db connection
            user_ids (list[str]): the users whose keys are being requested

        Returns:
            dict[str, dict[str, dict]]: mapping from user ID to key type to key
                data.  If a user's cross-signing keys were not found, their user
                ID will not be in the dict.

        """
        result = {}

        for user_chunk in batch_iter(user_ids, 100):
            clause, params = make_in_list_sql_clause(
                txn.database_engine, "k.user_id", user_chunk
            )
            sql = (
                """
                SELECT k.user_id, k.keytype, k.keydata, k.stream_id
                  FROM e2e_cross_signing_keys k
                  INNER JOIN (SELECT user_id, keytype, MAX(stream_id) AS stream_id
                                FROM e2e_cross_signing_keys
                               GROUP BY user_id, keytype) s
                 USING (user_id, stream_id, keytype)
                 WHERE
            """
                + clause
            )

            txn.execute(sql, params)
            rows = self.db.cursor_to_dict(txn)

            for row in rows:
                user_id = row["user_id"]
                key_type = row["keytype"]
                key = json.loads(row["keydata"])
                user_info = result.setdefault(user_id, {})
                user_info[key_type] = key

        return result
コード例 #3
0
    def test_rollbackErrorLogged(self):
        """
        If an error happens during rollback, L{ConnectionLost} is raised but
        the original error is logged.
        """
        class ConnectionRollbackRaise:
            def rollback(self):
                raise RuntimeError("problem!")

        pool = FakePool(ConnectionRollbackRaise)
        connection = Connection(pool)
        self.assertRaises(ConnectionLost, connection.rollback)
        errors = self.flushLoggedErrors(RuntimeError)
        self.assertEqual(len(errors), 1)
        self.assertEqual(errors[0].value.args[0], "problem!")
コード例 #4
0
    def _get_e2e_cross_signing_signatures_txn(
        self,
        txn: Connection,
        keys: Dict[str, Dict[str, dict]],
        from_user_id: str,
    ) -> Dict[str, Dict[str, dict]]:
        """Returns the cross-signing signatures made by a user on a set of keys.

        Args:
            txn (twisted.enterprise.adbapi.Connection): db connection
            keys (dict[str, dict[str, dict]]): a map of user ID to key type to
                key data.  This dict will be modified to add signatures.
            from_user_id (str): fetch the signatures made by this user

        Returns:
            dict[str, dict[str, dict]]: mapping from user ID to key type to key
                data.  The return value will be the same as the keys argument,
                with the modifications included.
        """

        # find out what cross-signing keys (a.k.a. devices) we need to get
        # signatures for.  This is a map of (user_id, device_id) to key type
        # (device_id is the key's public part).
        devices = {}

        for user_id, user_info in keys.items():
            if user_info is None:
                continue
            for key_type, key in user_info.items():
                device_id = None
                for k in key["keys"].values():
                    device_id = k
                devices[(user_id, device_id)] = key_type

        for batch in batch_iter(devices.keys(), size=100):
            sql = """
                SELECT target_user_id, target_device_id, key_id, signature
                  FROM e2e_cross_signing_signatures
                 WHERE user_id = ?
                   AND (%s)
            """ % (" OR ".join("(target_user_id = ? AND target_device_id = ?)"
                               for _ in batch))
            query_params = [from_user_id]
            for item in batch:
                # item is a (user_id, device_id) tuple
                query_params.extend(item)

            txn.execute(sql, query_params)
            rows = self.db_pool.cursor_to_dict(txn)

            # and add the signatures to the appropriate keys
            for row in rows:
                key_id = row["key_id"]
                target_user_id = row["target_user_id"]
                target_device_id = row["target_device_id"]
                key_type = devices[(target_user_id, target_device_id)]
                # We need to copy everything, because the result may have come
                # from the cache.  dict.copy only does a shallow copy, so we
                # need to recursively copy the dicts that will be modified.
                user_info = keys[target_user_id] = keys[target_user_id].copy()
                target_user_key = user_info[key_type] = user_info[
                    key_type].copy()
                if "signatures" in target_user_key:
                    signatures = target_user_key[
                        "signatures"] = target_user_key["signatures"].copy()
                    if from_user_id in signatures:
                        user_sigs = signatures[from_user_id] = signatures[
                            from_user_id]
                        user_sigs[key_id] = row["signature"]
                    else:
                        signatures[from_user_id] = {key_id: row["signature"]}
                else:
                    target_user_key["signatures"] = {
                        from_user_id: {
                            key_id: row["signature"]
                        }
                    }

        return keys
コード例 #5
0
    def _get_bare_e2e_cross_signing_keys_bulk_txn(
        self,
        txn: Connection,
        user_ids: List[str],
    ) -> Dict[str, Dict[str, dict]]:
        """Returns the cross-signing keys for a set of users.  The output of this
        function should be passed to _get_e2e_cross_signing_signatures_txn if
        the signatures for the calling user need to be fetched.

        Args:
            txn (twisted.enterprise.adbapi.Connection): db connection
            user_ids (list[str]): the users whose keys are being requested

        Returns:
            dict[str, dict[str, dict]]: mapping from user ID to key type to key
                data.  If a user's cross-signing keys were not found, their user
                ID will not be in the dict.

        """
        result = {}

        for user_chunk in batch_iter(user_ids, 100):
            clause, params = make_in_list_sql_clause(txn.database_engine,
                                                     "user_id", user_chunk)

            # Fetch the latest key for each type per user.
            if isinstance(self.database_engine, PostgresEngine):
                # The `DISTINCT ON` clause will pick the *first* row it
                # encounters, so ordering by stream ID desc will ensure we get
                # the latest key.
                sql = """
                    SELECT DISTINCT ON (user_id, keytype) user_id, keytype, keydata, stream_id
                        FROM e2e_cross_signing_keys
                        WHERE %(clause)s
                        ORDER BY user_id, keytype, stream_id DESC
                """ % {
                    "clause": clause
                }
            else:
                # SQLite has special handling for bare columns when using
                # MIN/MAX with a `GROUP BY` clause where it picks the value from
                # a row that matches the MIN/MAX.
                sql = """
                    SELECT user_id, keytype, keydata, MAX(stream_id)
                        FROM e2e_cross_signing_keys
                        WHERE %(clause)s
                        GROUP BY user_id, keytype
                """ % {
                    "clause": clause
                }

            txn.execute(sql, params)
            rows = self.db_pool.cursor_to_dict(txn)

            for row in rows:
                user_id = row["user_id"]
                key_type = row["keytype"]
                key = db_to_json(row["keydata"])
                user_info = result.setdefault(user_id, {})
                user_info[key_type] = key

        return result