Exemple #1
0
    def add_to_queue(self, room_id, events_and_contexts, backfilled):
        """Add events to the queue, with the given persist_event options.

        NB: due to the normal usage pattern of this method, it does *not*
        follow the synapse logcontext rules, and leaves the logcontext in
        place whether or not the returned deferred is ready.

        Args:
            room_id (str):
            events_and_contexts (list[(EventBase, EventContext)]):
            backfilled (bool):

        Returns:
            defer.Deferred: a deferred which will resolve once the events are
                persisted. Runs its callbacks *without* a logcontext.
        """
        queue = self._event_persist_queues.setdefault(room_id, deque())
        if queue:
            # if the last item in the queue has the same `backfilled` setting,
            # we can just add these new events to that item.
            end_item = queue[-1]
            if end_item.backfilled == backfilled:
                end_item.events_and_contexts.extend(events_and_contexts)
                return end_item.deferred.observe()

        deferred = ObservableDeferred(defer.Deferred(), consumeErrors=True)

        queue.append(
            self._EventPersistQueueItem(
                events_and_contexts=events_and_contexts,
                backfilled=backfilled,
                deferred=deferred,
            ))

        return deferred.observe()
    def fetch_or_execute(self, txn_key, fn, *args, **kwargs):
        """Fetches the response for this transaction, or executes the given function
        to produce a response for this transaction.

        Args:
            txn_key (str): A key to ensure idempotency should fetch_or_execute be
            called again at a later point in time.
            fn (function): A function which returns a tuple of
            (response_code, response_dict).
            *args: Arguments to pass to fn.
            **kwargs: Keyword arguments to pass to fn.
        Returns:
            Deferred which resolves to a tuple of (response_code, response_dict).
        """
        if txn_key in self.transactions:
            observable = self.transactions[txn_key][0]
        else:
            # execute the function instead.
            deferred = run_in_background(fn, *args, **kwargs)

            observable = ObservableDeferred(deferred)
            self.transactions[txn_key] = (observable, self.clock.time_msec())

            # if the request fails with an exception, remove it
            # from the transaction map. This is done to ensure that we don't
            # cache transient errors like rate-limiting errors, etc.
            def remove_from_map(err):
                self.transactions.pop(txn_key, None)
                # we deliberately do not propagate the error any further, as we
                # expect the observers to have reported it.

            deferred.addErrback(remove_from_map)

        return make_deferred_yieldable(observable.observe())
Exemple #3
0
    def notify(
        self,
        stream_key: str,
        stream_id: Union[int, RoomStreamToken],
        time_now_ms: int,
    ):
        """Notify any listeners for this user of a new event from an
        event source.
        Args:
            stream_key: The stream the event came from.
            stream_id: The new id for the stream the event came from.
            time_now_ms: The current time in milliseconds.
        """
        self.current_token = self.current_token.copy_and_advance(stream_key, stream_id)
        self.last_notified_token = self.current_token
        self.last_notified_ms = time_now_ms
        noify_deferred = self.notify_deferred

        log_kv(
            {
                "notify": self.user_id,
                "stream": stream_key,
                "stream_id": stream_id,
                "listeners": self.count_listeners(),
            }
        )

        users_woken_by_stream_counter.labels(stream_key).inc()

        with PreserveLoggingContext():
            self.notify_deferred = ObservableDeferred(defer.Deferred())
            noify_deferred.callback(self.current_token)
    async def _async_render_GET(self, request: SynapseRequest) -> None:
        # XXX: if get_user_by_req fails, what should we do in an async render?
        requester = await self.auth.get_user_by_req(request)
        url = parse_string(request, "url", required=True)
        ts = parse_integer(request, "ts")
        if ts is None:
            ts = self.clock.time_msec()

        # XXX: we could move this into _do_preview if we wanted.
        url_tuple = urlparse.urlsplit(url)
        for entry in self.url_preview_url_blacklist:
            match = True
            for attrib in entry:
                pattern = entry[attrib]
                value = getattr(url_tuple, attrib)
                logger.debug(
                    "Matching attrib '%s' with value '%s' against pattern '%s'",
                    attrib,
                    value,
                    pattern,
                )

                if value is None:
                    match = False
                    continue

                if pattern.startswith("^"):
                    if not re.match(pattern, getattr(url_tuple, attrib)):
                        match = False
                        continue
                else:
                    if not fnmatch.fnmatch(getattr(url_tuple, attrib),
                                           pattern):
                        match = False
                        continue
            if match:
                logger.warning("URL %s blocked by url_blacklist entry %s", url,
                               entry)
                raise SynapseError(
                    403, "URL blocked by url pattern blacklist entry",
                    Codes.UNKNOWN)

        # the in-memory cache:
        # * ensures that only one request is active at a time
        # * takes load off the DB for the thundering herds
        # * also caches any failures (unlike the DB) so we don't keep
        #    requesting the same endpoint

        observable = self._cache.get(url)

        if not observable:
            download = run_in_background(self._do_preview, url, requester.user,
                                         ts)
            observable = ObservableDeferred(download, consumeErrors=True)
            self._cache[url] = observable
        else:
            logger.info("Returning cached response")

        og = await make_deferred_yieldable(observable.observe())
        respond_with_json_bytes(request, 200, og, send_cors=True)
    def set(self, key: T, deferred: defer.Deferred) -> defer.Deferred:
        """Set the entry for the given key to the given deferred.

        *deferred* should run its callbacks in the sentinel logcontext (ie,
        you should wrap normal synapse deferreds with
        synapse.logging.context.run_in_background).

        Can return either a new Deferred (which also doesn't follow the synapse
        logcontext rules), or, if *deferred* was already complete, the actual
        result. You will probably want to make_deferred_yieldable the result.

        Args:
            key: key to get/set in the cache
            deferred: The deferred which resolves to the result.

        Returns:
            A new deferred which resolves to the actual result.
        """
        result = ObservableDeferred(deferred, consumeErrors=True)
        self.pending_result_cache[key] = result

        def remove(r):
            if self.timeout_sec:
                self.clock.call_later(self.timeout_sec,
                                      self.pending_result_cache.pop, key, None)
            else:
                self.pending_result_cache.pop(key, None)
            return r

        result.addBoth(remove)
        return result.observe()
    def test_succeed(self):
        origin_d = Deferred()
        observable = ObservableDeferred(origin_d)

        observer1 = observable.observe()
        observer2 = observable.observe()

        self.assertFalse(observer1.called)
        self.assertFalse(observer2.called)

        # check the first observer is called first
        def check_called_first(res):
            self.assertFalse(observer2.called)
            return res

        observer1.addBoth(check_called_first)

        # store the results
        results = [None, None]

        def check_val(res, idx):
            results[idx] = res
            return res

        observer1.addCallback(check_val, 0)
        observer2.addCallback(check_val, 1)

        origin_d.callback(123)
        self.assertEqual(results[0], 123, "observer 1 callback result")
        self.assertEqual(results[1], 123, "observer 2 callback result")
    def test_failure(self):
        origin_d = Deferred()
        observable = ObservableDeferred(origin_d, consumeErrors=True)

        observer1 = observable.observe()
        observer2 = observable.observe()

        self.assertFalse(observer1.called)
        self.assertFalse(observer2.called)

        # check the first observer is called first
        def check_called_first(res):
            self.assertFalse(observer2.called)
            return res

        observer1.addBoth(check_called_first)

        # store the results
        results = [None, None]

        def check_val(res, idx):
            results[idx] = res
            return None

        observer1.addErrback(check_val, 0)
        observer2.addErrback(check_val, 1)

        try:
            raise Exception("gah!")
        except Exception as e:
            origin_d.errback(e)
        self.assertEqual(str(results[0].value), "gah!",
                         "observer 1 errback result")
        self.assertEqual(str(results[1].value), "gah!",
                         "observer 2 errback result")
Exemple #8
0
    def set(self, key, deferred):
        """Set the entry for the given key to the given deferred.

        *deferred* should run its callbacks in the sentinel logcontext (ie,
        you should wrap normal synapse deferreds with
        logcontext.run_in_background).

        Can return either a new Deferred (which also doesn't follow the synapse
        logcontext rules), or, if *deferred* was already complete, the actual
        result. You will probably want to make_deferred_yieldable the result.

        Args:
            key (hashable):
            deferred (twisted.internet.defer.Deferred[T):

        Returns:
            twisted.internet.defer.Deferred[T]|T: a new deferred, or the actual
                result.
        """
        result = ObservableDeferred(deferred, consumeErrors=True)
        self.pending_result_cache[key] = result

        def remove(r):
            if self.timeout_sec:
                self.clock.call_later(
                    self.timeout_sec,
                    self.pending_result_cache.pop, key, None,
                )
            else:
                self.pending_result_cache.pop(key, None)
            return r

        result.addBoth(remove)
        return result.observe()
    def test_cancellation(self):
        """Test that cancelling an observer does not affect other observers."""
        origin_d: "Deferred[int]" = Deferred()
        observable = ObservableDeferred(origin_d, consumeErrors=True)

        observer1 = observable.observe()
        observer2 = observable.observe()
        observer3 = observable.observe()

        self.assertFalse(observer1.called)
        self.assertFalse(observer2.called)
        self.assertFalse(observer3.called)

        # cancel the second observer
        observer2.cancel()
        self.assertFalse(observer1.called)
        self.failureResultOf(observer2, CancelledError)
        self.assertFalse(observer3.called)

        # other observers resolve as normal
        origin_d.callback(123)
        self.assertEqual(observer1.result, 123, "observer 1 callback result")
        self.assertEqual(observer3.result, 123, "observer 3 callback result")

        # additional observers resolve as normal
        observer4 = observable.observe()
        self.assertEqual(observer4.result, 123, "observer 4 callback result")
Exemple #10
0
    def set(self, key, value, callback=None):
        if not isinstance(value, defer.Deferred):
            raise TypeError("not a Deferred")

        callbacks = [callback] if callback else []
        self.check_thread()
        observable = ObservableDeferred(value, consumeErrors=True)
        observer = observable.observe()
        entry = CacheEntry(deferred=observable, callbacks=callbacks)

        existing_entry = self._pending_deferred_cache.pop(key, None)
        if existing_entry:
            existing_entry.invalidate()

        self._pending_deferred_cache[key] = entry

        def compare_and_pop():
            """Check if our entry is still the one in _pending_deferred_cache, and
            if so, pop it.

            Returns true if the entries matched.
            """
            existing_entry = self._pending_deferred_cache.pop(key, None)
            if existing_entry is entry:
                return True

            # oops, the _pending_deferred_cache has been updated since
            # we started our query, so we are out of date.
            #
            # Better put back whatever we took out. (We do it this way
            # round, rather than peeking into the _pending_deferred_cache
            # and then removing on a match, to make the common case faster)
            if existing_entry is not None:
                self._pending_deferred_cache[key] = existing_entry

            return False

        def cb(result):
            if compare_and_pop():
                self.cache.set(key, result, entry.callbacks)
            else:
                # we're not going to put this entry into the cache, so need
                # to make sure that the invalidation callbacks are called.
                # That was probably done when _pending_deferred_cache was
                # updated, but it's possible that `set` was called without
                # `invalidate` being previously called, in which case it may
                # not have been. Either way, let's double-check now.
                entry.invalidate()

        def eb(_fail):
            compare_and_pop()
            entry.invalidate()

        # once the deferred completes, we can move the entry from the
        # _pending_deferred_cache to the real cache.
        #
        observer.addCallbacks(cb, eb)
        return observable
Exemple #11
0
    def __init__(self, hs):
        self.user_to_user_stream = {}
        self.room_to_user_streams = {}

        self.hs = hs
        self.event_sources = hs.get_event_sources()
        self.store = hs.get_datastore()
        self.pending_new_room_events = []

        self.replication_callbacks = []

        self.clock = hs.get_clock()
        self.appservice_handler = hs.get_application_service_handler()

        if hs.should_send_federation():
            self.federation_sender = hs.get_federation_sender()
        else:
            self.federation_sender = None

        self.state_handler = hs.get_state_handler()

        self.clock.looping_call(self.remove_expired_streams,
                                self.UNUSED_STREAM_EXPIRY_MS)

        self.replication_deferred = ObservableDeferred(defer.Deferred())

        # This is not a very cheap test to perform, but it's only executed
        # when rendering the metrics page, which is likely once per minute at
        # most when scraping it.
        def count_listeners():
            all_user_streams = set()

            for x in self.room_to_user_streams.values():
                all_user_streams |= x
            for x in self.user_to_user_stream.values():
                all_user_streams.add(x)

            return sum(stream.count_listeners() for stream in all_user_streams)

        LaterGauge("synapse_notifier_listeners", "", [], count_listeners)

        LaterGauge(
            "synapse_notifier_rooms",
            "",
            [],
            lambda: count(bool, self.room_to_user_streams.values()),
        )
        LaterGauge(
            "synapse_notifier_users",
            "",
            [],
            lambda: len(self.user_to_user_stream),
        )
Exemple #12
0
    def notify_replication(self):
        """Notify the any replication listeners that there's a new event"""
        with PreserveLoggingContext():
            deferred = self.replication_deferred
            self.replication_deferred = ObservableDeferred(defer.Deferred())
            deferred.callback(None)

            # the callbacks may well outlast the current request, so we run
            # them in the sentinel logcontext.
            #
            # (ideally it would be up to the callbacks to know if they were
            # starting off background processes and drop the logcontext
            # accordingly, but that requires more changes)
            for cb in self.replication_callbacks:
                cb()
Exemple #13
0
    def __init__(self, user_id, rooms, current_token, time_now_ms):
        self.user_id = user_id
        self.rooms = set(rooms)
        self.current_token = current_token

        # The last token for which we should wake up any streams that have a
        # token that comes before it. This gets updated everytime we get poked.
        # We start it at the current token since if we get any streams
        # that have a token from before we have no idea whether they should be
        # woken up or not, so lets just wake them up.
        self.last_notified_token = current_token
        self.last_notified_ms = time_now_ms

        with PreserveLoggingContext():
            self.notify_deferred = ObservableDeferred(defer.Deferred())
Exemple #14
0
    async def add_to_queue(
        self,
        room_id: str,
        events_and_contexts: Iterable[Tuple[EventBase, EventContext]],
        backfilled: bool,
    ) -> _PersistResult:
        """Add events to the queue, with the given persist_event options.

        If we are not already processing events in this room, starts off a background
        process to to so, calling the per_item_callback for each item.

        Args:
            room_id (str):
            events_and_contexts (list[(EventBase, EventContext)]):
            backfilled (bool):

        Returns:
            the result returned by the `_per_item_callback` passed to
            `__init__`.
        """
        queue = self._event_persist_queues.setdefault(room_id, deque())

        # if the last item in the queue has the same `backfilled` setting,
        # we can just add these new events to that item.
        if queue and queue[-1].backfilled == backfilled:
            end_item = queue[-1]
        else:
            # need to make a new queue item
            deferred: ObservableDeferred[_PersistResult] = ObservableDeferred(
                defer.Deferred(), consumeErrors=True)

            end_item = _EventPersistQueueItem(
                events_and_contexts=[],
                backfilled=backfilled,
                deferred=deferred,
            )
            queue.append(end_item)

        # add our events to the queue item
        end_item.events_and_contexts.extend(events_and_contexts)

        # also add our active opentracing span to the item so that we get a link back
        span = opentracing.active_span()
        if span:
            end_item.parent_opentracing_span_contexts.append(span.context)

        # start a processor for the queue, if there isn't one already
        self._handle_queue(room_id)

        # wait for the queue item to complete
        res = await make_deferred_yieldable(end_item.deferred.observe())

        # add another opentracing span which links to the persist trace.
        with opentracing.start_active_span_follows_from(
                "persist_event_batch_complete",
            (end_item.opentracing_span_context, )):
            pass

        return res
Exemple #15
0
        def wrapped(*args, **kwargs):
            # If we're passed a cache_context then we'll want to call its invalidate()
            # whenever we are invalidated
            invalidate_callback = kwargs.pop("on_invalidate", None)

            cache_key = get_cache_key(args, kwargs)

            # Add our own `cache_context` to argument list if the wrapped function
            # has asked for one
            if self.add_cache_context:
                kwargs["cache_context"] = _CacheContext(cache, cache_key)

            try:
                cached_result_d = cache.get(cache_key, callback=invalidate_callback)

                if isinstance(cached_result_d, ObservableDeferred):
                    observer = cached_result_d.observe()
                else:
                    observer = cached_result_d

            except KeyError:
                ret = defer.maybeDeferred(
                    logcontext.preserve_fn(self.function_to_call),
                    obj, *args, **kwargs
                )

                def onErr(f):
                    cache.invalidate(cache_key)
                    return f

                ret.addErrback(onErr)

                # If our cache_key is a string on py2, try to convert to ascii
                # to save a bit of space in large caches. Py3 does this
                # internally automatically.
                if six.PY2 and isinstance(cache_key, string_types):
                    cache_key = to_ascii(cache_key)

                result_d = ObservableDeferred(ret, consumeErrors=True)
                cache.set(cache_key, result_d, callback=invalidate_callback)
                observer = result_d.observe()

            if isinstance(observer, defer.Deferred):
                return logcontext.make_deferred_yieldable(observer)
            else:
                return observer
Exemple #16
0
        def wrapped(*args, **kwargs):
            # If we're passed a cache_context then we'll want to call its invalidate()
            # whenever we are invalidated
            invalidate_callback = kwargs.pop("on_invalidate", None)

            cache_key = get_cache_key(args, kwargs)

            # Add our own `cache_context` to argument list if the wrapped function
            # has asked for one
            if self.add_cache_context:
                kwargs["cache_context"] = _CacheContext(cache, cache_key)

            try:
                cached_result_d = cache.get(cache_key,
                                            callback=invalidate_callback)

                if isinstance(cached_result_d, ObservableDeferred):
                    observer = cached_result_d.observe()
                else:
                    observer = cached_result_d

            except KeyError:
                ret = defer.maybeDeferred(
                    logcontext.preserve_fn(self.function_to_call), obj, *args,
                    **kwargs)

                def onErr(f):
                    cache.invalidate(cache_key)
                    return f

                ret.addErrback(onErr)

                # If our cache_key is a string on py2, try to convert to ascii
                # to save a bit of space in large caches. Py3 does this
                # internally automatically.
                if six.PY2 and isinstance(cache_key, string_types):
                    cache_key = to_ascii(cache_key)

                result_d = ObservableDeferred(ret, consumeErrors=True)
                cache.set(cache_key, result_d, callback=invalidate_callback)
                observer = result_d.observe()

            if isinstance(observer, defer.Deferred):
                return logcontext.make_deferred_yieldable(observer)
            else:
                return observer
Exemple #17
0
    def set(self, time_now_ms, key, deferred):
        self.rotate(time_now_ms)

        result = ObservableDeferred(deferred)

        self.pending_result_cache[key] = result

        def shuffle_along(r):
            # When the deferred completes we shuffle it along to the first
            # generation of the result cache. So that it will eventually
            # expire from the rotation of that cache.
            self.next_result_cache[key] = result
            self.pending_result_cache.pop(key, None)
            return r

        result.addBoth(shuffle_along)

        return result.observe()
Exemple #18
0
    def notify(self, stream_key, stream_id, time_now_ms):
        """Notify any listeners for this user of a new event from an
        event source.
        Args:
            stream_key(str): The stream the event came from.
            stream_id(str): The new id for the stream the event came from.
            time_now_ms(int): The current time in milliseconds.
        """
        self.current_token = self.current_token.copy_and_advance(stream_key, stream_id)
        self.last_notified_token = self.current_token
        self.last_notified_ms = time_now_ms
        noify_deferred = self.notify_deferred

        users_woken_by_stream_counter.labels(stream_key).inc()

        with PreserveLoggingContext():
            self.notify_deferred = ObservableDeferred(defer.Deferred())
            noify_deferred.callback(self.current_token)
    def _set(
        self,
        context: ResponseCacheContext[KV],
        deferred: "defer.Deferred[RV]",
        opentracing_span_context: "Optional[opentracing.SpanContext]",
    ) -> ResponseCacheEntry:
        """Set the entry for the given key to the given deferred.

        *deferred* should run its callbacks in the sentinel logcontext (ie,
        you should wrap normal synapse deferreds with
        synapse.logging.context.run_in_background).

        Args:
            context: Information about the cache miss
            deferred: The deferred which resolves to the result.
            opentracing_span_context: An opentracing span wrapping the calculation

        Returns:
            The cache entry object.
        """
        result = ObservableDeferred(deferred, consumeErrors=True)
        key = context.cache_key
        entry = ResponseCacheEntry(result, opentracing_span_context)
        self._result_cache[key] = entry

        def on_complete(r: RV) -> RV:
            # if this cache has a non-zero timeout, and the callback has not cleared
            # the should_cache bit, we leave it in the cache for now and schedule
            # its removal later.
            if self.timeout_sec and context.should_cache:
                self.clock.call_later(
                    self.timeout_sec, self._result_cache.pop, key, None
                )
            else:
                # otherwise, remove the result immediately.
                self._result_cache.pop(key, None)
            return r

        # make sure we do this *after* adding the entry to result_cache,
        # in case the result is already complete (in which case flipping the order would
        # leave us with a stuck entry in the cache).
        result.addBoth(on_complete)
        return entry
Exemple #20
0
    def _set(self, context: ResponseCacheContext[KV],
             deferred: defer.Deferred) -> defer.Deferred:
        """Set the entry for the given key to the given deferred.

        *deferred* should run its callbacks in the sentinel logcontext (ie,
        you should wrap normal synapse deferreds with
        synapse.logging.context.run_in_background).

        Returns a new Deferred (which also doesn't follow the synapse logcontext rules).
        You will probably want to make_deferred_yieldable the result.

        Args:
            context: Information about the cache miss
            deferred: The deferred which resolves to the result.

        Returns:
            A new deferred which resolves to the actual result.
        """
        result = ObservableDeferred(deferred, consumeErrors=True)
        key = context.cache_key
        self.pending_result_cache[key] = result

        def on_complete(r):
            # if this cache has a non-zero timeout, and the callback has not cleared
            # the should_cache bit, we leave it in the cache for now and schedule
            # its removal later.
            if self.timeout_sec and context.should_cache:
                self.clock.call_later(self.timeout_sec,
                                      self.pending_result_cache.pop, key, None)
            else:
                # otherwise, remove the result immediately.
                self.pending_result_cache.pop(key, None)
            return r

        # make sure we do this *after* adding the entry to pending_result_cache,
        # in case the result is already complete (in which case flipping the order would
        # leave us with a stuck entry in the cache).
        result.addBoth(on_complete)
        return result.observe()
Exemple #21
0
    def __init__(self, hs):
        self.user_to_user_stream = {}
        self.room_to_user_streams = {}

        self.hs = hs
        self.event_sources = hs.get_event_sources()
        self.store = hs.get_datastore()
        self.pending_new_room_events = []

        self.replication_callbacks = []

        self.clock = hs.get_clock()
        self.appservice_handler = hs.get_application_service_handler()

        if hs.should_send_federation():
            self.federation_sender = hs.get_federation_sender()
        else:
            self.federation_sender = None

        self.state_handler = hs.get_state_handler()

        self.clock.looping_call(
            self.remove_expired_streams, self.UNUSED_STREAM_EXPIRY_MS
        )

        self.replication_deferred = ObservableDeferred(defer.Deferred())

        # This is not a very cheap test to perform, but it's only executed
        # when rendering the metrics page, which is likely once per minute at
        # most when scraping it.
        def count_listeners():
            all_user_streams = set()

            for x in list(self.room_to_user_streams.values()):
                all_user_streams |= x
            for x in list(self.user_to_user_stream.values()):
                all_user_streams.add(x)

            return sum(stream.count_listeners() for stream in all_user_streams)
        LaterGauge("synapse_notifier_listeners", "", [], count_listeners)

        LaterGauge(
            "synapse_notifier_rooms", "", [],
            lambda: count(bool, list(self.room_to_user_streams.values())),
        )
        LaterGauge(
            "synapse_notifier_users", "", [],
            lambda: len(self.user_to_user_stream),
        )
Exemple #22
0
    def __init__(self, user_id, rooms, current_token, time_now_ms):
        self.user_id = user_id
        self.rooms = set(rooms)
        self.current_token = current_token

        # The last token for which we should wake up any streams that have a
        # token that comes before it. This gets updated everytime we get poked.
        # We start it at the current token since if we get any streams
        # that have a token from before we have no idea whether they should be
        # woken up or not, so lets just wake them up.
        self.last_notified_token = current_token
        self.last_notified_ms = time_now_ms

        with PreserveLoggingContext():
            self.notify_deferred = ObservableDeferred(defer.Deferred())
Exemple #23
0
    def test_prefill(self):
        callcount = [0]

        d = defer.succeed(123)

        class A(object):
            @cached()
            def func(self, key):
                callcount[0] += 1
                return d

        a = A()

        a.func.prefill(("foo", ), ObservableDeferred(d))

        self.assertEquals(a.func("foo").result, d.result)
        self.assertEquals(callcount[0], 0)
Exemple #24
0
    def notify(self, stream_key, stream_id, time_now_ms):
        """Notify any listeners for this user of a new event from an
        event source.
        Args:
            stream_key(str): The stream the event came from.
            stream_id(str): The new id for the stream the event came from.
            time_now_ms(int): The current time in milliseconds.
        """
        self.current_token = self.current_token.copy_and_advance(
            stream_key, stream_id
        )
        self.last_notified_token = self.current_token
        self.last_notified_ms = time_now_ms
        noify_deferred = self.notify_deferred

        users_woken_by_stream_counter.labels(stream_key).inc()

        with PreserveLoggingContext():
            self.notify_deferred = ObservableDeferred(defer.Deferred())
            noify_deferred.callback(self.current_token)
Exemple #25
0
    def __init__(
        self,
        user_id: str,
        rooms: Collection[str],
        current_token: StreamToken,
        time_now_ms: int,
    ):
        self.user_id = user_id
        self.rooms = set(rooms)
        self.current_token = current_token

        # The last token for which we should wake up any streams that have a
        # token that comes before it. This gets updated every time we get poked.
        # We start it at the current token since if we get any streams
        # that have a token from before we have no idea whether they should be
        # woken up or not, so lets just wake them up.
        self.last_notified_token = current_token
        self.last_notified_ms = time_now_ms

        self.notify_deferred: ObservableDeferred[
            StreamToken] = ObservableDeferred(defer.Deferred())
Exemple #26
0
class _NotifierUserStream:
    """This represents a user connected to the event stream.
    It tracks the most recent stream token for that user.
    At a given point a user may have a number of streams listening for
    events.

    This listener will also keep track of which rooms it is listening in
    so that it can remove itself from the indexes in the Notifier class.
    """

    def __init__(
        self,
        user_id: str,
        rooms: Collection[str],
        current_token: StreamToken,
        time_now_ms: int,
    ):
        self.user_id = user_id
        self.rooms = set(rooms)
        self.current_token = current_token

        # The last token for which we should wake up any streams that have a
        # token that comes before it. This gets updated every time we get poked.
        # We start it at the current token since if we get any streams
        # that have a token from before we have no idea whether they should be
        # woken up or not, so lets just wake them up.
        self.last_notified_token = current_token
        self.last_notified_ms = time_now_ms

        with PreserveLoggingContext():
            self.notify_deferred = ObservableDeferred(defer.Deferred())

    def notify(self, stream_key: str, stream_id: int, time_now_ms: int):
        """Notify any listeners for this user of a new event from an
        event source.
        Args:
            stream_key: The stream the event came from.
            stream_id: The new id for the stream the event came from.
            time_now_ms: The current time in milliseconds.
        """
        self.current_token = self.current_token.copy_and_advance(stream_key, stream_id)
        self.last_notified_token = self.current_token
        self.last_notified_ms = time_now_ms
        noify_deferred = self.notify_deferred

        users_woken_by_stream_counter.labels(stream_key).inc()

        with PreserveLoggingContext():
            self.notify_deferred = ObservableDeferred(defer.Deferred())
            noify_deferred.callback(self.current_token)

    def remove(self, notifier: "Notifier"):
        """ Remove this listener from all the indexes in the Notifier
        it knows about.
        """

        for room in self.rooms:
            lst = notifier.room_to_user_streams.get(room, set())
            lst.discard(self)

        notifier.user_to_user_stream.pop(self.user_id)

    def count_listeners(self) -> int:
        return len(self.notify_deferred.observers())

    def new_listener(self, token: StreamToken) -> _NotificationListener:
        """Returns a deferred that is resolved when there is a new token
        greater than the given token.

        Args:
            token: The token from which we are streaming from, i.e. we shouldn't
                notify for things that happened before this.
        """
        # Immediately wake up stream if something has already since happened
        # since their last token.
        if self.last_notified_token.is_after(token):
            return _NotificationListener(defer.succeed(self.current_token))
        else:
            return _NotificationListener(self.notify_deferred.observe())
    def _async_render_GET(self, request):

        # XXX: if get_user_by_req fails, what should we do in an async render?
        requester = yield self.auth.get_user_by_req(request)
        url = parse_string(request, "url")
        if b"ts" in request.args:
            ts = parse_integer(request, "ts")
        else:
            ts = self.clock.time_msec()

        # XXX: we could move this into _do_preview if we wanted.
        url_tuple = urlparse.urlsplit(url)
        for entry in self.url_preview_url_blacklist:
            match = True
            for attrib in entry:
                pattern = entry[attrib]
                value = getattr(url_tuple, attrib)
                logger.debug((
                    "Matching attrib '%s' with value '%s' against"
                    " pattern '%s'"
                ) % (attrib, value, pattern))

                if value is None:
                    match = False
                    continue

                if pattern.startswith('^'):
                    if not re.match(pattern, getattr(url_tuple, attrib)):
                        match = False
                        continue
                else:
                    if not fnmatch.fnmatch(getattr(url_tuple, attrib), pattern):
                        match = False
                        continue
            if match:
                logger.warn(
                    "URL %s blocked by url_blacklist entry %s", url, entry
                )
                raise SynapseError(
                    403, "URL blocked by url pattern blacklist entry",
                    Codes.UNKNOWN
                )

        # the in-memory cache:
        # * ensures that only one request is active at a time
        # * takes load off the DB for the thundering herds
        # * also caches any failures (unlike the DB) so we don't keep
        #    requesting the same endpoint

        observable = self._cache.get(url)

        if not observable:
            download = run_in_background(
                self._do_preview,
                url, requester.user, ts,
            )
            observable = ObservableDeferred(
                download,
                consumeErrors=True
            )
            self._cache[url] = observable
        else:
            logger.info("Returning cached response")

        og = yield make_deferred_yieldable(observable.observe())
        respond_with_json_bytes(request, 200, og, send_cors=True)
Exemple #28
0
class _NotifierUserStream(object):
    """This represents a user connected to the event stream.
    It tracks the most recent stream token for that user.
    At a given point a user may have a number of streams listening for
    events.

    This listener will also keep track of which rooms it is listening in
    so that it can remove itself from the indexes in the Notifier class.
    """

    def __init__(self, user_id, rooms, current_token, time_now_ms):
        self.user_id = user_id
        self.rooms = set(rooms)
        self.current_token = current_token

        # The last token for which we should wake up any streams that have a
        # token that comes before it. This gets updated everytime we get poked.
        # We start it at the current token since if we get any streams
        # that have a token from before we have no idea whether they should be
        # woken up or not, so lets just wake them up.
        self.last_notified_token = current_token
        self.last_notified_ms = time_now_ms

        with PreserveLoggingContext():
            self.notify_deferred = ObservableDeferred(defer.Deferred())

    def notify(self, stream_key, stream_id, time_now_ms):
        """Notify any listeners for this user of a new event from an
        event source.
        Args:
            stream_key(str): The stream the event came from.
            stream_id(str): The new id for the stream the event came from.
            time_now_ms(int): The current time in milliseconds.
        """
        self.current_token = self.current_token.copy_and_advance(
            stream_key, stream_id
        )
        self.last_notified_token = self.current_token
        self.last_notified_ms = time_now_ms
        noify_deferred = self.notify_deferred

        users_woken_by_stream_counter.labels(stream_key).inc()

        with PreserveLoggingContext():
            self.notify_deferred = ObservableDeferred(defer.Deferred())
            noify_deferred.callback(self.current_token)

    def remove(self, notifier):
        """ Remove this listener from all the indexes in the Notifier
        it knows about.
        """

        for room in self.rooms:
            lst = notifier.room_to_user_streams.get(room, set())
            lst.discard(self)

        notifier.user_to_user_stream.pop(self.user_id)

    def count_listeners(self):
        return len(self.notify_deferred.observers())

    def new_listener(self, token):
        """Returns a deferred that is resolved when there is a new token
        greater than the given token.

        Args:
            token: The token from which we are streaming from, i.e. we shouldn't
                notify for things that happened before this.
        """
        # Immediately wake up stream if something has already since happened
        # since their last token.
        if self.last_notified_token.is_after(token):
            return _NotificationListener(defer.succeed(self.current_token))
        else:
            return _NotificationListener(self.notify_deferred.observe())
Exemple #29
0
class Notifier(object):
    """ This class is responsible for notifying any listeners when there are
    new events available for it.

    Primarily used from the /events stream.
    """

    UNUSED_STREAM_EXPIRY_MS = 10 * 60 * 1000

    def __init__(self, hs):
        self.user_to_user_stream = {}
        self.room_to_user_streams = {}

        self.hs = hs
        self.event_sources = hs.get_event_sources()
        self.store = hs.get_datastore()
        self.pending_new_room_events = []

        self.replication_callbacks = []

        self.clock = hs.get_clock()
        self.appservice_handler = hs.get_application_service_handler()

        if hs.should_send_federation():
            self.federation_sender = hs.get_federation_sender()
        else:
            self.federation_sender = None

        self.state_handler = hs.get_state_handler()

        self.clock.looping_call(
            self.remove_expired_streams, self.UNUSED_STREAM_EXPIRY_MS
        )

        self.replication_deferred = ObservableDeferred(defer.Deferred())

        # This is not a very cheap test to perform, but it's only executed
        # when rendering the metrics page, which is likely once per minute at
        # most when scraping it.
        def count_listeners():
            all_user_streams = set()

            for x in list(self.room_to_user_streams.values()):
                all_user_streams |= x
            for x in list(self.user_to_user_stream.values()):
                all_user_streams.add(x)

            return sum(stream.count_listeners() for stream in all_user_streams)
        LaterGauge("synapse_notifier_listeners", "", [], count_listeners)

        LaterGauge(
            "synapse_notifier_rooms", "", [],
            lambda: count(bool, list(self.room_to_user_streams.values())),
        )
        LaterGauge(
            "synapse_notifier_users", "", [],
            lambda: len(self.user_to_user_stream),
        )

    def add_replication_callback(self, cb):
        """Add a callback that will be called when some new data is available.
        Callback is not given any arguments.
        """
        self.replication_callbacks.append(cb)

    def on_new_room_event(self, event, room_stream_id, max_room_stream_id,
                          extra_users=[]):
        """ Used by handlers to inform the notifier something has happened
        in the room, room event wise.

        This triggers the notifier to wake up any listeners that are
        listening to the room, and any listeners for the users in the
        `extra_users` param.

        The events can be peristed out of order. The notifier will wait
        until all previous events have been persisted before notifying
        the client streams.
        """
        self.pending_new_room_events.append((
            room_stream_id, event, extra_users
        ))
        self._notify_pending_new_room_events(max_room_stream_id)

        self.notify_replication()

    def _notify_pending_new_room_events(self, max_room_stream_id):
        """Notify for the room events that were queued waiting for a previous
        event to be persisted.
        Args:
            max_room_stream_id(int): The highest stream_id below which all
                events have been persisted.
        """
        pending = self.pending_new_room_events
        self.pending_new_room_events = []
        for room_stream_id, event, extra_users in pending:
            if room_stream_id > max_room_stream_id:
                self.pending_new_room_events.append((
                    room_stream_id, event, extra_users
                ))
            else:
                self._on_new_room_event(event, room_stream_id, extra_users)

    def _on_new_room_event(self, event, room_stream_id, extra_users=[]):
        """Notify any user streams that are interested in this room event"""
        # poke any interested application service.
        run_as_background_process(
            "notify_app_services",
            self._notify_app_services, room_stream_id,
        )

        if self.federation_sender:
            self.federation_sender.notify_new_events(room_stream_id)

        if event.type == EventTypes.Member and event.membership == Membership.JOIN:
            self._user_joined_room(event.state_key, event.room_id)

        self.on_new_event(
            "room_key", room_stream_id,
            users=extra_users,
            rooms=[event.room_id],
        )

    @defer.inlineCallbacks
    def _notify_app_services(self, room_stream_id):
        try:
            yield self.appservice_handler.notify_interested_services(room_stream_id)
        except Exception:
            logger.exception("Error notifying application services of event")

    def on_new_event(self, stream_key, new_token, users=[], rooms=[]):
        """ Used to inform listeners that something has happened event wise.

        Will wake up all listeners for the given users and rooms.
        """
        with PreserveLoggingContext():
            with Measure(self.clock, "on_new_event"):
                user_streams = set()

                for user in users:
                    user_stream = self.user_to_user_stream.get(str(user))
                    if user_stream is not None:
                        user_streams.add(user_stream)

                for room in rooms:
                    user_streams |= self.room_to_user_streams.get(room, set())

                time_now_ms = self.clock.time_msec()
                for user_stream in user_streams:
                    try:
                        user_stream.notify(stream_key, new_token, time_now_ms)
                    except Exception:
                        logger.exception("Failed to notify listener")

                self.notify_replication()

    def on_new_replication_data(self):
        """Used to inform replication listeners that something has happend
        without waking up any of the normal user event streams"""
        self.notify_replication()

    @defer.inlineCallbacks
    def wait_for_events(self, user_id, timeout, callback, room_ids=None,
                        from_token=StreamToken.START):
        """Wait until the callback returns a non empty response or the
        timeout fires.
        """
        user_stream = self.user_to_user_stream.get(user_id)
        if user_stream is None:
            current_token = yield self.event_sources.get_current_token()
            if room_ids is None:
                room_ids = yield self.store.get_rooms_for_user(user_id)
            user_stream = _NotifierUserStream(
                user_id=user_id,
                rooms=room_ids,
                current_token=current_token,
                time_now_ms=self.clock.time_msec(),
            )
            self._register_with_keys(user_stream)

        result = None
        prev_token = from_token
        if timeout:
            end_time = self.clock.time_msec() + timeout

            while not result:
                try:
                    now = self.clock.time_msec()
                    if end_time <= now:
                        break

                    # Now we wait for the _NotifierUserStream to be told there
                    # is a new token.
                    listener = user_stream.new_listener(prev_token)
                    listener.deferred = timeout_deferred(
                        listener.deferred,
                        (end_time - now) / 1000.,
                        self.hs.get_reactor(),
                    )
                    with PreserveLoggingContext():
                        yield listener.deferred

                    current_token = user_stream.current_token

                    result = yield callback(prev_token, current_token)
                    if result:
                        break

                    # Update the prev_token to the current_token since nothing
                    # has happened between the old prev_token and the current_token
                    prev_token = current_token
                except defer.TimeoutError:
                    break
                except defer.CancelledError:
                    break

        if result is None:
            # This happened if there was no timeout or if the timeout had
            # already expired.
            current_token = user_stream.current_token
            result = yield callback(prev_token, current_token)

        defer.returnValue(result)

    @defer.inlineCallbacks
    def get_events_for(self, user, pagination_config, timeout,
                       only_keys=None,
                       is_guest=False, explicit_room_id=None):
        """ For the given user and rooms, return any new events for them. If
        there are no new events wait for up to `timeout` milliseconds for any
        new events to happen before returning.

        If `only_keys` is not None, events from keys will be sent down.

        If explicit_room_id is not set, the user's joined rooms will be polled
        for events.
        If explicit_room_id is set, that room will be polled for events only if
        it is world readable or the user has joined the room.
        """
        from_token = pagination_config.from_token
        if not from_token:
            from_token = yield self.event_sources.get_current_token()

        limit = pagination_config.limit

        room_ids, is_joined = yield self._get_room_ids(user, explicit_room_id)
        is_peeking = not is_joined

        @defer.inlineCallbacks
        def check_for_updates(before_token, after_token):
            if not after_token.is_after(before_token):
                defer.returnValue(EventStreamResult([], (from_token, from_token)))

            events = []
            end_token = from_token

            for name, source in self.event_sources.sources.items():
                keyname = "%s_key" % name
                before_id = getattr(before_token, keyname)
                after_id = getattr(after_token, keyname)
                if before_id == after_id:
                    continue
                if only_keys and name not in only_keys:
                    continue

                new_events, new_key = yield source.get_new_events(
                    user=user,
                    from_key=getattr(from_token, keyname),
                    limit=limit,
                    is_guest=is_peeking,
                    room_ids=room_ids,
                    explicit_room_id=explicit_room_id,
                )

                if name == "room":
                    new_events = yield filter_events_for_client(
                        self.store,
                        user.to_string(),
                        new_events,
                        is_peeking=is_peeking,
                    )
                elif name == "presence":
                    now = self.clock.time_msec()
                    new_events[:] = [
                        {
                            "type": "m.presence",
                            "content": format_user_presence_state(event, now),
                        }
                        for event in new_events
                    ]

                events.extend(new_events)
                end_token = end_token.copy_and_replace(keyname, new_key)

            defer.returnValue(EventStreamResult(events, (from_token, end_token)))

        user_id_for_stream = user.to_string()
        if is_peeking:
            # Internally, the notifier keeps an event stream per user_id.
            # This is used by both /sync and /events.
            # We want /events to be used for peeking independently of /sync,
            # without polluting its contents. So we invent an illegal user ID
            # (which thus cannot clash with any real users) for keying peeking
            # over /events.
            #
            # I am sorry for what I have done.
            user_id_for_stream = "_PEEKING_%s_%s" % (
                explicit_room_id, user_id_for_stream
            )

        result = yield self.wait_for_events(
            user_id_for_stream,
            timeout,
            check_for_updates,
            room_ids=room_ids,
            from_token=from_token,
        )

        defer.returnValue(result)

    @defer.inlineCallbacks
    def _get_room_ids(self, user, explicit_room_id):
        joined_room_ids = yield self.store.get_rooms_for_user(user.to_string())
        if explicit_room_id:
            if explicit_room_id in joined_room_ids:
                defer.returnValue(([explicit_room_id], True))
            if (yield self._is_world_readable(explicit_room_id)):
                defer.returnValue(([explicit_room_id], False))
            raise AuthError(403, "Non-joined access not allowed")
        defer.returnValue((joined_room_ids, True))

    @defer.inlineCallbacks
    def _is_world_readable(self, room_id):
        state = yield self.state_handler.get_current_state(
            room_id,
            EventTypes.RoomHistoryVisibility,
            "",
        )
        if state and "history_visibility" in state.content:
            defer.returnValue(state.content["history_visibility"] == "world_readable")
        else:
            defer.returnValue(False)

    @log_function
    def remove_expired_streams(self):
        time_now_ms = self.clock.time_msec()
        expired_streams = []
        expire_before_ts = time_now_ms - self.UNUSED_STREAM_EXPIRY_MS
        for stream in self.user_to_user_stream.values():
            if stream.count_listeners():
                continue
            if stream.last_notified_ms < expire_before_ts:
                expired_streams.append(stream)

        for expired_stream in expired_streams:
            expired_stream.remove(self)

    @log_function
    def _register_with_keys(self, user_stream):
        self.user_to_user_stream[user_stream.user_id] = user_stream

        for room in user_stream.rooms:
            s = self.room_to_user_streams.setdefault(room, set())
            s.add(user_stream)

    def _user_joined_room(self, user_id, room_id):
        new_user_stream = self.user_to_user_stream.get(user_id)
        if new_user_stream is not None:
            room_streams = self.room_to_user_streams.setdefault(room_id, set())
            room_streams.add(new_user_stream)
            new_user_stream.rooms.add(room_id)

    def notify_replication(self):
        """Notify the any replication listeners that there's a new event"""
        with PreserveLoggingContext():
            deferred = self.replication_deferred
            self.replication_deferred = ObservableDeferred(defer.Deferred())
            deferred.callback(None)

            # the callbacks may well outlast the current request, so we run
            # them in the sentinel logcontext.
            #
            # (ideally it would be up to the callbacks to know if they were
            # starting off background processes and drop the logcontext
            # accordingly, but that requires more changes)
            for cb in self.replication_callbacks:
                cb()

    @defer.inlineCallbacks
    def wait_for_replication(self, callback, timeout):
        """Wait for an event to happen.

        Args:
            callback: Gets called whenever an event happens. If this returns a
                truthy value then ``wait_for_replication`` returns, otherwise
                it waits for another event.
            timeout: How many milliseconds to wait for callback return a truthy
                value.

        Returns:
            A deferred that resolves with the value returned by the callback.
        """
        listener = _NotificationListener(None)

        end_time = self.clock.time_msec() + timeout

        while True:
            listener.deferred = self.replication_deferred.observe()
            result = yield callback()
            if result:
                break

            now = self.clock.time_msec()
            if end_time <= now:
                break

            listener.deferred = timeout_deferred(
                listener.deferred,
                timeout=(end_time - now) / 1000.,
                reactor=self.hs.get_reactor(),
            )

            try:
                with PreserveLoggingContext():
                    yield listener.deferred
            except defer.TimeoutError:
                break
            except defer.CancelledError:
                break

        defer.returnValue(result)
Exemple #30
0
        def wrapped(*args, **kwargs):
            # If we're passed a cache_context then we'll want to call its
            # invalidate() whenever we are invalidated
            invalidate_callback = kwargs.pop("on_invalidate", None)

            arg_dict = inspect.getcallargs(self.orig, obj, *args, **kwargs)
            keyargs = [arg_dict[arg_nm] for arg_nm in self.arg_names]
            list_args = arg_dict[self.list_name]

            results = {}

            def update_results_dict(res, arg):
                results[arg] = res

            # list of deferreds to wait for
            cached_defers = []

            missing = set()

            # If the cache takes a single arg then that is used as the key,
            # otherwise a tuple is used.
            if num_args == 1:

                def arg_to_cache_key(arg):
                    return arg

            else:
                keylist = list(keyargs)

                def arg_to_cache_key(arg):
                    keylist[self.list_pos] = arg
                    return tuple(keylist)

            for arg in list_args:
                try:
                    res = cache.get(arg_to_cache_key(arg),
                                    callback=invalidate_callback)
                    if not isinstance(res, ObservableDeferred):
                        results[arg] = res
                    elif not res.has_succeeded():
                        res = res.observe()
                        res.addCallback(update_results_dict, arg)
                        cached_defers.append(res)
                    else:
                        results[arg] = res.get_result()
                except KeyError:
                    missing.add(arg)

            if missing:
                # we need an observable deferred for each entry in the list,
                # which we put in the cache. Each deferred resolves with the
                # relevant result for that key.
                deferreds_map = {}
                for arg in missing:
                    deferred = defer.Deferred()
                    deferreds_map[arg] = deferred
                    key = arg_to_cache_key(arg)
                    observable = ObservableDeferred(deferred)
                    cache.set(key, observable, callback=invalidate_callback)

                def complete_all(res):
                    # the wrapped function has completed. It returns a
                    # a dict. We can now resolve the observable deferreds in
                    # the cache and update our own result map.
                    for e in missing:
                        val = res.get(e, None)
                        deferreds_map[e].callback(val)
                        results[e] = val

                def errback(f):
                    # the wrapped function has failed. Invalidate any cache
                    # entries we're supposed to be populating, and fail
                    # their deferreds.
                    for e in missing:
                        key = arg_to_cache_key(e)
                        cache.invalidate(key)
                        deferreds_map[e].errback(f)

                    # return the failure, to propagate to our caller.
                    return f

                args_to_call = dict(arg_dict)
                args_to_call[self.list_name] = list(missing)

                cached_defers.append(
                    defer.maybeDeferred(preserve_fn(self.function_to_call),
                                        **args_to_call).addCallbacks(
                                            complete_all, errback))

            if cached_defers:
                d = defer.gatherResults(cached_defers,
                                        consumeErrors=True).addCallbacks(
                                            lambda _: results,
                                            unwrapFirstError)
                return make_deferred_yieldable(d)
            else:
                return results
Exemple #31
0
    def set(
        self,
        key: KT,
        value: "defer.Deferred[VT]",
        callback: Optional[Callable[[], None]] = None,
    ) -> defer.Deferred:
        """Adds a new entry to the cache (or updates an existing one).

        The given `value` *must* be a Deferred.

        First any existing entry for the same key is invalidated. Then a new entry
        is added to the cache for the given key.

        Until the `value` completes, calls to `get()` for the key will also result in an
        incomplete Deferred, which will ultimately complete with the same result as
        `value`.

        If `value` completes successfully, subsequent calls to `get()` will then return
        a completed deferred with the same result. If it *fails*, the cache is
        invalidated and subequent calls to `get()` will raise a KeyError.

        If another call to `set()` happens before `value` completes, then (a) any
        invalidation callbacks registered in the interim will be called, (b) any
        `get()`s in the interim will continue to complete with the result from the
        *original* `value`, (c) any future calls to `get()` will complete with the
        result from the *new* `value`.

        It is expected that `value` does *not* follow the synapse logcontext rules - ie,
        if it is incomplete, it runs its callbacks in the sentinel context.

        Args:
            key: Key to be set
            value: a deferred which will complete with a result to add to the cache
            callback: An optional callback to be called when the entry is invalidated
        """
        if not isinstance(value, defer.Deferred):
            raise TypeError("not a Deferred")

        callbacks = [callback] if callback else []
        self.check_thread()

        existing_entry = self._pending_deferred_cache.pop(key, None)
        if existing_entry:
            existing_entry.invalidate()

        # XXX: why don't we invalidate the entry in `self.cache` yet?

        # we can save a whole load of effort if the deferred is ready.
        if value.called:
            result = value.result
            if not isinstance(result, failure.Failure):
                self.cache.set(key, cast(VT, result), callbacks)
            return value

        # otherwise, we'll add an entry to the _pending_deferred_cache for now,
        # and add callbacks to add it to the cache properly later.

        observable = ObservableDeferred(value, consumeErrors=True)
        observer = observable.observe()
        entry = CacheEntry(deferred=observable, callbacks=callbacks)

        self._pending_deferred_cache[key] = entry

        def compare_and_pop() -> bool:
            """Check if our entry is still the one in _pending_deferred_cache, and
            if so, pop it.

            Returns true if the entries matched.
            """
            existing_entry = self._pending_deferred_cache.pop(key, None)
            if existing_entry is entry:
                return True

            # oops, the _pending_deferred_cache has been updated since
            # we started our query, so we are out of date.
            #
            # Better put back whatever we took out. (We do it this way
            # round, rather than peeking into the _pending_deferred_cache
            # and then removing on a match, to make the common case faster)
            if existing_entry is not None:
                self._pending_deferred_cache[key] = existing_entry

            return False

        def cb(result: VT) -> None:
            if compare_and_pop():
                self.cache.set(key, result, entry.callbacks)
            else:
                # we're not going to put this entry into the cache, so need
                # to make sure that the invalidation callbacks are called.
                # That was probably done when _pending_deferred_cache was
                # updated, but it's possible that `set` was called without
                # `invalidate` being previously called, in which case it may
                # not have been. Either way, let's double-check now.
                entry.invalidate()

        def eb(_fail: Failure) -> None:
            compare_and_pop()
            entry.invalidate()

        # once the deferred completes, we can move the entry from the
        # _pending_deferred_cache to the real cache.
        #
        observer.addCallbacks(cb, eb)

        # we return a new Deferred which will be called before any subsequent observers.
        return observable.observe()