Beispiel #1
0
        def get_new_messages_for_remote_destination_txn(
            txn: LoggingTransaction, ) -> Tuple[List[JsonDict], int]:
            sql = (
                "SELECT stream_id, messages_json FROM device_federation_outbox"
                " WHERE destination = ?"
                " AND ? < stream_id AND stream_id <= ?"
                " ORDER BY stream_id ASC"
                " LIMIT ?")
            txn.execute(
                sql, (destination, last_stream_id, current_stream_id, limit))

            messages = []
            stream_pos = current_stream_id

            for row in txn:
                stream_pos = row[0]
                messages.append(db_to_json(row[1]))

            # If the limit was not reached we know that there's no more data for this
            # user/device pair up to current_stream_id.
            if len(messages) < limit:
                log_kv({"message": "Set stream position to current position"})
                stream_pos = current_stream_id

            return messages, stream_pos
Beispiel #2
0
    def notify(
        self,
        stream_key: str,
        stream_id: Union[int, RoomStreamToken],
        time_now_ms: int,
    ):
        """Notify any listeners for this user of a new event from an
        event source.
        Args:
            stream_key: The stream the event came from.
            stream_id: The new id for the stream the event came from.
            time_now_ms: The current time in milliseconds.
        """
        self.current_token = self.current_token.copy_and_advance(stream_key, stream_id)
        self.last_notified_token = self.current_token
        self.last_notified_ms = time_now_ms
        noify_deferred = self.notify_deferred

        log_kv(
            {
                "notify": self.user_id,
                "stream": stream_key,
                "stream_id": stream_id,
                "listeners": self.count_listeners(),
            }
        )

        users_woken_by_stream_counter.labels(stream_key).inc()

        with PreserveLoggingContext():
            self.notify_deferred = ObservableDeferred(defer.Deferred())
            noify_deferred.callback(self.current_token)
Beispiel #3
0
    async def delete_device(self, user_id: str, device_id: str) -> None:
        """ Delete the given device

        Args:
            user_id: The user to delete the device from.
            device_id: The device to delete.
        """

        try:
            await self.store.delete_device(user_id, device_id)
        except errors.StoreError as e:
            if e.code == 404:
                # no match
                set_tag("error", True)
                log_kv(
                    {"reason": "User doesn't have device id.", "device_id": device_id}
                )
                pass
            else:
                raise

        await self._auth_handler.delete_access_tokens_for_user(
            user_id, device_id=device_id
        )

        await self.store.delete_e2e_keys_by_device(user_id=user_id, device_id=device_id)

        await self.notify_device_update(user_id, [device_id])
Beispiel #4
0
 def _claim_e2e_one_time_keys(txn):
     sql = (
         "SELECT key_id, key_json FROM e2e_one_time_keys_json"
         " WHERE user_id = ? AND device_id = ? AND algorithm = ?"
         " LIMIT 1"
     )
     result = {}
     delete = []
     for user_id, device_id, algorithm in query_list:
         user_result = result.setdefault(user_id, {})
         device_result = user_result.setdefault(device_id, {})
         txn.execute(sql, (user_id, device_id, algorithm))
         for key_id, key_json in txn:
             device_result[algorithm + ":" + key_id] = key_json
             delete.append((user_id, device_id, algorithm, key_id))
     sql = (
         "DELETE FROM e2e_one_time_keys_json"
         " WHERE user_id = ? AND device_id = ? AND algorithm = ?"
         " AND key_id = ?"
     )
     for user_id, device_id, algorithm, key_id in delete:
         log_kv(
             {
                 "message": "Executing claim e2e_one_time_keys transaction on database."
             }
         )
         txn.execute(sql, (user_id, device_id, algorithm, key_id))
         log_kv({"message": "finished executing and invalidating cache"})
         self._invalidate_cache_and_stream(
             txn, self.count_e2e_one_time_keys, (user_id, device_id)
         )
     return result
Beispiel #5
0
    async def on_POST(self, request, device_id):
        requester = await self.auth.get_user_by_req(request, allow_guest=True)
        user_id = requester.user.to_string()
        body = parse_json_object_from_request(request)

        if device_id is not None:
            # passing the device_id here is deprecated; however, we allow it
            # for now for compatibility with older clients.
            if requester.device_id is not None and device_id != requester.device_id:
                set_tag("error", True)
                log_kv(
                    {
                        "message": "Client uploading keys for a different device",
                        "logged_in_id": requester.device_id,
                        "key_being_uploaded": device_id,
                    }
                )
                logger.warning(
                    "Client uploading keys for a different device "
                    "(logged in as %s, uploading for %s)",
                    requester.device_id,
                    device_id,
                )
        else:
            device_id = requester.device_id

        if device_id is None:
            raise SynapseError(
                400, "To upload keys, you must pass device_id when authenticating"
            )

        result = await self.e2e_keys_handler.upload_keys_for_user(
            user_id, device_id, body
        )
        return 200, result
Beispiel #6
0
    def get_e2e_one_time_keys(self, user_id, device_id, key_ids):
        """Retrieve a number of one-time keys for a user

        Args:
            user_id(str): id of user to get keys for
            device_id(str): id of device to get keys for
            key_ids(list[str]): list of key ids (excluding algorithm) to
                retrieve

        Returns:
            deferred resolving to Dict[(str, str), str]: map from (algorithm,
            key_id) to json string for key
        """

        rows = yield self._simple_select_many_batch(
            table="e2e_one_time_keys_json",
            column="key_id",
            iterable=key_ids,
            retcols=("algorithm", "key_id", "key_json"),
            keyvalues={"user_id": user_id, "device_id": device_id},
            desc="add_e2e_one_time_keys_check",
        )
        result = {(row["algorithm"], row["key_id"]): row["key_json"] for row in rows}
        log_kv({"message": "Fetched one time keys for user", "one_time_keys": result})
        return result
Beispiel #7
0
 def delete_e2e_keys_by_device_txn(txn):
     log_kv(
         {
             "message": "Deleting keys for device",
             "device_id": device_id,
             "user_id": user_id,
         }
     )
     self.db_pool.simple_delete_txn(
         txn,
         table="e2e_device_keys_json",
         keyvalues={"user_id": user_id, "device_id": device_id},
     )
     self.db_pool.simple_delete_txn(
         txn,
         table="e2e_one_time_keys_json",
         keyvalues={"user_id": user_id, "device_id": device_id},
     )
     self._invalidate_cache_and_stream(
         txn, self.count_e2e_one_time_keys, (user_id, device_id)
     )
     self.db_pool.simple_delete_txn(
         txn,
         table="dehydrated_devices",
         keyvalues={"user_id": user_id, "device_id": device_id},
     )
     self.db_pool.simple_delete_txn(
         txn,
         table="e2e_fallback_keys_json",
         keyvalues={"user_id": user_id, "device_id": device_id},
     )
     self._invalidate_cache_and_stream(
         txn, self.get_e2e_unused_fallback_key_types, (user_id, device_id)
     )
Beispiel #8
0
    def delete_device(self, user_id, device_id):
        """ Delete the given device

        Args:
            user_id (str):
            device_id (str):

        Returns:
            defer.Deferred:
        """

        try:
            yield self.store.delete_device(user_id, device_id)
        except errors.StoreError as e:
            if e.code == 404:
                # no match
                set_tag("error", True)
                log_kv(
                    {"reason": "User doesn't have device id.", "device_id": device_id}
                )
                pass
            else:
                raise

        yield self._auth_handler.delete_access_tokens_for_user(
            user_id, device_id=device_id
        )

        yield self.store.delete_e2e_keys_by_device(user_id=user_id, device_id=device_id)

        yield self.notify_device_update(user_id, [device_id])
Beispiel #9
0
        def _set_e2e_device_keys_txn(txn):
            set_tag("user_id", user_id)
            set_tag("device_id", device_id)
            set_tag("time_now", time_now)
            set_tag("device_keys", device_keys)

            old_key_json = self._simple_select_one_onecol_txn(
                txn,
                table="e2e_device_keys_json",
                keyvalues={"user_id": user_id, "device_id": device_id},
                retcol="key_json",
                allow_none=True,
            )

            # In py3 we need old_key_json to match new_key_json type. The DB
            # returns unicode while encode_canonical_json returns bytes.
            new_key_json = encode_canonical_json(device_keys).decode("utf-8")

            if old_key_json == new_key_json:
                log_kv({"Message": "Device key already stored."})
                return False

            self._simple_upsert_txn(
                txn,
                table="e2e_device_keys_json",
                keyvalues={"user_id": user_id, "device_id": device_id},
                values={"ts_added_ms": time_now, "key_json": new_key_json},
            )
            log_kv({"message": "Device keys stored."})
            return True
Beispiel #10
0
    def notify_device_update(self, user_id, device_ids):
        """Notify that a user's device(s) has changed. Pokes the notifier, and
        remote servers if the user is local.
        """
        users_who_share_room = yield self.store.get_users_who_share_room_with_user(
            user_id
        )

        hosts = set()
        if self.hs.is_mine_id(user_id):
            hosts.update(get_domain_from_id(u) for u in users_who_share_room)
            hosts.discard(self.server_name)

        set_tag("target_hosts", hosts)

        position = yield self.store.add_device_change_to_streams(
            user_id, device_ids, list(hosts)
        )

        for device_id in device_ids:
            logger.debug(
                "Notifying about update %r/%r, ID: %r", user_id, device_id, position
            )

        room_ids = yield self.store.get_rooms_for_user(user_id)

        yield self.notifier.on_new_event("device_list_key", position, rooms=room_ids)

        if hosts:
            logger.info(
                "Sending device list update notif for %r to: %r", user_id, hosts
            )
            for host in hosts:
                self.federation_sender.send_device_messages(host)
                log_kv({"message": "sent device update to host", "host": host})
 def delete_e2e_keys_by_device_txn(txn):
     log_kv({
         "message": "Deleting keys for device",
         "device_id": device_id,
         "user_id": user_id,
     })
     self.db.simple_delete_txn(
         txn,
         table="e2e_device_keys_json",
         keyvalues={
             "user_id": user_id,
             "device_id": device_id
         },
     )
     self.db.simple_delete_txn(
         txn,
         table="e2e_one_time_keys_json",
         keyvalues={
             "user_id": user_id,
             "device_id": device_id
         },
     )
     self._invalidate_cache_and_stream(txn,
                                       self.count_e2e_one_time_keys,
                                       (user_id, device_id))
Beispiel #12
0
    async def on_claim_client_keys(self, origin: str,
                                   content: JsonDict) -> Dict[str, Any]:
        query = []
        for user_id, device_keys in content.get("one_time_keys", {}).items():
            for device_id, algorithm in device_keys.items():
                query.append((user_id, device_id, algorithm))

        log_kv({
            "message": "Claiming one time keys.",
            "user, device pairs": query
        })
        results = await self.store.claim_e2e_one_time_keys(query)

        json_result = {}  # type: Dict[str, Dict[str, dict]]
        for user_id, device_keys in results.items():
            for device_id, keys in device_keys.items():
                for key_id, json_bytes in keys.items():
                    json_result.setdefault(user_id, {})[device_id] = {
                        key_id: json.loads(json_bytes)
                    }

        logger.info(
            "Claimed one-time-keys: %s",
            ",".join(("%s for %s:%s" % (key_id, user_id, device_id)
                      for user_id, user_keys in iteritems(json_result)
                      for device_id, device_keys in iteritems(user_keys)
                      for key_id, _ in iteritems(device_keys))),
        )

        return {"one_time_keys": json_result}
Beispiel #13
0
    async def send_device_message(
        self,
        sender_user_id: str,
        message_type: str,
        messages: Dict[str, Dict[str, JsonDict]],
    ) -> None:
        set_tag("number_of_messages", len(messages))
        set_tag("sender", sender_user_id)
        local_messages = {}
        remote_messages = {}  # type: Dict[str, Dict[str, Dict[str, JsonDict]]]
        for user_id, by_device in messages.items():
            # we use UserID.from_string to catch invalid user ids
            if self.is_mine(UserID.from_string(user_id)):
                messages_by_device = {
                    device_id: {
                        "content": message_content,
                        "type": message_type,
                        "sender": sender_user_id,
                    }
                    for device_id, message_content in by_device.items()
                }
                if messages_by_device:
                    local_messages[user_id] = messages_by_device
            else:
                destination = get_domain_from_id(user_id)
                remote_messages.setdefault(destination,
                                           {})[user_id] = by_device

        message_id = random_string(16)

        context = get_active_span_text_map()

        remote_edu_contents = {}
        for destination, messages in remote_messages.items():
            with start_active_span("to_device_for_user"):
                set_tag("destination", destination)
                remote_edu_contents[destination] = {
                    "messages": messages,
                    "sender": sender_user_id,
                    "type": message_type,
                    "message_id": message_id,
                    "org.matrix.opentracing_context":
                    json_encoder.encode(context),
                }

        log_kv({"local_messages": local_messages})
        stream_id = await self.store.add_messages_to_device_inbox(
            local_messages, remote_edu_contents)

        self.notifier.on_new_event("to_device_key",
                                   stream_id,
                                   users=local_messages.keys())

        log_kv({"remote_messages": remote_messages})
        if self.federation_sender:
            for destination in remote_messages.keys():
                # Enqueue a new federation transaction to send the new
                # device messages to each remote destination.
                self.federation_sender.send_device_messages(destination)
Beispiel #14
0
    def incoming_device_list_update(self, origin, edu_content):
        """Called on incoming device list update from federation. Responsible
        for parsing the EDU and adding to pending updates list.
        """

        set_tag("origin", origin)
        set_tag("edu_content", edu_content)
        user_id = edu_content.pop("user_id")
        device_id = edu_content.pop("device_id")
        stream_id = str(edu_content.pop("stream_id"))  # They may come as ints
        prev_ids = edu_content.pop("prev_id", [])
        prev_ids = [str(p) for p in prev_ids]  # They may come as ints

        if get_domain_from_id(user_id) != origin:
            # TODO: Raise?
            logger.warning(
                "Got device list update edu for %r/%r from %r",
                user_id,
                device_id,
                origin,
            )

            set_tag("error", True)
            log_kv(
                {
                    "message": "Got a device list update edu from a user and "
                    "device which does not match the origin of the request.",
                    "user_id": user_id,
                    "device_id": device_id,
                }
            )
            return

        room_ids = yield self.store.get_rooms_for_user(user_id)
        if not room_ids:
            # We don't share any rooms with this user. Ignore update, as we
            # probably won't get any further updates.
            set_tag("error", True)
            log_kv(
                {
                    "message": "Got an update from a user for which "
                    "we don't share any rooms",
                    "other user_id": user_id,
                }
            )
            logger.warning(
                "Got device list update edu for %r/%r, but don't share a room",
                user_id,
                device_id,
            )
            return

        logger.debug("Received device list update for %r/%r", user_id, device_id)

        self._pending_updates.setdefault(user_id, []).append(
            (device_id, stream_id, prev_ids, edu_content)
        )

        yield self._handle_device_updates(user_id)
Beispiel #15
0
    async def get_new_device_msgs_for_remote(self, destination, last_stream_id,
                                             current_stream_id,
                                             limit) -> Tuple[List[dict], int]:
        """
        Args:
            destination(str): The name of the remote server.
            last_stream_id(int|long): The last position of the device message stream
                that the server sent up to.
            current_stream_id(int|long): The current position of the device
                message stream.
        Returns:
            A list of messages for the device and where in the stream the messages got to.
        """

        set_tag("destination", destination)
        set_tag("last_stream_id", last_stream_id)
        set_tag("current_stream_id", current_stream_id)
        set_tag("limit", limit)

        has_changed = self._device_federation_outbox_stream_cache.has_entity_changed(
            destination, last_stream_id)
        if not has_changed or last_stream_id == current_stream_id:
            log_kv({"message": "No new messages in stream"})
            return [], current_stream_id

        if limit <= 0:
            # This can happen if we run out of room for EDUs in the transaction.
            return [], last_stream_id

        @trace
        def get_new_messages_for_remote_destination_txn(txn):
            sql = (
                "SELECT stream_id, messages_json FROM device_federation_outbox"
                " WHERE destination = ?"
                " AND ? < stream_id AND stream_id <= ?"
                " ORDER BY stream_id ASC"
                " LIMIT ?")
            txn.execute(
                sql, (destination, last_stream_id, current_stream_id, limit))

            messages = []
            stream_pos = current_stream_id

            for row in txn:
                stream_pos = row[0]
                messages.append(db_to_json(row[1]))

            # If the limit was not reached we know that there's no more data for this
            # user/device pair up to current_stream_id.
            if len(messages) < limit:
                log_kv({"message": "Set stream position to current position"})
                stream_pos = current_stream_id

            return messages, stream_pos

        return await self.db_pool.runInteraction(
            "get_new_device_msgs_for_remote",
            get_new_messages_for_remote_destination_txn,
        )
Beispiel #16
0
    def on_new_event(
        self,
        stream_key: str,
        new_token: Union[int, RoomStreamToken],
        users: Optional[Collection[Union[str, UserID]]] = None,
        rooms: Optional[Collection[str]] = None,
    ):
        """Used to inform listeners that something has happened event wise.

        Will wake up all listeners for the given users and rooms.
        """
        users = users or []
        rooms = rooms or []

        with Measure(self.clock, "on_new_event"):
            user_streams = set()

            log_kv(
                {
                    "waking_up_explicit_users": len(users),
                    "waking_up_explicit_rooms": len(rooms),
                }
            )

            for user in users:
                user_stream = self.user_to_user_stream.get(str(user))
                if user_stream is not None:
                    user_streams.add(user_stream)

            for room in rooms:
                user_streams |= self.room_to_user_streams.get(room, set())

            if stream_key == "to_device_key":
                issue9533_logger.debug(
                    "to-device messages stream id %s, awaking streams for %s",
                    new_token,
                    users,
                )

            time_now_ms = self.clock.time_msec()
            for user_stream in user_streams:
                try:
                    user_stream.notify(stream_key, new_token, time_now_ms)
                except Exception:
                    logger.exception("Failed to notify listener")

            self.notify_replication()

            # Notify appservices
            self._notify_app_services_ephemeral(
                stream_key,
                new_token,
                users,
            )
Beispiel #17
0
    async def delete_messages_for_device(
        self, user_id: str, device_id: Optional[str], up_to_stream_id: int
    ) -> int:
        """
        Args:
            user_id: The recipient user_id.
            device_id: The recipient device_id.
            up_to_stream_id: Where to delete messages up to.

        Returns:
            The number of messages deleted.
        """
        # If we have cached the last stream id we've deleted up to, we can
        # check if there is likely to be anything that needs deleting
        last_deleted_stream_id = self._last_device_delete_cache.get(
            (user_id, device_id), None
        )

        set_tag("last_deleted_stream_id", last_deleted_stream_id)

        if last_deleted_stream_id:
            has_changed = self._device_inbox_stream_cache.has_entity_changed(
                user_id, last_deleted_stream_id
            )
            if not has_changed:
                log_kv({"message": "No changes in cache since last check"})
                return 0

        def delete_messages_for_device_txn(txn):
            sql = (
                "DELETE FROM device_inbox"
                " WHERE user_id = ? AND device_id = ?"
                " AND stream_id <= ?"
            )
            txn.execute(sql, (user_id, device_id, up_to_stream_id))
            return txn.rowcount

        count = await self.db_pool.runInteraction(
            "delete_messages_for_device", delete_messages_for_device_txn
        )

        log_kv(
            {"message": "deleted {} messages for device".format(count), "count": count}
        )

        # Update the cache, ensuring that we only ever increase the value
        last_deleted_stream_id = self._last_device_delete_cache.get(
            (user_id, device_id), 0
        )
        self._last_device_delete_cache[(user_id, device_id)] = max(
            last_deleted_stream_id, up_to_stream_id
        )

        return count
Beispiel #18
0
    def query_local_devices(self, query):
        """Get E2E device keys for local users

        Args:
            query (dict[string, list[string]|None): map from user_id to a list
                 of devices to query (None for all devices)

        Returns:
            defer.Deferred: (resolves to dict[string, dict[string, dict]]):
                 map from user_id -> device_id -> device details
        """
        set_tag("local_query", query)
        local_query = []

        result_dict = {}
        for user_id, device_ids in query.items():
            # we use UserID.from_string to catch invalid user ids
            if not self.is_mine(UserID.from_string(user_id)):
                logger.warning("Request for keys for non-local user %s",
                               user_id)
                log_kv({
                    "message": "Requested a local key for a user which"
                    " was not local to the homeserver",
                    "user_id": user_id,
                })
                set_tag("error", True)
                raise SynapseError(400, "Not a user here")

            if not device_ids:
                local_query.append((user_id, None))
            else:
                for device_id in device_ids:
                    local_query.append((user_id, device_id))

            # make sure that each queried user appears in the result dict
            result_dict[user_id] = {}

        results = yield self.store.get_e2e_device_keys(local_query)

        # Build the result structure, un-jsonify the results, and add the
        # "unsigned" section
        for user_id, device_keys in results.items():
            for device_id, device_info in device_keys.items():
                r = dict(device_info["keys"])
                r["unsigned"] = {}
                display_name = device_info["device_display_name"]
                if display_name is not None:
                    r["unsigned"]["device_display_name"] = display_name
                result_dict[user_id][device_id] = r

        log_kv(results)
        return result_dict
Beispiel #19
0
    def _get_e2e_device_keys_txn(
        self, txn, query_list, include_all_devices=False, include_deleted_devices=False
    ):
        set_tag("include_all_devices", include_all_devices)
        set_tag("include_deleted_devices", include_deleted_devices)

        query_clauses = []
        query_params = []

        if include_all_devices is False:
            include_deleted_devices = False

        if include_deleted_devices:
            deleted_devices = set(query_list)

        for (user_id, device_id) in query_list:
            query_clause = "user_id = ?"
            query_params.append(user_id)

            if device_id is not None:
                query_clause += " AND device_id = ?"
                query_params.append(device_id)

            query_clauses.append(query_clause)

        sql = (
            "SELECT user_id, device_id, "
            "    d.display_name AS device_display_name, "
            "    k.key_json"
            " FROM devices d"
            "    %s JOIN e2e_device_keys_json k USING (user_id, device_id)"
            " WHERE %s"
        ) % (
            "LEFT" if include_all_devices else "INNER",
            " OR ".join("(" + q + ")" for q in query_clauses),
        )

        txn.execute(sql, query_params)
        rows = self.cursor_to_dict(txn)

        result = {}
        for row in rows:
            if include_deleted_devices:
                deleted_devices.remove((row["user_id"], row["device_id"]))
            result.setdefault(row["user_id"], {})[row["device_id"]] = row

        if include_deleted_devices:
            for user_id, device_id in deleted_devices:
                result.setdefault(user_id, {})[device_id] = None

        log_kv(result)
        return result
Beispiel #20
0
    def get_new_device_msgs_for_remote(self, destination, last_stream_id,
                                       current_stream_id, limit):
        """
        Args:
            destination(str): The name of the remote server.
            last_stream_id(int|long): The last position of the device message stream
                that the server sent up to.
            current_stream_id(int|long): The current position of the device
                message stream.
        Returns:
            Deferred ([dict], int|long): List of messages for the device and where
                in the stream the messages got to.
        """

        set_tag("destination", destination)
        set_tag("last_stream_id", last_stream_id)
        set_tag("current_stream_id", current_stream_id)
        set_tag("limit", limit)

        has_changed = self._device_federation_outbox_stream_cache.has_entity_changed(
            destination, last_stream_id)
        if not has_changed or last_stream_id == current_stream_id:
            log_kv({"message": "No new messages in stream"})
            return defer.succeed(([], current_stream_id))

        if limit <= 0:
            # This can happen if we run out of room for EDUs in the transaction.
            return defer.succeed(([], last_stream_id))

        @trace
        def get_new_messages_for_remote_destination_txn(txn):
            sql = (
                "SELECT stream_id, messages_json FROM device_federation_outbox"
                " WHERE destination = ?"
                " AND ? < stream_id AND stream_id <= ?"
                " ORDER BY stream_id ASC"
                " LIMIT ?")
            txn.execute(
                sql, (destination, last_stream_id, current_stream_id, limit))
            messages = []
            for row in txn:
                stream_pos = row[0]
                messages.append(json.loads(row[1]))
            if len(messages) < limit:
                log_kv({"message": "Set stream position to current position"})
                stream_pos = current_stream_id
            return messages, stream_pos

        return self.runInteraction(
            "get_new_device_msgs_for_remote",
            get_new_messages_for_remote_destination_txn,
        )
Beispiel #21
0
    async def add_e2e_room_keys(
            self, user_id: str, version: str,
            room_keys: Iterable[Tuple[str, str, RoomKey]]) -> None:
        """Bulk add room keys to a given backup.

        Args:
            user_id: the user whose backup we're adding to
            version: the version ID of the backup for the set of keys we're adding to
            room_keys: the keys to add, in the form (roomID, sessionID, keyData)
        """
        try:
            version_int = int(version)
        except ValueError:
            # Our versions are all ints so if we can't convert it to an integer,
            # it doesn't exist.
            raise StoreError(404, "No backup with that version exists")

        values = []
        for (room_id, session_id, room_key) in room_keys:
            values.append((
                user_id,
                version_int,
                room_id,
                session_id,
                room_key["first_message_index"],
                room_key["forwarded_count"],
                room_key["is_verified"],
                json_encoder.encode(room_key["session_data"]),
            ))
            log_kv({
                "message": "Set room key",
                "room_id": room_id,
                "session_id": session_id,
                StreamKeyType.ROOM: room_key,
            })

        await self.db_pool.simple_insert_many(
            table="e2e_room_keys",
            keys=(
                "user_id",
                "version",
                "room_id",
                "session_id",
                "first_message_index",
                "forwarded_count",
                "is_verified",
                "session_data",
            ),
            values=values,
            desc="add_e2e_room_keys",
        )
Beispiel #22
0
    async def notify_device_update(self, user_id: str,
                                   device_ids: Collection[str]) -> None:
        """Notify that a user's device(s) has changed. Pokes the notifier, and
        remote servers if the user is local.

        Args:
            user_id: The Matrix ID of the user who's device list has been updated.
            device_ids: The device IDs that have changed.
        """
        if not device_ids:
            # No changes to notify about, so this is a no-op.
            return

        users_who_share_room = await self.store.get_users_who_share_room_with_user(
            user_id)

        hosts: Set[str] = set()
        if self.hs.is_mine_id(user_id):
            hosts.update(get_domain_from_id(u) for u in users_who_share_room)
            hosts.discard(self.server_name)

        set_tag("target_hosts", hosts)

        position = await self.store.add_device_change_to_streams(
            user_id, device_ids, list(hosts))

        if not position:
            # This should only happen if there are no updates, so we bail.
            return

        for device_id in device_ids:
            logger.debug("Notifying about update %r/%r, ID: %r", user_id,
                         device_id, position)

        room_ids = await self.store.get_rooms_for_user(user_id)

        # specify the user ID too since the user should always get their own device list
        # updates, even if they aren't in any rooms.
        self.notifier.on_new_event("device_list_key",
                                   position,
                                   users=[user_id],
                                   rooms=room_ids)

        if hosts:
            logger.info("Sending device list update notif for %r to: %r",
                        user_id, hosts)
            for host in hosts:
                self.federation_sender.send_device_messages(host)
                log_kv({"message": "sent device update to host", "host": host})
Beispiel #23
0
    def upload_keys_for_user(self, user_id, device_id, keys):

        time_now = self.clock.time_msec()

        # TODO: Validate the JSON to make sure it has the right keys.
        device_keys = keys.get("device_keys", None)
        if device_keys:
            logger.info(
                "Updating device_keys for device %r for user %s at %d",
                device_id,
                user_id,
                time_now,
            )
            log_kv(
                {
                    "message": "Updating device_keys for user.",
                    "user_id": user_id,
                    "device_id": device_id,
                }
            )
            # TODO: Sign the JSON with the server key
            changed = yield self.store.set_e2e_device_keys(
                user_id, device_id, time_now, device_keys
            )
            if changed:
                # Only notify about device updates *if* the keys actually changed
                yield self.device_handler.notify_device_update(user_id, [device_id])
        else:
            log_kv({"message": "Not updating device_keys for user", "user_id": user_id})
        one_time_keys = keys.get("one_time_keys", None)
        if one_time_keys:
            log_kv(
                {
                    "message": "Updating one_time_keys for device.",
                    "user_id": user_id,
                    "device_id": device_id,
                }
            )
            yield self._upload_one_time_keys_for_user(
                user_id, device_id, time_now, one_time_keys
            )
        else:
            log_kv(
                {"message": "Did not update one_time_keys", "reason": "no keys given"}
            )

        # the device should have been registered already, but it may have been
        # deleted due to a race with a DELETE request. Or we may be using an
        # old access_token without an associated device_id. Either way, we
        # need to double-check the device is registered to avoid ending up with
        # keys without a corresponding device.
        yield self.device_handler.check_device_registered(user_id, device_id)

        result = yield self.store.count_e2e_one_time_keys(user_id, device_id)

        set_tag("one_time_key_counts", result)
        return {"one_time_key_counts": result}
Beispiel #24
0
 def get_new_messages_for_remote_destination_txn(txn):
     sql = (
         "SELECT stream_id, messages_json FROM device_federation_outbox"
         " WHERE destination = ?"
         " AND ? < stream_id AND stream_id <= ?"
         " ORDER BY stream_id ASC"
         " LIMIT ?")
     txn.execute(
         sql, (destination, last_stream_id, current_stream_id, limit))
     messages = []
     for row in txn:
         stream_pos = row[0]
         messages.append(db_to_json(row[1]))
     if len(messages) < limit:
         log_kv({"message": "Set stream position to current position"})
         stream_pos = current_stream_id
     return messages, stream_pos
    async def get_room_keys(
        self,
        user_id: str,
        version: str,
        room_id: Optional[str] = None,
        session_id: Optional[str] = None,
    ) -> Dict[Literal["rooms"], Dict[str, Dict[Literal["sessions"], Dict[
            str, RoomKey]]]]:
        """Bulk get the E2E room keys for a given backup, optionally filtered to a given
        room, or a given session.
        See EndToEndRoomKeyStore.get_e2e_room_keys for full details.

        Args:
            user_id: the user whose keys we're getting
            version: the version ID of the backup we're getting keys from
            room_id: room ID to get keys for, for None to get keys for all rooms
            session_id: session ID to get keys for, for None to get keys for all
                sessions
        Raises:
            NotFoundError: if the backup version does not exist
        Returns:
            A dict giving the session_data and message metadata for these room keys.
            `{"rooms": {room_id: {"sessions": {session_id: room_key}}}}`
        """

        # we deliberately take the lock to get keys so that changing the version
        # works atomically
        with (await self._upload_linearizer.queue(user_id)):
            # make sure the backup version exists
            try:
                await self.store.get_e2e_room_keys_version_info(
                    user_id, version)
            except StoreError as e:
                if e.code == 404:
                    raise NotFoundError("Unknown backup version")
                else:
                    raise

            results = await self.store.get_e2e_room_keys(
                user_id, version, room_id, session_id)

            log_kv(results)
            return results
Beispiel #26
0
    async def on_POST(
        self, request: SynapseRequest, device_id: Optional[str]
    ) -> Tuple[int, JsonDict]:
        requester = await self.auth.get_user_by_req(request, allow_guest=True)
        user_id = requester.user.to_string()
        body = parse_json_object_from_request(request)

        if device_id is not None:
            # Providing the device_id should only be done for setting keys
            # for dehydrated devices; however, we allow it for any device for
            # compatibility with older clients.
            if requester.device_id is not None and device_id != requester.device_id:
                dehydrated_device = await self.device_handler.get_dehydrated_device(
                    user_id
                )
                if dehydrated_device is not None and device_id != dehydrated_device[0]:
                    set_tag("error", True)
                    log_kv(
                        {
                            "message": "Client uploading keys for a different device",
                            "logged_in_id": requester.device_id,
                            "key_being_uploaded": device_id,
                        }
                    )
                    logger.warning(
                        "Client uploading keys for a different device "
                        "(logged in as %s, uploading for %s)",
                        requester.device_id,
                        device_id,
                    )
        else:
            device_id = requester.device_id

        if device_id is None:
            raise SynapseError(
                400, "To upload keys, you must pass device_id when authenticating"
            )

        result = await self.e2e_keys_handler.upload_keys_for_user(
            user_id, device_id, body
        )
        return 200, result
Beispiel #27
0
    def _upload_one_time_keys_for_user(
        self, user_id, device_id, time_now, one_time_keys
    ):
        logger.info(
            "Adding one_time_keys %r for device %r for user %r at %d",
            one_time_keys.keys(),
            device_id,
            user_id,
            time_now,
        )

        # make a list of (alg, id, key) tuples
        key_list = []
        for key_id, key_obj in one_time_keys.items():
            algorithm, key_id = key_id.split(":")
            key_list.append((algorithm, key_id, key_obj))

        # First we check if we have already persisted any of the keys.
        existing_key_map = yield self.store.get_e2e_one_time_keys(
            user_id, device_id, [k_id for _, k_id, _ in key_list]
        )

        new_keys = []  # Keys that we need to insert. (alg, id, json) tuples.
        for algorithm, key_id, key in key_list:
            ex_json = existing_key_map.get((algorithm, key_id), None)
            if ex_json:
                if not _one_time_keys_match(ex_json, key):
                    raise SynapseError(
                        400,
                        (
                            "One time key %s:%s already exists. "
                            "Old key: %s; new key: %r"
                        )
                        % (algorithm, key_id, ex_json, key),
                    )
            else:
                new_keys.append(
                    (algorithm, key_id, encode_canonical_json(key).decode("ascii"))
                )

        log_kv({"message": "Inserting new one_time_keys.", "keys": new_keys})
        yield self.store.add_e2e_one_time_keys(user_id, device_id, time_now, new_keys)
Beispiel #28
0
    async def get_devices_by_user(self, user_id: str) -> List[Dict[str, Any]]:
        """
        Retrieve the given user's devices

        Args:
            user_id: The user ID to query for devices.
        Returns:
            info on each device
        """

        set_tag("user_id", user_id)
        device_map = await self.store.get_devices_by_user(user_id)

        ips = await self.store.get_last_client_ip_by_device(user_id, device_id=None)

        devices = list(device_map.values())
        for device in devices:
            _update_device_from_client_ips(device, ips)

        log_kv(device_map)
        return devices
Beispiel #29
0
    def get_devices_by_user(self, user_id):
        """
        Retrieve the given user's devices

        Args:
            user_id (str):
        Returns:
            defer.Deferred: list[dict[str, X]]: info on each device
        """

        set_tag("user_id", user_id)
        device_map = yield self.store.get_devices_by_user(user_id)

        ips = yield self.store.get_last_client_ip_by_device(user_id, device_id=None)

        devices = list(device_map.values())
        for device in devices:
            _update_device_from_client_ips(device, ips)

        log_kv(device_map)
        return devices
    async def add_e2e_room_keys(self, user_id, version, room_keys):
        """Bulk add room keys to a given backup.

        Args:
            user_id (str): the user whose backup we're adding to
            version (str): the version ID of the backup for the set of keys we're adding to
            room_keys (iterable[(str, str, dict)]): the keys to add, in the form
                (roomID, sessionID, keyData)
        """

        values = []
        for (room_id, session_id, room_key) in room_keys:
            values.append({
                "user_id":
                user_id,
                "version":
                version,
                "room_id":
                room_id,
                "session_id":
                session_id,
                "first_message_index":
                room_key["first_message_index"],
                "forwarded_count":
                room_key["forwarded_count"],
                "is_verified":
                room_key["is_verified"],
                "session_data":
                json_encoder.encode(room_key["session_data"]),
            })
            log_kv({
                "message": "Set room key",
                "room_id": room_id,
                "session_id": session_id,
                "room_key": room_key,
            })

        await self.db_pool.simple_insert_many(table="e2e_room_keys",
                                              values=values,
                                              desc="add_e2e_room_keys")