Beispiel #1
0
    async def on_rdata(self, stream_name, token, rows):
        await super(ASReplicationHandler,
                    self).on_rdata(stream_name, token, rows)

        if stream_name == "events":
            max_stream_id = self.store.get_room_max_stream_ordering()
            run_in_background(self._notify_app_services, max_stream_id)
Beispiel #2
0
    def start_purge_history(self, room_id, token, delete_local_events=False):
        """Start off a history purge on a room.

        Args:
            room_id (str): The room to purge from

            token (str): topological token to delete events before
            delete_local_events (bool): True to delete local events as well as
                remote ones

        Returns:
            str: unique ID for this purge transaction.
        """
        if room_id in self._purges_in_progress_by_room:
            raise SynapseError(
                400, "History purge already in progress for %s" % (room_id,)
            )

        purge_id = random_string(16)

        # we log the purge_id here so that it can be tied back to the
        # request id in the log lines.
        logger.info("[purge] starting purge_id %s", purge_id)

        self._purges_by_id[purge_id] = PurgeStatus()
        run_in_background(
            self._purge_history, purge_id, room_id, token, delete_local_events
        )
        return purge_id
Beispiel #3
0
    async def store_file(self, path: str, file_info: FileInfo) -> None:
        if not file_info.server_name and not self.store_local:
            return None

        if file_info.server_name and not self.store_remote:
            return None

        if file_info.url_cache:
            # The URL preview cache is short lived and not worth offloading or
            # backing up.
            return None

        if self.store_synchronous:
            # store_file is supposed to return an Awaitable, but guard
            # against improper implementations.
            await maybe_awaitable(self.backend.store_file(path, file_info)
                                  )  # type: ignore
        else:
            # TODO: Handle errors.
            async def store() -> None:
                try:
                    return await maybe_awaitable(
                        self.backend.store_file(path, file_info))
                except Exception:
                    logger.exception("Error storing file")

            run_in_background(store)
Beispiel #4
0
    def _handle_timeouts(self):
        logger.debug("Checking for typing timeouts")

        now = self.clock.time_msec()

        members = set(self.wheel_timer.fetch(now))

        for member in members:
            if not self.is_typing(member):
                # Nothing to do if they're no longer typing
                continue

            until = self._member_typing_until.get(member, None)
            if not until or until <= now:
                logger.info("Timing out typing for: %s", member.user_id)
                self._stopped_typing(member)
                continue

            # Check if we need to resend a keep alive over federation for this
            # user.
            if self.hs.is_mine_id(member.user_id):
                last_fed_poke = self._member_last_federation_poke.get(
                    member, None)
                if not last_fed_poke or last_fed_poke + FEDERATION_PING_INTERVAL <= now:
                    run_in_background(self._push_remote,
                                      member=member,
                                      typing=True)

            # Add a paranoia timer to ensure that we always have a timer for
            # each person typing.
            self.wheel_timer.insert(now=now, obj=member, then=now + 60 * 1000)
Beispiel #5
0
    async def _handle_incoming_transaction(
        self, origin: str, transaction: Transaction, request_time: int
    ) -> Tuple[int, Dict[str, Any]]:
        """ Process an incoming transaction and return the HTTP response

        Args:
            origin: the server making the request
            transaction: incoming transaction
            request_time: timestamp that the HTTP request arrived at

        Returns:
            HTTP response code and body
        """
        response = await self.transaction_actions.have_responded(origin, transaction)

        if response:
            logger.debug(
                "[%s] We've already responded to this request",
                transaction.transaction_id,  # type: ignore
            )
            return response

        logger.debug("[%s] Transaction is new", transaction.transaction_id)  # type: ignore

        # Reject if PDU count > 50 or EDU count > 100
        if len(transaction.pdus) > 50 or (  # type: ignore
            hasattr(transaction, "edus") and len(transaction.edus) > 100  # type: ignore
        ):

            logger.info("Transaction PDU or EDU count too large. Returning 400")

            response = {}
            await self.transaction_actions.set_response(
                origin, transaction, 400, response
            )
            return 400, response

        # We process PDUs and EDUs in parallel. This is important as we don't
        # want to block things like to device messages from reaching clients
        # behind the potentially expensive handling of PDUs.
        pdu_results, _ = await make_deferred_yieldable(
            defer.gatherResults(
                [
                    run_in_background(
                        self._handle_pdus_in_txn, origin, transaction, request_time
                    ),
                    run_in_background(self._handle_edus_in_txn, origin, transaction),
                ],
                consumeErrors=True,
            ).addErrback(unwrapFirstError)
        )

        response = {"pdus": pdu_results}

        logger.debug("Returning: %s", str(response))

        await self.transaction_actions.set_response(origin, transaction, 200, response)
        return 200, response
Beispiel #6
0
    def _verify_objects(
        self, verify_requests: Iterable[VerifyJsonRequest]
    ) -> List[defer.Deferred]:
        """Does the work of verify_json_[objects_]for_server


        Args:
            verify_requests: Iterable of verification requests.

        Returns:
            List<Deferred[None]>: for each input item, a deferred indicating success
                or failure to verify each json object's signature for the given
                server_name. The deferreds run their callbacks in the sentinel
                logcontext.
        """
        # a list of VerifyJsonRequests which are awaiting a key lookup
        key_lookups = []
        handle = preserve_fn(_handle_key_deferred)

        def process(verify_request: VerifyJsonRequest) -> defer.Deferred:
            """Process an entry in the request list

            Adds a key request to key_lookups, and returns a deferred which
            will complete or fail (in the sentinel context) when verification completes.
            """
            if not verify_request.key_ids:
                return defer.fail(
                    SynapseError(
                        400,
                        "Not signed by %s" % (verify_request.server_name,),
                        Codes.UNAUTHORIZED,
                    )
                )

            logger.debug(
                "Verifying %s for %s with key_ids %s, min_validity %i",
                verify_request.request_name,
                verify_request.server_name,
                verify_request.key_ids,
                verify_request.minimum_valid_until_ts,
            )

            # add the key request to the queue, but don't start it off yet.
            key_lookups.append(verify_request)

            # now run _handle_key_deferred, which will wait for the key request
            # to complete and then do the verification.
            #
            # We want _handle_key_request to log to the right context, so we
            # wrap it with preserve_fn (aka run_in_background)
            return handle(verify_request)

        results = [process(r) for r in verify_requests]

        if key_lookups:
            run_in_background(self._start_key_lookups, key_lookups)

        return results
Beispiel #7
0
    async def authenticate_request(self, request, content):
        now = self._clock.time_msec()
        json_request = {
            "method": request.method.decode("ascii"),
            "uri": request.uri.decode("ascii"),
            "destination": self.server_name,
            "signatures": {},
        }

        if content is not None:
            json_request["content"] = content

        origin = None

        auth_headers = request.requestHeaders.getRawHeaders(b"Authorization")

        if not auth_headers:
            raise NoAuthenticationError(
                401, "Missing Authorization headers", Codes.UNAUTHORIZED
            )

        for auth in auth_headers:
            if auth.startswith(b"X-Matrix"):
                (origin, key, sig) = _parse_auth_header(auth)
                json_request["origin"] = origin
                json_request["signatures"].setdefault(origin, {})[key] = sig

        if (
            self.federation_domain_whitelist is not None
            and origin not in self.federation_domain_whitelist
        ):
            raise FederationDeniedError(origin)

        if origin is None or not json_request["signatures"]:
            raise NoAuthenticationError(
                401, "Missing Authorization headers", Codes.UNAUTHORIZED
            )

        await self.keyring.verify_json_for_server(
            origin,
            json_request,
            now,
        )

        logger.debug("Request from %s", origin)
        request.requester = origin

        # If we get a valid signed request from the other side, its probably
        # alive
        retry_timings = await self.store.get_destination_retry_timings(origin)
        if retry_timings and retry_timings.retry_last_ts:
            run_in_background(self._reset_retry_timings, origin)

        return origin
Beispiel #8
0
    def _get_server_verify_keys(self, verify_requests):
        """Tries to find at least one key for each verify request

        For each verify_request, verify_request.key_ready is called back with
        params (server_name, key_id, VerifyKey) if a key is found, or errbacked
        with a SynapseError if none of the keys are found.

        Args:
            verify_requests (list[VerifyJsonRequest]): list of verify requests
        """

        remaining_requests = {
            rq
            for rq in verify_requests if not rq.key_ready.called
        }

        @defer.inlineCallbacks
        def do_iterations():
            with Measure(self.clock, "get_server_verify_keys"):
                for f in self._key_fetchers:
                    if not remaining_requests:
                        return
                    yield self._attempt_key_fetches_with_fetcher(
                        f, remaining_requests)

                # look for any requests which weren't satisfied
                with PreserveLoggingContext():
                    for verify_request in remaining_requests:
                        verify_request.key_ready.errback(
                            SynapseError(
                                401,
                                "No key for %s with ids in %s (min_validity %i)"
                                % (
                                    verify_request.server_name,
                                    verify_request.key_ids,
                                    verify_request.minimum_valid_until_ts,
                                ),
                                Codes.UNAUTHORIZED,
                            ))

        def on_err(err):
            # we don't really expect to get here, because any errors should already
            # have been caught and logged. But if we do, let's log the error and make
            # sure that all of the deferreds are resolved.
            logger.error("Unexpected error in _get_server_verify_keys: %s",
                         err)
            with PreserveLoggingContext():
                for verify_request in remaining_requests:
                    if not verify_request.key_ready.called:
                        verify_request.key_ready.errback(err)

        run_in_background(do_iterations).addErrback(on_err)
Beispiel #9
0
 def send(self, service, events):
     try:
         txn = yield self.store.create_appservice_txn(service=service, events=events)
         service_is_up = yield self._is_service_up(service)
         if service_is_up:
             sent = yield txn.send(self.as_api)
             if sent:
                 yield txn.complete(self.store)
             else:
                 run_in_background(self._start_recoverer, service)
     except Exception:
         logger.exception("Error creating appservice transaction")
         run_in_background(self._start_recoverer, service)
Beispiel #10
0
 async def send(self, service, events):
     try:
         txn = await self.store.create_appservice_txn(service=service,
                                                      events=events)
         service_is_up = await self._is_service_up(service)
         if service_is_up:
             sent = await txn.send(self.as_api)
             if sent:
                 await txn.complete(self.store)
             else:
                 run_in_background(self._on_txn_fail, service)
     except Exception:
         logger.exception("Error creating appservice transaction")
         run_in_background(self._on_txn_fail, service)
Beispiel #11
0
    def _renew_attestations(self):
        """Called periodically to check if we need to update any of our attestations
        """

        now = self.clock.time_msec()

        rows = yield self.store.get_attestations_need_renewals(
            now + UPDATE_ATTESTATION_TIME_MS)

        @defer.inlineCallbacks
        def _renew_attestation(group_id, user_id):
            try:
                if not self.is_mine_id(group_id):
                    destination = get_domain_from_id(group_id)
                elif not self.is_mine_id(user_id):
                    destination = get_domain_from_id(user_id)
                else:
                    logger.warn(
                        "Incorrectly trying to do attestations for user: %r in %r",
                        user_id,
                        group_id,
                    )
                    yield self.store.remove_attestation_renewal(
                        group_id, user_id)
                    return

                attestation = self.attestations.create_attestation(
                    group_id, user_id)

                yield self.transport_client.renew_group_attestation(
                    destination,
                    group_id,
                    user_id,
                    content={"attestation": attestation})

                yield self.store.update_attestation_renewal(
                    group_id, user_id, attestation)
            except (RequestSendFailed, HttpResponseException) as e:
                logger.warning("Failed to renew attestation of %r in %r: %s",
                               user_id, group_id, e)
            except Exception:
                logger.exception("Error renewing attestation of %r in %r",
                                 user_id, group_id)

        for row in rows:
            group_id = row["group_id"]
            user_id = row["user_id"]

            run_in_background(_renew_attestation, group_id, user_id)
Beispiel #12
0
    async def send(
        self,
        service: ApplicationService,
        events: List[EventBase],
        ephemeral: Optional[List[JsonDict]] = None,
        to_device_messages: Optional[List[JsonDict]] = None,
        one_time_key_counts: Optional[TransactionOneTimeKeyCounts] = None,
        unused_fallback_keys: Optional[TransactionUnusedFallbackKeys] = None,
        device_list_summary: Optional[DeviceListUpdates] = None,
    ) -> None:
        """
        Create a transaction with the given data and send to the provided
        application service.

        Args:
            service: The application service to send the transaction to.
            events: The persistent events to include in the transaction.
            ephemeral: The ephemeral events to include in the transaction.
            to_device_messages: The to-device messages to include in the transaction.
            one_time_key_counts: Counts of remaining one-time keys for relevant
                appservice devices in the transaction.
            unused_fallback_keys: Lists of unused fallback keys for relevant
                appservice devices in the transaction.
            device_list_summary: The device list summary to include in the transaction.
        """
        try:
            service_is_up = await self._is_service_up(service)
            # Don't create empty txns when in recovery mode (ephemeral events are dropped)
            if not service_is_up and not events:
                return

            txn = await self.store.create_appservice_txn(
                service=service,
                events=events,
                ephemeral=ephemeral or [],
                to_device_messages=to_device_messages or [],
                one_time_key_counts=one_time_key_counts or {},
                unused_fallback_keys=unused_fallback_keys or {},
                device_list_summary=device_list_summary or DeviceListUpdates(),
            )
            if service_is_up:
                sent = await txn.send(self.as_api)
                if sent:
                    await txn.complete(self.store)
                else:
                    run_in_background(self._on_txn_fail, service)
        except Exception:
            logger.exception("Error creating appservice transaction")
            run_in_background(self._on_txn_fail, service)
Beispiel #13
0
def respond_with_json(
    request: SynapseRequest,
    code: int,
    json_object: Any,
    send_cors: bool = False,
    canonical_json: bool = True,
) -> Optional[int]:
    """Sends encoded JSON in response to the given request.

    Args:
        request: The http request to respond to.
        code: The HTTP response code.
        json_object: The object to serialize to JSON.
        send_cors: Whether to send Cross-Origin Resource Sharing headers
            https://fetch.spec.whatwg.org/#http-cors-protocol
        canonical_json: Whether to use the canonicaljson algorithm when encoding
            the JSON bytes.

    Returns:
        twisted.web.server.NOT_DONE_YET if the request is still active.
    """
    # The response code must always be set, for logging purposes.
    request.setResponseCode(code)

    # could alternatively use request.notifyFinish() and flip a flag when
    # the Deferred fires, but since the flag is RIGHT THERE it seems like
    # a waste.
    if request._disconnected:
        logger.warning(
            "Not sending response to request %s, already disconnected.", request
        )
        return None

    if canonical_json:
        encoder = encode_canonical_json
    else:
        encoder = _encode_json_bytes

    request.setHeader(b"Content-Type", b"application/json")
    request.setHeader(b"Cache-Control", b"no-cache, no-store, must-revalidate")

    if send_cors:
        set_cors_headers(request)

    run_in_background(
        _async_write_json_to_request_in_thread, request, encoder, json_object
    )
    return NOT_DONE_YET
    def _check_for_unknown_devices(
        self,
        message_type: str,
        sender_user_id: str,
        by_device: Dict[str, Dict[str, Any]],
    ):
        """Checks inbound device messages for unkown remote devices, and if
        found marks the remote cache for the user as stale.
        """

        if message_type != "m.room_key_request":
            return

        # Get the sending device IDs
        requesting_device_ids = set()
        for message_content in by_device.values():
            device_id = message_content.get("requesting_device_id")
            requesting_device_ids.add(device_id)

        # Check if we are tracking the devices of the remote user.
        room_ids = yield self.store.get_rooms_for_user(sender_user_id)
        if not room_ids:
            logger.info(
                "Received device message from remote device we don't"
                " share a room with: %s %s",
                sender_user_id,
                requesting_device_ids,
            )
            return

        # If we are tracking check that we know about the sending
        # devices.
        cached_devices = yield self.store.get_cached_devices_for_user(
            sender_user_id)

        unknown_devices = requesting_device_ids - set(cached_devices)
        if unknown_devices:
            logger.info(
                "Received device message from remote device not in our cache: %s %s",
                sender_user_id,
                unknown_devices,
            )
            yield self.store.mark_remote_user_device_cache_as_stale(
                sender_user_id)

            # Immediately attempt a resync in the background
            run_in_background(self._device_list_updater.user_device_resync,
                              sender_user_id)
Beispiel #15
0
    async def get(self) -> TV:
        """Kick off the call if necessary, and return the result"""

        # Fire off the callable now if this is our first time
        if not self._deferred:
            self._deferred = run_in_background(self._callable)

            # we will never need the callable again, so make sure it can be GCed
            self._callable = None

            # once the deferred completes, store the result. We cannot simply leave the
            # result in the deferred, since if it's a Failure, GCing the deferred
            # would then log a critical error about unhandled Failures.
            def got_result(r):
                self._result = r

            self._deferred.addBoth(got_result)

        # TODO: consider cancellation semantics. Currently, if the call to get()
        #    is cancelled, the underlying call will continue (and any future calls
        #    will get the result/exception), which I think is *probably* ok, modulo
        #    the fact the underlying call may be logged to a cancelled logcontext,
        #    and any eventual exception may not be reported.

        # we can now await the deferred, and once it completes, return the result.
        await make_deferred_yieldable(self._deferred)

        # I *think* this is the easiest way to correctly raise a Failure without having
        # to gut-wrench into the implementation of Deferred.
        d = Deferred()
        d.callback(self._result)
        return await d
Beispiel #16
0
    def fetch_or_execute(self, txn_key, fn, *args, **kwargs):
        """Fetches the response for this transaction, or executes the given function
        to produce a response for this transaction.

        Args:
            txn_key (str): A key to ensure idempotency should fetch_or_execute be
            called again at a later point in time.
            fn (function): A function which returns a tuple of
            (response_code, response_dict).
            *args: Arguments to pass to fn.
            **kwargs: Keyword arguments to pass to fn.
        Returns:
            Deferred which resolves to a tuple of (response_code, response_dict).
        """
        if txn_key in self.transactions:
            observable = self.transactions[txn_key][0]
        else:
            # execute the function instead.
            deferred = run_in_background(fn, *args, **kwargs)

            observable = ObservableDeferred(deferred)
            self.transactions[txn_key] = (observable, self.clock.time_msec())

            # if the request fails with an exception, remove it
            # from the transaction map. This is done to ensure that we don't
            # cache transient errors like rate-limiting errors, etc.
            def remove_from_map(err):
                self.transactions.pop(txn_key, None)
                # we deliberately do not propagate the error any further, as we
                # expect the observers to have reported it.

            deferred.addErrback(remove_from_map)

        return make_deferred_yieldable(observable.observe())
Beispiel #17
0
    def fire(self, *args: P.args,
             **kwargs: P.kwargs) -> "defer.Deferred[List[Any]]":
        """Invokes every callable in the observer list, passing in the args and
        kwargs. Exceptions thrown by observers are logged but ignored. It is
        not an error to fire a signal with no observers.

        Returns a Deferred that will complete when all the observers have
        completed."""
        async def do(
                observer: Callable[P, Union[R, Awaitable[R]]]) -> Optional[R]:
            try:
                return await maybe_awaitable(observer(*args, **kwargs))
            except Exception as e:
                logger.warning(
                    "%s signal observer %s failed: %r",
                    self.name,
                    observer,
                    e,
                )
                return None

        deferreds = [run_in_background(do, o) for o in self.observers]

        return make_deferred_yieldable(
            defer.gatherResults(deferreds, consumeErrors=True))
Beispiel #18
0
def concurrently_execute(func: Callable[[T], Any], args: Iterable[T],
                         limit: int) -> defer.Deferred:
    """Executes the function with each argument concurrently while limiting
    the number of concurrent executions.

    Args:
        func: Function to execute, should return a deferred or coroutine.
        args: List of arguments to pass to func, each invocation of func
            gets a single argument.
        limit: Maximum number of conccurent executions.

    Returns:
        Deferred: Resolved when all function invocations have finished.
    """
    it = iter(args)

    async def _concurrently_execute_inner(value: T) -> None:
        try:
            while True:
                await maybe_awaitable(func(value))
                value = next(it)
        except StopIteration:
            pass

    # We use `itertools.islice` to handle the case where the number of args is
    # less than the limit, avoiding needlessly spawning unnecessary background
    # tasks.
    return make_deferred_yieldable(
        defer.gatherResults(
            [
                run_in_background(_concurrently_execute_inner, value)
                for value in itertools.islice(it, limit)
            ],
            consumeErrors=True,
        )).addErrback(unwrapFirstError)
Beispiel #19
0
    def _ensure_fetched(self, store):
        if not self._fetching_state_deferred:
            self._fetching_state_deferred = run_in_background(
                self._fill_out_state, store
            )

        return make_deferred_yieldable(self._fetching_state_deferred)
Beispiel #20
0
    def process_replication_rows(self, stream_name, token, rows):
        # The federation stream contains things that we want to send out, e.g.
        # presence, typing, etc.
        if stream_name == "federation":
            send_queue.process_rows_for_federation(self.federation_sender,
                                                   rows)
            run_in_background(self.update_token, token)

        # We also need to poke the federation sender when new events happen
        elif stream_name == "events":
            self.federation_sender.notify_new_events(token)

        # ... and when new receipts happen
        elif stream_name == ReceiptsStream.NAME:
            run_as_background_process("process_receipts_for_federation",
                                      self._on_new_receipts, rows)
    async def _async_render_GET(self, request: SynapseRequest) -> None:
        # XXX: if get_user_by_req fails, what should we do in an async render?
        requester = await self.auth.get_user_by_req(request)
        url = parse_string(request, "url", required=True)
        ts = parse_integer(request, "ts")
        if ts is None:
            ts = self.clock.time_msec()

        # XXX: we could move this into _do_preview if we wanted.
        url_tuple = urlparse.urlsplit(url)
        for entry in self.url_preview_url_blacklist:
            match = True
            for attrib in entry:
                pattern = entry[attrib]
                value = getattr(url_tuple, attrib)
                logger.debug(
                    "Matching attrib '%s' with value '%s' against pattern '%s'",
                    attrib,
                    value,
                    pattern,
                )

                if value is None:
                    match = False
                    continue

                if pattern.startswith("^"):
                    if not re.match(pattern, getattr(url_tuple, attrib)):
                        match = False
                        continue
                else:
                    if not fnmatch.fnmatch(getattr(url_tuple, attrib),
                                           pattern):
                        match = False
                        continue
            if match:
                logger.warning("URL %s blocked by url_blacklist entry %s", url,
                               entry)
                raise SynapseError(
                    403, "URL blocked by url pattern blacklist entry",
                    Codes.UNKNOWN)

        # the in-memory cache:
        # * ensures that only one request is active at a time
        # * takes load off the DB for the thundering herds
        # * also caches any failures (unlike the DB) so we don't keep
        #    requesting the same endpoint

        observable = self._cache.get(url)

        if not observable:
            download = run_in_background(self._do_preview, url, requester.user,
                                         ts)
            observable = ObservableDeferred(download, consumeErrors=True)
            self._cache[url] = observable
        else:
            logger.info("Returning cached response")

        og = await make_deferred_yieldable(observable.observe())
        respond_with_json_bytes(request, 200, og, send_cors=True)
Beispiel #22
0
def concurrently_execute(func, args, limit):
    """Executes the function with each argument conncurrently while limiting
    the number of concurrent executions.

    Args:
        func (func): Function to execute, should return a deferred.
        args (list): List of arguments to pass to func, each invocation of func
            gets a signle argument.
        limit (int): Maximum number of conccurent executions.

    Returns:
        deferred: Resolved when all function invocations have finished.
    """
    it = iter(args)

    @defer.inlineCallbacks
    def _concurrently_execute_inner():
        try:
            while True:
                yield func(next(it))
        except StopIteration:
            pass

    return make_deferred_yieldable(
        defer.gatherResults(
            [
                run_in_background(_concurrently_execute_inner)
                for _ in range(limit)
            ],
            consumeErrors=True,
        )).addErrback(unwrapFirstError)
Beispiel #23
0
    async def get_keys(self, keys_to_fetch):
        """see KeyFetcher.get_keys"""
        async def get_key(key_server):
            try:
                result = await self.get_server_verify_key_v2_indirect(
                    keys_to_fetch, key_server)
                return result
            except KeyLookupError as e:
                logger.warning("Key lookup failed from %r: %s",
                               key_server.server_name, e)
            except Exception as e:
                logger.exception(
                    "Unable to get key from %r: %s %s",
                    key_server.server_name,
                    type(e).__name__,
                    str(e),
                )

            return {}

        results = await make_deferred_yieldable(
            defer.gatherResults(
                [
                    run_in_background(get_key, server)
                    for server in self.key_servers
                ],
                consumeErrors=True,
            ).addErrback(unwrapFirstError))

        union_of_keys = {}
        for result in results:
            for server_name, keys in result.items():
                union_of_keys.setdefault(server_name, {}).update(keys)

        return union_of_keys
Beispiel #24
0
    def fire(self, *args, **kwargs):
        """Invokes every callable in the observer list, passing in the args and
        kwargs. Exceptions thrown by observers are logged but ignored. It is
        not an error to fire a signal with no observers.

        Returns a Deferred that will complete when all the observers have
        completed."""
        def do(observer):
            def eb(failure):
                logger.warning(
                    "%s signal observer %s failed: %r",
                    self.name,
                    observer,
                    failure,
                    exc_info=(
                        failure.type,
                        failure.value,
                        failure.getTracebackObject(),
                    ),
                )

            return maybeAwaitableDeferred(observer, *args,
                                          **kwargs).addErrback(eb)

        deferreds = [run_in_background(do, o) for o in self.observers]

        return make_deferred_yieldable(
            defer.gatherResults(deferreds, consumeErrors=True))
Beispiel #25
0
    def fire(self, *args, **kwargs):
        """Invokes every callable in the observer list, passing in the args and
        kwargs. Exceptions thrown by observers are logged but ignored. It is
        not an error to fire a signal with no observers.

        Returns a Deferred that will complete when all the observers have
        completed."""
        async def do(observer):
            try:
                result = observer(*args, **kwargs)
                if inspect.isawaitable(result):
                    result = await result
                return result
            except Exception as e:
                logger.warning(
                    "%s signal observer %s failed: %r",
                    self.name,
                    observer,
                    e,
                )

        deferreds = [run_in_background(do, o) for o in self.observers]

        return make_deferred_yieldable(
            defer.gatherResults(deferreds, consumeErrors=True))
Beispiel #26
0
    def verify_json_objects_for_server(
        self, server_and_json: Iterable[Tuple[str, dict,
                                              int]]) -> List[defer.Deferred]:
        """Bulk verifies signatures of json objects, bulk fetching keys as
        necessary.

        Args:
            server_and_json:
                Iterable of (server_name, json_object, validity_time)
                tuples.

                validity_time is a timestamp at which the signing key must be
                valid.

        Returns:
            List<Deferred[None]>: for each input triplet, a deferred indicating success
                or failure to verify each json object's signature for the given
                server_name. The deferreds run their callbacks in the sentinel
                logcontext.
        """
        return [
            run_in_background(
                self.process_request,
                VerifyJsonRequest.from_json_object(
                    server_name,
                    json_object,
                    validity_time,
                ),
            ) for server_name, json_object, validity_time in server_and_json
        ]
Beispiel #27
0
def concurrently_execute(func: Callable, args: Iterable[Any],
                         limit: int) -> defer.Deferred:
    """Executes the function with each argument concurrently while limiting
    the number of concurrent executions.

    Args:
        func: Function to execute, should return a deferred or coroutine.
        args: List of arguments to pass to func, each invocation of func
            gets a single argument.
        limit: Maximum number of conccurent executions.

    Returns:
        Deferred[list]: Resolved when all function invocations have finished.
    """
    it = iter(args)

    async def _concurrently_execute_inner():
        try:
            while True:
                await maybe_awaitable(func(next(it)))
        except StopIteration:
            pass

    return make_deferred_yieldable(
        defer.gatherResults(
            [
                run_in_background(_concurrently_execute_inner)
                for _ in range(limit)
            ],
            consumeErrors=True,
        )).addErrback(unwrapFirstError)
Beispiel #28
0
    def get_events_from_store_or_dest(self, destination, room_id, event_ids):
        """Fetch events from a remote destination, checking if we already have them.

        Args:
            destination (str)
            room_id (str)
            event_ids (list)

        Returns:
            Deferred: A deferred resolving to a 2-tuple where the first is a list of
            events and the second is a list of event ids that we failed to fetch.
        """
        seen_events = yield self.store.get_events(event_ids, allow_rejected=True)
        signed_events = list(seen_events.values())

        failed_to_fetch = set()

        missing_events = set(event_ids)
        for k in seen_events:
            missing_events.discard(k)

        if not missing_events:
            return signed_events, failed_to_fetch

        logger.debug(
            "Fetching unknown state/auth events %s for room %s",
            missing_events,
            event_ids,
        )

        room_version = yield self.store.get_room_version(room_id)

        batch_size = 20
        missing_events = list(missing_events)
        for i in range(0, len(missing_events), batch_size):
            batch = set(missing_events[i : i + batch_size])

            deferreds = [
                run_in_background(
                    self.get_pdu,
                    destinations=[destination],
                    event_id=e_id,
                    room_version=room_version,
                )
                for e_id in batch
            ]

            res = yield make_deferred_yieldable(
                defer.DeferredList(deferreds, consumeErrors=True)
            )
            for success, result in res:
                if success and result:
                    signed_events.append(result)
                    batch.discard(result.event_id)

            # We removed all events we successfully fetched from `batch`
            failed_to_fetch.update(batch)

        return signed_events, failed_to_fetch
Beispiel #29
0
    async def _get_total_count_to_port(self, table: str, forward_chunk: int,
                                       backward_chunk: int) -> Tuple[int, int]:
        remaining, done = await make_deferred_yieldable(
            defer.gatherResults([
                run_in_background(
                    self._get_remaining_count_to_port,
                    table,
                    forward_chunk,
                    backward_chunk,
                ),
                run_in_background(self._get_already_ported_count, table),
            ], ))

        remaining = int(remaining) if remaining else 0
        done = int(done) if done else 0

        return done, remaining + done
    def test_backfill_with_many_backward_extremities(self):
        """
        Check that we can backfill with many backward extremities.
        The goal is to make sure that when we only use a portion
        of backwards extremities(the magic number is more than 5),
        no errors are thrown.

        Regression test, see #11027
        """
        # create the room
        user_id = self.register_user("kermit", "test")
        tok = self.login("kermit", "test")
        requester = create_requester(user_id)

        room_id = self.helper.create_room_as(room_creator=user_id, tok=tok)

        ev1 = self.helper.send(room_id, "first message", tok=tok)

        # Create "many" backward extremities. The magic number we're trying to
        # create more than is 5 which corresponds to the number of backward
        # extremities we slice off in `_maybe_backfill_inner`
        for _ in range(0, 8):
            event_handler = self.hs.get_event_creation_handler()
            event, context = self.get_success(
                event_handler.create_event(
                    requester,
                    {
                        "type": "m.room.message",
                        "content": {
                            "msgtype": "m.text",
                            "body": "message connected to fake event",
                        },
                        "room_id": room_id,
                        "sender": user_id,
                    },
                    prev_event_ids=[
                        ev1["event_id"],
                        # We're creating an backward extremity each time thanks
                        # to this fake event
                        generate_fake_event_id(),
                    ],
                )
            )
            self.get_success(
                event_handler.handle_new_client_event(requester, event, context)
            )

        current_depth = 1
        limit = 100
        with LoggingContext("receive_pdu"):
            # Make sure backfill still works
            d = run_in_background(
                self.hs.get_federation_handler().maybe_backfill,
                room_id,
                current_depth,
                limit,
            )
        self.get_success(d)