Beispiel #1
0
    def on_PUT(self, request, room_id, membership_action, txn_id):
        set_tag("txn_id", txn_id)

        return self.txns.fetch_or_execute_request(request, self.on_POST,
                                                  request, room_id,
                                                  membership_action, txn_id)
Beispiel #2
0
 def on_PUT(self, request, txn_id):
     set_tag("txn_id", txn_id)
     return self.txns.fetch_or_execute_request(request, self.on_POST,
                                               request)
Beispiel #3
0
    def request(self, method, uri, data=None, headers=None):
        """
        Args:
            method (str): HTTP method to use.
            uri (str): URI to query.
            data (bytes): Data to send in the request body, if applicable.
            headers (t.w.http_headers.Headers): Request headers.
        """
        # A small wrapper around self.agent.request() so we can easily attach
        # counters to it
        outgoing_requests_counter.labels(method).inc()

        # log request but strip `access_token` (AS requests for example include this)
        logger.info("Sending request %s %s", method, redact_uri(uri))

        with start_active_span(
            "outgoing-client-request",
            tags={
                tags.SPAN_KIND: tags.SPAN_KIND_RPC_CLIENT,
                tags.HTTP_METHOD: method,
                tags.HTTP_URL: uri,
            },
            finish_on_close=True,
        ):
            try:
                body_producer = None
                if data is not None:
                    body_producer = QuieterFileBodyProducer(BytesIO(data))

                request_deferred = treq.request(
                    method,
                    uri,
                    agent=self.agent,
                    data=body_producer,
                    headers=headers,
                    **self._extra_treq_args
                )
                request_deferred = timeout_deferred(
                    request_deferred,
                    60,
                    self.hs.get_reactor(),
                    cancelled_to_request_timed_out_error,
                )
                response = yield make_deferred_yieldable(request_deferred)

                incoming_responses_counter.labels(method, response.code).inc()
                logger.info(
                    "Received response to %s %s: %s",
                    method,
                    redact_uri(uri),
                    response.code,
                )
                return response
            except Exception as e:
                incoming_responses_counter.labels(method, "ERR").inc()
                logger.info(
                    "Error sending request to  %s %s: %s %s",
                    method,
                    redact_uri(uri),
                    type(e).__name__,
                    e.args[0],
                )
                set_tag(tags.ERROR, True)
                set_tag("error_reason", e.args[0])
                raise
Beispiel #4
0
    async def send_device_message(
        self,
        requester: Requester,
        message_type: str,
        messages: Dict[str, Dict[str, JsonDict]],
    ) -> None:
        sender_user_id = requester.user.to_string()

        set_tag("number_of_messages", len(messages))
        set_tag("sender", sender_user_id)
        local_messages = {}
        remote_messages = {}  # type: Dict[str, Dict[str, Dict[str, JsonDict]]]
        for user_id, by_device in messages.items():
            # Ratelimit local cross-user key requests by the sending device.
            if (message_type == EduTypes.RoomKeyRequest
                    and user_id != sender_user_id
                    and self._ratelimiter.can_do_action(
                        (sender_user_id, requester.device_id))):
                continue

            # we use UserID.from_string to catch invalid user ids
            if self.is_mine(UserID.from_string(user_id)):
                messages_by_device = {
                    device_id: {
                        "content": message_content,
                        "type": message_type,
                        "sender": sender_user_id,
                    }
                    for device_id, message_content in by_device.items()
                }
                if messages_by_device:
                    local_messages[user_id] = messages_by_device
            else:
                destination = get_domain_from_id(user_id)
                remote_messages.setdefault(destination,
                                           {})[user_id] = by_device

        message_id = random_string(16)

        context = get_active_span_text_map()

        remote_edu_contents = {}
        for destination, messages in remote_messages.items():
            with start_active_span("to_device_for_user"):
                set_tag("destination", destination)
                remote_edu_contents[destination] = {
                    "messages": messages,
                    "sender": sender_user_id,
                    "type": message_type,
                    "message_id": message_id,
                    "org.matrix.opentracing_context":
                    json_encoder.encode(context),
                }

        log_kv({"local_messages": local_messages})
        stream_id = await self.store.add_messages_to_device_inbox(
            local_messages, remote_edu_contents)

        self.notifier.on_new_event("to_device_key",
                                   stream_id,
                                   users=local_messages.keys())

        log_kv({"remote_messages": remote_messages})
        if self.federation_sender:
            for destination in remote_messages.keys():
                # Enqueue a new federation transaction to send the new
                # device messages to each remote destination.
                self.federation_sender.send_device_messages(destination)
Beispiel #5
0
    def query_devices(self, query_body, timeout, from_user_id):
        """ Handle a device key query from a client

        {
            "device_keys": {
                "<user_id>": ["<device_id>"]
            }
        }
        ->
        {
            "device_keys": {
                "<user_id>": {
                    "<device_id>": {
                        ...
                    }
                }
            }
        }

        Args:
            from_user_id (str): the user making the query.  This is used when
                adding cross-signing signatures to limit what signatures users
                can see.
        """

        device_keys_query = query_body.get("device_keys", {})

        # separate users by domain.
        # make a map from domain to user_id to device_ids
        local_query = {}
        remote_queries = {}

        for user_id, device_ids in device_keys_query.items():
            # we use UserID.from_string to catch invalid user ids
            if self.is_mine(UserID.from_string(user_id)):
                local_query[user_id] = device_ids
            else:
                remote_queries[user_id] = device_ids

        set_tag("local_key_query", local_query)
        set_tag("remote_key_query", remote_queries)

        # First get local devices.
        failures = {}
        results = {}
        if local_query:
            local_result = yield self.query_local_devices(local_query)
            for user_id, keys in local_result.items():
                if user_id in local_query:
                    results[user_id] = keys

        # Now attempt to get any remote devices from our local cache.
        remote_queries_not_in_cache = {}
        if remote_queries:
            query_list = []
            for user_id, device_ids in iteritems(remote_queries):
                if device_ids:
                    query_list.extend(
                        (user_id, device_id) for device_id in device_ids)
                else:
                    query_list.append((user_id, None))

            (
                user_ids_not_in_cache,
                remote_results,
            ) = yield self.store.get_user_devices_from_cache(query_list)
            for user_id, devices in iteritems(remote_results):
                user_devices = results.setdefault(user_id, {})
                for device_id, device in iteritems(devices):
                    keys = device.get("keys", None)
                    device_display_name = device.get("device_display_name",
                                                     None)
                    if keys:
                        result = dict(keys)
                        unsigned = result.setdefault("unsigned", {})
                        if device_display_name:
                            unsigned[
                                "device_display_name"] = device_display_name
                        user_devices[device_id] = result

            for user_id in user_ids_not_in_cache:
                domain = get_domain_from_id(user_id)
                r = remote_queries_not_in_cache.setdefault(domain, {})
                r[user_id] = remote_queries[user_id]

        # Get cached cross-signing keys
        cross_signing_keys = yield self.get_cross_signing_keys_from_cache(
            device_keys_query, from_user_id)

        # Now fetch any devices that we don't have in our cache
        @trace
        @defer.inlineCallbacks
        def do_remote_query(destination):
            """This is called when we are querying the device list of a user on
            a remote homeserver and their device list is not in the device list
            cache. If we share a room with this user and we're not querying for
            specific user we will update the cache
            with their device list."""

            destination_query = remote_queries_not_in_cache[destination]

            # We first consider whether we wish to update the device list cache with
            # the users device list. We want to track a user's devices when the
            # authenticated user shares a room with the queried user and the query
            # has not specified a particular device.
            # If we update the cache for the queried user we remove them from further
            # queries. We use the more efficient batched query_client_keys for all
            # remaining users
            user_ids_updated = []
            for (user_id, device_list) in destination_query.items():
                if user_id in user_ids_updated:
                    continue

                if device_list:
                    continue

                room_ids = yield self.store.get_rooms_for_user(user_id)
                if not room_ids:
                    continue

                # We've decided we're sharing a room with this user and should
                # probably be tracking their device lists. However, we haven't
                # done an initial sync on the device list so we do it now.
                try:
                    if self._is_master:
                        user_devices = yield self.device_handler.device_list_updater.user_device_resync(
                            user_id)
                    else:
                        user_devices = yield self._user_device_resync_client(
                            user_id=user_id)

                    user_devices = user_devices["devices"]
                    for device in user_devices:
                        results[user_id] = {
                            device["device_id"]: device["keys"]
                        }
                    user_ids_updated.append(user_id)
                except Exception as e:
                    failures[destination] = _exception_to_failure(e)

            if len(destination_query) == len(user_ids_updated):
                # We've updated all the users in the query and we do not need to
                # make any further remote calls.
                return

            # Remove all the users from the query which we have updated
            for user_id in user_ids_updated:
                destination_query.pop(user_id)

            try:
                remote_result = yield self.federation.query_client_keys(
                    destination, {"device_keys": destination_query},
                    timeout=timeout)

                for user_id, keys in remote_result["device_keys"].items():
                    if user_id in destination_query:
                        results[user_id] = keys

                if "master_keys" in remote_result:
                    for user_id, key in remote_result["master_keys"].items():
                        if user_id in destination_query:
                            cross_signing_keys["master_keys"][user_id] = key

                if "self_signing_keys" in remote_result:
                    for user_id, key in remote_result[
                            "self_signing_keys"].items():
                        if user_id in destination_query:
                            cross_signing_keys["self_signing_keys"][
                                user_id] = key

            except Exception as e:
                failure = _exception_to_failure(e)
                failures[destination] = failure
                set_tag("error", True)
                set_tag("reason", failure)

        yield make_deferred_yieldable(
            defer.gatherResults(
                [
                    run_in_background(do_remote_query, destination)
                    for destination in remote_queries_not_in_cache
                ],
                consumeErrors=True,
            ).addErrback(unwrapFirstError))

        ret = {"device_keys": results, "failures": failures}

        ret.update(cross_signing_keys)

        return ret
Beispiel #6
0
    def get_user_by_req(
        self,
        request: Request,
        allow_guest: bool = False,
        rights: str = "access",
        allow_expired: bool = False,
    ):
        """ Get a registered user's ID.

        Args:
            request: An HTTP request with an access_token query parameter.
            allow_guest: If False, will raise an AuthError if the user making the
                request is a guest.
            rights: The operation being performed; the access token must allow this
            allow_expired: If True, allow the request through even if the account
                is expired, or session token lifetime has ended. Note that
                /login will deliver access tokens regardless of expiration.

        Returns:
            defer.Deferred: resolves to a `synapse.types.Requester` object
        Raises:
            InvalidClientCredentialsError if no user by that token exists or the token
                is invalid.
            AuthError if access is denied for the user in the access token
        """
        try:
            ip_addr = self.hs.get_ip_from_request(request)
            user_agent = request.requestHeaders.getRawHeaders(
                b"User-Agent",
                default=[b""])[0].decode("ascii", "surrogateescape")

            access_token = self.get_access_token_from_request(request)

            user_id, app_service = yield self._get_appservice_user_id(request)
            if user_id:
                request.authenticated_entity = user_id
                opentracing.set_tag("authenticated_entity", user_id)
                opentracing.set_tag("appservice_id", app_service.id)

                if ip_addr and self._track_appservice_user_ips:
                    yield self.store.insert_client_ip(
                        user_id=user_id,
                        access_token=access_token,
                        ip=ip_addr,
                        user_agent=user_agent,
                        device_id="dummy-device",  # stubbed
                    )

                return synapse.types.create_requester(user_id,
                                                      app_service=app_service)

            user_info = yield self.get_user_by_access_token(
                access_token, rights, allow_expired=allow_expired)
            user = user_info["user"]
            token_id = user_info["token_id"]
            is_guest = user_info["is_guest"]

            # Deny the request if the user account has expired.
            if self._account_validity.enabled and not allow_expired:
                user_id = user.to_string()
                expiration_ts = yield self.store.get_expiration_ts_for_user(
                    user_id)
                if (expiration_ts is not None
                        and self.clock.time_msec() >= expiration_ts):
                    raise AuthError(403,
                                    "User account has expired",
                                    errcode=Codes.EXPIRED_ACCOUNT)

            # device_id may not be present if get_user_by_access_token has been
            # stubbed out.
            device_id = user_info.get("device_id")

            if user and access_token and ip_addr:
                yield self.store.insert_client_ip(
                    user_id=user.to_string(),
                    access_token=access_token,
                    ip=ip_addr,
                    user_agent=user_agent,
                    device_id=device_id,
                )

            if is_guest and not allow_guest:
                raise AuthError(
                    403,
                    "Guest access not allowed",
                    errcode=Codes.GUEST_ACCESS_FORBIDDEN,
                )

            request.authenticated_entity = user.to_string()
            opentracing.set_tag("authenticated_entity", user.to_string())
            if device_id:
                opentracing.set_tag("device_id", device_id)

            return synapse.types.create_requester(user,
                                                  token_id,
                                                  is_guest,
                                                  device_id,
                                                  app_service=app_service)
        except KeyError:
            raise MissingClientTokenError()
Beispiel #7
0
    async def send_device_message(
        self,
        requester: Requester,
        message_type: str,
        messages: Dict[str, Dict[str, JsonDict]],
    ) -> None:
        sender_user_id = requester.user.to_string()

        message_id = random_string(16)
        set_tag(SynapseTags.TO_DEVICE_MESSAGE_ID, message_id)

        log_kv({"number_of_to_device_messages": len(messages)})
        set_tag("sender", sender_user_id)
        local_messages = {}
        remote_messages: Dict[str, Dict[str, Dict[str, JsonDict]]] = {}
        for user_id, by_device in messages.items():
            # Ratelimit local cross-user key requests by the sending device.
            if (message_type == ToDeviceEventTypes.RoomKeyRequest
                    and user_id != sender_user_id):
                allowed, _ = await self._ratelimiter.can_do_action(
                    requester, (sender_user_id, requester.device_id))
                if not allowed:
                    logger.info(
                        "Dropping room_key_request from %s to %s due to rate limit",
                        sender_user_id,
                        user_id,
                    )
                    continue

            # we use UserID.from_string to catch invalid user ids
            if self.is_mine(UserID.from_string(user_id)):
                messages_by_device = {
                    device_id: {
                        "content": message_content,
                        "type": message_type,
                        "sender": sender_user_id,
                        "message_id": message_id,
                    }
                    for device_id, message_content in by_device.items()
                }
                if messages_by_device:
                    local_messages[user_id] = messages_by_device
                    log_kv({
                        "user_id": user_id,
                        "device_id": list(messages_by_device),
                    })
            else:
                destination = get_domain_from_id(user_id)
                remote_messages.setdefault(destination,
                                           {})[user_id] = by_device

        context = get_active_span_text_map()

        remote_edu_contents = {}
        for destination, messages in remote_messages.items():
            log_kv({"destination": destination})
            remote_edu_contents[destination] = {
                "messages": messages,
                "sender": sender_user_id,
                "type": message_type,
                "message_id": message_id,
                "org.matrix.opentracing_context": json_encoder.encode(context),
            }

        stream_id = await self.store.add_messages_to_device_inbox(
            local_messages, remote_edu_contents)

        self.notifier.on_new_event("to_device_key",
                                   stream_id,
                                   users=local_messages.keys())

        if self.federation_sender:
            for destination in remote_messages.keys():
                # Enqueue a new federation transaction to send the new
                # device messages to each remote destination.
                self.federation_sender.send_device_messages(destination)
Beispiel #8
0
        def do_remote_query(destination):
            """This is called when we are querying the device list of a user on
            a remote homeserver and their device list is not in the device list
            cache. If we share a room with this user and we're not querying for
            specific user we will update the cache
            with their device list."""

            destination_query = remote_queries_not_in_cache[destination]

            # We first consider whether we wish to update the device list cache with
            # the users device list. We want to track a user's devices when the
            # authenticated user shares a room with the queried user and the query
            # has not specified a particular device.
            # If we update the cache for the queried user we remove them from further
            # queries. We use the more efficient batched query_client_keys for all
            # remaining users
            user_ids_updated = []
            for (user_id, device_list) in destination_query.items():
                if user_id in user_ids_updated:
                    continue

                if device_list:
                    continue

                room_ids = yield self.store.get_rooms_for_user(user_id)
                if not room_ids:
                    continue

                # We've decided we're sharing a room with this user and should
                # probably be tracking their device lists. However, we haven't
                # done an initial sync on the device list so we do it now.
                try:
                    if self._is_master:
                        user_devices = yield self.device_handler.device_list_updater.user_device_resync(
                            user_id)
                    else:
                        user_devices = yield self._user_device_resync_client(
                            user_id=user_id)

                    user_devices = user_devices["devices"]
                    for device in user_devices:
                        results[user_id] = {
                            device["device_id"]: device["keys"]
                        }
                    user_ids_updated.append(user_id)
                except Exception as e:
                    failures[destination] = _exception_to_failure(e)

            if len(destination_query) == len(user_ids_updated):
                # We've updated all the users in the query and we do not need to
                # make any further remote calls.
                return

            # Remove all the users from the query which we have updated
            for user_id in user_ids_updated:
                destination_query.pop(user_id)

            try:
                remote_result = yield self.federation.query_client_keys(
                    destination, {"device_keys": destination_query},
                    timeout=timeout)

                for user_id, keys in remote_result["device_keys"].items():
                    if user_id in destination_query:
                        results[user_id] = keys

                if "master_keys" in remote_result:
                    for user_id, key in remote_result["master_keys"].items():
                        if user_id in destination_query:
                            cross_signing_keys["master_keys"][user_id] = key

                if "self_signing_keys" in remote_result:
                    for user_id, key in remote_result[
                            "self_signing_keys"].items():
                        if user_id in destination_query:
                            cross_signing_keys["self_signing_keys"][
                                user_id] = key

            except Exception as e:
                failure = _exception_to_failure(e)
                failures[destination] = failure
                set_tag("error", True)
                set_tag("reason", failure)
    async def send_new_transaction(
        self,
        destination: str,
        pending_pdus: List[Tuple[EventBase, int]],
        pending_edus: List[Edu],
    ):

        # Make a transaction-sending opentracing span. This span follows on from
        # all the edus in that transaction. This needs to be done since there is
        # no active span here, so if the edus were not received by the remote the
        # span would have no causality and it would be forgotten.

        span_contexts = []
        keep_destination = whitelisted_homeserver(destination)

        for edu in pending_edus:
            context = edu.get_context()
            if context:
                span_contexts.append(extract_text_map(json.loads(context)))
            if keep_destination:
                edu.strip_context()

        with start_active_span_follows_from("send_transaction", span_contexts):

            # Sort based on the order field
            pending_pdus.sort(key=lambda t: t[1])
            pdus = [x[0] for x in pending_pdus]
            edus = pending_edus

            success = True

            logger.debug("TX [%s] _attempt_new_transaction", destination)

            txn_id = str(self._next_txn_id)

            logger.debug(
                "TX [%s] {%s} Attempting new transaction (pdus: %d, edus: %d)",
                destination,
                txn_id,
                len(pdus),
                len(edus),
            )

            transaction = Transaction.create_new(
                origin_server_ts=int(self.clock.time_msec()),
                transaction_id=txn_id,
                origin=self._server_name,
                destination=destination,
                pdus=pdus,
                edus=edus,
            )

            self._next_txn_id += 1

            logger.info(
                "TX [%s] {%s} Sending transaction [%s], (PDUs: %d, EDUs: %d)",
                destination,
                txn_id,
                transaction.transaction_id,
                len(pdus),
                len(edus),
            )

            # Actually send the transaction

            # FIXME (erikj): This is a bit of a hack to make the Pdu age
            # keys work
            def json_data_cb():
                data = transaction.get_dict()
                now = int(self.clock.time_msec())
                if "pdus" in data:
                    for p in data["pdus"]:
                        if "age_ts" in p:
                            unsigned = p.setdefault("unsigned", {})
                            unsigned["age"] = now - int(p["age_ts"])
                            del p["age_ts"]
                return data

            try:
                response = await self._transport_layer.send_transaction(
                    transaction, json_data_cb)
                code = 200
            except HttpResponseException as e:
                code = e.code
                response = e.response

                if e.code in (401, 404, 429) or 500 <= e.code:
                    logger.info("TX [%s] {%s} got %d response", destination,
                                txn_id, code)
                    raise e

            logger.info("TX [%s] {%s} got %d response", destination, txn_id,
                        code)

            if code == 200:
                for e_id, r in response.get("pdus", {}).items():
                    if "error" in r:
                        logger.warning(
                            "TX [%s] {%s} Remote returned error for %s: %s",
                            destination,
                            txn_id,
                            e_id,
                            r,
                        )
            else:
                for p in pdus:
                    logger.warning(
                        "TX [%s] {%s} Failed to send event %s",
                        destination,
                        txn_id,
                        p.event_id,
                    )
                success = False

            set_tag(tags.ERROR, not success)
            return success
Beispiel #10
0
 def on_PUT(self, request, message_type, txn_id):
     set_tag("message_type", message_type)
     set_tag("txn_id", txn_id)
     return self.txns.fetch_or_execute_request(request, self._put, request,
                                               message_type, txn_id)
Beispiel #11
0
    async def get_user_ids_changed(self, user_id: str,
                                   from_token: StreamToken) -> JsonDict:
        """Get list of users that have had the devices updated, or have newly
        joined a room, that `user_id` may be interested in.
        """

        set_tag("user_id", user_id)
        set_tag("from_token", from_token)
        now_room_key = self.store.get_room_max_token()

        room_ids = await self.store.get_rooms_for_user(user_id)

        changed = await self.get_device_changes_in_shared_rooms(
            user_id, room_ids, from_token)

        # Then work out if any users have since joined
        rooms_changed = self.store.get_rooms_that_changed(
            room_ids, from_token.room_key)

        member_events = await self.store.get_membership_changes_for_user(
            user_id, from_token.room_key, now_room_key)
        rooms_changed.update(event.room_id for event in member_events)

        stream_ordering = from_token.room_key.stream

        possibly_changed = set(changed)
        possibly_left = set()
        for room_id in rooms_changed:
            current_state_ids = await self._state_storage.get_current_state_ids(
                room_id)

            # The user may have left the room
            # TODO: Check if they actually did or if we were just invited.
            if room_id not in room_ids:
                for etype, state_key in current_state_ids.keys():
                    if etype != EventTypes.Member:
                        continue
                    possibly_left.add(state_key)
                continue

            # Fetch the current state at the time.
            try:
                event_ids = await self.store.get_forward_extremities_for_room_at_stream_ordering(
                    room_id, stream_ordering=stream_ordering)
            except errors.StoreError:
                # we have purged the stream_ordering index since the stream
                # ordering: treat it the same as a new room
                event_ids = []

            # special-case for an empty prev state: include all members
            # in the changed list
            if not event_ids:
                log_kv({
                    "event": "encountered empty previous state",
                    "room_id": room_id
                })
                for etype, state_key in current_state_ids.keys():
                    if etype != EventTypes.Member:
                        continue
                    possibly_changed.add(state_key)
                continue

            current_member_id = current_state_ids.get(
                (EventTypes.Member, user_id))
            if not current_member_id:
                continue

            # mapping from event_id -> state_dict
            prev_state_ids = await self._state_storage.get_state_ids_for_events(
                event_ids)

            # Check if we've joined the room? If so we just blindly add all the users to
            # the "possibly changed" users.
            for state_dict in prev_state_ids.values():
                member_event = state_dict.get((EventTypes.Member, user_id),
                                              None)
                if not member_event or member_event != current_member_id:
                    for etype, state_key in current_state_ids.keys():
                        if etype != EventTypes.Member:
                            continue
                        possibly_changed.add(state_key)
                    break

            # If there has been any change in membership, include them in the
            # possibly changed list. We'll check if they are joined below,
            # and we're not toooo worried about spuriously adding users.
            for key, event_id in current_state_ids.items():
                etype, state_key = key
                if etype != EventTypes.Member:
                    continue

                # check if this member has changed since any of the extremities
                # at the stream_ordering, and add them to the list if so.
                for state_dict in prev_state_ids.values():
                    prev_event_id = state_dict.get(key, None)
                    if not prev_event_id or prev_event_id != event_id:
                        if state_key != user_id:
                            possibly_changed.add(state_key)
                        break

        if possibly_changed or possibly_left:
            possibly_joined = possibly_changed
            possibly_left = possibly_changed | possibly_left

            # Double check if we still share rooms with the given user.
            users_rooms = await self.store.get_rooms_for_users_with_stream_ordering(
                possibly_left)
            for changed_user_id, entries in users_rooms.items():
                if any(e.room_id in room_ids for e in entries):
                    possibly_left.discard(changed_user_id)
                else:
                    possibly_joined.discard(changed_user_id)

        else:
            possibly_joined = set()
            possibly_left = set()

        result = {
            "changed": list(possibly_joined),
            "left": list(possibly_left)
        }

        log_kv(result)

        return result
Beispiel #12
0
    async def get_user_by_req(
        self,
        request: SynapseRequest,
        allow_guest: bool = False,
        rights: str = "access",
        allow_expired: bool = False,
    ) -> Requester:
        """Get a registered user's ID.

        Args:
            request: An HTTP request with an access_token query parameter.
            allow_guest: If False, will raise an AuthError if the user making the
                request is a guest.
            rights: The operation being performed; the access token must allow this
            allow_expired: If True, allow the request through even if the account
                is expired, or session token lifetime has ended. Note that
                /login will deliver access tokens regardless of expiration.

        Returns:
            Resolves to the requester
        Raises:
            InvalidClientCredentialsError if no user by that token exists or the token
                is invalid.
            AuthError if access is denied for the user in the access token
        """
        try:
            ip_addr = request.getClientIP()
            user_agent = get_request_user_agent(request)

            access_token = self.get_access_token_from_request(request)

            user_id, app_service = await self._get_appservice_user_id(request)
            if user_id and app_service:
                if ip_addr and self._track_appservice_user_ips:
                    await self.store.insert_client_ip(
                        user_id=user_id,
                        access_token=access_token,
                        ip=ip_addr,
                        user_agent=user_agent,
                        device_id="dummy-device",  # stubbed
                    )

                requester = create_requester(user_id, app_service=app_service)

                request.requester = user_id
                if user_id in self._force_tracing_for_users:
                    opentracing.force_tracing()
                opentracing.set_tag("authenticated_entity", user_id)
                opentracing.set_tag("user_id", user_id)
                opentracing.set_tag("appservice_id", app_service.id)

                return requester

            user_info = await self.get_user_by_access_token(
                access_token, rights, allow_expired=allow_expired
            )
            token_id = user_info.token_id
            is_guest = user_info.is_guest
            shadow_banned = user_info.shadow_banned

            # Deny the request if the user account has expired.
            if not allow_expired:
                if await self._account_validity_handler.is_user_expired(
                    user_info.user_id
                ):
                    # Raise the error if either an account validity module has determined
                    # the account has expired, or the legacy account validity
                    # implementation is enabled and determined the account has expired
                    raise AuthError(
                        403,
                        "User account has expired",
                        errcode=Codes.EXPIRED_ACCOUNT,
                    )

            device_id = user_info.device_id

            if access_token and ip_addr:
                await self.store.insert_client_ip(
                    user_id=user_info.token_owner,
                    access_token=access_token,
                    ip=ip_addr,
                    user_agent=user_agent,
                    device_id=device_id,
                )

            if is_guest and not allow_guest:
                raise AuthError(
                    403,
                    "Guest access not allowed",
                    errcode=Codes.GUEST_ACCESS_FORBIDDEN,
                )

            # Mark the token as used. This is used to invalidate old refresh
            # tokens after some time.
            if not user_info.token_used and token_id is not None:
                await self.store.mark_access_token_as_used(token_id)

            requester = create_requester(
                user_info.user_id,
                token_id,
                is_guest,
                shadow_banned,
                device_id,
                app_service=app_service,
                authenticated_entity=user_info.token_owner,
            )

            request.requester = requester
            if user_info.token_owner in self._force_tracing_for_users:
                opentracing.force_tracing()
            opentracing.set_tag("authenticated_entity", user_info.token_owner)
            opentracing.set_tag("user_id", user_info.user_id)
            if device_id:
                opentracing.set_tag("device_id", device_id)

            return requester
        except KeyError:
            raise MissingClientTokenError()
Beispiel #13
0
        async def new_func(request: SynapseRequest, *args: Any,
                           **kwargs: str) -> Optional[Tuple[int, Any]]:
            """A callback which can be passed to HttpServer.RegisterPaths

            Args:
                request:
                *args: unused?
                **kwargs: the dict mapping keys to path components as specified
                    in the path match regexp.

            Returns:
                (response code, response object) as returned by the callback method.
                None if the request has already been handled.
            """
            content = None
            if request.method in [b"PUT", b"POST"]:
                # TODO: Handle other method types? other content types?
                content = parse_json_object_from_request(request)

            try:
                origin: Optional[
                    str] = await authenticator.authenticate_request(
                        request, content)
            except NoAuthenticationError:
                origin = None
                if self.REQUIRE_AUTH:
                    logger.warning(
                        "authenticate_request failed: missing authentication")
                    raise
            except Exception as e:
                logger.warning("authenticate_request failed: %s", e)
                raise

            # update the active opentracing span with the authenticated entity
            set_tag("authenticated_entity", origin)

            # if the origin is authenticated and whitelisted, link to its span context
            context = None
            if origin and whitelisted_homeserver(origin):
                context = span_context_from_request(request)

            scope = start_active_span_follows_from(
                "incoming-federation-request",
                contexts=(context, ) if context else ())

            with scope:
                if origin and self.RATELIMIT:
                    with ratelimiter.ratelimit(origin) as d:
                        await d
                        if request._disconnected:
                            logger.warning(
                                "client disconnected before we started processing "
                                "request")
                            return None
                        response = await func(origin, content, request.args,
                                              *args, **kwargs)
                else:
                    response = await func(origin, content, request.args, *args,
                                          **kwargs)

            return response
Beispiel #14
0
    async def get_user_ids_changed(self, user_id, from_token):
        """Get list of users that have had the devices updated, or have newly
        joined a room, that `user_id` may be interested in.

        Args:
            user_id (str)
            from_token (StreamToken)
        """

        set_tag("user_id", user_id)
        set_tag("from_token", from_token)
        now_room_key = await self.store.get_room_events_max_id()

        room_ids = await self.store.get_rooms_for_user(user_id)

        # First we check if any devices have changed for users that we share
        # rooms with.
        users_who_share_room = await self.store.get_users_who_share_room_with_user(
            user_id)

        tracked_users = set(users_who_share_room)

        # Always tell the user about their own devices
        tracked_users.add(user_id)

        changed = await self.store.get_users_whose_devices_changed(
            from_token.device_list_key, tracked_users)

        # Then work out if any users have since joined
        rooms_changed = self.store.get_rooms_that_changed(
            room_ids, from_token.room_key)

        member_events = await self.store.get_membership_changes_for_user(
            user_id, from_token.room_key, now_room_key)
        rooms_changed.update(event.room_id for event in member_events)

        stream_ordering = RoomStreamToken.parse_stream_token(
            from_token.room_key).stream

        possibly_changed = set(changed)
        possibly_left = set()
        for room_id in rooms_changed:
            current_state_ids = await self.store.get_current_state_ids(room_id)

            # The user may have left the room
            # TODO: Check if they actually did or if we were just invited.
            if room_id not in room_ids:
                for key, event_id in current_state_ids.items():
                    etype, state_key = key
                    if etype != EventTypes.Member:
                        continue
                    possibly_left.add(state_key)
                continue

            # Fetch the current state at the time.
            try:
                event_ids = await self.store.get_forward_extremeties_for_room(
                    room_id, stream_ordering=stream_ordering)
            except errors.StoreError:
                # we have purged the stream_ordering index since the stream
                # ordering: treat it the same as a new room
                event_ids = []

            # special-case for an empty prev state: include all members
            # in the changed list
            if not event_ids:
                log_kv({
                    "event": "encountered empty previous state",
                    "room_id": room_id
                })
                for key, event_id in current_state_ids.items():
                    etype, state_key = key
                    if etype != EventTypes.Member:
                        continue
                    possibly_changed.add(state_key)
                continue

            current_member_id = current_state_ids.get(
                (EventTypes.Member, user_id))
            if not current_member_id:
                continue

            # mapping from event_id -> state_dict
            prev_state_ids = await self.state_store.get_state_ids_for_events(
                event_ids)

            # Check if we've joined the room? If so we just blindly add all the users to
            # the "possibly changed" users.
            for state_dict in prev_state_ids.values():
                member_event = state_dict.get((EventTypes.Member, user_id),
                                              None)
                if not member_event or member_event != current_member_id:
                    for key, event_id in current_state_ids.items():
                        etype, state_key = key
                        if etype != EventTypes.Member:
                            continue
                        possibly_changed.add(state_key)
                    break

            # If there has been any change in membership, include them in the
            # possibly changed list. We'll check if they are joined below,
            # and we're not toooo worried about spuriously adding users.
            for key, event_id in current_state_ids.items():
                etype, state_key = key
                if etype != EventTypes.Member:
                    continue

                # check if this member has changed since any of the extremities
                # at the stream_ordering, and add them to the list if so.
                for state_dict in prev_state_ids.values():
                    prev_event_id = state_dict.get(key, None)
                    if not prev_event_id or prev_event_id != event_id:
                        if state_key != user_id:
                            possibly_changed.add(state_key)
                        break

        if possibly_changed or possibly_left:
            # Take the intersection of the users whose devices may have changed
            # and those that actually still share a room with the user
            possibly_joined = possibly_changed & users_who_share_room
            possibly_left = (possibly_changed
                             | possibly_left) - users_who_share_room
        else:
            possibly_joined = []
            possibly_left = []

        result = {
            "changed": list(possibly_joined),
            "left": list(possibly_left)
        }

        log_kv(result)

        return result
Beispiel #15
0
    async def get_user_by_req(
        self,
        request: Request,
        allow_guest: bool = False,
        rights: str = "access",
        allow_expired: bool = False,
    ) -> synapse.types.Requester:
        """ Get a registered user's ID.

        Args:
            request: An HTTP request with an access_token query parameter.
            allow_guest: If False, will raise an AuthError if the user making the
                request is a guest.
            rights: The operation being performed; the access token must allow this
            allow_expired: If True, allow the request through even if the account
                is expired, or session token lifetime has ended. Note that
                /login will deliver access tokens regardless of expiration.

        Returns:
            Resolves to the requester
        Raises:
            InvalidClientCredentialsError if no user by that token exists or the token
                is invalid.
            AuthError if access is denied for the user in the access token
        """
        try:
            ip_addr = self.hs.get_ip_from_request(request)
            user_agent = request.get_user_agent("")

            access_token = self.get_access_token_from_request(request)

            user_id, app_service = await self._get_appservice_user_id(request)
            if user_id:
                if ip_addr and self._track_appservice_user_ips:
                    await self.store.insert_client_ip(
                        user_id=user_id,
                        access_token=access_token,
                        ip=ip_addr,
                        user_agent=user_agent,
                        device_id="dummy-device",  # stubbed
                    )

                requester = synapse.types.create_requester(
                    user_id, app_service=app_service
                )

                request.requester = user_id
                opentracing.set_tag("authenticated_entity", user_id)
                opentracing.set_tag("user_id", user_id)
                opentracing.set_tag("appservice_id", app_service.id)

                return requester

            user_info = await self.get_user_by_access_token(
                access_token, rights, allow_expired=allow_expired
            )
            token_id = user_info.token_id
            is_guest = user_info.is_guest
            shadow_banned = user_info.shadow_banned

            # Deny the request if the user account has expired.
            if self._account_validity.enabled and not allow_expired:
                if await self.store.is_account_expired(
                    user_info.user_id, self.clock.time_msec()
                ):
                    raise AuthError(
                        403, "User account has expired", errcode=Codes.EXPIRED_ACCOUNT
                    )

            device_id = user_info.device_id

            if access_token and ip_addr:
                await self.store.insert_client_ip(
                    user_id=user_info.token_owner,
                    access_token=access_token,
                    ip=ip_addr,
                    user_agent=user_agent,
                    device_id=device_id,
                )

            if is_guest and not allow_guest:
                raise AuthError(
                    403,
                    "Guest access not allowed",
                    errcode=Codes.GUEST_ACCESS_FORBIDDEN,
                )

            requester = synapse.types.create_requester(
                user_info.user_id,
                token_id,
                is_guest,
                shadow_banned,
                device_id,
                app_service=app_service,
                authenticated_entity=user_info.token_owner,
            )

            request.requester = requester
            opentracing.set_tag("authenticated_entity", user_info.token_owner)
            opentracing.set_tag("user_id", user_info.user_id)
            if device_id:
                opentracing.set_tag("device_id", device_id)

            return requester
        except KeyError:
            raise MissingClientTokenError()
Beispiel #16
0
    async def user_device_resync(
            self,
            user_id: str,
            mark_failed_as_stale: bool = True) -> Optional[dict]:
        """Fetches all devices for a user and updates the device cache with them.

        Args:
            user_id: The user's id whose device_list will be updated.
            mark_failed_as_stale: Whether to mark the user's device list as stale
                if the attempt to resync failed.
        Returns:
            A dict with device info as under the "devices" in the result of this
            request:
            https://matrix.org/docs/spec/server_server/r0.1.2#get-matrix-federation-v1-user-devices-userid
        """
        logger.debug("Attempting to resync the device list for %s", user_id)
        log_kv({"message": "Doing resync to update device list."})
        # Fetch all devices for the user.
        origin = get_domain_from_id(user_id)
        try:
            result = await self.federation.query_user_devices(origin, user_id)
        except NotRetryingDestination:
            if mark_failed_as_stale:
                # Mark the remote user's device list as stale so we know we need to retry
                # it later.
                await self.store.mark_remote_user_device_cache_as_stale(user_id
                                                                        )

            return
        except (RequestSendFailed, HttpResponseException) as e:
            logger.warning(
                "Failed to handle device list update for %s: %s",
                user_id,
                e,
            )

            if mark_failed_as_stale:
                # Mark the remote user's device list as stale so we know we need to retry
                # it later.
                await self.store.mark_remote_user_device_cache_as_stale(user_id
                                                                        )

            # We abort on exceptions rather than accepting the update
            # as otherwise synapse will 'forget' that its device list
            # is out of date. If we bail then we will retry the resync
            # next time we get a device list update for this user_id.
            # This makes it more likely that the device lists will
            # eventually become consistent.
            return
        except FederationDeniedError as e:
            set_tag("error", True)
            log_kv({"reason": "FederationDeniedError"})
            logger.info(e)
            return
        except Exception as e:
            set_tag("error", True)
            log_kv({
                "message": "Exception raised by federation request",
                "exception": e
            })
            logger.exception("Failed to handle device list update for %s",
                             user_id)

            if mark_failed_as_stale:
                # Mark the remote user's device list as stale so we know we need to retry
                # it later.
                await self.store.mark_remote_user_device_cache_as_stale(user_id
                                                                        )

            return
        log_kv({"result": result})
        stream_id = result["stream_id"]
        devices = result["devices"]

        # Get the master key and the self-signing key for this user if provided in the
        # response (None if not in the response).
        # The response will not contain the user signing key, as this key is only used by
        # its owner, thus it doesn't make sense to send it over federation.
        master_key = result.get("master_key")
        self_signing_key = result.get("self_signing_key")

        # If the remote server has more than ~1000 devices for this user
        # we assume that something is going horribly wrong (e.g. a bot
        # that logs in and creates a new device every time it tries to
        # send a message).  Maintaining lots of devices per user in the
        # cache can cause serious performance issues as if this request
        # takes more than 60s to complete, internal replication from the
        # inbound federation worker to the synapse master may time out
        # causing the inbound federation to fail and causing the remote
        # server to retry, causing a DoS.  So in this scenario we give
        # up on storing the total list of devices and only handle the
        # delta instead.
        if len(devices) > 1000:
            logger.warning(
                "Ignoring device list snapshot for %s as it has >1K devs (%d)",
                user_id,
                len(devices),
            )
            devices = []

        for device in devices:
            logger.debug(
                "Handling resync update %r/%r, ID: %r",
                user_id,
                device["device_id"],
                stream_id,
            )

        await self.store.update_remote_device_list_cache(
            user_id, devices, stream_id)
        device_ids = [device["device_id"] for device in devices]

        # Handle cross-signing keys.
        cross_signing_device_ids = await self.process_cross_signing_key_update(
            user_id,
            master_key,
            self_signing_key,
        )
        device_ids = device_ids + cross_signing_device_ids

        await self.device_handler.notify_device_update(user_id, device_ids)

        # We clobber the seen updates since we've re-synced from a given
        # point.
        self._seen_updates[user_id] = {stream_id}

        return result
Beispiel #17
0
    def _get_e2e_device_keys_txn(self,
                                 txn,
                                 query_list,
                                 include_all_devices=False,
                                 include_deleted_devices=False):
        set_tag("include_all_devices", include_all_devices)
        set_tag("include_deleted_devices", include_deleted_devices)

        query_clauses = []
        query_params = []
        signature_query_clauses = []
        signature_query_params = []

        if include_all_devices is False:
            include_deleted_devices = False

        if include_deleted_devices:
            deleted_devices = set(query_list)

        for (user_id, device_id) in query_list:
            query_clause = "user_id = ?"
            query_params.append(user_id)
            signature_query_clause = "target_user_id = ?"
            signature_query_params.append(user_id)

            if device_id is not None:
                query_clause += " AND device_id = ?"
                query_params.append(device_id)
                signature_query_clause += " AND target_device_id = ?"
                signature_query_params.append(device_id)

            signature_query_clause += " AND user_id = ?"
            signature_query_params.append(user_id)

            query_clauses.append(query_clause)
            signature_query_clauses.append(signature_query_clause)

        sql = ("SELECT user_id, device_id, "
               "    d.display_name AS device_display_name, "
               "    k.key_json"
               " FROM devices d"
               "    %s JOIN e2e_device_keys_json k USING (user_id, device_id)"
               " WHERE %s AND NOT d.hidden") % (
                   "LEFT" if include_all_devices else "INNER",
                   " OR ".join("(" + q + ")" for q in query_clauses),
               )

        txn.execute(sql, query_params)
        rows = self.db.cursor_to_dict(txn)

        result = {}
        for row in rows:
            if include_deleted_devices:
                deleted_devices.remove((row["user_id"], row["device_id"]))
            result.setdefault(row["user_id"], {})[row["device_id"]] = row

        if include_deleted_devices:
            for user_id, device_id in deleted_devices:
                result.setdefault(user_id, {})[device_id] = None

        # get signatures on the device
        signature_sql = ("SELECT *  FROM e2e_cross_signing_signatures WHERE %s"
                         ) % (" OR ".join("(" + q + ")"
                                          for q in signature_query_clauses))

        txn.execute(signature_sql, signature_query_params)
        rows = self.db.cursor_to_dict(txn)

        # add each cross-signing signature to the correct device in the result dict.
        for row in rows:
            signing_user_id = row["user_id"]
            signing_key_id = row["key_id"]
            target_user_id = row["target_user_id"]
            target_device_id = row["target_device_id"]
            signature = row["signature"]

            target_user_result = result.get(target_user_id)
            if not target_user_result:
                continue

            target_device_result = target_user_result.get(target_device_id)
            if not target_device_result:
                # note that target_device_result will be None for deleted devices.
                continue

            target_device_signatures = target_device_result.setdefault(
                "signatures", {})
            signing_user_signatures = target_device_signatures.setdefault(
                signing_user_id, {})
            signing_user_signatures[signing_key_id] = signature

        log_kv(result)
        return result
Beispiel #18
0
    def claim_one_time_keys(self, query, timeout):
        local_query = []
        remote_queries = {}

        for user_id, device_keys in query.get("one_time_keys", {}).items():
            # we use UserID.from_string to catch invalid user ids
            if self.is_mine(UserID.from_string(user_id)):
                for device_id, algorithm in device_keys.items():
                    local_query.append((user_id, device_id, algorithm))
            else:
                domain = get_domain_from_id(user_id)
                remote_queries.setdefault(domain, {})[user_id] = device_keys

        set_tag("local_key_query", local_query)
        set_tag("remote_key_query", remote_queries)

        results = yield self.store.claim_e2e_one_time_keys(local_query)

        json_result = {}
        failures = {}
        for user_id, device_keys in results.items():
            for device_id, keys in device_keys.items():
                for key_id, json_bytes in keys.items():
                    json_result.setdefault(user_id, {})[device_id] = {
                        key_id: json.loads(json_bytes)
                    }

        @trace
        @defer.inlineCallbacks
        def claim_client_keys(destination):
            set_tag("destination", destination)
            device_keys = remote_queries[destination]
            try:
                remote_result = yield self.federation.claim_client_keys(
                    destination, {"one_time_keys": device_keys},
                    timeout=timeout)
                for user_id, keys in remote_result["one_time_keys"].items():
                    if user_id in device_keys:
                        json_result[user_id] = keys

            except Exception as e:
                failure = _exception_to_failure(e)
                failures[destination] = failure
                set_tag("error", True)
                set_tag("reason", failure)

        yield make_deferred_yieldable(
            defer.gatherResults(
                [
                    run_in_background(claim_client_keys, destination)
                    for destination in remote_queries
                ],
                consumeErrors=True,
            ))

        logger.info(
            "Claimed one-time-keys: %s",
            ",".join(("%s for %s:%s" % (key_id, user_id, device_id)
                      for user_id, user_keys in iteritems(json_result)
                      for device_id, device_keys in iteritems(user_keys)
                      for key_id, _ in iteritems(device_keys))),
        )

        log_kv({"one_time_keys": json_result, "failures": failures})
        return {"one_time_keys": json_result, "failures": failures}
Beispiel #19
0
    def on_PUT(self, request: SynapseRequest, room_id: str,
               txn_id: str) -> Awaitable[Tuple[int, JsonDict]]:
        set_tag("txn_id", txn_id)

        return self.txns.fetch_or_execute_request(request, self.on_POST,
                                                  request, room_id, txn_id)
Beispiel #20
0
    async def request(
        self,
        method: str,
        uri: str,
        data: Optional[bytes] = None,
        headers: Optional[Headers] = None,
    ) -> IResponse:
        """
        Args:
            method: HTTP method to use.
            uri: URI to query.
            data: Data to send in the request body, if applicable.
            headers: Request headers.

        Returns:
            Response object, once the headers have been read.

        Raises:
            RequestTimedOutError if the request times out before the headers are read

        """
        outgoing_requests_counter.labels(method).inc()

        # log request but strip `access_token` (AS requests for example include this)
        logger.debug("Sending request %s %s", method, redact_uri(uri))

        with start_active_span(
                "outgoing-client-request",
                tags={
                    tags.SPAN_KIND: tags.SPAN_KIND_RPC_CLIENT,
                    tags.HTTP_METHOD: method,
                    tags.HTTP_URL: uri,
                },
                finish_on_close=True,
        ):
            try:
                body_producer = None
                if data is not None:
                    body_producer = QuieterFileBodyProducer(
                        BytesIO(data),
                        cooperator=self._cooperator,
                    )

                request_deferred = treq.request(
                    method,
                    uri,
                    agent=self.agent,
                    data=body_producer,
                    headers=headers,
                    **self._extra_treq_args,
                )  # type: defer.Deferred

                # we use our own timeout mechanism rather than treq's as a workaround
                # for https://twistedmatrix.com/trac/ticket/9534.
                request_deferred = timeout_deferred(
                    request_deferred,
                    60,
                    self.hs.get_reactor(),
                )

                # turn timeouts into RequestTimedOutErrors
                request_deferred.addErrback(
                    _timeout_to_request_timed_out_error)

                response = await make_deferred_yieldable(request_deferred)

                incoming_responses_counter.labels(method, response.code).inc()
                logger.info(
                    "Received response to %s %s: %s",
                    method,
                    redact_uri(uri),
                    response.code,
                )
                return response
            except Exception as e:
                incoming_responses_counter.labels(method, "ERR").inc()
                logger.info(
                    "Error sending request to  %s %s: %s %s",
                    method,
                    redact_uri(uri),
                    type(e).__name__,
                    e.args[0],
                )
                set_tag(tags.ERROR, True)
                set_tag("error_reason", e.args[0])
                raise
Beispiel #21
0
    async def _send_request(
        self,
        request: MatrixFederationRequest,
        retry_on_dns_fail: bool = True,
        timeout: Optional[int] = None,
        long_retries: bool = False,
        ignore_backoff: bool = False,
        backoff_on_404: bool = False,
    ) -> IResponse:
        """
        Sends a request to the given server.

        Args:
            request: details of request to be sent

            retry_on_dns_fail: true if the request should be retied on DNS failures

            timeout: number of milliseconds to wait for the response headers
                (including connecting to the server), *for each attempt*.
                60s by default.

            long_retries: whether to use the long retry algorithm.

                The regular retry algorithm makes 4 attempts, with intervals
                [0.5s, 1s, 2s].

                The long retry algorithm makes 11 attempts, with intervals
                [4s, 16s, 60s, 60s, ...]

                Both algorithms add -20%/+40% jitter to the retry intervals.

                Note that the above intervals are *in addition* to the time spent
                waiting for the request to complete (up to `timeout` ms).

                NB: the long retry algorithm takes over 20 minutes to complete, with
                a default timeout of 60s!

            ignore_backoff: true to ignore the historical backoff data
                and try the request anyway.

            backoff_on_404: Back off if we get a 404

        Returns:
            Resolves with the HTTP response object on success.

        Raises:
            HttpResponseException: If we get an HTTP response code >= 300
                (except 429).
            NotRetryingDestination: If we are not yet ready to retry this
                server.
            FederationDeniedError: If this destination  is not on our
                federation whitelist
            RequestSendFailed: If there were problems connecting to the
                remote, due to e.g. DNS failures, connection timeouts etc.
        """
        if timeout:
            _sec_timeout = timeout / 1000
        else:
            _sec_timeout = self.default_timeout

        if (self.hs.config.federation_domain_whitelist is not None
                and request.destination
                not in self.hs.config.federation_domain_whitelist):
            raise FederationDeniedError(request.destination)

        limiter = await synapse.util.retryutils.get_retry_limiter(
            request.destination,
            self.clock,
            self._store,
            backoff_on_404=backoff_on_404,
            ignore_backoff=ignore_backoff,
        )

        method_bytes = request.method.encode("ascii")
        destination_bytes = request.destination.encode("ascii")
        path_bytes = request.path.encode("ascii")
        if request.query:
            query_bytes = encode_query_args(request.query)
        else:
            query_bytes = b""

        scope = start_active_span(
            "outgoing-federation-request",
            tags={
                tags.SPAN_KIND: tags.SPAN_KIND_RPC_CLIENT,
                tags.PEER_ADDRESS: request.destination,
                tags.HTTP_METHOD: request.method,
                tags.HTTP_URL: request.path,
            },
            finish_on_close=True,
        )

        # Inject the span into the headers
        headers_dict = {}  # type: Dict[bytes, List[bytes]]
        inject_active_span_byte_dict(headers_dict, request.destination)

        headers_dict[b"User-Agent"] = [self.version_string_bytes]

        with limiter, scope:
            # XXX: Would be much nicer to retry only at the transaction-layer
            # (once we have reliable transactions in place)
            if long_retries:
                retries_left = MAX_LONG_RETRIES
            else:
                retries_left = MAX_SHORT_RETRIES

            url_bytes = request.uri
            url_str = url_bytes.decode("ascii")

            url_to_sign_bytes = urllib.parse.urlunparse(
                (b"", b"", path_bytes, None, query_bytes, b""))

            while True:
                try:
                    json = request.get_json()
                    if json:
                        headers_dict[b"Content-Type"] = [b"application/json"]
                        auth_headers = self.build_auth_headers(
                            destination_bytes, method_bytes, url_to_sign_bytes,
                            json)
                        data = encode_canonical_json(json)
                        producer = QuieterFileBodyProducer(
                            BytesIO(data), cooperator=self._cooperator
                        )  # type: Optional[IBodyProducer]
                    else:
                        producer = None
                        auth_headers = self.build_auth_headers(
                            destination_bytes, method_bytes, url_to_sign_bytes)

                    headers_dict[b"Authorization"] = auth_headers

                    logger.debug(
                        "{%s} [%s] Sending request: %s %s; timeout %fs",
                        request.txn_id,
                        request.destination,
                        request.method,
                        url_str,
                        _sec_timeout,
                    )

                    outgoing_requests_counter.labels(request.method).inc()

                    try:
                        with Measure(self.clock, "outbound_request"):
                            # we don't want all the fancy cookie and redirect handling
                            # that treq.request gives: just use the raw Agent.
                            request_deferred = self.agent.request(
                                method_bytes,
                                url_bytes,
                                headers=Headers(headers_dict),
                                bodyProducer=producer,
                            )

                            request_deferred = timeout_deferred(
                                request_deferred,
                                timeout=_sec_timeout,
                                reactor=self.reactor,
                            )

                            response = await request_deferred
                    except DNSLookupError as e:
                        raise RequestSendFailed(
                            e, can_retry=retry_on_dns_fail) from e
                    except Exception as e:
                        raise RequestSendFailed(e, can_retry=True) from e

                    incoming_responses_counter.labels(request.method,
                                                      response.code).inc()

                    set_tag(tags.HTTP_STATUS_CODE, response.code)
                    response_phrase = response.phrase.decode("ascii",
                                                             errors="replace")

                    if 200 <= response.code < 300:
                        logger.debug(
                            "{%s} [%s] Got response headers: %d %s",
                            request.txn_id,
                            request.destination,
                            response.code,
                            response_phrase,
                        )
                        pass
                    else:
                        logger.info(
                            "{%s} [%s] Got response headers: %d %s",
                            request.txn_id,
                            request.destination,
                            response.code,
                            response_phrase,
                        )
                        # :'(
                        # Update transactions table?
                        d = treq.content(response)
                        d = timeout_deferred(d,
                                             timeout=_sec_timeout,
                                             reactor=self.reactor)

                        try:
                            body = await make_deferred_yieldable(d)
                        except Exception as e:
                            # Eh, we're already going to raise an exception so lets
                            # ignore if this fails.
                            logger.warning(
                                "{%s} [%s] Failed to get error response: %s %s: %s",
                                request.txn_id,
                                request.destination,
                                request.method,
                                url_str,
                                _flatten_response_never_received(e),
                            )
                            body = None

                        exc = HttpResponseException(response.code,
                                                    response_phrase, body)

                        # Retry if the error is a 429 (Too Many Requests),
                        # otherwise just raise a standard HttpResponseException
                        if response.code == 429:
                            raise RequestSendFailed(exc,
                                                    can_retry=True) from exc
                        else:
                            raise exc

                    break
                except RequestSendFailed as e:
                    logger.info(
                        "{%s} [%s] Request failed: %s %s: %s",
                        request.txn_id,
                        request.destination,
                        request.method,
                        url_str,
                        _flatten_response_never_received(e.inner_exception),
                    )

                    if not e.can_retry:
                        raise

                    if retries_left and not timeout:
                        if long_retries:
                            delay = 4**(MAX_LONG_RETRIES + 1 - retries_left)
                            delay = min(delay, 60)
                            delay *= random.uniform(0.8, 1.4)
                        else:
                            delay = 0.5 * 2**(MAX_SHORT_RETRIES - retries_left)
                            delay = min(delay, 2)
                            delay *= random.uniform(0.8, 1.4)

                        logger.debug(
                            "{%s} [%s] Waiting %ss before re-sending...",
                            request.txn_id,
                            request.destination,
                            delay,
                        )

                        await self.clock.sleep(delay)
                        retries_left -= 1
                    else:
                        raise

                except Exception as e:
                    logger.warning(
                        "{%s} [%s] Request failed: %s %s: %s",
                        request.txn_id,
                        request.destination,
                        request.method,
                        url_str,
                        _flatten_response_never_received(e),
                    )
                    raise
        return response
    async def get_e2e_device_keys_and_signatures(
        self,
        query_list: List[Tuple[str, Optional[str]]],
        include_all_devices: bool = False,
        include_deleted_devices: bool = False,
    ) -> Dict[str, Dict[str, Optional[DeviceKeyLookupResult]]]:
        """Fetch a list of device keys

        Any cross-signatures made on the keys by the owner of the device are also
        included.

        The cross-signatures are added to the `signatures` field within the `keys`
        object in the response.

        Args:
            query_list: List of pairs of user_ids and device_ids. Device id can be None
                to indicate "all devices for this user"

            include_all_devices: whether to return devices without device keys

            include_deleted_devices: whether to include null entries for
                devices which no longer exist (but were in the query_list).
                This option only takes effect if include_all_devices is true.

        Returns:
            Dict mapping from user-id to dict mapping from device_id to
            key data.
        """
        set_tag("include_all_devices", include_all_devices)
        set_tag("include_deleted_devices", include_deleted_devices)

        result = await self.db_pool.runInteraction(
            "get_e2e_device_keys",
            self._get_e2e_device_keys_txn,
            query_list,
            include_all_devices,
            include_deleted_devices,
        )

        # get the (user_id, device_id) tuples to look up cross-signatures for
        signature_query = (
            (user_id, device_id)
            for user_id, dev in result.items()
            for device_id, d in dev.items()
            if d is not None and d.keys is not None
        )

        for batch in batch_iter(signature_query, 50):
            cross_sigs_result = await self.db_pool.runInteraction(
                "get_e2e_cross_signing_signatures",
                self._get_e2e_cross_signing_signatures_for_devices_txn,
                batch,
            )

            # add each cross-signing signature to the correct device in the result dict.
            for (user_id, key_id, device_id, signature) in cross_sigs_result:
                target_device_result = result[user_id][device_id]
                target_device_signatures = target_device_result.keys.setdefault(
                    "signatures", {}
                )
                signing_user_signatures = target_device_signatures.setdefault(
                    user_id, {}
                )
                signing_user_signatures[key_id] = signature

        log_kv(result)
        return result
Beispiel #23
0
    async def send_new_transaction(
        self,
        destination: str,
        pdus: List[EventBase],
        edus: List[Edu],
    ) -> None:
        """
        Args:
            destination: The destination to send to (e.g. 'example.org')
            pdus: In-order list of PDUs to send
            edus: List of EDUs to send
        """

        # Make a transaction-sending opentracing span. This span follows on from
        # all the edus in that transaction. This needs to be done since there is
        # no active span here, so if the edus were not received by the remote the
        # span would have no causality and it would be forgotten.

        span_contexts = []
        keep_destination = whitelisted_homeserver(destination)

        for edu in edus:
            context = edu.get_context()
            if context:
                span_contexts.append(extract_text_map(json_decoder.decode(context)))
            if keep_destination:
                edu.strip_context()

        with start_active_span_follows_from("send_transaction", span_contexts):
            logger.debug("TX [%s] _attempt_new_transaction", destination)

            txn_id = str(self._next_txn_id)

            logger.debug(
                "TX [%s] {%s} Attempting new transaction (pdus: %d, edus: %d)",
                destination,
                txn_id,
                len(pdus),
                len(edus),
            )

            transaction = Transaction(
                origin_server_ts=int(self.clock.time_msec()),
                transaction_id=txn_id,
                origin=self._server_name,
                destination=destination,
                pdus=[p.get_pdu_json() for p in pdus],
                edus=[edu.get_dict() for edu in edus],
            )

            self._next_txn_id += 1

            logger.info(
                "TX [%s] {%s} Sending transaction [%s], (PDUs: %d, EDUs: %d)",
                destination,
                txn_id,
                transaction.transaction_id,
                len(pdus),
                len(edus),
            )
            if issue_8631_logger.isEnabledFor(logging.DEBUG):
                DEVICE_UPDATE_EDUS = {
                    EduTypes.DEVICE_LIST_UPDATE,
                    EduTypes.SIGNING_KEY_UPDATE,
                }
                device_list_updates = [
                    edu.content for edu in edus if edu.edu_type in DEVICE_UPDATE_EDUS
                ]
                if device_list_updates:
                    issue_8631_logger.debug(
                        "about to send txn [%s] including device list updates: %s",
                        transaction.transaction_id,
                        device_list_updates,
                    )

            # Actually send the transaction

            # FIXME (erikj): This is a bit of a hack to make the Pdu age
            # keys work
            # FIXME (richardv): I also believe it no longer works. We (now?) store
            #  "age_ts" in "unsigned" rather than at the top level. See
            #  https://github.com/matrix-org/synapse/issues/8429.
            def json_data_cb() -> JsonDict:
                data = transaction.get_dict()
                now = int(self.clock.time_msec())
                if "pdus" in data:
                    for p in data["pdus"]:
                        if "age_ts" in p:
                            unsigned = p.setdefault("unsigned", {})
                            unsigned["age"] = now - int(p["age_ts"])
                            del p["age_ts"]
                return data

            try:
                response = await self._transport_layer.send_transaction(
                    transaction, json_data_cb
                )
            except HttpResponseException as e:
                code = e.code

                set_tag(tags.ERROR, True)

                logger.info("TX [%s] {%s} got %d response", destination, txn_id, code)
                raise

            logger.info("TX [%s] {%s} got 200 response", destination, txn_id)

            for e_id, r in response.get("pdus", {}).items():
                if "error" in r:
                    logger.warning(
                        "TX [%s] {%s} Remote returned error for %s: %s",
                        destination,
                        txn_id,
                        e_id,
                        r,
                    )

            if pdus and destination in self._federation_metrics_domains:
                last_pdu = pdus[-1]
                last_pdu_ts_metric.labels(server_name=destination).set(
                    last_pdu.origin_server_ts / 1000
                )