def _do_get_well_known(self, server_name):
        """Actually fetch and parse a .well-known, without checking the cache

        Args:
            server_name (bytes): name of the server, from the requested url

        Returns:
            Deferred[Tuple[bytes|None|object],int]:
                result, cache period, where result is one of:
                 - the new server name from the .well-known (as a `bytes`)
                 - None if there was no .well-known file.
                 - INVALID_WELL_KNOWN if the .well-known was invalid
        """
        uri = b"https://%s/.well-known/matrix/server" % (server_name, )
        uri_str = uri.decode("ascii")
        logger.info("Fetching %s", uri_str)
        try:
            response = yield make_deferred_yieldable(
                self._well_known_agent.request(b"GET", uri), )
            body = yield make_deferred_yieldable(readBody(response))
            if response.code != 200:
                raise Exception("Non-200 response %s" % (response.code, ))

            parsed_body = json.loads(body.decode('utf-8'))
            logger.info("Response from .well-known: %s", parsed_body)
            if not isinstance(parsed_body, dict):
                raise Exception("not a dict")
            if "m.server" not in parsed_body:
                raise Exception("Missing key 'm.server'")
        except Exception as e:
            logger.info("Error fetching %s: %s", uri_str, e)

            # add some randomness to the TTL to avoid a stampeding herd every hour
            # after startup
            cache_period = WELL_KNOWN_INVALID_CACHE_PERIOD
            cache_period += random.uniform(
                0, WELL_KNOWN_DEFAULT_CACHE_PERIOD_JITTER)
            defer.returnValue((None, cache_period))

        result = parsed_body["m.server"].encode("ascii")

        cache_period = _cache_period_from_headers(
            response.headers,
            time_now=self._reactor.seconds,
        )
        if cache_period is None:
            cache_period = WELL_KNOWN_DEFAULT_CACHE_PERIOD
            # add some randomness to the TTL to avoid a stampeding herd every 24 hours
            # after startup
            cache_period += random.uniform(
                0, WELL_KNOWN_DEFAULT_CACHE_PERIOD_JITTER)
        else:
            cache_period = min(cache_period, WELL_KNOWN_MAX_CACHE_PERIOD)

        defer.returnValue((result, cache_period))
    def _on_new_receipts(self, min_stream_id, max_stream_id,
                         affected_room_ids):
        try:
            # Need to subtract 1 from the minimum because the lower bound here
            # is not inclusive
            updated_receipts = yield self.store.get_all_updated_receipts(
                min_stream_id - 1, max_stream_id)
            # This returns a tuple, user_id is at index 3
            users_affected = set([r[3] for r in updated_receipts])

            deferreds = []

            for u in users_affected:
                if u in self.pushers:
                    for p in self.pushers[u].values():
                        deferreds.append(
                            run_in_background(
                                p.on_new_receipts,
                                min_stream_id,
                                max_stream_id,
                            ))

            yield make_deferred_yieldable(
                defer.gatherResults(deferreds, consumeErrors=True), )
        except Exception:
            logger.exception("Exception in pusher on_new_receipts")
Пример #3
0
    def get_keys_from_perspectives(self, server_name_and_key_ids):
        @defer.inlineCallbacks
        def get_key(perspective_name, perspective_keys):
            try:
                result = yield self.get_server_verify_key_v2_indirect(
                    server_name_and_key_ids, perspective_name, perspective_keys
                )
                defer.returnValue(result)
            except KeyLookupError as e:
                logger.warning(
                    "Key lookup failed from %r: %s", perspective_name, e,
                )
            except Exception as e:
                logger.exception(
                    "Unable to get key from %r: %s %s",
                    perspective_name,
                    type(e).__name__, str(e),
                )

            defer.returnValue({})

        results = yield logcontext.make_deferred_yieldable(defer.gatherResults(
            [
                run_in_background(get_key, p_name, p_keys)
                for p_name, p_keys in self.perspective_servers.items()
            ],
            consumeErrors=True,
        ).addErrback(unwrapFirstError))

        union_of_keys = {}
        for result in results:
            for server_name, keys in result.items():
                union_of_keys.setdefault(server_name, {}).update(keys)

        defer.returnValue(union_of_keys)
Пример #4
0
    def get_raw(self, uri, args={}, headers=None):
        """ Gets raw text from the given URI.

        Args:
            uri (str): The URI to request, not including query parameters
            args (dict): A dictionary used to create query strings, defaults to
                None.
                **Note**: The value of each key is assumed to be an iterable
                and *not* a string.
            headers (dict[str, List[str]]|None): If not None, a map from
               header name to a list of values for that header
        Returns:
            Deferred: Succeeds when we get *any* 2xx HTTP response, with the
            HTTP body at text.
        Raises:
            HttpResponseException on a non-2xx HTTP response.
        """
        if len(args):
            query_bytes = urllib.parse.urlencode(args, True)
            uri = "%s?%s" % (uri, query_bytes)

        actual_headers = {b"User-Agent": [self.user_agent]}
        if headers:
            actual_headers.update(headers)

        response = yield self.request("GET", uri, headers=Headers(actual_headers))

        body = yield make_deferred_yieldable(readBody(response))

        if 200 <= response.code < 300:
            defer.returnValue(body)
        else:
            raise HttpResponseException(response.code, response.phrase, body)
Пример #5
0
    def request(self, method, uri, data=b'', headers=None):
        # A small wrapper around self.agent.request() so we can easily attach
        # counters to it
        outgoing_requests_counter.labels(method).inc()

        # log request but strip `access_token` (AS requests for example include this)
        logger.info("Sending request %s %s", method, redact_uri(uri))

        try:
            request_deferred = treq.request(method,
                                            uri,
                                            agent=self.agent,
                                            data=data,
                                            headers=headers)
            request_deferred = timeout_deferred(
                request_deferred,
                60,
                self.hs.get_reactor(),
                cancelled_to_request_timed_out_error,
            )
            response = yield make_deferred_yieldable(request_deferred)

            incoming_responses_counter.labels(method, response.code).inc()
            logger.info("Received response to  %s %s: %s", method,
                        redact_uri(uri), response.code)
            defer.returnValue(response)
        except Exception as e:
            incoming_responses_counter.labels(method, "ERR").inc()
            logger.info("Error sending request to  %s %s: %s %s", method,
                        redact_uri(uri),
                        type(e).__name__, e.args[0])
            raise
Пример #6
0
def concurrently_execute(func, args, limit):
    """Executes the function with each argument conncurrently while limiting
    the number of concurrent executions.

    Args:
        func (func): Function to execute, should return a deferred.
        args (list): List of arguments to pass to func, each invocation of func
            gets a signle argument.
        limit (int): Maximum number of conccurent executions.

    Returns:
        deferred: Resolved when all function invocations have finished.
    """
    it = iter(args)

    @defer.inlineCallbacks
    def _concurrently_execute_inner():
        try:
            while True:
                yield func(next(it))
        except StopIteration:
            pass

    return logcontext.make_deferred_yieldable(defer.gatherResults([
        run_in_background(_concurrently_execute_inner)
        for _ in range(limit)
    ], consumeErrors=True)).addErrback(unwrapFirstError)
Пример #7
0
def concurrently_execute(func, args, limit):
    """Executes the function with each argument conncurrently while limiting
    the number of concurrent executions.

    Args:
        func (func): Function to execute, should return a deferred.
        args (list): List of arguments to pass to func, each invocation of func
            gets a signle argument.
        limit (int): Maximum number of conccurent executions.

    Returns:
        deferred: Resolved when all function invocations have finished.
    """
    it = iter(args)

    @defer.inlineCallbacks
    def _concurrently_execute_inner():
        try:
            while True:
                yield func(next(it))
        except StopIteration:
            pass

    return logcontext.make_deferred_yieldable(
        defer.gatherResults([
            run_in_background(_concurrently_execute_inner)
            for _ in range(limit)
        ],
                            consumeErrors=True)).addErrback(unwrapFirstError)
Пример #8
0
        def first_lookup():
            with LoggingContext("11") as context_11:
                context_11.request = "11"

                res_deferreds = kr.verify_json_objects_for_server([
                    ("server10", json1, 0, "test10"),
                    ("server11", {}, 0, "test11")
                ])

                # the unsigned json should be rejected pretty quickly
                self.assertTrue(res_deferreds[1].called)
                try:
                    yield res_deferreds[1]
                    self.assertFalse("unsigned json didn't cause a failure")
                except SynapseError:
                    pass

                self.assertFalse(res_deferreds[0].called)
                res_deferreds[0].addBoth(self.check_context, None)

                yield logcontext.make_deferred_yieldable(res_deferreds[0])

                # let verify_json_objects_for_server finish its work before we kill the
                # logcontext
                yield self.clock.sleep(0)
Пример #9
0
    def post_urlencoded_get_json(self, uri, args={}, headers=None):
        """
        Args:
            uri (str):
            args (dict[str, str|List[str]]): query params
            headers (dict[str, List[str]]|None): If not None, a map from
               header name to a list of values for that header

        Returns:
            Deferred[object]: parsed json
        """

        # TODO: Do we ever want to log message contents?
        logger.debug("post_urlencoded_get_json args: %s", args)

        query_bytes = urllib.urlencode(encode_urlencode_args(args), True)

        actual_headers = {
            b"Content-Type": [b"application/x-www-form-urlencoded"],
            b"User-Agent": [self.user_agent],
        }
        if headers:
            actual_headers.update(headers)

        response = yield self.request("POST",
                                      uri.encode("ascii"),
                                      headers=Headers(actual_headers),
                                      bodyProducer=FileBodyProducer(
                                          StringIO(query_bytes)))

        body = yield make_deferred_yieldable(readBody(response))

        defer.returnValue(json.loads(body))
Пример #10
0
        def handle_check_result(pdu, deferred):
            try:
                res = yield logcontext.make_deferred_yieldable(deferred)
            except SynapseError:
                res = None

            if not res:
                # Check local db.
                res = yield self.store.get_event(
                    pdu.event_id,
                    allow_rejected=True,
                    allow_none=True,
                )

            if not res and pdu.origin != origin:
                try:
                    res = yield self.get_pdu(
                        destinations=[pdu.origin],
                        event_id=pdu.event_id,
                        outlier=outlier,
                        timeout=10000,
                    )
                except SynapseError:
                    pass

            if not res:
                logger.warn(
                    "Failed to find copy of %s with valid signature",
                    pdu.event_id,
                )

            defer.returnValue(res)
Пример #11
0
    def on_new_receipts(self, min_stream_id, max_stream_id, affected_room_ids):
        yield run_on_reactor()
        try:
            # Need to subtract 1 from the minimum because the lower bound here
            # is not inclusive
            updated_receipts = yield self.store.get_all_updated_receipts(
                min_stream_id - 1, max_stream_id
            )
            # This returns a tuple, user_id is at index 3
            users_affected = set([r[3] for r in updated_receipts])

            deferreds = []

            for u in users_affected:
                if u in self.pushers:
                    for p in self.pushers[u].values():
                        deferreds.append(
                            run_in_background(
                                p.on_new_receipts,
                                min_stream_id, max_stream_id,
                            )
                        )

            yield make_deferred_yieldable(
                defer.gatherResults(deferreds, consumeErrors=True),
            )
        except Exception:
            logger.exception("Exception in pusher on_new_receipts")
Пример #12
0
    def copy_to_backup(self, path):
        """Copy a file from the primary to backup media store, if configured.

        Args:
            path(str): Relative path to write file to
        """
        if self.backup_base_path:
            primary_fname = os.path.join(self.primary_base_path, path)
            backup_fname = os.path.join(self.backup_base_path, path)

            # We can either wait for successful writing to the backup repository
            # or write in the background and immediately return
            if self.synchronous_backup_media_store:
                yield make_deferred_yieldable(
                    threads.deferToThread(
                        shutil.copyfile,
                        primary_fname,
                        backup_fname,
                    ))
            else:
                preserve_fn(threads.deferToThread)(
                    shutil.copyfile,
                    primary_fname,
                    backup_fname,
                )
Пример #13
0
    def write_to_file_and_backup(self, source, path):
        """Write `source` to the on disk media store, and also the backup store
        if configured.

        Args:
            source: A file like object that should be written
            path (str): Relative path to write file to

        Returns:
            Deferred[str]: the file path written to in the primary media store
        """
        fname = os.path.join(self.primary_base_path, path)

        # Write to the main repository
        yield make_deferred_yieldable(
            threads.deferToThread(
                self._write_file_synchronously,
                source,
                fname,
            ))

        # Write to backup repository
        yield self.copy_to_backup(path)

        defer.returnValue(fname)
Пример #14
0
    def generate_remote_exact_thumbnail(self, server_name, file_id, media_id,
                                        t_width, t_height, t_method, t_type):
        input_path = self.filepaths.remote_media_filepath(server_name, file_id)

        thumbnailer = Thumbnailer(input_path)
        t_byte_source = yield make_deferred_yieldable(
            threads.deferToThread(self._generate_thumbnail, thumbnailer,
                                  t_width, t_height, t_method, t_type))

        if t_byte_source:
            try:
                output_path = yield self.write_to_file_and_backup(
                    t_byte_source,
                    self.filepaths.remote_media_thumbnail_rel(
                        server_name, file_id, t_width, t_height, t_type,
                        t_method))
            finally:
                t_byte_source.close()

            logger.info("Stored thumbnail in file %r", output_path)

            t_len = os.path.getsize(output_path)

            yield self.store.store_remote_media_thumbnail(
                server_name, media_id, file_id, t_width, t_height, t_type,
                t_method, t_len)

            defer.returnValue(output_path)
Пример #15
0
    def fire(self, *args, **kwargs):
        """Invokes every callable in the observer list, passing in the args and
        kwargs. Exceptions thrown by observers are logged but ignored. It is
        not an error to fire a signal with no observers.

        Returns a Deferred that will complete when all the observers have
        completed."""

        def do(observer):
            def eb(failure):
                logger.warning(
                    "%s signal observer %s failed: %r",
                    self.name, observer, failure,
                    exc_info=(
                        failure.type,
                        failure.value,
                        failure.getTracebackObject()))

            return defer.maybeDeferred(observer, *args, **kwargs).addErrback(eb)

        deferreds = [
            run_in_background(do, o)
            for o in self.observers
        ]

        return make_deferred_yieldable(defer.gatherResults(
            deferreds, consumeErrors=True,
        ))
Пример #16
0
    def get_raw(self, uri, args={}, headers=None):
        """ Gets raw text from the given URI.

        Args:
            uri (str): The URI to request, not including query parameters
            args (dict): A dictionary used to create query strings, defaults to
                None.
                **Note**: The value of each key is assumed to be an iterable
                and *not* a string.
            headers (dict[str, List[str]]|None): If not None, a map from
               header name to a list of values for that header
        Returns:
            Deferred: Succeeds when we get *any* 2xx HTTP response, with the
            HTTP body at text.
        Raises:
            HttpResponseException on a non-2xx HTTP response.
        """
        if len(args):
            query_bytes = urllib.parse.urlencode(args, True)
            uri = "%s?%s" % (uri, query_bytes)

        actual_headers = {b"User-Agent": [self.user_agent]}
        if headers:
            actual_headers.update(headers)

        response = yield self.request("GET",
                                      uri,
                                      headers=Headers(actual_headers))

        body = yield make_deferred_yieldable(readBody(response))

        if 200 <= response.code < 300:
            defer.returnValue(body)
        else:
            raise HttpResponseException(response.code, response.phrase, body)
Пример #17
0
    def get_keys_from_store(self, server_name_and_key_ids):
        """

        Args:
            server_name_and_key_ids (list[(str, iterable[str])]):
                list of (server_name, iterable[key_id]) tuples to fetch keys for

        Returns:
            Deferred: resolves to dict[str, dict[str, VerifyKey]]: map from
                server_name -> key_id -> VerifyKey
        """
        res = yield logcontext.make_deferred_yieldable(
            defer.gatherResults(
                [
                    run_in_background(
                        self.store.get_server_verify_keys,
                        server_name,
                        key_ids,
                    ).addCallback(lambda ks, server: (server, ks), server_name)
                    for server_name, key_ids in server_name_and_key_ids
                ],
                consumeErrors=True,
            ).addErrback(unwrapFirstError))

        defer.returnValue(dict(res))
Пример #18
0
    def validate_hash(self, password, stored_hash):
        """Validates that self.hash(password) == stored_hash.

        Args:
            password (unicode): Password to hash.
            stored_hash (unicode): Expected hash value.

        Returns:
            Deferred(bool): Whether self.hash(password) == stored_hash.
        """
        def _do_validate_hash():
            # Normalise the Unicode in the password
            pw = unicodedata.normalize("NFKC", password)

            return bcrypt.checkpw(
                pw.encode('utf8') +
                self.hs.config.password_pepper.encode("utf8"),
                stored_hash.encode('utf8'))

        if stored_hash:
            return make_deferred_yieldable(
                threads.deferToThreadPool(
                    self.hs.get_reactor(),
                    self.hs.get_reactor().getThreadPool(),
                    _do_validate_hash,
                ), )
        else:
            return defer.succeed(False)
Пример #19
0
    def get_events(self, destinations, room_id, event_ids, return_local=True):
        """Fetch events from some remote destinations, checking if we already
        have them.

        Args:
            destinations (list)
            room_id (str)
            event_ids (list)
            return_local (bool): Whether to include events we already have in
                the DB in the returned list of events

        Returns:
            Deferred: A deferred resolving to a 2-tuple where the first is a list of
            events and the second is a list of event ids that we failed to fetch.
        """
        if return_local:
            seen_events = yield self.store.get_events(event_ids,
                                                      allow_rejected=True)
            signed_events = seen_events.values()
        else:
            seen_events = yield self.store.have_events(event_ids)
            signed_events = []

        failed_to_fetch = set()

        missing_events = set(event_ids)
        for k in seen_events:
            missing_events.discard(k)

        if not missing_events:
            defer.returnValue((signed_events, failed_to_fetch))

        def random_server_list():
            srvs = list(destinations)
            random.shuffle(srvs)
            return srvs

        batch_size = 20
        missing_events = list(missing_events)
        for i in xrange(0, len(missing_events), batch_size):
            batch = set(missing_events[i:i + batch_size])

            deferreds = [
                preserve_fn(self.get_pdu)(
                    destinations=random_server_list(),
                    event_id=e_id,
                ) for e_id in batch
            ]

            res = yield make_deferred_yieldable(
                defer.DeferredList(deferreds, consumeErrors=True))
            for success, result in res:
                if success and result:
                    signed_events.append(result)
                    batch.discard(result.event_id)

            # We removed all events we successfully fetched from `batch`
            failed_to_fetch.update(batch)

        defer.returnValue((signed_events, failed_to_fetch))
Пример #20
0
    def fetch_or_execute(self, txn_key, fn, *args, **kwargs):
        """Fetches the response for this transaction, or executes the given function
        to produce a response for this transaction.

        Args:
            txn_key (str): A key to ensure idempotency should fetch_or_execute be
            called again at a later point in time.
            fn (function): A function which returns a tuple of
            (response_code, response_dict).
            *args: Arguments to pass to fn.
            **kwargs: Keyword arguments to pass to fn.
        Returns:
            Deferred which resolves to a tuple of (response_code, response_dict).
        """
        if txn_key in self.transactions:
            observable = self.transactions[txn_key][0]
        else:
            # execute the function instead.
            deferred = run_in_background(fn, *args, **kwargs)

            observable = ObservableDeferred(deferred)
            self.transactions[txn_key] = (observable, self.clock.time_msec())

            # if the request fails with an exception, remove it
            # from the transaction map. This is done to ensure that we don't
            # cache transient errors like rate-limiting errors, etc.
            def remove_from_map(err):
                self.transactions.pop(txn_key, None)
                # we deliberately do not propagate the error any further, as we
                # expect the observers to have reported it.

            deferred.addErrback(remove_from_map)

        return make_deferred_yieldable(observable.observe())
    def get_file(self, destination, path, output_stream, args={},
                 retry_on_dns_fail=True, max_size=None,
                 ignore_backoff=False):
        """GETs a file from a given homeserver
        Args:
            destination (str): The remote server to send the HTTP request to.
            path (str): The HTTP path to GET.
            output_stream (file): File to write the response body to.
            args (dict): Optional dictionary used to create the query string.
            ignore_backoff (bool): true to ignore the historical backoff data
                and try the request anyway.
        Returns:
            Deferred: resolves with an (int,dict) tuple of the file length and
            a dict of the response headers.

            Fails with ``HttpResponseException`` if we get an HTTP response code
            >= 300

            Fails with ``NotRetryingDestination`` if we are not yet ready
            to retry this server.

            Fails with ``FederationDeniedError`` if this destination
            is not on our federation whitelist
        """
        request = MatrixFederationRequest(
            method="GET",
            destination=destination,
            path=path,
            query=args,
        )

        response = yield self._send_request(
            request,
            retry_on_dns_fail=retry_on_dns_fail,
            ignore_backoff=ignore_backoff,
        )

        headers = dict(response.headers.getAllRawHeaders())

        try:
            d = _readBodyToFile(response, output_stream, max_size)
            d.addTimeout(self.default_timeout, self.hs.get_reactor())
            length = yield make_deferred_yieldable(d)
        except Exception as e:
            logger.warn(
                "{%s} [%s] Error reading response: %s",
                request.txn_id,
                request.destination,
                e,
            )
            raise
        logger.info(
            "{%s} [%s] Completed: %d %s [%d bytes]",
            request.txn_id,
            request.destination,
            response.code,
            response.phrase.decode('ascii', errors='replace'),
            length,
        )
        defer.returnValue((length, headers))
Пример #22
0
        def handle_check_result(pdu, deferred):
            try:
                res = yield logcontext.make_deferred_yieldable(deferred)
            except SynapseError:
                res = None

            if not res:
                # Check local db.
                res = yield self.store.get_event(
                    pdu.event_id,
                    allow_rejected=True,
                    allow_none=True,
                )

            if not res and pdu.origin != origin:
                try:
                    res = yield self.get_pdu(
                        destinations=[pdu.origin],
                        event_id=pdu.event_id,
                        outlier=outlier,
                        timeout=10000,
                    )
                except SynapseError:
                    pass

            if not res:
                logger.warn(
                    "Failed to find copy of %s with valid signature",
                    pdu.event_id,
                )

            defer.returnValue(res)
Пример #23
0
    def get_room_events_stream_for_rooms(self,
                                         room_ids,
                                         from_key,
                                         to_key,
                                         limit=0,
                                         order='DESC'):
        from_id = RoomStreamToken.parse_stream_token(from_key).stream

        room_ids = yield self._events_stream_cache.get_entities_changed(
            room_ids, from_id)

        if not room_ids:
            defer.returnValue({})

        results = {}
        room_ids = list(room_ids)
        for rm_ids in (room_ids[i:i + 20]
                       for i in range(0, len(room_ids), 20)):
            res = yield make_deferred_yieldable(
                defer.gatherResults([
                    run_in_background(
                        self.get_room_events_stream_for_room,
                        room_id,
                        from_key,
                        to_key,
                        limit,
                        order=order,
                    ) for room_id in rm_ids
                ],
                                    consumeErrors=True))
            results.update(dict(zip(rm_ids, res)))

        defer.returnValue(results)
Пример #24
0
    def fetch_or_execute(self, txn_key, fn, *args, **kwargs):
        """Fetches the response for this transaction, or executes the given function
        to produce a response for this transaction.

        Args:
            txn_key (str): A key to ensure idempotency should fetch_or_execute be
            called again at a later point in time.
            fn (function): A function which returns a tuple of
            (response_code, response_dict).
            *args: Arguments to pass to fn.
            **kwargs: Keyword arguments to pass to fn.
        Returns:
            Deferred which resolves to a tuple of (response_code, response_dict).
        """
        if txn_key in self.transactions:
            observable = self.transactions[txn_key][0]
        else:
            # execute the function instead.
            deferred = run_in_background(fn, *args, **kwargs)

            observable = ObservableDeferred(deferred)
            self.transactions[txn_key] = (observable, self.clock.time_msec())

            # if the request fails with an exception, remove it
            # from the transaction map. This is done to ensure that we don't
            # cache transient errors like rate-limiting errors, etc.
            def remove_from_map(err):
                self.transactions.pop(txn_key, None)
                # we deliberately do not propagate the error any further, as we
                # expect the observers to have reported it.

            deferred.addErrback(remove_from_map)

        return make_deferred_yieldable(observable.observe())
Пример #25
0
    def get_file(self, destination, path, output_stream, args={},
                 retry_on_dns_fail=True, max_size=None,
                 ignore_backoff=False):
        """GETs a file from a given homeserver
        Args:
            destination (str): The remote server to send the HTTP request to.
            path (str): The HTTP path to GET.
            output_stream (file): File to write the response body to.
            args (dict): Optional dictionary used to create the query string.
            ignore_backoff (bool): true to ignore the historical backoff data
                and try the request anyway.
        Returns:
            Deferred: resolves with an (int,dict) tuple of the file length and
            a dict of the response headers.

            Fails with ``HttpResponseException`` if we get an HTTP response code
            >= 300

            Fails with ``NotRetryingDestination`` if we are not yet ready
            to retry this server.

            Fails with ``FederationDeniedError`` if this destination
            is not on our federation whitelist
        """
        request = MatrixFederationRequest(
            method="GET",
            destination=destination,
            path=path,
            query=args,
        )

        response = yield self._send_request(
            request,
            retry_on_dns_fail=retry_on_dns_fail,
            ignore_backoff=ignore_backoff,
        )

        headers = dict(response.headers.getAllRawHeaders())

        try:
            d = _readBodyToFile(response, output_stream, max_size)
            d.addTimeout(self.default_timeout, self.hs.get_reactor())
            length = yield make_deferred_yieldable(d)
        except Exception as e:
            logger.warn(
                "{%s} [%s] Error reading response: %s",
                request.txn_id,
                request.destination,
                e,
            )
            raise
        logger.info(
            "{%s} [%s] Completed: %d %s [%d bytes]",
            request.txn_id,
            request.destination,
            response.code,
            response.phrase.decode('ascii', errors='replace'),
            length,
        )
        defer.returnValue((length, headers))
Пример #26
0
    def validate_hash(self, password, stored_hash):
        """Validates that self.hash(password) == stored_hash.

        Args:
            password (str): Password to hash.
            stored_hash (str): Expected hash value.

        Returns:
            Deferred(bool): Whether self.hash(password) == stored_hash.
        """

        def _do_validate_hash():
            return bcrypt.checkpw(
                password.encode('utf8') + self.hs.config.password_pepper,
                stored_hash.encode('utf8')
            )

        if stored_hash:
            return make_deferred_yieldable(
                threads.deferToThreadPool(
                    self.hs.get_reactor(),
                    self.hs.get_reactor().getThreadPool(),
                    _do_validate_hash,
                ),
            )
        else:
            return defer.succeed(False)
Пример #27
0
    def fire(self, *args, **kwargs):
        """Invokes every callable in the observer list, passing in the args and
        kwargs. Exceptions thrown by observers are logged but ignored. It is
        not an error to fire a signal with no observers.

        Returns a Deferred that will complete when all the observers have
        completed."""
        def do(observer):
            def eb(failure):
                logger.warning("%s signal observer %s failed: %r",
                               self.name,
                               observer,
                               failure,
                               exc_info=(failure.type, failure.value,
                                         failure.getTracebackObject()))

            return defer.maybeDeferred(observer, *args,
                                       **kwargs).addErrback(eb)

        deferreds = [run_in_background(do, o) for o in self.observers]

        return make_deferred_yieldable(
            defer.gatherResults(
                deferreds,
                consumeErrors=True,
            ))
Пример #28
0
    def request(self, method, uri, *args, **kwargs):
        # A small wrapper around self.agent.request() so we can easily attach
        # counters to it
        outgoing_requests_counter.inc(method)

        logger.info("Sending request %s %s", method, uri)

        try:
            request_deferred = self.agent.request(method, uri, *args, **kwargs)
            add_timeout_to_deferred(
                request_deferred,
                60,
                cancelled_to_request_timed_out_error,
            )
            response = yield make_deferred_yieldable(request_deferred)

            incoming_responses_counter.inc(method, response.code)
            logger.info("Received response to  %s %s: %s", method, uri,
                        response.code)
            defer.returnValue(response)
        except Exception as e:
            incoming_responses_counter.inc(method, "ERR")
            logger.info("Error sending request to  %s %s: %s %s", method, uri,
                        type(e).__name__, e.message)
            raise e
Пример #29
0
    def get_keys_from_perspectives(self, server_name_and_key_ids):
        @defer.inlineCallbacks
        def get_key(perspective_name, perspective_keys):
            try:
                result = yield self.get_server_verify_key_v2_indirect(
                    server_name_and_key_ids, perspective_name, perspective_keys
                )
                defer.returnValue(result)
            except KeyLookupError as e:
                logger.warning(
                    "Key lookup failed from %r: %s", perspective_name, e,
                )
            except Exception as e:
                logger.exception(
                    "Unable to get key from %r: %s %s",
                    perspective_name,
                    type(e).__name__, str(e),
                )

            defer.returnValue({})

        results = yield logcontext.make_deferred_yieldable(defer.gatherResults(
            [
                run_in_background(get_key, p_name, p_keys)
                for p_name, p_keys in self.perspective_servers.items()
            ],
            consumeErrors=True,
        ).addErrback(unwrapFirstError))

        union_of_keys = {}
        for result in results:
            for server_name, keys in result.items():
                union_of_keys.setdefault(server_name, {}).update(keys)

        defer.returnValue(union_of_keys)
Пример #30
0
    def get_events_from_store_or_dest(self, destination, room_id, event_ids):
        """Fetch events from a remote destination, checking if we already have them.

        Args:
            destination (str)
            room_id (str)
            event_ids (list)

        Returns:
            Deferred: A deferred resolving to a 2-tuple where the first is a list of
            events and the second is a list of event ids that we failed to fetch.
        """
        seen_events = yield self.store.get_events(event_ids, allow_rejected=True)
        signed_events = list(seen_events.values())

        failed_to_fetch = set()

        missing_events = set(event_ids)
        for k in seen_events:
            missing_events.discard(k)

        if not missing_events:
            defer.returnValue((signed_events, failed_to_fetch))

        logger.debug(
            "Fetching unknown state/auth events %s for room %s",
            missing_events,
            event_ids,
        )

        room_version = yield self.store.get_room_version(room_id)

        batch_size = 20
        missing_events = list(missing_events)
        for i in range(0, len(missing_events), batch_size):
            batch = set(missing_events[i:i + batch_size])

            deferreds = [
                run_in_background(
                    self.get_pdu,
                    destinations=[destination],
                    event_id=e_id,
                    room_version=room_version,
                )
                for e_id in batch
            ]

            res = yield make_deferred_yieldable(
                defer.DeferredList(deferreds, consumeErrors=True)
            )
            for success, result in res:
                if success and result:
                    signed_events.append(result)
                    batch.discard(result.event_id)

            # We removed all events we successfully fetched from `batch`
            failed_to_fetch.update(batch)

        defer.returnValue((signed_events, failed_to_fetch))
Пример #31
0
 def on_PUT(self, request, event_id):
     result = self.response_cache.get(event_id)
     if not result:
         result = self.response_cache.set(event_id,
                                          self._handle_request(request))
     else:
         logger.warn("Returning cached response")
     return make_deferred_yieldable(result)
Пример #32
0
    def get_server_verify_key_v2_direct(self, server_name, key_ids):
        keys = {}

        for requested_key_id in key_ids:
            if requested_key_id in keys:
                continue

            (response, tls_certificate) = yield fetch_server_key(
                server_name,
                self.hs.tls_client_options_factory,
                path=(
                    "/_matrix/key/v2/server/%s" %
                    (urllib.parse.quote(requested_key_id), )).encode("ascii"),
            )

            if (u"signatures" not in response
                    or server_name not in response[u"signatures"]):
                raise KeyLookupError(
                    "Key response not signed by remote server")

            if "tls_fingerprints" not in response:
                raise KeyLookupError("Key response missing TLS fingerprints")

            certificate_bytes = crypto.dump_certificate(
                crypto.FILETYPE_ASN1, tls_certificate)
            sha256_fingerprint = hashlib.sha256(certificate_bytes).digest()
            sha256_fingerprint_b64 = encode_base64(sha256_fingerprint)

            response_sha256_fingerprints = set()
            for fingerprint in response[u"tls_fingerprints"]:
                if u"sha256" in fingerprint:
                    response_sha256_fingerprints.add(fingerprint[u"sha256"])

            if sha256_fingerprint_b64 not in response_sha256_fingerprints:
                raise KeyLookupError(
                    "TLS certificate not allowed by fingerprints")

            response_keys = yield self.process_v2_response(
                from_server=server_name,
                requested_ids=[requested_key_id],
                response_json=response,
            )

            keys.update(response_keys)

        yield logcontext.make_deferred_yieldable(
            defer.gatherResults(
                [
                    run_in_background(
                        self.store_keys,
                        server_name=key_server_name,
                        from_server=server_name,
                        verify_keys=verify_keys,
                    ) for key_server_name, verify_keys in keys.items()
                ],
                consumeErrors=True).addErrback(unwrapFirstError))

        defer.returnValue(keys)
Пример #33
0
    def claim_one_time_keys(self, query, timeout):
        local_query = []
        remote_queries = {}

        for user_id, device_keys in query.get("one_time_keys", {}).items():
            # we use UserID.from_string to catch invalid user ids
            if self.is_mine(UserID.from_string(user_id)):
                for device_id, algorithm in device_keys.items():
                    local_query.append((user_id, device_id, algorithm))
            else:
                domain = get_domain_from_id(user_id)
                remote_queries.setdefault(domain, {})[user_id] = device_keys

        results = yield self.store.claim_e2e_one_time_keys(local_query)

        json_result = {}
        failures = {}
        for user_id, device_keys in results.items():
            for device_id, keys in device_keys.items():
                for key_id, json_bytes in keys.items():
                    json_result.setdefault(user_id, {})[device_id] = {
                        key_id: json.loads(json_bytes)
                    }

        @defer.inlineCallbacks
        def claim_client_keys(destination):
            device_keys = remote_queries[destination]
            try:
                remote_result = yield self.federation.claim_client_keys(
                    destination,
                    {"one_time_keys": device_keys},
                    timeout=timeout
                )
                for user_id, keys in remote_result["one_time_keys"].items():
                    if user_id in device_keys:
                        json_result[user_id] = keys
            except Exception as e:
                failures[destination] = _exception_to_failure(e)

        yield make_deferred_yieldable(defer.gatherResults([
            run_in_background(claim_client_keys, destination)
            for destination in remote_queries
        ], consumeErrors=True))

        logger.info(
            "Claimed one-time-keys: %s",
            ",".join((
                "%s for %s:%s" % (key_id, user_id, device_id)
                for user_id, user_keys in iteritems(json_result)
                for device_id, device_keys in iteritems(user_keys)
                for key_id, _ in iteritems(device_keys)
            )),
        )

        defer.returnValue({
            "one_time_keys": json_result,
            "failures": failures
        })
Пример #34
0
    def get_file(self, url, output_stream, max_size=None, headers=None):
        """GETs a file from a given URL
        Args:
            url (str): The URL to GET
            output_stream (file): File to write the response body to.
            headers (dict[str, List[str]]|None): If not None, a map from
               header name to a list of values for that header
        Returns:
            A (int,dict,string,int) tuple of the file length, dict of the response
            headers, absolute URI of the response and HTTP response code.
        """

        actual_headers = {b"User-Agent": [self.user_agent]}
        if headers:
            actual_headers.update(headers)

        response = yield self.request("GET",
                                      url,
                                      headers=Headers(actual_headers))

        resp_headers = dict(response.headers.getAllRawHeaders())

        if (b"Content-Length" in resp_headers
                and int(resp_headers[b"Content-Length"][0]) > max_size):
            logger.warn("Requested URL is too large > %r bytes" %
                        (self.max_size, ))
            raise SynapseError(
                502,
                "Requested file is too large > %r bytes" % (self.max_size, ),
                Codes.TOO_LARGE,
            )

        if response.code > 299:
            logger.warn("Got %d when downloading %s" % (response.code, url))
            raise SynapseError(502, "Got error %d" % (response.code, ),
                               Codes.UNKNOWN)

        # TODO: if our Content-Type is HTML or something, just read the first
        # N bytes into RAM rather than saving it all to disk only to read it
        # straight back in again

        try:
            length = yield make_deferred_yieldable(
                _readBodyToFile(response, output_stream, max_size))
        except SynapseError:
            # This can happen e.g. because the body is too large.
            raise
        except Exception as e:
            raise_from(
                SynapseError(502, ("Failed to download remote body: %s" % e)),
                e)

        defer.returnValue((
            length,
            resp_headers,
            response.request.absoluteURI.decode("ascii"),
            response.code,
        ))
Пример #35
0
    def get_file(self, url, output_stream, max_size=None, headers=None):
        """GETs a file from a given URL
        Args:
            url (str): The URL to GET
            output_stream (file): File to write the response body to.
            headers (dict[str, List[str]]|None): If not None, a map from
               header name to a list of values for that header
        Returns:
            A (int,dict,string,int) tuple of the file length, dict of the response
            headers, absolute URI of the response and HTTP response code.
        """

        actual_headers = {b"User-Agent": [self.user_agent]}
        if headers:
            actual_headers.update(headers)

        response = yield self.request("GET", url, headers=Headers(actual_headers))

        resp_headers = dict(response.headers.getAllRawHeaders())

        if (
            b'Content-Length' in resp_headers
            and int(resp_headers[b'Content-Length'][0]) > max_size
        ):
            logger.warn("Requested URL is too large > %r bytes" % (self.max_size,))
            raise SynapseError(
                502,
                "Requested file is too large > %r bytes" % (self.max_size,),
                Codes.TOO_LARGE,
            )

        if response.code > 299:
            logger.warn("Got %d when downloading %s" % (response.code, url))
            raise SynapseError(502, "Got error %d" % (response.code,), Codes.UNKNOWN)

        # TODO: if our Content-Type is HTML or something, just read the first
        # N bytes into RAM rather than saving it all to disk only to read it
        # straight back in again

        try:
            length = yield make_deferred_yieldable(
                _readBodyToFile(response, output_stream, max_size)
            )
        except Exception as e:
            logger.exception("Failed to download body")
            raise SynapseError(
                502, ("Failed to download remote body: %s" % e), Codes.UNKNOWN
            )

        defer.returnValue(
            (
                length,
                resp_headers,
                response.request.absoluteURI.decode('ascii'),
                response.code,
            )
        )
Пример #36
0
    def get_server_verify_key_v2_direct(self, server_name, key_ids):
        keys = {}

        for requested_key_id in key_ids:
            if requested_key_id in keys:
                continue

            (response, tls_certificate) = yield fetch_server_key(
                server_name, self.hs.tls_server_context_factory,
                path=(b"/_matrix/key/v2/server/%s" % (
                    urllib.quote(requested_key_id),
                )).encode("ascii"),
            )

            if (u"signatures" not in response
                    or server_name not in response[u"signatures"]):
                raise KeyLookupError("Key response not signed by remote server")

            if "tls_fingerprints" not in response:
                raise KeyLookupError("Key response missing TLS fingerprints")

            certificate_bytes = crypto.dump_certificate(
                crypto.FILETYPE_ASN1, tls_certificate
            )
            sha256_fingerprint = hashlib.sha256(certificate_bytes).digest()
            sha256_fingerprint_b64 = encode_base64(sha256_fingerprint)

            response_sha256_fingerprints = set()
            for fingerprint in response[u"tls_fingerprints"]:
                if u"sha256" in fingerprint:
                    response_sha256_fingerprints.add(fingerprint[u"sha256"])

            if sha256_fingerprint_b64 not in response_sha256_fingerprints:
                raise KeyLookupError("TLS certificate not allowed by fingerprints")

            response_keys = yield self.process_v2_response(
                from_server=server_name,
                requested_ids=[requested_key_id],
                response_json=response,
            )

            keys.update(response_keys)

        yield logcontext.make_deferred_yieldable(defer.gatherResults(
            [
                run_in_background(
                    self.store_keys,
                    server_name=key_server_name,
                    from_server=server_name,
                    verify_keys=verify_keys,
                )
                for key_server_name, verify_keys in keys.items()
            ],
            consumeErrors=True
        ).addErrback(unwrapFirstError))

        defer.returnValue(keys)
Пример #37
0
    def get_json(self,
                 destination,
                 path,
                 args=None,
                 retry_on_dns_fail=True,
                 timeout=None,
                 ignore_backoff=False):
        """ GETs some json from the given host homeserver and path

        Args:
            destination (str): The remote server to send the HTTP request
                to.
            path (str): The HTTP path.
            args (dict|None): A dictionary used to create query strings, defaults to
                None.
            timeout (int): How long to try (in ms) the destination for before
                giving up. None indicates no timeout and that the request will
                be retried.
            ignore_backoff (bool): true to ignore the historical backoff data
                and try the request anyway.
        Returns:
            Deferred: Succeeds when we get a 2xx HTTP response. The result
            will be the decoded JSON body.

            Fails with ``HTTPRequestException`` if we get an HTTP response
            code >= 300.

            Fails with ``NotRetryingDestination`` if we are not yet ready
            to retry this server.

            Fails with ``FederationDeniedError`` if this destination
            is not on our federation whitelist
        """
        logger.debug("get_json args: %s", args)

        logger.debug("Query bytes: %s Retry DNS: %s", args, retry_on_dns_fail)

        response = yield self._request(
            destination,
            "GET",
            path,
            query=args,
            retry_on_dns_fail=retry_on_dns_fail,
            timeout=timeout,
            ignore_backoff=ignore_backoff,
        )

        if 200 <= response.code < 300:
            # We need to update the transactions table to say it was sent?
            check_content_type_is_json(response.headers)

        with logcontext.PreserveLoggingContext():
            d = treq.json_content(response)
            d.addTimeout(self.default_timeout, self.hs.get_reactor())
            body = yield make_deferred_yieldable(d)

        defer.returnValue(body)
Пример #38
0
    def store_file(self, path, file_info):
        """See StorageProvider.store_file"""
        def _store_file():
            boto3.resource('s3').Bucket(self.bucket).upload_file(
                Filename=os.path.join(self.cache_directory, path),
                Key=path,
                ExtraArgs={"StorageClass": self.storage_class},
            )

        return make_deferred_yieldable(reactor.callInThread(_store_file))
Пример #39
0
    def claim_one_time_keys(self, query, timeout):
        local_query = []
        remote_queries = {}

        for user_id, device_keys in query.get("one_time_keys", {}).items():
            # we use UserID.from_string to catch invalid user ids
            if self.is_mine(UserID.from_string(user_id)):
                for device_id, algorithm in device_keys.items():
                    local_query.append((user_id, device_id, algorithm))
            else:
                domain = get_domain_from_id(user_id)
                remote_queries.setdefault(domain, {})[user_id] = device_keys

        results = yield self.store.claim_e2e_one_time_keys(local_query)

        json_result = {}
        failures = {}
        for user_id, device_keys in results.items():
            for device_id, keys in device_keys.items():
                for key_id, json_bytes in keys.items():
                    json_result.setdefault(user_id, {})[device_id] = {
                        key_id: json.loads(json_bytes)
                    }

        @defer.inlineCallbacks
        def claim_client_keys(destination):
            device_keys = remote_queries[destination]
            try:
                remote_result = yield self.federation.claim_client_keys(
                    destination, {"one_time_keys": device_keys},
                    timeout=timeout)
                for user_id, keys in remote_result["one_time_keys"].items():
                    if user_id in device_keys:
                        json_result[user_id] = keys
            except Exception as e:
                failures[destination] = _exception_to_failure(e)

        yield make_deferred_yieldable(
            defer.gatherResults(
                [
                    run_in_background(claim_client_keys, destination)
                    for destination in remote_queries
                ],
                consumeErrors=True,
            ))

        logger.info(
            "Claimed one-time-keys: %s",
            ",".join(("%s for %s:%s" % (key_id, user_id, device_id)
                      for user_id, user_keys in iteritems(json_result)
                      for device_id, device_keys in iteritems(user_keys)
                      for key_id, _ in iteritems(device_keys))),
        )

        defer.returnValue({"one_time_keys": json_result, "failures": failures})
Пример #40
0
    def on_context_state_request(self, origin, room_id, event_id):
        if not event_id:
            raise NotImplementedError("Specify an event")

        in_room = yield self.auth.check_host_in_room(room_id, origin)
        if not in_room:
            raise AuthError(403, "Host not in room.")

        result = self._state_resp_cache.get((room_id, event_id))
        if not result:
            with (yield self._server_linearizer.queue((origin, room_id))):
                d = self._state_resp_cache.set(
                    (room_id, event_id),
                    preserve_fn(self._on_context_state_request_compute)(
                        room_id, event_id))
                resp = yield make_deferred_yieldable(d)
        else:
            resp = yield make_deferred_yieldable(result)

        defer.returnValue((200, resp))
Пример #41
0
    def get_room_events_stream_for_rooms(self,
                                         room_ids,
                                         from_key,
                                         to_key,
                                         limit=0,
                                         order='DESC'):
        """Get new room events in stream ordering since `from_key`.

        Args:
            room_id (str)
            from_key (str): Token from which no events are returned before
            to_key (str): Token from which no events are returned after. (This
                is typically the current stream token)
            limit (int): Maximum number of events to return
            order (str): Either "DESC" or "ASC". Determines which events are
                returned when the result is limited. If "DESC" then the most
                recent `limit` events are returned, otherwise returns the
                oldest `limit` events.

        Returns:
            Deferred[dict[str,tuple[list[FrozenEvent], str]]]
                A map from room id to a tuple containing:
                    - list of recent events in the room
                    - stream ordering key for the start of the chunk of events returned.
        """
        from_id = RoomStreamToken.parse_stream_token(from_key).stream

        room_ids = yield self._events_stream_cache.get_entities_changed(
            room_ids, from_id)

        if not room_ids:
            defer.returnValue({})

        results = {}
        room_ids = list(room_ids)
        for rm_ids in (room_ids[i:i + 20]
                       for i in range(0, len(room_ids), 20)):
            res = yield make_deferred_yieldable(
                defer.gatherResults(
                    [
                        run_in_background(
                            self.get_room_events_stream_for_room,
                            room_id,
                            from_key,
                            to_key,
                            limit,
                            order=order,
                        ) for room_id in rm_ids
                    ],
                    consumeErrors=True,
                ))
            results.update(dict(zip(rm_ids, res)))

        defer.returnValue(results)
Пример #42
0
    def _enqueue_events(self, events, check_redacted=True, allow_rejected=False):
        """Fetches events from the database using the _event_fetch_list. This
        allows batch and bulk fetching of events - it allows us to fetch events
        without having to create a new transaction for each request for events.
        """
        if not events:
            defer.returnValue({})

        events_d = defer.Deferred()
        with self._event_fetch_lock:
            self._event_fetch_list.append(
                (events, events_d)
            )

            self._event_fetch_lock.notify()

            if self._event_fetch_ongoing < EVENT_QUEUE_THREADS:
                self._event_fetch_ongoing += 1
                should_start = True
            else:
                should_start = False

        if should_start:
            run_as_background_process(
                "fetch_events",
                self.runWithConnection,
                self._do_fetch,
            )

        logger.debug("Loading %d events", len(events))
        with PreserveLoggingContext():
            rows = yield events_d
        logger.debug("Loaded %d events (%d rows)", len(events), len(rows))

        if not allow_rejected:
            rows[:] = [r for r in rows if not r["rejects"]]

        res = yield make_deferred_yieldable(defer.gatherResults(
            [
                run_in_background(
                    self._get_event_from_row,
                    row["internal_metadata"], row["json"], row["redacts"],
                    rejected_reason=row["rejects"],
                )
                for row in rows
            ],
            consumeErrors=True
        ))

        defer.returnValue({
            e.event.event_id: e
            for e in res if e
        })
Пример #43
0
    def test_make_deferred_yieldable_on_non_deferred(self):
        """Check that make_deferred_yieldable does the right thing when its
        argument isn't actually a deferred"""

        with LoggingContext() as context_one:
            context_one.request = "one"

            d1 = logcontext.make_deferred_yieldable("bum")
            self._check_test_key("one")

            r = yield d1
            self.assertEqual(r, "bum")
            self._check_test_key("one")
Пример #44
0
    def test_make_deferred_yieldable_with_chained_deferreds(self):
        sentinel_context = LoggingContext.current_context()

        with LoggingContext() as context_one:
            context_one.request = "one"

            d1 = logcontext.make_deferred_yieldable(_chained_deferred_function())
            # make sure that the context was reset by make_deferred_yieldable
            self.assertIs(LoggingContext.current_context(), sentinel_context)

            yield d1

            # now it should be restored
            self._check_test_key("one")
Пример #45
0
    def request(self, method, uri, data=b'', headers=None):
        """
        Args:
            method (str): HTTP method to use.
            uri (str): URI to query.
            data (bytes): Data to send in the request body, if applicable.
            headers (t.w.http_headers.Headers): Request headers.

        Raises:
            SynapseError: If the IP is blacklisted.
        """
        # A small wrapper around self.agent.request() so we can easily attach
        # counters to it
        outgoing_requests_counter.labels(method).inc()

        # log request but strip `access_token` (AS requests for example include this)
        logger.info("Sending request %s %s", method, redact_uri(uri))

        try:
            request_deferred = treq.request(
                method,
                uri,
                agent=self.agent,
                data=data,
                headers=headers,
                **self._extra_treq_args
            )
            request_deferred = timeout_deferred(
                request_deferred,
                60,
                self.hs.get_reactor(),
                cancelled_to_request_timed_out_error,
            )
            response = yield make_deferred_yieldable(request_deferred)

            incoming_responses_counter.labels(method, response.code).inc()
            logger.info(
                "Received response to %s %s: %s", method, redact_uri(uri), response.code
            )
            defer.returnValue(response)
        except Exception as e:
            incoming_responses_counter.labels(method, "ERR").inc()
            logger.info(
                "Error sending request to  %s %s: %s %s",
                method,
                redact_uri(uri),
                type(e).__name__,
                e.args[0],
            )
            raise
Пример #46
0
        def second_lookup():
            with LoggingContext("12") as context_12:
                context_12.request = "12"
                self.http_client.post_json.reset_mock()
                self.http_client.post_json.return_value = defer.Deferred()

                res_deferreds_2 = kr.verify_json_objects_for_server(
                    [("server10", json1)]
                )
                res_deferreds_2[0].addBoth(self.check_context, None)
                yield logcontext.make_deferred_yieldable(res_deferreds_2[0])

                # let verify_json_objects_for_server finish its work before we kill the
                # logcontext
                yield self.clock.sleep(0)
Пример #47
0
        def get_file(destination, path, output_stream, args=None, max_size=None):
            """
            Returns tuple[int,dict,str,int] of file length, response headers,
            absolute URI, and response code.
            """

            def write_to(r):
                data, response = r
                output_stream.write(data)
                return response

            d = Deferred()
            d.addCallback(write_to)
            self.fetches.append((d, destination, path, args))
            return make_deferred_yieldable(d)
Пример #48
0
    def wrap(self, key, callback, *args, **kwargs):
        """Wrap together a *get* and *set* call, taking care of logcontexts

        First looks up the key in the cache, and if it is present makes it
        follow the synapse logcontext rules and returns it.

        Otherwise, makes a call to *callback(*args, **kwargs)*, which should
        follow the synapse logcontext rules, and adds the result to the cache.

        Example usage:

            @defer.inlineCallbacks
            def handle_request(request):
                # etc
                defer.returnValue(result)

            result = yield response_cache.wrap(
                key,
                handle_request,
                request,
            )

        Args:
            key (hashable): key to get/set in the cache

            callback (callable): function to call if the key is not found in
                the cache

            *args: positional parameters to pass to the callback, if it is used

            **kwargs: named paramters to pass to the callback, if it is used

        Returns:
            twisted.internet.defer.Deferred: yieldable result
        """
        result = self.get(key)
        if not result:
            logger.info("[%s]: no cached result for [%s], calculating new one",
                        self._name, key)
            d = run_in_background(callback, *args, **kwargs)
            result = self.set(key, d)
        elif not isinstance(result, defer.Deferred) or result.called:
            logger.info("[%s]: using completed cached result for [%s]",
                        self._name, key)
        else:
            logger.info("[%s]: using incomplete cached result for [%s]",
                        self._name, key)
        return make_deferred_yieldable(result)
Пример #49
0
    def get_prev_state_ids(self, store):
        """Gets the prev state IDs

        Returns:
            Deferred[dict[(str, str), str]|None]: Returns None if state_group
            is None, which happens when the associated event is an outlier.
        """

        if not self._fetching_state_deferred:
            self._fetching_state_deferred = run_in_background(
                self._fill_out_state, store,
            )

        yield make_deferred_yieldable(self._fetching_state_deferred)

        defer.returnValue(self._prev_state_ids)
Пример #50
0
def yieldable_gather_results(func, iter, *args, **kwargs):
    """Executes the function with each argument concurrently.

    Args:
        func (func): Function to execute that returns a Deferred
        iter (iter): An iterable that yields items that get passed as the first
            argument to the function
        *args: Arguments to be passed to each call to func

    Returns
        Deferred[list]: Resolved when all functions have been invoked, or errors if
        one of the function calls fails.
    """
    return logcontext.make_deferred_yieldable(defer.gatherResults([
        run_in_background(func, item, *args, **kwargs)
        for item in iter
    ], consumeErrors=True)).addErrback(unwrapFirstError)
Пример #51
0
    def query_3pe(self, kind, protocol, fields):
        services = yield self._get_services_for_3pn(protocol)

        results = yield make_deferred_yieldable(defer.DeferredList([
            run_in_background(
                self.appservice_api.query_3pe,
                service, kind, protocol, fields,
            )
            for service in services
        ], consumeErrors=True))

        ret = []
        for (success, result) in results:
            if success:
                ret.extend(result)

        defer.returnValue(ret)
Пример #52
0
        def wrapped(*args, **kwargs):
            # If we're passed a cache_context then we'll want to call its invalidate()
            # whenever we are invalidated
            invalidate_callback = kwargs.pop("on_invalidate", None)

            cache_key = get_cache_key(args, kwargs)

            # Add our own `cache_context` to argument list if the wrapped function
            # has asked for one
            if self.add_cache_context:
                kwargs["cache_context"] = _CacheContext(cache, cache_key)

            try:
                cached_result_d = cache.get(cache_key, callback=invalidate_callback)

                if isinstance(cached_result_d, ObservableDeferred):
                    observer = cached_result_d.observe()
                else:
                    observer = cached_result_d

            except KeyError:
                ret = defer.maybeDeferred(
                    logcontext.preserve_fn(self.function_to_call),
                    obj, *args, **kwargs
                )

                def onErr(f):
                    cache.invalidate(cache_key)
                    return f

                ret.addErrback(onErr)

                # If our cache_key is a string on py2, try to convert to ascii
                # to save a bit of space in large caches. Py3 does this
                # internally automatically.
                if six.PY2 and isinstance(cache_key, string_types):
                    cache_key = to_ascii(cache_key)

                result_d = ObservableDeferred(ret, consumeErrors=True)
                cache.set(cache_key, result_d, callback=invalidate_callback)
                observer = result_d.observe()

            if isinstance(observer, defer.Deferred):
                return logcontext.make_deferred_yieldable(observer)
            else:
                return observer
Пример #53
0
    def update_state(self, state_group, prev_state_ids, current_state_ids,
                     prev_group, delta_ids):
        """Replace the state in the context
        """

        # We need to make sure we wait for any ongoing fetching of state
        # to complete so that the updated state doesn't get clobbered
        if self._fetching_state_deferred:
            yield make_deferred_yieldable(self._fetching_state_deferred)

        self.state_group = state_group
        self._prev_state_ids = prev_state_ids
        self.prev_group = prev_group
        self._current_state_ids = current_state_ids
        self.delta_ids = delta_ids

        # We need to ensure that that we've marked as having fetched the state
        self._fetching_state_deferred = defer.succeed(None)
Пример #54
0
    def get_current_state_ids(self, store):
        """Gets the current state IDs

        Returns:
            Deferred[dict[(str, str), str]|None]: Returns None if state_group
                is None, which happens when the associated event is an outlier.
                Maps a (type, state_key) to the event ID of the state event matching
                this tuple.
        """

        if not self._fetching_state_deferred:
            self._fetching_state_deferred = run_in_background(
                self._fill_out_state, store,
            )

        yield make_deferred_yieldable(self._fetching_state_deferred)

        defer.returnValue(self._current_state_ids)
Пример #55
0
def respond_with_file(request, media_type, file_path, file_size=None, upload_name=None):
    logger.debug("Responding with %r", file_path)

    if os.path.isfile(file_path):
        if file_size is None:
            stat = os.stat(file_path)
            file_size = stat.st_size

        add_file_headers(request, media_type, file_size, upload_name)

        with open(file_path, "rb") as f:
            yield logcontext.make_deferred_yieldable(
                FileSender().beginFileTransfer(f, request)
            )

        finish_request(request)
    else:
        respond_404(request)
Пример #56
0
 def test_send_single_event_with_queue(self):
     d = defer.Deferred()
     self.txn_ctrl.send = Mock(side_effect=lambda x, y: make_deferred_yieldable(d))
     service = Mock(id=4)
     event = Mock(event_id="first")
     event2 = Mock(event_id="second")
     event3 = Mock(event_id="third")
     # Send an event and don't resolve it just yet.
     self.queuer.enqueue(service, event)
     # Send more events: expect send() to NOT be called multiple times.
     self.queuer.enqueue(service, event2)
     self.queuer.enqueue(service, event3)
     self.txn_ctrl.send.assert_called_with(service, [event])
     self.assertEquals(1, self.txn_ctrl.send.call_count)
     # Resolve the send event: expect the queued events to be sent
     d.callback(service)
     self.txn_ctrl.send.assert_called_with(service, [event2, event3])
     self.assertEquals(2, self.txn_ctrl.send.call_count)
Пример #57
0
    def put_json(self, uri, json_body, args={}, headers=None):
        """ Puts some json to the given URI.

        Args:
            uri (str): The URI to request, not including query parameters
            json_body (dict): The JSON to put in the HTTP body,
            args (dict): A dictionary used to create query strings, defaults to
                None.
                **Note**: The value of each key is assumed to be an iterable
                and *not* a string.
            headers (dict[str, List[str]]|None): If not None, a map from
               header name to a list of values for that header
        Returns:
            Deferred: Succeeds when we get *any* 2xx HTTP response, with the
            HTTP body as JSON.
        Raises:
            HttpResponseException On a non-2xx HTTP response.

            ValueError: if the response was not JSON
        """
        if len(args):
            query_bytes = urllib.parse.urlencode(args, True)
            uri = "%s?%s" % (uri, query_bytes)

        json_str = encode_canonical_json(json_body)

        actual_headers = {
            b"Content-Type": [b"application/json"],
            b"User-Agent": [self.user_agent],
        }
        if headers:
            actual_headers.update(headers)

        response = yield self.request(
            "PUT", uri, headers=Headers(actual_headers), data=json_str
        )

        body = yield make_deferred_yieldable(readBody(response))

        if 200 <= response.code < 300:
            defer.returnValue(json.loads(body))
        else:
            raise HttpResponseException(response.code, response.phrase, body)
Пример #58
0
    def on_REPLICATE(self, cmd):
        stream_name = cmd.stream_name
        token = cmd.token

        if stream_name == "ALL":
            # Subscribe to all streams we're publishing to.
            deferreds = [
                run_in_background(
                    self.subscribe_to_stream,
                    stream, token,
                )
                for stream in iterkeys(self.streamer.streams_by_name)
            ]

            return make_deferred_yieldable(
                defer.gatherResults(deferreds, consumeErrors=True)
            )
        else:
            return self.subscribe_to_stream(stream_name, token)
Пример #59
0
 def store_keys(self, server_name, from_server, verify_keys):
     """Store a collection of verify keys for a given server
     Args:
         server_name(str): The name of the server the keys are for.
         from_server(str): The server the keys were downloaded from.
         verify_keys(dict): A mapping of key_id to VerifyKey.
     Returns:
         A deferred that completes when the keys are stored.
     """
     # TODO(markjh): Store whether the keys have expired.
     return logcontext.make_deferred_yieldable(defer.gatherResults(
         [
             run_in_background(
                 self.store.store_server_verify_key,
                 server_name, server_name, key.time_added, key
             )
             for key_id, key in verify_keys.items()
         ],
         consumeErrors=True,
     ).addErrback(unwrapFirstError))
Пример #60
0
    def store_file(self, source, file_info):
        """Write `source` to the on disk media store, and also any other
        configured storage providers

        Args:
            source: A file like object that should be written
            file_info (FileInfo): Info about the file to store

        Returns:
            Deferred[str]: the file path written to in the primary media store
        """

        with self.store_into_file(file_info) as (f, fname, finish_cb):
            # Write to the main repository
            yield make_deferred_yieldable(threads.deferToThread(
                _write_file_synchronously, source, f,
            ))
            yield finish_cb()

        defer.returnValue(fname)