예제 #1
0
    async def set_e2e_fallback_keys(self, user_id: str, device_id: str,
                                    fallback_keys: JsonDict) -> None:
        """Set the user's e2e fallback keys.

        Args:
            user_id: the user whose keys are being set
            device_id: the device whose keys are being set
            fallback_keys: the keys to set.  This is a map from key ID (which is
                of the form "algorithm:id") to key data.
        """
        # fallback_keys will usually only have one item in it, so using a for
        # loop (as opposed to calling simple_upsert_many_txn) won't be too bad
        # FIXME: make sure that only one key per algorithm is uploaded
        for key_id, fallback_key in fallback_keys.items():
            algorithm, key_id = key_id.split(":", 1)
            await self.db_pool.simple_upsert(
                "e2e_fallback_keys_json",
                keyvalues={
                    "user_id": user_id,
                    "device_id": device_id,
                    "algorithm": algorithm,
                },
                values={
                    "key_id": key_id,
                    "key_json": json_encoder.encode(fallback_key),
                    "used": False,
                },
                desc="set_e2e_fallback_key",
            )

        await self.invalidate_cache_and_stream(
            "get_e2e_unused_fallback_key_types", (user_id, device_id))
예제 #2
0
    async def _received_remote_receipt(self, origin: str,
                                       content: JsonDict) -> None:
        """Called when we receive an EDU of type m.receipt from a remote HS."""
        receipts = []
        for room_id, room_values in content.items():
            for receipt_type, users in room_values.items():
                for user_id, user_values in users.items():
                    if get_domain_from_id(user_id) != origin:
                        logger.info(
                            "Received receipt for user %r from server %s, ignoring",
                            user_id,
                            origin,
                        )
                        continue

                    receipts.append(
                        ReadReceipt(
                            room_id=room_id,
                            receipt_type=receipt_type,
                            user_id=user_id,
                            event_ids=user_values["event_ids"],
                            data=user_values.get("data", {}),
                        ))

        await self._handle_new_receipts(receipts)
예제 #3
0
    def _set_e2e_fallback_keys_txn(
        self,
        txn: LoggingTransaction,
        user_id: str,
        device_id: str,
        fallback_keys: JsonDict,
    ) -> None:
        # fallback_keys will usually only have one item in it, so using a for
        # loop (as opposed to calling simple_upsert_many_txn) won't be too bad
        # FIXME: make sure that only one key per algorithm is uploaded
        for key_id, fallback_key in fallback_keys.items():
            algorithm, key_id = key_id.split(":", 1)
            old_key_json = self.db_pool.simple_select_one_onecol_txn(
                txn,
                table="e2e_fallback_keys_json",
                keyvalues={
                    "user_id": user_id,
                    "device_id": device_id,
                    "algorithm": algorithm,
                },
                retcol="key_json",
                allow_none=True,
            )

            new_key_json = encode_canonical_json(fallback_key).decode("utf-8")

            # If the uploaded key is the same as the current fallback key,
            # don't do anything.  This prevents marking the key as unused if it
            # was already used.
            if old_key_json != new_key_json:
                self.db_pool.simple_upsert_txn(
                    txn,
                    table="e2e_fallback_keys_json",
                    keyvalues={
                        "user_id": user_id,
                        "device_id": device_id,
                        "algorithm": algorithm,
                    },
                    values={
                        "key_id": key_id,
                        "key_json": json_encoder.encode(fallback_key),
                        "used": False,
                    },
                )
예제 #4
0
    async def _received_remote_receipt(self, origin: str, content: JsonDict) -> None:
        """Called when we receive an EDU of type m.receipt from a remote HS."""
        receipts = []
        for room_id, room_values in content.items():
            # If we're not in the room just ditch the event entirely. This is
            # probably an old server that has come back and thinks we're still in
            # the room (or we've been rejoined to the room by a state reset).
            is_in_room = await self.event_auth_handler.check_host_in_room(
                room_id, self.server_name
            )
            if not is_in_room:
                logger.info(
                    "Ignoring receipt for room %r from server %s as we're not in the room",
                    room_id,
                    origin,
                )
                continue

            for receipt_type, users in room_values.items():
                for user_id, user_values in users.items():
                    if get_domain_from_id(user_id) != origin:
                        logger.info(
                            "Received receipt for user %r from server %s, ignoring",
                            user_id,
                            origin,
                        )
                        continue

                    receipts.append(
                        ReadReceipt(
                            room_id=room_id,
                            receipt_type=receipt_type,
                            user_id=user_id,
                            event_ids=user_values["event_ids"],
                            data=user_values.get("data", {}),
                        )
                    )

        await self._handle_new_receipts(receipts)
예제 #5
0
def _to_postgres_options(options_dict: JsonDict) -> str:
    return "'%s'" % (",".join("%s=%s" % (k, v)
                              for k, v in options_dict.items()), )
예제 #6
0
    async def query_keys(
        self,
        request: SynapseRequest,
        query: JsonDict,
        query_remote_on_cache_miss: bool = False,
    ) -> None:
        logger.info("Handling query for keys %r", query)

        store_queries = []
        for server_name, key_ids in query.items():
            if (self.federation_domain_whitelist is not None
                    and server_name not in self.federation_domain_whitelist):
                logger.debug("Federation denied with %s", server_name)
                continue

            if not key_ids:
                key_ids = (None, )
            for key_id in key_ids:
                store_queries.append((server_name, key_id, None))

        cached = await self.store.get_server_keys_json(store_queries)

        json_results: Set[bytes] = set()

        time_now_ms = self.clock.time_msec()

        # Note that the value is unused.
        cache_misses: Dict[str, Dict[str, int]] = {}
        for (server_name, key_id, _), key_results in cached.items():
            results = [(result["ts_added_ms"], result)
                       for result in key_results]

            if not results and key_id is not None:
                cache_misses.setdefault(server_name, {})[key_id] = 0
                continue

            if key_id is not None:
                ts_added_ms, most_recent_result = max(results)
                ts_valid_until_ms = most_recent_result["ts_valid_until_ms"]
                req_key = query.get(server_name, {}).get(key_id, {})
                req_valid_until = req_key.get("minimum_valid_until_ts")
                miss = False
                if req_valid_until is not None:
                    if ts_valid_until_ms < req_valid_until:
                        logger.debug(
                            "Cached response for %r/%r is older than requested"
                            ": valid_until (%r) < minimum_valid_until (%r)",
                            server_name,
                            key_id,
                            ts_valid_until_ms,
                            req_valid_until,
                        )
                        miss = True
                    else:
                        logger.debug(
                            "Cached response for %r/%r is newer than requested"
                            ": valid_until (%r) >= minimum_valid_until (%r)",
                            server_name,
                            key_id,
                            ts_valid_until_ms,
                            req_valid_until,
                        )
                elif (ts_added_ms + ts_valid_until_ms) / 2 < time_now_ms:
                    logger.debug(
                        "Cached response for %r/%r is too old"
                        ": (added (%r) + valid_until (%r)) / 2 < now (%r)",
                        server_name,
                        key_id,
                        ts_added_ms,
                        ts_valid_until_ms,
                        time_now_ms,
                    )
                    # We more than half way through the lifetime of the
                    # response. We should fetch a fresh copy.
                    miss = True
                else:
                    logger.debug(
                        "Cached response for %r/%r is still valid"
                        ": (added (%r) + valid_until (%r)) / 2 < now (%r)",
                        server_name,
                        key_id,
                        ts_added_ms,
                        ts_valid_until_ms,
                        time_now_ms,
                    )

                if miss:
                    cache_misses.setdefault(server_name, {})[key_id] = 0
                # Cast to bytes since postgresql returns a memoryview.
                json_results.add(bytes(most_recent_result["key_json"]))
            else:
                for _, result in results:
                    # Cast to bytes since postgresql returns a memoryview.
                    json_results.add(bytes(result["key_json"]))

        # If there is a cache miss, request the missing keys, then recurse (and
        # ensure the result is sent).
        if cache_misses and query_remote_on_cache_miss:
            await yieldable_gather_results(
                lambda t: self.fetcher.get_keys(*t),
                ((server_name, list(keys), 0)
                 for server_name, keys in cache_misses.items()),
            )
            await self.query_keys(request,
                                  query,
                                  query_remote_on_cache_miss=False)
        else:
            signed_keys = []
            for key_json_raw in json_results:
                key_json = json_decoder.decode(key_json_raw.decode("utf-8"))
                for signing_key in self.config.key.key_server_signing_keys:
                    key_json = sign_json(key_json,
                                         self.config.server.server_name,
                                         signing_key)

                signed_keys.append(key_json)

            response = {"server_keys": signed_keys}

            respond_with_json(request, 200, response, canonical_json=True)
예제 #7
0
def prune_event_dict(room_version: RoomVersion, event_dict: JsonDict) -> JsonDict:
    """Redacts the event_dict in the same way as `prune_event`, except it
    operates on dicts rather than event objects

    Returns:
        A copy of the pruned event dict
    """

    allowed_keys = [
        "event_id",
        "sender",
        "room_id",
        "hashes",
        "signatures",
        "content",
        "type",
        "state_key",
        "depth",
        "prev_events",
        "auth_events",
        "origin",
        "origin_server_ts",
    ]

    # Room versions from before MSC2176 had additional allowed keys.
    if not room_version.msc2176_redaction_rules:
        allowed_keys.extend(["prev_state", "membership"])

    event_type = event_dict["type"]

    new_content = {}

    def add_fields(*fields: str) -> None:
        for field in fields:
            if field in event_dict["content"]:
                new_content[field] = event_dict["content"][field]

    if event_type == EventTypes.Member:
        add_fields("membership")
        if room_version.msc3375_redaction_rules:
            add_fields(EventContentFields.AUTHORISING_USER)
    elif event_type == EventTypes.Create:
        # MSC2176 rules state that create events cannot be redacted.
        if room_version.msc2176_redaction_rules:
            return event_dict

        add_fields("creator")
    elif event_type == EventTypes.JoinRules:
        add_fields("join_rule")
        if room_version.msc3083_join_rules:
            add_fields("allow")
    elif event_type == EventTypes.PowerLevels:
        add_fields(
            "users",
            "users_default",
            "events",
            "events_default",
            "state_default",
            "ban",
            "kick",
            "redact",
        )

        if room_version.msc2176_redaction_rules:
            add_fields("invite")

        if room_version.msc2716_historical:
            add_fields("historical")

    elif event_type == EventTypes.Aliases and room_version.special_case_aliases_auth:
        add_fields("aliases")
    elif event_type == EventTypes.RoomHistoryVisibility:
        add_fields("history_visibility")
    elif event_type == EventTypes.Redaction and room_version.msc2176_redaction_rules:
        add_fields("redacts")
    elif room_version.msc2716_redactions and event_type == EventTypes.MSC2716_INSERTION:
        add_fields(EventContentFields.MSC2716_NEXT_BATCH_ID)
    elif room_version.msc2716_redactions and event_type == EventTypes.MSC2716_BATCH:
        add_fields(EventContentFields.MSC2716_BATCH_ID)
    elif room_version.msc2716_redactions and event_type == EventTypes.MSC2716_MARKER:
        add_fields(EventContentFields.MSC2716_MARKER_INSERTION)

    allowed_fields = {k: v for k, v in event_dict.items() if k in allowed_keys}

    allowed_fields["content"] = new_content

    unsigned: JsonDict = {}
    allowed_fields["unsigned"] = unsigned

    event_unsigned = event_dict.get("unsigned", {})

    if "age_ts" in event_unsigned:
        unsigned["age_ts"] = event_unsigned["age_ts"]
    if "replaces_state" in event_unsigned:
        unsigned["replaces_state"] = event_unsigned["replaces_state"]

    return allowed_fields